diff --git a/quic/v3/muxer.go b/quic/v3/muxer.go index 79081762..4107a845 100644 --- a/quic/v3/muxer.go +++ b/quic/v3/muxer.go @@ -267,7 +267,7 @@ func (c *datagramConn) handleSessionMigration(requestID RequestID, logger *zerol // Migrate the session to use this edge connection instead of the currently running one. // We also pass in this connection's logger to override the existing logger for the session. - session.Migrate(c, c.logger) + session.Migrate(c, c.conn.Context(), c.logger) // Send another registration response since the session is already active err = c.SendUDPSessionResponse(requestID, ResponseOk) diff --git a/quic/v3/muxer_test.go b/quic/v3/muxer_test.go index 1b2149f9..ac9bf883 100644 --- a/quic/v3/muxer_test.go +++ b/quic/v3/muxer_test.go @@ -619,12 +619,14 @@ func newMockSession() mockSession { } } -func (m *mockSession) ID() v3.RequestID { return testRequestID } -func (m *mockSession) RemoteAddr() net.Addr { return testOriginAddr } -func (m *mockSession) LocalAddr() net.Addr { return testLocalAddr } -func (m *mockSession) ConnectionID() uint8 { return 0 } -func (m *mockSession) Migrate(conn v3.DatagramConn, log *zerolog.Logger) { m.migrated <- conn.ID() } -func (m *mockSession) ResetIdleTimer() {} +func (m *mockSession) ID() v3.RequestID { return testRequestID } +func (m *mockSession) RemoteAddr() net.Addr { return testOriginAddr } +func (m *mockSession) LocalAddr() net.Addr { return testLocalAddr } +func (m *mockSession) ConnectionID() uint8 { return 0 } +func (m *mockSession) Migrate(conn v3.DatagramConn, ctx context.Context, log *zerolog.Logger) { + m.migrated <- conn.ID() +} +func (m *mockSession) ResetIdleTimer() {} func (m *mockSession) Serve(ctx context.Context) error { close(m.served) diff --git a/quic/v3/session.go b/quic/v3/session.go index 7ebe02a7..57641c90 100644 --- a/quic/v3/session.go +++ b/quic/v3/session.go @@ -55,7 +55,7 @@ type Session interface { RemoteAddr() net.Addr LocalAddr() net.Addr ResetIdleTimer() - Migrate(eyeball DatagramConn, logger *zerolog.Logger) + Migrate(eyeball DatagramConn, ctx context.Context, logger *zerolog.Logger) // Serve starts the event loop for processing UDP packets Serve(ctx context.Context) error } @@ -70,6 +70,7 @@ type session struct { // activeAtChan is used to communicate the last read/write time activeAtChan chan time.Time closeChan chan error + contextChan chan context.Context metrics Metrics log *zerolog.Logger } @@ -96,8 +97,10 @@ func NewSession( // drop instead of blocking because last active time only needs to be an approximation activeAtChan: make(chan time.Time, 1), closeChan: make(chan error, 1), - metrics: metrics, - log: &logger, + // contextChan is an unbounded channel to help enforce one active migration of a session at a time. + contextChan: make(chan context.Context), + metrics: metrics, + log: &logger, } session.eyeball.Store(&eyeball) return session @@ -120,11 +123,12 @@ func (s *session) ConnectionID() uint8 { return eyeball.ID() } -func (s *session) Migrate(eyeball DatagramConn, logger *zerolog.Logger) { +func (s *session) Migrate(eyeball DatagramConn, ctx context.Context, logger *zerolog.Logger) { current := *(s.eyeball.Load()) // Only migrate if the connection ids are different. if current.ID() != eyeball.ID() { s.eyeball.Store(&eyeball) + s.contextChan <- ctx log := logger.With().Str(logFlowID, s.id.String()).Logger() s.log = &log } @@ -225,6 +229,7 @@ func (s *session) Close() error { } func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error { + connCtx := ctx // Closing the session at the end cancels read so Serve() can return defer s.Close() if closeAfterIdle == 0 { @@ -237,8 +242,14 @@ func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time for { select { - case <-ctx.Done(): - return ctx.Err() + case <-connCtx.Done(): + return connCtx.Err() + case newContext := <-s.contextChan: + // During migration of a session, we need to make sure that the context of the new connection is used instead + // of the old connection context. This will ensure that when the old connection goes away, this session will + // still be active on the existing connection. + connCtx = newContext + continue case reason := <-s.closeChan: return reason case <-checkIdleTimer.C: diff --git a/quic/v3/session_test.go b/quic/v3/session_test.go index 1e31962a..f47ceb14 100644 --- a/quic/v3/session_test.go +++ b/quic/v3/session_test.go @@ -137,14 +137,28 @@ func TestSessionServe_Migrate(t *testing.T) { defer session.Close() done := make(chan error) + eyeball1Ctx, cancel := context.WithCancelCause(context.Background()) go func() { - done <- session.Serve(context.Background()) + done <- session.Serve(eyeball1Ctx) }() // Migrate the session to a new connection before origin sends data eyeball2 := newMockEyeball() eyeball2.connID = 1 - session.Migrate(&eyeball2, &log) + eyeball2Ctx := context.Background() + session.Migrate(&eyeball2, eyeball2Ctx, &log) + + // Cancel the origin eyeball context; this should not cancel the session + contextCancelErr := errors.New("context canceled for first eyeball connection") + cancel(contextCancelErr) + select { + case <-done: + t.Fatalf("expected session to still be running") + default: + } + if context.Cause(eyeball1Ctx) != contextCancelErr { + t.Fatalf("first eyeball context should be cancelled manually: %+v", context.Cause(eyeball1Ctx)) + } // Origin sends data payload2 := []byte{0xde} @@ -166,6 +180,68 @@ func TestSessionServe_Migrate(t *testing.T) { if !errors.Is(err, v3.SessionIdleErr{}) { t.Error(err) } + if eyeball2Ctx.Err() != nil { + t.Fatalf("second eyeball context should be not be cancelled") + } +} + +func TestSessionServe_Migrate_CloseContext2(t *testing.T) { + log := zerolog.Nop() + eyeball := newMockEyeball() + pipe1, pipe2 := net.Pipe() + session := v3.NewSession(testRequestID, 2*time.Second, pipe2, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log) + defer session.Close() + + done := make(chan error) + eyeball1Ctx, cancel := context.WithCancelCause(context.Background()) + go func() { + done <- session.Serve(eyeball1Ctx) + }() + + // Migrate the session to a new connection before origin sends data + eyeball2 := newMockEyeball() + eyeball2.connID = 1 + eyeball2Ctx, cancel2 := context.WithCancelCause(context.Background()) + session.Migrate(&eyeball2, eyeball2Ctx, &log) + + // Cancel the origin eyeball context; this should not cancel the session + contextCancelErr := errors.New("context canceled for first eyeball connection") + cancel(contextCancelErr) + select { + case <-done: + t.Fatalf("expected session to still be running") + default: + } + if context.Cause(eyeball1Ctx) != contextCancelErr { + t.Fatalf("first eyeball context should be cancelled manually: %+v", context.Cause(eyeball1Ctx)) + } + + // Origin sends data + payload2 := []byte{0xde} + pipe1.Write(payload2) + + // Expect write to eyeball2 + data := <-eyeball2.recvData + if len(data) <= 17 || !slices.Equal(payload2, data[17:]) { + t.Fatalf("expected data to write to eyeball2 after migration: %+v", data) + } + + select { + case data := <-eyeball.recvData: + t.Fatalf("expected no data to write to eyeball1 after migration: %+v", data) + default: + } + + // Close the connection2 context manually + contextCancel2Err := errors.New("context canceled for second eyeball connection") + cancel2(contextCancel2Err) + err := <-done + if err != context.Canceled { + t.Fatalf("session Serve should be done: %+v", err) + } + if context.Cause(eyeball2Ctx) != contextCancel2Err { + t.Fatalf("second eyeball context should have been cancelled manually: %+v", context.Cause(eyeball2Ctx)) + } } func TestSessionClose_Multiple(t *testing.T) {