139 lines
3.7 KiB
Go
139 lines
3.7 KiB
Go
package supervisor
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
|
|
"github.com/cloudflare/cloudflared/retry"
|
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
)
|
|
|
|
var (
|
|
errJWTUnset = errors.New("JWT unset")
|
|
)
|
|
|
|
// reconnectTunnelCredentialManager is invoked by functions in tunnel.go to
|
|
// get/set parameters for ReconnectTunnel RPC calls.
|
|
type reconnectCredentialManager struct {
|
|
mu sync.RWMutex
|
|
jwt []byte
|
|
eventDigest map[uint8][]byte
|
|
connDigest map[uint8][]byte
|
|
authSuccess prometheus.Counter
|
|
authFail *prometheus.CounterVec
|
|
}
|
|
|
|
func newReconnectCredentialManager(namespace, subsystem string, haConnections int) *reconnectCredentialManager {
|
|
authSuccess := prometheus.NewCounter(
|
|
prometheus.CounterOpts{
|
|
Namespace: namespace,
|
|
Subsystem: subsystem,
|
|
Name: "tunnel_authenticate_success",
|
|
Help: "Count of successful tunnel authenticate",
|
|
},
|
|
)
|
|
authFail := prometheus.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Namespace: namespace,
|
|
Subsystem: subsystem,
|
|
Name: "tunnel_authenticate_fail",
|
|
Help: "Count of tunnel authenticate errors by type",
|
|
},
|
|
[]string{"error"},
|
|
)
|
|
prometheus.MustRegister(authSuccess, authFail)
|
|
return &reconnectCredentialManager{
|
|
eventDigest: make(map[uint8][]byte, haConnections),
|
|
connDigest: make(map[uint8][]byte, haConnections),
|
|
authSuccess: authSuccess,
|
|
authFail: authFail,
|
|
}
|
|
}
|
|
|
|
func (cm *reconnectCredentialManager) ReconnectToken() ([]byte, error) {
|
|
cm.mu.RLock()
|
|
defer cm.mu.RUnlock()
|
|
if cm.jwt == nil {
|
|
return nil, errJWTUnset
|
|
}
|
|
return cm.jwt, nil
|
|
}
|
|
|
|
func (cm *reconnectCredentialManager) SetReconnectToken(jwt []byte) {
|
|
cm.mu.Lock()
|
|
defer cm.mu.Unlock()
|
|
cm.jwt = jwt
|
|
}
|
|
|
|
func (cm *reconnectCredentialManager) EventDigest(connID uint8) ([]byte, error) {
|
|
cm.mu.RLock()
|
|
defer cm.mu.RUnlock()
|
|
digest, ok := cm.eventDigest[connID]
|
|
if !ok {
|
|
return nil, fmt.Errorf("no event digest for connection %v", connID)
|
|
}
|
|
return digest, nil
|
|
}
|
|
|
|
func (cm *reconnectCredentialManager) SetEventDigest(connID uint8, digest []byte) {
|
|
cm.mu.Lock()
|
|
defer cm.mu.Unlock()
|
|
cm.eventDigest[connID] = digest
|
|
}
|
|
|
|
func (cm *reconnectCredentialManager) ConnDigest(connID uint8) ([]byte, error) {
|
|
cm.mu.RLock()
|
|
defer cm.mu.RUnlock()
|
|
digest, ok := cm.connDigest[connID]
|
|
if !ok {
|
|
return nil, fmt.Errorf("no connection digest for connection %v", connID)
|
|
}
|
|
return digest, nil
|
|
}
|
|
|
|
func (cm *reconnectCredentialManager) SetConnDigest(connID uint8, digest []byte) {
|
|
cm.mu.Lock()
|
|
defer cm.mu.Unlock()
|
|
cm.connDigest[connID] = digest
|
|
}
|
|
|
|
func (cm *reconnectCredentialManager) RefreshAuth(
|
|
ctx context.Context,
|
|
backoff *retry.BackoffHandler,
|
|
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
|
|
) (retryTimer <-chan time.Time, err error) {
|
|
authOutcome, err := authenticate(ctx, backoff.Retries())
|
|
if err != nil {
|
|
cm.authFail.WithLabelValues(err.Error()).Inc()
|
|
if _, ok := backoff.GetMaxBackoffDuration(ctx); ok {
|
|
return backoff.BackoffTimer(), nil
|
|
}
|
|
return nil, err
|
|
}
|
|
// clear backoff timer
|
|
backoff.SetGracePeriod()
|
|
|
|
switch outcome := authOutcome.(type) {
|
|
case tunnelpogs.AuthSuccess:
|
|
cm.SetReconnectToken(outcome.JWT())
|
|
cm.authSuccess.Inc()
|
|
return retry.Clock.After(outcome.RefreshAfter()), nil
|
|
case tunnelpogs.AuthUnknown:
|
|
duration := outcome.RefreshAfter()
|
|
cm.authFail.WithLabelValues(outcome.Error()).Inc()
|
|
return retry.Clock.After(duration), nil
|
|
case tunnelpogs.AuthFail:
|
|
cm.authFail.WithLabelValues(outcome.Error()).Inc()
|
|
return nil, outcome
|
|
default:
|
|
err := fmt.Errorf("refresh_auth: Unexpected outcome type %T", authOutcome)
|
|
cm.authFail.WithLabelValues(err.Error()).Inc()
|
|
return nil, err
|
|
}
|
|
}
|