TUN-7065: Remove classic tunnel creation

This commit is contained in:
Devin Carr 2023-02-06 09:13:05 -08:00
parent bd046677e5
commit ae46af9236
9 changed files with 80 additions and 270 deletions

View File

@ -155,8 +155,6 @@ func action(graceShutdownC chan struct{}) cli.ActionFunc {
if isEmptyInvocation(c) { if isEmptyInvocation(c) {
return handleServiceMode(c, graceShutdownC) return handleServiceMode(c, graceShutdownC)
} }
tags := make(map[string]string)
tags["hostname"] = c.String("hostname")
func() { func() {
defer sentry.Recover() defer sentry.Recover()
err = tunnel.TunnelCommand(c) err = tunnel.TunnelCommand(c)

View File

@ -36,6 +36,7 @@ import (
"github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/supervisor"
"github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns" "github.com/cloudflare/cloudflared/tunneldns"
"github.com/cloudflare/cloudflared/validation"
) )
const ( const (
@ -100,6 +101,7 @@ var (
routeFailMsg = fmt.Sprintf("failed to provision routing, please create it manually via Cloudflare dashboard or UI; "+ routeFailMsg = fmt.Sprintf("failed to provision routing, please create it manually via Cloudflare dashboard or UI; "+
"most likely you already have a conflicting record there. You can also rerun this command with --%s to overwrite "+ "most likely you already have a conflicting record there. You can also rerun this command with --%s to overwrite "+
"any existing DNS records for this hostname.", overwriteDNSFlag) "any existing DNS records for this hostname.", overwriteDNSFlag)
deprecatedClassicTunnelErr = fmt.Errorf("Classic tunnels have been deprecated, please use Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)")
) )
func Flags() []cli.Flag { func Flags() []cli.Flag {
@ -176,23 +178,40 @@ func TunnelCommand(c *cli.Context) error {
return err return err
} }
if name := c.String("name"); name != "" { // Start a named tunnel // Run a adhoc named tunnel
// Allows for the creation, routing (optional), and startup of a tunnel in one command
// --name required
// --url or --hello-world required
// --hostname optional
if name := c.String("name"); name != "" {
hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil {
return errors.Wrap(err, "Invalid hostname provided")
}
url := c.String("url")
if url == hostname && url != "" && hostname != "" {
return fmt.Errorf("hostname and url shouldn't match. See --help for more information")
}
return runAdhocNamedTunnel(sc, name, c.String(CredFileFlag)) return runAdhocNamedTunnel(sc, name, c.String(CredFileFlag))
} }
// Unauthenticated named tunnel on <random>.<quick-tunnels-service>.com // Run a quick tunnel
// A unauthenticated named tunnel hosted on <random>.<quick-tunnels-service>.com
// We don't support running proxy-dns and a quick tunnel at the same time as the same process
shouldRunQuickTunnel := c.IsSet("url") || c.IsSet("hello-world") shouldRunQuickTunnel := c.IsSet("url") || c.IsSet("hello-world")
if !dnsProxyStandAlone(c, nil) && c.String("hostname") == "" && c.String("quick-service") != "" && shouldRunQuickTunnel { if !c.IsSet("proxy-dns") && c.String("quick-service") != "" && shouldRunQuickTunnel {
return RunQuickTunnel(sc) return RunQuickTunnel(sc)
} }
// If user provides a config, check to see if they meant to use `tunnel run` instead
if ref := config.GetConfiguration().TunnelID; ref != "" { if ref := config.GetConfiguration().TunnelID; ref != "" {
return fmt.Errorf("Use `cloudflared tunnel run` to start tunnel %s", ref) return fmt.Errorf("Use `cloudflared tunnel run` to start tunnel %s", ref)
} }
// Start a classic tunnel if hostname is specified. // Classic tunnel usage is no longer supported
if c.String("hostname") != "" { if c.String("hostname") != "" {
return runClassicTunnel(sc) return deprecatedClassicTunnelErr
} }
if c.IsSet("proxy-dns") { if c.IsSet("proxy-dns") {
@ -237,11 +256,6 @@ func runAdhocNamedTunnel(sc *subcommandContext, name, credentialsOutputPath stri
return nil return nil
} }
// runClassicTunnel creates a "classic" non-named tunnel
func runClassicTunnel(sc *subcommandContext) error {
return StartServer(sc.c, buildInfo, nil, sc.log)
}
func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) { func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) {
if hostname := c.String("hostname"); hostname != "" { if hostname := c.String("hostname"); hostname != "" {
if lbPool := c.String("lb-pool"); lbPool != "" { if lbPool := c.String("lb-pool"); lbPool != "" {
@ -343,21 +357,13 @@ func StartServer(
errC <- autoupdater.Run(ctx) errC <- autoupdater.Run(ctx)
}() }()
// Serve DNS proxy stand-alone if no hostname or tag or app is going to run // Serve DNS proxy stand-alone if no tunnel type (quick, adhoc, named) is going to run
if dnsProxyStandAlone(c, namedTunnel) { if dnsProxyStandAlone(c, namedTunnel) {
connectedSignal.Notify() connectedSignal.Notify()
// no grace period, handle SIGINT/SIGTERM immediately // no grace period, handle SIGINT/SIGTERM immediately
return waitToShutdown(&wg, cancel, errC, graceShutdownC, 0, log) return waitToShutdown(&wg, cancel, errC, graceShutdownC, 0, log)
} }
url := c.String("url")
hostname := c.String("hostname")
if url == hostname && url != "" && hostname != "" {
errText := "hostname and url shouldn't match. See --help for more information"
log.Error().Msg(errText)
return fmt.Errorf(errText)
}
logTransport := logger.CreateTransportLoggerFromContext(c, logger.EnableTerminalLog) logTransport := logger.CreateTransportLoggerFromContext(c, logger.EnableTerminalLog)
observer := connection.NewObserver(log, logTransport) observer := connection.NewObserver(log, logTransport)

View File

@ -32,7 +32,6 @@ import (
"github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/supervisor"
"github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tlsconfig"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation"
) )
const LogFieldOriginCertPath = "originCertPath" const LogFieldOriginCertPath = "originCertPath"
@ -43,8 +42,6 @@ var (
serviceUrl = developerPortal + "/reference/service/" serviceUrl = developerPortal + "/reference/service/"
argumentsUrl = developerPortal + "/reference/arguments/" argumentsUrl = developerPortal + "/reference/arguments/"
LogFieldHostname = "hostname"
secretFlags = [2]*altsrc.StringFlag{credentialsContentsFlag, tunnelTokenFlag} secretFlags = [2]*altsrc.StringFlag{credentialsContentsFlag, tunnelTokenFlag}
defaultFeatures = []string{supervisor.FeatureAllowRemoteConfig, supervisor.FeatureSerializedHeaders, supervisor.FeatureDatagramV2, supervisor.FeatureQUICSupportEOF} defaultFeatures = []string{supervisor.FeatureAllowRemoteConfig, supervisor.FeatureSerializedHeaders, supervisor.FeatureDatagramV2, supervisor.FeatureQUICSupportEOF}
@ -127,7 +124,10 @@ func isSecretEnvVar(key string) bool {
} }
func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelProperties) bool { func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelProperties) bool {
return c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world") && namedTunnel == nil) return c.IsSet("proxy-dns") &&
!(c.IsSet("name") || // adhoc-named tunnel
c.IsSet("hello-world") || // quick or named tunnel
namedTunnel != nil) // named tunnel
} }
func findOriginCert(originCertPath string, log *zerolog.Logger) (string, error) { func findOriginCert(originCertPath string, log *zerolog.Logger) (string, error) {
@ -193,37 +193,19 @@ func prepareTunnelConfig(
observer *connection.Observer, observer *connection.Observer,
namedTunnel *connection.NamedTunnelProperties, namedTunnel *connection.NamedTunnelProperties,
) (*supervisor.TunnelConfig, *orchestration.Config, error) { ) (*supervisor.TunnelConfig, *orchestration.Config, error) {
isNamedTunnel := namedTunnel != nil clientID, err := uuid.NewRandom()
configHostname := c.String("hostname")
hostname, err := validation.ValidateHostname(configHostname)
if err != nil { if err != nil {
log.Err(err).Str(LogFieldHostname, configHostname).Msg("Invalid hostname") return nil, nil, errors.Wrap(err, "can't generate connector UUID")
return nil, nil, errors.Wrap(err, "Invalid hostname")
} }
clientID := c.String("id") log.Info().Msgf("Generated Connector ID: %s", clientID)
if !c.IsSet("id") {
clientID, err = generateRandomClientID(log)
if err != nil {
return nil, nil, err
}
}
tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
if err != nil { if err != nil {
log.Err(err).Msg("Tag parse failure") log.Err(err).Msg("Tag parse failure")
return nil, nil, errors.Wrap(err, "Tag parse failure") return nil, nil, errors.Wrap(err, "Tag parse failure")
} }
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID.String()})
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
var (
ingressRules ingress.Ingress
classicTunnel *connection.ClassicTunnelProperties
)
transportProtocol := c.String("protocol") transportProtocol := c.String("protocol")
needPQ := c.Bool("post-quantum") needPQ := c.Bool("post-quantum")
if needPQ { if needPQ {
if FipsEnabled { if FipsEnabled {
@ -238,79 +220,55 @@ func prepareTunnelConfig(
protocolFetcher := edgediscovery.ProtocolPercentage protocolFetcher := edgediscovery.ProtocolPercentage
cfg := config.GetConfiguration() features := append(c.StringSlice("features"), defaultFeatures...)
if isNamedTunnel { if needPQ {
clientUUID, err := uuid.NewRandom() features = append(features, supervisor.FeaturePostQuantum)
if err != nil { }
return nil, nil, errors.Wrap(err, "can't generate connector UUID") if c.IsSet(TunnelTokenFlag) {
} if transportProtocol == connection.AutoSelectFlag {
log.Info().Msgf("Generated Connector ID: %s", clientUUID) protocolFetcher = func() (edgediscovery.ProtocolPercents, error) {
features := append(c.StringSlice("features"), defaultFeatures...) // If the Tunnel is remotely managed and no protocol is set, we prefer QUIC, but still allow fall-back.
if needPQ { preferQuic := []edgediscovery.ProtocolPercent{
features = append(features, supervisor.FeaturePostQuantum) {
} Protocol: connection.QUIC.String(),
if c.IsSet(TunnelTokenFlag) { Percentage: 100,
if transportProtocol == connection.AutoSelectFlag { },
protocolFetcher = func() (edgediscovery.ProtocolPercents, error) { {
// If the Tunnel is remotely managed and no protocol is set, we prefer QUIC, but still allow fall-back. Protocol: connection.HTTP2.String(),
preferQuic := []edgediscovery.ProtocolPercent{ Percentage: 100,
{ },
Protocol: connection.QUIC.String(),
Percentage: 100,
},
{
Protocol: connection.HTTP2.String(),
Percentage: 100,
},
}
return preferQuic, nil
} }
return preferQuic, nil
} }
log.Info().Msg("Will be fetching remotely managed configuration from Cloudflare API. Defaulting to protocol: quic")
} }
namedTunnel.Client = tunnelpogs.ClientInfo{ log.Info().Msg("Will be fetching remotely managed configuration from Cloudflare API. Defaulting to protocol: quic")
ClientID: clientUUID[:], }
Features: dedup(features), namedTunnel.Client = tunnelpogs.ClientInfo{
Version: info.Version(), ClientID: clientID[:],
Arch: info.OSArch(), Features: dedup(features),
} Version: info.Version(),
ingressRules, err = ingress.ParseIngress(cfg) Arch: info.OSArch(),
if err != nil && err != ingress.ErrNoIngressRules { }
return nil, nil, err cfg := config.GetConfiguration()
} ingressRules, err := ingress.ParseIngress(cfg)
if !ingressRules.IsEmpty() && c.IsSet("url") { if err != nil && err != ingress.ErrNoIngressRules {
return nil, nil, err
}
if c.IsSet("url") {
// Ingress rules cannot be provided with --url flag
if !ingressRules.IsEmpty() {
return nil, nil, ingress.ErrURLIncompatibleWithIngress return nil, nil, ingress.ErrURLIncompatibleWithIngress
} } else {
} else { // Only for quick or adhoc tunnels will we attempt to parse:
// --url, --hello-world, or --unix-socket flag for a tunnel ingress rule
originCertPath := c.String("origincert") ingressRules, err = ingress.NewSingleOrigin(c, false)
originCertLog := log.With(). if err != nil {
Str(LogFieldOriginCertPath, originCertPath). return nil, nil, err
Logger() }
originCert, err := getOriginCert(originCertPath, &originCertLog)
if err != nil {
return nil, nil, errors.Wrap(err, "Error getting origin cert")
}
classicTunnel = &connection.ClassicTunnelProperties{
Hostname: hostname,
OriginCert: originCert,
// turn off use of reconnect token and auth refresh when using named tunnels
UseReconnectToken: !isNamedTunnel && c.Bool("use-reconnect-token"),
} }
} }
// Convert single-origin configuration into multi-origin configuration. protocolSelector, err := connection.NewProtocolSelector(transportProtocol, cfg.WarpRouting.Enabled, namedTunnel, protocolFetcher, supervisor.ResolveTTL, log, c.Bool("post-quantum"))
if ingressRules.IsEmpty() {
ingressRules, err = ingress.NewSingleOrigin(c, !isNamedTunnel)
if err != nil {
return nil, nil, err
}
}
warpRoutingEnabled := isWarpRoutingEnabled(cfg.WarpRouting, isNamedTunnel)
protocolSelector, err := connection.NewProtocolSelector(transportProtocol, warpRoutingEnabled, namedTunnel, protocolFetcher, supervisor.ResolveTTL, log, c.Bool("post-quantum"))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -362,7 +320,7 @@ func prepareTunnelConfig(
GracePeriod: gracePeriod, GracePeriod: gracePeriod,
ReplaceExisting: c.Bool("force"), ReplaceExisting: c.Bool("force"),
OSArch: info.OSArch(), OSArch: info.OSArch(),
ClientID: clientID, ClientID: clientID.String(),
EdgeAddrs: c.StringSlice("edge"), EdgeAddrs: c.StringSlice("edge"),
Region: c.String("region"), Region: c.String("region"),
EdgeIPVersion: edgeIPVersion, EdgeIPVersion: edgeIPVersion,
@ -379,7 +337,6 @@ func prepareTunnelConfig(
Retries: uint(c.Int("retries")), Retries: uint(c.Int("retries")),
RunFromTerminal: isRunningFromTerminal(), RunFromTerminal: isRunningFromTerminal(),
NamedTunnel: namedTunnel, NamedTunnel: namedTunnel,
ClassicTunnel: classicTunnel,
MuxerConfig: muxerConfig, MuxerConfig: muxerConfig,
ProtocolSelector: protocolSelector, ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs, EdgeTLSConfigs: edgeTLSConfigs,
@ -421,10 +378,6 @@ func gracePeriod(c *cli.Context) (time.Duration, error) {
return period, nil return period, nil
} }
func isWarpRoutingEnabled(warpConfig config.WarpRoutingConfig, isNamedTunnel bool) bool {
return warpConfig.Enabled && isNamedTunnel
}
func isRunningFromTerminal() bool { func isRunningFromTerminal() bool {
return terminal.IsTerminal(int(os.Stdout.Fd())) return terminal.IsTerminal(int(os.Stdout.Fd()))
} }

View File

@ -12,18 +12,12 @@ class TestProxyDns:
def test_proxy_dns_with_named_tunnel(self, tmp_path, component_tests_config): def test_proxy_dns_with_named_tunnel(self, tmp_path, component_tests_config):
run_test_scenario(tmp_path, component_tests_config, CfdModes.NAMED, run_proxy_dns=True) run_test_scenario(tmp_path, component_tests_config, CfdModes.NAMED, run_proxy_dns=True)
def test_proxy_dns_with_classic_tunnel(self, tmp_path, component_tests_config):
run_test_scenario(tmp_path, component_tests_config, CfdModes.CLASSIC, run_proxy_dns=True)
def test_proxy_dns_alone(self, tmp_path, component_tests_config): def test_proxy_dns_alone(self, tmp_path, component_tests_config):
run_test_scenario(tmp_path, component_tests_config, CfdModes.PROXY_DNS, run_proxy_dns=True) run_test_scenario(tmp_path, component_tests_config, CfdModes.PROXY_DNS, run_proxy_dns=True)
def test_named_tunnel_alone(self, tmp_path, component_tests_config): def test_named_tunnel_alone(self, tmp_path, component_tests_config):
run_test_scenario(tmp_path, component_tests_config, CfdModes.NAMED, run_proxy_dns=False) run_test_scenario(tmp_path, component_tests_config, CfdModes.NAMED, run_proxy_dns=False)
def test_classic_tunnel_alone(self, tmp_path, component_tests_config):
run_test_scenario(tmp_path, component_tests_config, CfdModes.CLASSIC, run_proxy_dns=False)
def run_test_scenario(tmp_path, component_tests_config, cfd_mode, run_proxy_dns): def run_test_scenario(tmp_path, component_tests_config, cfd_mode, run_proxy_dns):
expect_proxy_dns = run_proxy_dns expect_proxy_dns = run_proxy_dns
@ -33,10 +27,6 @@ def run_test_scenario(tmp_path, component_tests_config, cfd_mode, run_proxy_dns)
expect_tunnel = True expect_tunnel = True
pre_args = ["tunnel"] pre_args = ["tunnel"]
args = ["run"] args = ["run"]
elif cfd_mode == CfdModes.CLASSIC:
expect_tunnel = True
pre_args = []
args = []
elif cfd_mode == CfdModes.PROXY_DNS: elif cfd_mode == CfdModes.PROXY_DNS:
expect_proxy_dns = True expect_proxy_dns = True
pre_args = [] pre_args = []

View File

@ -33,13 +33,6 @@ class TestReconnect:
# Repeat the test multiple times because some issues only occur after multiple reconnects # Repeat the test multiple times because some issues only occur after multiple reconnects
self.assert_reconnect(config, cloudflared, 5) self.assert_reconnect(config, cloudflared, 5)
def test_classic_reconnect(self, tmp_path, component_tests_config):
extra_config = copy.copy(self.extra_config)
extra_config["hello-world"] = True
config = component_tests_config(additional_config=extra_config, cfd_mode=CfdModes.CLASSIC)
with start_cloudflared(tmp_path, config, cfd_args=[], new_process=True, allow_input=True, capture_output=False) as cloudflared:
self.assert_reconnect(config, cloudflared, 1)
def send_reconnect(self, cloudflared, secs): def send_reconnect(self, cloudflared, secs):
# Although it is recommended to use the Popen.communicate method, we cannot # Although it is recommended to use the Popen.communicate method, we cannot
# use it because it blocks on reading stdout and stderr until EOF is reached # use it because it blocks on reading stdout and stderr until EOF is reached

View File

@ -123,46 +123,6 @@ func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *Nam
return err return err
} }
func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelProperties, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return h.serveMuxer(serveCtx)
})
errGroup.Go(func() (err error) {
defer func() {
if err == nil {
connectedFuse.Connected()
}
}()
if classicTunnel.UseReconnectToken && connectedFuse.IsConnected() {
err := h.reconnectTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
if err == nil {
return nil
}
// log errors and proceed to RegisterTunnel
h.observer.log.Err(err).
Uint8(LogFieldConnIndex, h.connIndex).
Msg("Couldn't reconnect connection. Re-registering it instead.")
}
return h.registerTunnel(ctx, credentialManager, classicTunnel, registrationOptions)
})
errGroup.Go(func() error {
h.controlLoop(serveCtx, connectedFuse, false)
return nil
})
err := errGroup.Wait()
if err == errMuxerStopped {
if h.stoppedGracefully {
return nil
}
h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unexpected muxer shutdown")
}
return err
}
func (h *h2muxConnection) serveMuxer(ctx context.Context) error { func (h *h2muxConnection) serveMuxer(ctx context.Context) error {
// All routines should stop when muxer finish serving. When muxer is shutdown // All routines should stop when muxer finish serving. When muxer is shutdown
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown // gracefully, it doesn't return an error, so we need to return errMuxerShutdown

View File

@ -71,9 +71,8 @@ type Ingress struct {
} }
// NewSingleOrigin constructs an Ingress set with only one rule, constructed from // NewSingleOrigin constructs an Ingress set with only one rule, constructed from
// legacy CLI parameters like --url or --no-chunked-encoding. // CLI parameters for quick tunnels like --url or --no-chunked-encoding.
func NewSingleOrigin(c *cli.Context, allowURLFromArgs bool) (Ingress, error) { func NewSingleOrigin(c *cli.Context, allowURLFromArgs bool) (Ingress, error) {
service, err := parseSingleOriginService(c, allowURLFromArgs) service, err := parseSingleOriginService(c, allowURLFromArgs)
if err != nil { if err != nil {
return Ingress{}, err return Ingress{}, err

View File

@ -3,22 +3,18 @@ package supervisor
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net" "net"
"strings" "strings"
"time" "time"
"github.com/google/uuid"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/orchestration" "github.com/cloudflare/cloudflared/orchestration"
"github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/retry"
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/tunnelstate" "github.com/cloudflare/cloudflared/tunnelstate"
) )
@ -40,7 +36,6 @@ const (
// Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and // Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and
// reconnects them if they disconnect. // reconnects them if they disconnect.
type Supervisor struct { type Supervisor struct {
cloudflaredUUID uuid.UUID
config *TunnelConfig config *TunnelConfig
orchestrator *orchestration.Orchestrator orchestrator *orchestration.Orchestrator
edgeIPs *edgediscovery.Edge edgeIPs *edgediscovery.Edge
@ -57,7 +52,6 @@ type Supervisor struct {
logTransport *zerolog.Logger logTransport *zerolog.Logger
reconnectCredentialManager *reconnectCredentialManager reconnectCredentialManager *reconnectCredentialManager
useReconnectToken bool
reconnectCh chan ReconnectSignal reconnectCh chan ReconnectSignal
gracefulShutdownC <-chan struct{} gracefulShutdownC <-chan struct{}
@ -71,13 +65,9 @@ type tunnelError struct {
} }
func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrator, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) { func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrator, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
cloudflaredUUID, err := uuid.NewRandom()
if err != nil {
return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err)
}
isStaticEdge := len(config.EdgeAddrs) > 0 isStaticEdge := len(config.EdgeAddrs) > 0
var err error
var edgeIPs *edgediscovery.Edge var edgeIPs *edgediscovery.Edge
if isStaticEdge { // static edge addresses if isStaticEdge { // static edge addresses
edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs) edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs)
@ -97,7 +87,6 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
edgeTunnelServer := EdgeTunnelServer{ edgeTunnelServer := EdgeTunnelServer{
config: config, config: config,
cloudflaredUUID: cloudflaredUUID,
orchestrator: orchestrator, orchestrator: orchestrator,
credentialManager: reconnectCredentialManager, credentialManager: reconnectCredentialManager,
edgeAddrs: edgeIPs, edgeAddrs: edgeIPs,
@ -108,13 +97,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
connAwareLogger: log, connAwareLogger: log,
} }
useReconnectToken := false
if config.ClassicTunnel != nil {
useReconnectToken = config.ClassicTunnel.UseReconnectToken
}
return &Supervisor{ return &Supervisor{
cloudflaredUUID: cloudflaredUUID,
config: config, config: config,
orchestrator: orchestrator, orchestrator: orchestrator,
edgeIPs: edgeIPs, edgeIPs: edgeIPs,
@ -125,7 +108,6 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
log: log, log: log,
logTransport: config.LogTransport, logTransport: config.LogTransport,
reconnectCredentialManager: reconnectCredentialManager, reconnectCredentialManager: reconnectCredentialManager,
useReconnectToken: useReconnectToken,
reconnectCh: reconnectCh, reconnectCh: reconnectCh,
gracefulShutdownC: gracefulShutdownC, gracefulShutdownC: gracefulShutdownC,
}, nil }, nil
@ -159,20 +141,6 @@ func (s *Supervisor) Run(
backoff := retry.BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true} backoff := retry.BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
var backoffTimer <-chan time.Time var backoffTimer <-chan time.Time
refreshAuthBackoff := &retry.BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
var refreshAuthBackoffTimer <-chan time.Time
if s.useReconnectToken {
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
refreshAuthBackoffTimer = timer
} else {
s.log.Logger().Err(err).
Dur("refreshAuthRetryDuration", refreshAuthRetryDuration).
Msgf("supervisor: initial refreshAuth failed, retrying in %v", refreshAuthRetryDuration)
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
}
}
shuttingDown := false shuttingDown := false
for { for {
select { select {
@ -219,16 +187,6 @@ func (s *Supervisor) Run(
} }
tunnelsActive += len(tunnelsWaiting) tunnelsActive += len(tunnelsWaiting)
tunnelsWaiting = nil tunnelsWaiting = nil
// Time to call Authenticate
case <-refreshAuthBackoffTimer:
newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
if err != nil {
s.log.Logger().Err(err).Msg("supervisor: Authentication failed")
// Permanent failure. Leave the `select` without setting the
// channel to be non-null, so we'll never hit this case of the `select` again.
continue
}
refreshAuthBackoffTimer = newTimer
// Tunnel successfully connected // Tunnel successfully connected
case <-s.nextConnectedSignal: case <-s.nextConnectedSignal:
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 { if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
@ -377,44 +335,3 @@ func (s *Supervisor) waitForNextTunnel(index int) bool {
func (s *Supervisor) unusedIPs() bool { func (s *Supervisor) unusedIPs() bool {
return s.edgeIPs.AvailableAddrs() > s.config.HAConnections return s.edgeIPs.AvailableAddrs() > s.config.HAConnections
} }
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
arbitraryEdgeIP, err := s.edgeIPs.GetAddrForRPC()
if err != nil {
return nil, err
}
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.EdgeTLSConfigs[connection.H2mux], arbitraryEdgeIP.TCP)
if err != nil {
return nil, err
}
defer edgeConn.Close()
handler := h2mux.MuxedStreamFunc(func(*h2mux.MuxedStream) error {
// This callback is invoked by h2mux when the edge initiates a stream.
return nil // noop
})
muxerConfig := s.config.MuxerConfig.H2MuxerConfig(handler, s.logTransport)
muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig, h2mux.ActiveStreams)
if err != nil {
return nil, err
}
go muxer.Serve(ctx)
defer func() {
// If we don't wait for the muxer shutdown here, edgeConn.Close() runs before the muxer connections are done,
// and the user sees log noise: "error writing data", "connection closed unexpectedly"
<-muxer.Shutdown()
}()
stream, err := muxer.OpenRPCStream(ctx)
if err != nil {
return nil, err
}
rpcClient := connection.NewTunnelServerClient(ctx, stream, s.log.Logger())
defer rpcClient.Close()
const arbitraryConnectionID = uint8(0)
registrationOptions := s.config.registrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions)
}

View File

@ -68,7 +68,6 @@ type TunnelConfig struct {
PQKexIdx int PQKexIdx int
NamedTunnel *connection.NamedTunnelProperties NamedTunnel *connection.NamedTunnelProperties
ClassicTunnel *connection.ClassicTunnelProperties
MuxerConfig *connection.MuxerConfig MuxerConfig *connection.MuxerConfig
ProtocolSelector connection.ProtocolSelector ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config EdgeTLSConfigs map[connection.Protocol]*tls.Config
@ -204,7 +203,6 @@ func (f *ipAddrFallback) ShouldGetNewAddress(connIndex uint8, err error) (needsN
type EdgeTunnelServer struct { type EdgeTunnelServer struct {
config *TunnelConfig config *TunnelConfig
cloudflaredUUID uuid.UUID
orchestrator *orchestration.Orchestrator orchestrator *orchestration.Orchestrator
credentialManager *reconnectCredentialManager credentialManager *reconnectCredentialManager
edgeAddrHandler EdgeAddrHandler edgeAddrHandler EdgeAddrHandler
@ -577,12 +575,8 @@ func (e *EdgeTunnelServer) serveH2mux(
errGroup, serveCtx := errgroup.WithContext(ctx) errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
if e.config.NamedTunnel != nil { connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries()))
connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries())) return handler.ServeNamedTunnel(serveCtx, e.config.NamedTunnel, connOptions, connectedFuse)
return handler.ServeNamedTunnel(serveCtx, e.config.NamedTunnel, connOptions, connectedFuse)
}
registrationOptions := e.config.registrationOptions(connIndex, edgeConn.LocalAddr().String(), e.cloudflaredUUID)
return handler.ServeClassicTunnel(serveCtx, e.config.ClassicTunnel, e.credentialManager, registrationOptions, connectedFuse)
}) })
errGroup.Go(func() error { errGroup.Go(func() error {