TUN-3268: Each connection has its own event digest to reconnect

This commit is contained in:
cthuang 2020-08-18 11:14:14 +01:00
parent 9323844ea7
commit 8eeb452cce
5 changed files with 230 additions and 263 deletions

View File

@ -58,11 +58,9 @@ type TunnelMetrics struct {
// oldServerLocations stores the last server the tunnel was connected to // oldServerLocations stores the last server the tunnel was connected to
oldServerLocations map[string]string oldServerLocations map[string]string
regSuccess *prometheus.CounterVec regSuccess *prometheus.CounterVec
regFail *prometheus.CounterVec regFail *prometheus.CounterVec
authSuccess prometheus.Counter rpcFail *prometheus.CounterVec
authFail *prometheus.CounterVec
rpcFail *prometheus.CounterVec
muxerMetrics *muxerMetrics muxerMetrics *muxerMetrics
tunnelsHA tunnelsForHA tunnelsHA tunnelsForHA
@ -456,27 +454,6 @@ func NewTunnelMetrics() *TunnelMetrics {
) )
prometheus.MustRegister(registerSuccess) prometheus.MustRegister(registerSuccess)
authSuccess := prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "tunnel_authenticate_success",
Help: "Count of successful tunnel authenticate",
},
)
prometheus.MustRegister(authSuccess)
authFail := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "tunnel_authenticate_fail",
Help: "Count of tunnel authenticate errors by type",
},
[]string{"error"},
)
prometheus.MustRegister(authFail)
return &TunnelMetrics{ return &TunnelMetrics{
haConnections: haConnections, haConnections: haConnections,
activeStreams: activeStreams, activeStreams: activeStreams,
@ -497,8 +474,6 @@ func NewTunnelMetrics() *TunnelMetrics {
regFail: registerFail, regFail: registerFail,
rpcFail: rpcFail, rpcFail: rpcFail,
userHostnamesCounts: userHostnamesCounts, userHostnamesCounts: userHostnamesCounts,
authSuccess: authSuccess,
authFail: authFail,
} }
} }

191
origin/reconnect.go Normal file
View File

