This commit is contained in:
Shayon Mukherjee 2024-05-10 16:16:18 -04:00
parent df3ef06169
commit d4f86ac26d
3 changed files with 29 additions and 24 deletions

View File

@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/websocket"
)
@ -58,14 +59,13 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
bastionReq := baseReq.Clone(context.Background())
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
u, err := url.Parse("https://place-holder1")
require.NoError(t, err)
tests := []struct {
testCase string
service *tcpOverWSService
req *http.Request
expectErr bool
testCase string
service *tcpOverWSService
req *http.Request
bastionMode bool
expectErr bool
}{
{
testCase: "specific TCP service",
@ -73,33 +73,37 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
req: baseReq,
},
{
testCase: "bastion service",
service: newBastionService(),
req: bastionReq,
testCase: "bastion service",
service: newBastionService(),
bastionMode: true,
req: bastionReq,
},
{
testCase: "invalid bastion request",
service: newBastionService(),
req: baseReq,
expectErr: true,
testCase: "invalid bastion request",
service: newBastionService(),
bastionMode: true,
req: baseReq,
expectErr: true,
},
{
testCase: "bastion service",
service: newBastionServiceWithDest(u),
req: bastionReq,
testCase: "bastion service",
service: newBastionServiceWithDest(MustParseURL(t, "https://place-holder1")),
req: bastionReq,
bastionMode: true,
},
{
testCase: "bastion service",
service: newBastionServiceWithDest(u),
req: bastionReq,
expectErr: true,
testCase: "bastion service",
service: newBastionServiceWithDest(MustParseURL(t, "https://place-holder1")),
req: baseReq,
bastionMode: true,
expectErr: true,
},
}
for _, test := range tests {
t.Run(test.testCase, func(t *testing.T) {
if test.expectErr {
bastionHost, _ := carrier.ResolveBastionDest(test.req, false, "bastion")
bastionHost, _ := carrier.ResolveBastionDest(test.req, test.bastionMode, config.BastionFlag)
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err)
}

View File

@ -59,7 +59,6 @@ func (r *Rule) Matches(hostname, path string) bool {
} else {
hostMatch = matchHost(r.Hostname, hostname)
}
punycodeHostMatch := false
if r.punycodeHostname != "" {
punycodeHostMatch = matchHost(r.punycodeHostname, hostname)

View File

@ -99,10 +99,11 @@ func (p *Proxy) ProxyHTTP(
}
return err
}
// Handling for StreamBasedOriginProxy or BastionMode
// Check if config is for Bastion Mode and service is a stream based origin proxy, if so stream service in bastion mode
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); ok || rule.Config.BastionMode {
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); !ok && rule.Config.BastionMode {
return fmt.Errorf("Unrecognized service: %s", rule.Service)
return fmt.Errorf("Unsupported service to stream to in bastion mode: %s", rule.Service)
}
dest, err := getDestFromRule(rule, req)
@ -116,6 +117,7 @@ func (p *Proxy) ProxyHTTP(
}
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
logger := logger.With().Str(logFieldDestAddr, dest).Logger()
// We know that Bastion mode is supported by StreamBasedOriginProxy, hence use the same
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, rule.Service.(ingress.StreamBasedOriginProxy), &logger); err != nil {
logRequestError(&logger, err)
return err