diff --git a/connection/errors.go b/connection/errors.go index 9d095c93..df3cfe97 100644 --- a/connection/errors.go +++ b/connection/errors.go @@ -41,13 +41,13 @@ func serverRegistrationErrorFromRPC(err error) ServerRegisterTunnelError { } } -type MuxerShutdownError struct{} +type muxerShutdownError struct{} -func (e MuxerShutdownError) Error() string { +func (e muxerShutdownError) Error() string { return "muxer shutdown" } -var errMuxerStopped = MuxerShutdownError{} +var errMuxerStopped = muxerShutdownError{} func isHandshakeErrRecoverable(err error, connIndex uint8, observer *Observer) bool { log := observer.log.With(). diff --git a/connection/h2mux.go b/connection/h2mux.go index aaecf6a2..4a5d7f00 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -104,7 +104,15 @@ func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *Nam h.controlLoop(serveCtx, connectedFuse, true) return nil }) - return errGroup.Wait() + + err := errGroup.Wait() + if err == errMuxerStopped { + if h.stoppedGracefully { + return nil + } + h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unexpected muxer shutdown") + } + return err } func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error { @@ -136,11 +144,15 @@ func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel h.controlLoop(serveCtx, connectedFuse, false) return nil }) - return errGroup.Wait() -} -func (h *h2muxConnection) StoppedGracefully() bool { - return h.stoppedGracefully + err := errGroup.Wait() + if err == errMuxerStopped { + if h.stoppedGracefully { + return nil + } + h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unexpected muxer shutdown") + } + return err } func (h *h2muxConnection) serveMuxer(ctx context.Context) error { diff --git a/connection/http2.go b/connection/http2.go index bb54a2b8..9c4e4522 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -34,12 +34,14 @@ type http2Connection struct { observer *Observer connIndexStr string connIndex uint8 - wg *sync.WaitGroup // newRPCClientFunc allows us to mock RPCs during testing - newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient + newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient + + activeRequestsWG sync.WaitGroup connectedFuse ConnectedFuse gracefulShutdownC chan struct{} stoppedGracefully bool + controlStreamErr error // result of running control stream handler } func NewHTTP2Connection( @@ -63,7 +65,6 @@ func NewHTTP2Connection( observer: observer, connIndexStr: uint8ToString(connIndex), connIndex: connIndex, - wg: &sync.WaitGroup{}, newRPCClientFunc: newRegistrationRPCClient, connectedFuse: connectedFuse, gracefulShutdownC: gracefulShutdownC, @@ -80,15 +81,20 @@ func (c *http2Connection) Serve(ctx context.Context) error { Handler: c, }) - if !c.stoppedGracefully { + switch { + case c.stoppedGracefully: + return nil + case c.controlStreamErr != nil: + return c.controlStreamErr + default: + c.observer.log.Info().Uint8(LogFieldConnIndex, c.connIndex).Msg("Lost connection with the edge") return errEdgeConnectionClosed } - return nil } func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { - c.wg.Add(1) - defer c.wg.Done() + c.activeRequestsWG.Add(1) + defer c.activeRequestsWG.Done() respWriter := &http2RespWriter{ r: r.Body, @@ -105,6 +111,7 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { if isControlStreamUpgrade(r) { respWriter.shouldFlush = true err = c.serveControlStream(r.Context(), respWriter) + c.controlStreamErr = err } else if isWebsocketUpgrade(r) { respWriter.shouldFlush = true stripWebsocketUpgradeHeader(r) @@ -118,10 +125,6 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (c *http2Connection) StoppedGracefully() bool { - return c.stoppedGracefully -} - func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error { rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log) defer rpcClient.Close() @@ -146,7 +149,7 @@ func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *ht func (c *http2Connection) close() { // Wait for all serve HTTP handlers to return - c.wg.Wait() + c.activeRequestsWG.Wait() c.conn.Close() } diff --git a/origin/tunnel.go b/origin/tunnel.go index f972d967..fd810aac 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -214,6 +214,7 @@ func waitForBackoff( config.Observer.SendReconnect(connIndex) log.Info(). Err(err). + Uint8(connection.LogFieldConnIndex, connIndex). Msgf("Retrying connection in %s seconds", duration) protobackoff.Backoff(ctx) @@ -238,9 +239,11 @@ func waitForBackoff( return nil } +// ServeTunnel runs a single tunnel connection, returns nil on graceful shutdown, +// on error returns a flag indicating if error can be retried func ServeTunnel( ctx context.Context, - log *zerolog.Logger, + connLong *zerolog.Logger, credentialManager *reconnectCredentialManager, config *TunnelConfig, addr *net.TCPAddr, @@ -275,11 +278,12 @@ func ServeTunnel( fuse: fuse, backoff: backoff, } + if protocol == connection.HTTP2 { connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries)) - return ServeHTTP2( + err = ServeHTTP2( ctx, - log, + connLong, config, edgeConn, connOptions, @@ -288,24 +292,64 @@ func ServeTunnel( reconnectCh, gracefulShutdownC, ) + } else { + err = ServeH2mux( + ctx, + connLong, + credentialManager, + config, + edgeConn, + connIndex, + connectedFuse, + cloudflaredUUID, + reconnectCh, + gracefulShutdownC, + ) } - return ServeH2mux( - ctx, - log, - credentialManager, - config, - edgeConn, - connIndex, - connectedFuse, - cloudflaredUUID, - reconnectCh, - gracefulShutdownC, - ) + + if err != nil { + switch err := err.(type) { + case connection.DupConnRegisterTunnelError: + // don't retry this connection anymore, let supervisor pick a new address + return err, false + case connection.ServerRegisterTunnelError: + connLong.Err(err).Msg("Register tunnel error from server side") + // Don't send registration error return from server to Sentry. They are + // logged on server side + if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 { + connLong.Error().Msg(activeIncidentsMsg(incidents)) + } + return err.Cause, !err.Permanent + case ReconnectSignal: + connLong.Info(). + Uint8(connection.LogFieldConnIndex, connIndex). + Msgf("Restarting connection due to reconnect signal in %s", err.Delay) + err.DelayBeforeReconnect() + return err, true + default: + if err == context.Canceled { + connLong.Debug().Err(err).Msgf("Serve tunnel error") + return err, false + } + connLong.Err(err).Msgf("Serve tunnel error") + _, permanent := err.(unrecoverableError) + return err, !permanent + } + } + return nil, false +} + +type unrecoverableError struct { + err error +} + +func (r unrecoverableError) Error() string { + return r.err.Error() } func ServeH2mux( ctx context.Context, - log *zerolog.Logger, + connLog *zerolog.Logger, credentialManager *reconnectCredentialManager, config *TunnelConfig, edgeConn net.Conn, @@ -314,8 +358,8 @@ func ServeH2mux( cloudflaredUUID uuid.UUID, reconnectCh chan ReconnectSignal, gracefulShutdownC chan struct{}, -) (err error, recoverable bool) { - config.Log.Debug().Msgf("Connecting via h2mux") +) error { + connLog.Debug().Msgf("Connecting via h2mux") // Returns error from parsing the origin URL or handshake errors handler, err, recoverable := connection.NewH2muxConnection( config.ConnectionConfig, @@ -326,12 +370,15 @@ func ServeH2mux( gracefulShutdownC, ) if err != nil { - return err, recoverable + if !recoverable { + return unrecoverableError{err} + } + return err } errGroup, serveCtx := errgroup.WithContext(ctx) - errGroup.Go(func() (err error) { + errGroup.Go(func() error { if config.NamedTunnel != nil { connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries)) return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse) @@ -340,49 +387,16 @@ func ServeH2mux( return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse) }) - errGroup.Go(listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)) + errGroup.Go(func() error { + return listenReconnect(serveCtx, reconnectCh, gracefulShutdownC) + }) - err = errGroup.Wait() - if err != nil { - switch err := err.(type) { - case connection.DupConnRegisterTunnelError: - // don't retry this connection anymore, let supervisor pick a new address - return err, false - case connection.ServerRegisterTunnelError: - log.Err(err).Msg("Register tunnel error from server side") - // Don't send registration error return from server to Sentry. They are - // logged on server side - if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 { - log.Error().Msg(activeIncidentsMsg(incidents)) - } - return err.Cause, !err.Permanent - case connection.MuxerShutdownError: - if handler.StoppedGracefully() { - return nil, false - } - log.Info().Msg("Unexpected muxer shutdown") - return err, true - case ReconnectSignal: - log.Info(). - Uint8(connection.LogFieldConnIndex, connIndex). - Msgf("Restarting connection due to reconnect signal in %s", err.Delay) - err.DelayBeforeReconnect() - return err, true - default: - if err == context.Canceled { - log.Debug().Err(err).Msgf("Serve tunnel error") - return err, false - } - log.Err(err).Msgf("Serve tunnel error") - return err, true - } - } - return nil, true + return errGroup.Wait() } func ServeHTTP2( ctx context.Context, - log *zerolog.Logger, + connLog *zerolog.Logger, config *TunnelConfig, tlsServerConn net.Conn, connOptions *tunnelpogs.ConnectionOptions, @@ -390,8 +404,8 @@ func ServeHTTP2( connectedFuse connection.ConnectedFuse, reconnectCh chan ReconnectSignal, gracefulShutdownC chan struct{}, -) (err error, recoverable bool) { - log.Debug().Msgf("Connecting via http2") +) error { + connLog.Debug().Msgf("Connecting via http2") h2conn := connection.NewHTTP2Connection( tlsServerConn, config.ConnectionConfig, @@ -408,25 +422,26 @@ func ServeHTTP2( return h2conn.Serve(serveCtx) }) - errGroup.Go(listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)) + errGroup.Go(func() error { + err := listenReconnect(serveCtx, reconnectCh, gracefulShutdownC) + if err != nil { + // forcefully break the connection (this is only used for testing) + _ = tlsServerConn.Close() + } + return err + }) - err = errGroup.Wait() - if err != nil { - return err, true - } - return nil, false + return errGroup.Wait() } -func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh chan struct{}) func() error { - return func() error { - select { - case reconnect := <-reconnectCh: - return reconnect - case <-gracefulShutdownCh: - return nil - case <-ctx.Done(): - return nil - } +func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh chan struct{}) error { + select { + case reconnect := <-reconnectCh: + return reconnect + case <-gracefulShutdownCh: + return nil + case <-ctx.Done(): + return nil } }