diff --git a/quic/v3/session.go b/quic/v3/session.go index fa1b1f6e..6836aed9 100644 --- a/quic/v3/session.go +++ b/quic/v3/session.go @@ -89,7 +89,10 @@ func NewSession( log *zerolog.Logger, ) Session { logger := log.With().Str(logFlowID, id.String()).Logger() - closeChan := make(chan error, 1) + // closeChan has two slots to allow for both writers (the closeFn and the Serve routine) to both be able to + // write to the channel without blocking since there is only ever one value read from the closeChan by the + // waitForCloseCondition. + closeChan := make(chan error, 2) session := &session{ id: id, closeAfterIdle: closeAfterIdle, diff --git a/quic/v3/session_test.go b/quic/v3/session_test.go index 8c570074..b739ca2d 100644 --- a/quic/v3/session_test.go +++ b/quic/v3/session_test.go @@ -3,13 +3,14 @@ package v3_test import ( "context" "errors" + "io" "net" "net/netip" "slices" - "sync/atomic" "testing" "time" + "github.com/fortytw2/leaktest" "github.com/rs/zerolog" v3 "github.com/cloudflare/cloudflared/quic/v3" @@ -32,45 +33,64 @@ func TestSessionNew(t *testing.T) { func testSessionWrite(t *testing.T, payload []byte) { log := zerolog.Nop() - origin := newTestOrigin(makePayload(1280)) - session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + origin, server := net.Pipe() + defer origin.Close() + defer server.Close() + // Start origin server read + serverRead := make(chan []byte, 1) + go func() { + read := make([]byte, 1500) + server.Read(read[:]) + serverRead <- read + }() + // Create session and write to origin + session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) n, err := session.Write(payload) + defer session.Close() if err != nil { t.Fatal(err) } if n != len(payload) { t.Fatal("unable to write the whole payload") } - if !slices.Equal(payload, origin.write[:len(payload)]) { + + read := <-serverRead + if !slices.Equal(payload, read[:len(payload)]) { t.Fatal("payload provided from origin and read value are not the same") } } func TestSessionWrite_Max(t *testing.T) { + defer leaktest.Check(t)() payload := makePayload(1280) testSessionWrite(t, payload) } func TestSessionWrite_Min(t *testing.T) { + defer leaktest.Check(t)() payload := makePayload(0) testSessionWrite(t, payload) } func TestSessionServe_OriginMax(t *testing.T) { + defer leaktest.Check(t)() payload := makePayload(1280) testSessionServe_Origin(t, payload) } func TestSessionServe_OriginMin(t *testing.T) { + defer leaktest.Check(t)() payload := makePayload(0) testSessionServe_Origin(t, payload) } func testSessionServe_Origin(t *testing.T, payload []byte) { log := zerolog.Nop() + origin, server := net.Pipe() + defer origin.Close() + defer server.Close() eyeball := newMockEyeball() - origin := newTestOrigin(payload) - session := v3.NewSession(testRequestID, 3*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log) + session := v3.NewSession(testRequestID, 3*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log) defer session.Close() ctx, cancel := context.WithCancelCause(context.Background()) @@ -80,13 +100,19 @@ func testSessionServe_Origin(t *testing.T, payload []byte) { done <- session.Serve(ctx) }() + // Write from the origin server + _, err := server.Write(payload) + if err != nil { + t.Fatal(err) + } + select { case data := <-eyeball.recvData: // check received data matches provided from origin expectedData := makePayload(1500) v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:]) copy(expectedData[17:], payload) - if !slices.Equal(expectedData[:17+len(payload)], data) { + if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) { t.Fatal("expected datagram did not equal expected") } cancel(expectedContextCanceled) @@ -95,7 +121,7 @@ func testSessionServe_Origin(t *testing.T, payload []byte) { t.Fatal(err) } - err := <-done + err = <-done if !errors.Is(err, context.Canceled) { t.Fatal(err) } @@ -105,11 +131,14 @@ func testSessionServe_Origin(t *testing.T, payload []byte) { } func TestSessionServe_OriginTooLarge(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() eyeball := newMockEyeball() payload := makePayload(1281) - origin := newTestOrigin(payload) - session := v3.NewSession(testRequestID, 2*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log) + origin, server := net.Pipe() + defer origin.Close() + defer server.Close() + session := v3.NewSession(testRequestID, 2*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log) defer session.Close() done := make(chan error) @@ -117,6 +146,12 @@ func TestSessionServe_OriginTooLarge(t *testing.T) { done <- session.Serve(context.Background()) }() + // Attempt to write a payload too large from the origin + _, err := server.Write(payload) + if err != nil { + t.Fatal(err) + } + select { case data := <-eyeball.recvData: // we never expect a read to make it here because the origin provided a payload that is too large @@ -130,6 +165,7 @@ func TestSessionServe_OriginTooLarge(t *testing.T) { } func TestSessionServe_Migrate(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() eyeball := newMockEyeball() pipe1, pipe2 := net.Pipe() @@ -186,6 +222,7 @@ func TestSessionServe_Migrate(t *testing.T) { } func TestSessionServe_Migrate_CloseContext2(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() eyeball := newMockEyeball() pipe1, pipe2 := net.Pipe() @@ -245,39 +282,48 @@ func TestSessionServe_Migrate_CloseContext2(t *testing.T) { } func TestSessionClose_Multiple(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - origin := newTestOrigin(makePayload(128)) - session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + origin, server := net.Pipe() + defer origin.Close() + defer server.Close() + session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) err := session.Close() if err != nil { t.Fatal(err) } - if !origin.closed.Load() { - t.Fatal("origin wasn't closed") + b := [1500]byte{} + _, err = server.Read(b[:]) + if !errors.Is(err, io.EOF) { + t.Fatalf("origin server connection should be closed: %s", err) } - // Reset the closed status to make sure it isn't closed again - origin.closed.Store(false) // subsequent closes shouldn't call close again or cause any errors err = session.Close() if err != nil { t.Fatal(err) } - if origin.closed.Load() { - t.Fatal("origin was incorrectly closed twice") + _, err = server.Read(b[:]) + if !errors.Is(err, io.EOF) { + t.Fatalf("origin server connection should still be closed: %s", err) } } func TestSessionServe_IdleTimeout(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - origin := newTestIdleOrigin(10 * time.Second) // Make idle time longer than closeAfterIdle + origin, server := net.Pipe() + defer origin.Close() + defer server.Close() closeAfterIdle := 2 * time.Second - session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) err := session.Serve(context.Background()) if !errors.Is(err, v3.SessionIdleErr{}) { t.Fatal(err) } // session should be closed - if !origin.closed { + b := [1500]byte{} + _, err = server.Read(b[:]) + if !errors.Is(err, io.EOF) { t.Fatalf("session should be closed after Serve returns") } // closing a session again should not return an error @@ -288,12 +334,14 @@ func TestSessionServe_IdleTimeout(t *testing.T) { } func TestSessionServe_ParentContextCanceled(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - // Make idle time and idle timeout longer than closeAfterIdle - origin := newTestIdleOrigin(10 * time.Second) + origin, server := net.Pipe() + defer origin.Close() + defer server.Close() closeAfterIdle := 10 * time.Second - session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() err := session.Serve(ctx) @@ -301,7 +349,9 @@ func TestSessionServe_ParentContextCanceled(t *testing.T) { t.Fatal(err) } // session should be closed - if !origin.closed { + b := [1500]byte{} + _, err = server.Read(b[:]) + if !errors.Is(err, io.EOF) { t.Fatalf("session should be closed after Serve returns") } // closing a session again should not return an error @@ -312,6 +362,7 @@ func TestSessionServe_ParentContextCanceled(t *testing.T) { } func TestSessionServe_ReadErrors(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() origin := newTestErrOrigin(net.ErrClosed, nil) session := v3.NewSession(testRequestID, 30*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) @@ -321,72 +372,6 @@ func TestSessionServe_ReadErrors(t *testing.T) { } } -type testOrigin struct { - // bytes from Write - write []byte - // bytes provided to Read - read []byte - readOnce atomic.Bool - closed atomic.Bool -} - -func newTestOrigin(payload []byte) testOrigin { - return testOrigin{ - read: payload, - } -} - -func (o *testOrigin) Read(p []byte) (n int, err error) { - if o.closed.Load() { - return -1, net.ErrClosed - } - if o.readOnce.Load() { - // We only want to provide one read so all other reads will be blocked - time.Sleep(10 * time.Second) - } - o.readOnce.Store(true) - return copy(p, o.read), nil -} - -func (o *testOrigin) Write(p []byte) (n int, err error) { - if o.closed.Load() { - return -1, net.ErrClosed - } - o.write = make([]byte, len(p)) - copy(o.write, p) - return len(p), nil -} - -func (o *testOrigin) Close() error { - o.closed.Store(true) - return nil -} - -type testIdleOrigin struct { - duration time.Duration - closed bool -} - -func newTestIdleOrigin(d time.Duration) testIdleOrigin { - return testIdleOrigin{ - duration: d, - } -} - -func (o *testIdleOrigin) Read(p []byte) (n int, err error) { - time.Sleep(o.duration) - return -1, nil -} - -func (o *testIdleOrigin) Write(p []byte) (n int, err error) { - return 0, nil -} - -func (o *testIdleOrigin) Close() error { - o.closed = true - return nil -} - type testErrOrigin struct { readErr error writeErr error