TUN-2555: origin/supervisor.go calls Authenticate

This commit is contained in:
Nick Vollmar 2019-12-04 11:22:08 -06:00
parent b499c0fdba
commit 5e7ca14412
7 changed files with 344 additions and 37 deletions

View File

@ -977,6 +977,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
EnvVars: []string{"TUNNEL_INTENT"},
Hidden: true,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "use-reconnect-token",
Usage: "Test reestablishing connections with the new 'reconnect token' flow.",
EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"},
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "dial-edge-timeout",
Usage: "Maximum wait time to set up a connection with the edge",
@ -1044,7 +1050,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Usage: "Absolute path of directory to save SSH host keys in",
EnvVars: []string{"HOST_KEY_PATH"},
Hidden: true,
}),
}
}
}

View File

@ -275,6 +275,7 @@ func prepareTunnelConfig(
TlsConfig: toEdgeTLSConfig,
TransportLogger: transportLogger,
UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"),
UseReconnectToken: c.Bool("use-reconnect-token"),
}, nil
}

View File

@ -92,3 +92,8 @@ func (b BackoffHandler) GetBaseTime() time.Duration {
}
return b.BaseTime
}
// Retries returns the number of retries consumed so far.
func (b *BackoffHandler) Retries() int {
return int(b.retries)
}

View File

