From 5e7ca1441207bfc26eac8ef31e1c39a2614e276a Mon Sep 17 00:00:00 2001 From: Nick Vollmar Date: Wed, 4 Dec 2019 11:22:08 -0600 Subject: [PATCH] TUN-2555: origin/supervisor.go calls Authenticate --- cmd/cloudflared/tunnel/cmd.go | 9 +- cmd/cloudflared/tunnel/configuration.go | 1 + origin/backoffhandler.go | 5 + origin/supervisor.go | 178 ++++++++++++++++++++++-- origin/supervisor_test.go | 128 +++++++++++++++++ origin/tunnel.go | 28 ++-- tunnelrpc/pogs/auth_outcome.go | 32 +++-- 7 files changed, 344 insertions(+), 37 deletions(-) create mode 100644 origin/supervisor_test.go diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 176a9889..62c6c1c4 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -977,6 +977,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag { EnvVars: []string{"TUNNEL_INTENT"}, Hidden: true, }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: "use-reconnect-token", + Usage: "Test reestablishing connections with the new 'reconnect token' flow.", + EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"}, + Hidden: true, + }), altsrc.NewDurationFlag(&cli.DurationFlag{ Name: "dial-edge-timeout", Usage: "Maximum wait time to set up a connection with the edge", @@ -1044,7 +1050,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag { Usage: "Absolute path of directory to save SSH host keys in", EnvVars: []string{"HOST_KEY_PATH"}, Hidden: true, - }), } -} \ No newline at end of file +} diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 8986c9f9..ef24751b 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -275,6 +275,7 @@ func prepareTunnelConfig( TlsConfig: toEdgeTLSConfig, TransportLogger: transportLogger, UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"), + UseReconnectToken: c.Bool("use-reconnect-token"), }, nil } diff --git a/origin/backoffhandler.go b/origin/backoffhandler.go index 97bb9ad8..8ff9752b 100644 --- a/origin/backoffhandler.go +++ b/origin/backoffhandler.go @@ -92,3 +92,8 @@ func (b BackoffHandler) GetBaseTime() time.Duration { } return b.BaseTime } + +// Retries returns the number of retries consumed so far. +func (b *BackoffHandler) Retries() int { + return int(b.retries) +} diff --git a/origin/supervisor.go b/origin/supervisor.go index cea9f9b3..8bb8d046 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -2,15 +2,20 @@ package origin import ( "context" + "errors" "fmt" + "math/rand" "net" + "sync" "time" "github.com/google/uuid" "github.com/sirupsen/logrus" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/signal" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) const ( @@ -20,11 +25,23 @@ const ( resolveTTL = time.Hour // Interval between registering new tunnels registrationInterval = time.Second + + subsystemRefreshAuth = "refresh_auth" + // Maximum exponent for 'Authenticate' exponential backoff + refreshAuthMaxBackoff = 10 + // Waiting time before retrying a failed 'Authenticate' connection + refreshAuthRetryDuration = time.Second * 10 +) + +var ( + errJWTUnset = errors.New("JWT unset") + errEventDigestUnset = errors.New("event digest unset") ) type Supervisor struct { - config *TunnelConfig - edgeIPs []*net.TCPAddr + cloudflaredUUID uuid.UUID + config *TunnelConfig + edgeIPs []*net.TCPAddr // nextUnusedEdgeIP is the index of the next addr k edgeIPs to try nextUnusedEdgeIP int lastResolve time.Time @@ -37,6 +54,12 @@ type Supervisor struct { nextConnectedSignal chan struct{} logger *logrus.Entry + + jwtLock *sync.RWMutex + jwt []byte + + eventDigestLock *sync.RWMutex + eventDigest []byte } type resolveResult struct { @@ -49,18 +72,21 @@ type tunnelError struct { err error } -func NewSupervisor(config *TunnelConfig) *Supervisor { +func NewSupervisor(config *TunnelConfig, u uuid.UUID) *Supervisor { return &Supervisor{ + cloudflaredUUID: u, config: config, tunnelErrors: make(chan tunnelError), tunnelsConnecting: map[int]chan struct{}{}, logger: config.Logger.WithField("subsystem", "supervisor"), + jwtLock: &sync.RWMutex{}, + eventDigestLock: &sync.RWMutex{}, } } -func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error { +func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error { logger := s.config.Logger - if err := s.initialize(ctx, connectedSignal, u); err != nil { + if err := s.initialize(ctx, connectedSignal); err != nil { return err } var tunnelsWaiting []int @@ -68,6 +94,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u var backoffTimer <-chan time.Time tunnelsActive := s.config.HAConnections + refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true} + var refreshAuthBackoffTimer <-chan time.Time + if s.config.UseReconnectToken { + refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration) + } + for { select { // Context cancelled @@ -103,10 +135,20 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u case <-backoffTimer: backoffTimer = nil for _, index := range tunnelsWaiting { - go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u) + go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index)) } tunnelsActive += len(tunnelsWaiting) tunnelsWaiting = nil + // Time to call Authenticate + case <-refreshAuthBackoffTimer: + newTimer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate) + if err != nil { + logger.WithError(err).Error("Authentication failed") + // Permanent failure. Leave the `select` without setting the + // channel to be non-null, so we'll never hit this case of the `select` again. + continue + } + refreshAuthBackoffTimer = newTimer // Tunnel successfully connected case <-s.nextConnectedSignal: if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 { @@ -127,7 +169,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u } } -func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error { +func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error { logger := s.logger edgeIPs, err := s.resolveEdgeIPs() @@ -144,12 +186,12 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig s.lastResolve = time.Now() // check entitlement and version too old error before attempting to register more tunnels s.nextUnusedEdgeIP = s.config.HAConnections - go s.startFirstTunnel(ctx, connectedSignal, u) + go s.startFirstTunnel(ctx, connectedSignal) select { case <-ctx.Done(): <-s.tunnelErrors // Error can't be nil. A nil error signals that initialization succeed - return fmt.Errorf("context was canceled") + return ctx.Err() case tunnelError := <-s.tunnelErrors: return tunnelError.err case <-connectedSignal.Wait(): @@ -157,7 +199,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig // At least one successful connection, so start the rest for i := 1; i < s.config.HAConnections; i++ { ch := signal.New(make(chan struct{})) - go s.startTunnel(ctx, i, ch, u) + go s.startTunnel(ctx, i, ch) time.Sleep(registrationInterval) } return nil @@ -165,8 +207,8 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig // startTunnel starts the first tunnel connection. The resulting error will be sent on // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed -func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) { - err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u) +func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) { + err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID) defer func() { s.tunnelErrors <- tunnelError{index: 0, err: err} }() @@ -187,14 +229,14 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign default: return } - err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u) + err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID) } } // startTunnel starts a new tunnel connection. The resulting error will be sent on // s.tunnelErrors. -func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) { - err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u) +func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) { + err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID) s.tunnelErrors <- tunnelError{index: index, err: err} } @@ -252,3 +294,109 @@ func (s *Supervisor) replaceEdgeIP(badIPIndex int) { s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP] s.nextUnusedEdgeIP++ } + +func (s *Supervisor) ReconnectToken() ([]byte, error) { + s.jwtLock.RLock() + defer s.jwtLock.RUnlock() + if s.jwt == nil { + return nil, errJWTUnset + } + return s.jwt, nil +} + +func (s *Supervisor) SetReconnectToken(jwt []byte) { + s.jwtLock.Lock() + defer s.jwtLock.Unlock() + s.jwt = jwt +} + +func (s *Supervisor) EventDigest() ([]byte, error) { + s.eventDigestLock.RLock() + defer s.eventDigestLock.RUnlock() + if s.eventDigest == nil { + return nil, errEventDigestUnset + } + return s.eventDigest, nil +} + +func (s *Supervisor) SetEventDigest(eventDigest []byte) { + s.eventDigestLock.Lock() + defer s.eventDigestLock.Unlock() + s.eventDigest = eventDigest +} + +func (s *Supervisor) refreshAuth( + ctx context.Context, + backoff *BackoffHandler, + authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error), +) (retryTimer <-chan time.Time, err error) { + logger := s.config.Logger.WithField("subsystem", subsystemRefreshAuth) + authOutcome, err := authenticate(ctx, backoff.Retries()) + if err != nil { + if duration, ok := backoff.GetBackoffDuration(ctx); ok { + logger.WithError(err).Warnf("Retrying in %v", duration) + return backoff.BackoffTimer(), nil + } + return nil, err + } + // clear backoff timer + backoff.SetGracePeriod() + + switch outcome := authOutcome.(type) { + case tunnelpogs.AuthSuccess: + s.SetReconnectToken(outcome.JWT()) + return timeAfter(outcome.RefreshAfter()), nil + case tunnelpogs.AuthUnknown: + return timeAfter(outcome.RefreshAfter()), nil + case tunnelpogs.AuthFail: + return nil, outcome + default: + return nil, fmt.Errorf("Unexpected outcome type %T", authOutcome) + } +} + +func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) { + arbitraryEdgeIP := s.getEdgeIP(rand.Int()) + edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP) + if err != nil { + return nil, err + } + defer edgeConn.Close() + + handler := h2mux.MuxedStreamFunc(func(*h2mux.MuxedStream) error { + // This callback is invoked by h2mux when the edge initiates a stream. + return nil // noop + }) + muxerConfig := s.config.muxerConfig(handler) + muxerConfig.Logger = muxerConfig.Logger.WithField("subsystem", subsystemRefreshAuth) + muxer, err := h2mux.Handshake(edgeConn, edgeConn, muxerConfig, s.config.Metrics.activeStreams) + if err != nil { + return nil, err + } + go muxer.Serve(ctx) + defer func() { + // If we don't wait for the muxer shutdown here, edgeConn.Close() runs before the muxer connections are done, + // and the user sees log noise: "error writing data", "connection closed unexpectedly" + <-muxer.Shutdown() + }() + + tunnelServer, err := connection.NewRPCClient(ctx, muxer, s.logger.WithField("subsystem", subsystemRefreshAuth), openStreamTimeout) + if err != nil { + return nil, err + } + defer tunnelServer.Close() + + const arbitraryConnectionID = uint8(0) + registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID) + registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts) + authResponse, err := tunnelServer.Authenticate( + ctx, + s.config.OriginCert, + s.config.Hostname, + registrationOptions, + ) + if err != nil { + return nil, err + } + return authResponse.Outcome(), nil +} diff --git a/origin/supervisor_test.go b/origin/supervisor_test.go new file mode 100644 index 00000000..559c0400 --- /dev/null +++ b/origin/supervisor_test.go @@ -0,0 +1,128 @@ +package origin + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/google/uuid" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" +) + +func TestRefreshAuthBackoff(t *testing.T) { + logger := logrus.New() + logger.Level = logrus.ErrorLevel + + var wait time.Duration + timeAfter = func(d time.Duration) <-chan time.Time { + wait = d + return time.After(d) + } + + s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New()) + backoff := &BackoffHandler{MaxRetries: 3} + auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { + return nil, fmt.Errorf("authentication failure") + } + + // authentication failures should consume the backoff + for i := uint(0); i < backoff.MaxRetries; i++ { + retryChan, err := s.refreshAuth(context.Background(), backoff, auth) + assert.NoError(t, err) + assert.NotNil(t, retryChan) + assert.Equal(t, (1< 0 { + if len(ar.Jwt) > 0 { return NewAuthSuccess(ar.Jwt, ar.HoursUntilRefresh) } @@ -57,6 +57,10 @@ func NewAuthSuccess(jwt []byte, hoursUntilRefresh uint8) AuthSuccess { return AuthSuccess{jwt: jwt, hoursUntilRefresh: hoursUntilRefresh} } +func (ao AuthSuccess) JWT() []byte { + return ao.jwt +} + // RefreshAfter is how long cloudflared should wait before rerunning Authenticate. func (ao AuthSuccess) RefreshAfter() time.Duration { return hoursToTime(ao.hoursUntilRefresh) @@ -81,6 +85,10 @@ func NewAuthFail(err error) AuthFail { return AuthFail{err: err} } +func (ao AuthFail) Error() string { + return ao.err.Error() +} + // Serialize into an AuthenticateResponse which can be sent via Capnp func (ao AuthFail) Serialize() AuthenticateResponse { return AuthenticateResponse{ @@ -100,6 +108,10 @@ func NewAuthUnknown(err error, hoursUntilRefresh uint8) AuthUnknown { return AuthUnknown{err: err, hoursUntilRefresh: hoursUntilRefresh} } +func (ao AuthUnknown) Error() string { + return ao.err.Error() +} + // RefreshAfter is how long cloudflared should wait before rerunning Authenticate. func (ao AuthUnknown) RefreshAfter() time.Duration { return hoursToTime(ao.hoursUntilRefresh) @@ -117,4 +129,4 @@ func (ao AuthUnknown) isAuthOutcome() {} func hoursToTime(hours uint8) time.Duration { return time.Duration(hours) * time.Hour -} +} \ No newline at end of file