diff --git a/connection/quic_datagram_v3.go b/connection/quic_datagram_v3.go index 3921d66c..2aab8966 100644 --- a/connection/quic_datagram_v3.go +++ b/connection/quic_datagram_v3.go @@ -24,9 +24,10 @@ type datagramV3Connection struct { func NewDatagramV3Connection(ctx context.Context, conn quic.Connection, sessionManager cfdquic.SessionManager, + index uint8, logger *zerolog.Logger, ) DatagramSessionHandler { - datagramMuxer := cfdquic.NewDatagramConn(conn, sessionManager, logger) + datagramMuxer := cfdquic.NewDatagramConn(conn, sessionManager, index, logger) return &datagramV3Connection{ conn, diff --git a/quic/v3/datagram.go b/quic/v3/datagram.go index bafeb15e..3c45e6b2 100644 --- a/quic/v3/datagram.go +++ b/quic/v3/datagram.go @@ -284,8 +284,6 @@ const ( ResponseDestinationUnreachable SessionRegistrationResp = 0x01 // Session registration was unable to bind to a local UDP socket. ResponseUnableToBindSocket SessionRegistrationResp = 0x02 - // Session registration is already bound to another connection. - ResponseSessionAlreadyConnected SessionRegistrationResp = 0x03 // Session registration failed with an unexpected error but provided a message. ResponseErrorWithMsg SessionRegistrationResp = 0xff ) diff --git a/quic/v3/manager.go b/quic/v3/manager.go index 49c0fec1..57314728 100644 --- a/quic/v3/manager.go +++ b/quic/v3/manager.go @@ -12,15 +12,19 @@ import ( ) var ( - ErrSessionNotFound = errors.New("session not found") + // ErrSessionNotFound indicates that a session has not been registered yet for the request id. + ErrSessionNotFound = errors.New("session not found") + // ErrSessionBoundToOtherConn is returned when a registration already exists for a different connection. ErrSessionBoundToOtherConn = errors.New("session is in use by another connection") + // ErrSessionAlreadyRegistered is returned when a registration already exists for this connection. + ErrSessionAlreadyRegistered = errors.New("session is already registered for this connection") ) type SessionManager interface { // RegisterSession will register a new session if it does not already exist for the request ID. // During new session creation, the session will also bind the UDP socket for the origin. // If the session exists for a different connection, it will return [ErrSessionBoundToOtherConn]. - RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramWriter) (Session, error) + RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramConn) (Session, error) // GetSession returns an active session if available for the provided connection. // If the session does not exist, it will return [ErrSessionNotFound]. If the session exists for a different // connection, it will return [ErrSessionBoundToOtherConn]. @@ -45,12 +49,14 @@ func NewSessionManager(log *zerolog.Logger, originDialer DialUDP) SessionManager } } -func (s *sessionManager) RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramWriter) (Session, error) { +func (s *sessionManager) RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramConn) (Session, error) { s.mutex.Lock() defer s.mutex.Unlock() // Check to make sure session doesn't already exist for requestID - _, exists := s.sessions[request.RequestID] - if exists { + if session, exists := s.sessions[request.RequestID]; exists { + if conn.ID() == session.ConnectionID() { + return nil, ErrSessionAlreadyRegistered + } return nil, ErrSessionBoundToOtherConn } // Attempt to bind the UDP socket for the new session diff --git a/quic/v3/manager_test.go b/quic/v3/manager_test.go index 93e959dd..0d93ac2f 100644 --- a/quic/v3/manager_test.go +++ b/quic/v3/manager_test.go @@ -34,8 +34,14 @@ func TestRegisterSession(t *testing.T) { // We shouldn't be able to register another session with the same request id _, err = manager.RegisterSession(&request, &noopEyeball{}) + if !errors.Is(err, v3.ErrSessionAlreadyRegistered) { + t.Fatalf("session is already registered for this connection: %v", err) + } + + // We shouldn't be able to register another session with the same request id for a different connection + _, err = manager.RegisterSession(&request, &noopEyeball{connID: 1}) if !errors.Is(err, v3.ErrSessionBoundToOtherConn) { - t.Fatalf("session should not be able to be registered again: %v", err) + t.Fatalf("session is already registered for a separate connection: %v", err) } // Get session diff --git a/quic/v3/muxer.go b/quic/v3/muxer.go index 7fd0c151..e34dd27b 100644 --- a/quic/v3/muxer.go +++ b/quic/v3/muxer.go @@ -19,6 +19,8 @@ type DatagramConn interface { DatagramWriter // Serve provides a server interface to process and handle incoming QUIC datagrams and demux their datagram v3 payloads. Serve(context.Context) error + // ID indicates connection index identifier + ID() uint8 } // DatagramWriter provides the Muxer interface to create proper Datagrams when sending over a connection. @@ -41,6 +43,7 @@ type QuicConnection interface { type datagramConn struct { conn QuicConnection + index uint8 sessionManager SessionManager logger *zerolog.Logger @@ -48,10 +51,11 @@ type datagramConn struct { readErrors chan error } -func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, logger *zerolog.Logger) DatagramConn { +func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, index uint8, logger *zerolog.Logger) DatagramConn { log := logger.With().Uint8("datagramVersion", 3).Logger() return &datagramConn{ conn: conn, + index: index, sessionManager: sessionManager, logger: &log, datagrams: make(chan []byte, demuxChanCapacity), @@ -59,6 +63,10 @@ func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, logger } } +func (c datagramConn) ID() uint8 { + return c.index +} + func (c *datagramConn) SendUDPSessionDatagram(datagram []byte) error { return c.conn.SendDatagram(datagram) } @@ -163,9 +171,20 @@ func (c *datagramConn) Serve(ctx context.Context) error { // This method handles new registrations of a session and the serve loop for the session. func (c *datagramConn) handleSessionRegistrationDatagram(ctx context.Context, datagram *UDPSessionRegistrationDatagram) { session, err := c.sessionManager.RegisterSession(datagram, c) - if err != nil { + switch err { + case nil: + // Continue as normal + case ErrSessionAlreadyRegistered: + // Session is already registered and likely the response got lost + c.handleSessionAlreadyRegistered(datagram.RequestID) + return + case ErrSessionBoundToOtherConn: + // Session is already registered but to a different connection + c.handleSessionMigration(datagram.RequestID) + return + default: c.logger.Err(err).Msgf("session registration failure") - c.handleSessionRegistrationFailure(datagram.RequestID, err) + c.handleSessionRegistrationFailure(datagram.RequestID) return } // Make sure to eventually remove the session from the session manager when the session is closed @@ -197,17 +216,49 @@ func (c *datagramConn) handleSessionRegistrationDatagram(ctx context.Context, da c.logger.Err(err).Msgf("session was closed with an error") } -func (c *datagramConn) handleSessionRegistrationFailure(requestID RequestID, regErr error) { - var errResp SessionRegistrationResp - switch regErr { - case ErrSessionBoundToOtherConn: - errResp = ResponseSessionAlreadyConnected - default: - errResp = ResponseUnableToBindSocket - } - err := c.SendUDPSessionResponse(requestID, errResp) +func (c *datagramConn) handleSessionAlreadyRegistered(requestID RequestID) { + // Send another registration response since the session is already active + err := c.SendUDPSessionResponse(requestID, ResponseOk) if err != nil { - c.logger.Err(err).Msgf("unable to send session registration error response (%d)", errResp) + c.logger.Err(err).Msgf("session registration failure: unable to send an additional session registration response") + return + } + + session, err := c.sessionManager.GetSession(requestID) + if err != nil { + // If for some reason we can not find the session after attempting to register it, we can just return + // instead of trying to reset the idle timer for it. + return + } + // The session is already running in another routine so we want to restart the idle timeout since no proxied + // packets have come down yet. + session.ResetIdleTimer() +} + +func (c *datagramConn) handleSessionMigration(requestID RequestID) { + // We need to migrate the currently running session to this edge connection. + session, err := c.sessionManager.GetSession(requestID) + if err != nil { + // If for some reason we can not find the session after attempting to register it, we can just return + // instead of trying to reset the idle timer for it. + return + } + + // Migrate the session to use this edge connection instead of the currently running one. + session.Migrate(c) + + // Send another registration response since the session is already active + err = c.SendUDPSessionResponse(requestID, ResponseOk) + if err != nil { + c.logger.Err(err).Msgf("session registration failure: unable to send an additional session registration response") + return + } +} + +func (c *datagramConn) handleSessionRegistrationFailure(requestID RequestID) { + err := c.SendUDPSessionResponse(requestID, ResponseUnableToBindSocket) + if err != nil { + c.logger.Err(err).Msgf("unable to send session registration error response (%d)", ResponseUnableToBindSocket) } } diff --git a/quic/v3/muxer_test.go b/quic/v3/muxer_test.go index b2cb7e06..b80ad172 100644 --- a/quic/v3/muxer_test.go +++ b/quic/v3/muxer_test.go @@ -17,17 +17,19 @@ import ( v3 "github.com/cloudflare/cloudflared/quic/v3" ) -type noopEyeball struct{} - -func (noopEyeball) SendUDPSessionDatagram(datagram []byte) error { - return nil +type noopEyeball struct { + connID uint8 } +func (noopEyeball) Serve(ctx context.Context) error { return nil } +func (n noopEyeball) ID() uint8 { return n.connID } +func (noopEyeball) SendUDPSessionDatagram(datagram []byte) error { return nil } func (noopEyeball) SendUDPSessionResponse(id v3.RequestID, resp v3.SessionRegistrationResp) error { return nil } type mockEyeball struct { + connID uint8 // datagram sent via SendUDPSessionDatagram recvData chan []byte // responses sent via SendUDPSessionResponse @@ -39,6 +41,7 @@ type mockEyeball struct { func newMockEyeball() mockEyeball { return mockEyeball{ + connID: 0, recvData: make(chan []byte, 1), recvResp: make(chan struct { id v3.RequestID @@ -47,6 +50,9 @@ func newMockEyeball() mockEyeball { } } +func (mockEyeball) Serve(ctx context.Context) error { return nil } +func (m *mockEyeball) ID() uint8 { return m.connID } + func (m *mockEyeball) SendUDPSessionDatagram(datagram []byte) error { b := make([]byte, len(datagram)) copy(b, datagram) @@ -66,7 +72,7 @@ func (m *mockEyeball) SendUDPSessionResponse(id v3.RequestID, resp v3.SessionReg func TestDatagramConn_New(t *testing.T) { log := zerolog.Nop() - conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&log, ingress.DialUDPAddrPort), &log) + conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&log, ingress.DialUDPAddrPort), 0, &log) if conn == nil { t.Fatal("expected valid connection") } @@ -75,7 +81,7 @@ func TestDatagramConn_New(t *testing.T) { func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) { log := zerolog.Nop() quic := newMockQuicConn() - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&log, ingress.DialUDPAddrPort), &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&log, ingress.DialUDPAddrPort), 0, &log) payload := []byte{0xef, 0xef} conn.SendUDPSessionDatagram(payload) @@ -88,7 +94,7 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) { func TestDatagramConn_SendUDPSessionResponse(t *testing.T) { log := zerolog.Nop() quic := newMockQuicConn() - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&log, ingress.DialUDPAddrPort), &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&log, ingress.DialUDPAddrPort), 0, &log) conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable) resp := <-quic.recv @@ -109,7 +115,7 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) { func TestDatagramConnServe_ApplicationClosed(t *testing.T) { log := zerolog.Nop() quic := newMockQuicConn() - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&log, ingress.DialUDPAddrPort), &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&log, ingress.DialUDPAddrPort), 0, &log) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -125,7 +131,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() quic.ctx = ctx - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&log, ingress.DialUDPAddrPort), &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&log, ingress.DialUDPAddrPort), 0, &log) err := conn.Serve(context.Background()) if !errors.Is(err, context.DeadlineExceeded) { @@ -136,7 +142,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) { func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) { log := zerolog.Nop() quic := &mockQuicConnReadError{err: net.ErrClosed} - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&log, ingress.DialUDPAddrPort), &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&log, ingress.DialUDPAddrPort), 0, &log) err := conn.Serve(context.Background()) if !errors.Is(err, net.ErrClosed) { @@ -171,7 +177,7 @@ func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) { log := zerolog.New(logOutput) quic := newMockQuicConn() quic.send <- test.input - conn := v3.NewDatagramConn(quic, &mockSessionManager{}, &log) + conn := v3.NewDatagramConn(quic, &mockSessionManager{}, 0, &log) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -212,7 +218,7 @@ func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) { quic := newMockQuicConn() expectedErr := errors.New("unable to register session") sessionManager := mockSessionManager{expectedRegErr: expectedErr} - conn := v3.NewDatagramConn(quic, &sessionManager, &log) + conn := v3.NewDatagramConn(quic, &sessionManager, 0, &log) // Setup the muxer ctx, cancel := context.WithCancelCause(context.Background()) @@ -234,19 +240,12 @@ func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) { t.Fatal(err) } - if resp.RequestID != testRequestID && resp.ResponseType != v3.ResponseUnableToBindSocket { + if resp.RequestID != testRequestID || resp.ResponseType != v3.ResponseUnableToBindSocket { t.Fatalf("expected registration response failure") } // Cancel the muxer Serve context and make sure it closes with the expected error - cancel(expectedContextCanceled) - err = <-done - if !errors.Is(err, context.Canceled) { - t.Fatal(err) - } - if !errors.Is(context.Cause(ctx), expectedContextCanceled) { - t.Fatal(err) - } + assertContextClosed(t, ctx, done, cancel) } func TestDatagramConnServe(t *testing.T) { @@ -254,7 +253,7 @@ func TestDatagramConnServe(t *testing.T) { quic := newMockQuicConn() session := newMockSession() sessionManager := mockSessionManager{session: &session} - conn := v3.NewDatagramConn(quic, &sessionManager, &log) + conn := v3.NewDatagramConn(quic, &sessionManager, 0, &log) // Setup the muxer ctx, cancel := context.WithCancelCause(context.Background()) @@ -276,7 +275,7 @@ func TestDatagramConnServe(t *testing.T) { t.Fatal(err) } - if resp.RequestID != testRequestID && resp.ResponseType != v3.ResponseOk { + if resp.RequestID != testRequestID || resp.ResponseType != v3.ResponseOk { t.Fatalf("expected registration response ok") } @@ -291,21 +290,160 @@ func TestDatagramConnServe(t *testing.T) { } // Cancel the muxer Serve context and make sure it closes with the expected error - cancel(expectedContextCanceled) - err = <-done - if !errors.Is(err, context.Canceled) { + assertContextClosed(t, ctx, done, cancel) +} + +func TestDatagramConnServe_RegisterTwice(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + session := newMockSession() + sessionManager := mockSessionManager{session: &session} + conn := v3.NewDatagramConn(quic, &sessionManager, 0, &log) + + // Setup the muxer + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errors.New("other error")) + done := make(chan error, 1) + go func() { + done <- conn.Serve(ctx) + }() + + // Send new session registration + datagram := newRegisterSessionDatagram(testRequestID) + quic.send <- datagram + + // Wait for session registration response with success + datagram = <-quic.recv + var resp v3.UDPSessionRegistrationResponseDatagram + err := resp.UnmarshalBinary(datagram) + if err != nil { t.Fatal(err) } - if !errors.Is(context.Cause(ctx), expectedContextCanceled) { + + if resp.RequestID != testRequestID || resp.ResponseType != v3.ResponseOk { + t.Fatalf("expected registration response ok") + } + + // Set the session manager to return already registered + sessionManager.expectedRegErr = v3.ErrSessionAlreadyRegistered + // Send the registration again as if we didn't receive it at the edge + datagram = newRegisterSessionDatagram(testRequestID) + quic.send <- datagram + + // Wait for session registration response with success + datagram = <-quic.recv + err = resp.UnmarshalBinary(datagram) + if err != nil { t.Fatal(err) } + + if resp.RequestID != testRequestID || resp.ResponseType != v3.ResponseOk { + t.Fatalf("expected registration response ok") + } + + // We expect the session to be served + timer := time.NewTimer(15 * time.Second) + defer timer.Stop() + select { + case <-session.served: + break + case <-timer.C: + t.Fatalf("expected session serve to be called") + } + + // Cancel the muxer Serve context and make sure it closes with the expected error + assertContextClosed(t, ctx, done, cancel) +} + +func TestDatagramConnServe_MigrateConnection(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + session := newMockSession() + sessionManager := mockSessionManager{session: &session} + conn := v3.NewDatagramConn(quic, &sessionManager, 0, &log) + quic2 := newMockQuicConn() + conn2 := v3.NewDatagramConn(quic2, &sessionManager, 1, &log) + + // Setup the muxer + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errors.New("other error")) + done := make(chan error, 1) + go func() { + done <- conn.Serve(ctx) + }() + + ctx2, cancel2 := context.WithCancelCause(context.Background()) + defer cancel2(errors.New("other error")) + done2 := make(chan error, 1) + go func() { + done2 <- conn2.Serve(ctx2) + }() + + // Send new session registration + datagram := newRegisterSessionDatagram(testRequestID) + quic.send <- datagram + + // Wait for session registration response with success + datagram = <-quic.recv + var resp v3.UDPSessionRegistrationResponseDatagram + err := resp.UnmarshalBinary(datagram) + if err != nil { + t.Fatal(err) + } + + if resp.RequestID != testRequestID || resp.ResponseType != v3.ResponseOk { + t.Fatalf("expected registration response ok") + } + + // Set the session manager to return already registered to another connection + sessionManager.expectedRegErr = v3.ErrSessionBoundToOtherConn + // Send the registration again as if we didn't receive it at the edge for a new connection + datagram = newRegisterSessionDatagram(testRequestID) + quic2.send <- datagram + + // Wait for session registration response with success + datagram = <-quic2.recv + err = resp.UnmarshalBinary(datagram) + if err != nil { + t.Fatal(err) + } + + if resp.RequestID != testRequestID || resp.ResponseType != v3.ResponseOk { + t.Fatalf("expected registration response ok") + } + + // We expect the session to be served + timer := time.NewTimer(15 * time.Second) + defer timer.Stop() + select { + case <-session.served: + break + case <-timer.C: + t.Fatalf("expected session serve to be called") + } + + // Expect session to be migrated + select { + case id := <-session.migrated: + if id != conn2.ID() { + t.Fatalf("expected session to be migrated to connection 2") + } + case <-timer.C: + t.Fatalf("expected session migration to be called") + } + + // Cancel the muxer Serve context and make sure it closes with the expected error + assertContextClosed(t, ctx, done, cancel) + // Cancel the second muxer Serve context and make sure it closes with the expected error + assertContextClosed(t, ctx2, done2, cancel2) } func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) { log := zerolog.Nop() quic := newMockQuicConn() + // mockSessionManager will return the ErrSessionNotFound for any session attempting to be queried by the muxer sessionManager := mockSessionManager{session: nil, expectedGetErr: v3.ErrSessionNotFound} - conn := v3.NewDatagramConn(quic, &sessionManager, &log) + conn := v3.NewDatagramConn(quic, &sessionManager, 0, &log) // Setup the muxer ctx, cancel := context.WithCancelCause(context.Background()) @@ -319,15 +457,13 @@ func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) { datagram := newSessionPayloadDatagram(testRequestID, []byte{0xef, 0xef}) quic.send <- datagram + // Since the muxer should eventually discard a failed registration request, there is no side-effect + // that the registration was failed beyond the muxer accepting the registration request. As such, the + // test can only ensure that the quic.send channel was consumed and that the muxer closes normally + // afterwards with the expected context cancelled trigger. + // Cancel the muxer Serve context and make sure it closes with the expected error - cancel(expectedContextCanceled) - err := <-done - if !errors.Is(err, context.Canceled) { - t.Fatal(err) - } - if !errors.Is(context.Cause(ctx), expectedContextCanceled) { - t.Fatal(err) - } + assertContextClosed(t, ctx, done, cancel) } func TestDatagramConnServe_Payload(t *testing.T) { @@ -335,7 +471,7 @@ func TestDatagramConnServe_Payload(t *testing.T) { quic := newMockQuicConn() session := newMockSession() sessionManager := mockSessionManager{session: &session} - conn := v3.NewDatagramConn(quic, &sessionManager, &log) + conn := v3.NewDatagramConn(quic, &sessionManager, 0, &log) // Setup the muxer ctx, cancel := context.WithCancelCause(context.Background()) @@ -357,14 +493,7 @@ func TestDatagramConnServe_Payload(t *testing.T) { } // Cancel the muxer Serve context and make sure it closes with the expected error - cancel(expectedContextCanceled) - err := <-done - if !errors.Is(err, context.Canceled) { - t.Fatal(err) - } - if !errors.Is(context.Cause(ctx), expectedContextCanceled) { - t.Fatal(err) - } + assertContextClosed(t, ctx, done, cancel) } func newRegisterSessionDatagram(id v3.RequestID) []byte { @@ -402,6 +531,18 @@ func newSessionPayloadDatagram(id v3.RequestID, payload []byte) []byte { return datagram } +// Cancel the provided context and make sure it closes with the expected cancellation error +func assertContextClosed(t *testing.T, ctx context.Context, done <-chan error, cancel context.CancelCauseFunc) { + cancel(expectedContextCanceled) + err := <-done + if !errors.Is(err, context.Canceled) { + t.Fatal(err) + } + if !errors.Is(context.Cause(ctx), expectedContextCanceled) { + t.Fatal(err) + } +} + type mockQuicConn struct { ctx context.Context send chan []byte @@ -454,7 +595,7 @@ type mockSessionManager struct { expectedGetErr error } -func (m *mockSessionManager) RegisterSession(request *v3.UDPSessionRegistrationDatagram, conn v3.DatagramWriter) (v3.Session, error) { +func (m *mockSessionManager) RegisterSession(request *v3.UDPSessionRegistrationDatagram, conn v3.DatagramConn) (v3.Session, error) { return m.session, m.expectedRegErr } @@ -465,14 +606,16 @@ func (m *mockSessionManager) GetSession(requestID v3.RequestID) (v3.Session, err func (m *mockSessionManager) UnregisterSession(requestID v3.RequestID) {} type mockSession struct { - served chan struct{} - recv chan []byte + served chan struct{} + migrated chan uint8 + recv chan []byte } func newMockSession() mockSession { return mockSession{ - served: make(chan struct{}), - recv: make(chan []byte, 1), + served: make(chan struct{}), + migrated: make(chan uint8, 2), + recv: make(chan []byte, 1), } } @@ -480,6 +623,13 @@ func (m *mockSession) ID() v3.RequestID { return testRequestID } +func (m *mockSession) ConnectionID() uint8 { + return 0 +} + +func (m *mockSession) Migrate(conn v3.DatagramConn) { m.migrated <- conn.ID() } +func (m *mockSession) ResetIdleTimer() {} + func (m *mockSession) Serve(ctx context.Context) error { close(m.served) return v3.SessionCloseErr diff --git a/quic/v3/session.go b/quic/v3/session.go index 13c42ae3..0146e90d 100644 --- a/quic/v3/session.go +++ b/quic/v3/session.go @@ -7,6 +7,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/rs/zerolog" @@ -47,6 +48,9 @@ func newSessionIdleErr(timeout time.Duration) error { type Session interface { io.WriteCloser ID() RequestID + ConnectionID() uint8 + ResetIdleTimer() + Migrate(eyeball DatagramConn) // Serve starts the event loop for processing UDP packets Serve(ctx context.Context) error } @@ -55,31 +59,48 @@ type session struct { id RequestID closeAfterIdle time.Duration origin io.ReadWriteCloser - eyeball DatagramWriter + eyeball atomic.Pointer[DatagramConn] // activeAtChan is used to communicate the last read/write time activeAtChan chan time.Time closeChan chan error log *zerolog.Logger } -func NewSession(id RequestID, closeAfterIdle time.Duration, origin io.ReadWriteCloser, eyeball DatagramWriter, log *zerolog.Logger) Session { - return &session{ +func NewSession(id RequestID, closeAfterIdle time.Duration, origin io.ReadWriteCloser, eyeball DatagramConn, log *zerolog.Logger) Session { + session := &session{ id: id, closeAfterIdle: closeAfterIdle, origin: origin, - eyeball: eyeball, + eyeball: atomic.Pointer[DatagramConn]{}, // activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will // drop instead of blocking because last active time only needs to be an approximation activeAtChan: make(chan time.Time, 1), closeChan: make(chan error, 1), log: log, } + session.eyeball.Store(&eyeball) + return session } func (s *session) ID() RequestID { return s.id } +func (s *session) ConnectionID() uint8 { + eyeball := *(s.eyeball.Load()) + return eyeball.ID() +} + +func (s *session) Migrate(eyeball DatagramConn) { + current := *(s.eyeball.Load()) + // Only migrate if the connection ids are different. + if current.ID() != eyeball.ID() { + s.eyeball.Store(&eyeball) + } + // The session is already running so we want to restart the idle timeout since no proxied packets have come down yet. + s.markActive() +} + func (s *session) Serve(ctx context.Context) error { go func() { // QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975 @@ -107,9 +128,12 @@ func (s *session) Serve(ctx context.Context) error { s.log.Error().Int("packetSize", n).Msg("Session (origin) packet read was too large and was dropped") continue } + // We need to synchronize on the eyeball in-case that the connection was migrated. This should be rarely a point + // of lock contention, as a migration can only happen during startup of a session before traffic flow. + eyeball := *(s.eyeball.Load()) // Sending a packet to the session does block on the [quic.Connection], however, this is okay because it // will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge. - err = s.eyeball.SendUDPSessionDatagram(readBuffer[:DatagramPayloadHeaderLen+n]) + err = eyeball.SendUDPSessionDatagram(readBuffer[:DatagramPayloadHeaderLen+n]) if err != nil { s.closeChan <- err return @@ -137,6 +161,14 @@ func (s *session) Write(payload []byte) (n int, err error) { return n, err } +// ResetIdleTimer will restart the current idle timer. +// +// This public method is used to allow operators of sessions the ability to extend the session using information that is +// known external to the session itself. +func (s *session) ResetIdleTimer() { + s.markActive() +} + // Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there // are many concurrent read/write. It is fine to lose some precision func (s *session) markActive() { diff --git a/quic/v3/session_test.go b/quic/v3/session_test.go index 8c25878d..c14f2bb7 100644 --- a/quic/v3/session_test.go +++ b/quic/v3/session_test.go @@ -123,6 +123,45 @@ func TestSessionServe_OriginTooLarge(t *testing.T) { } } +func TestSessionServe_Migrate(t *testing.T) { + log := zerolog.Nop() + eyeball := newMockEyeball() + pipe1, pipe2 := net.Pipe() + session := v3.NewSession(testRequestID, 2*time.Second, pipe2, &eyeball, &log) + defer session.Close() + + done := make(chan error) + go func() { + done <- session.Serve(context.Background()) + }() + + // Migrate the session to a new connection before origin sends data + eyeball2 := newMockEyeball() + eyeball2.connID = 1 + session.Migrate(&eyeball2) + + // Origin sends data + payload2 := []byte{0xde} + pipe1.Write(payload2) + + // Expect write to eyeball2 + data := <-eyeball2.recvData + if len(data) <= 17 || !slices.Equal(payload2, data[17:]) { + t.Fatalf("expected data to write to eyeball2 after migration: %+v", data) + } + + select { + case data := <-eyeball.recvData: + t.Fatalf("expected no data to write to eyeball1 after migration: %+v", data) + default: + } + + err := <-done + if !errors.Is(err, v3.SessionIdleErr{}) { + t.Error(err) + } +} + func TestSessionClose_Multiple(t *testing.T) { log := zerolog.Nop() origin := newTestOrigin(makePayload(128)) @@ -249,7 +288,7 @@ func newTestIdleOrigin(d time.Duration) testIdleOrigin { func (o *testIdleOrigin) Read(p []byte) (n int, err error) { time.Sleep(o.duration) - return 0, nil + return -1, nil } func (o *testIdleOrigin) Write(p []byte) (n int, err error) { diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 3e6ca86a..13644f58 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -606,6 +606,7 @@ func (e *EdgeTunnelServer) serveQUIC( ctx, conn, e.sessionManager, + connIndex, connLogger.Logger(), ) } else {