Cleanup
This commit is contained in:
parent
df3ef06169
commit
d4f86ac26d
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/carrier"
|
"github.com/cloudflare/cloudflared/carrier"
|
||||||
|
"github.com/cloudflare/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -58,14 +59,13 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
|
|
||||||
bastionReq := baseReq.Clone(context.Background())
|
bastionReq := baseReq.Clone(context.Background())
|
||||||
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
|
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
|
||||||
u, err := url.Parse("https://place-holder1")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
testCase string
|
testCase string
|
||||||
service *tcpOverWSService
|
service *tcpOverWSService
|
||||||
req *http.Request
|
req *http.Request
|
||||||
expectErr bool
|
bastionMode bool
|
||||||
|
expectErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
testCase: "specific TCP service",
|
testCase: "specific TCP service",
|
||||||
|
@ -73,33 +73,37 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||||
req: baseReq,
|
req: baseReq,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
testCase: "bastion service",
|
testCase: "bastion service",
|
||||||
service: newBastionService(),
|
service: newBastionService(),
|
||||||
req: bastionReq,
|
bastionMode: true,
|
||||||
|
req: bastionReq,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
testCase: "invalid bastion request",
|
testCase: "invalid bastion request",
|
||||||
service: newBastionService(),
|
service: newBastionService(),
|
||||||
req: baseReq,
|
bastionMode: true,
|
||||||
expectErr: true,
|
req: baseReq,
|
||||||
|
expectErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
testCase: "bastion service",
|
testCase: "bastion service",
|
||||||
service: newBastionServiceWithDest(u),
|
service: newBastionServiceWithDest(MustParseURL(t, "https://place-holder1")),
|
||||||
req: bastionReq,
|
req: bastionReq,
|
||||||
|
bastionMode: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
testCase: "bastion service",
|
testCase: "bastion service",
|
||||||
service: newBastionServiceWithDest(u),
|
service: newBastionServiceWithDest(MustParseURL(t, "https://place-holder1")),
|
||||||
req: bastionReq,
|
req: baseReq,
|
||||||
expectErr: true,
|
bastionMode: true,
|
||||||
|
expectErr: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.testCase, func(t *testing.T) {
|
t.Run(test.testCase, func(t *testing.T) {
|
||||||
if test.expectErr {
|
if test.expectErr {
|
||||||
bastionHost, _ := carrier.ResolveBastionDest(test.req, false, "bastion")
|
bastionHost, _ := carrier.ResolveBastionDest(test.req, test.bastionMode, config.BastionFlag)
|
||||||
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
|
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,7 +59,6 @@ func (r *Rule) Matches(hostname, path string) bool {
|
||||||
} else {
|
} else {
|
||||||
hostMatch = matchHost(r.Hostname, hostname)
|
hostMatch = matchHost(r.Hostname, hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
punycodeHostMatch := false
|
punycodeHostMatch := false
|
||||||
if r.punycodeHostname != "" {
|
if r.punycodeHostname != "" {
|
||||||
punycodeHostMatch = matchHost(r.punycodeHostname, hostname)
|
punycodeHostMatch = matchHost(r.punycodeHostname, hostname)
|
||||||
|
|
|
@ -99,10 +99,11 @@ func (p *Proxy) ProxyHTTP(
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Handling for StreamBasedOriginProxy or BastionMode
|
|
||||||
|
// Check if config is for Bastion Mode and service is a stream based origin proxy, if so stream service in bastion mode
|
||||||
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); ok || rule.Config.BastionMode {
|
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); ok || rule.Config.BastionMode {
|
||||||
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); !ok && rule.Config.BastionMode {
|
if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); !ok && rule.Config.BastionMode {
|
||||||
return fmt.Errorf("Unrecognized service: %s", rule.Service)
|
return fmt.Errorf("Unsupported service to stream to in bastion mode: %s", rule.Service)
|
||||||
}
|
}
|
||||||
|
|
||||||
dest, err := getDestFromRule(rule, req)
|
dest, err := getDestFromRule(rule, req)
|
||||||
|
@ -116,6 +117,7 @@ func (p *Proxy) ProxyHTTP(
|
||||||
}
|
}
|
||||||
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
|
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
|
||||||
logger := logger.With().Str(logFieldDestAddr, dest).Logger()
|
logger := logger.With().Str(logFieldDestAddr, dest).Logger()
|
||||||
|
// We know that Bastion mode is supported by StreamBasedOriginProxy, hence use the same
|
||||||
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, rule.Service.(ingress.StreamBasedOriginProxy), &logger); err != nil {
|
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, rule.Service.(ingress.StreamBasedOriginProxy), &logger); err != nil {
|
||||||
logRequestError(&logger, err)
|
logRequestError(&logger, err)
|
||||||
return err
|
return err
|
||||||
|
|
Loading…
Reference in New Issue