@ -0,0 +1,191 @@
package origin
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
)
var (
errJWTUnset = errors.New("JWT unset")
)
// reconnectTunnelCredentialManager is invoked by functions in tunnel.go to
// get/set parameters for ReconnectTunnel RPC calls.
type reconnectCredentialManager struct {
mu sync.RWMutex
jwt []byte
eventDigest map[uint8][]byte
connDigest map[uint8][]byte
authSuccess prometheus.Counter
authFail *prometheus.CounterVec
}
func newReconnectCredentialManager(namespace, subsystem string, haConnections int) *reconnectCredentialManager {
authSuccess := prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "tunnel_authenticate_success",
Help: "Count of successful tunnel authenticate",
},
)
authFail := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: subsystem,
Name: "tunnel_authenticate_fail",
Help: "Count of tunnel authenticate errors by type",
},
[]string{"error"},
)
prometheus.MustRegister(authSuccess, authFail)
return &reconnectCredentialManager{
eventDigest: make(map[uint8][]byte, haConnections),
connDigest: make(map[uint8][]byte, haConnections),
authSuccess: authSuccess,
authFail: authFail,
}
}
func (cm *reconnectCredentialManager) ReconnectToken() ([]byte, error) {
cm.mu.RLock()
defer cm.mu.RUnlock()
if cm.jwt == nil {
return nil, errJWTUnset
}
return cm.jwt, nil
}
func (cm *reconnectCredentialManager) SetReconnectToken(jwt []byte) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.jwt = jwt
}
func (cm *reconnectCredentialManager) EventDigest(connID uint8) ([]byte, error) {
cm.mu.RLock()
defer cm.mu.RUnlock()
digest, ok := cm.eventDigest[connID]
if !ok {
return nil, fmt.Errorf("no event digest for connection %v", connID)
}
return digest, nil
}
func (cm *reconnectCredentialManager) SetEventDigest(connID uint8, digest []byte) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.eventDigest[connID] = digest
}
func (cm *reconnectCredentialManager) ConnDigest(connID uint8) ([]byte, error) {
cm.mu.RLock()
defer cm.mu.RUnlock()
digest, ok := cm.connDigest[connID]
if !ok {
return nil, fmt.Errorf("no conneciton digest for connection %v", connID)
}
return digest, nil
}
func (cm *reconnectCredentialManager) SetConnDigest(connID uint8, digest []byte) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.connDigest[connID] = digest
}
func (cm *reconnectCredentialManager) RefreshAuth(
ctx context.Context,
backoff *BackoffHandler,
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
) (retryTimer <-chan time.Time, err error) {
authOutcome, err := authenticate(ctx, backoff.Retries())
if err != nil {
cm.authFail.WithLabelValues(err.Error()).Inc()
if _, ok := backoff.GetBackoffDuration(ctx); ok {
return backoff.BackoffTimer(), nil
}
return nil, err
}
// clear backoff timer
backoff.SetGracePeriod()
switch outcome := authOutcome.(type) {
case tunnelpogs.AuthSuccess:
cm.SetReconnectToken(outcome.JWT())
cm.authSuccess.Inc()
return timeAfter(outcome.RefreshAfter()), nil
case tunnelpogs.AuthUnknown:
duration := outcome.RefreshAfter()
cm.authFail.WithLabelValues(outcome.Error()).Inc()
return timeAfter(duration), nil
case tunnelpogs.AuthFail:
cm.authFail.WithLabelValues(outcome.Error()).Inc()
return nil, outcome
default:
err := fmt.Errorf("refresh_auth: Unexpected outcome type %T", authOutcome)
cm.authFail.WithLabelValues(err.Error()).Inc()
return nil, err
}
}
func ReconnectTunnel(
ctx context.Context,
muxer *h2mux.Muxer,
config *TunnelConfig,
logger logger.Service,
connectionID uint8,
originLocalAddr string,
uuid uuid.UUID,
credentialManager *reconnectCredentialManager,
) error {
token, err := credentialManager.ReconnectToken()
if err != nil {
return err
}
eventDigest, err := credentialManager.EventDigest(connectionID)
if err != nil {
return err
}
connDigest, err := credentialManager.ConnDigest(connectionID)
if err != nil {
return err
}
config.TransportLogger.Debug("initiating RPC stream to reconnect")
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
if err != nil {
// RPC stream open error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, reconnect)
}
defer tunnelServer.Close()
// Request server info without blocking tunnel registration; must use capnp library directly.
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil
})
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
registration := tunnelServer.ReconnectTunnel(
ctx,
token,
eventDigest,
connDigest,
config.Hostname,
config.RegistrationOptions(connectionID, originLocalAddr, uuid),
)
if registrationErr := registration.DeserializeError(); registrationErr != nil {
// ReconnectTunnel RPC failure
return processRegisterTunnelError(registrationErr, config.Metrics, reconnect)
}
return processRegistrationSuccess(config, logger, connectionID, registration, reconnect, credentialManager)
}

View File

