TUN-8667: Add datagram v3 session manager

New session manager leverages similar functionality that was previously
provided with datagram v2, with the distinct difference that the sessions
are registered via QUIC Datagrams and unregistered via timeouts only; the
sessions will no longer attempt to unregister sessions remotely with the
edge service.

The Session Manager is shared across all QUIC connections that cloudflared
uses to connect to the edge (typically 4). This will help cloudflared be
able to monitor all sessions across the connections and help correlate
in the future if sessions migrate across connections.

The UDP payload size is still limited to 1280 bytes across all OS's. Any
UDP packet that provides a payload size of greater than 1280 will cause
cloudflared to report (as it currently does) a log error and drop the packet.

Closes TUN-8667
This commit is contained in:
Devin Carr 2024-10-31 14:05:15 -07:00
parent 599ba52750
commit 6a6c890700
11 changed files with 743 additions and 7 deletions

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
) )
type UDPProxy interface { type UDPProxy interface {
@ -30,3 +31,16 @@ func DialUDP(dstIP net.IP, dstPort uint16) (UDPProxy, error) {
return &udpProxy{udpConn}, nil return &udpProxy{udpConn}, nil
} }
func DialUDPAddrPort(dest netip.AddrPort) (*net.UDPConn, error) {
addr := net.UDPAddrFromAddrPort(dest)
// We use nil as local addr to force runtime to find the best suitable local address IP given the destination
// address as context.
udpConn, err := net.DialUDP("udp", nil, addr)
if err != nil {
return nil, fmt.Errorf("unable to create UDP proxy to origin (%v:%v): %w", dest.Addr(), dest.Port(), err)
}
return udpConn, nil
}

View File

@ -24,7 +24,7 @@ const (
datagramTypeLen = 1 datagramTypeLen = 1
// 1280 is the default datagram packet length used before MTU discovery: https://github.com/quic-go/quic-go/blob/v0.45.0/internal/protocol/params.go#L12 // 1280 is the default datagram packet length used before MTU discovery: https://github.com/quic-go/quic-go/blob/v0.45.0/internal/protocol/params.go#L12
maxDatagramLen = 1280 maxDatagramPayloadLen = 1280
) )
func parseDatagramType(data []byte) (DatagramType, error) { func parseDatagramType(data []byte) (DatagramType, error) {
@ -100,10 +100,10 @@ func (s *UDPSessionRegistrationDatagram) MarshalBinary() (data []byte, err error
} }
var maxPayloadLen int var maxPayloadLen int
if ipv6 { if ipv6 {
maxPayloadLen = maxDatagramLen - sessionRegistrationIPv6DatagramHeaderLen maxPayloadLen = maxDatagramPayloadLen + sessionRegistrationIPv6DatagramHeaderLen
flags |= sessionRegistrationFlagsIPMask flags |= sessionRegistrationFlagsIPMask
} else { } else {
maxPayloadLen = maxDatagramLen - sessionRegistrationIPv4DatagramHeaderLen maxPayloadLen = maxDatagramPayloadLen + sessionRegistrationIPv4DatagramHeaderLen
} }
// Make sure that the payload being bundled can actually fit in the payload destination // Make sure that the payload being bundled can actually fit in the payload destination
if len(s.Payload) > maxPayloadLen { if len(s.Payload) > maxPayloadLen {
@ -195,7 +195,7 @@ const (
datagramPayloadHeaderLen = datagramTypeLen + datagramRequestIdLen datagramPayloadHeaderLen = datagramTypeLen + datagramRequestIdLen
// The maximum size that a proxied UDP payload can be in a [UDPSessionPayloadDatagram] // The maximum size that a proxied UDP payload can be in a [UDPSessionPayloadDatagram]
maxPayloadPlusHeaderLen = maxDatagramLen - datagramPayloadHeaderLen maxPayloadPlusHeaderLen = maxDatagramPayloadLen + datagramPayloadHeaderLen
) )
// The datagram structure for UDPSessionPayloadDatagram is: // The datagram structure for UDPSessionPayloadDatagram is:
@ -270,7 +270,7 @@ const (
datagramSessionRegistrationResponseLen = datagramTypeLen + datagramRespTypeLen + datagramRequestIdLen + datagramRespErrMsgLen datagramSessionRegistrationResponseLen = datagramTypeLen + datagramRespTypeLen + datagramRequestIdLen + datagramRespErrMsgLen
// The maximum size that an error message can be in a [UDPSessionRegistrationResponseDatagram]. // The maximum size that an error message can be in a [UDPSessionRegistrationResponseDatagram].
maxResponseErrorMessageLen = maxDatagramLen - datagramSessionRegistrationResponseLen maxResponseErrorMessageLen = maxDatagramPayloadLen - datagramSessionRegistrationResponseLen
) )
// SessionRegistrationResp represents all of the responses that a UDP session registration response // SessionRegistrationResp represents all of the responses that a UDP session registration response

View File

@ -21,7 +21,7 @@ func makePayload(size int) []byte {
} }
func TestSessionRegistration_MarshalUnmarshal(t *testing.T) { func TestSessionRegistration_MarshalUnmarshal(t *testing.T) {
payload := makePayload(1254) payload := makePayload(1280)
tests := []*v3.UDPSessionRegistrationDatagram{ tests := []*v3.UDPSessionRegistrationDatagram{
// Default (IPv4) // Default (IPv4)
{ {
@ -236,7 +236,7 @@ func TestSessionPayload(t *testing.T) {
}) })
t.Run("payload size too large", func(t *testing.T) { t.Run("payload size too large", func(t *testing.T) {
datagram := makePayload(17 + 1264) // 1263 is the largest payload size allowed datagram := makePayload(17 + 1281) // 1280 is the largest payload size allowed
err := v3.MarshalPayloadHeaderTo(testRequestID, datagram) err := v3.MarshalPayloadHeaderTo(testRequestID, datagram)
if err != nil { if err != nil {
t.Error(err) t.Error(err)

87
quic/v3/manager.go Normal file
View File

@ -0,0 +1,87 @@
package v3
import (
"errors"
"net"
"net/netip"
"sync"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/ingress"
)
var (
ErrSessionNotFound = errors.New("session not found")
ErrSessionBoundToOtherConn = errors.New("session is in use by another connection")
)
type SessionManager interface {
// RegisterSession will register a new session if it does not already exist for the request ID.
// During new session creation, the session will also bind the UDP socket for the origin.
// If the session exists for a different connection, it will return [ErrSessionBoundToOtherConn].
RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramWriter) (Session, error)
// GetSession returns an active session if available for the provided connection.
// If the session does not exist, it will return [ErrSessionNotFound]. If the session exists for a different
// connection, it will return [ErrSessionBoundToOtherConn].
GetSession(requestID RequestID) (Session, error)
// UnregisterSession will remove a session from the current session manager. It will attempt to close the session
// before removal.
UnregisterSession(requestID RequestID)
}
type DialUDP func(dest netip.AddrPort) (*net.UDPConn, error)
type sessionManager struct {
sessions map[RequestID]Session
mutex sync.RWMutex
log *zerolog.Logger
}
func NewSessionManager(log *zerolog.Logger, originDialer DialUDP) SessionManager {
return &sessionManager{
sessions: make(map[RequestID]Session),
log: log,
}
}
func (s *sessionManager) RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramWriter) (Session, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
// Check to make sure session doesn't already exist for requestID
_, exists := s.sessions[request.RequestID]
if exists {
return nil, ErrSessionBoundToOtherConn
}
// Attempt to bind the UDP socket for the new session
origin, err := ingress.DialUDPAddrPort(request.Dest)
if err != nil {
return nil, err
}
// Create and insert the new session in the map
session := NewSession(request.RequestID, request.IdleDurationHint, origin, conn, s.log)
s.sessions[request.RequestID] = session
return session, nil
}
func (s *sessionManager) GetSession(requestID RequestID) (Session, error) {
s.mutex.RLock()
defer s.mutex.RUnlock()
session, exists := s.sessions[requestID]
if exists {
return session, nil
}
return nil, ErrSessionNotFound
}
func (s *sessionManager) UnregisterSession(requestID RequestID) {
s.mutex.Lock()
defer s.mutex.Unlock()
// Get the session and make sure to close it if it isn't already closed
session, exists := s.sessions[requestID]
if exists {
// We ignore any errors when attempting to close the session
_ = session.Close()
}
delete(s.sessions, requestID)
}

74
quic/v3/manager_test.go Normal file
View File

@ -0,0 +1,74 @@
package v3_test
import (
"errors"
"net/netip"
"strings"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/ingress"
v3 "github.com/cloudflare/cloudflared/quic/v3"
)
func TestRegisterSession(t *testing.T) {
log := zerolog.Nop()
manager := v3.NewSessionManager(&log, ingress.DialUDPAddrPort)
request := v3.UDPSessionRegistrationDatagram{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("127.0.0.1:5000"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: nil,
}
session, err := manager.RegisterSession(&request, &noopEyeball{})
if err != nil {
t.Fatalf("register session should've succeeded: %v", err)
}
if request.RequestID != session.ID() {
t.Fatalf("session id doesn't match: %v != %v", request.RequestID, session.ID())
}
// We shouldn't be able to register another session with the same request id
_, err = manager.RegisterSession(&request, &noopEyeball{})
if !errors.Is(err, v3.ErrSessionBoundToOtherConn) {
t.Fatalf("session should not be able to be registered again: %v", err)
}
// Get session
sessionGet, err := manager.GetSession(request.RequestID)
if err != nil {
t.Fatalf("get session failed: %v", err)
}
if session.ID() != sessionGet.ID() {
t.Fatalf("session's do not match: %v != %v", session.ID(), sessionGet.ID())
}
// Remove the session
manager.UnregisterSession(request.RequestID)
// Get session should fail
_, err = manager.GetSession(request.RequestID)
if !errors.Is(err, v3.ErrSessionNotFound) {
t.Fatalf("get session failed: %v", err)
}
// Closing the original session should return that the socket is already closed (by the session unregistration)
err = session.Close()
if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
t.Fatalf("session should've closed without issue: %v", err)
}
}
func TestGetSession_Empty(t *testing.T) {
log := zerolog.Nop()
manager := v3.NewSessionManager(&log, ingress.DialUDPAddrPort)
_, err := manager.GetSession(testRequestID)
if !errors.Is(err, v3.ErrSessionNotFound) {
t.Fatalf("get session find no session: %v", err)
}
}

8
quic/v3/muxer.go Normal file
View File

@ -0,0 +1,8 @@
package v3
// DatagramWriter provides the Muxer interface to create proper Datagrams when sending over a connection.
type DatagramWriter interface {
SendUDPSessionDatagram(datagram []byte) error
SendUDPSessionResponse(id RequestID, resp SessionRegistrationResp) error
//SendICMPPacket(packet packet.IP) error
}

50
quic/v3/muxer_test.go Normal file
View File

@ -0,0 +1,50 @@
package v3_test
import v3 "github.com/cloudflare/cloudflared/quic/v3"
type noopEyeball struct{}
func (noopEyeball) SendUDPSessionDatagram(datagram []byte) error {
return nil
}
func (noopEyeball) SendUDPSessionResponse(id v3.RequestID, resp v3.SessionRegistrationResp) error {
return nil
}
type mockEyeball struct {
// datagram sent via SendUDPSessionDatagram
recvData chan []byte
// responses sent via SendUDPSessionResponse
recvResp chan struct {
id v3.RequestID
resp v3.SessionRegistrationResp
}
}
func newMockEyeball() mockEyeball {
return mockEyeball{
recvData: make(chan []byte, 1),
recvResp: make(chan struct {
id v3.RequestID
resp v3.SessionRegistrationResp
}, 1),
}
}
func (m *mockEyeball) SendUDPSessionDatagram(datagram []byte) error {
b := make([]byte, len(datagram))
copy(b, datagram)
m.recvData <- b
return nil
}
func (m *mockEyeball) SendUDPSessionResponse(id v3.RequestID, resp v3.SessionRegistrationResp) error {
m.recvResp <- struct {
id v3.RequestID
resp v3.SessionRegistrationResp
}{
id, resp,
}
return nil
}

View File

@ -3,6 +3,7 @@ package v3
import ( import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
) )
const ( const (
@ -37,6 +38,10 @@ func RequestIDFromSlice(data []byte) (RequestID, error) {
}, nil }, nil
} }
func (id RequestID) String() string {
return fmt.Sprintf("%016x%016x", id.hi, id.lo)
}
// Compare returns an integer comparing two IPs. // Compare returns an integer comparing two IPs.
// The result will be 0 if id == id2, -1 if id < id2, and +1 if id > id2. // The result will be 0 if id == id2, -1 if id < id2, and +1 if id > id2.
// The definition of "less than" is the same as the [RequestID.Less] method. // The definition of "less than" is the same as the [RequestID.Less] method.

192
quic/v3/session.go Normal file
View File

@ -0,0 +1,192 @@
package v3
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/rs/zerolog"
)
const (
// A default is provided in the case that the client does not provide a close idle timeout.
defaultCloseIdleAfter = 210 * time.Second
// The maximum payload from the origin that we will be able to read. However, even though we will
// read 1500 bytes from the origin, we limit the amount of bytes to be proxied to less than
// this value (maxDatagramPayloadLen).
maxOriginUDPPacketSize = 1500
)
// SessionCloseErr indicates that the session's Close method was called.
var SessionCloseErr error = errors.New("session was closed")
// SessionIdleErr is returned when the session was closed because there was no communication
// in either direction over the session for the timeout period.
type SessionIdleErr struct {
timeout time.Duration
}
func (e SessionIdleErr) Error() string {
return fmt.Sprintf("session idle for %v", e.timeout)
}
func (e SessionIdleErr) Is(target error) bool {
_, ok := target.(SessionIdleErr)
return ok
}
func newSessionIdleErr(timeout time.Duration) error {
return SessionIdleErr{timeout}
}
type Session interface {
io.WriteCloser
ID() RequestID
// Serve starts the event loop for processing UDP packets
Serve(ctx context.Context) error
}
type session struct {
id RequestID
closeAfterIdle time.Duration
origin io.ReadWriteCloser
eyeball DatagramWriter
// activeAtChan is used to communicate the last read/write time
activeAtChan chan time.Time
closeChan chan error
log *zerolog.Logger
}
func NewSession(id RequestID, closeAfterIdle time.Duration, origin io.ReadWriteCloser, eyeball DatagramWriter, log *zerolog.Logger) Session {
return &session{
id: id,
closeAfterIdle: closeAfterIdle,
origin: origin,
eyeball: eyeball,
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
// drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 1),
closeChan: make(chan error, 1),
log: log,
}
}
func (s *session) ID() RequestID {
return s.id
}
func (s *session) Serve(ctx context.Context) error {
go func() {
// QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975
// This makes it safe to share readBuffer between iterations
readBuffer := [maxOriginUDPPacketSize + datagramPayloadHeaderLen]byte{}
// To perform a zero copy write when passing the datagram to the connection, we prepare the buffer with
// the required datagram header information. We can reuse this buffer for this session since the header is the
// same for the each read.
MarshalPayloadHeaderTo(s.id, readBuffer[:datagramPayloadHeaderLen])
for {
// Read from the origin UDP socket
n, err := s.origin.Read(readBuffer[datagramPayloadHeaderLen:])
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
s.log.Debug().Msg("Session (origin) connection closed")
}
if err != nil {
s.closeChan <- err
return
}
if n < 0 {
s.log.Warn().Int("packetSize", n).Msg("Session (origin) packet read was negative and was dropped")
continue
}
if n > maxDatagramPayloadLen {
s.log.Error().Int("packetSize", n).Msg("Session (origin) packet read was too large and was dropped")
continue
}
// Sending a packet to the session does block on the [quic.Connection], however, this is okay because it
// will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge.
err = s.eyeball.SendUDPSessionDatagram(readBuffer[:datagramPayloadHeaderLen+n])
if err != nil {
s.closeChan <- err
return
}
// Mark the session as active since we proxied a valid packet from the origin.
s.markActive()
}
}()
return s.waitForCloseCondition(ctx, s.closeAfterIdle)
}
func (s *session) Write(payload []byte) (n int, err error) {
n, err = s.origin.Write(payload)
if err != nil {
s.log.Err(err).Msg("Failed to write payload to session (remote)")
return n, err
}
// Write must return a non-nil error if it returns n < len(p). https://pkg.go.dev/io#Writer
if n < len(payload) {
s.log.Err(io.ErrShortWrite).Msg("Failed to write the full payload to session (remote)")
return n, io.ErrShortWrite
}
// Mark the session as active since we proxied a packet to the origin.
s.markActive()
return n, err
}
// Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there
// are many concurrent read/write. It is fine to lose some precision
func (s *session) markActive() {
select {
case s.activeAtChan <- time.Now():
default:
}
}
func (s *session) Close() error {
// Make sure that we only close the origin connection once
return sync.OnceValue(func() error {
// We don't want to block on sending to the close channel if it is already full
select {
case s.closeChan <- SessionCloseErr:
default:
}
return s.origin.Close()
})()
}
func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error {
// Closing the session at the end cancels read so Serve() can return
defer s.Close()
if closeAfterIdle == 0 {
// provide deafult is caller doesn't specify one
closeAfterIdle = defaultCloseIdleAfter
}
checkIdleTimer := time.NewTimer(closeAfterIdle)
defer checkIdleTimer.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case reason := <-s.closeChan:
return reason
case <-checkIdleTimer.C:
// The check idle timer will only return after an idle period since the last active
// operation (read or write).
return newSessionIdleErr(closeAfterIdle)
case <-s.activeAtChan:
// The session is still active, we want to reset the timer. First we have to stop the timer, drain the
// current value and then reset. It's okay if we lose some time on this operation as we don't need to
// close an idle session directly on-time.
if !checkIdleTimer.Stop() {
<-checkIdleTimer.C
}
checkIdleTimer.Reset(closeAfterIdle)
}
}
}

View File

@ -0,0 +1,23 @@
package v3_test
import (
"testing"
)
// FuzzSessionWrite verifies that we don't run into any panics when writing variable sized payloads to the origin.
func FuzzSessionWrite(f *testing.F) {
f.Fuzz(func(t *testing.T, b []byte) {
testSessionWrite(t, b)
})
}
// FuzzSessionServe verifies that we don't run into any panics when reading variable sized payloads from the origin.
func FuzzSessionServe(f *testing.F) {
f.Fuzz(func(t *testing.T, b []byte) {
// The origin transport read is bound to 1280 bytes
if len(b) > 1280 {
b = b[:1280]
}
testSessionServe_Origin(t, b)
})
}

283
quic/v3/session_test.go Normal file
View File

@ -0,0 +1,283 @@
package v3_test
import (
"context"
"errors"
"net"
"slices"
"sync/atomic"
"testing"
"time"
"github.com/rs/zerolog"
v3 "github.com/cloudflare/cloudflared/quic/v3"
)
var expectedContextCanceled = errors.New("expected context canceled")
func TestSessionNew(t *testing.T) {
log := zerolog.Nop()
session := v3.NewSession(testRequestID, 5*time.Second, nil, &noopEyeball{}, &log)
if testRequestID != session.ID() {
t.Fatalf("session id doesn't match: %s != %s", testRequestID, session.ID())
}
}
func testSessionWrite(t *testing.T, payload []byte) {
log := zerolog.Nop()
origin := newTestOrigin(makePayload(1280))
session := v3.NewSession(testRequestID, 5*time.Second, &origin, &noopEyeball{}, &log)
n, err := session.Write(payload)
if err != nil {
t.Fatal(err)
}
if n != len(payload) {
t.Fatal("unable to write the whole payload")
}
if !slices.Equal(payload, origin.write[:len(payload)]) {
t.Fatal("payload provided from origin and read value are not the same")
}
}
func TestSessionWrite_Max(t *testing.T) {
payload := makePayload(1280)
testSessionWrite(t, payload)
}
func TestSessionWrite_Min(t *testing.T) {
payload := makePayload(0)
testSessionWrite(t, payload)
}
func TestSessionServe_OriginMax(t *testing.T) {
payload := makePayload(1280)
testSessionServe_Origin(t, payload)
}
func TestSessionServe_OriginMin(t *testing.T) {
payload := makePayload(0)
testSessionServe_Origin(t, payload)
}
func testSessionServe_Origin(t *testing.T, payload []byte) {
log := zerolog.Nop()
eyeball := newMockEyeball()
origin := newTestOrigin(payload)
session := v3.NewSession(testRequestID, 3*time.Second, &origin, &eyeball, &log)
defer session.Close()
ctx, cancel := context.WithCancelCause(context.Background())
defer cancel(context.Canceled)
done := make(chan error)
go func() {
done <- session.Serve(ctx)
}()
select {
case data := <-eyeball.recvData:
// check received data matches provided from origin
expectedData := makePayload(1500)
v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
copy(expectedData[17:], payload)
if !slices.Equal(expectedData[:17+len(payload)], data) {
t.Fatal("expected datagram did not equal expected")
}
cancel(expectedContextCanceled)
case err := <-ctx.Done():
// we expect the payload to return before the context to cancel on the session
t.Fatal(err)
}
err := <-done
if !errors.Is(err, context.Canceled) {
t.Fatal(err)
}
if !errors.Is(context.Cause(ctx), expectedContextCanceled) {
t.Fatal(err)
}
}
func TestSessionServe_OriginTooLarge(t *testing.T) {
log := zerolog.Nop()
eyeball := newMockEyeball()
payload := makePayload(1281)
origin := newTestOrigin(payload)
session := v3.NewSession(testRequestID, 2*time.Second, &origin, &eyeball, &log)
defer session.Close()
done := make(chan error)
go func() {
done <- session.Serve(context.Background())
}()
select {
case data := <-eyeball.recvData:
// we never expect a read to make it here because the origin provided a payload that is too large
// for cloudflared to proxy and it will drop it.
t.Fatalf("we should never proxy a payload of this size: %d", len(data))
case err := <-done:
if !errors.Is(err, v3.SessionIdleErr{}) {
t.Error(err)
}
}
}
func TestSessionClose_Multiple(t *testing.T) {
log := zerolog.Nop()
origin := newTestOrigin(makePayload(128))
session := v3.NewSession(testRequestID, 5*time.Second, &origin, &noopEyeball{}, &log)
err := session.Close()
if err != nil {
t.Fatal(err)
}
if !origin.closed.Load() {
t.Fatal("origin wasn't closed")
}
// subsequent closes shouldn't call close again or cause any errors
err = session.Close()
if err != nil {
t.Fatal(err)
}
}
func TestSessionServe_IdleTimeout(t *testing.T) {
log := zerolog.Nop()
origin := newTestIdleOrigin(10 * time.Second) // Make idle time longer than closeAfterIdle
closeAfterIdle := 2 * time.Second
session := v3.NewSession(testRequestID, closeAfterIdle, &origin, &noopEyeball{}, &log)
err := session.Serve(context.Background())
if !errors.Is(err, v3.SessionIdleErr{}) {
t.Fatal(err)
}
// session should be closed
if !origin.closed {
t.Fatalf("session should be closed after Serve returns")
}
// closing a session again should not return an error
err = session.Close()
if err != nil {
t.Fatal(err)
}
}
func TestSessionServe_ParentContextCanceled(t *testing.T) {
log := zerolog.Nop()
// Make idle time and idle timeout longer than closeAfterIdle
origin := newTestIdleOrigin(10 * time.Second)
closeAfterIdle := 10 * time.Second
session := v3.NewSession(testRequestID, closeAfterIdle, &origin, &noopEyeball{}, &log)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err := session.Serve(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal(err)
}
// session should be closed
if !origin.closed {
t.Fatalf("session should be closed after Serve returns")
}
// closing a session again should not return an error
err = session.Close()
if err != nil {
t.Fatal(err)
}
}
func TestSessionServe_ReadErrors(t *testing.T) {
log := zerolog.Nop()
origin := newTestErrOrigin(net.ErrClosed, nil)
session := v3.NewSession(testRequestID, 30*time.Second, &origin, &noopEyeball{}, &log)
err := session.Serve(context.Background())
if !errors.Is(err, net.ErrClosed) {
t.Fatal(err)
}
}
type testOrigin struct {
// bytes from Write
write []byte
// bytes provided to Read
read []byte
readOnce atomic.Bool
closed atomic.Bool
}
func newTestOrigin(payload []byte) testOrigin {
return testOrigin{
read: payload,
}
}
func (o *testOrigin) Read(p []byte) (n int, err error) {
if o.closed.Load() {
return -1, net.ErrClosed
}
if o.readOnce.Load() {
// We only want to provide one read so all other reads will be blocked
time.Sleep(10 * time.Second)
}
o.readOnce.Store(true)
return copy(p, o.read), nil
}
func (o *testOrigin) Write(p []byte) (n int, err error) {
if o.closed.Load() {
return -1, net.ErrClosed
}
o.write = make([]byte, len(p))
copy(o.write, p)
return len(p), nil
}
func (o *testOrigin) Close() error {
o.closed.Store(true)
return nil
}
type testIdleOrigin struct {
duration time.Duration
closed bool
}
func newTestIdleOrigin(d time.Duration) testIdleOrigin {
return testIdleOrigin{
duration: d,
}
}
func (o *testIdleOrigin) Read(p []byte) (n int, err error) {
time.Sleep(o.duration)
return 0, nil
}
func (o *testIdleOrigin) Write(p []byte) (n int, err error) {
return 0, nil
}
func (o *testIdleOrigin) Close() error {
o.closed = true
return nil
}
type testErrOrigin struct {
readErr error
writeErr error
}
func newTestErrOrigin(readErr error, writeErr error) testErrOrigin {
return testErrOrigin{readErr, writeErr}
}
func (o *testErrOrigin) Read(p []byte) (n int, err error) {
return 0, o.readErr
}
func (o *testErrOrigin) Write(p []byte) (n int, err error) {
return len(p), o.writeErr
}
func (o *testErrOrigin) Close() error {
return nil
}