This commit is contained in:
Shayon Mukherjee 2024-05-10 16:16:18 -04:00
parent df3ef06169
commit d4f86ac26d
3 changed files with 29 additions and 24 deletions

View File

@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
) )
@ -58,13 +59,12 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
bastionReq := baseReq.Clone(context.Background()) bastionReq := baseReq.Clone(context.Background())
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String()) carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
u, err := url.Parse("https://place-holder1")
require.NoError(t, err)
tests := []struct { tests := []struct {
testCase string testCase string
service *tcpOverWSService service *tcpOverWSService
req *http.Request req *http.Request
bastionMode bool
expectErr bool expectErr bool
}{ }{
{ {
@ -75,23 +75,27 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
{ {
testCase: "bastion service", testCase: "bastion service",
service: newBastionService(), service: newBastionService(),
bastionMode: true,
req: bastionReq, req: bastionReq,
}, },
{ {
testCase: "invalid bastion request", testCase: "invalid bastion request",
service: newBastionService(), service: newBastionService(),
bastionMode: true,
req: baseReq, req: baseReq,
expectErr: true, expectErr: true,
}, },
{ {
testCase: "bastion service", testCase: "bastion service",
service: newBastionServiceWithDest(u), service: newBastionServiceWithDest(MustParseURL(t, "https://place-holder1")),
req: bastionReq, req: bastionReq,
bastionMode: true,
}, },
{ {
testCase: "bastion service", testCase: "bastion service",
service: newBastionServiceWithDest(u), service: newBastionServiceWithDest(MustParseURL(t, "https://place-holder1")),
req: bastionReq, req: baseReq,
bastionMode: true,
expectErr: true, expectErr: true,
}, },
} }
@ -99,7 +103,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.testCase, func(t *testing.T) { t.Run(test.testCase, func(t *testing.T) {
if test.expectErr { if test.expectErr {
bastionHost, _ := carrier.ResolveBastionDest(test.req, false, "bastion") bastionHost, _ := carrier.ResolveBastionDest(test.req, test.bastionMode, config.BastionFlag)
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger) _, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err) assert.Error(t, err)
} }

View File

@ -59,7 +59,6 @@ func (r *Rule) Matches(hostname, path string) bool {
} else { } else {
hostMatch = matchHost(r.Hostname, hostname) hostMatch = matchHost(r.Hostname, hostname)
} }
punycodeHostMatch := false punycodeHostMatch := false
if r.punycodeHostname != "" { if r.punycodeHostname != "" {
punycodeHostMatch = matchHost(r.punycodeHostname, hostname) punycodeHostMatch = matchHost(r.punycodeHostname, hostname)

View File

@ -99,10 +99,11 @@ func (p *Proxy) ProxyHTTP(
} }
return err return err
} }
// Handling for StreamBasedOriginProxy or BastionMode
// Check if config is for Bastion Mode and service is a stream based origin proxy, if so stream service in bastion mode
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); ok || rule.Config.BastionMode { if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); ok || rule.Config.BastionMode {
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); !ok && rule.Config.BastionMode { if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); !ok && rule.Config.BastionMode {
return fmt.Errorf("Unrecognized service: %s", rule.Service) return fmt.Errorf("Unsupported service to stream to in bastion mode: %s", rule.Service)
} }
dest, err := getDestFromRule(rule, req) dest, err := getDestFromRule(rule, req)
@ -116,6 +117,7 @@ func (p *Proxy) ProxyHTTP(
} }
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req) rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
logger := logger.With().Str(logFieldDestAddr, dest).Logger() logger := logger.With().Str(logFieldDestAddr, dest).Logger()
// We know that Bastion mode is supported by StreamBasedOriginProxy, hence use the same
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, rule.Service.(ingress.StreamBasedOriginProxy), &logger); err != nil { if err := p.proxyStream(tr.ToTracedContext(), rws, dest, rule.Service.(ingress.StreamBasedOriginProxy), &logger); err != nil {
logRequestError(&logger, err) logRequestError(&logger, err)
return err return err