From 8eeb452cce3d6abf237b90dd02a9d249ac7fa292 Mon Sep 17 00:00:00 2001 From: cthuang Date: Tue, 18 Aug 2020 11:14:14 +0100 Subject: [PATCH] TUN-3268: Each connection has its own event digest to reconnect --- origin/metrics.go | 31 +-- origin/reconnect.go | 191 ++++++++++++++++++ .../{supervisor_test.go => reconnect_test.go} | 70 ++----- origin/supervisor.go | 122 ++--------- origin/tunnel.go | 79 +------- 5 files changed, 230 insertions(+), 263 deletions(-) create mode 100644 origin/reconnect.go rename origin/{supervisor_test.go => reconnect_test.go} (56%) diff --git a/origin/metrics.go b/origin/metrics.go index be64f533..4021041d 100644 --- a/origin/metrics.go +++ b/origin/metrics.go @@ -58,11 +58,9 @@ type TunnelMetrics struct { // oldServerLocations stores the last server the tunnel was connected to oldServerLocations map[string]string - regSuccess *prometheus.CounterVec - regFail *prometheus.CounterVec - authSuccess prometheus.Counter - authFail *prometheus.CounterVec - rpcFail *prometheus.CounterVec + regSuccess *prometheus.CounterVec + regFail *prometheus.CounterVec + rpcFail *prometheus.CounterVec muxerMetrics *muxerMetrics tunnelsHA tunnelsForHA @@ -456,27 +454,6 @@ func NewTunnelMetrics() *TunnelMetrics { ) prometheus.MustRegister(registerSuccess) - authSuccess := prometheus.NewCounter( - prometheus.CounterOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "tunnel_authenticate_success", - Help: "Count of successful tunnel authenticate", - }, - ) - prometheus.MustRegister(authSuccess) - - authFail := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "tunnel_authenticate_fail", - Help: "Count of tunnel authenticate errors by type", - }, - []string{"error"}, - ) - prometheus.MustRegister(authFail) - return &TunnelMetrics{ haConnections: haConnections, activeStreams: activeStreams, @@ -497,8 +474,6 @@ func NewTunnelMetrics() *TunnelMetrics { regFail: registerFail, rpcFail: rpcFail, userHostnamesCounts: userHostnamesCounts, - authSuccess: authSuccess, - authFail: authFail, } } diff --git a/origin/reconnect.go b/origin/reconnect.go new file mode 100644 index 00000000..2c7414aa --- /dev/null +++ b/origin/reconnect.go @@ -0,0 +1,191 @@ +package origin + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/tunnelrpc" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" +) + +var ( + errJWTUnset = errors.New("JWT unset") +) + +// reconnectTunnelCredentialManager is invoked by functions in tunnel.go to +// get/set parameters for ReconnectTunnel RPC calls. +type reconnectCredentialManager struct { + mu sync.RWMutex + jwt []byte + eventDigest map[uint8][]byte + connDigest map[uint8][]byte + authSuccess prometheus.Counter + authFail *prometheus.CounterVec +} + +func newReconnectCredentialManager(namespace, subsystem string, haConnections int) *reconnectCredentialManager { + authSuccess := prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "tunnel_authenticate_success", + Help: "Count of successful tunnel authenticate", + }, + ) + authFail := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "tunnel_authenticate_fail", + Help: "Count of tunnel authenticate errors by type", + }, + []string{"error"}, + ) + prometheus.MustRegister(authSuccess, authFail) + return &reconnectCredentialManager{ + eventDigest: make(map[uint8][]byte, haConnections), + connDigest: make(map[uint8][]byte, haConnections), + authSuccess: authSuccess, + authFail: authFail, + } +} + +func (cm *reconnectCredentialManager) ReconnectToken() ([]byte, error) { + cm.mu.RLock() + defer cm.mu.RUnlock() + if cm.jwt == nil { + return nil, errJWTUnset + } + return cm.jwt, nil +} + +func (cm *reconnectCredentialManager) SetReconnectToken(jwt []byte) { + cm.mu.Lock() + defer cm.mu.Unlock() + cm.jwt = jwt +} + +func (cm *reconnectCredentialManager) EventDigest(connID uint8) ([]byte, error) { + cm.mu.RLock() + defer cm.mu.RUnlock() + digest, ok := cm.eventDigest[connID] + if !ok { + return nil, fmt.Errorf("no event digest for connection %v", connID) + } + return digest, nil +} + +func (cm *reconnectCredentialManager) SetEventDigest(connID uint8, digest []byte) { + cm.mu.Lock() + defer cm.mu.Unlock() + cm.eventDigest[connID] = digest +} + +func (cm *reconnectCredentialManager) ConnDigest(connID uint8) ([]byte, error) { + cm.mu.RLock() + defer cm.mu.RUnlock() + digest, ok := cm.connDigest[connID] + if !ok { + return nil, fmt.Errorf("no conneciton digest for connection %v", connID) + } + return digest, nil +} + +func (cm *reconnectCredentialManager) SetConnDigest(connID uint8, digest []byte) { + cm.mu.Lock() + defer cm.mu.Unlock() + cm.connDigest[connID] = digest +} + +func (cm *reconnectCredentialManager) RefreshAuth( + ctx context.Context, + backoff *BackoffHandler, + authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error), +) (retryTimer <-chan time.Time, err error) { + authOutcome, err := authenticate(ctx, backoff.Retries()) + if err != nil { + cm.authFail.WithLabelValues(err.Error()).Inc() + if _, ok := backoff.GetBackoffDuration(ctx); ok { + return backoff.BackoffTimer(), nil + } + return nil, err + } + // clear backoff timer + backoff.SetGracePeriod() + + switch outcome := authOutcome.(type) { + case tunnelpogs.AuthSuccess: + cm.SetReconnectToken(outcome.JWT()) + cm.authSuccess.Inc() + return timeAfter(outcome.RefreshAfter()), nil + case tunnelpogs.AuthUnknown: + duration := outcome.RefreshAfter() + cm.authFail.WithLabelValues(outcome.Error()).Inc() + return timeAfter(duration), nil + case tunnelpogs.AuthFail: + cm.authFail.WithLabelValues(outcome.Error()).Inc() + return nil, outcome + default: + err := fmt.Errorf("refresh_auth: Unexpected outcome type %T", authOutcome) + cm.authFail.WithLabelValues(err.Error()).Inc() + return nil, err + } +} + +func ReconnectTunnel( + ctx context.Context, + muxer *h2mux.Muxer, + config *TunnelConfig, + logger logger.Service, + connectionID uint8, + originLocalAddr string, + uuid uuid.UUID, + credentialManager *reconnectCredentialManager, +) error { + token, err := credentialManager.ReconnectToken() + if err != nil { + return err + } + eventDigest, err := credentialManager.EventDigest(connectionID) + if err != nil { + return err + } + connDigest, err := credentialManager.ConnDigest(connectionID) + if err != nil { + return err + } + + config.TransportLogger.Debug("initiating RPC stream to reconnect") + tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout) + if err != nil { + // RPC stream open error + return newClientRegisterTunnelError(err, config.Metrics.rpcFail, reconnect) + } + defer tunnelServer.Close() + // Request server info without blocking tunnel registration; must use capnp library directly. + serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { + return nil + }) + LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger) + registration := tunnelServer.ReconnectTunnel( + ctx, + token, + eventDigest, + connDigest, + config.Hostname, + config.RegistrationOptions(connectionID, originLocalAddr, uuid), + ) + if registrationErr := registration.DeserializeError(); registrationErr != nil { + // ReconnectTunnel RPC failure + return processRegisterTunnelError(registrationErr, config.Metrics, reconnect) + } + return processRegistrationSuccess(config, logger, connectionID, registration, reconnect, credentialManager) +} diff --git a/origin/supervisor_test.go b/origin/reconnect_test.go similarity index 56% rename from origin/supervisor_test.go rename to origin/reconnect_test.go index 21eeec60..d3cf6dbe 100644 --- a/origin/supervisor_test.go +++ b/origin/reconnect_test.go @@ -7,51 +7,19 @@ import ( "testing" "time" - "github.com/google/uuid" - "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/assert" - "github.com/cloudflare/cloudflared/logger" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) -func testConfig(logger logger.Service) *TunnelConfig { - metrics := TunnelMetrics{} - - metrics.authSuccess = prometheus.NewCounter( - prometheus.CounterOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "tunnel_authenticate_success", - Help: "Count of successful tunnel authenticate", - }, - ) - - metrics.authFail = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: metricsNamespace, - Subsystem: tunnelSubsystem, - Name: "tunnel_authenticate_fail", - Help: "Count of tunnel authenticate errors by type", - }, - []string{"error"}, - ) - return &TunnelConfig{Logger: logger, Metrics: &metrics} -} - func TestRefreshAuthBackoff(t *testing.T) { - logger := logger.NewOutputWriter(logger.NewMockWriteManager()) + rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4) var wait time.Duration timeAfter = func(d time.Duration) <-chan time.Time { wait = d return time.After(d) } - - s, err := NewSupervisor(testConfig(logger), uuid.New()) - if !assert.NoError(t, err) { - t.FailNow() - } backoff := &BackoffHandler{MaxRetries: 3} auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { return nil, fmt.Errorf("authentication failure") @@ -59,17 +27,17 @@ func TestRefreshAuthBackoff(t *testing.T) { // authentication failures should consume the backoff for i := uint(0); i < backoff.MaxRetries; i++ { - retryChan, err := s.refreshAuth(context.Background(), backoff, auth) + retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth) assert.NoError(t, err) assert.NotNil(t, retryChan) assert.Equal(t, (1< s.config.HAConnections } -func (s *Supervisor) ReconnectToken() ([]byte, error) { - s.jwtLock.RLock() - defer s.jwtLock.RUnlock() - if s.jwt == nil { - return nil, errJWTUnset - } - return s.jwt, nil -} - -func (s *Supervisor) SetReconnectToken(jwt []byte) { - s.jwtLock.Lock() - defer s.jwtLock.Unlock() - s.jwt = jwt -} - -func (s *Supervisor) EventDigest() ([]byte, error) { - s.eventDigestLock.RLock() - defer s.eventDigestLock.RUnlock() - if s.eventDigest == nil { - return nil, errEventDigestUnset - } - return s.eventDigest, nil -} - -func (s *Supervisor) SetEventDigest(eventDigest []byte) { - s.eventDigestLock.Lock() - defer s.eventDigestLock.Unlock() - s.eventDigest = eventDigest -} - -func (s *Supervisor) ConnDigest(connID uint8) ([]byte, error) { - s.connDigestLock.RLock() - defer s.connDigestLock.RUnlock() - digest, ok := s.connDigest[connID] - if !ok { - return nil, fmt.Errorf("no connection digest for connection %v", connID) - } - return digest, nil -} - -func (s *Supervisor) SetConnDigest(connID uint8, connDigest []byte) { - s.connDigestLock.Lock() - defer s.connDigestLock.Unlock() - s.connDigest[connID] = connDigest -} - -func (s *Supervisor) refreshAuth( - ctx context.Context, - backoff *BackoffHandler, - authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error), -) (retryTimer <-chan time.Time, err error) { - logger := s.config.Logger - authOutcome, err := authenticate(ctx, backoff.Retries()) - if err != nil { - s.config.Metrics.authFail.WithLabelValues(err.Error()).Inc() - if duration, ok := backoff.GetBackoffDuration(ctx); ok { - logger.Debugf("refresh_auth: Retrying in %v: %s", duration, err) - return backoff.BackoffTimer(), nil - } - return nil, err - } - // clear backoff timer - backoff.SetGracePeriod() - - switch outcome := authOutcome.(type) { - case tunnelpogs.AuthSuccess: - s.SetReconnectToken(outcome.JWT()) - s.config.Metrics.authSuccess.Inc() - return timeAfter(outcome.RefreshAfter()), nil - case tunnelpogs.AuthUnknown: - duration := outcome.RefreshAfter() - s.config.Metrics.authFail.WithLabelValues(outcome.Error()).Inc() - logger.Debugf("refresh_auth: Retrying in %v: %s", duration, outcome) - return timeAfter(duration), nil - case tunnelpogs.AuthFail: - s.config.Metrics.authFail.WithLabelValues(outcome.Error()).Inc() - return nil, outcome - default: - err := fmt.Errorf("refresh_auth: Unexpected outcome type %T", authOutcome) - s.config.Metrics.authFail.WithLabelValues(err.Error()).Inc() - return nil, err - } -} - func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) { arbitraryEdgeIP, err := s.edgeIPs.GetAddrForRPC() if err != nil { diff --git a/origin/tunnel.go b/origin/tunnel.go index 5086ae84..07985069 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -90,16 +90,6 @@ type TunnelConfig struct { ReplaceExisting bool } -// ReconnectTunnelCredentialManager is invoked by functions in this file to -// get/set parameters for ReconnectTunnel RPC calls. -type ReconnectTunnelCredentialManager interface { - ReconnectToken() ([]byte, error) - EventDigest() ([]byte, error) - SetEventDigest(eventDigest []byte) - ConnDigest(connID uint8) ([]byte, error) - SetConnDigest(connID uint8, connDigest []byte) -} - type dupConnRegisterTunnelError struct{} var errDuplicationConnection = &dupConnRegisterTunnelError{} @@ -209,7 +199,7 @@ func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSigna } func ServeTunnelLoop(ctx context.Context, - credentialManager ReconnectTunnelCredentialManager, + credentialManager *reconnectCredentialManager, config *TunnelConfig, addr *net.TCPAddr, connectionIndex uint8, @@ -255,7 +245,7 @@ func ServeTunnelLoop(ctx context.Context, func ServeTunnel( ctx context.Context, - credentialManager ReconnectTunnelCredentialManager, + credentialManager *reconnectCredentialManager, config *TunnelConfig, logger logger.Service, addr *net.TCPAddr, @@ -310,24 +300,12 @@ func ServeTunnel( } if config.UseReconnectToken && connectedFuse.Value() { - token, tokenErr := credentialManager.ReconnectToken() - eventDigest, eventDigestErr := credentialManager.EventDigest() - // if we have both credentials, we can reconnect - if tokenErr == nil && eventDigestErr == nil { - var connDigest []byte - if digest, connDigestErr := credentialManager.ConnDigest(connectionIndex); connDigestErr == nil { - connDigest = digest - } - - return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager) + err := ReconnectTunnel(serveCtx, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager) + if err == nil { + return nil } // log errors and proceed to RegisterTunnel - if tokenErr != nil { - logger.Errorf("Couldn't get reconnect token: %s", tokenErr) - } - if eventDigestErr != nil { - logger.Errorf("Couldn't get event digest: %s", eventDigestErr) - } + logger.Errorf("Couldn't reconnect connection %d. Reregistering it instead. Error was: %v", connectionIndex, err) } return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID) }) @@ -482,7 +460,7 @@ func UnregisterConnection( func RegisterTunnel( ctx context.Context, - credentialManager ReconnectTunnelCredentialManager, + credentialManager *reconnectCredentialManager, muxer *h2mux.Muxer, config *TunnelConfig, logger logger.Service, @@ -512,56 +490,17 @@ func RegisterTunnel( // RegisterTunnel RPC failure return processRegisterTunnelError(registrationErr, config.Metrics, register) } - credentialManager.SetEventDigest(registration.EventDigest) + credentialManager.SetEventDigest(connectionID, registration.EventDigest) return processRegistrationSuccess(config, logger, connectionID, registration, register, credentialManager) } -func ReconnectTunnel( - ctx context.Context, - token []byte, - eventDigest, connDigest []byte, - muxer *h2mux.Muxer, - config *TunnelConfig, - logger logger.Service, - connectionID uint8, - originLocalAddr string, - uuid uuid.UUID, - credentialManager ReconnectTunnelCredentialManager, -) error { - config.TransportLogger.Debug("initiating RPC stream to reconnect") - tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout) - if err != nil { - // RPC stream open error - return newClientRegisterTunnelError(err, config.Metrics.rpcFail, reconnect) - } - defer tunnelServer.Close() - // Request server info without blocking tunnel registration; must use capnp library directly. - serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { - return nil - }) - LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger) - registration := tunnelServer.ReconnectTunnel( - ctx, - token, - eventDigest, - connDigest, - config.Hostname, - config.RegistrationOptions(connectionID, originLocalAddr, uuid), - ) - if registrationErr := registration.DeserializeError(); registrationErr != nil { - // ReconnectTunnel RPC failure - return processRegisterTunnelError(registrationErr, config.Metrics, reconnect) - } - return processRegistrationSuccess(config, logger, connectionID, registration, reconnect, credentialManager) -} - func processRegistrationSuccess( config *TunnelConfig, logger logger.Service, connectionID uint8, registration *tunnelpogs.TunnelRegistration, name registerRPCName, - credentialManager ReconnectTunnelCredentialManager, + credentialManager *reconnectCredentialManager, ) error { for _, logLine := range registration.LogLines { logger.Info(logLine)