From d4f86ac26d9fba19b2db0a1ed6d3abd7cc419c5a Mon Sep 17 00:00:00 2001 From: Shayon Mukherjee Date: Fri, 10 May 2024 16:16:18 -0400 Subject: [PATCH] Cleanup --- ingress/origin_proxy_test.go | 46 ++++++++++++++++++++---------------- ingress/rule.go | 1 - proxy/proxy.go | 6 +++-- 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index 9387cd2d..cc3f26ef 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/require" "github.com/cloudflare/cloudflared/carrier" + "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/websocket" ) @@ -58,14 +59,13 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { bastionReq := baseReq.Clone(context.Background()) carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String()) - u, err := url.Parse("https://place-holder1") - require.NoError(t, err) tests := []struct { - testCase string - service *tcpOverWSService - req *http.Request - expectErr bool + testCase string + service *tcpOverWSService + req *http.Request + bastionMode bool + expectErr bool }{ { testCase: "specific TCP service", @@ -73,33 +73,37 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { req: baseReq, }, { - testCase: "bastion service", - service: newBastionService(), - req: bastionReq, + testCase: "bastion service", + service: newBastionService(), + bastionMode: true, + req: bastionReq, }, { - testCase: "invalid bastion request", - service: newBastionService(), - req: baseReq, - expectErr: true, + testCase: "invalid bastion request", + service: newBastionService(), + bastionMode: true, + req: baseReq, + expectErr: true, }, { - testCase: "bastion service", - service: newBastionServiceWithDest(u), - req: bastionReq, + testCase: "bastion service", + service: newBastionServiceWithDest(MustParseURL(t, "https://place-holder1")), + req: bastionReq, + bastionMode: true, }, { - testCase: "bastion service", - service: newBastionServiceWithDest(u), - req: bastionReq, - expectErr: true, + testCase: "bastion service", + service: newBastionServiceWithDest(MustParseURL(t, "https://place-holder1")), + req: baseReq, + bastionMode: true, + expectErr: true, }, } for _, test := range tests { t.Run(test.testCase, func(t *testing.T) { if test.expectErr { - bastionHost, _ := carrier.ResolveBastionDest(test.req, false, "bastion") + bastionHost, _ := carrier.ResolveBastionDest(test.req, test.bastionMode, config.BastionFlag) _, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger) assert.Error(t, err) } diff --git a/ingress/rule.go b/ingress/rule.go index 79cdf1d5..43c7ad5e 100644 --- a/ingress/rule.go +++ b/ingress/rule.go @@ -59,7 +59,6 @@ func (r *Rule) Matches(hostname, path string) bool { } else { hostMatch = matchHost(r.Hostname, hostname) } - punycodeHostMatch := false if r.punycodeHostname != "" { punycodeHostMatch = matchHost(r.punycodeHostname, hostname) diff --git a/proxy/proxy.go b/proxy/proxy.go index c0cf7052..8edd3a9a 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -99,10 +99,11 @@ func (p *Proxy) ProxyHTTP( } return err } - // Handling for StreamBasedOriginProxy or BastionMode + + // Check if config is for Bastion Mode and service is a stream based origin proxy, if so stream service in bastion mode if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); ok || rule.Config.BastionMode { if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); !ok && rule.Config.BastionMode { - return fmt.Errorf("Unrecognized service: %s", rule.Service) + return fmt.Errorf("Unsupported service to stream to in bastion mode: %s", rule.Service) } dest, err := getDestFromRule(rule, req) @@ -116,6 +117,7 @@ func (p *Proxy) ProxyHTTP( } rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req) logger := logger.With().Str(logFieldDestAddr, dest).Logger() + // We know that Bastion mode is supported by StreamBasedOriginProxy, hence use the same if err := p.proxyStream(tr.ToTracedContext(), rws, dest, rule.Service.(ingress.StreamBasedOriginProxy), &logger); err != nil { logRequestError(&logger, err) return err