@ -2,15 +2,20 @@ package origin
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"sync"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/signal"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
@ -20,11 +25,23 @@ const (
resolveTTL = time.Hour
// Interval between registering new tunnels
registrationInterval = time.Second
subsystemRefreshAuth = "refresh_auth"
// Maximum exponent for 'Authenticate' exponential backoff
refreshAuthMaxBackoff = 10
// Waiting time before retrying a failed 'Authenticate' connection
refreshAuthRetryDuration = time.Second * 10
)
var (
errJWTUnset = errors.New("JWT unset")
errEventDigestUnset = errors.New("event digest unset")
)
type Supervisor struct {
config *TunnelConfig
edgeIPs []*net.TCPAddr
cloudflaredUUID uuid.UUID
config *TunnelConfig
edgeIPs []*net.TCPAddr
// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
nextUnusedEdgeIP int
lastResolve time.Time
@ -37,6 +54,12 @@ type Supervisor struct {
nextConnectedSignal chan struct{}
logger *logrus.Entry
jwtLock *sync.RWMutex
jwt []byte
eventDigestLock *sync.RWMutex
eventDigest []byte
}
type resolveResult struct {
@ -49,18 +72,21 @@ type tunnelError struct {
err error
}
func NewSupervisor(config *TunnelConfig) *Supervisor {
func NewSupervisor(config *TunnelConfig, u uuid.UUID) *Supervisor {
return &Supervisor{
cloudflaredUUID: u,
config: config,
tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{},
logger: config.Logger.WithField("subsystem", "supervisor"),
jwtLock: &sync.RWMutex{},
eventDigestLock: &sync.RWMutex{},
}
}
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error {
logger := s.config.Logger
if err := s.initialize(ctx, connectedSignal, u); err != nil {
if err := s.initialize(ctx, connectedSignal); err != nil {
return err
}
var tunnelsWaiting []int
@ -68,6 +94,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
var backoffTimer <-chan time.Time
tunnelsActive := s.config.HAConnections
refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
var refreshAuthBackoffTimer <-chan time.Time
if s.config.UseReconnectToken {
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
}
for {
select {
// Context cancelled
@ -103,10 +135,20 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
case <-backoffTimer:
backoffTimer = nil
for _, index := range tunnelsWaiting {
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u)
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
}
tunnelsActive += len(tunnelsWaiting)
tunnelsWaiting = nil
// Time to call Authenticate
case <-refreshAuthBackoffTimer:
newTimer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate)
if err != nil {
logger.WithError(err).Error("Authentication failed")
// Permanent failure. Leave the `select` without setting the
// channel to be non-null, so we'll never hit this case of the `select` again.
continue
}
refreshAuthBackoffTimer = newTimer
// Tunnel successfully connected
case <-s.nextConnectedSignal:
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
@ -127,7 +169,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
}
}
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
logger := s.logger
edgeIPs, err := s.resolveEdgeIPs()
@ -144,12 +186,12 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
s.lastResolve = time.Now()
// check entitlement and version too old error before attempting to register more tunnels
s.nextUnusedEdgeIP = s.config.HAConnections
go s.startFirstTunnel(ctx, connectedSignal, u)
go s.startFirstTunnel(ctx, connectedSignal)
select {
case <-ctx.Done():
<-s.tunnelErrors
// Error can't be nil. A nil error signals that initialization succeed
return fmt.Errorf("context was canceled")
return ctx.Err()
case tunnelError := <-s.tunnelErrors:
return tunnelError.err
case <-connectedSignal.Wait():
@ -157,7 +199,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// At least one successful connection, so start the rest
for i := 1; i < s.config.HAConnections; i++ {
ch := signal.New(make(chan struct{}))
go s.startTunnel(ctx, i, ch, u)
go s.startTunnel(ctx, i, ch)
time.Sleep(registrationInterval)
}
return nil
@ -165,8 +207,8 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
// startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
defer func() {
s.tunnelErrors <- tunnelError{index: 0, err: err}
}()
@ -187,14 +229,14 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
default:
return
}
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
}
}
// startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelErrors.
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u)
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID)
s.tunnelErrors <- tunnelError{index: index, err: err}
}
@ -252,3 +294,109 @@ func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
s.nextUnusedEdgeIP++
}
func (s *Supervisor) ReconnectToken() ([]byte, error) {
s.jwtLock.RLock()
defer s.jwtLock.RUnlock()
if s.jwt == nil {
return nil, errJWTUnset
}
return s.jwt, nil
}
func (s *Supervisor) SetReconnectToken(jwt []byte) {
s.jwtLock.Lock()
defer s.jwtLock.Unlock()
s.jwt = jwt
}
func (s *Supervisor) EventDigest() ([]byte, error) {
s.eventDigestLock.RLock()
defer s.eventDigestLock.RUnlock()
if s.eventDigest == nil {
return nil, errEventDigestUnset
}
return s.eventDigest, nil
}
func (s *Supervisor) SetEventDigest(eventDigest []byte) {
s.eventDigestLock.Lock()
defer s.eventDigestLock.Unlock()
s.eventDigest = eventDigest
}
func (s *Supervisor) refreshAuth(
ctx context.Context,
backoff *BackoffHandler,
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
) (retryTimer <-chan time.Time, err error) {
logger := s.config.Logger.WithField("subsystem", subsystemRefreshAuth)
authOutcome, err := authenticate(ctx, backoff.Retries())
if err != nil {
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
logger.WithError(err).Warnf("Retrying in %v", duration)
return backoff.BackoffTimer(), nil
}
return nil, err
}
// clear backoff timer
backoff.SetGracePeriod()
switch outcome := authOutcome.(type) {
case tunnelpogs.AuthSuccess:
s.SetReconnectToken(outcome.JWT())
return timeAfter(outcome.RefreshAfter()), nil
case tunnelpogs.AuthUnknown:
return timeAfter(outcome.RefreshAfter()), nil
case tunnelpogs.AuthFail:
return nil, outcome
default:
return nil, fmt.Errorf("Unexpected outcome type %T", authOutcome)
}
}
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
arbitraryEdgeIP := s.getEdgeIP(rand.Int())
edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP)
if err != nil {
return nil, err
}
defer edgeConn.Close()
handler := h2mux.MuxedStreamFunc(func(*h2mux.MuxedStream) error {
// This callback is invoked by h2mux when the edge initiates a stream.
return nil // noop
})
muxerConfig := s.config.muxerConfig(handler)
muxerConfig.Logger = muxerConfig.Logger.WithField("subsystem", subsystemRefreshAuth)
muxer, err := h2mux.Handshake(edgeConn, edgeConn, muxerConfig, s.config.Metrics.activeStreams)
if err != nil {
return nil, err
}
go muxer.Serve(ctx)
defer func() {
// If we don't wait for the muxer shutdown here, edgeConn.Close() runs before the muxer connections are done,
// and the user sees log noise: "error writing data", "connection closed unexpectedly"
<-muxer.Shutdown()
}()
tunnelServer, err := connection.NewRPCClient(ctx, muxer, s.logger.WithField("subsystem", subsystemRefreshAuth), openStreamTimeout)
if err != nil {
return nil, err
}
defer tunnelServer.Close()
const arbitraryConnectionID = uint8(0)
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
authResponse, err := tunnelServer.Authenticate(
ctx,
s.config.OriginCert,
s.config.Hostname,
registrationOptions,
)
if err != nil {
return nil, err
}
return authResponse.Outcome(), nil
}

