From 6aa48d2eb2ced44016db0c63a880e20d5a1b1e5e Mon Sep 17 00:00:00 2001 From: Nick Vollmar Date: Fri, 6 Dec 2019 15:32:15 -0600 Subject: [PATCH] TUN-2554: cloudflared calls ReconnectTunnel --- origin/supervisor.go | 26 ++++++--- origin/tunnel.go | 87 ++++++++++++++++++++++++++---- tunnelrpc/pogs/reconnect_tunnel.go | 10 ++-- 3 files changed, 103 insertions(+), 20 deletions(-) diff --git a/origin/supervisor.go b/origin/supervisor.go index 8bb8d046..f5d08ed4 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -31,6 +31,8 @@ const ( refreshAuthMaxBackoff = 10 // Waiting time before retrying a failed 'Authenticate' connection refreshAuthRetryDuration = time.Second * 10 + // Maximum time to make an Authenticate RPC + authTokenTimeout = time.Second * 30 ) var ( @@ -90,14 +92,21 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er return err } var tunnelsWaiting []int + tunnelsActive := s.config.HAConnections + backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true} var backoffTimer <-chan time.Time - tunnelsActive := s.config.HAConnections refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true} var refreshAuthBackoffTimer <-chan time.Time + if s.config.UseReconnectToken { - refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration) + if timer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil { + refreshAuthBackoffTimer = timer + } else { + logger.WithError(err).Errorf("initial refreshAuth failed, retrying in %v", refreshAuthRetryDuration) + refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration) + } } for { @@ -169,11 +178,11 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er } } +// Returns nil if initialization succeeded, else the initialization error. func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error { logger := s.logger edgeIPs, err := s.resolveEdgeIPs() - if err != nil { logger.Infof("ResolveEdgeIPs err") return err @@ -190,7 +199,6 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig select { case <-ctx.Done(): <-s.tunnelErrors - // Error can't be nil. A nil error signals that initialization succeed return ctx.Err() case tunnelError := <-s.tunnelErrors: return tunnelError.err @@ -208,7 +216,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig // startTunnel starts the first tunnel connection. The resulting error will be sent on // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) { - err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID) + err := ServeTunnelLoop(ctx, s, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID) defer func() { s.tunnelErrors <- tunnelError{index: 0, err: err} }() @@ -229,14 +237,14 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign default: return } - err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID) + err = ServeTunnelLoop(ctx, s, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID) } } // startTunnel starts a new tunnel connection. The resulting error will be sent on // s.tunnelErrors. func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) { - err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID) + err := ServeTunnelLoop(ctx, s, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID) s.tunnelErrors <- tunnelError{index: index, err: err} } @@ -347,7 +355,9 @@ func (s *Supervisor) refreshAuth( s.SetReconnectToken(outcome.JWT()) return timeAfter(outcome.RefreshAfter()), nil case tunnelpogs.AuthUnknown: - return timeAfter(outcome.RefreshAfter()), nil + duration := outcome.RefreshAfter() + logger.WithError(outcome).Warnf("Retrying in %v", duration) + return timeAfter(duration), nil case tunnelpogs.AuthFail: return nil, outcome default: diff --git a/origin/tunnel.go b/origin/tunnel.go index 04b77792..3a05e347 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -78,6 +78,14 @@ type TunnelConfig struct { UseReconnectToken bool } +// ReconnectTunnelCredentialManager is invoked by functions in this file to +// get/set parameters for ReconnectTunnel RPC calls. +type ReconnectTunnelCredentialManager interface { + ReconnectToken() ([]byte, error) + EventDigest() ([]byte, error) + SetEventDigest(eventDigest []byte) +} + type dupConnRegisterTunnelError struct{} func (e dupConnRegisterTunnelError) Error() string { @@ -152,6 +160,7 @@ func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSigna } func ServeTunnelLoop(ctx context.Context, + credentialManager ReconnectTunnelCredentialManager, config *TunnelConfig, addr *net.TCPAddr, connectionID uint8, @@ -173,6 +182,7 @@ func ServeTunnelLoop(ctx context.Context, for { err, recoverable := ServeTunnel( ctx, + credentialManager, config, connectionLogger, addr, connectionID, @@ -193,6 +203,7 @@ func ServeTunnelLoop(ctx context.Context, func ServeTunnel( ctx context.Context, + credentialManager ReconnectTunnelCredentialManager, config *TunnelConfig, logger *log.Entry, addr *net.TCPAddr, @@ -237,13 +248,30 @@ func ServeTunnel( errGroup, serveCtx := errgroup.WithContext(ctx) - errGroup.Go(func() error { - err := RegisterTunnel(serveCtx, handler.muxer, config, logger, connectionID, originLocalIP, u) - if err == nil { - connectedFuse.Fuse(true) - backoff.SetGracePeriod() + errGroup.Go(func() (err error) { + defer func() { + if err == nil { + connectedFuse.Fuse(true) + backoff.SetGracePeriod() + } + }() + + if config.UseReconnectToken && connectedFuse.Value() { + token, tokenErr := credentialManager.ReconnectToken() + eventDigest, eventDigestErr := credentialManager.EventDigest() + // if we have both credentials, we can reconnect + if tokenErr == nil && eventDigestErr == nil { + return ReconnectTunnel(ctx, token, eventDigest, handler.muxer, config, logger, connectionID, originLocalIP, u) + } + // log errors and proceed to RegisterTunnel + if tokenErr != nil { + logger.WithError(tokenErr).Error("Couldn't get reconnect token") + } + if eventDigestErr != nil { + logger.WithError(eventDigestErr).Error("Couldn't get event digest") + } } - return err + return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionID, originLocalIP, u) }) errGroup.Go(func() error { @@ -304,6 +332,7 @@ func ServeTunnel( func RegisterTunnel( ctx context.Context, + credentialManager ReconnectTunnelCredentialManager, muxer *h2mux.Muxer, config *TunnelConfig, logger *log.Entry, @@ -329,12 +358,52 @@ func RegisterTunnel( config.Hostname, config.RegistrationOptions(connectionID, originLocalIP, uuid), ) - if registrationErr := registration.DeserializeError(); registrationErr != nil { // RegisterTunnel RPC failure return processRegisterTunnelError(registrationErr, config.Metrics) } + credentialManager.SetEventDigest(registration.EventDigest) + return processRegistrationSuccess(config, logger, connectionID, registration) +} +func ReconnectTunnel( + ctx context.Context, + token []byte, + eventDigest []byte, + muxer *h2mux.Muxer, + config *TunnelConfig, + logger *log.Entry, + connectionID uint8, + originLocalIP string, + uuid uuid.UUID, +) error { + config.TransportLogger.Debug("initiating RPC stream to reconnect") + tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger.WithField("subsystem", "rpc-reconnect"), openStreamTimeout) + if err != nil { + // RPC stream open error + return newClientRegisterTunnelError(err, config.Metrics.rpcFail) + } + defer tunnelServer.Close() + // Request server info without blocking tunnel registration; must use capnp library directly. + serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { + return nil + }) + LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger) + registration := tunnelServer.ReconnectTunnel( + ctx, + token, + eventDigest, + config.Hostname, + config.RegistrationOptions(connectionID, originLocalIP, uuid), + ) + if registrationErr := registration.DeserializeError(); registrationErr != nil { + // ReconnectTunnel RPC failure + return processRegisterTunnelError(registrationErr, config.Metrics) + } + return processRegistrationSuccess(config, logger, connectionID, registration) +} + +func processRegistrationSuccess(config *TunnelConfig, logger *log.Entry, connectionID uint8, registration *tunnelpogs.TunnelRegistration) error { for _, logLine := range registration.LogLines { logger.Info(logLine) } @@ -378,13 +447,13 @@ func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error { logger.Debug("initiating RPC stream to unregister") ctx := context.Background() - ts, err := connection.NewRPCClient(ctx, muxer, logger.WithField("subsystem", "rpc-unregister"), openStreamTimeout) + tunnelServer, err := connection.NewRPCClient(ctx, muxer, logger.WithField("subsystem", "rpc-unregister"), openStreamTimeout) if err != nil { // RPC stream open error return err } // gracePeriod is encoded in int64 using capnproto - return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds()) + return tunnelServer.UnregisterTunnel(ctx, gracePeriod.Nanoseconds()) } func LogServerInfo( diff --git a/tunnelrpc/pogs/reconnect_tunnel.go b/tunnelrpc/pogs/reconnect_tunnel.go index d3f73528..5a4c4159 100644 --- a/tunnelrpc/pogs/reconnect_tunnel.go +++ b/tunnelrpc/pogs/reconnect_tunnel.go @@ -46,7 +46,7 @@ func (c TunnelServer_PogsClient) ReconnectTunnel( eventDigest []byte, hostname string, options *RegistrationOptions, -) (*TunnelRegistration, error) { +) *TunnelRegistration { client := tunnelrpc.TunnelServer{Client: c.Client} promise := client.ReconnectTunnel(ctx, func(p tunnelrpc.TunnelServer_reconnectTunnel_Params) error { err := p.SetJwt(jwt) @@ -73,7 +73,11 @@ func (c TunnelServer_PogsClient) ReconnectTunnel( }) retval, err := promise.Result().Struct() if err != nil { - return nil, err + return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize() } - return UnmarshalTunnelRegistration(retval) + registration, err := UnmarshalTunnelRegistration(retval) + if err != nil { + return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize() + } + return registration }