From d68ff390cab1767ad55d001e62792c0f7bbd4931 Mon Sep 17 00:00:00 2001 From: cthuang Date: Fri, 11 Feb 2022 10:49:06 +0000 Subject: [PATCH] TUN-5698: Make ingress rules and warp routing dynamically configurable --- cmd/cloudflared/tunnel/cmd.go | 6 +- cmd/cloudflared/tunnel/configuration.go | 5 +- connection/connection.go | 6 +- connection/connection_test.go | 17 +- connection/h2mux.go | 20 +- connection/h2mux_test.go | 2 +- connection/http2.go | 26 +- connection/http2_test.go | 2 +- connection/quic.go | 14 +- connection/quic_test.go | 2 +- ingress/ingress.go | 5 +- ingress/origin_proxy_test.go | 9 +- ingress/origin_service.go | 36 +- orchestration/config.go | 15 + orchestration/orchestrator.go | 158 ++++++ orchestration/orchestrator_test.go | 686 ++++++++++++++++++++++++ proxy/proxy.go | 17 +- proxy/proxy_test.go | 26 +- supervisor/configmanager.go | 55 -- supervisor/supervisor.go | 13 +- supervisor/tunnel.go | 33 +- 21 files changed, 978 insertions(+), 175 deletions(-) create mode 100644 orchestration/config.go create mode 100644 orchestration/orchestrator.go create mode 100644 orchestration/orchestrator_test.go delete mode 100644 supervisor/configmanager.go diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 728bb8b7..2a692f73 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -31,6 +31,7 @@ import ( "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/metrics" + "github.com/cloudflare/cloudflared/orchestration" "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/tlsconfig" @@ -353,7 +354,8 @@ func StartServer( errC <- metrics.ServeMetrics(metricsListener, ctx.Done(), readinessServer, quickTunnelURL, log) }() - if err := dynamicConfig.Ingress.StartOrigins(&wg, log, ctx.Done(), errC); err != nil { + orchestrator, err := orchestration.NewOrchestrator(ctx, dynamicConfig, tunnelConfig.Tags, tunnelConfig.Log) + if err != nil { return err } @@ -369,7 +371,7 @@ func StartServer( wg.Done() log.Info().Msg("Tunnel server stopped") }() - errC <- supervisor.StartTunnelDaemon(ctx, tunnelConfig, dynamicConfig, connectedSignal, reconnectCh, graceShutdownC) + errC <- supervisor.StartTunnelDaemon(ctx, tunnelConfig, orchestrator, connectedSignal, reconnectCh, graceShutdownC) }() if isUIEnabled { diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 61e65a8f..27ce90e0 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -23,6 +23,7 @@ import ( "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/orchestration" "github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/tlsconfig" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -153,7 +154,7 @@ func prepareTunnelConfig( log, logTransport *zerolog.Logger, observer *connection.Observer, namedTunnel *connection.NamedTunnelProperties, -) (*supervisor.TunnelConfig, *supervisor.DynamicConfig, error) { +) (*supervisor.TunnelConfig, *orchestration.Config, error) { isNamedTunnel := namedTunnel != nil configHostname := c.String("hostname") @@ -292,7 +293,7 @@ func prepareTunnelConfig( ProtocolSelector: protocolSelector, EdgeTLSConfigs: edgeTLSConfigs, } - dynamicConfig := &supervisor.DynamicConfig{ + dynamicConfig := &orchestration.Config{ Ingress: &ingressRules, WarpRoutingEnabled: warpRoutingEnabled, } diff --git a/connection/connection.go b/connection/connection.go index 07649983..007011d0 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -25,9 +25,9 @@ const ( var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) -type ConfigManager interface { - Update(version int32, config []byte) *pogs.UpdateConfigurationResponse - GetOriginProxy() OriginProxy +type Orchestrator interface { + UpdateConfig(version int32, config []byte) *pogs.UpdateConfigurationResponse + GetOriginProxy() (OriginProxy, error) } type NamedTunnelProperties struct { diff --git a/connection/connection_test.go b/connection/connection_test.go index 3b83269e..9e43fee2 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -6,14 +6,12 @@ import ( "io" "math/rand" "net/http" - "net/url" "testing" "time" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" - "github.com/cloudflare/cloudflared/ingress" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/websocket" ) @@ -24,15 +22,10 @@ const ( ) var ( - unusedWarpRoutingService = (*ingress.WarpRoutingService)(nil) - testConfigManager = &mockConfigManager{ + testOrchestrator = &mockOrchestrator{ originProxy: &mockOriginProxy{}, } log = zerolog.Nop() - testOriginURL = &url.URL{ - Scheme: "https", - Host: "connectiontest.argotunnel.com", - } testLargeResp = make([]byte, largeFileSize) ) @@ -44,18 +37,18 @@ type testRequest struct { isProxyError bool } -type mockConfigManager struct { +type mockOrchestrator struct { originProxy OriginProxy } -func (*mockConfigManager) Update(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { +func (*mockOrchestrator) UpdateConfig(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { return &tunnelpogs.UpdateConfigurationResponse{ LastAppliedVersion: version, } } -func (mcr *mockConfigManager) GetOriginProxy() OriginProxy { - return mcr.originProxy +func (mcr *mockOrchestrator) GetOriginProxy() (OriginProxy, error) { + return mcr.originProxy, nil } type mockOriginProxy struct{} diff --git a/connection/h2mux.go b/connection/h2mux.go index 8401c81a..1c7276ac 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -22,10 +22,10 @@ const ( ) type h2muxConnection struct { - configManager ConfigManager - gracePeriod time.Duration - muxerConfig *MuxerConfig - muxer *h2mux.Muxer + orchestrator Orchestrator + gracePeriod time.Duration + muxerConfig *MuxerConfig + muxer *h2mux.Muxer // connectionID is only used by metrics, and prometheus requires labels to be string connIndexStr string connIndex uint8 @@ -61,7 +61,7 @@ func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, log *zerolog.Lo // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error func NewH2muxConnection( - configManager ConfigManager, + orchestrator Orchestrator, gracePeriod time.Duration, muxerConfig *MuxerConfig, edgeConn net.Conn, @@ -70,7 +70,7 @@ func NewH2muxConnection( gracefulShutdownC <-chan struct{}, ) (*h2muxConnection, error, bool) { h := &h2muxConnection{ - configManager: configManager, + orchestrator: orchestrator, gracePeriod: gracePeriod, muxerConfig: muxerConfig, connIndexStr: uint8ToString(connIndex), @@ -227,7 +227,13 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error { sourceConnectionType = TypeWebsocket } - err := h.configManager.GetOriginProxy().ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket) + originProxy, err := h.orchestrator.GetOriginProxy() + if err != nil { + respWriter.WriteErrorResponse() + return err + } + + err = originProxy.ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket) if err != nil { respWriter.WriteErrorResponse() } diff --git a/connection/h2mux_test.go b/connection/h2mux_test.go index 35ab447c..787cfd17 100644 --- a/connection/h2mux_test.go +++ b/connection/h2mux_test.go @@ -48,7 +48,7 @@ func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) { }() var connIndex = uint8(0) testObserver := NewObserver(&log, &log, false) - h2muxConn, err, _ := NewH2muxConnection(testConfigManager, testGracePeriod, testMuxerConfig, originConn, connIndex, testObserver, nil) + h2muxConn, err, _ := NewH2muxConnection(testOrchestrator, testGracePeriod, testMuxerConfig, originConn, connIndex, testObserver, nil) require.NoError(t, err) return h2muxConn, <-edgeMuxChan } diff --git a/connection/http2.go b/connection/http2.go index 1e1f517a..d1e78c1f 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -30,12 +30,12 @@ var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed") // HTTP2Connection represents a net.Conn that uses HTTP2 frames to proxy traffic from the edge to cloudflared on the // origin. type HTTP2Connection struct { - conn net.Conn - server *http2.Server - configManager ConfigManager - connOptions *tunnelpogs.ConnectionOptions - observer *Observer - connIndex uint8 + conn net.Conn + server *http2.Server + orchestrator Orchestrator + connOptions *tunnelpogs.ConnectionOptions + observer *Observer + connIndex uint8 // newRPCClientFunc allows us to mock RPCs during testing newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient @@ -49,7 +49,7 @@ type HTTP2Connection struct { // NewHTTP2Connection returns a new instance of HTTP2Connection. func NewHTTP2Connection( conn net.Conn, - configManager ConfigManager, + orchestrator Orchestrator, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, @@ -61,7 +61,7 @@ func NewHTTP2Connection( server: &http2.Server{ MaxConcurrentStreams: MaxConcurrentStreams, }, - configManager: configManager, + orchestrator: orchestrator, connOptions: connOptions, observer: observer, connIndex: connIndex, @@ -106,6 +106,12 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + originProxy, err := c.orchestrator.GetOriginProxy() + if err != nil { + c.observer.log.Error().Msg(err.Error()) + return + } + switch connType { case TypeControlStream: if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions); err != nil { @@ -116,7 +122,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { case TypeWebsocket, TypeHTTP: stripWebsocketUpgradeHeader(r) - if err := c.configManager.GetOriginProxy().ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil { + if err := originProxy.ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil { err := fmt.Errorf("Failed to proxy HTTP: %w", err) c.log.Error().Err(err) respWriter.WriteErrorResponse() @@ -131,7 +137,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { } rws := NewHTTPResponseReadWriterAcker(respWriter, r) - if err := c.configManager.GetOriginProxy().ProxyTCP(r.Context(), rws, &TCPRequest{ + if err := originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{ Dest: host, CFRay: FindCfRayHeader(r), LBProbe: IsLBProbeRequest(r), diff --git a/connection/http2_test.go b/connection/http2_test.go index c405d1ee..c067229c 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -44,7 +44,7 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { return NewHTTP2Connection( cfdConn, // OriginProxy is set in testConfigManager - testConfigManager, + testOrchestrator, &pogs.ConnectionOptions{}, obs, connIndex, diff --git a/connection/quic.go b/connection/quic.go index d0d38daa..5c4c893d 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -36,7 +36,7 @@ const ( type QUICConnection struct { session quic.Session logger *zerolog.Logger - configManager ConfigManager + orchestrator Orchestrator sessionManager datagramsession.Manager controlStreamHandler ControlStreamHandler connOptions *tunnelpogs.ConnectionOptions @@ -47,7 +47,7 @@ func NewQUICConnection( quicConfig *quic.Config, edgeAddr net.Addr, tlsConfig *tls.Config, - configManager ConfigManager, + orchestrator Orchestrator, connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler ControlStreamHandler, logger *zerolog.Logger, @@ -66,7 +66,7 @@ func NewQUICConnection( return &QUICConnection{ session: session, - configManager: configManager, + orchestrator: orchestrator, logger: logger, sessionManager: sessionManager, controlStreamHandler: controlStreamHandler, @@ -175,6 +175,10 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) return err } + originProxy, err := q.orchestrator.GetOriginProxy() + if err != nil { + return err + } switch connectRequest.Type { case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket: req, err := buildHTTPRequest(connectRequest, stream) @@ -183,10 +187,10 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) } w := newHTTPResponseAdapter(stream) - return q.configManager.GetOriginProxy().ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket) + return originProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket) case quicpogs.ConnectionTypeTCP: rwa := &streamReadWriteAcker{stream} - return q.configManager.GetOriginProxy().ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest}) + return originProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest}) } return nil } diff --git a/connection/quic_test.go b/connection/quic_test.go index 4eecc2a1..9763ae33 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -632,7 +632,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection testQUICConfig, udpListenerAddr, tlsClientConfig, - &mockConfigManager{originProxy: &mockOriginProxyWithRequest{}}, + &mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}}, &tunnelpogs.ConnectionOptions{}, fakeControlStream{}, &log, diff --git a/ingress/ingress.go b/ingress/ingress.go index f2ab2791..5e5f9655 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -7,7 +7,6 @@ import ( "regexp" "strconv" "strings" - "sync" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -145,13 +144,11 @@ func (ing Ingress) IsSingleRule() bool { // StartOrigins will start any origin services managed by cloudflared, e.g. proxy servers or Hello World. func (ing Ingress) StartOrigins( - wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, - errC chan error, ) error { for _, rule := range ing.Rules { - if err := rule.Service.start(wg, log, shutdownC, errC, rule.Config); err != nil { + if err := rule.Service.start(log, shutdownC, rule.Config); err != nil { return errors.Wrapf(err, "Error starting local service %s", rule.Service) } } diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index 5716bba1..cc244aee 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -8,7 +8,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "sync" "testing" "github.com/stretchr/testify/assert" @@ -132,10 +131,8 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) { httpService := &httpService{ url: originURL, } - var wg sync.WaitGroup shutdownC := make(chan struct{}) - errC := make(chan error) - require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg)) + require.NoError(t, httpService.start(testLogger, shutdownC, cfg)) req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) require.NoError(t, err) @@ -169,10 +166,8 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) { httpService := &httpService{ url: originURL, } - var wg sync.WaitGroup shutdownC := make(chan struct{}) - errC := make(chan error) - require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg)) + require.NoError(t, httpService.start(testLogger, shutdownC, cfg)) // Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header protos := []string{"https", "http", "dne"} diff --git a/ingress/origin_service.go b/ingress/origin_service.go index fc636c86..116b77f0 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -8,7 +8,6 @@ import ( "net" "net/http" "net/url" - "sync" "time" "github.com/pkg/errors" @@ -20,13 +19,18 @@ import ( "github.com/cloudflare/cloudflared/tlsconfig" ) +const ( + HelloWorldService = "Hello World test origin" +) + // OriginService is something a tunnel can proxy traffic to. type OriginService interface { String() string // Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World. // If it's not managed by cloudflared, this is a no-op because the user is responsible for // starting the origin service. - start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error + // Implementor of services managed by cloudflared should terminate the service if shutdownC is closed + start(log *zerolog.Logger, shutdownC <-chan struct{}, cfg OriginRequestConfig) error } // unixSocketPath is an OriginService representing a unix socket (which accepts HTTP) @@ -39,7 +43,7 @@ func (o *unixSocketPath) String() string { return "unix socket: " + o.path } -func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *unixSocketPath) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { transport, err := newHTTPTransport(o, cfg, log) if err != nil { return err @@ -54,7 +58,7 @@ type httpService struct { transport *http.Transport } -func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *httpService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { transport, err := newHTTPTransport(o, cfg, log) if err != nil { return err @@ -78,7 +82,7 @@ func (o *rawTCPService) String() string { return o.name } -func (o *rawTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *rawTCPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { return nil } @@ -139,7 +143,7 @@ func (o *tcpOverWSService) String() string { return o.dest } -func (o *tcpOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *tcpOverWSService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { if cfg.ProxyType == socksProxy { o.streamHandler = socks.StreamHandler } else { @@ -148,7 +152,7 @@ func (o *tcpOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdo return nil } -func (o *socksProxyOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *socksProxyOverWSService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { return nil } @@ -164,18 +168,16 @@ type helloWorld struct { } func (o *helloWorld) String() string { - return "Hello World test origin" + return HelloWorldService } // Start starts a HelloWorld server and stores its address in the Service receiver. func (o *helloWorld) start( - wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, - errC chan error, cfg OriginRequestConfig, ) error { - if err := o.httpService.start(wg, log, shutdownC, errC, cfg); err != nil { + if err := o.httpService.start(log, shutdownC, cfg); err != nil { return err } @@ -183,11 +185,7 @@ func (o *helloWorld) start( if err != nil { return errors.Wrap(err, "Cannot start Hello World Server") } - wg.Add(1) - go func() { - defer wg.Done() - _ = hello.StartHelloWorldServer(log, helloListener, shutdownC) - }() + go hello.StartHelloWorldServer(log, helloListener, shutdownC) o.server = helloListener o.httpService.url = &url.URL{ @@ -218,10 +216,8 @@ func (o *statusCode) String() string { } func (o *statusCode) start( - wg *sync.WaitGroup, log *zerolog.Logger, - shutdownC <-chan struct{}, - errC chan error, + _ <-chan struct{}, cfg OriginRequestConfig, ) error { return nil @@ -296,6 +292,6 @@ func (mos MockOriginHTTPService) String() string { return "MockOriginService" } -func (mos MockOriginHTTPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (mos MockOriginHTTPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { return nil } diff --git a/orchestration/config.go b/orchestration/config.go new file mode 100644 index 00000000..dff7e701 --- /dev/null +++ b/orchestration/config.go @@ -0,0 +1,15 @@ +package orchestration + +import ( + "github.com/cloudflare/cloudflared/ingress" +) + +type newConfig struct { + ingress.RemoteConfig + // Add more fields when we support other settings in tunnel orchestration +} + +type Config struct { + Ingress *ingress.Ingress + WarpRoutingEnabled bool +} diff --git a/orchestration/orchestrator.go b/orchestration/orchestrator.go new file mode 100644 index 00000000..d072e966 --- /dev/null +++ b/orchestration/orchestrator.go @@ -0,0 +1,158 @@ +package orchestration + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "sync/atomic" + + "github.com/pkg/errors" + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/proxy" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" +) + +// Orchestrator manages configurations so they can be updatable during runtime +// properties are static, so it can be read without lock +// currentVersion and config are read/write infrequently, so their access are synchronized with RWMutex +// access to proxy is synchronized with atmoic.Value, because it uses copy-on-write to provide scalable frequently +// read when update is infrequent +type Orchestrator struct { + currentVersion int32 + // Used by UpdateConfig to make sure one update at a time + lock sync.RWMutex + // Underlying value is proxy.Proxy, can be read without the lock, but still needs the lock to update + proxy atomic.Value + config *Config + tags []tunnelpogs.Tag + log *zerolog.Logger + + // orchestrator must not handle any more updates after shutdownC is closed + shutdownC <-chan struct{} + // Closing proxyShutdownC will close the previous proxy + proxyShutdownC chan<- struct{} +} + +func NewOrchestrator(ctx context.Context, config *Config, tags []tunnelpogs.Tag, log *zerolog.Logger) (*Orchestrator, error) { + o := &Orchestrator{ + // Lowest possible version, any remote configuration will have version higher than this + currentVersion: 0, + config: config, + tags: tags, + log: log, + shutdownC: ctx.Done(), + } + if err := o.updateIngress(*config.Ingress, config.WarpRoutingEnabled); err != nil { + return nil, err + } + go o.waitToCloseLastProxy() + return o, nil +} + +// Update creates a new proxy with the new ingress rules +func (o *Orchestrator) UpdateConfig(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { + o.lock.Lock() + defer o.lock.Unlock() + + if o.currentVersion >= version { + o.log.Debug(). + Int32("current_version", o.currentVersion). + Int32("received_version", version). + Msg("Current version is equal or newer than receivied version") + return &tunnelpogs.UpdateConfigurationResponse{ + LastAppliedVersion: o.currentVersion, + } + } + var newConf newConfig + if err := json.Unmarshal(config, &newConf); err != nil { + o.log.Err(err). + Int32("version", version). + Str("config", string(config)). + Msgf("Failed to deserialize new configuration") + return &tunnelpogs.UpdateConfigurationResponse{ + LastAppliedVersion: o.currentVersion, + Err: err, + } + } + + if err := o.updateIngress(newConf.Ingress, newConf.WarpRouting.Enabled); err != nil { + o.log.Err(err). + Int32("version", version). + Str("config", string(config)). + Msgf("Failed to update ingress") + return &tunnelpogs.UpdateConfigurationResponse{ + LastAppliedVersion: o.currentVersion, + Err: err, + } + } + o.currentVersion = version + + o.log.Info(). + Int32("version", version). + Str("config", string(config)). + Msg("Updated to new configuration") + return &tunnelpogs.UpdateConfigurationResponse{ + LastAppliedVersion: o.currentVersion, + } +} + +// The caller is responsible to make sure there is no concurrent access +func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRoutingEnabled bool) error { + select { + case <-o.shutdownC: + return fmt.Errorf("cloudflared already shutdown") + default: + } + + // Start new proxy before closing the ones from last version. + // The upside is we don't need to restart proxy from last version, which can fail + // The downside is new version might have ingress rule that require previous version to be shutdown first + // The downside is minimized because none of the ingress.OriginService implementation have that requirement + proxyShutdownC := make(chan struct{}) + if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil { + return errors.Wrap(err, "failed to start origin") + } + newProxy := proxy.NewOriginProxy(ingressRules, warpRoutingEnabled, o.tags, o.log) + o.proxy.Store(newProxy) + o.config.Ingress = &ingressRules + o.config.WarpRoutingEnabled = warpRoutingEnabled + + // If proxyShutdownC is nil, there is no previous running proxy + if o.proxyShutdownC != nil { + close(o.proxyShutdownC) + } + o.proxyShutdownC = proxyShutdownC + return nil +} + +// GetOriginProxy returns an interface to proxy to origin. It satisfies connection.ConfigManager interface +func (o *Orchestrator) GetOriginProxy() (connection.OriginProxy, error) { + val := o.proxy.Load() + if val == nil { + err := fmt.Errorf("origin proxy not configured") + o.log.Error().Msg(err.Error()) + return nil, err + } + proxy, ok := val.(*proxy.Proxy) + if !ok { + err := fmt.Errorf("origin proxy has unexpected value %+v", val) + o.log.Error().Msg(err.Error()) + return nil, err + } + return proxy, nil +} + +func (o *Orchestrator) waitToCloseLastProxy() { + <-o.shutdownC + o.lock.Lock() + defer o.lock.Unlock() + + if o.proxyShutdownC != nil { + close(o.proxyShutdownC) + o.proxyShutdownC = nil + } +} diff --git a/orchestration/orchestrator_test.go b/orchestration/orchestrator_test.go new file mode 100644 index 00000000..b4b19224 --- /dev/null +++ b/orchestration/orchestrator_test.go @@ -0,0 +1,686 @@ +package orchestration + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/gobwas/ws/wsutil" + gows "github.com/gorilla/websocket" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/proxy" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" +) + +var ( + testLogger = zerolog.Logger{} + testTags = []tunnelpogs.Tag{ + { + Name: "package", + Value: "orchestration", + }, + { + Name: "purpose", + Value: "test", + }, + } +) + +// TestUpdateConfiguration tests that +// - configurations can be deserialized +// - proxy can be updated +// - last applied version and error are returned +// - configurations can be deserialized +// - receiving an old version is noop +func TestUpdateConfiguration(t *testing.T) { + initConfig := &Config{ + Ingress: &ingress.Ingress{}, + WarpRoutingEnabled: false, + } + orchestrator, err := NewOrchestrator(context.Background(), initConfig, testTags, &testLogger) + require.NoError(t, err) + initOriginProxy, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + require.IsType(t, &proxy.Proxy{}, initOriginProxy) + + configJSONV2 := []byte(` +{ + "unknown_field": "not_deserialized", + "originRequest": { + "connectTimeout": 90000000000, + "noHappyEyeballs": true + }, + "ingress": [ + { + "hostname": "jira.tunnel.org", + "path": "^\/login", + "service": "http://192.16.19.1:443", + "originRequest": { + "noTLSVerify": true, + "connectTimeout": 10000000000 + } + }, + { + "hostname": "jira.tunnel.org", + "service": "http://172.32.20.6:80", + "originRequest": { + "noTLSVerify": true, + "connectTimeout": 30000000000 + } + }, + { + "service": "http_status:404" + } + ], + "warp-routing": { + "enabled": true + } +} +`) + + updateWithValidation(t, orchestrator, 2, configJSONV2) + configV2 := orchestrator.config + // Validate ingress rule 0 + require.Equal(t, "jira.tunnel.org", configV2.Ingress.Rules[0].Hostname) + require.True(t, configV2.Ingress.Rules[0].Matches("jira.tunnel.org", "/login")) + require.True(t, configV2.Ingress.Rules[0].Matches("jira.tunnel.org", "/login/2fa")) + require.False(t, configV2.Ingress.Rules[0].Matches("jira.tunnel.org", "/users")) + require.Equal(t, "http://192.16.19.1:443", configV2.Ingress.Rules[0].Service.String()) + require.Len(t, configV2.Ingress.Rules, 3) + // originRequest of this ingress rule overrides global default + require.Equal(t, time.Second*10, configV2.Ingress.Rules[0].Config.ConnectTimeout) + require.Equal(t, true, configV2.Ingress.Rules[0].Config.NoTLSVerify) + // Inherited from global default + require.Equal(t, true, configV2.Ingress.Rules[0].Config.NoHappyEyeballs) + // Validate ingress rule 1 + require.Equal(t, "jira.tunnel.org", configV2.Ingress.Rules[1].Hostname) + require.True(t, configV2.Ingress.Rules[1].Matches("jira.tunnel.org", "/users")) + require.Equal(t, "http://172.32.20.6:80", configV2.Ingress.Rules[1].Service.String()) + // originRequest of this ingress rule overrides global default + require.Equal(t, time.Second*30, configV2.Ingress.Rules[1].Config.ConnectTimeout) + require.Equal(t, true, configV2.Ingress.Rules[1].Config.NoTLSVerify) + // Inherited from global default + require.Equal(t, true, configV2.Ingress.Rules[1].Config.NoHappyEyeballs) + // Validate ingress rule 2, it's the catch-all rule + require.True(t, configV2.Ingress.Rules[2].Matches("blogs.tunnel.io", "/2022/02/10")) + // Inherited from global default + require.Equal(t, time.Second*90, configV2.Ingress.Rules[2].Config.ConnectTimeout) + require.Equal(t, false, configV2.Ingress.Rules[2].Config.NoTLSVerify) + require.Equal(t, true, configV2.Ingress.Rules[2].Config.NoHappyEyeballs) + require.True(t, configV2.WarpRoutingEnabled) + + originProxyV2, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + require.IsType(t, &proxy.Proxy{}, originProxyV2) + require.NotEqual(t, originProxyV2, initOriginProxy) + + // Should not downgrade to an older version + resp := orchestrator.UpdateConfig(1, nil) + require.NoError(t, resp.Err) + require.Equal(t, int32(2), resp.LastAppliedVersion) + + invalidJSON := []byte(` +{ + "originRequest": +} + +`) + + resp = orchestrator.UpdateConfig(3, invalidJSON) + require.Error(t, resp.Err) + require.Equal(t, int32(2), resp.LastAppliedVersion) + originProxyV3, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + require.Equal(t, originProxyV2, originProxyV3) + + configJSONV10 := []byte(` +{ + "ingress": [ + { + "service": "hello-world" + } + ], + "warp-routing": { + "enabled": false + } +} +`) + updateWithValidation(t, orchestrator, 10, configJSONV10) + configV10 := orchestrator.config + require.Len(t, configV10.Ingress.Rules, 1) + require.True(t, configV10.Ingress.Rules[0].Matches("blogs.tunnel.io", "/2022/02/10")) + require.Equal(t, ingress.HelloWorldService, configV10.Ingress.Rules[0].Service.String()) + require.False(t, configV10.WarpRoutingEnabled) + + originProxyV10, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + require.IsType(t, &proxy.Proxy{}, originProxyV10) + require.NotEqual(t, originProxyV10, originProxyV2) +} + +// TestConcurrentUpdateAndRead makes sure orchestrator can receive updates and return origin proxy concurrently +func TestConcurrentUpdateAndRead(t *testing.T) { + const ( + concurrentRequests = 200 + hostname = "public.tunnels.org" + expectedHost = "internal.tunnels.svc.cluster.local" + tcpBody = "testProxyTCP" + ) + + httpOrigin := httptest.NewServer(&validateHostHandler{ + expectedHost: expectedHost, + body: t.Name(), + }) + defer httpOrigin.Close() + + tcpOrigin, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer tcpOrigin.Close() + + var ( + configJSONV1 = []byte(fmt.Sprintf(` +{ + "originRequest": { + "connectTimeout": 90000000000, + "noHappyEyeballs": true + }, + "ingress": [ + { + "hostname": "%s", + "service": "%s", + "originRequest": { + "httpHostHeader": "%s", + "connectTimeout": 10000000000 + } + }, + { + "service": "http_status:404" + } + ], + "warp-routing": { + "enabled": true + } +} +`, hostname, httpOrigin.URL, expectedHost)) + configJSONV2 = []byte(` +{ + "ingress": [ + { + "service": "http_status:204" + } + ], + "warp-routing": { + "enabled": false + } +} +`) + + configJSONV3 = []byte(` +{ + "ingress": [ + { + "service": "http_status:418" + } + ], + "warp-routing": { + "enabled": true + } +} +`) + + // appliedV2 makes sure v3 is applied after v2 + appliedV2 = make(chan struct{}) + + initConfig = &Config{ + Ingress: &ingress.Ingress{}, + WarpRoutingEnabled: false, + } + ) + + orchestrator, err := NewOrchestrator(context.Background(), initConfig, testTags, &testLogger) + require.NoError(t, err) + + updateWithValidation(t, orchestrator, 1, configJSONV1) + + var wg sync.WaitGroup + // tcpOrigin will be closed when the test exits. Only the handler routines are included in the wait group + go func() { + serveTCPOrigin(t, tcpOrigin, &wg) + }() + for i := 0; i < concurrentRequests; i++ { + originProxy, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + wg.Add(1) + go func(i int, originProxy connection.OriginProxy) { + defer wg.Done() + resp, err := proxyHTTP(t, originProxy, hostname) + require.NoError(t, err) + + var warpRoutingDisabled bool + // The response can be from initOrigin, http_status:204 or http_status:418 + switch resp.StatusCode { + // v1 proxy, warp enabled + case 200: + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, t.Name(), string(body)) + warpRoutingDisabled = false + // v2 proxy, warp disabled + case 204: + require.Greater(t, i, concurrentRequests/4) + warpRoutingDisabled = true + // v3 proxy, warp enabled + case 418: + require.Greater(t, i, concurrentRequests/2) + warpRoutingDisabled = false + } + + // Once we have originProxy, it won't be changed by configuration updates. + // We can infer the version by the ProxyHTTP response code + pr, pw := io.Pipe() + // concurrentRespWriter makes sure ResponseRecorder is not read/write concurrently, and read waits for the first write + w := newRespReadWriteFlusher() + + // Write TCP message and make sure it's echo back. This has to be done in a go routune since ProxyTCP doesn't + // return until the stream is closed. + if !warpRoutingDisabled { + wg.Add(1) + go func() { + defer wg.Done() + defer pw.Close() + tcpEyeball(t, pw, tcpBody, w) + }() + } + proxyTCP(t, originProxy, tcpOrigin.Addr().String(), w, pr, warpRoutingDisabled) + }(i, originProxy) + + if i == concurrentRequests/4 { + wg.Add(1) + go func() { + defer wg.Done() + updateWithValidation(t, orchestrator, 2, configJSONV2) + close(appliedV2) + }() + } + + if i == concurrentRequests/2 { + wg.Add(1) + go func() { + defer wg.Done() + <-appliedV2 + updateWithValidation(t, orchestrator, 3, configJSONV3) + }() + } + } + + wg.Wait() +} + +func proxyHTTP(t *testing.T, originProxy connection.OriginProxy, hostname string) (*http.Response, error) { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", hostname), nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + respWriter, err := connection.NewHTTP2RespWriter(req, w, connection.TypeHTTP) + require.NoError(t, err) + + err = originProxy.ProxyHTTP(respWriter, req, false) + if err != nil { + return nil, err + } + + return w.Result(), nil +} + +func tcpEyeball(t *testing.T, reqWriter io.WriteCloser, body string, respReadWriter *respReadWriteFlusher) { + writeN, err := reqWriter.Write([]byte(body)) + require.NoError(t, err) + + readBuffer := make([]byte, writeN) + n, err := respReadWriter.Read(readBuffer) + require.NoError(t, err) + require.Equal(t, body, string(readBuffer[:n])) + require.Equal(t, writeN, n) +} + +func proxyTCP(t *testing.T, originProxy connection.OriginProxy, originAddr string, w http.ResponseWriter, reqBody io.ReadCloser, expectErr bool) { + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originAddr), reqBody) + require.NoError(t, err) + + respWriter, err := connection.NewHTTP2RespWriter(req, w, connection.TypeTCP) + require.NoError(t, err) + + tcpReq := &connection.TCPRequest{ + Dest: originAddr, + CFRay: "123", + LBProbe: false, + } + rws := connection.NewHTTPResponseReadWriterAcker(respWriter, req) + if expectErr { + require.Error(t, originProxy.ProxyTCP(context.Background(), rws, tcpReq)) + return + } + + require.NoError(t, originProxy.ProxyTCP(context.Background(), rws, tcpReq)) +} + +func serveTCPOrigin(t *testing.T, tcpOrigin net.Listener, wg *sync.WaitGroup) { + for { + conn, err := tcpOrigin.Accept() + if err != nil { + return + } + wg.Add(1) + go func() { + defer wg.Done() + defer conn.Close() + + echoTCP(t, conn) + }() + } +} + +func echoTCP(t *testing.T, conn net.Conn) { + readBuf := make([]byte, 1000) + readN, err := conn.Read(readBuf) + require.NoError(t, err) + + writeN, err := conn.Write(readBuf[:readN]) + require.NoError(t, err) + require.Equal(t, readN, writeN) +} + +type validateHostHandler struct { + expectedHost string + body string +} + +func (vhh *validateHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Host != vhh.expectedHost { + w.WriteHeader(http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(vhh.body)) +} + +func updateWithValidation(t *testing.T, orchestrator *Orchestrator, version int32, config []byte) { + resp := orchestrator.UpdateConfig(version, config) + require.NoError(t, resp.Err) + require.Equal(t, version, resp.LastAppliedVersion) +} + +// TestClosePreviousProxies makes sure proxies started in the pervious configuration version are shutdown +func TestClosePreviousProxies(t *testing.T) { + var ( + hostname = "hello.tunnel1.org" + configWithHelloWorld = []byte(fmt.Sprintf(` +{ + "ingress": [ + { + "hostname": "%s", + "service": "hello-world" + }, + { + "service": "http_status:404" + } + ], + "warp-routing": { + "enabled": true + } +} +`, hostname)) + + configTeapot = []byte(` +{ + "ingress": [ + { + "service": "http_status:418" + } + ], + "warp-routing": { + "enabled": true + } +} +`) + initConfig = &Config{ + Ingress: &ingress.Ingress{}, + WarpRoutingEnabled: false, + } + ) + + ctx, cancel := context.WithCancel(context.Background()) + orchestrator, err := NewOrchestrator(ctx, initConfig, testTags, &testLogger) + require.NoError(t, err) + + updateWithValidation(t, orchestrator, 1, configWithHelloWorld) + + originProxyV1, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + resp, err := proxyHTTP(t, originProxyV1, hostname) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + updateWithValidation(t, orchestrator, 2, configTeapot) + + originProxyV2, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + resp, err = proxyHTTP(t, originProxyV2, hostname) + require.NoError(t, err) + require.Equal(t, http.StatusTeapot, resp.StatusCode) + + // The hello-world server in config v1 should have been stopped + resp, err = proxyHTTP(t, originProxyV1, hostname) + require.Error(t, err) + require.Nil(t, resp) + + // Apply the config with hello world server again, orchestrator should spin up another hello world server + updateWithValidation(t, orchestrator, 3, configWithHelloWorld) + + originProxyV3, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + require.NotEqual(t, originProxyV1, originProxyV3) + + resp, err = proxyHTTP(t, originProxyV3, hostname) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // cancel the context should terminate the last proxy + cancel() + // Wait for proxies to shutdown + time.Sleep(time.Millisecond * 10) + + resp, err = proxyHTTP(t, originProxyV3, hostname) + require.Error(t, err) + require.Nil(t, resp) +} + +// TestPersistentConnection makes sure updating the ingress doesn't intefere with existing connections +func TestPersistentConnection(t *testing.T) { + const ( + hostname = "http://ws.tunnel.org" + ) + msg := t.Name() + initConfig := &Config{ + Ingress: &ingress.Ingress{}, + WarpRoutingEnabled: false, + } + orchestrator, err := NewOrchestrator(context.Background(), initConfig, testTags, &testLogger) + require.NoError(t, err) + + wsOrigin := httptest.NewServer(http.HandlerFunc(wsEcho)) + defer wsOrigin.Close() + + tcpOrigin, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer tcpOrigin.Close() + + configWithWSAndWarp := []byte(fmt.Sprintf(` +{ + "ingress": [ + { + "service": "%s" + } + ], + "warp-routing": { + "enabled": true + } +} +`, wsOrigin.URL)) + + updateWithValidation(t, orchestrator, 1, configWithWSAndWarp) + + originProxy, err := orchestrator.GetOriginProxy() + require.NoError(t, err) + + wsReqReader, wsReqWriter := io.Pipe() + wsRespReadWriter := newRespReadWriteFlusher() + + tcpReqReader, tcpReqWriter := io.Pipe() + tcpRespReadWriter := newRespReadWriteFlusher() + + var wg sync.WaitGroup + wg.Add(3) + // Start TCP origin + go func() { + defer wg.Done() + conn, err := tcpOrigin.Accept() + require.NoError(t, err) + defer conn.Close() + + // Expect 3 TCP messages + for i := 0; i < 3; i++ { + echoTCP(t, conn) + } + }() + // Simulate cloudflared recieving a TCP connection + go func() { + defer wg.Done() + proxyTCP(t, originProxy, tcpOrigin.Addr().String(), tcpRespReadWriter, tcpReqReader, false) + }() + // Simulate cloudflared recieving a WS connection + go func() { + defer wg.Done() + + req, err := http.NewRequest(http.MethodGet, hostname, wsReqReader) + require.NoError(t, err) + // ProxyHTTP will add Connection, Upgrade and Sec-Websocket-Version headers + req.Header.Add("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") + + respWriter, err := connection.NewHTTP2RespWriter(req, wsRespReadWriter, connection.TypeWebsocket) + require.NoError(t, err) + + err = originProxy.ProxyHTTP(respWriter, req, true) + require.NoError(t, err) + }() + + // Simulate eyeball WS and TCP connections + validateWsEcho(t, msg, wsReqWriter, wsRespReadWriter) + tcpEyeball(t, tcpReqWriter, msg, tcpRespReadWriter) + + configNoWSAndWarp := []byte(` +{ + "ingress": [ + { + "service": "http_status:404" + } + ], + "warp-routing": { + "enabled": false + } +} +`) + + updateWithValidation(t, orchestrator, 2, configNoWSAndWarp) + // Make sure connection is still up + validateWsEcho(t, msg, wsReqWriter, wsRespReadWriter) + tcpEyeball(t, tcpReqWriter, msg, tcpRespReadWriter) + + updateWithValidation(t, orchestrator, 3, configWithWSAndWarp) + // Make sure connection is still up + validateWsEcho(t, msg, wsReqWriter, wsRespReadWriter) + tcpEyeball(t, tcpReqWriter, msg, tcpRespReadWriter) + + wsReqWriter.Close() + tcpReqWriter.Close() + wg.Wait() +} + +func wsEcho(w http.ResponseWriter, r *http.Request) { + upgrader := gows.Upgrader{} + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + for { + mt, message, err := conn.ReadMessage() + if err != nil { + fmt.Println("read message err", err) + break + } + err = conn.WriteMessage(mt, message) + if err != nil { + fmt.Println("write message err", err) + break + } + } +} + +func validateWsEcho(t *testing.T, msg string, reqWriter io.Writer, respReadWriter io.ReadWriter) { + err := wsutil.WriteClientText(reqWriter, []byte(msg)) + require.NoError(t, err) + + receivedMsg, err := wsutil.ReadServerText(respReadWriter) + require.NoError(t, err) + require.Equal(t, msg, string(receivedMsg)) +} + +type respReadWriteFlusher struct { + io.Reader + w io.Writer + headers http.Header + statusCode int + setStatusOnce sync.Once + hasStatus chan struct{} +} + +func newRespReadWriteFlusher() *respReadWriteFlusher { + pr, pw := io.Pipe() + return &respReadWriteFlusher{ + Reader: pr, + w: pw, + headers: make(http.Header), + hasStatus: make(chan struct{}), + } +} + +func (rrw *respReadWriteFlusher) Write(buf []byte) (int, error) { + rrw.WriteHeader(http.StatusOK) + return rrw.w.Write(buf) +} + +func (rrw *respReadWriteFlusher) Flush() {} + +func (rrw *respReadWriteFlusher) Header() http.Header { + return rrw.headers +} + +func (rrw *respReadWriteFlusher) WriteHeader(statusCode int) { + rrw.setStatusOnce.Do(func() { + rrw.statusCode = statusCode + close(rrw.hasStatus) + }) +} diff --git a/proxy/proxy.go b/proxy/proxy.go index b849967d..5ec8743b 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -28,7 +28,7 @@ const ( // Proxy represents a means to Proxy between cloudflared and the origin services. type Proxy struct { - ingressRules *ingress.Ingress + ingressRules ingress.Ingress warpRouting *ingress.WarpRoutingService tags []tunnelpogs.Tag log *zerolog.Logger @@ -37,18 +37,23 @@ type Proxy struct { // NewOriginProxy returns a new instance of the Proxy struct. func NewOriginProxy( - ingressRules *ingress.Ingress, - warpRouting *ingress.WarpRoutingService, + ingressRules ingress.Ingress, + warpRoutingEnabled bool, tags []tunnelpogs.Tag, log *zerolog.Logger, ) *Proxy { - return &Proxy{ + proxy := &Proxy{ ingressRules: ingressRules, - warpRouting: warpRouting, tags: tags, log: log, bufferPool: newBufferPool(512 * 1024), } + if warpRoutingEnabled { + proxy.warpRouting = ingress.NewWarpRoutingService() + log.Info().Msgf("Warp-routing is enabled") + } + + return proxy } // ProxyHTTP further depends on ingress rules to establish a connection with the origin service. This may be @@ -139,7 +144,7 @@ func (p *Proxy) ProxyTCP( return nil } -func ruleField(ing *ingress.Ingress, ruleNum int) (ruleID string, srv string) { +func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) { srv = ing.Rules[ruleNum].Service.String() if ing.IsSingleRule() { return "", srv diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 8ccc7624..db8747f7 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -31,8 +31,7 @@ import ( ) var ( - testTags = []tunnelpogs.Tag{tunnelpogs.Tag{Name: "Name", Value: "value"}} - unusedWarpRoutingService = (*ingress.WarpRoutingService)(nil) + testTags = []tunnelpogs.Tag{tunnelpogs.Tag{Name: "Name", Value: "value"}} ) type mockHTTPRespWriter struct { @@ -131,17 +130,14 @@ func TestProxySingleOrigin(t *testing.T) { ingressRule, err := ingress.NewSingleOrigin(cliCtx, allowURLFromArgs) require.NoError(t, err) - var wg sync.WaitGroup - errC := make(chan error) - require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC)) + require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(&ingressRule, unusedWarpRoutingService, testTags, &log) + proxy := NewOriginProxy(ingressRule, false, testTags, &log) t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxySSE", testProxySSE(proxy)) t.Run("testProxySSEAllData", testProxySSEAllData(proxy)) cancel() - wg.Wait() } func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) { @@ -341,11 +337,9 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat log := zerolog.Nop() ctx, cancel := context.WithCancel(context.Background()) - errC := make(chan error) - var wg sync.WaitGroup - require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC)) + require.NoError(t, ingress.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(&ingress, unusedWarpRoutingService, testTags, &log) + proxy := NewOriginProxy(ingress, false, testTags, &log) for _, test := range tests { responseWriter := newMockHTTPRespWriter() @@ -363,7 +357,6 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat } } cancel() - wg.Wait() } type mockAPI struct{} @@ -394,7 +387,7 @@ func TestProxyError(t *testing.T) { log := zerolog.Nop() - proxy := NewOriginProxy(&ing, unusedWarpRoutingService, testTags, &log) + proxy := NewOriginProxy(ing, false, testTags, &log) responseWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) @@ -634,10 +627,9 @@ func TestConnections(t *testing.T) { test.args.originService(t, ln) ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String()) - var wg sync.WaitGroup - errC := make(chan error) - ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) - proxy := NewOriginProxy(&ingressRule, test.args.warpRoutingService, testTags, logger) + ingressRule.StartOrigins(logger, ctx.Done()) + proxy := NewOriginProxy(ingressRule, true, testTags, logger) + proxy.warpRouting = test.args.warpRoutingService dest := ln.Addr().String() req, err := http.NewRequest( diff --git a/supervisor/configmanager.go b/supervisor/configmanager.go deleted file mode 100644 index b1067c83..00000000 --- a/supervisor/configmanager.go +++ /dev/null @@ -1,55 +0,0 @@ -package supervisor - -import ( - "sync" - - "github.com/rs/zerolog" - - "github.com/cloudflare/cloudflared/connection" - "github.com/cloudflare/cloudflared/ingress" - "github.com/cloudflare/cloudflared/proxy" - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" -) - -type configManager struct { - currentVersion int32 - // Only used by UpdateConfig - updateLock sync.Mutex - // TODO: TUN-5698: Make proxy atomic.Value - proxy *proxy.Proxy - config *DynamicConfig - tags []tunnelpogs.Tag - log *zerolog.Logger -} - -func newConfigManager(config *DynamicConfig, tags []tunnelpogs.Tag, log *zerolog.Logger) *configManager { - var warpRoutingService *ingress.WarpRoutingService - if config.WarpRoutingEnabled { - warpRoutingService = ingress.NewWarpRoutingService() - log.Info().Msgf("Warp-routing is enabled") - } - - return &configManager{ - // Lowest possible version, any remote configuration will have version higher than this - currentVersion: 0, - proxy: proxy.NewOriginProxy(config.Ingress, warpRoutingService, tags, log), - config: config, - log: log, - } -} - -func (cm *configManager) Update(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { - // TODO: TUN-5698: make ingress configurable - return &tunnelpogs.UpdateConfigurationResponse{ - LastAppliedVersion: cm.currentVersion, - } -} - -func (cm *configManager) GetOriginProxy() connection.OriginProxy { - return cm.proxy -} - -type DynamicConfig struct { - Ingress *ingress.Ingress - WarpRoutingEnabled bool -} diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index a87384a8..f1661bf3 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -13,6 +13,7 @@ import ( "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/orchestration" "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -37,8 +38,8 @@ const ( // reconnects them if they disconnect. type Supervisor struct { cloudflaredUUID uuid.UUID - configManager *configManager config *TunnelConfig + orchestrator *orchestration.Orchestrator edgeIPs *edgediscovery.Edge tunnelErrors chan tunnelError tunnelsConnecting map[int]chan struct{} @@ -65,7 +66,7 @@ type tunnelError struct { err error } -func NewSupervisor(config *TunnelConfig, dynamiConfig *DynamicConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) { +func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrator, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) { cloudflaredUUID, err := uuid.NewRandom() if err != nil { return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err) @@ -89,7 +90,7 @@ func NewSupervisor(config *TunnelConfig, dynamiConfig *DynamicConfig, reconnectC return &Supervisor{ cloudflaredUUID: cloudflaredUUID, config: config, - configManager: newConfigManager(dynamiConfig, config.Tags, config.Log), + orchestrator: orchestrator, edgeIPs: edgeIPs, tunnelErrors: make(chan tunnelError), tunnelsConnecting: map[int]chan struct{}{}, @@ -244,8 +245,8 @@ func (s *Supervisor) startFirstTunnel( err = ServeTunnelLoop( ctx, s.reconnectCredentialManager, - s.configManager, s.config, + s.orchestrator, addr, s.log, firstConnIndex, @@ -279,8 +280,8 @@ func (s *Supervisor) startFirstTunnel( err = ServeTunnelLoop( ctx, s.reconnectCredentialManager, - s.configManager, s.config, + s.orchestrator, addr, s.log, firstConnIndex, @@ -314,8 +315,8 @@ func (s *Supervisor) startTunnel( err = ServeTunnelLoop( ctx, s.reconnectCredentialManager, - s.configManager, s.config, + s.orchestrator, addr, s.log, uint8(index), diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 64047863..bca5e081 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -20,6 +20,7 @@ import ( "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/orchestration" quicpogs "github.com/cloudflare/cloudflared/quic" "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" @@ -107,12 +108,12 @@ func (c *TunnelConfig) SupportedFeatures() []string { func StartTunnelDaemon( ctx context.Context, config *TunnelConfig, - dynamiConfig *DynamicConfig, + orchestrator *orchestration.Orchestrator, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal, graceShutdownC <-chan struct{}, ) error { - s, err := NewSupervisor(config, dynamiConfig, reconnectCh, graceShutdownC) + s, err := NewSupervisor(config, orchestrator, reconnectCh, graceShutdownC) if err != nil { return err } @@ -122,8 +123,8 @@ func StartTunnelDaemon( func ServeTunnelLoop( ctx context.Context, credentialManager *reconnectCredentialManager, - configManager *configManager, config *TunnelConfig, + orchestrator *orchestration.Orchestrator, addr *allregions.EdgeAddr, connAwareLogger *ConnAwareLogger, connIndex uint8, @@ -158,8 +159,8 @@ func ServeTunnelLoop( ctx, connLog, credentialManager, - configManager, config, + orchestrator, addr, connIndex, connectedFuse, @@ -257,8 +258,8 @@ func ServeTunnel( ctx context.Context, connLog *ConnAwareLogger, credentialManager *reconnectCredentialManager, - configManager *configManager, config *TunnelConfig, + orchestrator *orchestration.Orchestrator, addr *allregions.EdgeAddr, connIndex uint8, fuse *h2mux.BooleanFuse, @@ -286,8 +287,8 @@ func ServeTunnel( ctx, connLog, credentialManager, - configManager, config, + orchestrator, addr, connIndex, fuse, @@ -335,8 +336,8 @@ func serveTunnel( ctx context.Context, connLog *ConnAwareLogger, credentialManager *reconnectCredentialManager, - configManager *configManager, config *TunnelConfig, + orchestrator *orchestration.Orchestrator, addr *allregions.EdgeAddr, connIndex uint8, fuse *h2mux.BooleanFuse, @@ -365,8 +366,8 @@ func serveTunnel( connOptions := config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries())) return ServeQUIC(ctx, addr.UDP, - configManager, config, + orchestrator, connLog, connOptions, controlStream, @@ -385,8 +386,8 @@ func serveTunnel( if err := ServeHTTP2( ctx, connLog, - configManager, config, + orchestrator, edgeConn, connOptions, controlStream, @@ -408,8 +409,8 @@ func serveTunnel( ctx, connLog, credentialManager, - configManager, config, + orchestrator, edgeConn, connIndex, connectedFuse, @@ -435,8 +436,8 @@ func ServeH2mux( ctx context.Context, connLog *ConnAwareLogger, credentialManager *reconnectCredentialManager, - configManager *configManager, config *TunnelConfig, + orchestrator *orchestration.Orchestrator, edgeConn net.Conn, connIndex uint8, connectedFuse *connectedFuse, @@ -447,7 +448,7 @@ func ServeH2mux( connLog.Logger().Debug().Msgf("Connecting via h2mux") // Returns error from parsing the origin URL or handshake errors handler, err, recoverable := connection.NewH2muxConnection( - configManager, + orchestrator, config.GracePeriod, config.MuxerConfig, edgeConn, @@ -483,8 +484,8 @@ func ServeH2mux( func ServeHTTP2( ctx context.Context, connLog *ConnAwareLogger, - configManager *configManager, config *TunnelConfig, + orchestrator *orchestration.Orchestrator, tlsServerConn net.Conn, connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler connection.ControlStreamHandler, @@ -495,7 +496,7 @@ func ServeHTTP2( connLog.Logger().Debug().Msgf("Connecting via http2") h2conn := connection.NewHTTP2Connection( tlsServerConn, - configManager, + orchestrator, connOptions, config.Observer, connIndex, @@ -523,8 +524,8 @@ func ServeHTTP2( func ServeQUIC( ctx context.Context, edgeAddr *net.UDPAddr, - configManager *configManager, config *TunnelConfig, + orchestrator *orchestration.Orchestrator, connLogger *ConnAwareLogger, connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler connection.ControlStreamHandler, @@ -548,7 +549,7 @@ func ServeQUIC( quicConfig, edgeAddr, tlsConfig, - configManager, + orchestrator, connOptions, controlStreamHandler, connLogger.Logger())