128
origin/supervisor_test.go Normal file
View File

@ -0,0 +1,128 @@
package origin
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
func TestRefreshAuthBackoff(t *testing.T) {
logger := logrus.New()
logger.Level = logrus.ErrorLevel
var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return nil, fmt.Errorf("authentication failure")
}
// authentication failures should consume the backoff
for i := uint(0); i < backoff.MaxRetries; i++ {
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err)
assert.NotNil(t, retryChan)
assert.Equal(t, (1<<i)*time.Second, wait)
}
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.Error(t, err)
assert.Nil(t, retryChan)
// now we actually make contact with the remote server
_, _ = s.refreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
})
// The backoff timer should have been reset. To confirm this, make timeNow
// return a value after the backoff timer's grace period
timeNow = func() time.Time {
expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries)
return time.Now().Add(expectedGracePeriod * 2)
}
_, ok := backoff.GetBackoffDuration(context.Background())
assert.True(t, ok)
}
func TestRefreshAuthSuccess(t *testing.T) {
logger := logrus.New()
logger.Level = logrus.ErrorLevel
var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
}
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err)
assert.NotNil(t, retryChan)
assert.Equal(t, 19*time.Hour, wait)
token, err := s.ReconnectToken()
assert.NoError(t, err)
assert.Equal(t, []byte("jwt"), token)
}
func TestRefreshAuthUnknown(t *testing.T) {
logger := logrus.New()
logger.Level = logrus.ErrorLevel
var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
}
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err)
assert.NotNil(t, retryChan)
assert.Equal(t, 19*time.Hour, wait)
token, err := s.ReconnectToken()
assert.Equal(t, errJWTUnset, err)
assert.Nil(t, token)
}
func TestRefreshAuthFail(t *testing.T) {
logger := logrus.New()
logger.Level = logrus.ErrorLevel
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
}
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
assert.Error(t, err)
assert.Nil(t, retryChan)
token, err := s.ReconnectToken()
assert.Equal(t, errJWTUnset, err)
assert.Nil(t, token)
}

View File

@ -34,6 +34,7 @@ import (
const (
dialTimeout = 15 * time.Second
openStreamTimeout = 30 * time.Second
muxerTimeout = 5 * time.Second
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
TagHeaderNamePrefix = "Cf-Warp-Tag-"
DuplicateConnectionError = "EDUPCONN"
@ -72,6 +73,9 @@ type TunnelConfig struct {
WSGI bool
// OriginUrl may not be used if a user specifies a unix socket.
OriginUrl string
// feature-flag to use new edge reconnect tokens
UseReconnectToken bool
}
type dupConnRegisterTunnelError struct{}
@ -110,6 +114,18 @@ func (e clientRegisterTunnelError) Error() string {
return e.cause.Error()
}
func (c *TunnelConfig) muxerConfig(handler h2mux.MuxedStreamHandler) h2mux.MuxerConfig {
return h2mux.MuxerConfig{
Timeout: muxerTimeout,
Handler: handler,
IsClient: true,
HeartbeatInterval: c.HeartbeatInterval,
MaxHeartbeats: c.MaxHeartbeats,
Logger: c.TransportLogger.WithFields(log.Fields{}),
CompressionQuality: h2mux.CompressionSetting(c.CompressionQuality),
}
}
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
policy := tunnelrpc.ExistingTunnelPolicy_balance
if c.HAConnections <= 1 && c.LBPool == "" {
@ -132,7 +148,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
}
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error {
return NewSupervisor(config).Run(ctx, connectedSignal, cloudflaredID)
return NewSupervisor(config, cloudflaredID).Run(ctx, connectedSignal)
}
func ServeTunnelLoop(ctx context.Context,
@ -448,15 +464,7 @@ func NewTunnelHandler(ctx context.Context,
}
// Establish a muxed connection with the edge
// Client mux handshake with agent server
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
Timeout: 5 * time.Second,
Handler: h,
IsClient: true,
HeartbeatInterval: config.HeartbeatInterval,
MaxHeartbeats: config.MaxHeartbeats,
Logger: config.TransportLogger.WithFields(log.Fields{}),
CompressionQuality: h2mux.CompressionSetting(config.CompressionQuality),
}, h.metrics.activeStreams)
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, config.muxerConfig(h), h.metrics.activeStreams)
if err != nil {
return nil, "", errors.Wrap(err, "Handshake with edge error")
}

