diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 6bce68f6..36257cc7 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -358,7 +358,7 @@ func prepareTunnelConfig( } orchestratorConfig := &orchestration.Config{ Ingress: &ingressRules, - WarpRoutingEnabled: warpRoutingEnabled, + WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting), ConfigurationFlags: parseConfigFlags(c), } return tunnelConfig, orchestratorConfig, nil diff --git a/config/configuration.go b/config/configuration.go index 49395404..fdbf0aef 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -244,7 +244,9 @@ type Configuration struct { } type WarpRoutingConfig struct { - Enabled bool `yaml:"enabled" json:"enabled"` + Enabled bool `yaml:"enabled" json:"enabled"` + ConnectTimeout *CustomDuration `yaml:"connectTimeout" json:"connectTimeout,omitempty"` + TCPKeepAlive *CustomDuration `yaml:"tcpKeepAlive" json:"tcpKeepAlive,omitempty"` } type configFileSettings struct { diff --git a/config/configuration_test.go b/config/configuration_test.go index dff785c8..96adb74b 100644 --- a/config/configuration_test.go +++ b/config/configuration_test.go @@ -23,7 +23,9 @@ func TestConfigFileSettings(t *testing.T) { Service: "https://localhost:8001", } warpRouting = WarpRoutingConfig{ - Enabled: true, + Enabled: true, + ConnectTimeout: &CustomDuration{Duration: 2 * time.Second}, + TCPKeepAlive: &CustomDuration{Duration: 10 * time.Second}, } ) rawYAML := ` @@ -48,6 +50,9 @@ ingress: service: https://localhost:8001 warp-routing: enabled: true + connectTimeout: 2s + tcpKeepAlive: 10s + retries: 5 grace-period: 30s percentage: 3.14 diff --git a/ingress/config.go b/ingress/config.go index bc2a9f6b..350edb0f 100644 --- a/ingress/config.go +++ b/ingress/config.go @@ -12,10 +12,11 @@ import ( ) var ( - defaultConnectTimeout = config.CustomDuration{Duration: 30 * time.Second} - defaultTLSTimeout = config.CustomDuration{Duration: 10 * time.Second} - defaultTCPKeepAlive = config.CustomDuration{Duration: 30 * time.Second} - defaultKeepAliveTimeout = config.CustomDuration{Duration: 90 * time.Second} + defaultHTTPConnectTimeout = config.CustomDuration{Duration: 30 * time.Second} + defaultWarpRoutingConnectTimeout = config.CustomDuration{Duration: 5 * time.Second} + defaultTLSTimeout = config.CustomDuration{Duration: 10 * time.Second} + defaultTCPKeepAlive = config.CustomDuration{Duration: 30 * time.Second} + defaultKeepAliveTimeout = config.CustomDuration{Duration: 90 * time.Second} ) const ( @@ -41,10 +42,44 @@ const ( socksProxy = "socks" ) +type WarpRoutingConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + ConnectTimeout config.CustomDuration `yaml:"connectTimeout" json:"connectTimeout,omitempty"` + TCPKeepAlive config.CustomDuration `yaml:"tcpKeepAlive" json:"tcpKeepAlive,omitempty"` +} + +func NewWarpRoutingConfig(raw *config.WarpRoutingConfig) WarpRoutingConfig { + cfg := WarpRoutingConfig{ + Enabled: raw.Enabled, + ConnectTimeout: defaultWarpRoutingConnectTimeout, + TCPKeepAlive: defaultTCPKeepAlive, + } + if raw.ConnectTimeout != nil { + cfg.ConnectTimeout = *raw.ConnectTimeout + } + if raw.TCPKeepAlive != nil { + cfg.TCPKeepAlive = *raw.TCPKeepAlive + } + return cfg +} + +func (c *WarpRoutingConfig) RawConfig() config.WarpRoutingConfig { + raw := config.WarpRoutingConfig{ + Enabled: c.Enabled, + } + if c.ConnectTimeout.Duration != defaultWarpRoutingConnectTimeout.Duration { + raw.ConnectTimeout = &c.ConnectTimeout + } + if c.TCPKeepAlive.Duration != defaultTCPKeepAlive.Duration { + raw.TCPKeepAlive = &c.TCPKeepAlive + } + return raw +} + // RemoteConfig models ingress settings that can be managed remotely, for example through the dashboard. type RemoteConfig struct { Ingress Ingress - WarpRouting config.WarpRoutingConfig + WarpRouting WarpRoutingConfig } type RemoteConfigJSON struct { @@ -72,18 +107,18 @@ func (rc *RemoteConfig) UnmarshalJSON(b []byte) error { } rc.Ingress = ingress - rc.WarpRouting = rawConfig.WarpRouting + rc.WarpRouting = NewWarpRoutingConfig(&rawConfig.WarpRouting) return nil } func originRequestFromSingeRule(c *cli.Context) OriginRequestConfig { - var connectTimeout config.CustomDuration = defaultConnectTimeout - var tlsTimeout config.CustomDuration = defaultTLSTimeout - var tcpKeepAlive config.CustomDuration = defaultTCPKeepAlive + var connectTimeout = defaultHTTPConnectTimeout + var tlsTimeout = defaultTLSTimeout + var tcpKeepAlive = defaultTCPKeepAlive var noHappyEyeballs bool - var keepAliveConnections int = defaultKeepAliveConnections - var keepAliveTimeout config.CustomDuration = defaultKeepAliveTimeout + var keepAliveConnections = defaultKeepAliveConnections + var keepAliveTimeout = defaultKeepAliveTimeout var httpHostHeader string var originServerName string var caPool string @@ -160,7 +195,7 @@ func originRequestFromSingeRule(c *cli.Context) OriginRequestConfig { func originRequestFromConfig(c config.OriginRequestConfig) OriginRequestConfig { out := OriginRequestConfig{ - ConnectTimeout: defaultConnectTimeout, + ConnectTimeout: defaultHTTPConnectTimeout, TLSTimeout: defaultTLSTimeout, TCPKeepAlive: defaultTCPKeepAlive, KeepAliveConnections: defaultKeepAliveConnections, @@ -404,7 +439,7 @@ func ConvertToRawOriginConfig(c OriginRequestConfig) config.OriginRequestConfig var keepAliveTimeout *config.CustomDuration var proxyAddress *string - if c.ConnectTimeout != defaultConnectTimeout { + if c.ConnectTimeout != defaultHTTPConnectTimeout { connectTimeout = &c.ConnectTimeout } if c.TLSTimeout != defaultTLSTimeout { diff --git a/ingress/config_test.go b/ingress/config_test.go index 93c971fc..2b9f8a3b 100644 --- a/ingress/config_test.go +++ b/ingress/config_test.go @@ -274,7 +274,7 @@ func TestOriginRequestConfigDefaults(t *testing.T) { // Rule 0 didn't override anything, so it inherits the cloudflared defaults actual0 := ing.Rules[0].Config expected0 := OriginRequestConfig{ - ConnectTimeout: defaultConnectTimeout, + ConnectTimeout: defaultHTTPConnectTimeout, TLSTimeout: defaultTLSTimeout, TCPKeepAlive: defaultTCPKeepAlive, KeepAliveConnections: defaultKeepAliveConnections, @@ -404,7 +404,7 @@ func TestDefaultConfigFromCLI(t *testing.T) { c := cli.NewContext(nil, set, nil) expected := OriginRequestConfig{ - ConnectTimeout: defaultConnectTimeout, + ConnectTimeout: defaultHTTPConnectTimeout, TLSTimeout: defaultTLSTimeout, TCPKeepAlive: defaultTCPKeepAlive, KeepAliveConnections: defaultKeepAliveConnections, diff --git a/ingress/ingress.go b/ingress/ingress.go index 8d2ec5ba..05a90a8b 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -97,8 +97,16 @@ type WarpRoutingService struct { Proxy StreamBasedOriginProxy } -func NewWarpRoutingService() *WarpRoutingService { - return &WarpRoutingService{Proxy: &rawTCPService{name: ServiceWarpRouting}} +func NewWarpRoutingService(config WarpRoutingConfig) *WarpRoutingService { + svc := &rawTCPService{ + name: ServiceWarpRouting, + dialer: net.Dialer{ + Timeout: config.ConnectTimeout.Duration, + KeepAlive: config.TCPKeepAlive.Duration, + }, + } + + return &WarpRoutingService{Proxy: svc} } // Get a single origin service from the CLI/config. diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index 90371dfa..83bcd4fc 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -1,26 +1,20 @@ package ingress import ( + "context" "fmt" - "net" "net/http" - - "github.com/pkg/errors" -) - -var ( - errUnsupportedConnectionType = errors.New("internal error: unsupported connection type") ) // HTTPOriginProxy can be implemented by origin services that want to proxy http requests. type HTTPOriginProxy interface { - // RoundTrip is how cloudflared proxies eyeball requests to the actual origin services + // RoundTripper is how cloudflared proxies eyeball requests to the actual origin services http.RoundTripper } // StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP. type StreamBasedOriginProxy interface { - EstablishConnection(dest string) (OriginConnection, error) + EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) } func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { @@ -59,8 +53,8 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { return resp, nil } -func (o *rawTCPService) EstablishConnection(dest string) (OriginConnection, error) { - conn, err := net.Dial("tcp", dest) +func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) { + conn, err := o.dialer.DialContext(ctx, "tcp", dest) if err != nil { return nil, err } @@ -71,13 +65,13 @@ func (o *rawTCPService) EstablishConnection(dest string) (OriginConnection, erro return originConn, nil } -func (o *tcpOverWSService) EstablishConnection(dest string) (OriginConnection, error) { +func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) { var err error if !o.isBastion { dest = o.dest } - conn, err := net.Dial("tcp", dest) + conn, err := o.dialer.DialContext(ctx, "tcp", dest) if err != nil { return nil, err } @@ -89,6 +83,6 @@ func (o *tcpOverWSService) EstablishConnection(dest string) (OriginConnection, e } -func (o *socksProxyOverWSService) EstablishConnection(dest string) (OriginConnection, error) { +func (o *socksProxyOverWSService) EstablishConnection(_ctx context.Context, _dest string) (OriginConnection, error) { return o.conn, nil } diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index cc244aee..427b2a65 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -36,7 +36,7 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) { require.NoError(t, err) // Origin not listening for new connection, should return an error - _, err = rawTCPService.EstablishConnection(req.URL.String()) + _, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String()) require.Error(t, err) } @@ -87,7 +87,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { t.Run(test.testCase, func(t *testing.T) { if test.expectErr { bastionHost, _ := carrier.ResolveBastionDest(test.req) - _, err := test.service.EstablishConnection(bastionHost) + _, err := test.service.EstablishConnection(context.Background(), bastionHost) assert.Error(t, err) } }) @@ -99,7 +99,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} { // Origin not listening for new connection, should return an error bastionHost, _ := carrier.ResolveBastionDest(bastionReq) - _, err := service.EstablishConnection(bastionHost) + _, err := service.EstablishConnection(context.Background(), bastionHost) assert.Error(t, err) } } diff --git a/ingress/origin_service.go b/ingress/origin_service.go index 38ceeda5..6f348bee 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -91,7 +91,8 @@ func (o httpService) MarshalJSON() ([]byte, error) { // rawTCPService dials TCP to the destination specified by the client // It's used by warp routing type rawTCPService struct { - name string + name string + dialer net.Dialer } func (o *rawTCPService) String() string { @@ -113,6 +114,7 @@ type tcpOverWSService struct { dest string isBastion bool streamHandler streamHandlerFunc + dialer net.Dialer } type socksProxyOverWSService struct { @@ -176,6 +178,8 @@ func (o *tcpOverWSService) start(log *zerolog.Logger, _ <-chan struct{}, cfg Ori } else { o.streamHandler = DefaultStreamHandler } + o.dialer.Timeout = cfg.ConnectTimeout.Duration + o.dialer.KeepAlive = cfg.TCPKeepAlive.Duration return nil } diff --git a/orchestration/config.go b/orchestration/config.go index cfdbd939..26904b57 100644 --- a/orchestration/config.go +++ b/orchestration/config.go @@ -19,8 +19,8 @@ type newLocalConfig struct { // Config is the original config as read and parsed by cloudflared. type Config struct { - Ingress *ingress.Ingress - WarpRoutingEnabled bool + Ingress *ingress.Ingress + WarpRouting ingress.WarpRoutingConfig // Extra settings used to configure this instance but that are not eligible for remotely management // ie. (--protocol, --loglevel, ...) @@ -37,7 +37,7 @@ func (rc *newLocalConfig) MarshalJSON() ([]byte, error) { // UI doesn't support top level configs, so we reconcile to individual ingress configs. GlobalOriginRequest: nil, IngressRules: convertToUnvalidatedIngressRules(rc.RemoteConfig.Ingress), - WarpRouting: rc.RemoteConfig.WarpRouting, + WarpRouting: rc.RemoteConfig.WarpRouting.RawConfig(), }, } diff --git a/orchestration/config_test.go b/orchestration/config_test.go index 53d9a23f..04c15359 100644 --- a/orchestration/config_test.go +++ b/orchestration/config_test.go @@ -3,16 +3,17 @@ package orchestration import ( "encoding/json" "testing" + "time" "github.com/stretchr/testify/require" + "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/ingress" ) // TestNewLocalConfig_MarshalJSON tests that we are able to converte a compiled and validated config back // into an "unvalidated" format which is compatible with Remote Managed configurations. func TestNewLocalConfig_MarshalJSON(t *testing.T) { - rawConfig := []byte(` { "originRequest": { @@ -57,7 +58,11 @@ func TestNewLocalConfig_MarshalJSON(t *testing.T) { ] } } - ] + ], + "warp-routing": { + "enabled": true, + "connectTimeout": 1 + } } `) @@ -73,10 +78,18 @@ func TestNewLocalConfig_MarshalJSON(t *testing.T) { jsonSerde, err := json.Marshal(c) require.NoError(t, err) - var config ingress.RemoteConfig - err = json.Unmarshal(jsonSerde, &config) + var remoteConfig ingress.RemoteConfig + err = json.Unmarshal(jsonSerde, &remoteConfig) require.NoError(t, err) - require.Equal(t, config.WarpRouting.Enabled, false) - require.Equal(t, config.Ingress.Rules, expectedConfig.Ingress.Rules) + require.Equal(t, remoteConfig.WarpRouting, ingress.WarpRoutingConfig{ + Enabled: true, + ConnectTimeout: config.CustomDuration{ + Duration: time.Second, + }, + TCPKeepAlive: config.CustomDuration{ + Duration: 30 * time.Second, // default value is 30 seconds + }, + }) + require.Equal(t, remoteConfig.Ingress.Rules, expectedConfig.Ingress.Rules) } diff --git a/orchestration/orchestrator.go b/orchestration/orchestrator.go index 60cd1f10..0659411a 100644 --- a/orchestration/orchestrator.go +++ b/orchestration/orchestrator.go @@ -47,7 +47,7 @@ func NewOrchestrator(ctx context.Context, config *Config, tags []tunnelpogs.Tag, log: log, shutdownC: ctx.Done(), } - if err := o.updateIngress(*config.Ingress, config.WarpRoutingEnabled); err != nil { + if err := o.updateIngress(*config.Ingress, config.WarpRouting); err != nil { return nil, err } go o.waitToCloseLastProxy() @@ -80,7 +80,7 @@ func (o *Orchestrator) UpdateConfig(version int32, config []byte) *tunnelpogs.Up } } - if err := o.updateIngress(newConf.Ingress, newConf.WarpRouting.Enabled); err != nil { + if err := o.updateIngress(newConf.Ingress, newConf.WarpRouting); err != nil { o.log.Err(err). Int32("version", version). Str("config", string(config)). @@ -103,7 +103,7 @@ func (o *Orchestrator) UpdateConfig(version int32, config []byte) *tunnelpogs.Up } // The caller is responsible to make sure there is no concurrent access -func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRoutingEnabled bool) error { +func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting ingress.WarpRoutingConfig) error { select { case <-o.shutdownC: return fmt.Errorf("cloudflared already shutdown") @@ -118,10 +118,10 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRoutingEn if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil { return errors.Wrap(err, "failed to start origin") } - newProxy := proxy.NewOriginProxy(ingressRules, warpRoutingEnabled, o.tags, o.log) + newProxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.log) o.proxy.Store(newProxy) o.config.Ingress = &ingressRules - o.config.WarpRoutingEnabled = warpRoutingEnabled + o.config.WarpRouting = warpRouting // If proxyShutdownC is nil, there is no previous running proxy if o.proxyShutdownC != nil { @@ -139,7 +139,7 @@ func (o *Orchestrator) GetConfigJSON() ([]byte, error) { c := &newLocalConfig{ RemoteConfig: ingress.RemoteConfig{ Ingress: *o.config.Ingress, - WarpRouting: config.WarpRoutingConfig{Enabled: o.config.WarpRoutingEnabled}, + WarpRouting: o.config.WarpRouting, }, ConfigurationFlags: o.config.ConfigurationFlags, } @@ -166,7 +166,7 @@ func (o *Orchestrator) GetVersionedConfigJSON() ([]byte, error) { OriginRequest ingress.OriginRequestConfig `json:"originRequest"` }{ Ingress: o.config.Ingress.Rules, - WarpRouting: config.WarpRoutingConfig{Enabled: o.config.WarpRoutingEnabled}, + WarpRouting: o.config.WarpRouting.RawConfig(), OriginRequest: o.config.Ingress.Defaults, }, } diff --git a/orchestration/orchestrator_test.go b/orchestration/orchestrator_test.go index 85f0f83a..dfac28a0 100644 --- a/orchestration/orchestrator_test.go +++ b/orchestration/orchestrator_test.go @@ -48,8 +48,7 @@ var ( // - receiving an old version is noop func TestUpdateConfiguration(t *testing.T) { initConfig := &Config{ - Ingress: &ingress.Ingress{}, - WarpRoutingEnabled: false, + Ingress: &ingress.Ingress{}, } orchestrator, err := NewOrchestrator(context.Background(), initConfig, testTags, &testLogger) require.NoError(t, err) @@ -87,7 +86,8 @@ func TestUpdateConfiguration(t *testing.T) { } ], "warp-routing": { - "enabled": true + "enabled": true, + "connectTimeout": 10 } } `) @@ -121,7 +121,8 @@ func TestUpdateConfiguration(t *testing.T) { require.Equal(t, config.CustomDuration{Duration: time.Second * 90}, configV2.Ingress.Rules[2].Config.ConnectTimeout) require.Equal(t, false, configV2.Ingress.Rules[2].Config.NoTLSVerify) require.Equal(t, true, configV2.Ingress.Rules[2].Config.NoHappyEyeballs) - require.True(t, configV2.WarpRoutingEnabled) + require.True(t, configV2.WarpRouting.Enabled) + require.Equal(t, configV2.WarpRouting.ConnectTimeout.Duration, 10*time.Second) originProxyV2, err := orchestrator.GetOriginProxy() require.NoError(t, err) @@ -164,7 +165,7 @@ func TestUpdateConfiguration(t *testing.T) { require.Len(t, configV10.Ingress.Rules, 1) require.True(t, configV10.Ingress.Rules[0].Matches("blogs.tunnel.io", "/2022/02/10")) require.Equal(t, ingress.HelloWorldService, configV10.Ingress.Rules[0].Service.String()) - require.False(t, configV10.WarpRoutingEnabled) + require.False(t, configV10.WarpRouting.Enabled) originProxyV10, err := orchestrator.GetOriginProxy() require.NoError(t, err) @@ -246,8 +247,7 @@ func TestConcurrentUpdateAndRead(t *testing.T) { appliedV2 = make(chan struct{}) initConfig = &Config{ - Ingress: &ingress.Ingress{}, - WarpRoutingEnabled: false, + Ingress: &ingress.Ingress{}, } ) @@ -476,8 +476,7 @@ func TestClosePreviousProxies(t *testing.T) { } `) initConfig = &Config{ - Ingress: &ingress.Ingress{}, - WarpRoutingEnabled: false, + Ingress: &ingress.Ingress{}, } ) @@ -534,8 +533,7 @@ func TestPersistentConnection(t *testing.T) { ) msg := t.Name() initConfig := &Config{ - Ingress: &ingress.Ingress{}, - WarpRoutingEnabled: false, + Ingress: &ingress.Ingress{}, } orchestrator, err := NewOrchestrator(context.Background(), initConfig, testTags, &testLogger) require.NoError(t, err) @@ -645,8 +643,7 @@ func TestPersistentConnection(t *testing.T) { func TestSerializeLocalConfig(t *testing.T) { c := &newLocalConfig{ RemoteConfig: ingress.RemoteConfig{ - Ingress: ingress.Ingress{}, - WarpRouting: config.WarpRoutingConfig{}, + Ingress: ingress.Ingress{}, }, ConfigurationFlags: map[string]string{"a": "b"}, } diff --git a/proxy/proxy.go b/proxy/proxy.go index 5db9ab4c..e31dc75b 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -42,7 +42,7 @@ type Proxy struct { // NewOriginProxy returns a new instance of the Proxy struct. func NewOriginProxy( ingressRules ingress.Ingress, - warpRoutingEnabled bool, + warpRouting ingress.WarpRoutingConfig, tags []tunnelpogs.Tag, log *zerolog.Logger, ) *Proxy { @@ -51,8 +51,8 @@ func NewOriginProxy( tags: tags, log: log, } - if warpRoutingEnabled { - proxy.warpRouting = ingress.NewWarpRoutingService() + if warpRouting.Enabled { + proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting) log.Info().Msgf("Warp-routing is enabled") } @@ -108,7 +108,7 @@ func (p *Proxy) ProxyHTTP( } rws := connection.NewHTTPResponseReadWriterAcker(w, req) - if err := p.proxyStream(req.Context(), rws, dest, originProxy, logFields); err != nil { + if err := p.proxyStream(req.Context(), rws, dest, originProxy); err != nil { rule, srv := ruleField(p.ingressRules, ruleNum) p.logRequestError(err, cfRay, "", rule, srv) return err @@ -137,15 +137,9 @@ func (p *Proxy) ProxyTCP( serveCtx, cancel := context.WithCancel(ctx) defer cancel() - logFields := logFields{ - cfRay: req.CFRay, - lbProbe: req.LBProbe, - rule: ingress.ServiceWarpRouting, - flowID: req.FlowID, - } - p.log.Debug().Str(LogFieldFlowID, req.FlowID).Msg("tcp proxy stream started") - if err := p.proxyStream(serveCtx, rwa, req.Dest, p.warpRouting.Proxy, logFields); err != nil { + + if err := p.proxyStream(serveCtx, rwa, req.Dest, p.warpRouting.Proxy); err != nil { p.logRequestError(err, req.CFRay, req.FlowID, "", ingress.ServiceWarpRouting) return err } @@ -255,9 +249,8 @@ func (p *Proxy) proxyStream( rwa connection.ReadWriteAcker, dest string, connectionProxy ingress.StreamBasedOriginProxy, - fields logFields, ) error { - originConn, err := connectionProxy.EstablishConnection(dest) + originConn, err := connectionProxy.EstablishConnection(ctx, dest) if err != nil { return err } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 9f0461ce..0e0019f1 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -32,7 +32,12 @@ import ( ) var ( - testTags = []tunnelpogs.Tag{tunnelpogs.Tag{Name: "Name", Value: "value"}} + testTags = []tunnelpogs.Tag{tunnelpogs.Tag{Name: "Name", Value: "value"}} + noWarpRouting = ingress.WarpRoutingConfig{} + testWarpRouting = ingress.WarpRoutingConfig{ + Enabled: true, + ConnectTimeout: config.CustomDuration{Duration: time.Second}, + } ) type mockHTTPRespWriter struct { @@ -138,7 +143,7 @@ func TestProxySingleOrigin(t *testing.T) { require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(ingressRule, false, testTags, &log) + proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, &log) t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxySSE", testProxySSE(proxy)) @@ -345,7 +350,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat ctx, cancel := context.WithCancel(context.Background()) require.NoError(t, ingress.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(ingress, false, testTags, &log) + proxy := NewOriginProxy(ingress, noWarpRouting, testTags, &log) for _, test := range tests { responseWriter := newMockHTTPRespWriter() @@ -393,7 +398,7 @@ func TestProxyError(t *testing.T) { log := zerolog.Nop() - proxy := NewOriginProxy(ing, false, testTags, &log) + proxy := NewOriginProxy(ing, noWarpRouting, testTags, &log) responseWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) @@ -509,7 +514,7 @@ func TestConnections(t *testing.T) { originService: runEchoTCPService, eyeballResponseWriter: newTCPRespWriter(replayer), eyeballRequestBody: newTCPRequestBody([]byte("test2")), - warpRoutingService: ingress.NewWarpRoutingService(), + warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting), connectionType: connection.TypeTCP, requestHeaders: map[string][]string{ "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, @@ -526,7 +531,7 @@ func TestConnections(t *testing.T) { originService: runEchoWSService, // eyeballResponseWriter gets set after roundtrip dial. eyeballRequestBody: newPipedWSRequestBody([]byte("test3")), - warpRoutingService: ingress.NewWarpRoutingService(), + warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting), requestHeaders: map[string][]string{ "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, }, @@ -652,7 +657,7 @@ func TestConnections(t *testing.T) { ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String()) ingressRule.StartOrigins(logger, ctx.Done()) - proxy := NewOriginProxy(ingressRule, true, testTags, logger) + proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, logger) proxy.warpRouting = test.args.warpRoutingService dest := ln.Addr().String()