From e22422aafb411c6952ffdb0b2e9094b75c68209b Mon Sep 17 00:00:00 2001 From: cthuang Date: Mon, 7 Feb 2022 09:42:07 +0000 Subject: [PATCH] TUN-5749: Refactor cloudflared to pave way for reconfigurable ingress - Split origin into supervisor and proxy packages - Create configManager to handle dynamic config --- cmd/cloudflared/tunnel/cmd.go | 18 ++-- cmd/cloudflared/tunnel/configuration.go | 90 +++++++++---------- cmd/cloudflared/tunnel/quick_tunnel.go | 2 +- cmd/cloudflared/tunnel/subcommand_context.go | 2 +- connection/connection.go | 11 ++- connection/connection_test.go | 23 ++++- connection/control.go | 24 ++--- connection/h2mux.go | 19 ++-- connection/h2mux_test.go | 2 +- connection/http2.go | 20 ++--- connection/http2_test.go | 14 +-- connection/protocol.go | 2 +- connection/protocol_test.go | 46 +++++----- connection/quic.go | 10 +-- connection/quic_test.go | 3 +- connection/rpc.go | 24 ++--- {origin => proxy}/metrics.go | 11 +-- {origin => proxy}/pool.go | 2 +- {origin => proxy}/proxy.go | 8 +- {origin => proxy}/proxy_posix_test.go | 2 +- {origin => proxy}/proxy_test.go | 10 +-- .../cloudflare_status_page.go | 2 +- .../cloudflare_status_page_test.go | 2 +- supervisor/configmanager.go | 55 ++++++++++++ {origin => supervisor}/conn_aware_logger.go | 2 +- {origin => supervisor}/external_control.go | 2 +- supervisor/metrics.go | 27 ++++++ {origin => supervisor}/reconnect.go | 2 +- {origin => supervisor}/reconnect_test.go | 2 +- {origin => supervisor}/supervisor.go | 11 ++- {origin => supervisor}/tunnel.go | 79 +++++++++------- {origin => supervisor}/tunnel_test.go | 8 +- {origin => supervisor}/tunnelsforha.go | 2 +- 33 files changed, 317 insertions(+), 220 deletions(-) rename {origin => proxy}/metrics.go (85%) rename {origin => proxy}/pool.go (96%) rename {origin => proxy}/proxy.go (98%) rename {origin => proxy}/proxy_posix_test.go (98%) rename {origin => proxy}/proxy_test.go (98%) rename {origin => supervisor}/cloudflare_status_page.go (99%) rename {origin => supervisor}/cloudflare_status_page_test.go (99%) create mode 100644 supervisor/configmanager.go rename {origin => supervisor}/conn_aware_logger.go (97%) rename {origin => supervisor}/external_control.go (95%) create mode 100644 supervisor/metrics.go rename {origin => supervisor}/reconnect.go (99%) rename {origin => supervisor}/reconnect_test.go (99%) rename {origin => supervisor}/supervisor.go (96%) rename {origin => supervisor}/tunnel.go (90%) rename {origin => supervisor}/tunnel_test.go (96%) rename {origin => supervisor}/tunnelsforha.go (98%) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 8833e209..728bb8b7 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -31,8 +31,8 @@ import ( "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/metrics" - "github.com/cloudflare/cloudflared/origin" "github.com/cloudflare/cloudflared/signal" + "github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tunneldns" ) @@ -223,7 +223,7 @@ func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) { func StartServer( c *cli.Context, info *cliutil.BuildInfo, - namedTunnel *connection.NamedTunnelConfig, + namedTunnel *connection.NamedTunnelProperties, log *zerolog.Logger, isUIEnabled bool, ) error { @@ -333,7 +333,7 @@ func StartServer( observer.SendURL(quickTunnelURL) } - tunnelConfig, ingressRules, err := prepareTunnelConfig(c, info, log, logTransport, observer, namedTunnel) + tunnelConfig, dynamicConfig, err := prepareTunnelConfig(c, info, log, logTransport, observer, namedTunnel) if err != nil { log.Err(err).Msg("Couldn't start tunnel") return err @@ -353,11 +353,11 @@ func StartServer( errC <- metrics.ServeMetrics(metricsListener, ctx.Done(), readinessServer, quickTunnelURL, log) }() - if err := ingressRules.StartOrigins(&wg, log, ctx.Done(), errC); err != nil { + if err := dynamicConfig.Ingress.StartOrigins(&wg, log, ctx.Done(), errC); err != nil { return err } - reconnectCh := make(chan origin.ReconnectSignal, 1) + reconnectCh := make(chan supervisor.ReconnectSignal, 1) if c.IsSet("stdin-control") { log.Info().Msg("Enabling control through stdin") go stdinControl(reconnectCh, log) @@ -369,7 +369,7 @@ func StartServer( wg.Done() log.Info().Msg("Tunnel server stopped") }() - errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, reconnectCh, graceShutdownC) + errC <- supervisor.StartTunnelDaemon(ctx, tunnelConfig, dynamicConfig, connectedSignal, reconnectCh, graceShutdownC) }() if isUIEnabled { @@ -377,7 +377,7 @@ func StartServer( info.Version(), hostname, metricsListener.Addr().String(), - &ingressRules, + dynamicConfig.Ingress, tunnelConfig.HAConnections, ) app := tunnelUI.Launch(ctx, log, logTransport) @@ -998,7 +998,7 @@ func configureProxyDNSFlags(shouldHide bool) []cli.Flag { } } -func stdinControl(reconnectCh chan origin.ReconnectSignal, log *zerolog.Logger) { +func stdinControl(reconnectCh chan supervisor.ReconnectSignal, log *zerolog.Logger) { for { scanner := bufio.NewScanner(os.Stdin) for scanner.Scan() { @@ -1009,7 +1009,7 @@ func stdinControl(reconnectCh chan origin.ReconnectSignal, log *zerolog.Logger) case "": break case "reconnect": - var reconnect origin.ReconnectSignal + var reconnect supervisor.ReconnectSignal if len(parts) > 1 { var err error if reconnect.Delay, err = time.ParseDuration(parts[1]); err != nil { diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 0d29cc7d..61e65a8f 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -23,7 +23,7 @@ import ( "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/ingress" - "github.com/cloudflare/cloudflared/origin" + "github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/tlsconfig" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/validation" @@ -87,7 +87,7 @@ func logClientOptions(c *cli.Context, log *zerolog.Logger) { } } -func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelConfig) bool { +func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelProperties) bool { return c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world") && namedTunnel == nil) } @@ -152,44 +152,44 @@ func prepareTunnelConfig( info *cliutil.BuildInfo, log, logTransport *zerolog.Logger, observer *connection.Observer, - namedTunnel *connection.NamedTunnelConfig, -) (*origin.TunnelConfig, ingress.Ingress, error) { + namedTunnel *connection.NamedTunnelProperties, +) (*supervisor.TunnelConfig, *supervisor.DynamicConfig, error) { isNamedTunnel := namedTunnel != nil configHostname := c.String("hostname") hostname, err := validation.ValidateHostname(configHostname) if err != nil { log.Err(err).Str(LogFieldHostname, configHostname).Msg("Invalid hostname") - return nil, ingress.Ingress{}, errors.Wrap(err, "Invalid hostname") + return nil, nil, errors.Wrap(err, "Invalid hostname") } clientID := c.String("id") if !c.IsSet("id") { clientID, err = generateRandomClientID(log) if err != nil { - return nil, ingress.Ingress{}, err + return nil, nil, err } } tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) if err != nil { log.Err(err).Msg("Tag parse failure") - return nil, ingress.Ingress{}, errors.Wrap(err, "Tag parse failure") + return nil, nil, errors.Wrap(err, "Tag parse failure") } tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) var ( ingressRules ingress.Ingress - classicTunnel *connection.ClassicTunnelConfig + classicTunnel *connection.ClassicTunnelProperties ) cfg := config.GetConfiguration() if isNamedTunnel { clientUUID, err := uuid.NewRandom() if err != nil { - return nil, ingress.Ingress{}, errors.Wrap(err, "can't generate connector UUID") + return nil, nil, errors.Wrap(err, "can't generate connector UUID") } log.Info().Msgf("Generated Connector ID: %s", clientUUID) - features := append(c.StringSlice("features"), origin.FeatureSerializedHeaders) + features := append(c.StringSlice("features"), supervisor.FeatureSerializedHeaders) namedTunnel.Client = tunnelpogs.ClientInfo{ ClientID: clientUUID[:], Features: dedup(features), @@ -198,10 +198,10 @@ func prepareTunnelConfig( } ingressRules, err = ingress.ParseIngress(cfg) if err != nil && err != ingress.ErrNoIngressRules { - return nil, ingress.Ingress{}, err + return nil, nil, err } if !ingressRules.IsEmpty() && c.IsSet("url") { - return nil, ingress.Ingress{}, ingress.ErrURLIncompatibleWithIngress + return nil, nil, ingress.ErrURLIncompatibleWithIngress } } else { @@ -212,10 +212,10 @@ func prepareTunnelConfig( originCert, err := getOriginCert(originCertPath, &originCertLog) if err != nil { - return nil, ingress.Ingress{}, errors.Wrap(err, "Error getting origin cert") + return nil, nil, errors.Wrap(err, "Error getting origin cert") } - classicTunnel = &connection.ClassicTunnelConfig{ + classicTunnel = &connection.ClassicTunnelProperties{ Hostname: hostname, OriginCert: originCert, // turn off use of reconnect token and auth refresh when using named tunnels @@ -227,20 +227,14 @@ func prepareTunnelConfig( if ingressRules.IsEmpty() { ingressRules, err = ingress.NewSingleOrigin(c, !isNamedTunnel) if err != nil { - return nil, ingress.Ingress{}, err + return nil, nil, err } } - var warpRoutingService *ingress.WarpRoutingService warpRoutingEnabled := isWarpRoutingEnabled(cfg.WarpRouting, isNamedTunnel) - if warpRoutingEnabled { - warpRoutingService = ingress.NewWarpRoutingService() - log.Info().Msgf("Warp-routing is enabled") - } - - protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, origin.ResolveTTL, log) + protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, supervisor.ResolveTTL, log) if err != nil { - return nil, ingress.Ingress{}, err + return nil, nil, err } log.Info().Msgf("Initial protocol %s", protocolSelector.Current()) @@ -248,11 +242,11 @@ func prepareTunnelConfig( for _, p := range connection.ProtocolList { tlsSettings := p.TLSSettings() if tlsSettings == nil { - return nil, ingress.Ingress{}, fmt.Errorf("%s has unknown TLS settings", p) + return nil, nil, fmt.Errorf("%s has unknown TLS settings", p) } edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, tlsSettings.ServerName) if err != nil { - return nil, ingress.Ingress{}, errors.Wrap(err, "unable to create TLS config to connect with edge") + return nil, nil, errors.Wrap(err, "unable to create TLS config to connect with edge") } if len(tlsSettings.NextProtos) > 0 { edgeTLSConfig.NextProtos = tlsSettings.NextProtos @@ -260,15 +254,9 @@ func prepareTunnelConfig( edgeTLSConfigs[p] = edgeTLSConfig } - originProxy := origin.NewOriginProxy(ingressRules, warpRoutingService, tags, log) gracePeriod, err := gracePeriod(c) if err != nil { - return nil, ingress.Ingress{}, err - } - connectionConfig := &connection.Config{ - OriginProxy: originProxy, - GracePeriod: gracePeriod, - ReplaceExisting: c.Bool("force"), + return nil, nil, err } muxerConfig := &connection.MuxerConfig{ HeartbeatInterval: c.Duration("heartbeat-interval"), @@ -279,21 +267,22 @@ func prepareTunnelConfig( MetricsUpdateFreq: c.Duration("metrics-update-freq"), } - return &origin.TunnelConfig{ - ConnectionConfig: connectionConfig, - OSArch: info.OSArch(), - ClientID: clientID, - EdgeAddrs: c.StringSlice("edge"), - Region: c.String("region"), - HAConnections: c.Int("ha-connections"), - IncidentLookup: origin.NewIncidentLookup(), - IsAutoupdated: c.Bool("is-autoupdated"), - LBPool: c.String("lb-pool"), - Tags: tags, - Log: log, - LogTransport: logTransport, - Observer: observer, - ReportedVersion: info.Version(), + tunnelConfig := &supervisor.TunnelConfig{ + GracePeriod: gracePeriod, + ReplaceExisting: c.Bool("force"), + OSArch: info.OSArch(), + ClientID: clientID, + EdgeAddrs: c.StringSlice("edge"), + Region: c.String("region"), + HAConnections: c.Int("ha-connections"), + IncidentLookup: supervisor.NewIncidentLookup(), + IsAutoupdated: c.Bool("is-autoupdated"), + LBPool: c.String("lb-pool"), + Tags: tags, + Log: log, + LogTransport: logTransport, + Observer: observer, + ReportedVersion: info.Version(), // Note TUN-3758 , we use Int because UInt is not supported with altsrc Retries: uint(c.Int("retries")), RunFromTerminal: isRunningFromTerminal(), @@ -302,7 +291,12 @@ func prepareTunnelConfig( MuxerConfig: muxerConfig, ProtocolSelector: protocolSelector, EdgeTLSConfigs: edgeTLSConfigs, - }, ingressRules, nil + } + dynamicConfig := &supervisor.DynamicConfig{ + Ingress: &ingressRules, + WarpRoutingEnabled: warpRoutingEnabled, + } + return tunnelConfig, dynamicConfig, nil } func gracePeriod(c *cli.Context) (time.Duration, error) { diff --git a/cmd/cloudflared/tunnel/quick_tunnel.go b/cmd/cloudflared/tunnel/quick_tunnel.go index 08b5ff78..e514b4ad 100644 --- a/cmd/cloudflared/tunnel/quick_tunnel.go +++ b/cmd/cloudflared/tunnel/quick_tunnel.go @@ -77,7 +77,7 @@ func RunQuickTunnel(sc *subcommandContext) error { return StartServer( sc.c, buildInfo, - &connection.NamedTunnelConfig{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname}, + &connection.NamedTunnelProperties{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname}, sc.log, sc.isUIEnabled, ) diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index cb5b15be..03d8c796 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -304,7 +304,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error { return StartServer( sc.c, buildInfo, - &connection.NamedTunnelConfig{Credentials: credentials}, + &connection.NamedTunnelProperties{Credentials: credentials}, sc.log, sc.isUIEnabled, ) diff --git a/connection/connection.go b/connection/connection.go index 2a57229f..07649983 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -25,13 +25,12 @@ const ( var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) -type Config struct { - OriginProxy OriginProxy - GracePeriod time.Duration - ReplaceExisting bool +type ConfigManager interface { + Update(version int32, config []byte) *pogs.UpdateConfigurationResponse + GetOriginProxy() OriginProxy } -type NamedTunnelConfig struct { +type NamedTunnelProperties struct { Credentials Credentials Client pogs.ClientInfo QuickTunnelUrl string @@ -52,7 +51,7 @@ func (c *Credentials) Auth() pogs.TunnelAuth { } } -type ClassicTunnelConfig struct { +type ClassicTunnelProperties struct { Hostname string OriginCert []byte // feature-flag to use new edge reconnect tokens diff --git a/connection/connection_test.go b/connection/connection_test.go index 5106422f..3b83269e 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -14,18 +14,19 @@ import ( "github.com/stretchr/testify/assert" "github.com/cloudflare/cloudflared/ingress" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/websocket" ) const ( - largeFileSize = 2 * 1024 * 1024 + largeFileSize = 2 * 1024 * 1024 + testGracePeriod = time.Millisecond * 100 ) var ( unusedWarpRoutingService = (*ingress.WarpRoutingService)(nil) - testConfig = &Config{ - OriginProxy: &mockOriginProxy{}, - GracePeriod: time.Millisecond * 100, + testConfigManager = &mockConfigManager{ + originProxy: &mockOriginProxy{}, } log = zerolog.Nop() testOriginURL = &url.URL{ @@ -43,6 +44,20 @@ type testRequest struct { isProxyError bool } +type mockConfigManager struct { + originProxy OriginProxy +} + +func (*mockConfigManager) Update(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { + return &tunnelpogs.UpdateConfigurationResponse{ + LastAppliedVersion: version, + } +} + +func (mcr *mockConfigManager) GetOriginProxy() OriginProxy { + return mcr.originProxy +} + type mockOriginProxy struct{} func (moc *mockOriginProxy) ProxyHTTP( diff --git a/connection/control.go b/connection/control.go index c0c6a1d7..2467e80a 100644 --- a/connection/control.go +++ b/connection/control.go @@ -16,9 +16,9 @@ type RPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) Na type controlStream struct { observer *Observer - connectedFuse ConnectedFuse - namedTunnelConfig *NamedTunnelConfig - connIndex uint8 + connectedFuse ConnectedFuse + namedTunnelProperties *NamedTunnelProperties + connIndex uint8 newRPCClientFunc RPCClientFunc @@ -39,7 +39,7 @@ type ControlStreamHandler interface { func NewControlStream( observer *Observer, connectedFuse ConnectedFuse, - namedTunnelConfig *NamedTunnelConfig, + namedTunnelConfig *NamedTunnelProperties, connIndex uint8, newRPCClientFunc RPCClientFunc, gracefulShutdownC <-chan struct{}, @@ -49,13 +49,13 @@ func NewControlStream( newRPCClientFunc = newRegistrationRPCClient } return &controlStream{ - observer: observer, - connectedFuse: connectedFuse, - namedTunnelConfig: namedTunnelConfig, - newRPCClientFunc: newRPCClientFunc, - connIndex: connIndex, - gracefulShutdownC: gracefulShutdownC, - gracePeriod: gracePeriod, + observer: observer, + connectedFuse: connectedFuse, + namedTunnelProperties: namedTunnelConfig, + newRPCClientFunc: newRPCClientFunc, + connIndex: connIndex, + gracefulShutdownC: gracefulShutdownC, + gracePeriod: gracePeriod, } } @@ -66,7 +66,7 @@ func (c *controlStream) ServeControlStream( ) error { rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log) - if err := rpcClient.RegisterConnection(ctx, c.namedTunnelConfig, connOptions, c.connIndex, c.observer); err != nil { + if err := rpcClient.RegisterConnection(ctx, c.namedTunnelProperties, connOptions, c.connIndex, c.observer); err != nil { rpcClient.Close() return err } diff --git a/connection/h2mux.go b/connection/h2mux.go index 1e7c652b..8401c81a 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -22,9 +22,10 @@ const ( ) type h2muxConnection struct { - config *Config - muxerConfig *MuxerConfig - muxer *h2mux.Muxer + configManager ConfigManager + gracePeriod time.Duration + muxerConfig *MuxerConfig + muxer *h2mux.Muxer // connectionID is only used by metrics, and prometheus requires labels to be string connIndexStr string connIndex uint8 @@ -60,7 +61,8 @@ func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, log *zerolog.Lo // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error func NewH2muxConnection( - config *Config, + configManager ConfigManager, + gracePeriod time.Duration, muxerConfig *MuxerConfig, edgeConn net.Conn, connIndex uint8, @@ -68,7 +70,8 @@ func NewH2muxConnection( gracefulShutdownC <-chan struct{}, ) (*h2muxConnection, error, bool) { h := &h2muxConnection{ - config: config, + configManager: configManager, + gracePeriod: gracePeriod, muxerConfig: muxerConfig, connIndexStr: uint8ToString(connIndex), connIndex: connIndex, @@ -88,7 +91,7 @@ func NewH2muxConnection( return h, nil, false } -func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error { +func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelProperties, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error { errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { return h.serveMuxer(serveCtx) @@ -117,7 +120,7 @@ func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *Nam return err } -func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error { +func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelProperties, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error { errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { return h.serveMuxer(serveCtx) @@ -224,7 +227,7 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error { sourceConnectionType = TypeWebsocket } - err := h.config.OriginProxy.ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket) + err := h.configManager.GetOriginProxy().ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket) if err != nil { respWriter.WriteErrorResponse() } diff --git a/connection/h2mux_test.go b/connection/h2mux_test.go index 83a39589..35ab447c 100644 --- a/connection/h2mux_test.go +++ b/connection/h2mux_test.go @@ -48,7 +48,7 @@ func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) { }() var connIndex = uint8(0) testObserver := NewObserver(&log, &log, false) - h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver, nil) + h2muxConn, err, _ := NewH2muxConnection(testConfigManager, testGracePeriod, testMuxerConfig, originConn, connIndex, testObserver, nil) require.NoError(t, err) return h2muxConn, <-edgeMuxChan } diff --git a/connection/http2.go b/connection/http2.go index c0ab8f23..1e1f517a 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -30,12 +30,12 @@ var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed") // HTTP2Connection represents a net.Conn that uses HTTP2 frames to proxy traffic from the edge to cloudflared on the // origin. type HTTP2Connection struct { - conn net.Conn - server *http2.Server - config *Config - connOptions *tunnelpogs.ConnectionOptions - observer *Observer - connIndex uint8 + conn net.Conn + server *http2.Server + configManager ConfigManager + connOptions *tunnelpogs.ConnectionOptions + observer *Observer + connIndex uint8 // newRPCClientFunc allows us to mock RPCs during testing newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient @@ -49,7 +49,7 @@ type HTTP2Connection struct { // NewHTTP2Connection returns a new instance of HTTP2Connection. func NewHTTP2Connection( conn net.Conn, - config *Config, + configManager ConfigManager, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, @@ -61,7 +61,7 @@ func NewHTTP2Connection( server: &http2.Server{ MaxConcurrentStreams: MaxConcurrentStreams, }, - config: config, + configManager: configManager, connOptions: connOptions, observer: observer, connIndex: connIndex, @@ -116,7 +116,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { case TypeWebsocket, TypeHTTP: stripWebsocketUpgradeHeader(r) - if err := c.config.OriginProxy.ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil { + if err := c.configManager.GetOriginProxy().ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil { err := fmt.Errorf("Failed to proxy HTTP: %w", err) c.log.Error().Err(err) respWriter.WriteErrorResponse() @@ -131,7 +131,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { } rws := NewHTTPResponseReadWriterAcker(respWriter, r) - if err := c.config.OriginProxy.ProxyTCP(r.Context(), rws, &TCPRequest{ + if err := c.configManager.GetOriginProxy().ProxyTCP(r.Context(), rws, &TCPRequest{ Dest: host, CFRay: FindCfRayHeader(r), LBProbe: IsLBProbeRequest(r), diff --git a/connection/http2_test.go b/connection/http2_test.go index 1a475f88..c405d1ee 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -35,7 +35,7 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { controlStream := NewControlStream( obs, mockConnectedFuse{}, - &NamedTunnelConfig{}, + &NamedTunnelProperties{}, connIndex, nil, nil, @@ -43,8 +43,8 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { ) return NewHTTP2Connection( cfdConn, - // OriginProxy is set in testConfig - testConfig, + // OriginProxy is set in testConfigManager + testConfigManager, &pogs.ConnectionOptions{}, obs, connIndex, @@ -132,7 +132,7 @@ type mockNamedTunnelRPCClient struct { func (mc mockNamedTunnelRPCClient) RegisterConnection( c context.Context, - config *NamedTunnelConfig, + properties *NamedTunnelProperties, options *tunnelpogs.ConnectionOptions, connIndex uint8, observer *Observer, @@ -313,7 +313,7 @@ func TestServeControlStream(t *testing.T) { controlStream := NewControlStream( obs, mockConnectedFuse{}, - &NamedTunnelConfig{}, + &NamedTunnelProperties{}, 1, rpcClientFactory.newMockRPCClient, nil, @@ -363,7 +363,7 @@ func TestFailRegistration(t *testing.T) { controlStream := NewControlStream( obs, mockConnectedFuse{}, - &NamedTunnelConfig{}, + &NamedTunnelProperties{}, http2Conn.connIndex, rpcClientFactory.newMockRPCClient, nil, @@ -409,7 +409,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) { controlStream := NewControlStream( obs, mockConnectedFuse{}, - &NamedTunnelConfig{}, + &NamedTunnelProperties{}, http2Conn.connIndex, rpcClientFactory.newMockRPCClient, shutdownC, diff --git a/connection/protocol.go b/connection/protocol.go index 399e6d9d..b94bb80e 100644 --- a/connection/protocol.go +++ b/connection/protocol.go @@ -195,7 +195,7 @@ type PercentageFetcher func() (edgediscovery.ProtocolPercents, error) func NewProtocolSelector( protocolFlag string, warpRoutingEnabled bool, - namedTunnel *NamedTunnelConfig, + namedTunnel *NamedTunnelProperties, fetchFunc PercentageFetcher, ttl time.Duration, log *zerolog.Logger, diff --git a/connection/protocol_test.go b/connection/protocol_test.go index 9bb8c50c..9ab5aae3 100644 --- a/connection/protocol_test.go +++ b/connection/protocol_test.go @@ -16,7 +16,7 @@ const ( ) var ( - testNamedTunnelConfig = &NamedTunnelConfig{ + testNamedTunnelProperties = &NamedTunnelProperties{ Credentials: Credentials{ AccountTag: "testAccountTag", }, @@ -51,7 +51,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback bool expectedFallback Protocol warpRoutingEnabled bool - namedTunnelConfig *NamedTunnelConfig + namedTunnelConfig *NamedTunnelProperties fetchFunc PercentageFetcher wantErr bool }{ @@ -66,35 +66,35 @@ func TestNewProtocolSelector(t *testing.T) { protocol: "h2mux", expectedProtocol: H2mux, fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil }, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel over http2", protocol: "http2", expectedProtocol: HTTP2, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel http2 disabled still gets http2 because it is manually picked", protocol: "http2", expectedProtocol: HTTP2, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel quic disabled still gets quic because it is manually picked", protocol: "quic", expectedProtocol: QUIC, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel quic and http2 disabled", protocol: "auto", expectedProtocol: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel quic disabled", @@ -104,21 +104,21 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: true, expectedFallback: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel auto all http2 disabled", protocol: "auto", expectedProtocol: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel auto to h2mux", protocol: "auto", expectedProtocol: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel auto to http2", @@ -127,7 +127,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: true, expectedFallback: H2mux, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "named tunnel auto to quic", @@ -136,7 +136,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: true, expectedFallback: HTTP2, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing requesting h2mux", @@ -145,7 +145,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: false, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1", @@ -154,7 +154,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: false, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing http2", @@ -163,7 +163,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: false, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing quic", @@ -173,7 +173,7 @@ func TestNewProtocolSelector(t *testing.T) { expectedFallback: HTTP2Warp, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing auto", @@ -182,7 +182,7 @@ func TestNewProtocolSelector(t *testing.T) { hasFallback: false, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { name: "warp routing auto- quic", @@ -192,7 +192,7 @@ func TestNewProtocolSelector(t *testing.T) { expectedFallback: HTTP2Warp, fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), warpRoutingEnabled: true, - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, }, { // None named tunnel can only use h2mux, so specifying an unknown protocol is not an error @@ -204,14 +204,14 @@ func TestNewProtocolSelector(t *testing.T) { name: "named tunnel unknown protocol", protocol: "unknown", fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, wantErr: true, }, { name: "named tunnel fetch error", protocol: "auto", fetchFunc: mockFetcher(true), - namedTunnelConfig: testNamedTunnelConfig, + namedTunnelConfig: testNamedTunnelProperties, expectedProtocol: HTTP2, wantErr: false, }, @@ -237,7 +237,7 @@ func TestNewProtocolSelector(t *testing.T) { func TestAutoProtocolSelectorRefresh(t *testing.T) { fetcher := dynamicMockFetcher{} - selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log) + selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log) assert.NoError(t, err) assert.Equal(t, H2mux, selector.Current()) @@ -267,7 +267,7 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) { func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { fetcher := dynamicMockFetcher{} // Since the user chooses http2 on purpose, we always stick to it. - selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log) + selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log) assert.NoError(t, err) assert.Equal(t, HTTP2, selector.Current()) @@ -297,7 +297,7 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { func TestProtocolSelectorRefreshTTL(t *testing.T) { fetcher := dynamicMockFetcher{} fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} - selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log) + selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), time.Hour, &log) assert.NoError(t, err) assert.Equal(t, QUIC, selector.Current()) diff --git a/connection/quic.go b/connection/quic.go index f9d3331d..d0d38daa 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -36,7 +36,7 @@ const ( type QUICConnection struct { session quic.Session logger *zerolog.Logger - httpProxy OriginProxy + configManager ConfigManager sessionManager datagramsession.Manager controlStreamHandler ControlStreamHandler connOptions *tunnelpogs.ConnectionOptions @@ -47,7 +47,7 @@ func NewQUICConnection( quicConfig *quic.Config, edgeAddr net.Addr, tlsConfig *tls.Config, - httpProxy OriginProxy, + configManager ConfigManager, connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler ControlStreamHandler, logger *zerolog.Logger, @@ -66,7 +66,7 @@ func NewQUICConnection( return &QUICConnection{ session: session, - httpProxy: httpProxy, + configManager: configManager, logger: logger, sessionManager: sessionManager, controlStreamHandler: controlStreamHandler, @@ -183,10 +183,10 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) } w := newHTTPResponseAdapter(stream) - return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket) + return q.configManager.GetOriginProxy().ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket) case quicpogs.ConnectionTypeTCP: rwa := &streamReadWriteAcker{stream} - return q.httpProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest}) + return q.configManager.GetOriginProxy().ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest}) } return nil } diff --git a/connection/quic_test.go b/connection/quic_test.go index 349e9210..4eecc2a1 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -627,13 +627,12 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection NextProtos: []string{"argotunnel"}, } // Start a mock httpProxy - originProxy := &mockOriginProxyWithRequest{} log := zerolog.New(os.Stdout) qc, err := NewQUICConnection( testQUICConfig, udpListenerAddr, tlsClientConfig, - originProxy, + &mockConfigManager{originProxy: &mockOriginProxyWithRequest{}}, &tunnelpogs.ConnectionOptions{}, fakeControlStream{}, &log, diff --git a/connection/rpc.go b/connection/rpc.go index e8eb6f4a..937604b3 100644 --- a/connection/rpc.go +++ b/connection/rpc.go @@ -37,7 +37,7 @@ func NewTunnelServerClient( } } -func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) { +func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelProperties, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) { authResp, err := tsc.client.Authenticate(ctx, classicTunnel.OriginCert, classicTunnel.Hostname, registrationOptions) if err != nil { return nil, err @@ -54,7 +54,7 @@ func (tsc *tunnelServerClient) Close() { type NamedTunnelRPCClient interface { RegisterConnection( c context.Context, - config *NamedTunnelConfig, + config *NamedTunnelProperties, options *tunnelpogs.ConnectionOptions, connIndex uint8, observer *Observer, @@ -86,15 +86,15 @@ func newRegistrationRPCClient( func (rsc *registrationServerClient) RegisterConnection( ctx context.Context, - config *NamedTunnelConfig, + properties *NamedTunnelProperties, options *tunnelpogs.ConnectionOptions, connIndex uint8, observer *Observer, ) error { conn, err := rsc.client.RegisterConnection( ctx, - config.Credentials.Auth(), - config.Credentials.TunnelID, + properties.Credentials.Auth(), + properties.Credentials.TunnelID, connIndex, options, ) @@ -137,7 +137,7 @@ const ( authenticate rpcName = " authenticate" ) -func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error { +func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelProperties, registrationOptions *tunnelpogs.RegistrationOptions) error { h.observer.sendRegisteringEvent(registrationOptions.ConnectionID) stream, err := h.newRPCStream(ctx, register) @@ -174,7 +174,7 @@ type CredentialManager interface { func (h *h2muxConnection) processRegistrationSuccess( registration *tunnelpogs.TunnelRegistration, name rpcName, - credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, + credentialManager CredentialManager, classicTunnel *ClassicTunnelProperties, ) error { for _, logLine := range registration.LogLines { h.observer.log.Info().Msg(logLine) @@ -205,7 +205,7 @@ func (h *h2muxConnection) processRegisterTunnelError(err tunnelpogs.TunnelRegist } } -func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error { +func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelProperties, registrationOptions *tunnelpogs.RegistrationOptions) error { token, err := credentialManager.ReconnectToken() if err != nil { return err @@ -264,7 +264,7 @@ func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelSe func (h *h2muxConnection) registerNamedTunnel( ctx context.Context, - namedTunnel *NamedTunnelConfig, + namedTunnel *NamedTunnelProperties, connOptions *tunnelpogs.ConnectionOptions, ) error { stream, err := h.newRPCStream(ctx, register) @@ -283,7 +283,7 @@ func (h *h2muxConnection) registerNamedTunnel( func (h *h2muxConnection) unregister(isNamedTunnel bool) { h.observer.sendUnregisteringEvent(h.connIndex) - unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod) + unregisterCtx, cancel := context.WithTimeout(context.Background(), h.gracePeriod) defer cancel() stream, err := h.newRPCStream(unregisterCtx, unregister) @@ -296,13 +296,13 @@ func (h *h2muxConnection) unregister(isNamedTunnel bool) { rpcClient := h.newRPCClientFunc(unregisterCtx, stream, h.observer.log) defer rpcClient.Close() - rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod) + rpcClient.GracefulShutdown(unregisterCtx, h.gracePeriod) } else { rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer.log) defer rpcClient.Close() // gracePeriod is encoded in int64 using capnproto - _ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds()) + _ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.gracePeriod.Nanoseconds()) } h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unregistered tunnel connection") diff --git a/origin/metrics.go b/proxy/metrics.go similarity index 85% rename from origin/metrics.go rename to proxy/metrics.go index 1e54f271..e5406681 100644 --- a/origin/metrics.go +++ b/proxy/metrics.go @@ -1,4 +1,4 @@ -package origin +package proxy import ( "github.com/prometheus/client_golang/prometheus" @@ -43,14 +43,6 @@ var ( Help: "Count of error proxying to origin", }, ) - haConnections = prometheus.NewGauge( - prometheus.GaugeOpts{ - Namespace: connection.MetricsNamespace, - Subsystem: connection.TunnelSubsystem, - Name: "ha_connections", - Help: "Number of active ha connections", - }, - ) ) func init() { @@ -59,7 +51,6 @@ func init() { concurrentRequests, responseByCode, requestErrors, - haConnections, ) } diff --git a/origin/pool.go b/proxy/pool.go similarity index 96% rename from origin/pool.go rename to proxy/pool.go index 396a4a76..fe2cf4a5 100644 --- a/origin/pool.go +++ b/proxy/pool.go @@ -1,4 +1,4 @@ -package origin +package proxy import ( "sync" diff --git a/origin/proxy.go b/proxy/proxy.go similarity index 98% rename from origin/proxy.go rename to proxy/proxy.go index dca03746..b849967d 100644 --- a/origin/proxy.go +++ b/proxy/proxy.go @@ -1,4 +1,4 @@ -package origin +package proxy import ( "bufio" @@ -28,7 +28,7 @@ const ( // Proxy represents a means to Proxy between cloudflared and the origin services. type Proxy struct { - ingressRules ingress.Ingress + ingressRules *ingress.Ingress warpRouting *ingress.WarpRoutingService tags []tunnelpogs.Tag log *zerolog.Logger @@ -37,7 +37,7 @@ type Proxy struct { // NewOriginProxy returns a new instance of the Proxy struct. func NewOriginProxy( - ingressRules ingress.Ingress, + ingressRules *ingress.Ingress, warpRouting *ingress.WarpRoutingService, tags []tunnelpogs.Tag, log *zerolog.Logger, @@ -139,7 +139,7 @@ func (p *Proxy) ProxyTCP( return nil } -func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) { +func ruleField(ing *ingress.Ingress, ruleNum int) (ruleID string, srv string) { srv = ing.Rules[ruleNum].Service.String() if ing.IsSingleRule() { return "", srv diff --git a/origin/proxy_posix_test.go b/proxy/proxy_posix_test.go similarity index 98% rename from origin/proxy_posix_test.go rename to proxy/proxy_posix_test.go index 1b649a43..40d070c7 100644 --- a/origin/proxy_posix_test.go +++ b/proxy/proxy_posix_test.go @@ -1,7 +1,7 @@ //go:build !windows // +build !windows -package origin +package proxy import ( "io/ioutil" diff --git a/origin/proxy_test.go b/proxy/proxy_test.go similarity index 98% rename from origin/proxy_test.go rename to proxy/proxy_test.go index e4184d7a..8ccc7624 100644 --- a/origin/proxy_test.go +++ b/proxy/proxy_test.go @@ -1,4 +1,4 @@ -package origin +package proxy import ( "bytes" @@ -135,7 +135,7 @@ func TestProxySingleOrigin(t *testing.T) { errC := make(chan error) require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC)) - proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, &log) + proxy := NewOriginProxy(&ingressRule, unusedWarpRoutingService, testTags, &log) t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxySSE", testProxySSE(proxy)) @@ -345,7 +345,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat var wg sync.WaitGroup require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC)) - proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log) + proxy := NewOriginProxy(&ingress, unusedWarpRoutingService, testTags, &log) for _, test := range tests { responseWriter := newMockHTTPRespWriter() @@ -394,7 +394,7 @@ func TestProxyError(t *testing.T) { log := zerolog.Nop() - proxy := NewOriginProxy(ing, unusedWarpRoutingService, testTags, &log) + proxy := NewOriginProxy(&ing, unusedWarpRoutingService, testTags, &log) responseWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) @@ -637,7 +637,7 @@ func TestConnections(t *testing.T) { var wg sync.WaitGroup errC := make(chan error) ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) - proxy := NewOriginProxy(ingressRule, test.args.warpRoutingService, testTags, logger) + proxy := NewOriginProxy(&ingressRule, test.args.warpRoutingService, testTags, logger) dest := ln.Addr().String() req, err := http.NewRequest( diff --git a/origin/cloudflare_status_page.go b/supervisor/cloudflare_status_page.go similarity index 99% rename from origin/cloudflare_status_page.go rename to supervisor/cloudflare_status_page.go index dfa9143a..93d9e849 100644 --- a/origin/cloudflare_status_page.go +++ b/supervisor/cloudflare_status_page.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "encoding/json" diff --git a/origin/cloudflare_status_page_test.go b/supervisor/cloudflare_status_page_test.go similarity index 99% rename from origin/cloudflare_status_page_test.go rename to supervisor/cloudflare_status_page_test.go index 21985dcc..a86fb63f 100644 --- a/origin/cloudflare_status_page_test.go +++ b/supervisor/cloudflare_status_page_test.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "testing" diff --git a/supervisor/configmanager.go b/supervisor/configmanager.go new file mode 100644 index 00000000..b1067c83 --- /dev/null +++ b/supervisor/configmanager.go @@ -0,0 +1,55 @@ +package supervisor + +import ( + "sync" + + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/proxy" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" +) + +type configManager struct { + currentVersion int32 + // Only used by UpdateConfig + updateLock sync.Mutex + // TODO: TUN-5698: Make proxy atomic.Value + proxy *proxy.Proxy + config *DynamicConfig + tags []tunnelpogs.Tag + log *zerolog.Logger +} + +func newConfigManager(config *DynamicConfig, tags []tunnelpogs.Tag, log *zerolog.Logger) *configManager { + var warpRoutingService *ingress.WarpRoutingService + if config.WarpRoutingEnabled { + warpRoutingService = ingress.NewWarpRoutingService() + log.Info().Msgf("Warp-routing is enabled") + } + + return &configManager{ + // Lowest possible version, any remote configuration will have version higher than this + currentVersion: 0, + proxy: proxy.NewOriginProxy(config.Ingress, warpRoutingService, tags, log), + config: config, + log: log, + } +} + +func (cm *configManager) Update(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { + // TODO: TUN-5698: make ingress configurable + return &tunnelpogs.UpdateConfigurationResponse{ + LastAppliedVersion: cm.currentVersion, + } +} + +func (cm *configManager) GetOriginProxy() connection.OriginProxy { + return cm.proxy +} + +type DynamicConfig struct { + Ingress *ingress.Ingress + WarpRoutingEnabled bool +} diff --git a/origin/conn_aware_logger.go b/supervisor/conn_aware_logger.go similarity index 97% rename from origin/conn_aware_logger.go rename to supervisor/conn_aware_logger.go index b8021121..6e717588 100644 --- a/origin/conn_aware_logger.go +++ b/supervisor/conn_aware_logger.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "github.com/rs/zerolog" diff --git a/origin/external_control.go b/supervisor/external_control.go similarity index 95% rename from origin/external_control.go rename to supervisor/external_control.go index cd9ef364..f170cde2 100644 --- a/origin/external_control.go +++ b/supervisor/external_control.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "time" diff --git a/supervisor/metrics.go b/supervisor/metrics.go new file mode 100644 index 00000000..e6d50cdd --- /dev/null +++ b/supervisor/metrics.go @@ -0,0 +1,27 @@ +package supervisor + +import ( + "github.com/prometheus/client_golang/prometheus" + + "github.com/cloudflare/cloudflared/connection" +) + +// Metrics uses connection.MetricsNamespace(aka cloudflared) as namespace and connection.TunnelSubsystem +// (tunnel) as subsystem to keep them consistent with the previous qualifier. + +var ( + haConnections = prometheus.NewGauge( + prometheus.GaugeOpts{ + Namespace: connection.MetricsNamespace, + Subsystem: connection.TunnelSubsystem, + Name: "ha_connections", + Help: "Number of active ha connections", + }, + ) +) + +func init() { + prometheus.MustRegister( + haConnections, + ) +} diff --git a/origin/reconnect.go b/supervisor/reconnect.go similarity index 99% rename from origin/reconnect.go rename to supervisor/reconnect.go index 8b43977b..040c2714 100644 --- a/origin/reconnect.go +++ b/supervisor/reconnect.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "context" diff --git a/origin/reconnect_test.go b/supervisor/reconnect_test.go similarity index 99% rename from origin/reconnect_test.go rename to supervisor/reconnect_test.go index fb2a1df9..593d16d1 100644 --- a/origin/reconnect_test.go +++ b/supervisor/reconnect_test.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "context" diff --git a/origin/supervisor.go b/supervisor/supervisor.go similarity index 96% rename from origin/supervisor.go rename to supervisor/supervisor.go index 304fc3c2..a87384a8 100644 --- a/origin/supervisor.go +++ b/supervisor/supervisor.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "context" @@ -37,6 +37,7 @@ const ( // reconnects them if they disconnect. type Supervisor struct { cloudflaredUUID uuid.UUID + configManager *configManager config *TunnelConfig edgeIPs *edgediscovery.Edge tunnelErrors chan tunnelError @@ -64,7 +65,7 @@ type tunnelError struct { err error } -func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) { +func NewSupervisor(config *TunnelConfig, dynamiConfig *DynamicConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) { cloudflaredUUID, err := uuid.NewRandom() if err != nil { return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err) @@ -88,6 +89,7 @@ func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, grace return &Supervisor{ cloudflaredUUID: cloudflaredUUID, config: config, + configManager: newConfigManager(dynamiConfig, config.Tags, config.Log), edgeIPs: edgeIPs, tunnelErrors: make(chan tunnelError), tunnelsConnecting: map[int]chan struct{}{}, @@ -242,6 +244,7 @@ func (s *Supervisor) startFirstTunnel( err = ServeTunnelLoop( ctx, s.reconnectCredentialManager, + s.configManager, s.config, addr, s.log, @@ -276,6 +279,7 @@ func (s *Supervisor) startFirstTunnel( err = ServeTunnelLoop( ctx, s.reconnectCredentialManager, + s.configManager, s.config, addr, s.log, @@ -310,6 +314,7 @@ func (s *Supervisor) startTunnel( err = ServeTunnelLoop( ctx, s.reconnectCredentialManager, + s.configManager, s.config, addr, s.log, @@ -380,7 +385,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) defer rpcClient.Close() const arbitraryConnectionID = uint8(0) - registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID) + registrationOptions := s.config.registrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID) registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts) return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions) } diff --git a/origin/tunnel.go b/supervisor/tunnel.go similarity index 90% rename from origin/tunnel.go rename to supervisor/tunnel.go index 99cb5987..64047863 100644 --- a/origin/tunnel.go +++ b/supervisor/tunnel.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "context" @@ -34,32 +34,33 @@ const ( ) type TunnelConfig struct { - ConnectionConfig *connection.Config - OSArch string - ClientID string - CloseConnOnce *sync.Once // Used to close connectedSignal no more than once - EdgeAddrs []string - Region string - HAConnections int - IncidentLookup IncidentLookup - IsAutoupdated bool - LBPool string - Tags []tunnelpogs.Tag - Log *zerolog.Logger - LogTransport *zerolog.Logger - Observer *connection.Observer - ReportedVersion string - Retries uint - RunFromTerminal bool + GracePeriod time.Duration + ReplaceExisting bool + OSArch string + ClientID string + CloseConnOnce *sync.Once // Used to close connectedSignal no more than once + EdgeAddrs []string + Region string + HAConnections int + IncidentLookup IncidentLookup + IsAutoupdated bool + LBPool string + Tags []tunnelpogs.Tag + Log *zerolog.Logger + LogTransport *zerolog.Logger + Observer *connection.Observer + ReportedVersion string + Retries uint + RunFromTerminal bool - NamedTunnel *connection.NamedTunnelConfig - ClassicTunnel *connection.ClassicTunnelConfig + NamedTunnel *connection.NamedTunnelProperties + ClassicTunnel *connection.ClassicTunnelProperties MuxerConfig *connection.MuxerConfig ProtocolSelector connection.ProtocolSelector EdgeTLSConfigs map[connection.Protocol]*tls.Config } -func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { +func (c *TunnelConfig) registrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { policy := tunnelrpc.ExistingTunnelPolicy_balance if c.HAConnections <= 1 && c.LBPool == "" { policy = tunnelrpc.ExistingTunnelPolicy_disconnect @@ -81,7 +82,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str } } -func (c *TunnelConfig) ConnectionOptions(originLocalAddr string, numPreviousAttempts uint8) *tunnelpogs.ConnectionOptions { +func (c *TunnelConfig) connectionOptions(originLocalAddr string, numPreviousAttempts uint8) *tunnelpogs.ConnectionOptions { // attempt to parse out origin IP, but don't fail since it's informational field host, _, _ := net.SplitHostPort(originLocalAddr) originIP := net.ParseIP(host) @@ -89,7 +90,7 @@ func (c *TunnelConfig) ConnectionOptions(originLocalAddr string, numPreviousAtte return &tunnelpogs.ConnectionOptions{ Client: c.NamedTunnel.Client, OriginLocalIP: originIP, - ReplaceExisting: c.ConnectionConfig.ReplaceExisting, + ReplaceExisting: c.ReplaceExisting, CompressionQuality: uint8(c.MuxerConfig.CompressionSetting), NumPreviousAttempts: numPreviousAttempts, } @@ -106,11 +107,12 @@ func (c *TunnelConfig) SupportedFeatures() []string { func StartTunnelDaemon( ctx context.Context, config *TunnelConfig, + dynamiConfig *DynamicConfig, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal, graceShutdownC <-chan struct{}, ) error { - s, err := NewSupervisor(config, reconnectCh, graceShutdownC) + s, err := NewSupervisor(config, dynamiConfig, reconnectCh, graceShutdownC) if err != nil { return err } @@ -120,6 +122,7 @@ func StartTunnelDaemon( func ServeTunnelLoop( ctx context.Context, credentialManager *reconnectCredentialManager, + configManager *configManager, config *TunnelConfig, addr *allregions.EdgeAddr, connAwareLogger *ConnAwareLogger, @@ -155,6 +158,7 @@ func ServeTunnelLoop( ctx, connLog, credentialManager, + configManager, config, addr, connIndex, @@ -253,6 +257,7 @@ func ServeTunnel( ctx context.Context, connLog *ConnAwareLogger, credentialManager *reconnectCredentialManager, + configManager *configManager, config *TunnelConfig, addr *allregions.EdgeAddr, connIndex uint8, @@ -281,6 +286,7 @@ func ServeTunnel( ctx, connLog, credentialManager, + configManager, config, addr, connIndex, @@ -329,6 +335,7 @@ func serveTunnel( ctx context.Context, connLog *ConnAwareLogger, credentialManager *reconnectCredentialManager, + configManager *configManager, config *TunnelConfig, addr *allregions.EdgeAddr, connIndex uint8, @@ -339,7 +346,6 @@ func serveTunnel( protocol connection.Protocol, gracefulShutdownC <-chan struct{}, ) (err error, recoverable bool) { - connectedFuse := &connectedFuse{ fuse: fuse, backoff: backoff, @@ -351,14 +357,15 @@ func serveTunnel( connIndex, nil, gracefulShutdownC, - config.ConnectionConfig.GracePeriod, + config.GracePeriod, ) switch protocol { case connection.QUIC, connection.QUICWarp: - connOptions := config.ConnectionOptions(addr.UDP.String(), uint8(backoff.Retries())) + connOptions := config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries())) return ServeQUIC(ctx, addr.UDP, + configManager, config, connLog, connOptions, @@ -374,10 +381,11 @@ func serveTunnel( return err, true } - connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries())) + connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries())) if err := ServeHTTP2( ctx, connLog, + configManager, config, edgeConn, connOptions, @@ -400,6 +408,7 @@ func serveTunnel( ctx, connLog, credentialManager, + configManager, config, edgeConn, connIndex, @@ -426,6 +435,7 @@ func ServeH2mux( ctx context.Context, connLog *ConnAwareLogger, credentialManager *reconnectCredentialManager, + configManager *configManager, config *TunnelConfig, edgeConn net.Conn, connIndex uint8, @@ -437,7 +447,8 @@ func ServeH2mux( connLog.Logger().Debug().Msgf("Connecting via h2mux") // Returns error from parsing the origin URL or handshake errors handler, err, recoverable := connection.NewH2muxConnection( - config.ConnectionConfig, + configManager, + config.GracePeriod, config.MuxerConfig, edgeConn, connIndex, @@ -455,10 +466,10 @@ func ServeH2mux( errGroup.Go(func() error { if config.NamedTunnel != nil { - connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries())) + connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries())) return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse) } - registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) + registrationOptions := config.registrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse) }) @@ -472,6 +483,7 @@ func ServeH2mux( func ServeHTTP2( ctx context.Context, connLog *ConnAwareLogger, + configManager *configManager, config *TunnelConfig, tlsServerConn net.Conn, connOptions *tunnelpogs.ConnectionOptions, @@ -483,7 +495,7 @@ func ServeHTTP2( connLog.Logger().Debug().Msgf("Connecting via http2") h2conn := connection.NewHTTP2Connection( tlsServerConn, - config.ConnectionConfig, + configManager, connOptions, config.Observer, connIndex, @@ -511,6 +523,7 @@ func ServeHTTP2( func ServeQUIC( ctx context.Context, edgeAddr *net.UDPAddr, + configManager *configManager, config *TunnelConfig, connLogger *ConnAwareLogger, connOptions *tunnelpogs.ConnectionOptions, @@ -535,7 +548,7 @@ func ServeQUIC( quicConfig, edgeAddr, tlsConfig, - config.ConnectionConfig.OriginProxy, + configManager, connOptions, controlStreamHandler, connLogger.Logger()) diff --git a/origin/tunnel_test.go b/supervisor/tunnel_test.go similarity index 96% rename from origin/tunnel_test.go rename to supervisor/tunnel_test.go index 870a5049..2e646089 100644 --- a/origin/tunnel_test.go +++ b/supervisor/tunnel_test.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "testing" @@ -32,11 +32,7 @@ func TestWaitForBackoffFallback(t *testing.T) { } log := zerolog.Nop() resolveTTL := time.Duration(0) - namedTunnel := &connection.NamedTunnelConfig{ - Credentials: connection.Credentials{ - AccountTag: "test-account", - }, - } + namedTunnel := &connection.NamedTunnelProperties{} mockFetcher := dynamicMockFetcher{ protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}, } diff --git a/origin/tunnelsforha.go b/supervisor/tunnelsforha.go similarity index 98% rename from origin/tunnelsforha.go rename to supervisor/tunnelsforha.go index 61673737..80704e38 100644 --- a/origin/tunnelsforha.go +++ b/supervisor/tunnelsforha.go @@ -1,4 +1,4 @@ -package origin +package supervisor import ( "fmt"