@ -7,51 +7,19 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/logger"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
func testConfig(logger logger.Service) *TunnelConfig {
metrics := TunnelMetrics{}
metrics.authSuccess = prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "tunnel_authenticate_success",
Help: "Count of successful tunnel authenticate",
},
)
metrics.authFail = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: metricsNamespace,
Subsystem: tunnelSubsystem,
Name: "tunnel_authenticate_fail",
Help: "Count of tunnel authenticate errors by type",
},
[]string{"error"},
)
return &TunnelConfig{Logger: logger, Metrics: &metrics}
}
func TestRefreshAuthBackoff(t *testing.T) { func TestRefreshAuthBackoff(t *testing.T) {
logger := logger.NewOutputWriter(logger.NewMockWriteManager()) rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
var wait time.Duration var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time { timeAfter = func(d time.Duration) <-chan time.Time {
wait = d wait = d
return time.After(d) return time.After(d)
} }
s, err := NewSupervisor(testConfig(logger), uuid.New())
if !assert.NoError(t, err) {
t.FailNow()
}
backoff := &BackoffHandler{MaxRetries: 3} backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return nil, fmt.Errorf("authentication failure") return nil, fmt.Errorf("authentication failure")
@ -59,17 +27,17 @@ func TestRefreshAuthBackoff(t *testing.T) {
// authentication failures should consume the backoff // authentication failures should consume the backoff
for i := uint(0); i < backoff.MaxRetries; i++ { for i := uint(0); i < backoff.MaxRetries; i++ {
retryChan, err := s.refreshAuth(context.Background(), backoff, auth) retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, retryChan) assert.NotNil(t, retryChan)
assert.Equal(t, (1<<i)*time.Second, wait) assert.Equal(t, (1<<i)*time.Second, wait)
} }
retryChan, err := s.refreshAuth(context.Background(), backoff, auth) retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, retryChan) assert.Nil(t, retryChan)
// now we actually make contact with the remote server // now we actually make contact with the remote server
_, _ = s.refreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { _, _ = rcm.RefreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
}) })
@ -84,7 +52,7 @@ func TestRefreshAuthBackoff(t *testing.T) {
} }
func TestRefreshAuthSuccess(t *testing.T) { func TestRefreshAuthSuccess(t *testing.T) {
logger := logger.NewOutputWriter(logger.NewMockWriteManager()) rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
var wait time.Duration var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time { timeAfter = func(d time.Duration) <-chan time.Time {
@ -92,27 +60,23 @@ func TestRefreshAuthSuccess(t *testing.T) {
return time.After(d) return time.After(d)
} }
s, err := NewSupervisor(testConfig(logger), uuid.New())
if !assert.NoError(t, err) {
t.FailNow()
}
backoff := &BackoffHandler{MaxRetries: 3} backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
} }
retryChan, err := s.refreshAuth(context.Background(), backoff, auth) retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, retryChan) assert.NotNil(t, retryChan)
assert.Equal(t, 19*time.Hour, wait) assert.Equal(t, 19*time.Hour, wait)
token, err := s.ReconnectToken() token, err := rcm.ReconnectToken()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []byte("jwt"), token) assert.Equal(t, []byte("jwt"), token)
} }
func TestRefreshAuthUnknown(t *testing.T) { func TestRefreshAuthUnknown(t *testing.T) {
logger := logger.NewOutputWriter(logger.NewMockWriteManager()) rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
var wait time.Duration var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time { timeAfter = func(d time.Duration) <-chan time.Time {
@ -120,42 +84,34 @@ func TestRefreshAuthUnknown(t *testing.T) {
return time.After(d) return time.After(d)
} }
s, err := NewSupervisor(testConfig(logger), uuid.New())
if !assert.NoError(t, err) {
t.FailNow()
}
backoff := &BackoffHandler{MaxRetries: 3} backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
} }
retryChan, err := s.refreshAuth(context.Background(), backoff, auth) retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, retryChan) assert.NotNil(t, retryChan)
assert.Equal(t, 19*time.Hour, wait) assert.Equal(t, 19*time.Hour, wait)
token, err := s.ReconnectToken() token, err := rcm.ReconnectToken()
assert.Equal(t, errJWTUnset, err) assert.Equal(t, errJWTUnset, err)
assert.Nil(t, token) assert.Nil(t, token)
} }
func TestRefreshAuthFail(t *testing.T) { func TestRefreshAuthFail(t *testing.T) {
logger := logger.NewOutputWriter(logger.NewMockWriteManager()) rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
s, err := NewSupervisor(testConfig(logger), uuid.New())
if !assert.NoError(t, err) {
t.FailNow()
}
backoff := &BackoffHandler{MaxRetries: 3} backoff := &BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
} }
retryChan, err := s.refreshAuth(context.Background(), backoff, auth) retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, retryChan) assert.Nil(t, retryChan)
token, err := s.ReconnectToken() token, err := rcm.ReconnectToken()
assert.Equal(t, errJWTUnset, err) assert.Equal(t, errJWTUnset, err)
assert.Nil(t, token) assert.Nil(t, token)
} }

