diff --git a/ingress/origin_udp_proxy.go b/ingress/origin_udp_proxy.go index 836489be..f553e30d 100644 --- a/ingress/origin_udp_proxy.go +++ b/ingress/origin_udp_proxy.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net" + "net/netip" ) type UDPProxy interface { @@ -30,3 +31,16 @@ func DialUDP(dstIP net.IP, dstPort uint16) (UDPProxy, error) { return &udpProxy{udpConn}, nil } + +func DialUDPAddrPort(dest netip.AddrPort) (*net.UDPConn, error) { + addr := net.UDPAddrFromAddrPort(dest) + + // We use nil as local addr to force runtime to find the best suitable local address IP given the destination + // address as context. + udpConn, err := net.DialUDP("udp", nil, addr) + if err != nil { + return nil, fmt.Errorf("unable to create UDP proxy to origin (%v:%v): %w", dest.Addr(), dest.Port(), err) + } + + return udpConn, nil +} diff --git a/quic/v3/datagram.go b/quic/v3/datagram.go index d5c2ac1b..a17804c2 100644 --- a/quic/v3/datagram.go +++ b/quic/v3/datagram.go @@ -24,7 +24,7 @@ const ( datagramTypeLen = 1 // 1280 is the default datagram packet length used before MTU discovery: https://github.com/quic-go/quic-go/blob/v0.45.0/internal/protocol/params.go#L12 - maxDatagramLen = 1280 + maxDatagramPayloadLen = 1280 ) func parseDatagramType(data []byte) (DatagramType, error) { @@ -100,10 +100,10 @@ func (s *UDPSessionRegistrationDatagram) MarshalBinary() (data []byte, err error } var maxPayloadLen int if ipv6 { - maxPayloadLen = maxDatagramLen - sessionRegistrationIPv6DatagramHeaderLen + maxPayloadLen = maxDatagramPayloadLen + sessionRegistrationIPv6DatagramHeaderLen flags |= sessionRegistrationFlagsIPMask } else { - maxPayloadLen = maxDatagramLen - sessionRegistrationIPv4DatagramHeaderLen + maxPayloadLen = maxDatagramPayloadLen + sessionRegistrationIPv4DatagramHeaderLen } // Make sure that the payload being bundled can actually fit in the payload destination if len(s.Payload) > maxPayloadLen { @@ -195,7 +195,7 @@ const ( datagramPayloadHeaderLen = datagramTypeLen + datagramRequestIdLen // The maximum size that a proxied UDP payload can be in a [UDPSessionPayloadDatagram] - maxPayloadPlusHeaderLen = maxDatagramLen - datagramPayloadHeaderLen + maxPayloadPlusHeaderLen = maxDatagramPayloadLen + datagramPayloadHeaderLen ) // The datagram structure for UDPSessionPayloadDatagram is: @@ -270,7 +270,7 @@ const ( datagramSessionRegistrationResponseLen = datagramTypeLen + datagramRespTypeLen + datagramRequestIdLen + datagramRespErrMsgLen // The maximum size that an error message can be in a [UDPSessionRegistrationResponseDatagram]. - maxResponseErrorMessageLen = maxDatagramLen - datagramSessionRegistrationResponseLen + maxResponseErrorMessageLen = maxDatagramPayloadLen - datagramSessionRegistrationResponseLen ) // SessionRegistrationResp represents all of the responses that a UDP session registration response diff --git a/quic/v3/datagram_test.go b/quic/v3/datagram_test.go index b2e77f89..ff46ef24 100644 --- a/quic/v3/datagram_test.go +++ b/quic/v3/datagram_test.go @@ -21,7 +21,7 @@ func makePayload(size int) []byte { } func TestSessionRegistration_MarshalUnmarshal(t *testing.T) { - payload := makePayload(1254) + payload := makePayload(1280) tests := []*v3.UDPSessionRegistrationDatagram{ // Default (IPv4) { @@ -236,7 +236,7 @@ func TestSessionPayload(t *testing.T) { }) t.Run("payload size too large", func(t *testing.T) { - datagram := makePayload(17 + 1264) // 1263 is the largest payload size allowed + datagram := makePayload(17 + 1281) // 1280 is the largest payload size allowed err := v3.MarshalPayloadHeaderTo(testRequestID, datagram) if err != nil { t.Error(err) diff --git a/quic/v3/manager.go b/quic/v3/manager.go new file mode 100644 index 00000000..49c0fec1 --- /dev/null +++ b/quic/v3/manager.go @@ -0,0 +1,87 @@ +package v3 + +import ( + "errors" + "net" + "net/netip" + "sync" + + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/ingress" +) + +var ( + ErrSessionNotFound = errors.New("session not found") + ErrSessionBoundToOtherConn = errors.New("session is in use by another connection") +) + +type SessionManager interface { + // RegisterSession will register a new session if it does not already exist for the request ID. + // During new session creation, the session will also bind the UDP socket for the origin. + // If the session exists for a different connection, it will return [ErrSessionBoundToOtherConn]. + RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramWriter) (Session, error) + // GetSession returns an active session if available for the provided connection. + // If the session does not exist, it will return [ErrSessionNotFound]. If the session exists for a different + // connection, it will return [ErrSessionBoundToOtherConn]. + GetSession(requestID RequestID) (Session, error) + // UnregisterSession will remove a session from the current session manager. It will attempt to close the session + // before removal. + UnregisterSession(requestID RequestID) +} + +type DialUDP func(dest netip.AddrPort) (*net.UDPConn, error) + +type sessionManager struct { + sessions map[RequestID]Session + mutex sync.RWMutex + log *zerolog.Logger +} + +func NewSessionManager(log *zerolog.Logger, originDialer DialUDP) SessionManager { + return &sessionManager{ + sessions: make(map[RequestID]Session), + log: log, + } +} + +func (s *sessionManager) RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramWriter) (Session, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + // Check to make sure session doesn't already exist for requestID + _, exists := s.sessions[request.RequestID] + if exists { + return nil, ErrSessionBoundToOtherConn + } + // Attempt to bind the UDP socket for the new session + origin, err := ingress.DialUDPAddrPort(request.Dest) + if err != nil { + return nil, err + } + // Create and insert the new session in the map + session := NewSession(request.RequestID, request.IdleDurationHint, origin, conn, s.log) + s.sessions[request.RequestID] = session + return session, nil +} + +func (s *sessionManager) GetSession(requestID RequestID) (Session, error) { + s.mutex.RLock() + defer s.mutex.RUnlock() + session, exists := s.sessions[requestID] + if exists { + return session, nil + } + return nil, ErrSessionNotFound +} + +func (s *sessionManager) UnregisterSession(requestID RequestID) { + s.mutex.Lock() + defer s.mutex.Unlock() + // Get the session and make sure to close it if it isn't already closed + session, exists := s.sessions[requestID] + if exists { + // We ignore any errors when attempting to close the session + _ = session.Close() + } + delete(s.sessions, requestID) +} diff --git a/quic/v3/manager_test.go b/quic/v3/manager_test.go new file mode 100644 index 00000000..93e959dd --- /dev/null +++ b/quic/v3/manager_test.go @@ -0,0 +1,74 @@ +package v3_test + +import ( + "errors" + "net/netip" + "strings" + "testing" + "time" + + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/ingress" + v3 "github.com/cloudflare/cloudflared/quic/v3" +) + +func TestRegisterSession(t *testing.T) { + log := zerolog.Nop() + manager := v3.NewSessionManager(&log, ingress.DialUDPAddrPort) + + request := v3.UDPSessionRegistrationDatagram{ + RequestID: testRequestID, + Dest: netip.MustParseAddrPort("127.0.0.1:5000"), + Traced: false, + IdleDurationHint: 5 * time.Second, + Payload: nil, + } + session, err := manager.RegisterSession(&request, &noopEyeball{}) + if err != nil { + t.Fatalf("register session should've succeeded: %v", err) + } + if request.RequestID != session.ID() { + t.Fatalf("session id doesn't match: %v != %v", request.RequestID, session.ID()) + } + + // We shouldn't be able to register another session with the same request id + _, err = manager.RegisterSession(&request, &noopEyeball{}) + if !errors.Is(err, v3.ErrSessionBoundToOtherConn) { + t.Fatalf("session should not be able to be registered again: %v", err) + } + + // Get session + sessionGet, err := manager.GetSession(request.RequestID) + if err != nil { + t.Fatalf("get session failed: %v", err) + } + if session.ID() != sessionGet.ID() { + t.Fatalf("session's do not match: %v != %v", session.ID(), sessionGet.ID()) + } + + // Remove the session + manager.UnregisterSession(request.RequestID) + + // Get session should fail + _, err = manager.GetSession(request.RequestID) + if !errors.Is(err, v3.ErrSessionNotFound) { + t.Fatalf("get session failed: %v", err) + } + + // Closing the original session should return that the socket is already closed (by the session unregistration) + err = session.Close() + if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { + t.Fatalf("session should've closed without issue: %v", err) + } +} + +func TestGetSession_Empty(t *testing.T) { + log := zerolog.Nop() + manager := v3.NewSessionManager(&log, ingress.DialUDPAddrPort) + + _, err := manager.GetSession(testRequestID) + if !errors.Is(err, v3.ErrSessionNotFound) { + t.Fatalf("get session find no session: %v", err) + } +} diff --git a/quic/v3/muxer.go b/quic/v3/muxer.go new file mode 100644 index 00000000..fda16bbe --- /dev/null +++ b/quic/v3/muxer.go @@ -0,0 +1,8 @@ +package v3 + +// DatagramWriter provides the Muxer interface to create proper Datagrams when sending over a connection. +type DatagramWriter interface { + SendUDPSessionDatagram(datagram []byte) error + SendUDPSessionResponse(id RequestID, resp SessionRegistrationResp) error + //SendICMPPacket(packet packet.IP) error +} diff --git a/quic/v3/muxer_test.go b/quic/v3/muxer_test.go new file mode 100644 index 00000000..552281a5 --- /dev/null +++ b/quic/v3/muxer_test.go @@ -0,0 +1,50 @@ +package v3_test + +import v3 "github.com/cloudflare/cloudflared/quic/v3" + +type noopEyeball struct{} + +func (noopEyeball) SendUDPSessionDatagram(datagram []byte) error { + return nil +} + +func (noopEyeball) SendUDPSessionResponse(id v3.RequestID, resp v3.SessionRegistrationResp) error { + return nil +} + +type mockEyeball struct { + // datagram sent via SendUDPSessionDatagram + recvData chan []byte + // responses sent via SendUDPSessionResponse + recvResp chan struct { + id v3.RequestID + resp v3.SessionRegistrationResp + } +} + +func newMockEyeball() mockEyeball { + return mockEyeball{ + recvData: make(chan []byte, 1), + recvResp: make(chan struct { + id v3.RequestID + resp v3.SessionRegistrationResp + }, 1), + } +} + +func (m *mockEyeball) SendUDPSessionDatagram(datagram []byte) error { + b := make([]byte, len(datagram)) + copy(b, datagram) + m.recvData <- b + return nil +} + +func (m *mockEyeball) SendUDPSessionResponse(id v3.RequestID, resp v3.SessionRegistrationResp) error { + m.recvResp <- struct { + id v3.RequestID + resp v3.SessionRegistrationResp + }{ + id, resp, + } + return nil +} diff --git a/quic/v3/request.go b/quic/v3/request.go index 29509e83..d939b373 100644 --- a/quic/v3/request.go +++ b/quic/v3/request.go @@ -3,6 +3,7 @@ package v3 import ( "encoding/binary" "errors" + "fmt" ) const ( @@ -37,6 +38,10 @@ func RequestIDFromSlice(data []byte) (RequestID, error) { }, nil } +func (id RequestID) String() string { + return fmt.Sprintf("%016x%016x", id.hi, id.lo) +} + // Compare returns an integer comparing two IPs. // The result will be 0 if id == id2, -1 if id < id2, and +1 if id > id2. // The definition of "less than" is the same as the [RequestID.Less] method. diff --git a/quic/v3/session.go b/quic/v3/session.go new file mode 100644 index 00000000..e05a91c5 --- /dev/null +++ b/quic/v3/session.go @@ -0,0 +1,192 @@ +package v3 + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/rs/zerolog" +) + +const ( + // A default is provided in the case that the client does not provide a close idle timeout. + defaultCloseIdleAfter = 210 * time.Second + + // The maximum payload from the origin that we will be able to read. However, even though we will + // read 1500 bytes from the origin, we limit the amount of bytes to be proxied to less than + // this value (maxDatagramPayloadLen). + maxOriginUDPPacketSize = 1500 +) + +// SessionCloseErr indicates that the session's Close method was called. +var SessionCloseErr error = errors.New("session was closed") + +// SessionIdleErr is returned when the session was closed because there was no communication +// in either direction over the session for the timeout period. +type SessionIdleErr struct { + timeout time.Duration +} + +func (e SessionIdleErr) Error() string { + return fmt.Sprintf("session idle for %v", e.timeout) +} + +func (e SessionIdleErr) Is(target error) bool { + _, ok := target.(SessionIdleErr) + return ok +} + +func newSessionIdleErr(timeout time.Duration) error { + return SessionIdleErr{timeout} +} + +type Session interface { + io.WriteCloser + ID() RequestID + // Serve starts the event loop for processing UDP packets + Serve(ctx context.Context) error +} + +type session struct { + id RequestID + closeAfterIdle time.Duration + origin io.ReadWriteCloser + eyeball DatagramWriter + // activeAtChan is used to communicate the last read/write time + activeAtChan chan time.Time + closeChan chan error + log *zerolog.Logger +} + +func NewSession(id RequestID, closeAfterIdle time.Duration, origin io.ReadWriteCloser, eyeball DatagramWriter, log *zerolog.Logger) Session { + return &session{ + id: id, + closeAfterIdle: closeAfterIdle, + origin: origin, + eyeball: eyeball, + // activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will + // drop instead of blocking because last active time only needs to be an approximation + activeAtChan: make(chan time.Time, 1), + closeChan: make(chan error, 1), + log: log, + } +} + +func (s *session) ID() RequestID { + return s.id +} + +func (s *session) Serve(ctx context.Context) error { + go func() { + // QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975 + // This makes it safe to share readBuffer between iterations + readBuffer := [maxOriginUDPPacketSize + datagramPayloadHeaderLen]byte{} + // To perform a zero copy write when passing the datagram to the connection, we prepare the buffer with + // the required datagram header information. We can reuse this buffer for this session since the header is the + // same for the each read. + MarshalPayloadHeaderTo(s.id, readBuffer[:datagramPayloadHeaderLen]) + for { + // Read from the origin UDP socket + n, err := s.origin.Read(readBuffer[datagramPayloadHeaderLen:]) + if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + s.log.Debug().Msg("Session (origin) connection closed") + } + if err != nil { + s.closeChan <- err + return + } + if n < 0 { + s.log.Warn().Int("packetSize", n).Msg("Session (origin) packet read was negative and was dropped") + continue + } + if n > maxDatagramPayloadLen { + s.log.Error().Int("packetSize", n).Msg("Session (origin) packet read was too large and was dropped") + continue + } + // Sending a packet to the session does block on the [quic.Connection], however, this is okay because it + // will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge. + err = s.eyeball.SendUDPSessionDatagram(readBuffer[:datagramPayloadHeaderLen+n]) + if err != nil { + s.closeChan <- err + return + } + // Mark the session as active since we proxied a valid packet from the origin. + s.markActive() + } + }() + return s.waitForCloseCondition(ctx, s.closeAfterIdle) +} + +func (s *session) Write(payload []byte) (n int, err error) { + n, err = s.origin.Write(payload) + if err != nil { + s.log.Err(err).Msg("Failed to write payload to session (remote)") + return n, err + } + // Write must return a non-nil error if it returns n < len(p). https://pkg.go.dev/io#Writer + if n < len(payload) { + s.log.Err(io.ErrShortWrite).Msg("Failed to write the full payload to session (remote)") + return n, io.ErrShortWrite + } + // Mark the session as active since we proxied a packet to the origin. + s.markActive() + return n, err +} + +// Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there +// are many concurrent read/write. It is fine to lose some precision +func (s *session) markActive() { + select { + case s.activeAtChan <- time.Now(): + default: + } +} + +func (s *session) Close() error { + // Make sure that we only close the origin connection once + return sync.OnceValue(func() error { + // We don't want to block on sending to the close channel if it is already full + select { + case s.closeChan <- SessionCloseErr: + default: + } + return s.origin.Close() + })() +} + +func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error { + // Closing the session at the end cancels read so Serve() can return + defer s.Close() + if closeAfterIdle == 0 { + // provide deafult is caller doesn't specify one + closeAfterIdle = defaultCloseIdleAfter + } + + checkIdleTimer := time.NewTimer(closeAfterIdle) + defer checkIdleTimer.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case reason := <-s.closeChan: + return reason + case <-checkIdleTimer.C: + // The check idle timer will only return after an idle period since the last active + // operation (read or write). + return newSessionIdleErr(closeAfterIdle) + case <-s.activeAtChan: + // The session is still active, we want to reset the timer. First we have to stop the timer, drain the + // current value and then reset. It's okay if we lose some time on this operation as we don't need to + // close an idle session directly on-time. + if !checkIdleTimer.Stop() { + <-checkIdleTimer.C + } + checkIdleTimer.Reset(closeAfterIdle) + } + } +} diff --git a/quic/v3/session_fuzz_test.go b/quic/v3/session_fuzz_test.go new file mode 100644 index 00000000..0e4952c0 --- /dev/null +++ b/quic/v3/session_fuzz_test.go @@ -0,0 +1,23 @@ +package v3_test + +import ( + "testing" +) + +// FuzzSessionWrite verifies that we don't run into any panics when writing variable sized payloads to the origin. +func FuzzSessionWrite(f *testing.F) { + f.Fuzz(func(t *testing.T, b []byte) { + testSessionWrite(t, b) + }) +} + +// FuzzSessionServe verifies that we don't run into any panics when reading variable sized payloads from the origin. +func FuzzSessionServe(f *testing.F) { + f.Fuzz(func(t *testing.T, b []byte) { + // The origin transport read is bound to 1280 bytes + if len(b) > 1280 { + b = b[:1280] + } + testSessionServe_Origin(t, b) + }) +} diff --git a/quic/v3/session_test.go b/quic/v3/session_test.go new file mode 100644 index 00000000..8c25878d --- /dev/null +++ b/quic/v3/session_test.go @@ -0,0 +1,283 @@ +package v3_test + +import ( + "context" + "errors" + "net" + "slices" + "sync/atomic" + "testing" + "time" + + "github.com/rs/zerolog" + + v3 "github.com/cloudflare/cloudflared/quic/v3" +) + +var expectedContextCanceled = errors.New("expected context canceled") + +func TestSessionNew(t *testing.T) { + log := zerolog.Nop() + session := v3.NewSession(testRequestID, 5*time.Second, nil, &noopEyeball{}, &log) + if testRequestID != session.ID() { + t.Fatalf("session id doesn't match: %s != %s", testRequestID, session.ID()) + } +} + +func testSessionWrite(t *testing.T, payload []byte) { + log := zerolog.Nop() + origin := newTestOrigin(makePayload(1280)) + session := v3.NewSession(testRequestID, 5*time.Second, &origin, &noopEyeball{}, &log) + n, err := session.Write(payload) + if err != nil { + t.Fatal(err) + } + if n != len(payload) { + t.Fatal("unable to write the whole payload") + } + if !slices.Equal(payload, origin.write[:len(payload)]) { + t.Fatal("payload provided from origin and read value are not the same") + } +} + +func TestSessionWrite_Max(t *testing.T) { + payload := makePayload(1280) + testSessionWrite(t, payload) +} + +func TestSessionWrite_Min(t *testing.T) { + payload := makePayload(0) + testSessionWrite(t, payload) +} + +func TestSessionServe_OriginMax(t *testing.T) { + payload := makePayload(1280) + testSessionServe_Origin(t, payload) +} + +func TestSessionServe_OriginMin(t *testing.T) { + payload := makePayload(0) + testSessionServe_Origin(t, payload) +} + +func testSessionServe_Origin(t *testing.T, payload []byte) { + log := zerolog.Nop() + eyeball := newMockEyeball() + origin := newTestOrigin(payload) + session := v3.NewSession(testRequestID, 3*time.Second, &origin, &eyeball, &log) + defer session.Close() + + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(context.Canceled) + done := make(chan error) + go func() { + done <- session.Serve(ctx) + }() + + select { + case data := <-eyeball.recvData: + // check received data matches provided from origin + expectedData := makePayload(1500) + v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:]) + copy(expectedData[17:], payload) + if !slices.Equal(expectedData[:17+len(payload)], data) { + t.Fatal("expected datagram did not equal expected") + } + cancel(expectedContextCanceled) + case err := <-ctx.Done(): + // we expect the payload to return before the context to cancel on the session + t.Fatal(err) + } + + err := <-done + if !errors.Is(err, context.Canceled) { + t.Fatal(err) + } + if !errors.Is(context.Cause(ctx), expectedContextCanceled) { + t.Fatal(err) + } +} + +func TestSessionServe_OriginTooLarge(t *testing.T) { + log := zerolog.Nop() + eyeball := newMockEyeball() + payload := makePayload(1281) + origin := newTestOrigin(payload) + session := v3.NewSession(testRequestID, 2*time.Second, &origin, &eyeball, &log) + defer session.Close() + + done := make(chan error) + go func() { + done <- session.Serve(context.Background()) + }() + + select { + case data := <-eyeball.recvData: + // we never expect a read to make it here because the origin provided a payload that is too large + // for cloudflared to proxy and it will drop it. + t.Fatalf("we should never proxy a payload of this size: %d", len(data)) + case err := <-done: + if !errors.Is(err, v3.SessionIdleErr{}) { + t.Error(err) + } + } +} + +func TestSessionClose_Multiple(t *testing.T) { + log := zerolog.Nop() + origin := newTestOrigin(makePayload(128)) + session := v3.NewSession(testRequestID, 5*time.Second, &origin, &noopEyeball{}, &log) + err := session.Close() + if err != nil { + t.Fatal(err) + } + if !origin.closed.Load() { + t.Fatal("origin wasn't closed") + } + // subsequent closes shouldn't call close again or cause any errors + err = session.Close() + if err != nil { + t.Fatal(err) + } +} + +func TestSessionServe_IdleTimeout(t *testing.T) { + log := zerolog.Nop() + origin := newTestIdleOrigin(10 * time.Second) // Make idle time longer than closeAfterIdle + closeAfterIdle := 2 * time.Second + session := v3.NewSession(testRequestID, closeAfterIdle, &origin, &noopEyeball{}, &log) + err := session.Serve(context.Background()) + if !errors.Is(err, v3.SessionIdleErr{}) { + t.Fatal(err) + } + // session should be closed + if !origin.closed { + t.Fatalf("session should be closed after Serve returns") + } + // closing a session again should not return an error + err = session.Close() + if err != nil { + t.Fatal(err) + } +} + +func TestSessionServe_ParentContextCanceled(t *testing.T) { + log := zerolog.Nop() + // Make idle time and idle timeout longer than closeAfterIdle + origin := newTestIdleOrigin(10 * time.Second) + closeAfterIdle := 10 * time.Second + + session := v3.NewSession(testRequestID, closeAfterIdle, &origin, &noopEyeball{}, &log) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + err := session.Serve(ctx) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal(err) + } + // session should be closed + if !origin.closed { + t.Fatalf("session should be closed after Serve returns") + } + // closing a session again should not return an error + err = session.Close() + if err != nil { + t.Fatal(err) + } +} + +func TestSessionServe_ReadErrors(t *testing.T) { + log := zerolog.Nop() + origin := newTestErrOrigin(net.ErrClosed, nil) + session := v3.NewSession(testRequestID, 30*time.Second, &origin, &noopEyeball{}, &log) + err := session.Serve(context.Background()) + if !errors.Is(err, net.ErrClosed) { + t.Fatal(err) + } +} + +type testOrigin struct { + // bytes from Write + write []byte + // bytes provided to Read + read []byte + readOnce atomic.Bool + closed atomic.Bool +} + +func newTestOrigin(payload []byte) testOrigin { + return testOrigin{ + read: payload, + } +} + +func (o *testOrigin) Read(p []byte) (n int, err error) { + if o.closed.Load() { + return -1, net.ErrClosed + } + if o.readOnce.Load() { + // We only want to provide one read so all other reads will be blocked + time.Sleep(10 * time.Second) + } + o.readOnce.Store(true) + return copy(p, o.read), nil +} + +func (o *testOrigin) Write(p []byte) (n int, err error) { + if o.closed.Load() { + return -1, net.ErrClosed + } + o.write = make([]byte, len(p)) + copy(o.write, p) + return len(p), nil +} + +func (o *testOrigin) Close() error { + o.closed.Store(true) + return nil +} + +type testIdleOrigin struct { + duration time.Duration + closed bool +} + +func newTestIdleOrigin(d time.Duration) testIdleOrigin { + return testIdleOrigin{ + duration: d, + } +} + +func (o *testIdleOrigin) Read(p []byte) (n int, err error) { + time.Sleep(o.duration) + return 0, nil +} + +func (o *testIdleOrigin) Write(p []byte) (n int, err error) { + return 0, nil +} + +func (o *testIdleOrigin) Close() error { + o.closed = true + return nil +} + +type testErrOrigin struct { + readErr error + writeErr error +} + +func newTestErrOrigin(readErr error, writeErr error) testErrOrigin { + return testErrOrigin{readErr, writeErr} +} + +func (o *testErrOrigin) Read(p []byte) (n int, err error) { + return 0, o.readErr +} + +func (o *testErrOrigin) Write(p []byte) (n int, err error) { + return len(p), o.writeErr +} + +func (o *testErrOrigin) Close() error { + return nil +}