Cleanup
This commit is contained in:
parent
df3ef06169
commit
d4f86ac26d
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/carrier"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
|
@ -58,13 +59,12 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
|||
|
||||
bastionReq := baseReq.Clone(context.Background())
|
||||
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
|
||||
u, err := url.Parse("https://place-holder1")
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
testCase string
|
||||
service *tcpOverWSService
|
||||
req *http.Request
|
||||
bastionMode bool
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
|
@ -75,23 +75,27 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
|||
{
|
||||
testCase: "bastion service",
|
||||
service: newBastionService(),
|
||||
bastionMode: true,
|
||||
req: bastionReq,
|
||||
},
|
||||
{
|
||||
testCase: "invalid bastion request",
|
||||
service: newBastionService(),
|
||||
bastionMode: true,
|
||||
req: baseReq,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
testCase: "bastion service",
|
||||
service: newBastionServiceWithDest(u),
|
||||
service: newBastionServiceWithDest(MustParseURL(t, "https://place-holder1")),
|
||||
req: bastionReq,
|
||||
bastionMode: true,
|
||||
},
|
||||
{
|
||||
testCase: "bastion service",
|
||||
service: newBastionServiceWithDest(u),
|
||||
req: bastionReq,
|
||||
service: newBastionServiceWithDest(MustParseURL(t, "https://place-holder1")),
|
||||
req: baseReq,
|
||||
bastionMode: true,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
@ -99,7 +103,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
|||
for _, test := range tests {
|
||||
t.Run(test.testCase, func(t *testing.T) {
|
||||
if test.expectErr {
|
||||
bastionHost, _ := carrier.ResolveBastionDest(test.req, false, "bastion")
|
||||
bastionHost, _ := carrier.ResolveBastionDest(test.req, test.bastionMode, config.BastionFlag)
|
||||
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
|
|
@ -59,7 +59,6 @@ func (r *Rule) Matches(hostname, path string) bool {
|
|||
} else {
|
||||
hostMatch = matchHost(r.Hostname, hostname)
|
||||
}
|
||||
|
||||
punycodeHostMatch := false
|
||||
if r.punycodeHostname != "" {
|
||||
punycodeHostMatch = matchHost(r.punycodeHostname, hostname)
|
||||
|
|
|
@ -99,10 +99,11 @@ func (p *Proxy) ProxyHTTP(
|
|||
}
|
||||
return err
|
||||
}
|
||||
// Handling for StreamBasedOriginProxy or BastionMode
|
||||
|
||||
// Check if config is for Bastion Mode and service is a stream based origin proxy, if so stream service in bastion mode
|
||||
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); ok || rule.Config.BastionMode {
|
||||
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); !ok && rule.Config.BastionMode {
|
||||
return fmt.Errorf("Unrecognized service: %s", rule.Service)
|
||||
return fmt.Errorf("Unsupported service to stream to in bastion mode: %s", rule.Service)
|
||||
}
|
||||
|
||||
dest, err := getDestFromRule(rule, req)
|
||||
|
@ -116,6 +117,7 @@ func (p *Proxy) ProxyHTTP(
|
|||
}
|
||||
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
|
||||
logger := logger.With().Str(logFieldDestAddr, dest).Logger()
|
||||
// We know that Bastion mode is supported by StreamBasedOriginProxy, hence use the same
|
||||
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, rule.Service.(ingress.StreamBasedOriginProxy), &logger); err != nil {
|
||||
logRequestError(&logger, err)
|
||||
return err
|
||||
|
|
Loading…
Reference in New Issue