TUN-8817: Increase close session channel by one since there are two writers

When closing a session, there are two possible signals that will occur,
one from the outside, indicating that the session is idle and needs to
be closed, and the internal error condition that will be unblocked
with a net.ErrClosed when the connection underneath is closed. Both of
these routines write to the session's closeChan.

Once the reader for the closeChan reads one value, it will immediately
return. This means that the channel is a one-shot and one of the two
writers will get stuck unless the size of the channel is increased to
accomodate for the second write to the channel.

With the channel size increased to two, the second writer (whichever
loses the race to write) will now be unblocked to end their go routine
and return.

Closes TUN-8817
This commit is contained in:
Devin Carr 2024-12-17 14:55:09 -08:00
parent 1859d742a8
commit bc9c5d2e6e
2 changed files with 80 additions and 92 deletions

View File

@ -89,7 +89,10 @@ func NewSession(
log *zerolog.Logger, log *zerolog.Logger,
) Session { ) Session {
logger := log.With().Str(logFlowID, id.String()).Logger() logger := log.With().Str(logFlowID, id.String()).Logger()
closeChan := make(chan error, 1) // closeChan has two slots to allow for both writers (the closeFn and the Serve routine) to both be able to
// write to the channel without blocking since there is only ever one value read from the closeChan by the
// waitForCloseCondition.
closeChan := make(chan error, 2)
session := &session{ session := &session{
id: id, id: id,
closeAfterIdle: closeAfterIdle, closeAfterIdle: closeAfterIdle,

View File

@ -3,13 +3,14 @@ package v3_test
import ( import (
"context" "context"
"errors" "errors"
"io"
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
"sync/atomic"
"testing" "testing"
"time" "time"
"github.com/fortytw2/leaktest"
"github.com/rs/zerolog" "github.com/rs/zerolog"
v3 "github.com/cloudflare/cloudflared/quic/v3" v3 "github.com/cloudflare/cloudflared/quic/v3"
@ -32,45 +33,64 @@ func TestSessionNew(t *testing.T) {
func testSessionWrite(t *testing.T, payload []byte) { func testSessionWrite(t *testing.T, payload []byte) {
log := zerolog.Nop() log := zerolog.Nop()
origin := newTestOrigin(makePayload(1280)) origin, server := net.Pipe()
session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) defer origin.Close()
defer server.Close()
// Start origin server read
serverRead := make(chan []byte, 1)
go func() {
read := make([]byte, 1500)
server.Read(read[:])
serverRead <- read
}()
// Create session and write to origin
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
n, err := session.Write(payload) n, err := session.Write(payload)
defer session.Close()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if n != len(payload) { if n != len(payload) {
t.Fatal("unable to write the whole payload") t.Fatal("unable to write the whole payload")
} }
if !slices.Equal(payload, origin.write[:len(payload)]) {
read := <-serverRead
if !slices.Equal(payload, read[:len(payload)]) {
t.Fatal("payload provided from origin and read value are not the same") t.Fatal("payload provided from origin and read value are not the same")
} }
} }
func TestSessionWrite_Max(t *testing.T) { func TestSessionWrite_Max(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(1280) payload := makePayload(1280)
testSessionWrite(t, payload) testSessionWrite(t, payload)
} }
func TestSessionWrite_Min(t *testing.T) { func TestSessionWrite_Min(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(0) payload := makePayload(0)
testSessionWrite(t, payload) testSessionWrite(t, payload)
} }
func TestSessionServe_OriginMax(t *testing.T) { func TestSessionServe_OriginMax(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(1280) payload := makePayload(1280)
testSessionServe_Origin(t, payload) testSessionServe_Origin(t, payload)
} }
func TestSessionServe_OriginMin(t *testing.T) { func TestSessionServe_OriginMin(t *testing.T) {
defer leaktest.Check(t)()
payload := makePayload(0) payload := makePayload(0)
testSessionServe_Origin(t, payload) testSessionServe_Origin(t, payload)
} }
func testSessionServe_Origin(t *testing.T, payload []byte) { func testSessionServe_Origin(t *testing.T, payload []byte) {
log := zerolog.Nop() log := zerolog.Nop()
origin, server := net.Pipe()
defer origin.Close()
defer server.Close()
eyeball := newMockEyeball() eyeball := newMockEyeball()
origin := newTestOrigin(payload) session := v3.NewSession(testRequestID, 3*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
session := v3.NewSession(testRequestID, 3*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
defer session.Close() defer session.Close()
ctx, cancel := context.WithCancelCause(context.Background()) ctx, cancel := context.WithCancelCause(context.Background())
@ -80,13 +100,19 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
done <- session.Serve(ctx) done <- session.Serve(ctx)
}() }()
// Write from the origin server
_, err := server.Write(payload)
if err != nil {
t.Fatal(err)
}
select { select {
case data := <-eyeball.recvData: case data := <-eyeball.recvData:
// check received data matches provided from origin // check received data matches provided from origin
expectedData := makePayload(1500) expectedData := makePayload(1500)
v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:]) v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
copy(expectedData[17:], payload) copy(expectedData[17:], payload)
if !slices.Equal(expectedData[:17+len(payload)], data) { if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) {
t.Fatal("expected datagram did not equal expected") t.Fatal("expected datagram did not equal expected")
} }
cancel(expectedContextCanceled) cancel(expectedContextCanceled)
@ -95,7 +121,7 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
t.Fatal(err) t.Fatal(err)
} }
err := <-done err = <-done
if !errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {
t.Fatal(err) t.Fatal(err)
} }
@ -105,11 +131,14 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
} }
func TestSessionServe_OriginTooLarge(t *testing.T) { func TestSessionServe_OriginTooLarge(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
eyeball := newMockEyeball() eyeball := newMockEyeball()
payload := makePayload(1281) payload := makePayload(1281)
origin := newTestOrigin(payload) origin, server := net.Pipe()
session := v3.NewSession(testRequestID, 2*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log) defer origin.Close()
defer server.Close()
session := v3.NewSession(testRequestID, 2*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
defer session.Close() defer session.Close()
done := make(chan error) done := make(chan error)
@ -117,6 +146,12 @@ func TestSessionServe_OriginTooLarge(t *testing.T) {
done <- session.Serve(context.Background()) done <- session.Serve(context.Background())
}() }()
// Attempt to write a payload too large from the origin
_, err := server.Write(payload)
if err != nil {
t.Fatal(err)
}
select { select {
case data := <-eyeball.recvData: case data := <-eyeball.recvData:
// we never expect a read to make it here because the origin provided a payload that is too large // we never expect a read to make it here because the origin provided a payload that is too large
@ -130,6 +165,7 @@ func TestSessionServe_OriginTooLarge(t *testing.T) {
} }
func TestSessionServe_Migrate(t *testing.T) { func TestSessionServe_Migrate(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
eyeball := newMockEyeball() eyeball := newMockEyeball()
pipe1, pipe2 := net.Pipe() pipe1, pipe2 := net.Pipe()
@ -186,6 +222,7 @@ func TestSessionServe_Migrate(t *testing.T) {
} }
func TestSessionServe_Migrate_CloseContext2(t *testing.T) { func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
eyeball := newMockEyeball() eyeball := newMockEyeball()
pipe1, pipe2 := net.Pipe() pipe1, pipe2 := net.Pipe()
@ -245,39 +282,48 @@ func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
} }
func TestSessionClose_Multiple(t *testing.T) { func TestSessionClose_Multiple(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
origin := newTestOrigin(makePayload(128)) origin, server := net.Pipe()
session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) defer origin.Close()
defer server.Close()
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
err := session.Close() err := session.Close()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !origin.closed.Load() { b := [1500]byte{}
t.Fatal("origin wasn't closed") _, err = server.Read(b[:])
if !errors.Is(err, io.EOF) {
t.Fatalf("origin server connection should be closed: %s", err)
} }
// Reset the closed status to make sure it isn't closed again
origin.closed.Store(false)
// subsequent closes shouldn't call close again or cause any errors // subsequent closes shouldn't call close again or cause any errors
err = session.Close() err = session.Close()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if origin.closed.Load() { _, err = server.Read(b[:])
t.Fatal("origin was incorrectly closed twice") if !errors.Is(err, io.EOF) {
t.Fatalf("origin server connection should still be closed: %s", err)
} }
} }
func TestSessionServe_IdleTimeout(t *testing.T) { func TestSessionServe_IdleTimeout(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
origin := newTestIdleOrigin(10 * time.Second) // Make idle time longer than closeAfterIdle origin, server := net.Pipe()
defer origin.Close()
defer server.Close()
closeAfterIdle := 2 * time.Second closeAfterIdle := 2 * time.Second
session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
err := session.Serve(context.Background()) err := session.Serve(context.Background())
if !errors.Is(err, v3.SessionIdleErr{}) { if !errors.Is(err, v3.SessionIdleErr{}) {
t.Fatal(err) t.Fatal(err)
} }
// session should be closed // session should be closed
if !origin.closed { b := [1500]byte{}
_, err = server.Read(b[:])
if !errors.Is(err, io.EOF) {
t.Fatalf("session should be closed after Serve returns") t.Fatalf("session should be closed after Serve returns")
} }
// closing a session again should not return an error // closing a session again should not return an error
@ -288,12 +334,14 @@ func TestSessionServe_IdleTimeout(t *testing.T) {
} }
func TestSessionServe_ParentContextCanceled(t *testing.T) { func TestSessionServe_ParentContextCanceled(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
// Make idle time and idle timeout longer than closeAfterIdle origin, server := net.Pipe()
origin := newTestIdleOrigin(10 * time.Second) defer origin.Close()
defer server.Close()
closeAfterIdle := 10 * time.Second closeAfterIdle := 10 * time.Second
session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
err := session.Serve(ctx) err := session.Serve(ctx)
@ -301,7 +349,9 @@ func TestSessionServe_ParentContextCanceled(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
// session should be closed // session should be closed
if !origin.closed { b := [1500]byte{}
_, err = server.Read(b[:])
if !errors.Is(err, io.EOF) {
t.Fatalf("session should be closed after Serve returns") t.Fatalf("session should be closed after Serve returns")
} }
// closing a session again should not return an error // closing a session again should not return an error
@ -312,6 +362,7 @@ func TestSessionServe_ParentContextCanceled(t *testing.T) {
} }
func TestSessionServe_ReadErrors(t *testing.T) { func TestSessionServe_ReadErrors(t *testing.T) {
defer leaktest.Check(t)()
log := zerolog.Nop() log := zerolog.Nop()
origin := newTestErrOrigin(net.ErrClosed, nil) origin := newTestErrOrigin(net.ErrClosed, nil)
session := v3.NewSession(testRequestID, 30*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) session := v3.NewSession(testRequestID, 30*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
@ -321,72 +372,6 @@ func TestSessionServe_ReadErrors(t *testing.T) {
} }
} }
type testOrigin struct {
// bytes from Write
write []byte
// bytes provided to Read
read []byte
readOnce atomic.Bool
closed atomic.Bool
}
func newTestOrigin(payload []byte) testOrigin {
return testOrigin{
read: payload,
}
}
func (o *testOrigin) Read(p []byte) (n int, err error) {
if o.closed.Load() {
return -1, net.ErrClosed
}
if o.readOnce.Load() {
// We only want to provide one read so all other reads will be blocked
time.Sleep(10 * time.Second)
}
o.readOnce.Store(true)
return copy(p, o.read), nil
}
func (o *testOrigin) Write(p []byte) (n int, err error) {
if o.closed.Load() {
return -1, net.ErrClosed
}
o.write = make([]byte, len(p))
copy(o.write, p)
return len(p), nil
}
func (o *testOrigin) Close() error {
o.closed.Store(true)
return nil
}
type testIdleOrigin struct {
duration time.Duration
closed bool
}
func newTestIdleOrigin(d time.Duration) testIdleOrigin {
return testIdleOrigin{
duration: d,
}
}
func (o *testIdleOrigin) Read(p []byte) (n int, err error) {
time.Sleep(o.duration)
return -1, nil
}
func (o *testIdleOrigin) Write(p []byte) (n int, err error) {
return 0, nil
}
func (o *testIdleOrigin) Close() error {
o.closed = true
return nil
}
type testErrOrigin struct { type testErrOrigin struct {
readErr error readErr error
writeErr error writeErr error