TUN-2555: origin/supervisor.go calls Authenticate
This commit is contained in:
		
							parent
							
								
									b499c0fdba
								
							
						
					
					
						commit
						5e7ca14412
					
				|  | @ -977,6 +977,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag { | ||||||
| 			EnvVars: []string{"TUNNEL_INTENT"}, | 			EnvVars: []string{"TUNNEL_INTENT"}, | ||||||
| 			Hidden:  true, | 			Hidden:  true, | ||||||
| 		}), | 		}), | ||||||
|  | 		altsrc.NewBoolFlag(&cli.BoolFlag{ | ||||||
|  | 			Name:    "use-reconnect-token", | ||||||
|  | 			Usage:   "Test reestablishing connections with the new 'reconnect token' flow.", | ||||||
|  | 			EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"}, | ||||||
|  | 			Hidden:  true, | ||||||
|  | 		}), | ||||||
| 		altsrc.NewDurationFlag(&cli.DurationFlag{ | 		altsrc.NewDurationFlag(&cli.DurationFlag{ | ||||||
| 			Name:    "dial-edge-timeout", | 			Name:    "dial-edge-timeout", | ||||||
| 			Usage:   "Maximum wait time to set up a connection with the edge", | 			Usage:   "Maximum wait time to set up a connection with the edge", | ||||||
|  | @ -1044,7 +1050,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag { | ||||||
| 			Usage:   "Absolute path of directory to save SSH host keys in", | 			Usage:   "Absolute path of directory to save SSH host keys in", | ||||||
| 			EnvVars: []string{"HOST_KEY_PATH"}, | 			EnvVars: []string{"HOST_KEY_PATH"}, | ||||||
| 			Hidden:  true, | 			Hidden:  true, | ||||||
| 
 |  | ||||||
| 		}), | 		}), | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | @ -275,6 +275,7 @@ func prepareTunnelConfig( | ||||||
| 		TlsConfig:            toEdgeTLSConfig, | 		TlsConfig:            toEdgeTLSConfig, | ||||||
| 		TransportLogger:      transportLogger, | 		TransportLogger:      transportLogger, | ||||||
| 		UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"), | 		UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"), | ||||||
|  | 		UseReconnectToken:    c.Bool("use-reconnect-token"), | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -92,3 +92,8 @@ func (b BackoffHandler) GetBaseTime() time.Duration { | ||||||
| 	} | 	} | ||||||
| 	return b.BaseTime | 	return b.BaseTime | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // Retries returns the number of retries consumed so far.
 | ||||||
