diff --git a/carrier/websocket.go b/carrier/websocket.go index 0ae203d2..e05dcb8c 100644 --- a/carrier/websocket.go +++ b/carrier/websocket.go @@ -23,7 +23,7 @@ type Websocket struct { } type wsdialer struct { - conn *cfwebsocket.Conn + conn *cfwebsocket.GorillaConn } func (d *wsdialer) Dial(address string) (io.ReadWriteCloser, *socks.AddrSpec, error) { @@ -75,7 +75,7 @@ func (ws *Websocket) StartServer(listener net.Listener, remote string, shutdownC // createWebsocketStream will create a WebSocket connection to stream data over // It also handles redirects from Access and will present that flow if // the token is not present on the request -func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebsocket.Conn, error) { +func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebsocket.GorillaConn, error) { req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil) if err != nil { return nil, err @@ -97,7 +97,7 @@ func createWebsocketStream(options *StartOptions, log *zerolog.Logger) (*cfwebso return nil, err } - return &cfwebsocket.Conn{Conn: wsConn}, nil + return &cfwebsocket.GorillaConn{Conn: wsConn}, nil } // createAccessAuthenticatedStream will try load a token from storage and make diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 26f07dce..9c97d350 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -245,9 +245,9 @@ func prepareTunnelConfig( edgeTLSConfigs[p] = edgeTLSConfig } - originClient := origin.NewClient(ingressRules, tags, log) + originProxy := origin.NewOriginProxy(ingressRules, tags, log) connectionConfig := &connection.Config{ - OriginClient: originClient, + OriginProxy: originProxy, GracePeriod: c.Duration("grace-period"), ReplaceExisting: c.Bool("force"), } diff --git a/connection/connection.go b/connection/connection.go index 053f4985..98e86baa 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -14,7 +14,7 @@ import ( const LogFieldConnIndex = "connIndex" type Config struct { - OriginClient OriginClient + OriginProxy OriginProxy GracePeriod time.Duration ReplaceExisting bool } @@ -50,12 +50,12 @@ func (c *ClassicTunnelConfig) IsTrialZone() bool { return c.Hostname == "" } -type OriginClient interface { +type OriginProxy interface { Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error } type ResponseWriter interface { - WriteRespHeaders(*http.Response) error + WriteRespHeaders(status int, header http.Header) error WriteErrorResponse() io.ReadWriter } diff --git a/connection/connection_test.go b/connection/connection_test.go index 01f03999..7fe02d17 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -19,8 +19,8 @@ const ( var ( testConfig = &Config{ - OriginClient: &mockOriginClient{}, - GracePeriod: time.Millisecond * 100, + OriginProxy: &mockOriginProxy{}, + GracePeriod: time.Millisecond * 100, } log = zerolog.Nop() testOriginURL = &url.URL{ @@ -38,10 +38,10 @@ type testRequest struct { isProxyError bool } -type mockOriginClient struct { +type mockOriginProxy struct { } -func (moc *mockOriginClient) Proxy(w ResponseWriter, r *http.Request, isWebsocket bool) error { +func (moc *mockOriginProxy) Proxy(w ResponseWriter, r *http.Request, isWebsocket bool) error { if isWebsocket { return wsEndpoint(w, r) } @@ -74,7 +74,7 @@ func wsEndpoint(w ResponseWriter, r *http.Request) error { resp := &http.Response{ StatusCode: http.StatusSwitchingProtocols, } - _ = w.WriteRespHeaders(resp) + _ = w.WriteRespHeaders(resp.StatusCode, resp.Header) clientReader := nowriter{r.Body} go func() { for { @@ -95,7 +95,7 @@ func originRespEndpoint(w ResponseWriter, status int, data []byte) { resp := &http.Response{ StatusCode: status, } - _ = w.WriteRespHeaders(resp) + _ = w.WriteRespHeaders(resp.StatusCode, resp.Header) _, _ = w.Write(data) } diff --git a/connection/h2mux.go b/connection/h2mux.go index 5d9ec068..31921847 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -216,7 +216,7 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error { return reqErr } - err := h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req)) + err := h.config.OriginProxy.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req)) if err != nil { respWriter.WriteErrorResponse() return err @@ -240,8 +240,8 @@ type h2muxRespWriter struct { *h2mux.MuxedStream } -func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error { - headers := h2mux.H1ResponseToH2ResponseHeaders(resp) +func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error { + headers := h2mux.H1ResponseToH2ResponseHeaders(status, header) headers = append(headers, h2mux.Header{Name: ResponseMetaHeaderField, Value: responseMetaHeaderOrigin}) return rp.WriteHeaders(headers) } diff --git a/connection/http2.go b/connection/http2.go index dbadd555..d938dd53 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -115,9 +115,9 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { } else if isWebsocketUpgrade(r) { respWriter.shouldFlush = true stripWebsocketUpgradeHeader(r) - err = c.config.OriginClient.Proxy(respWriter, r, true) + err = c.config.OriginProxy.Proxy(respWriter, r, true) } else { - err = c.config.OriginClient.Proxy(respWriter, r, false) + err = c.config.OriginProxy.Proxy(respWriter, r, false) } if err != nil { @@ -161,10 +161,10 @@ type http2RespWriter struct { shouldFlush bool } -func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error { +func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error { dest := rp.w.Header() - userHeaders := make(http.Header, len(resp.Header)) - for header, values := range resp.Header { + userHeaders := make(http.Header, len(header)) + for header, values := range header { // Since these are http2 headers, they're required to be lowercase h2name := strings.ToLower(header) for _, v := range values { @@ -184,13 +184,12 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error { // Perform user header serialization and set them in the single header dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders)) rp.setResponseMetaHeader(responseMetaHeaderOrigin) - status := resp.StatusCode // HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1 if status == http.StatusSwitchingProtocols { status = http.StatusOK } rp.w.WriteHeader(status) - if IsServerSentEvent(resp.Header) { + if IsServerSentEvent(header) { rp.shouldFlush = true } if rp.shouldFlush { diff --git a/h2mux/header.go b/h2mux/header.go index f27f6b26..848c93b0 100644 --- a/h2mux/header.go +++ b/h2mux/header.go @@ -125,12 +125,12 @@ func IsWebsocketClientHeader(headerName string) bool { headerName == "upgrade" } -func H1ResponseToH2ResponseHeaders(h1 *http.Response) (h2 []Header) { +func H1ResponseToH2ResponseHeaders(status int, h1 http.Header) (h2 []Header) { h2 = []Header{ - {Name: ":status", Value: strconv.Itoa(h1.StatusCode)}, + {Name: ":status", Value: strconv.Itoa(status)}, } - userHeaders := make(http.Header, len(h1.Header)) - for header, values := range h1.Header { + userHeaders := make(http.Header, len(h1)) + for header, values := range h1 { h2name := strings.ToLower(header) if h2name == "content-length" { // This header has meaning in HTTP/2 and will be used by the edge, diff --git a/h2mux/header_test.go b/h2mux/header_test.go index 0d3f5b06..e9da95e0 100644 --- a/h2mux/header_test.go +++ b/h2mux/header_test.go @@ -579,7 +579,7 @@ func TestH1ResponseToH2ResponseHeaders(t *testing.T) { Header: mockHeaders, } - headers := H1ResponseToH2ResponseHeaders(&mockResponse) + headers := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header) serializedHeadersIndex := -1 for i, header := range headers { @@ -622,7 +622,7 @@ func TestHeaderSize(t *testing.T) { Header: largeHeaders, } - serializedHeaders := H1ResponseToH2ResponseHeaders(&mockResponse) + serializedHeaders := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header) request, err := http.NewRequest(http.MethodGet, "https://example.com/", nil) assert.NoError(t, err) for _, header := range serializedHeaders { @@ -669,6 +669,6 @@ func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - _ = H1ResponseToH2ResponseHeaders(h1resp) + _ = H1ResponseToH2ResponseHeaders(h1resp.StatusCode, h1resp.Header) } } diff --git a/ingress/ingress.go b/ingress/ingress.go index eeaa2184..e3edd6a2 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -85,16 +85,24 @@ func NewSingleOrigin(c *cli.Context, allowURLFromArgs bool) (Ingress, error) { } // Get a single origin service from the CLI/config. -func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginService, error) { +func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originService, error) { if c.IsSet("hello-world") { return new(helloWorld), nil } - if c.IsSet("url") || c.IsSet(config.BastionFlag) { + if c.IsSet(config.BastionFlag) { + return newBridgeService(), nil + } + if c.IsSet("url") { originURL, err := config.ValidateUrl(c, allowURLFromArgs) if err != nil { return nil, errors.Wrap(err, "Error validating origin URL") } - return &localService{URL: originURL, RootURL: originURL}, nil + if isHTTPService(originURL) { + return &httpService{ + url: originURL, + }, nil + } + return newSingleTCPService(originURL), nil } if c.IsSet("unix-socket") { path, err := config.ValidateUnixSocket(c) @@ -104,7 +112,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginServ return &unixSocketPath{path: path}, nil } u, err := url.Parse("http://localhost:8080") - return &localService{URL: u, RootURL: u}, err + return &httpService{url: u}, err } // IsEmpty checks if there are any ingress rules. @@ -136,7 +144,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon rules := make([]Rule, len(ingress)) for i, r := range ingress { cfg := setConfig(defaults, r.OriginRequest) - var service OriginService + var service originService if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) { // No validation necessary for unix socket filepath services @@ -156,7 +164,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon // overwrite the localService.URL field when `start` is called. So, // leave the URL field empty for now. cfg.BastionMode = true - service = new(localService) + service = newBridgeService() } else { // Validate URL services u, err := url.Parse(r.Service) @@ -171,8 +179,11 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon if u.Path != "" { return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path", r.Service) } - serviceURL := localService{URL: u} - service = &serviceURL + if isHTTPService(u) { + service = &httpService{url: u} + } else { + service = newSingleTCPService(u) + } } if err := validateHostname(r, i, len(ingress)); err != nil { @@ -241,3 +252,7 @@ func ParseIngress(conf *config.Configuration) (Ingress, error) { } return validate(conf.Ingress, originRequestFromYAML(conf.OriginRequest)) } + +func isHTTPService(url *url.URL) bool { + return url.Scheme == "http" || url.Scheme == "https" || url.Scheme == "ws" || url.Scheme == "wss" +} diff --git a/ingress/ingress_test.go b/ingress/ingress_test.go index 411f22b4..64ad3315 100644 --- a/ingress/ingress_test.go +++ b/ingress/ingress_test.go @@ -61,12 +61,12 @@ ingress: want: []Rule{ { Hostname: "tunnel1.example.com", - Service: &localService{URL: localhost8000}, + Service: &httpService{url: localhost8000}, Config: defaultConfig, }, { Hostname: "*", - Service: &localService{URL: localhost8001}, + Service: &httpService{url: localhost8001}, Config: defaultConfig, }, }, @@ -82,7 +82,22 @@ extraKey: extraValue want: []Rule{ { Hostname: "*", - Service: &localService{URL: localhost8000}, + Service: &httpService{url: localhost8000}, + Config: defaultConfig, + }, + }, + }, + { + name: "ws service", + args: args{rawYAML: ` +ingress: + - hostname: "*" + service: wss://localhost:8000 +`}, + want: []Rule{ + { + Hostname: "*", + Service: &httpService{url: MustParseURL(t, "wss://localhost:8000")}, Config: defaultConfig, }, }, @@ -95,7 +110,7 @@ ingress: `}, want: []Rule{ { - Service: &localService{URL: localhost8000}, + Service: &httpService{url: localhost8000}, Config: defaultConfig, }, }, @@ -209,6 +224,85 @@ ingress: }, }, }, + { + name: "TCP services", + args: args{rawYAML: ` +ingress: +- hostname: tcp.foo.com + service: tcp://127.0.0.1 +- hostname: tcp2.foo.com + service: tcp://localhost:8000 +- service: http_status:404 +`}, + want: []Rule{ + { + Hostname: "tcp.foo.com", + Service: newSingleTCPService(MustParseURL(t, "tcp://127.0.0.1:7864")), + Config: defaultConfig, + }, + { + Hostname: "tcp2.foo.com", + Service: newSingleTCPService(MustParseURL(t, "tcp://localhost:8000")), + Config: defaultConfig, + }, + { + Service: &fourOhFour, + Config: defaultConfig, + }, + }, + }, + { + name: "SSH services", + args: args{rawYAML: ` +ingress: +- service: ssh://127.0.0.1 +`}, + want: []Rule{ + { + Service: newSingleTCPService(MustParseURL(t, "ssh://127.0.0.1:22")), + Config: defaultConfig, + }, + }, + }, + { + name: "RDP services", + args: args{rawYAML: ` +ingress: +- service: rdp://127.0.0.1 +`}, + want: []Rule{ + { + Service: newSingleTCPService(MustParseURL(t, "rdp://127.0.0.1:3389")), + Config: defaultConfig, + }, + }, + }, + { + name: "SMB services", + args: args{rawYAML: ` +ingress: +- service: smb://127.0.0.1 +`}, + want: []Rule{ + { + Service: newSingleTCPService(MustParseURL(t, "smb://127.0.0.1:445")), + Config: defaultConfig, + }, + }, + }, + { + name: "Other TCP services", + args: args{rawYAML: ` +ingress: +- service: ftp://127.0.0.1 +`}, + want: []Rule{ + { + Service: newSingleTCPService(MustParseURL(t, "ftp://127.0.0.1")), + Config: defaultConfig, + }, + }, + }, { name: "URL isn't necessary if using bastion", args: args{rawYAML: ` @@ -221,7 +315,7 @@ ingress: want: []Rule{ { Hostname: "bastion.foo.com", - Service: &localService{}, + Service: newBridgeService(), Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), }, { @@ -241,7 +335,7 @@ ingress: want: []Rule{ { Hostname: "bastion.foo.com", - Service: &localService{}, + Service: newBridgeService(), Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), }, { @@ -409,6 +503,37 @@ func TestFindMatchingRule(t *testing.T) { } } +func TestIsHTTPService(t *testing.T) { + tests := []struct { + url *url.URL + isHTTP bool + }{ + { + url: MustParseURL(t, "http://localhost"), + isHTTP: true, + }, + { + url: MustParseURL(t, "https://127.0.0.1:8000"), + isHTTP: true, + }, + { + url: MustParseURL(t, "ws://localhost"), + isHTTP: true, + }, + { + url: MustParseURL(t, "wss://localhost:8000"), + isHTTP: true, + }, + { + url: MustParseURL(t, "tcp://localhost:9000"), + isHTTP: false, + }, + } + for _, test := range tests { + assert.Equal(t, test.isHTTP, isHTTPService(test.url)) + } +} + func mustParsePath(t *testing.T, path string) *regexp.Regexp { regexp, err := regexp.Compile(path) assert.NoError(t, err) diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go new file mode 100644 index 00000000..9e2b9671 --- /dev/null +++ b/ingress/origin_connection.go @@ -0,0 +1,62 @@ +package ingress + +import ( + "io" + "net" + "net/http" + + "github.com/cloudflare/cloudflared/websocket" + gws "github.com/gorilla/websocket" +) + +// OriginConnection is a way to stream to a service running on the user's origin. +// Different concrete implementations will stream different protocols as long as they are io.ReadWriters. +type OriginConnection interface { + // Stream should generally be implemented as a bidirectional io.Copy. + Stream(tunnelConn io.ReadWriter) + Close() +} + +// tcpConnection is an OriginConnection that directly streams to raw TCP. +type tcpConnection struct { + conn net.Conn + streamHandler func(tunnelConn io.ReadWriter, originConn net.Conn) +} + +func (tc *tcpConnection) Stream(tunnelConn io.ReadWriter) { + tc.streamHandler(tunnelConn, tc.conn) +} + +func (tc *tcpConnection) Close() { + tc.conn.Close() +} + +// wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets. +// TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does. +type wsConnection struct { + wsConn *gws.Conn + resp *http.Response +} + +func (wsc *wsConnection) Stream(tunnelConn io.ReadWriter) { + websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn()) +} + +func (wsc *wsConnection) Close() { + wsc.resp.Body.Close() + wsc.wsConn.Close() +} + +func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) { + d := &gws.Dialer{ + TLSClientConfig: transport.TLSClientConfig, + } + wsConn, resp, err := websocket.ClientConnect(r, d) + if err != nil { + return nil, err + } + return &wsConnection{ + wsConn, + resp, + }, nil +} diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go new file mode 100644 index 00000000..ed02ce72 --- /dev/null +++ b/ingress/origin_proxy.go @@ -0,0 +1,100 @@ +package ingress + +import ( + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + + "github.com/cloudflare/cloudflared/h2mux" +) + +// HTTPOriginProxy can be implemented by origin services that want to proxy http requests. +type HTTPOriginProxy interface { + // RoundTrip is how cloudflared proxies eyeball requests to the actual origin services + http.RoundTripper +} + +// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level. +type StreamBasedOriginProxy interface { + EstablishConnection(r *http.Request) (OriginConnection, error) +} + +func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { + return o.transport.RoundTrip(req) +} + +// TODO: TUN-3636: establish connection to origins over UDS +func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, error) { + return nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections") +} + +func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { + // Rewrite the request URL so that it goes to the origin service. + req.URL.Host = o.url.Host + req.URL.Scheme = o.url.Scheme + return o.transport.RoundTrip(req) +} + +func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) { + // Rewrite the request URL so that it goes to the Hello World server. + req.URL.Host = o.server.Addr().String() + req.URL.Scheme = "https" + return o.transport.RoundTrip(req) +} + +func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, error) { + req.URL.Host = o.server.Addr().String() + req.URL.Scheme = "wss" + return newWSConnection(o.transport, req) +} + +func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { + return o.resp, nil +} + +func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, error) { + dest, err := o.destination(r) + if err != nil { + return nil, err + } + return o.client.connect(r, dest) +} + +func (o *bridgeService) destination(r *http.Request) (string, error) { + jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader) + if jumpDestination == "" { + return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side") + } + // Strip scheme and path set by client. Without a scheme + // Parsing a hostname and path without scheme might not return an error due to parsing ambiguities + if jumpURL, err := url.Parse(jumpDestination); err == nil && jumpURL.Host != "" { + return removePath(jumpURL.Host), nil + } + return removePath(jumpDestination), nil +} + +func removePath(dest string) string { + return strings.SplitN(dest, "/", 2)[0] +} + +func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, error) { + return o.client.connect(r, o.dest) +} + +type tcpClient struct { + streamHandler func(originConn io.ReadWriter, remoteConn net.Conn) +} + +func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + return &tcpConnection{ + conn: conn, + streamHandler: c.streamHandler, + }, nil +} diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go new file mode 100644 index 00000000..13b8ff63 --- /dev/null +++ b/ingress/origin_proxy_test.go @@ -0,0 +1,107 @@ +package ingress + +import ( + "net/http" + "testing" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/stretchr/testify/assert" +) + +func TestBridgeServiceDestination(t *testing.T) { + canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader) + tests := []struct { + name string + header http.Header + expectedDest string + wantErr bool + }{ + { + name: "hostname destination", + header: http.Header{ + canonicalJumpDestHeader: []string{"localhost"}, + }, + expectedDest: "localhost", + }, + { + name: "hostname destination with port", + header: http.Header{ + canonicalJumpDestHeader: []string{"localhost:9000"}, + }, + expectedDest: "localhost:9000", + }, + { + name: "hostname destination with scheme and port", + header: http.Header{ + canonicalJumpDestHeader: []string{"ssh://localhost:9000"}, + }, + expectedDest: "localhost:9000", + }, + { + name: "full hostname url", + header: http.Header{ + canonicalJumpDestHeader: []string{"ssh://localhost:9000/metrics"}, + }, + expectedDest: "localhost:9000", + }, + { + name: "hostname destination with port and path", + header: http.Header{ + canonicalJumpDestHeader: []string{"localhost:9000/metrics"}, + }, + expectedDest: "localhost:9000", + }, + { + name: "ip destination", + header: http.Header{ + canonicalJumpDestHeader: []string{"127.0.0.1"}, + }, + expectedDest: "127.0.0.1", + }, + { + name: "ip destination with port", + header: http.Header{ + canonicalJumpDestHeader: []string{"127.0.0.1:9000"}, + }, + expectedDest: "127.0.0.1:9000", + }, + { + name: "ip destination with port and path", + header: http.Header{ + canonicalJumpDestHeader: []string{"127.0.0.1:9000/metrics"}, + }, + expectedDest: "127.0.0.1:9000", + }, + { + name: "ip destination with schem and port", + header: http.Header{ + canonicalJumpDestHeader: []string{"tcp://127.0.0.1:9000"}, + }, + expectedDest: "127.0.0.1:9000", + }, + { + name: "full ip url", + header: http.Header{ + canonicalJumpDestHeader: []string{"ssh://127.0.0.1:9000/metrics"}, + }, + expectedDest: "127.0.0.1:9000", + }, + { + name: "no destination", + wantErr: true, + }, + } + s := newBridgeService() + for _, test := range tests { + r := &http.Request{ + Header: test.header, + } + dest, err := s.destination(r) + if test.wantErr { + assert.Error(t, err, "Test %s expects error", test.name) + } else { + assert.NoError(t, err, "Test %s expects no error, got error %v", test.name, err) + assert.Equal(t, test.expectedDest, dest, "Test %s expect dest %s, got %s", test.name, test.expectedDest, dest) + } + } +} diff --git a/ingress/origin_service.go b/ingress/origin_service.go index ba3beb7d..d8f0a015 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -8,7 +8,6 @@ import ( "net" "net/http" "net/url" - "strconv" "sync" "time" @@ -21,10 +20,8 @@ import ( "github.com/rs/zerolog" ) -// OriginService is something a tunnel can proxy traffic to. -type OriginService interface { - // RoundTrip is how cloudflared proxies eyeball requests to the actual origin services - http.RoundTripper +// originService is something a tunnel can proxy traffic to. +type originService interface { String() string // Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World. // If it's not managed by cloudflared, this is a no-op because the user is responsible for @@ -51,10 +48,6 @@ func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdown return nil } -func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { - return o.transport.RoundTrip(req) -} - func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) { d := &gws.Dialer{ NetDial: o.transport.Dial, @@ -65,130 +58,87 @@ func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, return d.Dial(reqURL.String(), headers) } -// localService is an OriginService listening on a TCP/IP address the user's origin can route to. -type localService struct { - // The URL for the user's origin service - RootURL *url.URL - // The URL that cloudflared should send requests to. - // If this origin requires starting a proxy, this is the proxy's address, - // and that proxy points to RootURL. Otherwise, this is equal to RootURL. - URL *url.URL +type httpService struct { + url *url.URL transport *http.Transport } -func (o *localService) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) { - d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig} - // Rewrite the request URL so that it goes to the origin service. - reqURL.Host = o.URL.Host - reqURL.Scheme = websocket.ChangeRequestScheme(o.URL) - return d.Dial(reqURL.String(), headers) -} - -func (o *localService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { transport, err := newHTTPTransport(o, cfg, log) if err != nil { return err } o.transport = transport - - // Start a proxy if one is needed - if staticHost := o.staticHost(); originRequiresProxy(staticHost, cfg) { - if err := o.startProxy(staticHost, wg, log, shutdownC, errC, cfg); err != nil { - return err - } - } - return nil } -func (o *localService) startProxy(staticHost string, wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (o *httpService) String() string { + return o.url.String() +} - // Start a listener for the proxy - proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort))) - listener, err := net.Listen("tcp", proxyAddress) - if err != nil { - log.Error().Msgf("Cannot start Websocket Proxy Server: %s", err) - return errors.Wrap(err, "Cannot start Websocket Proxy Server") +// bridgeService is like a jump host, the destination is specified by the client +type bridgeService struct { + client *tcpClient +} + +func newBridgeService() *bridgeService { + return &bridgeService{ + client: &tcpClient{}, } +} - // Start the proxy itself - wg.Add(1) - go func() { - defer wg.Done() - streamHandler := websocket.DefaultStreamHandler - // This origin's config specifies what type of proxy to start. - switch cfg.ProxyType { - case socksProxy: - log.Info().Msg("SOCKS5 server started") - streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, _ http.Header) { - dialer := socks.NewConnDialer(remoteConn) - requestHandler := socks.NewRequestHandler(dialer) - socksServer := socks.NewConnectionHandler(requestHandler) +func (o *bridgeService) String() string { + return "bridge service" +} - _ = socksServer.Serve(wsConn) - } - case "": - log.Debug().Msg("Not starting any websocket proxy") - default: - log.Error().Msgf("%s isn't a valid proxy (valid options are {%s})", cfg.ProxyType, socksProxy) - } - - errC <- websocket.StartProxyServer(log, listener, staticHost, shutdownC, streamHandler) - }() - - // Modify this origin, so that it no longer points at the origin service directly. - // Instead, it points at the proxy to the origin service. - newURL, err := url.Parse("http://" + listener.Addr().String()) - if err != nil { - return err +func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { + if cfg.ProxyType == socksProxy { + o.client.streamHandler = socks.StreamHandler + } else { + o.client.streamHandler = websocket.DefaultStreamHandler } - o.URL = newURL return nil } -func (o *localService) String() string { - if o.isBastion() { - return "Bastion" - } - return o.URL.String() +type singleTCPService struct { + dest string + client *tcpClient } -func (o *localService) isBastion() bool { - return o.URL == nil -} - -func (o *localService) RoundTrip(req *http.Request) (*http.Response, error) { - // Rewrite the request URL so that it goes to the origin service. - req.URL.Host = o.URL.Host - req.URL.Scheme = o.URL.Scheme - return o.transport.RoundTrip(req) -} - -func (o *localService) staticHost() string { - - if o.URL == nil { - return "" - } - - addPortIfMissing := func(uri *url.URL, port int) string { - if uri.Port() != "" { - return uri.Host - } - return fmt.Sprintf("%s:%d", uri.Hostname(), port) - } - - switch o.URL.Scheme { +func newSingleTCPService(url *url.URL) *singleTCPService { + switch url.Scheme { case "ssh": - return addPortIfMissing(o.URL, 22) + addPortIfMissing(url, 22) case "rdp": - return addPortIfMissing(o.URL, 3389) + addPortIfMissing(url, 3389) case "smb": - return addPortIfMissing(o.URL, 445) + addPortIfMissing(url, 445) case "tcp": - return addPortIfMissing(o.URL, 7864) // just a random port since there isn't a default in this case + addPortIfMissing(url, 7864) // just a random port since there isn't a default in this case } - return "" + return &singleTCPService{ + dest: url.Host, + client: &tcpClient{}, + } +} +func addPortIfMissing(uri *url.URL, port int) { + if uri.Port() == "" { + uri.Host = fmt.Sprintf("%s:%d", uri.Hostname(), port) + } +} + +func (o *singleTCPService) String() string { + return o.dest +} + +func (o *singleTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { + if cfg.ProxyType == socksProxy { + o.client.streamHandler = socks.StreamHandler + } else { + o.client.streamHandler = websocket.DefaultStreamHandler + } + return nil } // HelloWorld is an OriginService for the built-in Hello World server. @@ -228,26 +178,6 @@ func (o *helloWorld) start( return nil } -func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) { - // Rewrite the request URL so that it goes to the Hello World server. - req.URL.Host = o.server.Addr().String() - req.URL.Scheme = "https" - return o.transport.RoundTrip(req) -} - -func (o *helloWorld) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) { - d := &gws.Dialer{ - TLSClientConfig: o.transport.TLSClientConfig, - } - reqURL.Host = o.server.Addr().String() - reqURL.Scheme = "wss" - return d.Dial(reqURL.String(), headers) -} - -func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool { - return staticHost != "" || cfg.BastionMode -} - // statusCode is an OriginService that just responds with a given HTTP status. // Typical use-case is "user wants the catch-all rule to just respond 404". type statusCode struct { @@ -277,10 +207,6 @@ func (o *statusCode) start( return nil } -func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { - return o.resp, nil -} - type NopReadCloser struct{} // Read always returns EOF to signal end of input @@ -292,7 +218,7 @@ func (nrc *NopReadCloser) Close() error { return nil } -func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) { +func newHTTPTransport(service originService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) { originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log) if err != nil { return nil, errors.Wrap(err, "Error loading cert pool") @@ -337,19 +263,19 @@ func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerol return &httpTransport, nil } -// MockOriginService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior. -type MockOriginService struct { +// MockOriginHTTPService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior. +type MockOriginHTTPService struct { Transport http.RoundTripper } -func (mos MockOriginService) RoundTrip(req *http.Request) (*http.Response, error) { +func (mos MockOriginHTTPService) RoundTrip(req *http.Request) (*http.Response, error) { return mos.Transport.RoundTrip(req) } -func (mos MockOriginService) String() string { +func (mos MockOriginHTTPService) String() string { return "MockOriginService" } -func (mos MockOriginService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { +func (mos MockOriginHTTPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { return nil } diff --git a/ingress/rule.go b/ingress/rule.go index e91b4139..c9548fc3 100644 --- a/ingress/rule.go +++ b/ingress/rule.go @@ -17,7 +17,7 @@ type Rule struct { // A (probably local) address. Requests for a hostname which matches this // rule's hostname pattern will be proxied to the service running on this // address. - Service OriginService + Service originService // Configure the request cloudflared sends to this specific origin. Config OriginRequestConfig diff --git a/ingress/rule_test.go b/ingress/rule_test.go index f5bfcd92..e17d0857 100644 --- a/ingress/rule_test.go +++ b/ingress/rule_test.go @@ -14,7 +14,7 @@ func Test_rule_matches(t *testing.T) { type fields struct { Hostname string Path *regexp.Regexp - Service OriginService + Service originService } type args struct { requestURL *url.URL diff --git a/origin/cloudflared.log b/origin/cloudflared.log new file mode 100644 index 00000000..e69de29b diff --git a/origin/proxy.go b/origin/proxy.go index 247aa003..852355ee 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "io" - "net" "net/http" "strconv" "strings" @@ -15,7 +14,6 @@ import ( "github.com/cloudflare/cloudflared/ingress" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/websocket" - "github.com/pkg/errors" "github.com/rs/zerolog" ) @@ -24,15 +22,15 @@ const ( TagHeaderNamePrefix = "Cf-Warp-Tag-" ) -type client struct { +type proxy struct { ingressRules ingress.Ingress tags []tunnelpogs.Tag log *zerolog.Logger bufferPool *buffer.Pool } -func NewClient(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, log *zerolog.Logger) connection.OriginClient { - return &client{ +func NewOriginProxy(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, log *zerolog.Logger) connection.OriginProxy { + return &proxy{ ingressRules: ingressRules, tags: tags, log: log, @@ -40,36 +38,55 @@ func NewClient(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, log *zerolog } } -func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error { +func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error { incrementRequests() defer decrementConcurrentRequests() cfRay := findCfRayHeader(req) lbProbe := isLBProbeRequest(req) - c.appendTagHeaders(req) - rule, ruleNum := c.ingressRules.FindMatchingRule(req.Host, req.URL.Path) - c.logRequest(req, cfRay, lbProbe, ruleNum) + p.appendTagHeaders(req) + rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path) + p.logRequest(req, cfRay, lbProbe, ruleNum) var ( resp *http.Response err error ) + if isWebsocket { - resp, err = c.proxyWebsocket(w, req, rule) + go websocket.NewConn(w, p.log).Pinger(req.Context()) + + connClosedChan := make(chan struct{}) + err = p.proxyConnection(connClosedChan, w, req, rule) + if err == nil { + respHeader := websocket.NewResponseHeader(req) + status := http.StatusSwitchingProtocols + resp = &http.Response{ + Status: http.StatusText(status), + StatusCode: status, + Header: respHeader, + ContentLength: -1, + } + + w.WriteRespHeaders(http.StatusSwitchingProtocols, respHeader) + <-connClosedChan + } } else { - resp, err = c.proxyHTTP(w, req, rule) + resp, err = p.proxyHTTP(w, req, rule) } if err != nil { - c.logRequestError(err, cfRay, ruleNum) + p.logRequestError(err, cfRay, ruleNum) w.WriteErrorResponse() return err } - c.logOriginResponse(resp, cfRay, lbProbe, ruleNum) + + p.logOriginResponse(resp, cfRay, lbProbe, ruleNum) + return nil } -func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) { +func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) { // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate if rule.Config.DisableChunkedEncoding { req.TransferEncoding = []string{"gzip", "deflate"} @@ -87,73 +104,69 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule req.Host = hostHeader } - resp, err := rule.Service.RoundTrip(req) + httpService, ok := rule.Service.(ingress.HTTPOriginProxy) + if !ok { + p.log.Error().Msgf("%s is not a http service", rule.Service) + return nil, fmt.Errorf("Not a http service") + } + + resp, err := httpService.RoundTrip(req) if err != nil { return nil, errors.Wrap(err, "Error proxying request to origin") } defer resp.Body.Close() - err = w.WriteRespHeaders(resp) + err = w.WriteRespHeaders(resp.StatusCode, resp.Header) if err != nil { return nil, errors.Wrap(err, "Error writing response header") } if connection.IsServerSentEvent(resp.Header) { - c.log.Debug().Msg("Detected Server-Side Events from Origin") - c.writeEventStream(w, resp.Body) + p.log.Debug().Msg("Detected Server-Side Events from Origin") + p.writeEventStream(w, resp.Body) } else { // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream // compression generates dictionary on first write - buf := c.bufferPool.Get() - defer c.bufferPool.Put(buf) + buf := p.bufferPool.Get() + defer p.bufferPool.Put(buf) _, _ = io.CopyBuffer(w, resp.Body, buf) } return resp, nil } -func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) { +func (p *proxy) proxyConnection(connClosedChan chan struct{}, + conn io.ReadWriter, req *http.Request, rule *ingress.Rule) error { if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { req.Header.Set("Host", hostHeader) req.Host = hostHeader } - dialler, ok := rule.Service.(websocket.Dialler) + connectionService, ok := rule.Service.(ingress.StreamBasedOriginProxy) if !ok { - return nil, fmt.Errorf("Websockets aren't supported by the origin service '%s'", rule.Service) + p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service) + return fmt.Errorf("Not a connection-oriented service") } - conn, resp, err := websocket.ClientConnect(req, dialler) + originConn, err := connectionService.EstablishConnection(req) if err != nil { - return nil, err + return err } serveCtx, cancel := context.WithCancel(req.Context()) - connClosedChan := make(chan struct{}) go func() { // serveCtx is done if req is cancelled, or streamWebsocket returns <-serveCtx.Done() - _ = conn.Close() + originConn.Close() close(connClosedChan) }() - // Copy to/from stream to the undelying connection. Use the underlying - // connection because cloudflared doesn't operate on the message themselves - err = c.streamWebsocket(w, conn.UnderlyingConn(), resp) - cancel() + go func() { + originConn.Stream(conn) + cancel() + }() - // We need to make sure conn is closed before returning, otherwise we might write to conn after Proxy returns - <-connClosedChan - return resp, err -} - -func (c *client) streamWebsocket(w connection.ResponseWriter, conn net.Conn, resp *http.Response) error { - err := w.WriteRespHeaders(resp) - if err != nil { - return errors.Wrap(err, "Error writing websocket response header") - } - websocket.Stream(conn, w) return nil } -func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) { +func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) { reader := bufio.NewReader(respBody) for { line, err := reader.ReadBytes('\n') @@ -164,54 +177,54 @@ func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadC } } -func (c *client) appendTagHeaders(r *http.Request) { - for _, tag := range c.tags { +func (p *proxy) appendTagHeaders(r *http.Request) { + for _, tag := range p.tags { r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value) } } -func (c *client) logRequest(r *http.Request, cfRay string, lbProbe bool, ruleNum int) { +func (p *proxy) logRequest(r *http.Request, cfRay string, lbProbe bool, ruleNum int) { if cfRay != "" { - c.log.Debug().Msgf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto) + p.log.Debug().Msgf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto) } else if lbProbe { - c.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto) + p.log.Debug().Msgf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto) } else { - c.log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto) + p.log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto) } - c.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", cfRay, r.Header) - c.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %d", cfRay, ruleNum) + p.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", cfRay, r.Header) + p.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %d", cfRay, ruleNum) if contentLen := r.ContentLength; contentLen == -1 { - c.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", cfRay) + p.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", cfRay) } else { - c.log.Debug().Msgf("CF-RAY: %s Request content length %d", cfRay, contentLen) + p.log.Debug().Msgf("CF-RAY: %s Request content length %d", cfRay, contentLen) } } -func (c *client) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, ruleNum int) { +func (p *proxy) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, ruleNum int) { responseByCode.WithLabelValues(strconv.Itoa(r.StatusCode)).Inc() if cfRay != "" { - c.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, ruleNum) + p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, ruleNum) } else if lbProbe { - c.log.Debug().Msgf("Response to Load Balancer health check %s", r.Status) + p.log.Debug().Msgf("Response to Load Balancer health check %s", r.Status) } else { - c.log.Debug().Msgf("Status: %s served by ingress %d", r.Status, ruleNum) + p.log.Debug().Msgf("Status: %s served by ingress %d", r.Status, ruleNum) } - c.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) + p.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) if contentLen := r.ContentLength; contentLen == -1 { - c.log.Debug().Msgf("CF-RAY: %s Response content length unknown", cfRay) + p.log.Debug().Msgf("CF-RAY: %s Response content length unknown", cfRay) } else { - c.log.Debug().Msgf("CF-RAY: %s Response content length %d", cfRay, contentLen) + p.log.Debug().Msgf("CF-RAY: %s Response content length %d", cfRay, contentLen) } } -func (c *client) logRequestError(err error, cfRay string, ruleNum int) { +func (p *proxy) logRequestError(err error, cfRay string, ruleNum int) { requestErrors.Inc() if cfRay != "" { - c.log.Error().Msgf("CF-RAY: %s Proxying to ingress %d error: %v", cfRay, ruleNum, err) + p.log.Error().Msgf("CF-RAY: %s Proxying to ingress %d error: %v", cfRay, ruleNum, err) } else { - c.log.Error().Msgf("Proxying to ingress %d error: %v", ruleNum, err) + p.log.Error().Msgf("Proxying to ingress %d error: %v", ruleNum, err) } } diff --git a/origin/proxy_test.go b/origin/proxy_test.go index 193764f0..f1590255 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -5,7 +5,9 @@ import ( "context" "flag" "fmt" + "github.com/cloudflare/cloudflared/logger" "io" + "net" "net/http" "net/http/httptest" "sync" @@ -14,9 +16,11 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/ingress" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/websocket" "github.com/urfave/cli/v2" "github.com/gobwas/ws/wsutil" @@ -39,9 +43,9 @@ func newMockHTTPRespWriter() *mockHTTPRespWriter { } } -func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error { - w.WriteHeader(resp.StatusCode) - for header, val := range resp.Header { +func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) error { + w.WriteHeader(status) + for header, val := range header { w.Header()[header] = val } return nil @@ -125,28 +129,28 @@ func TestProxySingleOrigin(t *testing.T) { errC := make(chan error) require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC)) - client := NewClient(ingressRule, testTags, &log) - t.Run("testProxyHTTP", testProxyHTTP(t, client)) - t.Run("testProxyWebsocket", testProxyWebsocket(t, client)) - t.Run("testProxySSE", testProxySSE(t, client)) + proxy := NewOriginProxy(ingressRule, testTags, &log) + t.Run("testProxyHTTP", testProxyHTTP(t, proxy)) + t.Run("testProxyWebsocket", testProxyWebsocket(t, proxy)) + t.Run("testProxySSE", testProxySSE(t, proxy)) cancel() wg.Wait() } -func testProxyHTTP(t *testing.T, client connection.OriginClient) func(t *testing.T) { +func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) { return func(t *testing.T) { respWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) require.NoError(t, err) - err = client.Proxy(respWriter, req, false) + err = proxy.Proxy(respWriter, req, false) require.NoError(t, err) assert.Equal(t, http.StatusOK, respWriter.Code) } } -func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *testing.T) { +func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) { return func(t *testing.T) { // WSRoute is a websocket echo handler ctx, cancel := context.WithCancel(context.Background()) @@ -159,7 +163,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *te wg.Add(1) go func() { defer wg.Done() - err = client.Proxy(respWriter, req, true) + err = proxy.Proxy(respWriter, req, true) require.NoError(t, err) require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code) @@ -169,7 +173,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *te err = wsutil.WriteClientText(writePipe, msg) require.NoError(t, err) - // ReadServerText reads next data message from rw, considering that caller represents client side. + // ReadServerText reads next data message from rw, considering that caller represents proxy side. returnedMsg, err := wsutil.ReadServerText(respWriter.respBody()) require.NoError(t, err) require.Equal(t, msg, returnedMsg) @@ -186,7 +190,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *te } } -func testProxySSE(t *testing.T, client connection.OriginClient) func(t *testing.T) { +func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) { return func(t *testing.T) { var ( pushCount = 50 @@ -201,7 +205,7 @@ func testProxySSE(t *testing.T, client connection.OriginClient) func(t *testing. wg.Add(1) go func() { defer wg.Done() - err = client.Proxy(respWriter, req, false) + err = proxy.Proxy(respWriter, req, false) require.NoError(t, err) require.Equal(t, http.StatusOK, respWriter.Code) @@ -258,7 +262,7 @@ func TestProxyMultipleOrigins(t *testing.T) { var wg sync.WaitGroup require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC)) - client := NewClient(ingress, testTags, &log) + proxy := NewOriginProxy(ingress, testTags, &log) tests := []struct { url string @@ -294,7 +298,7 @@ func TestProxyMultipleOrigins(t *testing.T) { req, err := http.NewRequest(http.MethodGet, test.url, nil) require.NoError(t, err) - err = client.Proxy(respWriter, req, false) + err = proxy.Proxy(respWriter, req, false) require.NoError(t, err) assert.Equal(t, test.expectedStatus, respWriter.Code) @@ -327,7 +331,7 @@ func TestProxyError(t *testing.T) { { Hostname: "*", Path: nil, - Service: ingress.MockOriginService{ + Service: ingress.MockOriginHTTPService{ Transport: errorOriginTransport{}, }, }, @@ -336,14 +340,85 @@ func TestProxyError(t *testing.T) { log := zerolog.Nop() - client := NewClient(ingress, testTags, &log) + proxy := NewOriginProxy(ingress, testTags, &log) respWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) assert.NoError(t, err) - err = client.Proxy(respWriter, req, false) + err = proxy.Proxy(respWriter, req, false) assert.Error(t, err) assert.Equal(t, http.StatusBadGateway, respWriter.Code) assert.Equal(t, "http response error", respWriter.Body.String()) } + +func TestProxyBastionMode(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError) + flagSet.Bool("bastion", true, "") + + cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil) + err := cliCtx.Set(config.BastionFlag, "true") + require.NoError(t, err) + + allowURLFromArgs := false + ingressRule, err := ingress.NewSingleOrigin(cliCtx, allowURLFromArgs) + require.NoError(t, err) + + var wg sync.WaitGroup + errC := make(chan error) + + log := logger.Create(nil) + + ingressRule.StartOrigins(&wg, log, ctx.Done(), errC) + + proxy := NewOriginProxy(ingressRule, testTags, log) + + t.Run("testBastionWebsocket", testBastionWebsocket(proxy)) + cancel() + +} + +func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) { + return func(t *testing.T) { + // WSRoute is a websocket echo handler + ctx, cancel := context.WithCancel(context.Background()) + readPipe, _ := io.Pipe() + respWriter := newMockWSRespWriter(readPipe) + + var wg sync.WaitGroup + msgFromConn := []byte("data from websocket proxy") + ln, err := net.Listen("tcp", "127.0.0.1:0") + wg.Add(1) + go func() { + defer wg.Done() + defer ln.Close() + server, err := ln.Accept() + require.NoError(t, err) + conn := websocket.NewConn(server, nil) + conn.Write(msgFromConn) + }() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://dummy", nil) + req.Header.Set(h2mux.CFJumpDestinationHeader, ln.Addr().String()) + + wg.Add(1) + go func() { + defer wg.Done() + err = proxy.Proxy(respWriter, req, true) + require.NoError(t, err) + + require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code) + }() + + // ReadServerText reads next data message from rw, considering that caller represents proxy side. + returnedMsg, err := wsutil.ReadServerText(respWriter.respBody()) + if err != io.EOF { + require.NoError(t, err) + require.Equal(t, msgFromConn, returnedMsg) + } + + cancel() + wg.Wait() + } +} diff --git a/socks/request_handler.go b/socks/request_handler.go index fdb5a9fd..f94e0c1d 100644 --- a/socks/request_handler.go +++ b/socks/request_handler.go @@ -3,6 +3,7 @@ package socks import ( "fmt" "io" + "net" "strings" ) @@ -104,3 +105,11 @@ func (h *StandardRequestHandler) handleAssociate(conn io.ReadWriter, req *Reques } return nil } + +func StreamHandler(tunnelConn io.ReadWriter, originConn net.Conn) { + dialer := NewConnDialer(originConn) + requestHandler := NewRequestHandler(dialer) + socksServer := NewConnectionHandler(requestHandler) + + socksServer.Serve(tunnelConn) +} diff --git a/websocket/connection.go b/websocket/connection.go new file mode 100644 index 00000000..9616764e --- /dev/null +++ b/websocket/connection.go @@ -0,0 +1,118 @@ +package websocket + +import ( + "context" + "github.com/rs/zerolog" + "io" + "time" + + gobwas "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" + "github.com/gorilla/websocket" +) + +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 +) + +// GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter +// This is still used by access carrier +type GorillaConn struct { + *websocket.Conn + log *zerolog.Logger +} + +// Read will read messages from the websocket connection +func (c *GorillaConn) Read(p []byte) (int, error) { + _, message, err := c.Conn.ReadMessage() + if err != nil { + return 0, err + } + + return copy(p, message), nil + +} + +// Write will write messages to the websocket connection +func (c *GorillaConn) Write(p []byte) (int, error) { + if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil { + return 0, err + } + + return len(p), nil +} + +// pinger simulates the websocket connection to keep it alive +func (c *GorillaConn) pinger(ctx context.Context) { + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil { + c.log.Debug().Msgf("failed to send ping message: %s", err) + } + case <-ctx.Done(): + return + } + } +} + +type Conn struct { + rw io.ReadWriter + log *zerolog.Logger +} + +func NewConn(rw io.ReadWriter, log *zerolog.Logger) *Conn { + return &Conn{ + rw: rw, + log: log, + } +} + +// Read will read messages from the websocket connection +func (c *Conn) Read(reader []byte) (int, error) { + data, err := wsutil.ReadClientBinary(c.rw) + if err != nil { + return 0, err + } + return copy(reader, data), nil +} + +// Write will write messages to the websocket connection +func (c *Conn) Write(p []byte) (int, error) { + if err := wsutil.WriteServerBinary(c.rw, p); err != nil { + return 0, err + } + + return len(p), nil +} + +func (c *Conn) Pinger(ctx context.Context) { + pongMessge := wsutil.Message{ + OpCode: gobwas.OpPong, + Payload: []byte{}, + } + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil { + c.log.Err(err).Msgf("failed to write ping message") + } + if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil { + c.log.Err(err).Msgf("failed to write pong message") + } + case <-ctx.Done(): + return + } + } +} diff --git a/websocket/websocket.go b/websocket/websocket.go index 582db8f9..9807b6fe 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -2,7 +2,6 @@ package websocket import ( "crypto/sha1" - "crypto/tls" "encoding/base64" "io" "net" @@ -16,17 +15,6 @@ import ( "github.com/rs/zerolog" ) -const ( - // Time allowed to write a message to the peer. - writeWait = 10 * time.Second - - // Time allowed to read the next pong message from the peer. - pongWait = 60 * time.Second - - // Send pings to peer with this period. Must be less than pongWait. - pingPeriod = (pongWait * 9) / 10 -) - var stripWebsocketHeaders = []string{ "Upgrade", "Connection", @@ -35,70 +23,28 @@ var stripWebsocketHeaders = []string{ "Sec-Websocket-Extensions", } -// Conn is a wrapper around the standard gorilla websocket -// but implements a ReadWriter -type Conn struct { - *websocket.Conn -} - -// Read will read messages from the websocket connection -func (c *Conn) Read(p []byte) (int, error) { - _, message, err := c.Conn.ReadMessage() - if err != nil { - return 0, err - } - - return copy(p, message), nil - -} - -// Write will write messages to the websocket connection -func (c *Conn) Write(p []byte) (int, error) { - if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil { - return 0, err - } - - return len(p), nil -} - // IsWebSocketUpgrade checks to see if the request is a WebSocket connection. func IsWebSocketUpgrade(req *http.Request) bool { return websocket.IsWebSocketUpgrade(req) } -// Dialler is something that can proxy websocket requests. -type Dialler interface { - Dial(url *url.URL, headers http.Header) (*websocket.Conn, *http.Response, error) -} - -type defaultDialler struct { - tlsConfig *tls.Config -} - -func (dd *defaultDialler) Dial(url *url.URL, header http.Header) (*websocket.Conn, *http.Response, error) { - d := &websocket.Dialer{ - TLSClientConfig: dd.tlsConfig, - Proxy: http.ProxyFromEnvironment, - } - return d.Dial(url.String(), header) -} - // ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing // the connection. The response body may not contain the entire response and does // not need to be closed by the application. -func ClientConnect(req *http.Request, dialler Dialler) (*websocket.Conn, *http.Response, error) { +func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) { req.URL.Scheme = ChangeRequestScheme(req.URL) wsHeaders := websocketHeaders(req) - if dialler == nil { - dialler = new(defaultDialler) + dialler = &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + } } - conn, response, err := dialler.Dial(req.URL, wsHeaders) + conn, response, err := dialler.Dial(req.URL.String(), wsHeaders) if err != nil { return nil, response, err } response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req)) - return conn, response, err + return conn, response, nil } // Stream copies copy data to & from provided io.ReadWriters. @@ -121,8 +67,8 @@ func Stream(conn, backendConn io.ReadWriter) { // DefaultStreamHandler is provided to the the standard websocket to origin stream // This exist to allow SOCKS to deframe data before it gets to the origin -func DefaultStreamHandler(wsConn *Conn, remoteConn net.Conn, _ http.Header) { - Stream(wsConn, remoteConn) +func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn) { + Stream(originConn, remoteConn) } // StartProxyServer will start a websocket server that will decode @@ -132,7 +78,7 @@ func StartProxyServer( listener net.Listener, staticHost string, shutdownC <-chan struct{}, - streamHandler func(wsConn *Conn, remoteConn net.Conn, requestHeaders http.Header), + streamHandler func(originConn io.ReadWriter, remoteConn net.Conn), ) error { upgrader := websocket.Upgrader{ ReadBufferSize: 1024, @@ -159,7 +105,7 @@ type handler struct { log *zerolog.Logger staticHost string upgrader websocket.Upgrader - streamHandler func(wsConn *Conn, remoteConn net.Conn, requestHeaders http.Header) + streamHandler func(originConn io.ReadWriter, remoteConn net.Conn) } func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -192,14 +138,20 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } _ = conn.SetReadDeadline(time.Now().Add(pongWait)) conn.SetPongHandler(func(string) error { _ = conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) - done := make(chan struct{}) - go pinger(h.log, conn, done) - defer func() { - done <- struct{}{} - _ = conn.Close() - }() + gorillaConn := &GorillaConn{conn, h.log} + go gorillaConn.pinger(r.Context()) + defer conn.Close() - h.streamHandler(&Conn{conn}, stream, r.Header) + h.streamHandler(gorillaConn, stream) +} + +// NewResponseHeader returns headers needed to return to origin for completing handshake +func NewResponseHeader(req *http.Request) http.Header { + header := http.Header{} + header.Add("Connection", "Upgrade") + header.Add("Sec-Websocket-Accept", generateAcceptKey(req)) + header.Add("Upgrade", "websocket") + return header } // the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key, @@ -246,19 +198,3 @@ func ChangeRequestScheme(reqURL *url.URL) string { return reqURL.Scheme } } - -// pinger simulates the websocket connection to keep it alive -func pinger(logger *zerolog.Logger, ws *websocket.Conn, done chan struct{}) { - ticker := time.NewTicker(pingPeriod) - defer ticker.Stop() - for { - select { - case <-ticker.C: - if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil { - logger.Debug().Msgf("failed to send ping message: %s", err) - } - case <-done: - return - } - } -} diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go index 5b57b6c2..179098d6 100644 --- a/websocket/websocket_test.go +++ b/websocket/websocket_test.go @@ -11,7 +11,7 @@ import ( "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/tlsconfig" - + gws "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "golang.org/x/net/websocket" ) @@ -78,7 +78,7 @@ func TestServe(t *testing.T) { tlsConfig := websocketClientTLSConfig(t) assert.NotNil(t, tlsConfig) - d := defaultDialler{tlsConfig: tlsConfig} + d := gws.Dialer{TLSClientConfig: tlsConfig} conn, resp, err := ClientConnect(req, &d) assert.NoError(t, err) assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))