TUN-3268: Each connection has its own event digest to reconnect
This commit is contained in:
parent
9323844ea7
commit
8eeb452cce
|
@ -60,8 +60,6 @@ type TunnelMetrics struct {
|
||||||
|
|
||||||
regSuccess *prometheus.CounterVec
|
regSuccess *prometheus.CounterVec
|
||||||
regFail *prometheus.CounterVec
|
regFail *prometheus.CounterVec
|
||||||
authSuccess prometheus.Counter
|
|
||||||
authFail *prometheus.CounterVec
|
|
||||||
rpcFail *prometheus.CounterVec
|
rpcFail *prometheus.CounterVec
|
||||||
|
|
||||||
muxerMetrics *muxerMetrics
|
muxerMetrics *muxerMetrics
|
||||||
|
@ -456,27 +454,6 @@ func NewTunnelMetrics() *TunnelMetrics {
|
||||||
)
|
)
|
||||||
prometheus.MustRegister(registerSuccess)
|
prometheus.MustRegister(registerSuccess)
|
||||||
|
|
||||||
authSuccess := prometheus.NewCounter(
|
|
||||||
prometheus.CounterOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: tunnelSubsystem,
|
|
||||||
Name: "tunnel_authenticate_success",
|
|
||||||
Help: "Count of successful tunnel authenticate",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(authSuccess)
|
|
||||||
|
|
||||||
authFail := prometheus.NewCounterVec(
|
|
||||||
prometheus.CounterOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: tunnelSubsystem,
|
|
||||||
Name: "tunnel_authenticate_fail",
|
|
||||||
Help: "Count of tunnel authenticate errors by type",
|
|
||||||
},
|
|
||||||
[]string{"error"},
|
|
||||||
)
|
|
||||||
prometheus.MustRegister(authFail)
|
|
||||||
|
|
||||||
return &TunnelMetrics{
|
return &TunnelMetrics{
|
||||||
haConnections: haConnections,
|
haConnections: haConnections,
|
||||||
activeStreams: activeStreams,
|
activeStreams: activeStreams,
|
||||||
|
@ -497,8 +474,6 @@ func NewTunnelMetrics() *TunnelMetrics {
|
||||||
regFail: registerFail,
|
regFail: registerFail,
|
||||||
rpcFail: rpcFail,
|
rpcFail: rpcFail,
|
||||||
userHostnamesCounts: userHostnamesCounts,
|
userHostnamesCounts: userHostnamesCounts,
|
||||||
authSuccess: authSuccess,
|
|
||||||
authFail: authFail,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,191 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errJWTUnset = errors.New("JWT unset")
|
||||||
|
)
|
||||||
|
|
||||||
|
// reconnectTunnelCredentialManager is invoked by functions in tunnel.go to
|
||||||
|
// get/set parameters for ReconnectTunnel RPC calls.
|
||||||
|
type reconnectCredentialManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
jwt []byte
|
||||||
|
eventDigest map[uint8][]byte
|
||||||
|
connDigest map[uint8][]byte
|
||||||
|
authSuccess prometheus.Counter
|
||||||
|
authFail *prometheus.CounterVec
|
||||||
|
}
|
||||||
|
|
||||||
|
func newReconnectCredentialManager(namespace, subsystem string, haConnections int) *reconnectCredentialManager {
|
||||||
|
authSuccess := prometheus.NewCounter(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Namespace: namespace,
|
||||||
|
Subsystem: subsystem,
|
||||||
|
Name: "tunnel_authenticate_success",
|
||||||
|
Help: "Count of successful tunnel authenticate",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
authFail := prometheus.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Namespace: namespace,
|
||||||
|
Subsystem: subsystem,
|
||||||
|
Name: "tunnel_authenticate_fail",
|
||||||
|
Help: "Count of tunnel authenticate errors by type",
|
||||||
|
},
|
||||||
|
[]string{"error"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(authSuccess, authFail)
|
||||||
|
return &reconnectCredentialManager{
|
||||||
|
eventDigest: make(map[uint8][]byte, haConnections),
|
||||||
|
connDigest: make(map[uint8][]byte, haConnections),
|
||||||
|
authSuccess: authSuccess,
|
||||||
|
authFail: authFail,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *reconnectCredentialManager) ReconnectToken() ([]byte, error) {
|
||||||
|
cm.mu.RLock()
|
||||||
|
defer cm.mu.RUnlock()
|
||||||
|
if cm.jwt == nil {
|
||||||
|
return nil, errJWTUnset
|
||||||
|
}
|
||||||
|
return cm.jwt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *reconnectCredentialManager) SetReconnectToken(jwt []byte) {
|
||||||
|
cm.mu.Lock()
|
||||||
|
defer cm.mu.Unlock()
|
||||||
|
cm.jwt = jwt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *reconnectCredentialManager) EventDigest(connID uint8) ([]byte, error) {
|
||||||
|
cm.mu.RLock()
|
||||||
|
defer cm.mu.RUnlock()
|
||||||
|
digest, ok := cm.eventDigest[connID]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no event digest for connection %v", connID)
|
||||||
|
}
|
||||||
|
return digest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *reconnectCredentialManager) SetEventDigest(connID uint8, digest []byte) {
|
||||||
|
cm.mu.Lock()
|
||||||
|
defer cm.mu.Unlock()
|
||||||
|
cm.eventDigest[connID] = digest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *reconnectCredentialManager) ConnDigest(connID uint8) ([]byte, error) {
|
||||||
|
cm.mu.RLock()
|
||||||
|
defer cm.mu.RUnlock()
|
||||||
|
digest, ok := cm.connDigest[connID]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no conneciton digest for connection %v", connID)
|
||||||
|
}
|
||||||
|
return digest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *reconnectCredentialManager) SetConnDigest(connID uint8, digest []byte) {
|
||||||
|
cm.mu.Lock()
|
||||||
|
defer cm.mu.Unlock()
|
||||||
|
cm.connDigest[connID] = digest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cm *reconnectCredentialManager) RefreshAuth(
|
||||||
|
ctx context.Context,
|
||||||
|
backoff *BackoffHandler,
|
||||||
|
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
|
||||||
|
) (retryTimer <-chan time.Time, err error) {
|
||||||
|
authOutcome, err := authenticate(ctx, backoff.Retries())
|
||||||
|
if err != nil {
|
||||||
|
cm.authFail.WithLabelValues(err.Error()).Inc()
|
||||||
|
if _, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||||
|
return backoff.BackoffTimer(), nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// clear backoff timer
|
||||||
|
backoff.SetGracePeriod()
|
||||||
|
|
||||||
|
switch outcome := authOutcome.(type) {
|
||||||
|
case tunnelpogs.AuthSuccess:
|
||||||
|
cm.SetReconnectToken(outcome.JWT())
|
||||||
|
cm.authSuccess.Inc()
|
||||||
|
return timeAfter(outcome.RefreshAfter()), nil
|
||||||
|
case tunnelpogs.AuthUnknown:
|
||||||
|
duration := outcome.RefreshAfter()
|
||||||
|
cm.authFail.WithLabelValues(outcome.Error()).Inc()
|
||||||
|
return timeAfter(duration), nil
|
||||||
|
case tunnelpogs.AuthFail:
|
||||||
|
cm.authFail.WithLabelValues(outcome.Error()).Inc()
|
||||||
|
return nil, outcome
|
||||||
|
default:
|
||||||
|
err := fmt.Errorf("refresh_auth: Unexpected outcome type %T", authOutcome)
|
||||||
|
cm.authFail.WithLabelValues(err.Error()).Inc()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReconnectTunnel(
|
||||||
|
ctx context.Context,
|
||||||
|
muxer *h2mux.Muxer,
|
||||||
|
config *TunnelConfig,
|
||||||
|
logger logger.Service,
|
||||||
|
connectionID uint8,
|
||||||
|
originLocalAddr string,
|
||||||
|
uuid uuid.UUID,
|
||||||
|
credentialManager *reconnectCredentialManager,
|
||||||
|
) error {
|
||||||
|
token, err := credentialManager.ReconnectToken()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
eventDigest, err := credentialManager.EventDigest(connectionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
connDigest, err := credentialManager.ConnDigest(connectionID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config.TransportLogger.Debug("initiating RPC stream to reconnect")
|
||||||
|
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
|
||||||
|
if err != nil {
|
||||||
|
// RPC stream open error
|
||||||
|
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, reconnect)
|
||||||
|
}
|
||||||
|
defer tunnelServer.Close()
|
||||||
|
// Request server info without blocking tunnel registration; must use capnp library directly.
|
||||||
|
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
|
||||||
|
registration := tunnelServer.ReconnectTunnel(
|
||||||
|
ctx,
|
||||||
|
token,
|
||||||
|
eventDigest,
|
||||||
|
connDigest,
|
||||||
|
config.Hostname,
|
||||||
|
config.RegistrationOptions(connectionID, originLocalAddr, uuid),
|
||||||
|
)
|
||||||
|
if registrationErr := registration.DeserializeError(); registrationErr != nil {
|
||||||
|
// ReconnectTunnel RPC failure
|
||||||
|
return processRegisterTunnelError(registrationErr, config.Metrics, reconnect)
|
||||||
|
}
|
||||||
|
return processRegistrationSuccess(config, logger, connectionID, registration, reconnect, credentialManager)
|
||||||
|
}
|
|
@ -7,51 +7,19 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testConfig(logger logger.Service) *TunnelConfig {
|
|
||||||
metrics := TunnelMetrics{}
|
|
||||||
|
|
||||||
metrics.authSuccess = prometheus.NewCounter(
|
|
||||||
prometheus.CounterOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: tunnelSubsystem,
|
|
||||||
Name: "tunnel_authenticate_success",
|
|
||||||
Help: "Count of successful tunnel authenticate",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
metrics.authFail = prometheus.NewCounterVec(
|
|
||||||
prometheus.CounterOpts{
|
|
||||||
Namespace: metricsNamespace,
|
|
||||||
Subsystem: tunnelSubsystem,
|
|
||||||
Name: "tunnel_authenticate_fail",
|
|
||||||
Help: "Count of tunnel authenticate errors by type",
|
|
||||||
},
|
|
||||||
[]string{"error"},
|
|
||||||
)
|
|
||||||
return &TunnelConfig{Logger: logger, Metrics: &metrics}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRefreshAuthBackoff(t *testing.T) {
|
func TestRefreshAuthBackoff(t *testing.T) {
|
||||||
logger := logger.NewOutputWriter(logger.NewMockWriteManager())
|
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
||||||
|
|
||||||
var wait time.Duration
|
var wait time.Duration
|
||||||
timeAfter = func(d time.Duration) <-chan time.Time {
|
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||||
wait = d
|
wait = d
|
||||||
return time.After(d)
|
return time.After(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := NewSupervisor(testConfig(logger), uuid.New())
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
backoff := &BackoffHandler{MaxRetries: 3}
|
backoff := &BackoffHandler{MaxRetries: 3}
|
||||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
return nil, fmt.Errorf("authentication failure")
|
return nil, fmt.Errorf("authentication failure")
|
||||||
|
@ -59,17 +27,17 @@ func TestRefreshAuthBackoff(t *testing.T) {
|
||||||
|
|
||||||
// authentication failures should consume the backoff
|
// authentication failures should consume the backoff
|
||||||
for i := uint(0); i < backoff.MaxRetries; i++ {
|
for i := uint(0); i < backoff.MaxRetries; i++ {
|
||||||
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, retryChan)
|
assert.NotNil(t, retryChan)
|
||||||
assert.Equal(t, (1<<i)*time.Second, wait)
|
assert.Equal(t, (1<<i)*time.Second, wait)
|
||||||
}
|
}
|
||||||
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, retryChan)
|
assert.Nil(t, retryChan)
|
||||||
|
|
||||||
// now we actually make contact with the remote server
|
// now we actually make contact with the remote server
|
||||||
_, _ = s.refreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
_, _ = rcm.RefreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -84,7 +52,7 @@ func TestRefreshAuthBackoff(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRefreshAuthSuccess(t *testing.T) {
|
func TestRefreshAuthSuccess(t *testing.T) {
|
||||||
logger := logger.NewOutputWriter(logger.NewMockWriteManager())
|
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
||||||
|
|
||||||
var wait time.Duration
|
var wait time.Duration
|
||||||
timeAfter = func(d time.Duration) <-chan time.Time {
|
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||||
|
@ -92,27 +60,23 @@ func TestRefreshAuthSuccess(t *testing.T) {
|
||||||
return time.After(d)
|
return time.After(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := NewSupervisor(testConfig(logger), uuid.New())
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
backoff := &BackoffHandler{MaxRetries: 3}
|
backoff := &BackoffHandler{MaxRetries: 3}
|
||||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
|
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, retryChan)
|
assert.NotNil(t, retryChan)
|
||||||
assert.Equal(t, 19*time.Hour, wait)
|
assert.Equal(t, 19*time.Hour, wait)
|
||||||
|
|
||||||
token, err := s.ReconnectToken()
|
token, err := rcm.ReconnectToken()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, []byte("jwt"), token)
|
assert.Equal(t, []byte("jwt"), token)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRefreshAuthUnknown(t *testing.T) {
|
func TestRefreshAuthUnknown(t *testing.T) {
|
||||||
logger := logger.NewOutputWriter(logger.NewMockWriteManager())
|
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
||||||
|
|
||||||
var wait time.Duration
|
var wait time.Duration
|
||||||
timeAfter = func(d time.Duration) <-chan time.Time {
|
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||||
|
@ -120,42 +84,34 @@ func TestRefreshAuthUnknown(t *testing.T) {
|
||||||
return time.After(d)
|
return time.After(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := NewSupervisor(testConfig(logger), uuid.New())
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
backoff := &BackoffHandler{MaxRetries: 3}
|
backoff := &BackoffHandler{MaxRetries: 3}
|
||||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, retryChan)
|
assert.NotNil(t, retryChan)
|
||||||
assert.Equal(t, 19*time.Hour, wait)
|
assert.Equal(t, 19*time.Hour, wait)
|
||||||
|
|
||||||
token, err := s.ReconnectToken()
|
token, err := rcm.ReconnectToken()
|
||||||
assert.Equal(t, errJWTUnset, err)
|
assert.Equal(t, errJWTUnset, err)
|
||||||
assert.Nil(t, token)
|
assert.Nil(t, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRefreshAuthFail(t *testing.T) {
|
func TestRefreshAuthFail(t *testing.T) {
|
||||||
logger := logger.NewOutputWriter(logger.NewMockWriteManager())
|
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
||||||
|
|
||||||
s, err := NewSupervisor(testConfig(logger), uuid.New())
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
backoff := &BackoffHandler{MaxRetries: 3}
|
backoff := &BackoffHandler{MaxRetries: 3}
|
||||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||||
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
|
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
retryChan, err := s.refreshAuth(context.Background(), backoff, auth)
|
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, retryChan)
|
assert.Nil(t, retryChan)
|
||||||
|
|
||||||
token, err := s.ReconnectToken()
|
token, err := rcm.ReconnectToken()
|
||||||
assert.Equal(t, errJWTUnset, err)
|
assert.Equal(t, errJWTUnset, err)
|
||||||
assert.Nil(t, token)
|
assert.Nil(t, token)
|
||||||
}
|
}
|
|
@ -3,9 +3,7 @@ package origin
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
@ -37,7 +35,6 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errJWTUnset = errors.New("JWT unset")
|
|
||||||
errEventDigestUnset = errors.New("event digest unset")
|
errEventDigestUnset = errors.New("event digest unset")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -58,14 +55,7 @@ type Supervisor struct {
|
||||||
|
|
||||||
logger logger.Service
|
logger logger.Service
|
||||||
|
|
||||||
jwtLock sync.RWMutex
|
reconnectCredentialManager *reconnectCredentialManager
|
||||||
jwt []byte
|
|
||||||
|
|
||||||
eventDigestLock sync.RWMutex
|
|
||||||
eventDigest []byte
|
|
||||||
|
|
||||||
connDigestLock sync.RWMutex
|
|
||||||
connDigest map[uint8][]byte
|
|
||||||
|
|
||||||
bufferPool *buffer.Pool
|
bufferPool *buffer.Pool
|
||||||
}
|
}
|
||||||
|
@ -101,7 +91,7 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
|
||||||
tunnelErrors: make(chan tunnelError),
|
tunnelErrors: make(chan tunnelError),
|
||||||
tunnelsConnecting: map[int]chan struct{}{},
|
tunnelsConnecting: map[int]chan struct{}{},
|
||||||
logger: config.Logger,
|
logger: config.Logger,
|
||||||
connDigest: make(map[uint8][]byte),
|
reconnectCredentialManager: newReconnectCredentialManager(metricsNamespace, tunnelSubsystem, config.HAConnections),
|
||||||
bufferPool: buffer.NewPool(512 * 1024),
|
bufferPool: buffer.NewPool(512 * 1024),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -121,7 +111,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||||
var refreshAuthBackoffTimer <-chan time.Time
|
var refreshAuthBackoffTimer <-chan time.Time
|
||||||
|
|
||||||
if s.config.UseReconnectToken {
|
if s.config.UseReconnectToken {
|
||||||
if timer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
|
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
|
||||||
refreshAuthBackoffTimer = timer
|
refreshAuthBackoffTimer = timer
|
||||||
} else {
|
} else {
|
||||||
logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
|
logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
|
||||||
|
@ -164,7 +154,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||||
tunnelsWaiting = nil
|
tunnelsWaiting = nil
|
||||||
// Time to call Authenticate
|
// Time to call Authenticate
|
||||||
case <-refreshAuthBackoffTimer:
|
case <-refreshAuthBackoffTimer:
|
||||||
newTimer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate)
|
newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("supervisor: Authentication failed: %s", err)
|
logger.Errorf("supervisor: Authentication failed: %s", err)
|
||||||
// Permanent failure. Leave the `select` without setting the
|
// Permanent failure. Leave the `select` without setting the
|
||||||
|
@ -237,7 +227,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||||
// If the first tunnel disconnects, keep restarting it.
|
// If the first tunnel disconnects, keep restarting it.
|
||||||
edgeErrors := 0
|
edgeErrors := 0
|
||||||
for s.unusedIPs() {
|
for s.unusedIPs() {
|
||||||
|
@ -260,7 +250,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -279,7 +269,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
||||||
|
@ -305,90 +295,6 @@ func (s *Supervisor) unusedIPs() bool {
|
||||||
return s.edgeIPs.AvailableAddrs() > s.config.HAConnections
|
return s.edgeIPs.AvailableAddrs() > s.config.HAConnections
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Supervisor) ReconnectToken() ([]byte, error) {
|
|
||||||
s.jwtLock.RLock()
|
|
||||||
defer s.jwtLock.RUnlock()
|
|
||||||
if s.jwt == nil {
|
|
||||||
return nil, errJWTUnset
|
|
||||||
}
|
|
||||||
return s.jwt, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Supervisor) SetReconnectToken(jwt []byte) {
|
|
||||||
s.jwtLock.Lock()
|
|
||||||
defer s.jwtLock.Unlock()
|
|
||||||
s.jwt = jwt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Supervisor) EventDigest() ([]byte, error) {
|
|
||||||
s.eventDigestLock.RLock()
|
|
||||||
defer s.eventDigestLock.RUnlock()
|
|
||||||
if s.eventDigest == nil {
|
|
||||||
return nil, errEventDigestUnset
|
|
||||||
}
|
|
||||||
return s.eventDigest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Supervisor) SetEventDigest(eventDigest []byte) {
|
|
||||||
s.eventDigestLock.Lock()
|
|
||||||
defer s.eventDigestLock.Unlock()
|
|
||||||
s.eventDigest = eventDigest
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Supervisor) ConnDigest(connID uint8) ([]byte, error) {
|
|
||||||
s.connDigestLock.RLock()
|
|
||||||
defer s.connDigestLock.RUnlock()
|
|
||||||
digest, ok := s.connDigest[connID]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("no connection digest for connection %v", connID)
|
|
||||||
}
|
|
||||||
return digest, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Supervisor) SetConnDigest(connID uint8, connDigest []byte) {
|
|
||||||
s.connDigestLock.Lock()
|
|
||||||
defer s.connDigestLock.Unlock()
|
|
||||||
s.connDigest[connID] = connDigest
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Supervisor) refreshAuth(
|
|
||||||
ctx context.Context,
|
|
||||||
backoff *BackoffHandler,
|
|
||||||
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
|
|
||||||
) (retryTimer <-chan time.Time, err error) {
|
|
||||||
logger := s.config.Logger
|
|
||||||
authOutcome, err := authenticate(ctx, backoff.Retries())
|
|
||||||
if err != nil {
|
|
||||||
s.config.Metrics.authFail.WithLabelValues(err.Error()).Inc()
|
|
||||||
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
|
||||||
logger.Debugf("refresh_auth: Retrying in %v: %s", duration, err)
|
|
||||||
return backoff.BackoffTimer(), nil
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// clear backoff timer
|
|
||||||
backoff.SetGracePeriod()
|
|
||||||
|
|
||||||
switch outcome := authOutcome.(type) {
|
|
||||||
case tunnelpogs.AuthSuccess:
|
|
||||||
s.SetReconnectToken(outcome.JWT())
|
|
||||||
s.config.Metrics.authSuccess.Inc()
|
|
||||||
return timeAfter(outcome.RefreshAfter()), nil
|
|
||||||
case tunnelpogs.AuthUnknown:
|
|
||||||
duration := outcome.RefreshAfter()
|
|
||||||
s.config.Metrics.authFail.WithLabelValues(outcome.Error()).Inc()
|
|
||||||
logger.Debugf("refresh_auth: Retrying in %v: %s", duration, outcome)
|
|
||||||
return timeAfter(duration), nil
|
|
||||||
case tunnelpogs.AuthFail:
|
|
||||||
s.config.Metrics.authFail.WithLabelValues(outcome.Error()).Inc()
|
|
||||||
return nil, outcome
|
|
||||||
default:
|
|
||||||
err := fmt.Errorf("refresh_auth: Unexpected outcome type %T", authOutcome)
|
|
||||||
s.config.Metrics.authFail.WithLabelValues(err.Error()).Inc()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
|
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
|
||||||
arbitraryEdgeIP, err := s.edgeIPs.GetAddrForRPC()
|
arbitraryEdgeIP, err := s.edgeIPs.GetAddrForRPC()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -90,16 +90,6 @@ type TunnelConfig struct {
|
||||||
ReplaceExisting bool
|
ReplaceExisting bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReconnectTunnelCredentialManager is invoked by functions in this file to
|
|
||||||
// get/set parameters for ReconnectTunnel RPC calls.
|
|
||||||
type ReconnectTunnelCredentialManager interface {
|
|
||||||
ReconnectToken() ([]byte, error)
|
|
||||||
EventDigest() ([]byte, error)
|
|
||||||
SetEventDigest(eventDigest []byte)
|
|
||||||
ConnDigest(connID uint8) ([]byte, error)
|
|
||||||
SetConnDigest(connID uint8, connDigest []byte)
|
|
||||||
}
|
|
||||||
|
|
||||||
type dupConnRegisterTunnelError struct{}
|
type dupConnRegisterTunnelError struct{}
|
||||||
|
|
||||||
var errDuplicationConnection = &dupConnRegisterTunnelError{}
|
var errDuplicationConnection = &dupConnRegisterTunnelError{}
|
||||||
|
@ -209,7 +199,7 @@ func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSigna
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServeTunnelLoop(ctx context.Context,
|
func ServeTunnelLoop(ctx context.Context,
|
||||||
credentialManager ReconnectTunnelCredentialManager,
|
credentialManager *reconnectCredentialManager,
|
||||||
config *TunnelConfig,
|
config *TunnelConfig,
|
||||||
addr *net.TCPAddr,
|
addr *net.TCPAddr,
|
||||||
connectionIndex uint8,
|
connectionIndex uint8,
|
||||||
|
@ -255,7 +245,7 @@ func ServeTunnelLoop(ctx context.Context,
|
||||||
|
|
||||||
func ServeTunnel(
|
func ServeTunnel(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
credentialManager ReconnectTunnelCredentialManager,
|
credentialManager *reconnectCredentialManager,
|
||||||
config *TunnelConfig,
|
config *TunnelConfig,
|
||||||
logger logger.Service,
|
logger logger.Service,
|
||||||
addr *net.TCPAddr,
|
addr *net.TCPAddr,
|
||||||
|
@ -310,24 +300,12 @@ func ServeTunnel(
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.UseReconnectToken && connectedFuse.Value() {
|
if config.UseReconnectToken && connectedFuse.Value() {
|
||||||
token, tokenErr := credentialManager.ReconnectToken()
|
err := ReconnectTunnel(serveCtx, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
|
||||||
eventDigest, eventDigestErr := credentialManager.EventDigest()
|
if err == nil {
|
||||||
// if we have both credentials, we can reconnect
|
return nil
|
||||||
if tokenErr == nil && eventDigestErr == nil {
|
|
||||||
var connDigest []byte
|
|
||||||
if digest, connDigestErr := credentialManager.ConnDigest(connectionIndex); connDigestErr == nil {
|
|
||||||
connDigest = digest
|
|
||||||
}
|
|
||||||
|
|
||||||
return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
|
|
||||||
}
|
}
|
||||||
// log errors and proceed to RegisterTunnel
|
// log errors and proceed to RegisterTunnel
|
||||||
if tokenErr != nil {
|
logger.Errorf("Couldn't reconnect connection %d. Reregistering it instead. Error was: %v", connectionIndex, err)
|
||||||
logger.Errorf("Couldn't get reconnect token: %s", tokenErr)
|
|
||||||
}
|
|
||||||
if eventDigestErr != nil {
|
|
||||||
logger.Errorf("Couldn't get event digest: %s", eventDigestErr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID)
|
return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID)
|
||||||
})
|
})
|
||||||
|
@ -482,7 +460,7 @@ func UnregisterConnection(
|
||||||
|
|
||||||
func RegisterTunnel(
|
func RegisterTunnel(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
credentialManager ReconnectTunnelCredentialManager,
|
credentialManager *reconnectCredentialManager,
|
||||||
muxer *h2mux.Muxer,
|
muxer *h2mux.Muxer,
|
||||||
config *TunnelConfig,
|
config *TunnelConfig,
|
||||||
logger logger.Service,
|
logger logger.Service,
|
||||||
|
@ -512,56 +490,17 @@ func RegisterTunnel(
|
||||||
// RegisterTunnel RPC failure
|
// RegisterTunnel RPC failure
|
||||||
return processRegisterTunnelError(registrationErr, config.Metrics, register)
|
return processRegisterTunnelError(registrationErr, config.Metrics, register)
|
||||||
}
|
}
|
||||||
credentialManager.SetEventDigest(registration.EventDigest)
|
credentialManager.SetEventDigest(connectionID, registration.EventDigest)
|
||||||
return processRegistrationSuccess(config, logger, connectionID, registration, register, credentialManager)
|
return processRegistrationSuccess(config, logger, connectionID, registration, register, credentialManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReconnectTunnel(
|
|
||||||
ctx context.Context,
|
|
||||||
token []byte,
|
|
||||||
eventDigest, connDigest []byte,
|
|
||||||
muxer *h2mux.Muxer,
|
|
||||||
config *TunnelConfig,
|
|
||||||
logger logger.Service,
|
|
||||||
connectionID uint8,
|
|
||||||
originLocalAddr string,
|
|
||||||
uuid uuid.UUID,
|
|
||||||
credentialManager ReconnectTunnelCredentialManager,
|
|
||||||
) error {
|
|
||||||
config.TransportLogger.Debug("initiating RPC stream to reconnect")
|
|
||||||
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
|
|
||||||
if err != nil {
|
|
||||||
// RPC stream open error
|
|
||||||
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, reconnect)
|
|
||||||
}
|
|
||||||
defer tunnelServer.Close()
|
|
||||||
// Request server info without blocking tunnel registration; must use capnp library directly.
|
|
||||||
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger)
|
|
||||||
registration := tunnelServer.ReconnectTunnel(
|
|
||||||
ctx,
|
|
||||||
token,
|
|
||||||
eventDigest,
|
|
||||||
connDigest,
|
|
||||||
config.Hostname,
|
|
||||||
config.RegistrationOptions(connectionID, originLocalAddr, uuid),
|
|
||||||
)
|
|
||||||
if registrationErr := registration.DeserializeError(); registrationErr != nil {
|
|
||||||
// ReconnectTunnel RPC failure
|
|
||||||
return processRegisterTunnelError(registrationErr, config.Metrics, reconnect)
|
|
||||||
}
|
|
||||||
return processRegistrationSuccess(config, logger, connectionID, registration, reconnect, credentialManager)
|
|
||||||
}
|
|
||||||
|
|
||||||
func processRegistrationSuccess(
|
func processRegistrationSuccess(
|
||||||
config *TunnelConfig,
|
config *TunnelConfig,
|
||||||
logger logger.Service,
|
logger logger.Service,
|
||||||
connectionID uint8,
|
connectionID uint8,
|
||||||
registration *tunnelpogs.TunnelRegistration,
|
registration *tunnelpogs.TunnelRegistration,
|
||||||
name registerRPCName,
|
name registerRPCName,
|
||||||
credentialManager ReconnectTunnelCredentialManager,
|
credentialManager *reconnectCredentialManager,
|
||||||
) error {
|
) error {
|
||||||
for _, logLine := range registration.LogLines {
|
for _, logLine := range registration.LogLines {
|
||||||
logger.Info(logLine)
|
logger.Info(logLine)
|
||||||
|
|
Loading…
Reference in New Issue