View File

@ -1,7 +1,7 @@
package pogs
import (
"fmt"
"errors"
"time"
)
@ -18,20 +18,20 @@ type AuthenticateResponse struct {
// Outcome turns the deserialized response of Authenticate into a programmer-friendly sum type.
func (ar AuthenticateResponse) Outcome() AuthOutcome {
// If there was a network error, then cloudflared should retry later,
// because origintunneld couldn't prove whether auth was correct or not.
if ar.RetryableErr != "" {
return NewAuthUnknown(fmt.Errorf(ar.RetryableErr), ar.HoursUntilRefresh)
}
// If the user's authentication was unsuccessful, the server will return an error explaining why.
// cloudflared should fatal with this error.
if ar.PermanentErr != "" {
return NewAuthFail(fmt.Errorf(ar.PermanentErr))
return NewAuthFail(errors.New(ar.PermanentErr))
}
// If there was a network error, then cloudflared should retry later,
// because origintunneld couldn't prove whether auth was correct or not.
if ar.RetryableErr != "" {
return NewAuthUnknown(errors.New(ar.RetryableErr), ar.HoursUntilRefresh)
}
// If auth succeeded, return the token and refresh it when instructed.
if ar.PermanentErr == "" && len(ar.Jwt) > 0 {
if len(ar.Jwt) > 0 {
return NewAuthSuccess(ar.Jwt, ar.HoursUntilRefresh)
}
@ -57,6 +57,10 @@ func NewAuthSuccess(jwt []byte, hoursUntilRefresh uint8) AuthSuccess {
return AuthSuccess{jwt: jwt, hoursUntilRefresh: hoursUntilRefresh}
}
func (ao AuthSuccess) JWT() []byte {
return ao.jwt
}
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
func (ao AuthSuccess) RefreshAfter() time.Duration {
return hoursToTime(ao.hoursUntilRefresh)
@ -81,6 +85,10 @@ func NewAuthFail(err error) AuthFail {
return AuthFail{err: err}
}
func (ao AuthFail) Error() string {
return ao.err.Error()
}
// Serialize into an AuthenticateResponse which can be sent via Capnp
func (ao AuthFail) Serialize() AuthenticateResponse {
return AuthenticateResponse{
@ -100,6 +108,10 @@ func NewAuthUnknown(err error, hoursUntilRefresh uint8) AuthUnknown {
return AuthUnknown{err: err, hoursUntilRefresh: hoursUntilRefresh}
}
func (ao AuthUnknown) Error() string {
return ao.err.Error()
}
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
func (ao AuthUnknown) RefreshAfter() time.Duration {
return hoursToTime(ao.hoursUntilRefresh)
@ -117,4 +129,4 @@ func (ao AuthUnknown) isAuthOutcome() {}
func hoursToTime(hours uint8) time.Duration {
return time.Duration(hours) * time.Hour
}
}