diff --git a/ingress/ingress.go b/ingress/ingress.go index f8b22df4..0529e382 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -103,7 +103,7 @@ func NewWarpRoutingService() *WarpRoutingService { } // Get a single origin service from the CLI/config. -func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originService, error) { +func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginService, error) { if c.IsSet("hello-world") { return new(helloWorld), nil } @@ -167,7 +167,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon rules := make([]Rule, len(ingress)) for i, r := range ingress { cfg := setConfig(defaults, r.OriginRequest) - var service originService + var service OriginService if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) { // No validation necessary for unix socket filepath services diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index 6af8e1cd..c9da44f5 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -6,9 +6,6 @@ import ( "net/http" "github.com/pkg/errors" - - "github.com/cloudflare/cloudflared/carrier" - "github.com/cloudflare/cloudflared/websocket" ) var ( @@ -24,7 +21,7 @@ type HTTPOriginProxy interface { // StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP. type StreamBasedOriginProxy interface { - EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) + EstablishConnection(dest string) (OriginConnection, error) } func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { @@ -54,73 +51,36 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { return o.resp, nil } -func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { - dest, err := getRequestHost(r) - if err != nil { - return nil, nil, err - } +func (o *rawTCPService) EstablishConnection(dest string) (OriginConnection, error) { conn, err := net.Dial("tcp", dest) if err != nil { - return nil, nil, err + return nil, err } originConn := &tcpConnection{ conn: conn, } - resp := &http.Response{ - Status: switchingProtocolText, - StatusCode: http.StatusSwitchingProtocols, - ContentLength: -1, - } - return originConn, resp, nil + return originConn, nil } -// getRequestHost returns the host of the http.Request. -func getRequestHost(r *http.Request) (string, error) { - if r.Host != "" { - return r.Host, nil - } - if r.URL != nil { - return r.URL.Host, nil - } - return "", errors.New("host not found") -} - -func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { +func (o *tcpOverWSService) EstablishConnection(dest string) (OriginConnection, error) { var err error - dest := o.dest - if o.isBastion { - dest, err = carrier.ResolveBastionDest(r) - if err != nil { - return nil, nil, err - } + if !o.isBastion { + dest = o.dest } conn, err := net.Dial("tcp", dest) if err != nil { - return nil, nil, err + return nil, err } originConn := &tcpOverWSConnection{ conn: conn, streamHandler: o.streamHandler, } - resp := &http.Response{ - Status: switchingProtocolText, - StatusCode: http.StatusSwitchingProtocols, - Header: websocket.NewResponseHeader(r), - ContentLength: -1, - } - return originConn, resp, nil + return originConn, nil } -func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { - originConn := o.conn - resp := &http.Response{ - Status: switchingProtocolText, - StatusCode: http.StatusSwitchingProtocols, - Header: websocket.NewResponseHeader(r), - ContentLength: -1, - } - return originConn, resp, nil +func (o *socksProxyOverWSService) EstablishConnection(dest string) (OriginConnection, error) { + return o.conn, nil } diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index 1e54c2fe..9be788d0 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -17,20 +17,6 @@ import ( "github.com/cloudflare/cloudflared/websocket" ) -// TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns -// the expected response -func assertEstablishConnectionResponse(t *testing.T, - originProxy StreamBasedOriginProxy, - req *http.Request, - expectHeader http.Header, -) { - _, resp, err := originProxy.EstablishConnection(req) - assert.NoError(t, err) - assert.Equal(t, switchingProtocolText, resp.Status) - assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) - assert.Equal(t, expectHeader, resp.Header) -} - func TestRawTCPServiceEstablishConnection(t *testing.T) { originListener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -43,8 +29,6 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) { req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil) require.NoError(t, err) - assertEstablishConnectionResponse(t, rawTCPService, req, nil) - originListener.Close() <-listenerClosed @@ -52,9 +36,8 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) { require.NoError(t, err) // Origin not listening for new connection, should return an error - _, resp, err := rawTCPService.EstablishConnection(req) + _, err = rawTCPService.EstablishConnection(req.URL.String()) require.Error(t, err) - require.Nil(t, resp) } func TestTCPOverWSServiceEstablishConnection(t *testing.T) { @@ -76,12 +59,6 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { bastionReq := baseReq.Clone(context.Background()) carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String()) - expectHeader := http.Header{ - "Connection": {"Upgrade"}, - "Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, - "Upgrade": {"websocket"}, - } - tests := []struct { testCase string service *tcpOverWSService @@ -109,11 +86,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { for _, test := range tests { t.Run(test.testCase, func(t *testing.T) { if test.expectErr { - _, resp, err := test.service.EstablishConnection(test.req) + bastionHost, _ := carrier.ResolveBastionDest(test.req) + _, err := test.service.EstablishConnection(bastionHost) assert.Error(t, err) - assert.Nil(t, resp) - } else { - assertEstablishConnectionResponse(t, test.service, test.req, expectHeader) } }) } @@ -123,9 +98,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} { // Origin not listening for new connection, should return an error - _, resp, err := service.EstablishConnection(bastionReq) + bastionHost, _ := carrier.ResolveBastionDest(bastionReq) + _, err := service.EstablishConnection(bastionHost) assert.Error(t, err) - assert.Nil(t, resp) } } diff --git a/ingress/origin_service.go b/ingress/origin_service.go index a915ec72..fc636c86 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -20,8 +20,8 @@ import ( "github.com/cloudflare/cloudflared/tlsconfig" ) -// originService is something a tunnel can proxy traffic to. -type originService interface { +// OriginService is something a tunnel can proxy traffic to. +type OriginService interface { String() string // Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World. // If it's not managed by cloudflared, this is a no-op because the user is responsible for @@ -238,7 +238,7 @@ func (nrc *NopReadCloser) Close() error { return nil } -func newHTTPTransport(service originService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) { +func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) { originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log) if err != nil { return nil, errors.Wrap(err, "Error loading cert pool") diff --git a/ingress/rule.go b/ingress/rule.go index c9548fc3..e91b4139 100644 --- a/ingress/rule.go +++ b/ingress/rule.go @@ -17,7 +17,7 @@ type Rule struct { // A (probably local) address. Requests for a hostname which matches this // rule's hostname pattern will be proxied to the service running on this // address. - Service originService + Service OriginService // Configure the request cloudflared sends to this specific origin. Config OriginRequestConfig diff --git a/ingress/rule_test.go b/ingress/rule_test.go index e17d0857..f5bfcd92 100644 --- a/ingress/rule_test.go +++ b/ingress/rule_test.go @@ -14,7 +14,7 @@ func Test_rule_matches(t *testing.T) { type fields struct { Hostname string Path *regexp.Regexp - Service originService + Service OriginService } type args struct { requestURL *url.URL diff --git a/origin/proxy.go b/origin/proxy.go index a10452d3..8f984a94 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" "github.com/rs/zerolog" + "github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/ingress" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -33,6 +34,8 @@ type proxy struct { bufferPool *bufferPool } +var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) + func NewOriginProxy( ingressRules ingress.Ingress, warpRouting *ingress.WarpRoutingService, @@ -71,7 +74,13 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn lbProbe: lbProbe, rule: ingress.ServiceWarpRouting, } - if err := p.proxyStreamRequest(serveCtx, w, req, p.warpRouting.Proxy, logFields); err != nil { + + host, err := getRequestHost(req) + if err != nil { + err = fmt.Errorf(`cloudflared recieved a warp-routing request with an empty host value: %v`, err) + return err + } + if err := p.proxyStreamRequest(serveCtx, w, host, req, p.warpRouting.Proxy, logFields); err != nil { p.logRequestError(err, cfRay, "", ingress.ServiceWarpRouting) return err } @@ -97,7 +106,11 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn return nil case ingress.StreamBasedOriginProxy: - if err := p.proxyStreamRequest(serveCtx, w, req, originProxy, logFields); err != nil { + dest, err := getDestFromRule(rule, req) + if err != nil { + return err + } + if err := p.proxyStreamRequest(serveCtx, w, dest, req, originProxy, logFields); err != nil { rule, srv := ruleField(p.ingressRules, ruleNum) p.logRequestError(err, cfRay, rule, srv) return err @@ -105,10 +118,29 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn return nil default: return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy) - } } +func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) { + switch rule.Service.String() { + case ingress.ServiceBastion: + return carrier.ResolveBastionDest(req) + default: + return rule.Service.String(), nil + } +} + +// getRequestHost returns the host of the http.Request. +func getRequestHost(r *http.Request) (string, error) { + if r.Host != "" { + return r.Host, nil + } + if r.URL != nil { + return r.URL.Host, nil + } + return "", errors.New("host not set in incoming request") +} + func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) { srv = ing.Rules[ruleNum].Service.String() if ing.IsSingleRule() { @@ -191,16 +223,24 @@ func (p *proxy) proxyHTTPRequest( func (p *proxy) proxyStreamRequest( serveCtx context.Context, w connection.ResponseWriter, + dest string, req *http.Request, connectionProxy ingress.StreamBasedOriginProxy, fields logFields, ) error { - originConn, resp, err := connectionProxy.EstablishConnection(req) + originConn, err := connectionProxy.EstablishConnection(dest) if err != nil { return err } - if resp.Body != nil { - defer resp.Body.Close() + + resp := &http.Response{ + Status: switchingProtocolText, + StatusCode: http.StatusSwitchingProtocols, + ContentLength: -1, + } + + if secWebsocketKey := req.Header.Get("Sec-WebSocket-Key"); secWebsocketKey != "" { + resp.Header = websocket.NewResponseHeader(req) } if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {