TUN-8817: Increase close session channel by one since there are two writers
When closing a session, there are two possible signals that will occur, one from the outside, indicating that the session is idle and needs to be closed, and the internal error condition that will be unblocked with a net.ErrClosed when the connection underneath is closed. Both of these routines write to the session's closeChan. Once the reader for the closeChan reads one value, it will immediately return. This means that the channel is a one-shot and one of the two writers will get stuck unless the size of the channel is increased to accomodate for the second write to the channel. With the channel size increased to two, the second writer (whichever loses the race to write) will now be unblocked to end their go routine and return. Closes TUN-8817
This commit is contained in:
parent
1859d742a8
commit
bc9c5d2e6e
quic/v3
|
@ -89,7 +89,10 @@ func NewSession(
|
||||||
log *zerolog.Logger,
|
log *zerolog.Logger,
|
||||||
) Session {
|
) Session {
|
||||||
logger := log.With().Str(logFlowID, id.String()).Logger()
|
logger := log.With().Str(logFlowID, id.String()).Logger()
|
||||||
closeChan := make(chan error, 1)
|
// closeChan has two slots to allow for both writers (the closeFn and the Serve routine) to both be able to
|
||||||
|
// write to the channel without blocking since there is only ever one value read from the closeChan by the
|
||||||
|
// waitForCloseCondition.
|
||||||
|
closeChan := make(chan error, 2)
|
||||||
session := &session{
|
session := &session{
|
||||||
id: id,
|
id: id,
|
||||||
closeAfterIdle: closeAfterIdle,
|
closeAfterIdle: closeAfterIdle,
|
||||||
|
|
|
@ -3,13 +3,14 @@ package v3_test
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fortytw2/leaktest"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
||||||
|
@ -32,45 +33,64 @@ func TestSessionNew(t *testing.T) {
|
||||||
|
|
||||||
func testSessionWrite(t *testing.T, payload []byte) {
|
func testSessionWrite(t *testing.T, payload []byte) {
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
origin := newTestOrigin(makePayload(1280))
|
origin, server := net.Pipe()
|
||||||
session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
defer origin.Close()
|
||||||
|
defer server.Close()
|
||||||
|
// Start origin server read
|
||||||
|
serverRead := make(chan []byte, 1)
|
||||||
|
go func() {
|
||||||
|
read := make([]byte, 1500)
|
||||||
|
server.Read(read[:])
|
||||||
|
serverRead <- read
|
||||||
|
}()
|
||||||
|
// Create session and write to origin
|
||||||
|
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||||
n, err := session.Write(payload)
|
n, err := session.Write(payload)
|
||||||
|
defer session.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if n != len(payload) {
|
if n != len(payload) {
|
||||||
t.Fatal("unable to write the whole payload")
|
t.Fatal("unable to write the whole payload")
|
||||||
}
|
}
|
||||||
if !slices.Equal(payload, origin.write[:len(payload)]) {
|
|
||||||
|
read := <-serverRead
|
||||||
|
if !slices.Equal(payload, read[:len(payload)]) {
|
||||||
t.Fatal("payload provided from origin and read value are not the same")
|
t.Fatal("payload provided from origin and read value are not the same")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionWrite_Max(t *testing.T) {
|
func TestSessionWrite_Max(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
payload := makePayload(1280)
|
payload := makePayload(1280)
|
||||||
testSessionWrite(t, payload)
|
testSessionWrite(t, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionWrite_Min(t *testing.T) {
|
func TestSessionWrite_Min(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
payload := makePayload(0)
|
payload := makePayload(0)
|
||||||
testSessionWrite(t, payload)
|
testSessionWrite(t, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionServe_OriginMax(t *testing.T) {
|
func TestSessionServe_OriginMax(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
payload := makePayload(1280)
|
payload := makePayload(1280)
|
||||||
testSessionServe_Origin(t, payload)
|
testSessionServe_Origin(t, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionServe_OriginMin(t *testing.T) {
|
func TestSessionServe_OriginMin(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
payload := makePayload(0)
|
payload := makePayload(0)
|
||||||
testSessionServe_Origin(t, payload)
|
testSessionServe_Origin(t, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSessionServe_Origin(t *testing.T, payload []byte) {
|
func testSessionServe_Origin(t *testing.T, payload []byte) {
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
|
origin, server := net.Pipe()
|
||||||
|
defer origin.Close()
|
||||||
|
defer server.Close()
|
||||||
eyeball := newMockEyeball()
|
eyeball := newMockEyeball()
|
||||||
origin := newTestOrigin(payload)
|
session := v3.NewSession(testRequestID, 3*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
|
||||||
session := v3.NewSession(testRequestID, 3*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
|
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancelCause(context.Background())
|
ctx, cancel := context.WithCancelCause(context.Background())
|
||||||
|
@ -80,13 +100,19 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
|
||||||
done <- session.Serve(ctx)
|
done <- session.Serve(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Write from the origin server
|
||||||
|
_, err := server.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case data := <-eyeball.recvData:
|
case data := <-eyeball.recvData:
|
||||||
// check received data matches provided from origin
|
// check received data matches provided from origin
|
||||||
expectedData := makePayload(1500)
|
expectedData := makePayload(1500)
|
||||||
v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
|
v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
|
||||||
copy(expectedData[17:], payload)
|
copy(expectedData[17:], payload)
|
||||||
if !slices.Equal(expectedData[:17+len(payload)], data) {
|
if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) {
|
||||||
t.Fatal("expected datagram did not equal expected")
|
t.Fatal("expected datagram did not equal expected")
|
||||||
}
|
}
|
||||||
cancel(expectedContextCanceled)
|
cancel(expectedContextCanceled)
|
||||||
|
@ -95,7 +121,7 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := <-done
|
err = <-done
|
||||||
if !errors.Is(err, context.Canceled) {
|
if !errors.Is(err, context.Canceled) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -105,11 +131,14 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionServe_OriginTooLarge(t *testing.T) {
|
func TestSessionServe_OriginTooLarge(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
eyeball := newMockEyeball()
|
eyeball := newMockEyeball()
|
||||||
payload := makePayload(1281)
|
payload := makePayload(1281)
|
||||||
origin := newTestOrigin(payload)
|
origin, server := net.Pipe()
|
||||||
session := v3.NewSession(testRequestID, 2*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
|
defer origin.Close()
|
||||||
|
defer server.Close()
|
||||||
|
session := v3.NewSession(testRequestID, 2*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
|
|
||||||
done := make(chan error)
|
done := make(chan error)
|
||||||
|
@ -117,6 +146,12 @@ func TestSessionServe_OriginTooLarge(t *testing.T) {
|
||||||
done <- session.Serve(context.Background())
|
done <- session.Serve(context.Background())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Attempt to write a payload too large from the origin
|
||||||
|
_, err := server.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case data := <-eyeball.recvData:
|
case data := <-eyeball.recvData:
|
||||||
// we never expect a read to make it here because the origin provided a payload that is too large
|
// we never expect a read to make it here because the origin provided a payload that is too large
|
||||||
|
@ -130,6 +165,7 @@ func TestSessionServe_OriginTooLarge(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionServe_Migrate(t *testing.T) {
|
func TestSessionServe_Migrate(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
eyeball := newMockEyeball()
|
eyeball := newMockEyeball()
|
||||||
pipe1, pipe2 := net.Pipe()
|
pipe1, pipe2 := net.Pipe()
|
||||||
|
@ -186,6 +222,7 @@ func TestSessionServe_Migrate(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
|
func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
eyeball := newMockEyeball()
|
eyeball := newMockEyeball()
|
||||||
pipe1, pipe2 := net.Pipe()
|
pipe1, pipe2 := net.Pipe()
|
||||||
|
@ -245,39 +282,48 @@ func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionClose_Multiple(t *testing.T) {
|
func TestSessionClose_Multiple(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
origin := newTestOrigin(makePayload(128))
|
origin, server := net.Pipe()
|
||||||
session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
defer origin.Close()
|
||||||
|
defer server.Close()
|
||||||
|
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||||
err := session.Close()
|
err := session.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !origin.closed.Load() {
|
b := [1500]byte{}
|
||||||
t.Fatal("origin wasn't closed")
|
_, err = server.Read(b[:])
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatalf("origin server connection should be closed: %s", err)
|
||||||
}
|
}
|
||||||
// Reset the closed status to make sure it isn't closed again
|
|
||||||
origin.closed.Store(false)
|
|
||||||
// subsequent closes shouldn't call close again or cause any errors
|
// subsequent closes shouldn't call close again or cause any errors
|
||||||
err = session.Close()
|
err = session.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if origin.closed.Load() {
|
_, err = server.Read(b[:])
|
||||||
t.Fatal("origin was incorrectly closed twice")
|
if !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatalf("origin server connection should still be closed: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionServe_IdleTimeout(t *testing.T) {
|
func TestSessionServe_IdleTimeout(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
origin := newTestIdleOrigin(10 * time.Second) // Make idle time longer than closeAfterIdle
|
origin, server := net.Pipe()
|
||||||
|
defer origin.Close()
|
||||||
|
defer server.Close()
|
||||||
closeAfterIdle := 2 * time.Second
|
closeAfterIdle := 2 * time.Second
|
||||||
session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||||
err := session.Serve(context.Background())
|
err := session.Serve(context.Background())
|
||||||
if !errors.Is(err, v3.SessionIdleErr{}) {
|
if !errors.Is(err, v3.SessionIdleErr{}) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
// session should be closed
|
// session should be closed
|
||||||
if !origin.closed {
|
b := [1500]byte{}
|
||||||
|
_, err = server.Read(b[:])
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
t.Fatalf("session should be closed after Serve returns")
|
t.Fatalf("session should be closed after Serve returns")
|
||||||
}
|
}
|
||||||
// closing a session again should not return an error
|
// closing a session again should not return an error
|
||||||
|
@ -288,12 +334,14 @@ func TestSessionServe_IdleTimeout(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionServe_ParentContextCanceled(t *testing.T) {
|
func TestSessionServe_ParentContextCanceled(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
// Make idle time and idle timeout longer than closeAfterIdle
|
origin, server := net.Pipe()
|
||||||
origin := newTestIdleOrigin(10 * time.Second)
|
defer origin.Close()
|
||||||
|
defer server.Close()
|
||||||
closeAfterIdle := 10 * time.Second
|
closeAfterIdle := 10 * time.Second
|
||||||
|
|
||||||
session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
err := session.Serve(ctx)
|
err := session.Serve(ctx)
|
||||||
|
@ -301,7 +349,9 @@ func TestSessionServe_ParentContextCanceled(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
// session should be closed
|
// session should be closed
|
||||||
if !origin.closed {
|
b := [1500]byte{}
|
||||||
|
_, err = server.Read(b[:])
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
t.Fatalf("session should be closed after Serve returns")
|
t.Fatalf("session should be closed after Serve returns")
|
||||||
}
|
}
|
||||||
// closing a session again should not return an error
|
// closing a session again should not return an error
|
||||||
|
@ -312,6 +362,7 @@ func TestSessionServe_ParentContextCanceled(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionServe_ReadErrors(t *testing.T) {
|
func TestSessionServe_ReadErrors(t *testing.T) {
|
||||||
|
defer leaktest.Check(t)()
|
||||||
log := zerolog.Nop()
|
log := zerolog.Nop()
|
||||||
origin := newTestErrOrigin(net.ErrClosed, nil)
|
origin := newTestErrOrigin(net.ErrClosed, nil)
|
||||||
session := v3.NewSession(testRequestID, 30*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
session := v3.NewSession(testRequestID, 30*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||||
|
@ -321,72 +372,6 @@ func TestSessionServe_ReadErrors(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type testOrigin struct {
|
|
||||||
// bytes from Write
|
|
||||||
write []byte
|
|
||||||
// bytes provided to Read
|
|
||||||
read []byte
|
|
||||||
readOnce atomic.Bool
|
|
||||||
closed atomic.Bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestOrigin(payload []byte) testOrigin {
|
|
||||||
return testOrigin{
|
|
||||||
read: payload,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *testOrigin) Read(p []byte) (n int, err error) {
|
|
||||||
if o.closed.Load() {
|
|
||||||
return -1, net.ErrClosed
|
|
||||||
}
|
|
||||||
if o.readOnce.Load() {
|
|
||||||
// We only want to provide one read so all other reads will be blocked
|
|
||||||
time.Sleep(10 * time.Second)
|
|
||||||
}
|
|
||||||
o.readOnce.Store(true)
|
|
||||||
return copy(p, o.read), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *testOrigin) Write(p []byte) (n int, err error) {
|
|
||||||
if o.closed.Load() {
|
|
||||||
return -1, net.ErrClosed
|
|
||||||
}
|
|
||||||
o.write = make([]byte, len(p))
|
|
||||||
copy(o.write, p)
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *testOrigin) Close() error {
|
|
||||||
o.closed.Store(true)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type testIdleOrigin struct {
|
|
||||||
duration time.Duration
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestIdleOrigin(d time.Duration) testIdleOrigin {
|
|
||||||
return testIdleOrigin{
|
|
||||||
duration: d,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *testIdleOrigin) Read(p []byte) (n int, err error) {
|
|
||||||
time.Sleep(o.duration)
|
|
||||||
return -1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *testIdleOrigin) Write(p []byte) (n int, err error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *testIdleOrigin) Close() error {
|
|
||||||
o.closed = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type testErrOrigin struct {
|
type testErrOrigin struct {
|
||||||
readErr error
|
readErr error
|
||||||
writeErr error
|
writeErr error
|
||||||
|
|
Loading…
Reference in New Issue