View File

@ -3,9 +3,7 @@ package origin
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net" "net"
"sync"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@ -37,7 +35,6 @@ const (
) )
var ( var (
errJWTUnset = errors.New("JWT unset")
errEventDigestUnset = errors.New("event digest unset") errEventDigestUnset = errors.New("event digest unset")
) )
@ -58,14 +55,7 @@ type Supervisor struct {
logger logger.Service logger logger.Service
jwtLock sync.RWMutex reconnectCredentialManager *reconnectCredentialManager
jwt []byte
eventDigestLock sync.RWMutex
eventDigest []byte
connDigestLock sync.RWMutex
connDigest map[uint8][]byte
bufferPool *buffer.Pool bufferPool *buffer.Pool
} }
@ -95,14 +85,14 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
} }
return &Supervisor{ return &Supervisor{
cloudflaredUUID: cloudflaredUUID, cloudflaredUUID: cloudflaredUUID,
config: config, config: config,
edgeIPs: edgeIPs, edgeIPs: edgeIPs,
tunnelErrors: make(chan tunnelError), tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{}, tunnelsConnecting: map[int]chan struct{}{},
logger: config.Logger, logger: config.Logger,
connDigest: make(map[uint8][]byte), reconnectCredentialManager: newReconnectCredentialManager(metricsNamespace, tunnelSubsystem, config.HAConnections),
bufferPool: buffer.NewPool(512 * 1024), bufferPool: buffer.NewPool(512 * 1024),
}, nil }, nil
} }
@ -121,7 +111,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
var refreshAuthBackoffTimer <-chan time.Time var refreshAuthBackoffTimer <-chan time.Time
if s.config.UseReconnectToken { if s.config.UseReconnectToken {
if timer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil { if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
refreshAuthBackoffTimer = timer refreshAuthBackoffTimer = timer
} else { } else {
logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err) logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
@ -164,7 +154,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
tunnelsWaiting = nil tunnelsWaiting = nil
// Time to call Authenticate // Time to call Authenticate
case <-refreshAuthBackoffTimer: case <-refreshAuthBackoffTimer:
newTimer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate) newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
if err != nil { if err != nil {
logger.Errorf("supervisor: Authentication failed: %s", err) logger.Errorf("supervisor: Authentication failed: %s", err)
// Permanent failure. Leave the `select` without setting the // Permanent failure. Leave the `select` without setting the
@ -237,7 +227,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return return
} }
err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
// If the first tunnel disconnects, keep restarting it. // If the first tunnel disconnects, keep restarting it.
edgeErrors := 0 edgeErrors := 0
for s.unusedIPs() { for s.unusedIPs() {
@ -260,7 +250,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return return
} }
} }
err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
} }
} }
@ -279,7 +269,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
if err != nil { if err != nil {
return return
} }
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh) err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
} }
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal { func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
@ -305,90 +295,6 @@ func (s *Supervisor) unusedIPs() bool {
return s.edgeIPs.AvailableAddrs() > s.config.HAConnections return s.edgeIPs.AvailableAddrs() > s.config.HAConnections
} }
func (s *Supervisor) ReconnectToken() ([]byte, error) {
s.jwtLock.RLock()
defer s.jwtLock.RUnlock()
if s.jwt == nil {
return nil, errJWTUnset
}
return s.jwt, nil
}
func (s *Supervisor) SetReconnectToken(jwt []byte) {
s.jwtLock.Lock()
defer s.jwtLock.Unlock()
s.jwt = jwt
}
func (s *Supervisor) EventDigest() ([]byte, error) {
s.eventDigestLock.RLock()
defer s.eventDigestLock.RUnlock()
if s.eventDigest == nil {
return nil, errEventDigestUnset
}
return s.eventDigest, nil
}
func (s *Supervisor) SetEventDigest(eventDigest []byte) {
s.eventDigestLock.Lock()
defer s.eventDigestLock.Unlock()
s.eventDigest = eventDigest
}
func (s *Supervisor) ConnDigest(connID uint8) ([]byte, error) {
s.connDigestLock.RLock()
defer s.connDigestLock.RUnlock()
digest, ok := s.connDigest[connID]
if !ok {
return nil, fmt.Errorf("no connection digest for connection %v", connID)
}
return digest, nil
}
func (s *Supervisor) SetConnDigest(connID uint8, connDigest []byte) {
s.connDigestLock.Lock()
defer s.connDigestLock.Unlock()
s.connDigest[connID] = connDigest
}
func (s *Supervisor) refreshAuth(
ctx context.Context,
backoff *BackoffHandler,
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
) (retryTimer <-chan time.Time, err error) {
logger := s.config.Logger
authOutcome, err := authenticate(ctx, backoff.Retries())
if err != nil {
s.config.Metrics.authFail.WithLabelValues(err.Error()).Inc()
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
logger.Debugf("refresh_auth: Retrying in %v: %s", duration, err)
return backoff.BackoffTimer(), nil
}
return nil, err
}
// clear backoff timer
backoff.SetGracePeriod()
switch outcome := authOutcome.(type) {
case tunnelpogs.AuthSuccess:
s.SetReconnectToken(outcome.JWT())
s.config.Metrics.authSuccess.Inc()
return timeAfter(outcome.RefreshAfter()), nil
case tunnelpogs.AuthUnknown:
duration := outcome.RefreshAfter()
s.config.Metrics.authFail.WithLabelValues(outcome.Error()).Inc()
logger.Debugf("refresh_auth: Retrying in %v: %s", duration, outcome)
return timeAfter(duration), nil
case tunnelpogs.AuthFail:
s.config.Metrics.authFail.WithLabelValues(outcome.Error()).Inc()
return nil, outcome
default:
err := fmt.Errorf("refresh_auth: Unexpected outcome type %T", authOutcome)
s.config.Metrics.authFail.WithLabelValues(err.Error()).Inc()
return nil, err
}
}
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) { func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
arbitraryEdgeIP, err := s.edgeIPs.GetAddrForRPC() arbitraryEdgeIP, err := s.edgeIPs.GetAddrForRPC()
if err != nil { if err != nil {

View File

@ -90,16 +90,6 @@ type TunnelConfig struct {
ReplaceExisting bool ReplaceExisting bool
} }
// ReconnectTunnelCredentialManager is invoked by functions in this file to
// get/set parameters for ReconnectTunnel RPC calls.
type ReconnectTunnelCredentialManager interface {
ReconnectToken() ([]byte, error)
EventDigest() ([]byte, error)
SetEventDigest(eventDigest []byte)
ConnDigest(connID uint8) ([]byte, error)
SetConnDigest(connID uint8, connDigest []byte)
}
type dupConnRegisterTunnelError struct{} type dupConnRegisterTunnelError struct{}
var errDuplicationConnection = &dupConnRegisterTunnelError{} var errDuplicationConnection = &dupConnRegisterTunnelError{}
@ -209,7 +199,7 @@ func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSigna
} }
func ServeTunnelLoop(ctx context.Context, func ServeTunnelLoop(ctx context.Context,
credentialManager ReconnectTunnelCredentialManager, credentialManager *reconnectCredentialManager,
config *TunnelConfig, config *TunnelConfig,
addr *net.TCPAddr, addr *net.TCPAddr,
connectionIndex uint8, connectionIndex uint8,
@ -255,7 +245,7 @@ func ServeTunnelLoop(ctx context.Context,
func ServeTunnel( func ServeTunnel(
ctx context.Context, ctx context.Context,
credentialManager ReconnectTunnelCredentialManager, credentialManager *reconnectCredentialManager,
config *TunnelConfig, config *TunnelConfig,
logger logger.Service, logger logger.Service,
addr *net.TCPAddr, addr *net.TCPAddr,
@ -310,24 +300,12 @@ func ServeTunnel(
} }
if config.UseReconnectToken && connectedFuse.Value() { if config.UseReconnectToken && connectedFuse.Value() {
token, tokenErr := credentialManager.ReconnectToken() err := ReconnectTunnel(serveCtx, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
eventDigest, eventDigestErr := credentialManager.EventDigest() if err == nil {
// if we have both credentials, we can reconnect return nil
if tokenErr == nil && eventDigestErr == nil {
var connDigest []byte
if digest, connDigestErr := credentialManager.ConnDigest(connectionIndex); connDigestErr == nil {
connDigest = digest
}
return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
} }
// log errors and proceed to RegisterTunnel // log errors and proceed to RegisterTunnel
if tokenErr != nil { logger.Errorf("Couldn't reconnect connection %d. Reregistering it instead. Error was: %v", connectionIndex, err)
logger.Errorf("Couldn't get reconnect token: %s", tokenErr)
}
if eventDigestErr != nil {
logger.Errorf("Couldn't get event digest: %s", eventDigestErr)
}
} }
return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID) return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID)
}) })
@ -482,7 +460,7 @@ func UnregisterConnection(
func RegisterTunnel( func RegisterTunnel(
ctx context.Context, ctx context.Context,
credentialManager ReconnectTunnelCredentialManager, credentialManager *reconnectCredentialManager,
muxer *h2mux.Muxer, muxer *h2mux.Muxer,
config *TunnelConfig, config *TunnelConfig,
logger logger.Service, logger logger.Service,
@ -512,56 +490,17 @@ func RegisterTunnel(
// RegisterTunnel RPC failure // RegisterTunnel RPC failure
return processRegisterTunnelError(registrationErr, config.Metrics, register) return processRegisterTunnelError(registrationErr, config.Metrics, register)
} }
credentialManager.SetEventDigest(registration.EventDigest) credentialManager.SetEventDigest(connectionID, registration.EventDigest)
return processRegistrationSuccess(config, logger, connectionID, registration, register, credentialManager) return processRegistrationSuccess(config, logger, connectionID, registration, register, credentialManager)
} }
func ReconnectTunnel(
ctx context.Context,
token []byte,
eventDigest, connDigest []byte,
muxer *h2mux.Muxer,
config *TunnelConfig,
logger logger.Service,
connectionID uint8,
originLocalAddr string,
uuid uuid.UUID,
credentialManager ReconnectTunnelCredentialManager,
) error {
config.TransportLogger.Debug("initiating RPC stream to reconnect")
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
if err != nil {
// RPC stream open error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, reconnect)
}
defer tunnelServer.Close()
// Request server info without blocking tunnel registration; must use capnp library directly.
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil
})
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
registration := tunnelServer.ReconnectTunnel(
ctx,
token,
eventDigest,
connDigest,
config.Hostname,
config.RegistrationOptions(connectionID, originLocalAddr, uuid),
)
if registrationErr := registration.DeserializeError(); registrationErr != nil {
// ReconnectTunnel RPC failure
return processRegisterTunnelError(registrationErr, config.Metrics, reconnect)
}
return processRegistrationSuccess(config, logger, connectionID, registration, reconnect, credentialManager)
}
func processRegistrationSuccess( func processRegistrationSuccess(
config *TunnelConfig, config *TunnelConfig,
logger logger.Service, logger logger.Service,
connectionID uint8, connectionID uint8,
registration *tunnelpogs.TunnelRegistration, registration *tunnelpogs.TunnelRegistration,
name registerRPCName, name registerRPCName,
credentialManager ReconnectTunnelCredentialManager, credentialManager *reconnectCredentialManager,
) error { ) error {
for _, logLine := range registration.LogLines { for _, logLine := range registration.LogLines {
logger.Info(logLine) logger.Info(logLine)