TUN-2555: origin/supervisor.go calls Authenticate
This commit is contained in:
parent
b499c0fdba
commit
5e7ca14412
|
@ -977,6 +977,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
EnvVars: []string{"TUNNEL_INTENT"},
|
EnvVars: []string{"TUNNEL_INTENT"},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
}),
|
}),
|
||||||
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
|
Name: "use-reconnect-token",
|
||||||
|
Usage: "Test reestablishing connections with the new 'reconnect token' flow.",
|
||||||
|
EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"},
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
Name: "dial-edge-timeout",
|
Name: "dial-edge-timeout",
|
||||||
Usage: "Maximum wait time to set up a connection with the edge",
|
Usage: "Maximum wait time to set up a connection with the edge",
|
||||||
|
@ -1044,7 +1050,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
||||||
Usage: "Absolute path of directory to save SSH host keys in",
|
Usage: "Absolute path of directory to save SSH host keys in",
|
||||||
EnvVars: []string{"HOST_KEY_PATH"},
|
EnvVars: []string{"HOST_KEY_PATH"},
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
|
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -275,6 +275,7 @@ func prepareTunnelConfig(
|
||||||
TlsConfig: toEdgeTLSConfig,
|
TlsConfig: toEdgeTLSConfig,
|
||||||
TransportLogger: transportLogger,
|
TransportLogger: transportLogger,
|
||||||
UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"),
|
UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"),
|
||||||
|
UseReconnectToken: c.Bool("use-reconnect-token"),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -92,3 +92,8 @@ func (b BackoffHandler) GetBaseTime() time.Duration {
|
||||||
}
|
}
|
||||||
return b.BaseTime
|
return b.BaseTime
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Retries returns the number of retries consumed so far.
|
||||||
|
func (b *BackoffHandler) Retries() int {
|
||||||
|
return int(b.retries)
|
||||||
|
}
|
||||||
|
|
|
@ -2,15 +2,20 @@ package origin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflared/signal"
|
"github.com/cloudflare/cloudflared/signal"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -20,11 +25,23 @@ const (
|
||||||
resolveTTL = time.Hour
|
resolveTTL = time.Hour
|
||||||
// Interval between registering new tunnels
|
// Interval between registering new tunnels
|
||||||
registrationInterval = time.Second
|
registrationInterval = time.Second
|
||||||
|
|
||||||
|
subsystemRefreshAuth = "refresh_auth"
|
||||||
|
// Maximum exponent for 'Authenticate' exponential backoff
|
||||||
|
refreshAuthMaxBackoff = 10
|
||||||
|
// Waiting time before retrying a failed 'Authenticate' connection
|
||||||
|
refreshAuthRetryDuration = time.Second * 10
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errJWTUnset = errors.New("JWT unset")
|
||||||
|
errEventDigestUnset = errors.New("event digest unset")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Supervisor struct {
|
type Supervisor struct {
|
||||||
config *TunnelConfig
|
cloudflaredUUID uuid.UUID
|
||||||
edgeIPs []*net.TCPAddr
|
config *TunnelConfig
|
||||||
|
edgeIPs []*net.TCPAddr
|
||||||
// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
|
// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
|
||||||
nextUnusedEdgeIP int
|
nextUnusedEdgeIP int
|
||||||
lastResolve time.Time
|
lastResolve time.Time
|
||||||
|
@ -37,6 +54,12 @@ type Supervisor struct {
|
||||||
nextConnectedSignal chan struct{}
|
nextConnectedSignal chan struct{}
|
||||||
|
|
||||||
logger *logrus.Entry
|
logger *logrus.Entry
|
||||||
|
|
||||||
|
jwtLock *sync.RWMutex
|
||||||
|
jwt []byte
|
||||||
|
|
||||||
|
eventDigestLock *sync.RWMutex
|
||||||
|
eventDigest []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type resolveResult struct {
|
type resolveResult struct {
|
||||||
|
@ -49,18 +72,21 @@ type tunnelError struct {
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSupervisor(config *TunnelConfig) *Supervisor {
|
func NewSupervisor(config *TunnelConfig, u uuid.UUID) *Supervisor {
|
||||||
return &Supervisor{
|
return &Supervisor{
|
||||||
|
cloudflaredUUID: u,
|
||||||
config: config,
|
config: config,
|
||||||
tunnelErrors: make(chan tunnelError),
|
tunnelErrors: make(chan tunnelError),
|
||||||
tunnelsConnecting: map[int]chan struct{}{},
|
tunnelsConnecting: map[int]chan struct{}{},
|
||||||
logger: config.Logger.WithField("subsystem", "supervisor"),
|
logger: config.Logger.WithField("subsystem", "supervisor"),
|
||||||
|
jwtLock: &sync.RWMutex{},
|
||||||
|
eventDigestLock: &sync.RWMutex{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
|
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) error {
|
||||||
logger := s.config.Logger
|
logger := s.config.Logger
|
||||||
if err := s.initialize(ctx, connectedSignal, u); err != nil {
|
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var tunnelsWaiting []int
|
var tunnelsWaiting []int
|
||||||
|
@ -68,6 +94,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
|
||||||
var backoffTimer <-chan time.Time
|
var backoffTimer <-chan time.Time
|
||||||
tunnelsActive := s.config.HAConnections
|
tunnelsActive := s.config.HAConnections
|
||||||
|
|
||||||
|
refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
|
||||||
|
var refreshAuthBackoffTimer <-chan time.Time
|
||||||
|
if s.config.UseReconnectToken {
|
||||||
|
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
// Context cancelled
|
// Context cancelled
|
||||||
|
@ -103,10 +135,20 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
|
||||||
case <-backoffTimer:
|
case <-backoffTimer:
|
||||||
backoffTimer = nil
|
backoffTimer = nil
|
||||||
for _, index := range tunnelsWaiting {
|
for _, index := range tunnelsWaiting {
|
||||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), u)
|
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
|
||||||
}
|
}
|
||||||
tunnelsActive += len(tunnelsWaiting)
|
tunnelsActive += len(tunnelsWaiting)
|
||||||
tunnelsWaiting = nil
|
tunnelsWaiting = nil
|
||||||
|
// Time to call Authenticate
|
||||||
|
case <-refreshAuthBackoffTimer:
|
||||||
|
newTimer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate)
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Error("Authentication failed")
|
||||||
|
// Permanent failure. Leave the `select` without setting the
|
||||||
|
// channel to be non-null, so we'll never hit this case of the `select` again.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
refreshAuthBackoffTimer = newTimer
|
||||||
// Tunnel successfully connected
|
// Tunnel successfully connected
|
||||||
case <-s.nextConnectedSignal:
|
case <-s.nextConnectedSignal:
|
||||||
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
|
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
|
||||||
|
@ -127,7 +169,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, u
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) error {
|
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error {
|
||||||
logger := s.logger
|
logger := s.logger
|
||||||
|
|
||||||
edgeIPs, err := s.resolveEdgeIPs()
|
edgeIPs, err := s.resolveEdgeIPs()
|
||||||
|
@ -144,12 +186,12 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||||
s.lastResolve = time.Now()
|
s.lastResolve = time.Now()
|
||||||
// check entitlement and version too old error before attempting to register more tunnels
|
// check entitlement and version too old error before attempting to register more tunnels
|
||||||
s.nextUnusedEdgeIP = s.config.HAConnections
|
s.nextUnusedEdgeIP = s.config.HAConnections
|
||||||
go s.startFirstTunnel(ctx, connectedSignal, u)
|
go s.startFirstTunnel(ctx, connectedSignal)
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
<-s.tunnelErrors
|
<-s.tunnelErrors
|
||||||
// Error can't be nil. A nil error signals that initialization succeed
|
// Error can't be nil. A nil error signals that initialization succeed
|
||||||
return fmt.Errorf("context was canceled")
|
return ctx.Err()
|
||||||
case tunnelError := <-s.tunnelErrors:
|
case tunnelError := <-s.tunnelErrors:
|
||||||
return tunnelError.err
|
return tunnelError.err
|
||||||
case <-connectedSignal.Wait():
|
case <-connectedSignal.Wait():
|
||||||
|
@ -157,7 +199,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||||
// At least one successful connection, so start the rest
|
// At least one successful connection, so start the rest
|
||||||
for i := 1; i < s.config.HAConnections; i++ {
|
for i := 1; i < s.config.HAConnections; i++ {
|
||||||
ch := signal.New(make(chan struct{}))
|
ch := signal.New(make(chan struct{}))
|
||||||
go s.startTunnel(ctx, i, ch, u)
|
go s.startTunnel(ctx, i, ch)
|
||||||
time.Sleep(registrationInterval)
|
time.Sleep(registrationInterval)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -165,8 +207,8 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||||
|
|
||||||
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
||||||
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
||||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, u uuid.UUID) {
|
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal) {
|
||||||
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
|
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
|
||||||
defer func() {
|
defer func() {
|
||||||
s.tunnelErrors <- tunnelError{index: 0, err: err}
|
s.tunnelErrors <- tunnelError{index: 0, err: err}
|
||||||
}()
|
}()
|
||||||
|
@ -187,14 +229,14 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, u)
|
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal, s.cloudflaredUUID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||||
// s.tunnelErrors.
|
// s.tunnelErrors.
|
||||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, u uuid.UUID) {
|
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal) {
|
||||||
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, u)
|
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal, s.cloudflaredUUID)
|
||||||
s.tunnelErrors <- tunnelError{index: index, err: err}
|
s.tunnelErrors <- tunnelError{index: index, err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -252,3 +294,109 @@ func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
|
||||||
s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
|
s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
|
||||||
s.nextUnusedEdgeIP++
|
s.nextUnusedEdgeIP++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) ReconnectToken() ([]byte, error) {
|
||||||
|
s.jwtLock.RLock()
|
||||||
|
defer s.jwtLock.RUnlock()
|
||||||
|
if s.jwt == nil {
|
||||||
|
return nil, errJWTUnset
|
||||||
|
}
|
||||||
|
return s.jwt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) SetReconnectToken(jwt []byte) {
|
||||||
|
s.jwtLock.Lock()
|
||||||
|
defer s.jwtLock.Unlock()
|
||||||
|
s.jwt = jwt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) EventDigest() ([]byte, error) {
|
||||||
|
s.eventDigestLock.RLock()
|
||||||
|
defer s.eventDigestLock.RUnlock()
|
||||||
|
if s.eventDigest == nil {
|
||||||
|
return nil, errEventDigestUnset
|
||||||
|
}
|
||||||
|
return s.eventDigest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) SetEventDigest(eventDigest []byte) {
|
||||||
|
s.eventDigestLock.Lock()
|
||||||
|
defer s.eventDigestLock.Unlock()
|
||||||
|
s.eventDigest = eventDigest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) refreshAuth(
|
||||||
|
ctx context.Context,
|
||||||
|
backoff *BackoffHandler,
|
||||||
|
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
|
||||||
|
) (retryTimer <-chan time.Time, err error) {
|
||||||
|
logger := s.config.Logger.WithField("subsystem", subsystemRefreshAuth)
|
||||||
|
authOutcome, err := authenticate(ctx, backoff.Retries())
|
||||||
|
if err != nil {
|
||||||
|
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||||
|
logger.WithError(err).Warnf("Retrying in %v", duration)
|
||||||
|
return backoff.BackoffTimer(), nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// clear backoff timer
|
||||||
|
backoff.SetGracePeriod()
|
||||||
|
|
||||||
|
switch outcome := authOutcome.(type) {
|
||||||
|
case tunnelpogs.AuthSuccess:
|
||||||
|
s.SetReconnectToken(outcome.JWT())
|
||||||
|
return timeAfter(outcome.RefreshAfter()), nil
|
||||||
|
case tunnelpogs.AuthUnknown:
|
||||||
|
return timeAfter(outcome.RefreshAfter()), nil
|
||||||
|
case tunnelpogs.AuthFail:
|
||||||
|
return nil, outcome
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("Unexpected outcome type %T", authOutcome)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
|
||||||
|
arbitraryEdgeIP := s.getEdgeIP(rand.Int())
|
||||||
|
edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer edgeConn.Close()
|
||||||
|
|
||||||
|
handler := h2mux.MuxedStreamFunc(func(*h2mux.MuxedStream) error {
|
||||||
|
// This callback is invoked by h2mux when the edge initiates a stream.
|
||||||
|
return nil // noop
|
||||||
|
})
|
||||||
|
muxerConfig := s.config.muxerConfig(handler)
|
||||||
|
muxerConfig.Logger = muxerConfig.Logger.WithField("subsystem", subsystemRefreshAuth)
|
||||||
|
muxer, err := h2mux.Handshake(edgeConn, edgeConn, muxerConfig, s.config.Metrics.activeStreams)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
go muxer.Serve(ctx)
|
||||||
|
defer func() {
|
||||||
|
// If we don't wait for the muxer shutdown here, edgeConn.Close() runs before the muxer connections are done,
|
||||||
|
// and the user sees log noise: "error writing data", "connection closed unexpectedly"
|
||||||
|
<-muxer.Shutdown()
|
||||||
|
}()
|
||||||
|
|
||||||
|
tunnelServer, err := connection.NewRPCClient(ctx, muxer, s.logger.WithField("subsystem", subsystemRefreshAuth), openStreamTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer tunnelServer.Close()
|
||||||
|
|
||||||
|
const arbitraryConnectionID = uint8(0)
|
||||||
|
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
|
||||||
|
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
|
||||||
|
authResponse, err := tunnelServer.Authenticate(
|
||||||
|
ctx,
|
||||||
|
s.config.OriginCert,
|
||||||
|
s.config.Hostname,
|
||||||
|
registrationOptions,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return authResponse.Outcome(), nil
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,128 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRefreshAuthBackoff(t *testing.T) {
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.Level = logrus.ErrorLevel
|
||||||
|
|
||||||
|
var wait time.Duration
|
||||||
|
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||||
|
wait = d
|
||||||
|
return time.After(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||||
|
backoff := &BackoffHandler{MaxRetries: 3}
|
||||||
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
|
return nil, fmt.Errorf("authentication failure")
|
||||||
|
}
|
||||||
|
|
||||||
|
// authentication failures should consume the backoff
|
||||||
|
for i := uint(0); i < backoff.MaxRetries; i++ {
|
||||||
|
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, retryChan)
|
||||||
|
assert.Equal(t, (1<<i)*time.Second, wait)
|
||||||
|
}
|
||||||
|
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, retryChan)
|
||||||
|
|
||||||
|
// now we actually make contact with the remote server
|
||||||
|
_, _ = s.refreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
|
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// The backoff timer should have been reset. To confirm this, make timeNow
|
||||||
|
// return a value after the backoff timer's grace period
|
||||||
|
timeNow = func() time.Time {
|
||||||
|
expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries)
|
||||||
|
return time.Now().Add(expectedGracePeriod * 2)
|
||||||
|
}
|
||||||
|
_, ok := backoff.GetBackoffDuration(context.Background())
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshAuthSuccess(t *testing.T) {
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.Level = logrus.ErrorLevel
|
||||||
|
|
||||||
|
var wait time.Duration
|
||||||
|
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||||
|
wait = d
|
||||||
|
return time.After(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||||
|
backoff := &BackoffHandler{MaxRetries: 3}
|
||||||
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
|
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, retryChan)
|
||||||
|
assert.Equal(t, 19*time.Hour, wait)
|
||||||
|
|
||||||
|
token, err := s.ReconnectToken()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("jwt"), token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshAuthUnknown(t *testing.T) {
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.Level = logrus.ErrorLevel
|
||||||
|
|
||||||
|
var wait time.Duration
|
||||||
|
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||||
|
wait = d
|
||||||
|
return time.After(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||||
|
backoff := &BackoffHandler{MaxRetries: 3}
|
||||||
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
|
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, retryChan)
|
||||||
|
assert.Equal(t, 19*time.Hour, wait)
|
||||||
|
|
||||||
|
token, err := s.ReconnectToken()
|
||||||
|
assert.Equal(t, errJWTUnset, err)
|
||||||
|
assert.Nil(t, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshAuthFail(t *testing.T) {
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.Level = logrus.ErrorLevel
|
||||||
|
|
||||||
|
s := NewSupervisor(&TunnelConfig{Logger: logger}, uuid.New())
|
||||||
|
backoff := &BackoffHandler{MaxRetries: 3}
|
||||||
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
|
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, retryChan)
|
||||||
|
|
||||||
|
token, err := s.ReconnectToken()
|
||||||
|
assert.Equal(t, errJWTUnset, err)
|
||||||
|
assert.Nil(t, token)
|
||||||
|
}
|
|
@ -34,6 +34,7 @@ import (
|
||||||
const (
|
const (
|
||||||
dialTimeout = 15 * time.Second
|
dialTimeout = 15 * time.Second
|
||||||
openStreamTimeout = 30 * time.Second
|
openStreamTimeout = 30 * time.Second
|
||||||
|
muxerTimeout = 5 * time.Second
|
||||||
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
||||||
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
||||||
DuplicateConnectionError = "EDUPCONN"
|
DuplicateConnectionError = "EDUPCONN"
|
||||||
|
@ -72,6 +73,9 @@ type TunnelConfig struct {
|
||||||
WSGI bool
|
WSGI bool
|
||||||
// OriginUrl may not be used if a user specifies a unix socket.
|
// OriginUrl may not be used if a user specifies a unix socket.
|
||||||
OriginUrl string
|
OriginUrl string
|
||||||
|
|
||||||
|
// feature-flag to use new edge reconnect tokens
|
||||||
|
UseReconnectToken bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type dupConnRegisterTunnelError struct{}
|
type dupConnRegisterTunnelError struct{}
|
||||||
|
@ -110,6 +114,18 @@ func (e clientRegisterTunnelError) Error() string {
|
||||||
return e.cause.Error()
|
return e.cause.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *TunnelConfig) muxerConfig(handler h2mux.MuxedStreamHandler) h2mux.MuxerConfig {
|
||||||
|
return h2mux.MuxerConfig{
|
||||||
|
Timeout: muxerTimeout,
|
||||||
|
Handler: handler,
|
||||||
|
IsClient: true,
|
||||||
|
HeartbeatInterval: c.HeartbeatInterval,
|
||||||
|
MaxHeartbeats: c.MaxHeartbeats,
|
||||||
|
Logger: c.TransportLogger.WithFields(log.Fields{}),
|
||||||
|
CompressionQuality: h2mux.CompressionSetting(c.CompressionQuality),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
|
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
|
||||||
policy := tunnelrpc.ExistingTunnelPolicy_balance
|
policy := tunnelrpc.ExistingTunnelPolicy_balance
|
||||||
if c.HAConnections <= 1 && c.LBPool == "" {
|
if c.HAConnections <= 1 && c.LBPool == "" {
|
||||||
|
@ -132,7 +148,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error {
|
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID) error {
|
||||||
return NewSupervisor(config).Run(ctx, connectedSignal, cloudflaredID)
|
return NewSupervisor(config, cloudflaredID).Run(ctx, connectedSignal)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServeTunnelLoop(ctx context.Context,
|
func ServeTunnelLoop(ctx context.Context,
|
||||||
|
@ -448,15 +464,7 @@ func NewTunnelHandler(ctx context.Context,
|
||||||
}
|
}
|
||||||
// Establish a muxed connection with the edge
|
// Establish a muxed connection with the edge
|
||||||
// Client mux handshake with agent server
|
// Client mux handshake with agent server
|
||||||
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
|
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, config.muxerConfig(h), h.metrics.activeStreams)
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
Handler: h,
|
|
||||||
IsClient: true,
|
|
||||||
HeartbeatInterval: config.HeartbeatInterval,
|
|
||||||
MaxHeartbeats: config.MaxHeartbeats,
|
|
||||||
Logger: config.TransportLogger.WithFields(log.Fields{}),
|
|
||||||
CompressionQuality: h2mux.CompressionSetting(config.CompressionQuality),
|
|
||||||
}, h.metrics.activeStreams)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", errors.Wrap(err, "Handshake with edge error")
|
return nil, "", errors.Wrap(err, "Handshake with edge error")
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
package pogs
|
package pogs
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,20 +18,20 @@ type AuthenticateResponse struct {
|
||||||
|
|
||||||
// Outcome turns the deserialized response of Authenticate into a programmer-friendly sum type.
|
// Outcome turns the deserialized response of Authenticate into a programmer-friendly sum type.
|
||||||
func (ar AuthenticateResponse) Outcome() AuthOutcome {
|
func (ar AuthenticateResponse) Outcome() AuthOutcome {
|
||||||
// If there was a network error, then cloudflared should retry later,
|
|
||||||
// because origintunneld couldn't prove whether auth was correct or not.
|
|
||||||
if ar.RetryableErr != "" {
|
|
||||||
return NewAuthUnknown(fmt.Errorf(ar.RetryableErr), ar.HoursUntilRefresh)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the user's authentication was unsuccessful, the server will return an error explaining why.
|
// If the user's authentication was unsuccessful, the server will return an error explaining why.
|
||||||
// cloudflared should fatal with this error.
|
// cloudflared should fatal with this error.
|
||||||
if ar.PermanentErr != "" {
|
if ar.PermanentErr != "" {
|
||||||
return NewAuthFail(fmt.Errorf(ar.PermanentErr))
|
return NewAuthFail(errors.New(ar.PermanentErr))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there was a network error, then cloudflared should retry later,
|
||||||
|
// because origintunneld couldn't prove whether auth was correct or not.
|
||||||
|
if ar.RetryableErr != "" {
|
||||||
|
return NewAuthUnknown(errors.New(ar.RetryableErr), ar.HoursUntilRefresh)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If auth succeeded, return the token and refresh it when instructed.
|
// If auth succeeded, return the token and refresh it when instructed.
|
||||||
if ar.PermanentErr == "" && len(ar.Jwt) > 0 {
|
if len(ar.Jwt) > 0 {
|
||||||
return NewAuthSuccess(ar.Jwt, ar.HoursUntilRefresh)
|
return NewAuthSuccess(ar.Jwt, ar.HoursUntilRefresh)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,6 +57,10 @@ func NewAuthSuccess(jwt []byte, hoursUntilRefresh uint8) AuthSuccess {
|
||||||
return AuthSuccess{jwt: jwt, hoursUntilRefresh: hoursUntilRefresh}
|
return AuthSuccess{jwt: jwt, hoursUntilRefresh: hoursUntilRefresh}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ao AuthSuccess) JWT() []byte {
|
||||||
|
return ao.jwt
|
||||||
|
}
|
||||||
|
|
||||||
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
|
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
|
||||||
func (ao AuthSuccess) RefreshAfter() time.Duration {
|
func (ao AuthSuccess) RefreshAfter() time.Duration {
|
||||||
return hoursToTime(ao.hoursUntilRefresh)
|
return hoursToTime(ao.hoursUntilRefresh)
|
||||||
|
@ -81,6 +85,10 @@ func NewAuthFail(err error) AuthFail {
|
||||||
return AuthFail{err: err}
|
return AuthFail{err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ao AuthFail) Error() string {
|
||||||
|
return ao.err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
// Serialize into an AuthenticateResponse which can be sent via Capnp
|
// Serialize into an AuthenticateResponse which can be sent via Capnp
|
||||||
func (ao AuthFail) Serialize() AuthenticateResponse {
|
func (ao AuthFail) Serialize() AuthenticateResponse {
|
||||||
return AuthenticateResponse{
|
return AuthenticateResponse{
|
||||||
|
@ -100,6 +108,10 @@ func NewAuthUnknown(err error, hoursUntilRefresh uint8) AuthUnknown {
|
||||||
return AuthUnknown{err: err, hoursUntilRefresh: hoursUntilRefresh}
|
return AuthUnknown{err: err, hoursUntilRefresh: hoursUntilRefresh}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ao AuthUnknown) Error() string {
|
||||||
|
return ao.err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
|
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
|
||||||
func (ao AuthUnknown) RefreshAfter() time.Duration {
|
func (ao AuthUnknown) RefreshAfter() time.Duration {
|
||||||
return hoursToTime(ao.hoursUntilRefresh)
|
return hoursToTime(ao.hoursUntilRefresh)
|
||||||
|
@ -117,4 +129,4 @@ func (ao AuthUnknown) isAuthOutcome() {}
|
||||||
|
|
||||||
func hoursToTime(hours uint8) time.Duration {
|
func hoursToTime(hours uint8) time.Duration {
|
||||||
return time.Duration(hours) * time.Hour
|
return time.Duration(hours) * time.Hour
|
||||||
}
|
}
|
Loading…
Reference in New Issue