From bc9c5d2e6ee75ee826986c62b84724a6ab043723 Mon Sep 17 00:00:00 2001 From: Devin Carr Date: Tue, 17 Dec 2024 14:55:09 -0800 Subject: [PATCH] TUN-8817: Increase close session channel by one since there are two writers When closing a session, there are two possible signals that will occur, one from the outside, indicating that the session is idle and needs to be closed, and the internal error condition that will be unblocked with a net.ErrClosed when the connection underneath is closed. Both of these routines write to the session's closeChan. Once the reader for the closeChan reads one value, it will immediately return. This means that the channel is a one-shot and one of the two writers will get stuck unless the size of the channel is increased to accomodate for the second write to the channel. With the channel size increased to two, the second writer (whichever loses the race to write) will now be unblocked to end their go routine and return. Closes TUN-8817 --- quic/v3/session.go | 5 +- quic/v3/session_test.go | 167 ++++++++++++++++++---------------------- 2 files changed, 80 insertions(+), 92 deletions(-) diff --git a/quic/v3/session.go b/quic/v3/session.go index fa1b1f6e..6836aed9 100644 --- a/quic/v3/session.go +++ b/quic/v3/session.go @@ -89,7 +89,10 @@ func NewSession( log *zerolog.Logger, ) Session { logger := log.With().Str(logFlowID, id.String()).Logger() - closeChan := make(chan error, 1) + // closeChan has two slots to allow for both writers (the closeFn and the Serve routine) to both be able to + // write to the channel without blocking since there is only ever one value read from the closeChan by the + // waitForCloseCondition. + closeChan := make(chan error, 2) session := &session{ id: id, closeAfterIdle: closeAfterIdle, diff --git a/quic/v3/session_test.go b/quic/v3/session_test.go index 8c570074..b739ca2d 100644 --- a/quic/v3/session_test.go +++ b/quic/v3/session_test.go @@ -3,13 +3,14 @@ package v3_test import ( "context" "errors" + "io" "net" "net/netip" "slices" - "sync/atomic" "testing" "time" + "github.com/fortytw2/leaktest" "github.com/rs/zerolog" v3 "github.com/cloudflare/cloudflared/quic/v3" @@ -32,45 +33,64 @@ func TestSessionNew(t *testing.T) { func testSessionWrite(t *testing.T, payload []byte) { log := zerolog.Nop() - origin := newTestOrigin(makePayload(1280)) - session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + origin, server := net.Pipe() + defer origin.Close() + defer server.Close() + // Start origin server read + serverRead := make(chan []byte, 1) + go func() { + read := make([]byte, 1500) + server.Read(read[:]) + serverRead <- read + }() + // Create session and write to origin + session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) n, err := session.Write(payload) + defer session.Close() if err != nil { t.Fatal(err) } if n != len(payload) { t.Fatal("unable to write the whole payload") } - if !slices.Equal(payload, origin.write[:len(payload)]) { + + read := <-serverRead + if !slices.Equal(payload, read[:len(payload)]) { t.Fatal("payload provided from origin and read value are not the same") } } func TestSessionWrite_Max(t *testing.T) { + defer leaktest.Check(t)() payload := makePayload(1280) testSessionWrite(t, payload) } func TestSessionWrite_Min(t *testing.T) { + defer leaktest.Check(t)() payload := makePayload(0) testSessionWrite(t, payload) } func TestSessionServe_OriginMax(t *testing.T) { + defer leaktest.Check(t)() payload := makePayload(1280) testSessionServe_Origin(t, payload) } func TestSessionServe_OriginMin(t *testing.T) { + defer leaktest.Check(t)() payload := makePayload(0) testSessionServe_Origin(t, payload) } func testSessionServe_Origin(t *testing.T, payload []byte) { log := zerolog.Nop() + origin, server := net.Pipe() + defer origin.Close() + defer server.Close() eyeball := newMockEyeball() - origin := newTestOrigin(payload) - session := v3.NewSession(testRequestID, 3*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log) + session := v3.NewSession(testRequestID, 3*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log) defer session.Close() ctx, cancel := context.WithCancelCause(context.Background()) @@ -80,13 +100,19 @@ func testSessionServe_Origin(t *testing.T, payload []byte) { done <- session.Serve(ctx) }() + // Write from the origin server + _, err := server.Write(payload) + if err != nil { + t.Fatal(err) + } + select { case data := <-eyeball.recvData: // check received data matches provided from origin expectedData := makePayload(1500) v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:]) copy(expectedData[17:], payload) - if !slices.Equal(expectedData[:17+len(payload)], data) { + if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) { t.Fatal("expected datagram did not equal expected") } cancel(expectedContextCanceled) @@ -95,7 +121,7 @@ func testSessionServe_Origin(t *testing.T, payload []byte) { t.Fatal(err) } - err := <-done + err = <-done if !errors.Is(err, context.Canceled) { t.Fatal(err) } @@ -105,11 +131,14 @@ func testSessionServe_Origin(t *testing.T, payload []byte) { } func TestSessionServe_OriginTooLarge(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() eyeball := newMockEyeball() payload := makePayload(1281) - origin := newTestOrigin(payload) - session := v3.NewSession(testRequestID, 2*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log) + origin, server := net.Pipe() + defer origin.Close() + defer server.Close() + session := v3.NewSession(testRequestID, 2*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log) defer session.Close() done := make(chan error) @@ -117,6 +146,12 @@ func TestSessionServe_OriginTooLarge(t *testing.T) { done <- session.Serve(context.Background()) }() + // Attempt to write a payload too large from the origin + _, err := server.Write(payload) + if err != nil { + t.Fatal(err) + } + select { case data := <-eyeball.recvData: // we never expect a read to make it here because the origin provided a payload that is too large @@ -130,6 +165,7 @@ func TestSessionServe_OriginTooLarge(t *testing.T) { } func TestSessionServe_Migrate(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() eyeball := newMockEyeball() pipe1, pipe2 := net.Pipe() @@ -186,6 +222,7 @@ func TestSessionServe_Migrate(t *testing.T) { } func TestSessionServe_Migrate_CloseContext2(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() eyeball := newMockEyeball() pipe1, pipe2 := net.Pipe() @@ -245,39 +282,48 @@ func TestSessionServe_Migrate_CloseContext2(t *testing.T) { } func TestSessionClose_Multiple(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - origin := newTestOrigin(makePayload(128)) - session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + origin, server := net.Pipe() + defer origin.Close() + defer server.Close() + session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) err := session.Close() if err != nil { t.Fatal(err) } - if !origin.closed.Load() { - t.Fatal("origin wasn't closed") + b := [1500]byte{} + _, err = server.Read(b[:]) + if !errors.Is(err, io.EOF) { + t.Fatalf("origin server connection should be closed: %s", err) } - // Reset the closed status to make sure it isn't closed again - origin.closed.Store(false) // subsequent closes shouldn't call close again or cause any errors err = session.Close() if err != nil { t.Fatal(err) } - if origin.closed.Load() { - t.Fatal("origin was incorrectly closed twice") + _, err = server.Read(b[:]) + if !errors.Is(err, io.EOF) { + t.Fatalf("origin server connection should still be closed: %s", err) } } func TestSessionServe_IdleTimeout(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - origin := newTestIdleOrigin(10 * time.Second) // Make idle time longer than closeAfterIdle + origin, server := net.Pipe() + defer origin.Close() + defer server.Close() closeAfterIdle := 2 * time.Second - session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) err := session.Serve(context.Background()) if !errors.Is(err, v3.SessionIdleErr{}) { t.Fatal(err) } // session should be closed - if !origin.closed { + b := [1500]byte{} + _, err = server.Read(b[:]) + if !errors.Is(err, io.EOF) { t.Fatalf("session should be closed after Serve returns") } // closing a session again should not return an error @@ -288,12 +334,14 @@ func TestSessionServe_IdleTimeout(t *testing.T) { } func TestSessionServe_ParentContextCanceled(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() - // Make idle time and idle timeout longer than closeAfterIdle - origin := newTestIdleOrigin(10 * time.Second) + origin, server := net.Pipe() + defer origin.Close() + defer server.Close() closeAfterIdle := 10 * time.Second - session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() err := session.Serve(ctx) @@ -301,7 +349,9 @@ func TestSessionServe_ParentContextCanceled(t *testing.T) { t.Fatal(err) } // session should be closed - if !origin.closed { + b := [1500]byte{} + _, err = server.Read(b[:]) + if !errors.Is(err, io.EOF) { t.Fatalf("session should be closed after Serve returns") } // closing a session again should not return an error @@ -312,6 +362,7 @@ func TestSessionServe_ParentContextCanceled(t *testing.T) { } func TestSessionServe_ReadErrors(t *testing.T) { + defer leaktest.Check(t)() log := zerolog.Nop() origin := newTestErrOrigin(net.ErrClosed, nil) session := v3.NewSession(testRequestID, 30*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) @@ -321,72 +372,6 @@ func TestSessionServe_ReadErrors(t *testing.T) { } } -type testOrigin struct { - // bytes from Write - write []byte - // bytes provided to Read - read []byte - readOnce atomic.Bool - closed atomic.Bool -} - -func newTestOrigin(payload []byte) testOrigin { - return testOrigin{ - read: payload, - } -} - -func (o *testOrigin) Read(p []byte) (n int, err error) { - if o.closed.Load() { - return -1, net.ErrClosed - } - if o.readOnce.Load() { - // We only want to provide one read so all other reads will be blocked - time.Sleep(10 * time.Second) - } - o.readOnce.Store(true) - return copy(p, o.read), nil -} - -func (o *testOrigin) Write(p []byte) (n int, err error) { - if o.closed.Load() { - return -1, net.ErrClosed - } - o.write = make([]byte, len(p)) - copy(o.write, p) - return len(p), nil -} - -func (o *testOrigin) Close() error { - o.closed.Store(true) - return nil -} - -type testIdleOrigin struct { - duration time.Duration - closed bool -} - -func newTestIdleOrigin(d time.Duration) testIdleOrigin { - return testIdleOrigin{ - duration: d, - } -} - -func (o *testIdleOrigin) Read(p []byte) (n int, err error) { - time.Sleep(o.duration) - return -1, nil -} - -func (o *testIdleOrigin) Write(p []byte) (n int, err error) { - return 0, nil -} - -func (o *testIdleOrigin) Close() error { - o.closed = true - return nil -} - type testErrOrigin struct { readErr error writeErr error