|  | func (b *BackoffHandler) Retries() int { | ||||||
|  | 	return int(b.retries) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -2,15 +2,20 @@ package origin | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"math/rand" | ||||||
| 	"net" | 	"net" | ||||||
|  | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/google/uuid" | 	"github.com/google/uuid" | ||||||
| 	"github.com/sirupsen/logrus" | 	"github.com/sirupsen/logrus" | ||||||
| 
 | 
 | ||||||
| 	"github.com/cloudflare/cloudflared/connection" | 	"github.com/cloudflare/cloudflared/connection" | ||||||
|  | 	"github.com/cloudflare/cloudflared/h2mux" | ||||||
| 	"github.com/cloudflare/cloudflared/signal" | 	"github.com/cloudflare/cloudflared/signal" | ||||||
|  | 	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
|  | @ -20,11 +25,23 @@ const ( | ||||||
| 	resolveTTL = time.Hour | 	resolveTTL = time.Hour | ||||||
| 	// Interval between registering new tunnels
 | 	// Interval between registering new tunnels
 | ||||||
| 	registrationInterval = time.Second | 	registrationInterval = time.Second | ||||||
|  | 
 | ||||||
|  | 	subsystemRefreshAuth = "refresh_auth" | ||||||
|  | 	// Maximum exponent for 'Authenticate' exponential backoff
 | ||||||
|  | 	refreshAuthMaxBackoff = 10 | ||||||
|  | 	// Waiting time before retrying a failed 'Authenticate' connection
 | ||||||
|  | 	refreshAuthRetryDuration = time.Second * 10 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | var ( | ||||||
|  | 	errJWTUnset         = errors.New("JWT unset") | ||||||
|  | 	errEventDigestUnset = errors.New("event digest unset") | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Supervisor struct { | type Supervisor struct { | ||||||
| 	config  *TunnelConfig | 	cloudflaredUUID uuid.UUID | ||||||
| 	edgeIPs []*net.TCPAddr | 	config          *TunnelConfig | ||||||
|  | 	edgeIPs         []*net.TCPAddr | ||||||
| 	// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
 | 	// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
 | ||||||
| 	nextUnusedEdgeIP  int | 	nextUnusedEdgeIP  int | ||||||
| 	lastResolve       time.Time | 	lastResolve       time.Time | ||||||
|  | @ -37,6 +54,12 @@ type Supervisor struct { | ||||||
| 	nextConnectedSignal chan struct{} | 	nextConnectedSignal chan struct{} | ||||||
| 
 | 
 | ||||||
| 	logger *logrus.Entry | 	logger *logrus.Entry | ||||||
|  | 
 | ||||||
|  | 	jwtLock *sync.RWMutex | ||||||
|  | 	jwt     []byte | ||||||
|  | 
 | ||||||
|  | 	eventDigestLock *sync.RWMutex | ||||||
|  | 	eventDigest     []byte | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type resolveResult struct { | type resolveResult struct { | ||||||
|  | @ -49,18 +72,21 @@ type tunnelError struct { | ||||||
| 	err   error | 	err   error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewSupervisor(config *TunnelConfig) *Supervisor { | func NewSupervisor(config *TunnelConfig, u uuid.UUID) *Supervisor { | ||||||
| 	return &Supervisor{ | 	return &Supervisor{ | ||||||
|  | 		cloudflaredUUID:   u, | ||||||
| 		config:            config, | 		config:            config, | ||||||
| 		tunnelErrors:      make(chan tunnelError), | 		tunnelErrors:      make(chan tunnelError), | ||||||
| 		tunnelsConnecting: map[int]chan struct{}{}, | 		tunnelsConnecting: map[int]chan struct{}{}, | ||||||
| 		logger:            config.Logger.WithField("subsystem", "supervisor"), | 		logger:            config.Logger.WithField("subsystem", "supervisor"), | ||||||
|  | 		jwtLock:           &sync.RWMutex{}, | ||||||
|  | 		eventDigestLock:   &sync.RWMutex{}, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error { | func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error { | ||||||
| 	logger := s.config.Logger | 	logger := s.config.Logger | ||||||
| 	if err := s.initialize(ctx, connectedSignal, u); err != nil { | 	if err := s.initialize(ctx, connectedSignal); err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	var tunnelsWaiting []int | 	var tunnelsWaiting []int | ||||||
|  | @ -68,6 +94,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u | ||||||
| 	var backoffTimer <-chan time.Time | 	var backoffTimer <-chan time.Time | ||||||
| 	tunnelsActive := s.config.HAConnections | 	tunnelsActive := s.config.HAConnections | ||||||
| 
 | 
 | ||||||
|  | 	refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true} | ||||||
|  | 	var refreshAuthBackoffTimer <-chan time.Time | ||||||
|  | 	if s.config.UseReconnectToken { | ||||||
|  | 		refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	for { | 	for { | ||||||
| 		select { | 		select { | ||||||
| 		// Context cancelled
 | 		// Context cancelled
 | ||||||
|  | @ -103,10 +135,20 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u | ||||||
| 		case <-backoffTimer: | 		case <-backoffTimer: | ||||||
| 			backoffTimer = nil | 			backoffTimer = nil | ||||||
| 			for _, index := range tunnelsWaiting { | 			for _, index := range tunnelsWaiting { | ||||||
| 				go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u) | 				go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index)) | ||||||
| 			} | 			} | ||||||
| 			tunnelsActive += len(tunnelsWaiting) | 			tunnelsActive += len(tunnelsWaiting) | ||||||
| 			tunnelsWaiting = nil | 			tunnelsWaiting = nil | ||||||
|  | 		// Time to call Authenticate
 | ||||||
|  | 		case <-refreshAuthBackoffTimer: | ||||||
|  | 			newTimer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate) | ||||||
|  | 			if err != nil { | ||||||
|  | 				logger.WithError(err).Error("Authentication failed") | ||||||
|  | 				// Permanent failure. Leave the `select` without setting the
 | ||||||
|  | 				// channel to be non-null, so we'll never hit this case of the `select` again.
 | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			refreshAuthBackoffTimer = newTimer | ||||||
| 		// Tunnel successfully connected
 | 		// Tunnel successfully connected
 | ||||||
| 		case <-s.nextConnectedSignal: | 		case <-s.nextConnectedSignal: | ||||||
| 			if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 { | 			if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 { | ||||||
|  | @ -127,7 +169,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error { | func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error { | ||||||
| 	logger := s.logger | 	logger := s.logger | ||||||
| 
 | 
 | ||||||
| 	edgeIPs, err := s.resolveEdgeIPs() | 	edgeIPs, err := s.resolveEdgeIPs() | ||||||
|  | @ -144,12 +186,12 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig | ||||||
| 	s.lastResolve = time.Now() | 	s.lastResolve = time.Now() | ||||||
| 	// check entitlement and version too old error before attempting to register more tunnels
 | 	// check entitlement and version too old error before attempting to register more tunnels
 | ||||||
| 	s.nextUnusedEdgeIP = s.config.HAConnections | 	s.nextUnusedEdgeIP = s.config.HAConnections | ||||||
| 	go s.startFirstTunnel(ctx, connectedSignal, u) | 	go s.startFirstTunnel(ctx, connectedSignal) | ||||||
| 	select { | 	select { | ||||||
| 	case <-ctx.Done(): | 	case <-ctx.Done(): | ||||||
| 		<-s.tunnelErrors | 		<-s.tunnelErrors | ||||||
| 		// Error can't be nil. A nil error signals that initialization succeed
 | 		// Error can't be nil. A nil error signals that initialization succeed
 | ||||||
| 		return fmt.Errorf("context was canceled") | 		return ctx.Err() | ||||||
| 	case tunnelError := <-s.tunnelErrors: | 	case tunnelError := <-s.tunnelErrors: | ||||||
| 		return tunnelError.err | 		return tunnelError.err | ||||||
| 	case <-connectedSignal.Wait(): | 	case <-connectedSignal.Wait(): | ||||||
|  | @ -157,7 +199,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig | ||||||
| 	// At least one successful connection, so start the rest
 | 	// At least one successful connection, so start the rest
 | ||||||
| 	for i := 1; i < s.config.HAConnections; i++ { | 	for i := 1; i < s.config.HAConnections; i++ { | ||||||
| 		ch := signal.New(make(chan struct{})) | 		ch := signal.New(make(chan struct{})) | ||||||
| 		go s.startTunnel(ctx, i, ch, u) | 		go s.startTunnel(ctx, i, ch) | ||||||
| 		time.Sleep(registrationInterval) | 		time.Sleep(registrationInterval) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
|  | @ -165,8 +207,8 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig | ||||||
| 
 | 
 | ||||||
| // startTunnel starts the first tunnel connection. The resulting error will be sent on
 | // startTunnel starts the first tunnel connection. The resulting error will be sent on
 | ||||||
| // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
 | // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
 | ||||||
| func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) { | func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) { | ||||||
| 	err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u) | 	err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID) | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		s.tunnelErrors <- tunnelError{index: 0, err: err} | 		s.tunnelErrors <- tunnelError{index: 0, err: err} | ||||||
| 	}() | 	}() | ||||||
|  | @ -187,14 +229,14 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign | ||||||
| 		default: | 		default: | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u) | 		err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // startTunnel starts a new tunnel connection. The resulting error will be sent on
 | // startTunnel starts a new tunnel connection. The resulting error will be sent on
 | ||||||
| // s.tunnelErrors.
 | // s.tunnelErrors.
 | ||||||
| func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) { | func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) { | ||||||
| 	err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u) | 	err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID) | ||||||
| 	s.tunnelErrors <- tunnelError{index: index, err: err} | 	s.tunnelErrors <- tunnelError{index: index, err: err} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -252,3 +294,109 @@ func (s *Supervisor) replaceEdgeIP(badIPIndex int) { | ||||||
| 	s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP] | 	s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP] | ||||||
| 	s.nextUnusedEdgeIP++ | 	s.nextUnusedEdgeIP++ | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func (s *Supervisor) ReconnectToken() ([]byte, error) { | ||||||
|  | 	s.jwtLock.RLock() | ||||||
|  | 	defer s.jwtLock.RUnlock() | ||||||
|  | 	if s.jwt == nil { | ||||||
|  | 		return nil, errJWTUnset | ||||||
|  | 	} | ||||||
|  | 	return s.jwt, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Supervisor) SetReconnectToken(jwt []byte) { | ||||||
|  | 	s.jwtLock.Lock() | ||||||
|  | 	defer s.jwtLock.Unlock() | ||||||
|  | 	s.jwt = jwt | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Supervisor) EventDigest() ([]byte, error) { | ||||||
|  | 	s.eventDigestLock.RLock() | ||||||
|  | 	defer s.eventDigestLock.RUnlock() | ||||||
|  | 	if s.eventDigest == nil { | ||||||
|  | 		return nil, errEventDigestUnset | ||||||
|  | 	} | ||||||
|  | 	return s.eventDigest, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Supervisor) SetEventDigest(eventDigest []byte) { | ||||||
|  | 	s.eventDigestLock.Lock() | ||||||
|  | 	defer s.eventDigestLock.Unlock() | ||||||
|  | 	s.eventDigest = eventDigest | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Supervisor) refreshAuth( | ||||||
|  | 	ctx context.Context, | ||||||
|  | 	backoff *BackoffHandler, | ||||||
|  | 	authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error), | ||||||
|  | ) (retryTimer <-chan time.Time, err error) { | ||||||
|  | 	logger := s.config.Logger.WithField("subsystem", subsystemRefreshAuth) | ||||||
|  | 	authOutcome, err := authenticate(ctx, backoff.Retries()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		if duration, ok := backoff.GetBackoffDuration(ctx); ok { | ||||||
|  | 			logger.WithError(err).Warnf("Retrying in %v", duration) | ||||||
|  | 			return backoff.BackoffTimer(), nil | ||||||
|  | 		} | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	// clear backoff timer
 | ||||||
|  | 	backoff.SetGracePeriod() | ||||||
|  | 
 | ||||||
|  | 	switch outcome := authOutcome.(type) { | ||||||
|  | 	case tunnelpogs.AuthSuccess: | ||||||
|  | 		s.SetReconnectToken(outcome.JWT()) | ||||||
|  | 		return timeAfter(outcome.RefreshAfter()), nil | ||||||
|  | 	case tunnelpogs.AuthUnknown: | ||||||
|  | 		return timeAfter(outcome.RefreshAfter()), nil | ||||||
|  | 	case tunnelpogs.AuthFail: | ||||||
|  | 		return nil, outcome | ||||||
|  | 	default: | ||||||
|  | 		return nil, fmt.Errorf("Unexpected outcome type %T", authOutcome) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) { | ||||||
|  | 	arbitraryEdgeIP := s.getEdgeIP(rand.Int()) | ||||||
|  | 	edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	defer edgeConn.Close() | ||||||
|  | 
 | ||||||
|  | 	handler := h2mux.MuxedStreamFunc(func(*h2mux.MuxedStream) error { | ||||||
|  | 		// This callback is invoked by h2mux when the edge initiates a stream.
 | ||||||
|  | 		return nil // noop
 | ||||||
|  | 	}) | ||||||
|  | 	muxerConfig := s.config.muxerConfig(handler) | ||||||
|  | 	muxerConfig.Logger = muxerConfig.Logger.WithField("subsystem", subsystemRefreshAuth) | ||||||
|  | 	muxer, err := h2mux.Handshake(edgeConn, edgeConn, muxerConfig, s.config.Metrics.activeStreams) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	go muxer.Serve(ctx) | ||||||
|  | 	defer func() { | ||||||
|  | 		// If we don't wait for the muxer shutdown here, edgeConn.Close() runs before the muxer connections are done,
 | ||||||
|  | 		// and the user sees log noise: "error writing data", "connection closed unexpectedly"
 | ||||||
|  | 		<-muxer.Shutdown() | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	tunnelServer, err := connection.NewRPCClient(ctx, muxer, s.logger.WithField("subsystem", subsystemRefreshAuth), openStreamTimeout) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	defer tunnelServer.Close() | ||||||
|  | 
 | ||||||
|  | 	const arbitraryConnectionID = uint8(0) | ||||||
|  | 	registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID) | ||||||
|  | 	registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts) | ||||||
|  | 	authResponse, err := tunnelServer.Authenticate( | ||||||
|  | 		ctx, | ||||||
|  | 		s.config.OriginCert, | ||||||
|  | 		s.config.Hostname, | ||||||
|  | 		registrationOptions, | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return authResponse.Outcome(), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,128 @@ | ||||||
|  | package origin | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/google/uuid" | ||||||
|  | 	"github.com/sirupsen/logrus" | ||||||
|  | 	"github.com/stretchr/testify/assert" | ||||||
|  | 
 | ||||||
|  | 	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestRefreshAuthBackoff(t *testing.T) { | ||||||
|  | 	logger := logrus.New() | ||||||
|  | 	logger.Level = logrus.ErrorLevel | ||||||
|  | 
 | ||||||
|  | 	var wait time.Duration | ||||||
|  | 	timeAfter = func(d time.Duration) <-chan time.Time { | ||||||
|  | 		wait = d | ||||||
|  | 		return time.After(d) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New()) | ||||||
|  | 	backoff := &BackoffHandler{MaxRetries: 3} | ||||||
|  | 	auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { | ||||||
|  | 		return nil, fmt.Errorf("authentication failure") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// authentication failures should consume the backoff
 | ||||||
|  | 	for i := uint(0); i < backoff.MaxRetries; i++ { | ||||||
|  | 		retryChan, err := s.refreshAuth(context.Background(), backoff, auth) | ||||||
|  | 		assert.NoError(t, err) | ||||||
|  | 		assert.NotNil(t, retryChan) | ||||||
|  | 		assert.Equal(t, (1<<i)*time.Second, wait) | ||||||
|  | 	} | ||||||
|  | 	retryChan, err := s.refreshAuth(context.Background(), backoff, auth) | ||||||
|  | 	assert.Error(t, err) | ||||||
|  | 	assert.Nil(t, retryChan) | ||||||
|  | 
 | ||||||
|  | 	// now we actually make contact with the remote server
 | ||||||
|  | 	_, _ = s.refreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { | ||||||
|  | 		return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	// The backoff timer should have been reset. To confirm this, make timeNow
 | ||||||
|  | 	// return a value after the backoff timer's grace period
 | ||||||
|  | 	timeNow = func() time.Time { | ||||||
|  | 		expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries) | ||||||
|  | 		return time.Now().Add(expectedGracePeriod * 2) | ||||||
|  | 	} | ||||||
|  | 	_, ok := backoff.GetBackoffDuration(context.Background()) | ||||||
|  | 	assert.True(t, ok) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestRefreshAuthSuccess(t *testing.T) { | ||||||
|  | 	logger := logrus.New() | ||||||
|  | 	logger.Level = logrus.ErrorLevel | ||||||
|  | 
 | ||||||
|  | 	var wait time.Duration | ||||||
|  | 	timeAfter = func(d time.Duration) <-chan time.Time { | ||||||
|  | 		wait = d | ||||||
|  | 		return time.After(d) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New()) | ||||||
|  | 	backoff := &BackoffHandler{MaxRetries: 3} | ||||||
|  | 	auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { | ||||||
|  | 		return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	retryChan, err := s.refreshAuth(context.Background(), backoff, auth) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.NotNil(t, retryChan) | ||||||
|  | 	assert.Equal(t, 19*time.Hour, wait) | ||||||
|  | 
 | ||||||
|  | 	token, err := s.ReconnectToken() | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.Equal(t, []byte("jwt"), token) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestRefreshAuthUnknown(t *testing.T) { | ||||||
|  | 	logger := logrus.New() | ||||||
|  | 	logger.Level = logrus.ErrorLevel | ||||||
|  | 
 | ||||||
|  | 	var wait time.Duration | ||||||
|  | 	timeAfter = func(d time.Duration) <-chan time.Time { | ||||||
|  | 		wait = d | ||||||
|  | 		return time.After(d) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New()) | ||||||
|  | 	backoff := &BackoffHandler{MaxRetries: 3} | ||||||
|  | 	auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { | ||||||
|  | 		return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	retryChan, err := s.refreshAuth(context.Background(), backoff, auth) | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	assert.NotNil(t, retryChan) | ||||||
|  | 	assert.Equal(t, 19*time.Hour, wait) | ||||||
|  | 
 | ||||||
|  | 	token, err := s.ReconnectToken() | ||||||
|  | 	assert.Equal(t, errJWTUnset, err) | ||||||
|  | 	assert.Nil(t, token) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestRefreshAuthFail(t *testing.T) { | ||||||
|  | 	logger := logrus.New() | ||||||
|  | 	logger.Level = logrus.ErrorLevel | ||||||
|  | 
 | ||||||
|  | 	s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New()) | ||||||
|  | 	backoff := &BackoffHandler{MaxRetries: 3} | ||||||
|  | 	auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { | ||||||
|  | 		return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	retryChan, err := s.refreshAuth(context.Background(), backoff, auth) | ||||||
|  | 	assert.Error(t, err) | ||||||
|  | 	assert.Nil(t, retryChan) | ||||||
|  | 
 | ||||||
|  | 	token, err := s.ReconnectToken() | ||||||
|  | 	assert.Equal(t, errJWTUnset, err) | ||||||
|  | 	assert.Nil(t, token) | ||||||
|  | } | ||||||
|  | @ -34,6 +34,7 @@ import ( | ||||||
| const ( | const ( | ||||||
| 	dialTimeout              = 15 * time.Second | 	dialTimeout              = 15 * time.Second | ||||||
| 	openStreamTimeout        = 30 * time.Second | 	openStreamTimeout        = 30 * time.Second | ||||||
|  | 	muxerTimeout             = 5 * time.Second | ||||||
| 	lbProbeUserAgentPrefix   = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" | 	lbProbeUserAgentPrefix   = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" | ||||||
| 	TagHeaderNamePrefix      = "Cf-Warp-Tag-" | 	TagHeaderNamePrefix      = "Cf-Warp-Tag-" | ||||||
| 	DuplicateConnectionError = "EDUPCONN" | 	DuplicateConnectionError = "EDUPCONN" | ||||||
|  | @ -72,6 +73,9 @@ type TunnelConfig struct { | ||||||
| 	WSGI                 bool | 	WSGI                 bool | ||||||
| 	// OriginUrl may not be used if a user specifies a unix socket.
 | 	// OriginUrl may not be used if a user specifies a unix socket.
 | ||||||
| 	OriginUrl string | 	OriginUrl string | ||||||
|  | 
 | ||||||
|  | 	// feature-flag to use new edge reconnect tokens
 | ||||||
|  | 	UseReconnectToken bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type dupConnRegisterTunnelError struct{} | type dupConnRegisterTunnelError struct{} | ||||||
|  | @ -110,6 +114,18 @@ func (e clientRegisterTunnelError) Error() string { | ||||||
| 	return e.cause.Error() | 	return e.cause.Error() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (c *TunnelConfig) muxerConfig(handler h2mux.MuxedStreamHandler) h2mux.MuxerConfig { | ||||||
|  | 	return h2mux.MuxerConfig{ | ||||||
|  | 		Timeout:            muxerTimeout, | ||||||
|  | 		Handler:            handler, | ||||||
|  | 		IsClient:           true, | ||||||
|  | 		HeartbeatInterval:  c.HeartbeatInterval, | ||||||
|  | 		MaxHeartbeats:      c.MaxHeartbeats, | ||||||
|  | 		Logger:             c.TransportLogger.WithFields(log.Fields{}), | ||||||
|  | 		CompressionQuality: h2mux.CompressionSetting(c.CompressionQuality), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { | func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { | ||||||
| 	policy := tunnelrpc.ExistingTunnelPolicy_balance | 	policy := tunnelrpc.ExistingTunnelPolicy_balance | ||||||
| 	if c.HAConnections <= 1 && c.LBPool == "" { | 	if c.HAConnections <= 1 && c.LBPool == "" { | ||||||
|  | @ -132,7 +148,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error { | func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error { | ||||||
| 	return NewSupervisor(config).Run(ctx, connectedSignal, cloudflaredID) | 	return NewSupervisor(config, cloudflaredID).Run(ctx, connectedSignal) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func ServeTunnelLoop(ctx context.Context, | func ServeTunnelLoop(ctx context.Context, | ||||||
|  | @ -448,15 +464,7 @@ func NewTunnelHandler(ctx context.Context, | ||||||
| 	} | 	} | ||||||
| 	// Establish a muxed connection with the edge
 | 	// Establish a muxed connection with the edge
 | ||||||
| 	// Client mux handshake with agent server
 | 	// Client mux handshake with agent server
 | ||||||
| 	h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{ | 	h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, config.muxerConfig(h), h.metrics.activeStreams) | ||||||
| 		Timeout:            5 * time.Second, |  | ||||||
| 		Handler:            h, |  | ||||||
| 		IsClient:           true, |  | ||||||
| 		HeartbeatInterval:  config.HeartbeatInterval, |  | ||||||
| 		MaxHeartbeats:      config.MaxHeartbeats, |  | ||||||
| 		Logger:             config.TransportLogger.WithFields(log.Fields{}), |  | ||||||
| 		CompressionQuality: h2mux.CompressionSetting(config.CompressionQuality), |  | ||||||
| 	}, h.metrics.activeStreams) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, "", errors.Wrap(err, "Handshake with edge error") | 		return nil, "", errors.Wrap(err, "Handshake with edge error") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -1,7 +1,7 @@ | ||||||
| package pogs | package pogs | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"errors" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -18,20 +18,20 @@ type AuthenticateResponse struct { | ||||||
| 
 | 
 | ||||||
| // Outcome turns the deserialized response of Authenticate into a programmer-friendly sum type.
 | // Outcome turns the deserialized response of Authenticate into a programmer-friendly sum type.
 | ||||||
| func (ar AuthenticateResponse) Outcome() AuthOutcome { | func (ar AuthenticateResponse) Outcome() AuthOutcome { | ||||||
| 	// If there was a network error, then cloudflared should retry later,
 |  | ||||||
| 	// because origintunneld couldn't prove whether auth was correct or not.
 |  | ||||||
| 	if ar.RetryableErr != "" { |  | ||||||
| 		return NewAuthUnknown(fmt.Errorf(ar.RetryableErr), ar.HoursUntilRefresh) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// If the user's authentication was unsuccessful, the server will return an error explaining why.
 | 	// If the user's authentication was unsuccessful, the server will return an error explaining why.
 | ||||||
| 	// cloudflared should fatal with this error.
 | 	// cloudflared should fatal with this error.
 | ||||||
| 	if ar.PermanentErr != "" { | 	if ar.PermanentErr != "" { | ||||||
| 		return NewAuthFail(fmt.Errorf(ar.PermanentErr)) | 		return NewAuthFail(errors.New(ar.PermanentErr)) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// If there was a network error, then cloudflared should retry later,
 | ||||||
|  | 	// because origintunneld couldn't prove whether auth was correct or not.
 | ||||||
|  | 	if ar.RetryableErr != "" { | ||||||
|  | 		return NewAuthUnknown(errors.New(ar.RetryableErr), ar.HoursUntilRefresh) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// If auth succeeded, return the token and refresh it when instructed.
 | 	// If auth succeeded, return the token and refresh it when instructed.
 | ||||||
| 	if ar.PermanentErr == "" && len(ar.Jwt) > 0 { | 	if len(ar.Jwt) > 0 { | ||||||
| 		return NewAuthSuccess(ar.Jwt, ar.HoursUntilRefresh) | 		return NewAuthSuccess(ar.Jwt, ar.HoursUntilRefresh) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -57,6 +57,10 @@ func NewAuthSuccess(jwt []byte, hoursUntilRefresh uint8) AuthSuccess { | ||||||
| 	return AuthSuccess{jwt: jwt, hoursUntilRefresh: hoursUntilRefresh} | 	return AuthSuccess{jwt: jwt, hoursUntilRefresh: hoursUntilRefresh} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (ao AuthSuccess) JWT() []byte { | ||||||
|  | 	return ao.jwt | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
 | // RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
 | ||||||
| func (ao AuthSuccess) RefreshAfter() time.Duration { | func (ao AuthSuccess) RefreshAfter() time.Duration { | ||||||
| 	return hoursToTime(ao.hoursUntilRefresh) | 	return hoursToTime(ao.hoursUntilRefresh) | ||||||
|  | @ -81,6 +85,10 @@ func NewAuthFail(err error) AuthFail { | ||||||
| 	return AuthFail{err: err} | 	return AuthFail{err: err} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (ao AuthFail) Error() string { | ||||||
|  | 	return ao.err.Error() | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // Serialize into an AuthenticateResponse which can be sent via Capnp
 | // Serialize into an AuthenticateResponse which can be sent via Capnp
 | ||||||
| func (ao AuthFail) Serialize() AuthenticateResponse { | func (ao AuthFail) Serialize() AuthenticateResponse { | ||||||
| 	return AuthenticateResponse{ | 	return AuthenticateResponse{ | ||||||
|  | @ -100,6 +108,10 @@ func NewAuthUnknown(err error, hoursUntilRefresh uint8) AuthUnknown { | ||||||
| 	return AuthUnknown{err: err, hoursUntilRefresh: hoursUntilRefresh} | 	return AuthUnknown{err: err, hoursUntilRefresh: hoursUntilRefresh} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (ao AuthUnknown) Error() string { | ||||||
|  | 	return ao.err.Error() | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
 | // RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
 | ||||||
| func (ao AuthUnknown) RefreshAfter() time.Duration { | func (ao AuthUnknown) RefreshAfter() time.Duration { | ||||||
| 	return hoursToTime(ao.hoursUntilRefresh) | 	return hoursToTime(ao.hoursUntilRefresh) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue