diff --git a/carrier/carrier.go b/carrier/carrier.go index 4c90c1e3..1e423b77 100644 --- a/carrier/carrier.go +++ b/carrier/carrier.go @@ -5,20 +5,25 @@ package carrier import ( "crypto/tls" + "fmt" "io" "net" "net/http" + "net/url" "os" "strings" "github.com/pkg/errors" "github.com/rs/zerolog" - "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/token" ) -const LogFieldOriginURL = "originURL" +const ( + LogFieldOriginURL = "originURL" + CFAccessTokenHeader = "Cf-Access-Token" + cfJumpDestinationHeader = "Cf-Access-Jump-Destination" +) type StartOptions struct { AppInfo *token.AppInfo @@ -32,15 +37,11 @@ type StartOptions struct { type Connection interface { // ServeStream is used to forward data from the client to the edge ServeStream(*StartOptions, io.ReadWriter) error - - // StartServer is used to listen for incoming connections from the edge to the origin - StartServer(net.Listener, string, <-chan struct{}) error } // StdinoutStream is empty struct for wrapping stdin/stdout // into a single ReadWriter -type StdinoutStream struct { -} +type StdinoutStream struct{} // Read will read from Stdin func (c *StdinoutStream) Read(p []byte) (int, error) { @@ -149,7 +150,7 @@ func BuildAccessRequest(options *StartOptions, log *zerolog.Logger) (*http.Reque if err != nil { return nil, err } - originRequest.Header.Set(h2mux.CFAccessTokenHeader, token) + originRequest.Header.Set(CFAccessTokenHeader, token) for k, v := range options.Headers { if len(v) >= 1 { @@ -159,3 +160,26 @@ func BuildAccessRequest(options *StartOptions, log *zerolog.Logger) (*http.Reque return originRequest, nil } + +func SetBastionDest(header http.Header, destination string) { + if destination != "" { + header.Set(cfJumpDestinationHeader, destination) + } +} + +func ResolveBastionDest(r *http.Request) (string, error) { + jumpDestination := r.Header.Get(cfJumpDestinationHeader) + if jumpDestination == "" { + return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side") + } + // Strip scheme and path set by client. Without a scheme + // Parsing a hostname and path without scheme might not return an error due to parsing ambiguities + if jumpURL, err := url.Parse(jumpDestination); err == nil && jumpURL.Host != "" { + return removePath(jumpURL.Host), nil + } + return removePath(jumpDestination), nil +} + +func removePath(dest string) string { + return strings.SplitN(dest, "/", 2)[0] +} diff --git a/carrier/carrier_test.go b/carrier/carrier_test.go index 7315873c..84738e7f 100644 --- a/carrier/carrier_test.go +++ b/carrier/carrier_test.go @@ -156,3 +156,99 @@ func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request { return req } + +func TestBastionDestination(t *testing.T) { + tests := []struct { + name string + header http.Header + expectedDest string + wantErr bool + }{ + { + name: "hostname destination", + header: http.Header{ + cfJumpDestinationHeader: []string{"localhost"}, + }, + expectedDest: "localhost", + }, + { + name: "hostname destination with port", + header: http.Header{ + cfJumpDestinationHeader: []string{"localhost:9000"}, + }, + expectedDest: "localhost:9000", + }, + { + name: "hostname destination with scheme and port", + header: http.Header{ + cfJumpDestinationHeader: []string{"ssh://localhost:9000"}, + }, + expectedDest: "localhost:9000", + }, + { + name: "full hostname url", + header: http.Header{ + cfJumpDestinationHeader: []string{"ssh://localhost:9000/metrics"}, + }, + expectedDest: "localhost:9000", + }, + { + name: "hostname destination with port and path", + header: http.Header{ + cfJumpDestinationHeader: []string{"localhost:9000/metrics"}, + }, + expectedDest: "localhost:9000", + }, + { + name: "ip destination", + header: http.Header{ + cfJumpDestinationHeader: []string{"127.0.0.1"}, + }, + expectedDest: "127.0.0.1", + }, + { + name: "ip destination with port", + header: http.Header{ + cfJumpDestinationHeader: []string{"127.0.0.1:9000"}, + }, + expectedDest: "127.0.0.1:9000", + }, + { + name: "ip destination with port and path", + header: http.Header{ + cfJumpDestinationHeader: []string{"127.0.0.1:9000/metrics"}, + }, + expectedDest: "127.0.0.1:9000", + }, + { + name: "ip destination with schem and port", + header: http.Header{ + cfJumpDestinationHeader: []string{"tcp://127.0.0.1:9000"}, + }, + expectedDest: "127.0.0.1:9000", + }, + { + name: "full ip url", + header: http.Header{ + cfJumpDestinationHeader: []string{"ssh://127.0.0.1:9000/metrics"}, + }, + expectedDest: "127.0.0.1:9000", + }, + { + name: "no destination", + wantErr: true, + }, + } + for _, test := range tests { + r := &http.Request{ + Header: test.header, + } + dest, err := ResolveBastionDest(r) + if test.wantErr { + assert.Error(t, err, "Test %s expects error", test.name) + } else { + assert.NoError(t, err, "Test %s expects no error, got error %v", test.name, err) + assert.Equal(t, test.expectedDest, dest, "Test %s expect dest %s, got %s", test.name, test.expectedDest, dest) + } + } +} diff --git a/carrier/websocket.go b/carrier/websocket.go index 99313451..6ae4252b 100644 --- a/carrier/websocket.go +++ b/carrier/websocket.go @@ -1,17 +1,13 @@ package carrier import ( - "fmt" "io" - "net" "net/http" "net/http/httputil" "github.com/gorilla/websocket" "github.com/rs/zerolog" - "github.com/cloudflare/cloudflared/ingress" - "github.com/cloudflare/cloudflared/socks" "github.com/cloudflare/cloudflared/token" cfwebsocket "github.com/cloudflare/cloudflared/websocket" ) @@ -23,20 +19,6 @@ type Websocket struct { isSocks bool } -type wsdialer struct { - conn *cfwebsocket.GorillaConn -} - -func (d *wsdialer) Dial(address string) (io.ReadWriteCloser, *socks.AddrSpec, error) { - local, ok := d.conn.LocalAddr().(*net.TCPAddr) - if !ok { - return nil, nil, fmt.Errorf("not a tcp connection") - } - - addr := socks.AddrSpec{IP: local.IP, Port: local.Port} - return d.conn, &addr, nil -} - // NewWSConnection returns a new connection object func NewWSConnection(log *zerolog.Logger) Connection { return &Websocket{ @@ -54,16 +36,10 @@ func (ws *Websocket) ServeStream(options *StartOptions, conn io.ReadWriter) erro } defer wsConn.Close() - ingress.Stream(wsConn, conn, ws.log) + cfwebsocket.Stream(wsConn, conn, ws.log) return nil } -// StartServer creates a Websocket server to listen for connections. -// This is used on the origin (tunnel) side to take data from the muxer and send it to the origin -func (ws *Websocket) StartServer(listener net.Listener, remote string, shutdownC <-chan struct{}) error { - return cfwebsocket.StartProxyServer(ws.log, listener, remote, shutdownC, ingress.DefaultStreamHandler) -} - // createWebsocketStream will create a WebSocket connection to stream data over // It also handles redirects from Access and will present that flow if // the token is not present on the request diff --git a/cmd/cloudflared/access/carrier.go b/cmd/cloudflared/access/carrier.go index fc169818..ead0dedf 100644 --- a/cmd/cloudflared/access/carrier.go +++ b/cmd/cloudflared/access/carrier.go @@ -6,19 +6,20 @@ import ( "net/http" "strings" - "github.com/cloudflare/cloudflared/carrier" - "github.com/cloudflare/cloudflared/config" - "github.com/cloudflare/cloudflared/h2mux" - "github.com/cloudflare/cloudflared/logger" - "github.com/cloudflare/cloudflared/validation" - "github.com/pkg/errors" "github.com/rs/zerolog" "github.com/urfave/cli/v2" + + "github.com/cloudflare/cloudflared/carrier" + "github.com/cloudflare/cloudflared/config" + "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/validation" ) const ( - LogFieldHost = "host" + LogFieldHost = "host" + cfAccessClientIDHeader = "Cf-Access-Client-Id" + cfAccessClientSecretHeader = "Cf-Access-Client-Secret" ) // StartForwarder starts a client side websocket forward @@ -31,16 +32,14 @@ func StartForwarder(forwarder config.Forwarder, shutdown <-chan struct{}, log *z // get the headers from the config file and add to the request headers := make(http.Header) if forwarder.TokenClientID != "" { - headers.Set(h2mux.CFAccessClientIDHeader, forwarder.TokenClientID) + headers.Set(cfAccessClientIDHeader, forwarder.TokenClientID) } if forwarder.TokenSecret != "" { - headers.Set(h2mux.CFAccessClientSecretHeader, forwarder.TokenSecret) + headers.Set(cfAccessClientSecretHeader, forwarder.TokenSecret) } - if forwarder.Destination != "" { - headers.Add(h2mux.CFJumpDestinationHeader, forwarder.Destination) - } + carrier.SetBastionDest(headers, forwarder.Destination) options := &carrier.StartOptions{ OriginURL: forwarder.URL, @@ -72,16 +71,13 @@ func ssh(c *cli.Context) error { // get the headers from the cmdline and add them headers := buildRequestHeaders(c.StringSlice(sshHeaderFlag)) if c.IsSet(sshTokenIDFlag) { - headers.Set(h2mux.CFAccessClientIDHeader, c.String(sshTokenIDFlag)) + headers.Set(cfAccessClientIDHeader, c.String(sshTokenIDFlag)) } if c.IsSet(sshTokenSecretFlag) { - headers.Set(h2mux.CFAccessClientSecretHeader, c.String(sshTokenSecretFlag)) + headers.Set(cfAccessClientSecretHeader, c.String(sshTokenSecretFlag)) } - destination := c.String(sshDestinationFlag) - if destination != "" { - headers.Add(h2mux.CFJumpDestinationHeader, destination) - } + carrier.SetBastionDest(headers, c.String(sshDestinationFlag)) options := &carrier.StartOptions{ OriginURL: originURL, diff --git a/cmd/cloudflared/access/cmd.go b/cmd/cloudflared/access/cmd.go index 1cebbede..414e008e 100644 --- a/cmd/cloudflared/access/cmd.go +++ b/cmd/cloudflared/access/cmd.go @@ -19,7 +19,6 @@ import ( "github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" - "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/sshgen" "github.com/cloudflare/cloudflared/token" @@ -286,7 +285,7 @@ func curl(c *cli.Context) error { } cmdArgs = append(cmdArgs, "-H") - cmdArgs = append(cmdArgs, fmt.Sprintf("%s: %s", h2mux.CFAccessTokenHeader, tok)) + cmdArgs = append(cmdArgs, fmt.Sprintf("%s: %s", carrier.CFAccessTokenHeader, tok)) return run("curl", cmdArgs...) } @@ -472,10 +471,10 @@ func isFileThere(candidate string) bool { func verifyTokenAtEdge(appUrl *url.URL, appInfo *token.AppInfo, c *cli.Context, log *zerolog.Logger) error { headers := buildRequestHeaders(c.StringSlice(sshHeaderFlag)) if c.IsSet(sshTokenIDFlag) { - headers.Add(h2mux.CFAccessClientIDHeader, c.String(sshTokenIDFlag)) + headers.Add(cfAccessClientIDHeader, c.String(sshTokenIDFlag)) } if c.IsSet(sshTokenSecretFlag) { - headers.Add(h2mux.CFAccessClientSecretHeader, c.String(sshTokenSecretFlag)) + headers.Add(cfAccessClientSecretHeader, c.String(sshTokenSecretFlag)) } options := &carrier.StartOptions{AppInfo: appInfo, OriginURL: appUrl.String(), Headers: headers} diff --git a/connection/h2mux.go b/connection/h2mux.go index 4941b1d6..b97208d3 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -234,7 +234,7 @@ func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, if err != nil { return nil, errors.Wrap(err, "Unexpected error from http.NewRequest") } - err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req) + err = H2RequestHeadersToH1Request(stream.Headers, req) if err != nil { return nil, errors.Wrap(err, "invalid request received") } @@ -246,15 +246,15 @@ type h2muxRespWriter struct { } func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error { - headers := h2mux.H1ResponseToH2ResponseHeaders(status, header) - headers = append(headers, h2mux.Header{Name: ResponseMetaHeaderField, Value: responseMetaHeaderOrigin}) + headers := H1ResponseToH2ResponseHeaders(status, header) + headers = append(headers, h2mux.Header{Name: ResponseMetaHeader, Value: responseMetaHeaderOrigin}) return rp.WriteHeaders(headers) } func (rp *h2muxRespWriter) WriteErrorResponse() { _ = rp.WriteHeaders([]h2mux.Header{ {Name: ":status", Value: "502"}, - {Name: ResponseMetaHeaderField, Value: responseMetaHeaderCfd}, + {Name: ResponseMetaHeader, Value: responseMetaHeaderCfd}, }) _, _ = rp.Write([]byte("502 Bad Gateway")) } diff --git a/connection/h2mux_test.go b/connection/h2mux_test.go index cdbd2a66..e6eab072 100644 --- a/connection/h2mux_test.go +++ b/connection/h2mux_test.go @@ -115,9 +115,9 @@ func TestServeStreamHTTP(t *testing.T) { require.True(t, hasHeader(stream, ":status", strconv.Itoa(test.expectedStatus))) if test.isProxyError { - assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderCfd)) + assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderCfd)) } else { - assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin)) + assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin)) body := make([]byte, len(test.expectedBody)) _, err = stream.Read(body) require.NoError(t, err) @@ -164,7 +164,7 @@ func TestServeStreamWS(t *testing.T) { require.NoError(t, err) require.True(t, hasHeader(stream, ":status", strconv.Itoa(http.StatusSwitchingProtocols))) - assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin)) + assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin)) data := []byte("test websocket") err = wsutil.WriteClientText(writePipe, data) @@ -268,7 +268,7 @@ func benchmarkServeStreamHTTPSimple(b *testing.B, test testRequest) { b.StopTimer() require.NoError(b, openstreamErr) - assert.True(b, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin)) + assert.True(b, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin)) require.True(b, hasHeader(stream, ":status", strconv.Itoa(http.StatusOK))) require.NoError(b, readBodyErr) require.Equal(b, test.expectedBody, body) diff --git a/connection/header.go b/connection/header.go index b9fcd003..c93c1073 100644 --- a/connection/header.go +++ b/connection/header.go @@ -1,21 +1,33 @@ package connection import ( + "encoding/base64" "fmt" "net/http" + "net/url" + "strconv" + "strings" + + "github.com/pkg/errors" "github.com/cloudflare/cloudflared/h2mux" ) -const ( - ResponseMetaHeaderField = "cf-cloudflared-response-meta" +var ( + // h2mux-style special headers + RequestUserHeaders = "cf-cloudflared-request-headers" + ResponseUserHeaders = "cf-cloudflared-response-headers" + ResponseMetaHeader = "cf-cloudflared-response-meta" + + // h2mux-style special headers + CanonicalResponseUserHeaders = http.CanonicalHeaderKey(ResponseUserHeaders) + CanonicalResponseMetaHeader = http.CanonicalHeaderKey(ResponseMetaHeader) ) var ( - canonicalResponseUserHeadersField = http.CanonicalHeaderKey(h2mux.ResponseUserHeadersField) - canonicalResponseMetaHeaderField = http.CanonicalHeaderKey(ResponseMetaHeaderField) - responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared") - responseMetaHeaderOrigin = mustInitRespMetaHeader("origin") + // pre-generate possible values for res + responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared") + responseMetaHeaderOrigin = mustInitRespMetaHeader("origin") ) type responseMetaHeader struct { @@ -29,3 +41,204 @@ func mustInitRespMetaHeader(src string) string { } return string(header) } + +var headerEncoding = base64.RawStdEncoding + +// note: all h2mux headers should be lower-case (http/2 style) +const () + +// H2RequestHeadersToH1Request converts the HTTP/2 headers coming from origintunneld +// to an HTTP/1 Request object destined for the local origin web service. +// This operation includes conversion of the pseudo-headers into their closest +// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3 +func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error { + for _, header := range h2 { + name := strings.ToLower(header.Name) + if !IsControlHeader(name) { + continue + } + + switch name { + case ":method": + h1.Method = header.Value + case ":scheme": + // noop - use the preexisting scheme from h1.URL + case ":authority": + // Otherwise the host header will be based on the origin URL + h1.Host = header.Value + case ":path": + // We don't want to be an "opinionated" proxy, so ideally we would use :path as-is. + // However, this HTTP/1 Request object belongs to the Go standard library, + // whose URL package makes some opinionated decisions about the encoding of + // URL characters: see the docs of https://godoc.org/net/url#URL, + // in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath, + // which is always used when computing url.URL.String(), whether we'd like it or not. + // + // Well, not *always*. We could circumvent this by using url.URL.Opaque. But + // that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque + // is treated differently when HTTP_PROXY is set! + // See https://github.com/golang/go/issues/5684#issuecomment-66080888 + // + // This means we are subject to the behavior of net/url's function `shouldEscape` + // (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101 + + if header.Value == "*" { + h1.URL.Path = "*" + continue + } + // Due to the behavior of validation.ValidateUrl, h1.URL may + // already have a partial value, with or without a trailing slash. + base := h1.URL.String() + base = strings.TrimRight(base, "/") + // But we know :path begins with '/', because we handled '*' above - see RFC7540 + requestURL, err := url.Parse(base + header.Value) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value)) + } + h1.URL = requestURL + case "content-length": + contentLength, err := strconv.ParseInt(header.Value, 10, 64) + if err != nil { + return fmt.Errorf("unparseable content length") + } + h1.ContentLength = contentLength + case RequestUserHeaders: + // Do not forward the serialized headers to the origin -- deserialize them, and ditch the serialized version + // Find and parse user headers serialized into a single one + userHeaders, err := DeserializeHeaders(header.Value) + if err != nil { + return errors.Wrap(err, "Unable to parse user headers") + } + for _, userHeader := range userHeaders { + h1.Header.Add(userHeader.Name, userHeader.Value) + } + default: + // All other control headers shall just be proxied transparently + h1.Header.Add(header.Name, header.Value) + } + } + + return nil +} + +func IsControlHeader(headerName string) bool { + return headerName == "content-length" || + headerName == "connection" || headerName == "upgrade" || // Websocket headers + strings.HasPrefix(headerName, ":") || + strings.HasPrefix(headerName, "cf-") +} + +// isWebsocketClientHeader returns true if the header name is required by the client to upgrade properly +func IsWebsocketClientHeader(headerName string) bool { + return headerName == "sec-websocket-accept" || + headerName == "connection" || + headerName == "upgrade" +} + +func H1ResponseToH2ResponseHeaders(status int, h1 http.Header) (h2 []h2mux.Header) { + h2 = []h2mux.Header{ + {Name: ":status", Value: strconv.Itoa(status)}, + } + userHeaders := make(http.Header, len(h1)) + for header, values := range h1 { + h2name := strings.ToLower(header) + if h2name == "content-length" { + // This header has meaning in HTTP/2 and will be used by the edge, + // so it should be sent as an HTTP/2 response header. + + // Since these are http2 headers, they're required to be lowercase + h2 = append(h2, h2mux.Header{Name: "content-length", Value: values[0]}) + } else if !IsControlHeader(h2name) || IsWebsocketClientHeader(h2name) { + // User headers, on the other hand, must all be serialized so that + // HTTP/2 header validation won't be applied to HTTP/1 header values + userHeaders[header] = values + } + } + + // Perform user header serialization and set them in the single header + h2 = append(h2, h2mux.Header{Name: ResponseUserHeaders, Value: SerializeHeaders(userHeaders)}) + return h2 +} + +// Serialize HTTP1.x headers by base64-encoding each header name and value, +// and then joining them in the format of [key:value;] +func SerializeHeaders(h1Headers http.Header) string { + // compute size of the fully serialized value and largest temp buffer we will need + serializedLen := 0 + maxTempLen := 0 + for headerName, headerValues := range h1Headers { + for _, headerValue := range headerValues { + nameLen := headerEncoding.EncodedLen(len(headerName)) + valueLen := headerEncoding.EncodedLen(len(headerValue)) + const delims = 2 + serializedLen += delims + nameLen + valueLen + if nameLen > maxTempLen { + maxTempLen = nameLen + } + if valueLen > maxTempLen { + maxTempLen = valueLen + } + } + } + var buf strings.Builder + buf.Grow(serializedLen) + + temp := make([]byte, maxTempLen) + writeB64 := func(s string) { + n := headerEncoding.EncodedLen(len(s)) + if n > len(temp) { + temp = make([]byte, n) + } + headerEncoding.Encode(temp[:n], []byte(s)) + buf.Write(temp[:n]) + } + + for headerName, headerValues := range h1Headers { + for _, headerValue := range headerValues { + if buf.Len() > 0 { + buf.WriteByte(';') + } + writeB64(headerName) + buf.WriteByte(':') + writeB64(headerValue) + } + } + + return buf.String() +} + +// Deserialize headers serialized by `SerializeHeader` +func DeserializeHeaders(serializedHeaders string) ([]h2mux.Header, error) { + const unableToDeserializeErr = "Unable to deserialize headers" + + var deserialized []h2mux.Header + for _, serializedPair := range strings.Split(serializedHeaders, ";") { + if len(serializedPair) == 0 { + continue + } + + serializedHeaderParts := strings.Split(serializedPair, ":") + if len(serializedHeaderParts) != 2 { + return nil, errors.New(unableToDeserializeErr) + } + + serializedName := serializedHeaderParts[0] + serializedValue := serializedHeaderParts[1] + deserializedName := make([]byte, headerEncoding.DecodedLen(len(serializedName))) + deserializedValue := make([]byte, headerEncoding.DecodedLen(len(serializedValue))) + + if _, err := headerEncoding.Decode(deserializedName, []byte(serializedName)); err != nil { + return nil, errors.Wrap(err, unableToDeserializeErr) + } + if _, err := headerEncoding.Decode(deserializedValue, []byte(serializedValue)); err != nil { + return nil, errors.Wrap(err, unableToDeserializeErr) + } + + deserialized = append(deserialized, h2mux.Header{ + Name: string(deserializedName), + Value: string(deserializedValue), + }) + } + + return deserialized, nil +} diff --git a/h2mux/header_test.go b/connection/header_test.go similarity index 92% rename from h2mux/header_test.go rename to connection/header_test.go index e9da95e0..964e7751 100644 --- a/h2mux/header_test.go +++ b/connection/header_test.go @@ -1,4 +1,4 @@ -package h2mux +package connection import ( "fmt" @@ -14,9 +14,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/cloudflare/cloudflared/h2mux" ) -type ByName []Header +type ByName []h2mux.Header func (a ByName) Len() int { return len(a) } func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] } @@ -37,16 +39,16 @@ func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) { "Mock header 2": {"Mock value 2"}, } - headersConversionErr := H2RequestHeadersToH1Request(createSerializedHeaders(RequestUserHeadersField, mockHeaders), request) + headersConversionErr := H2RequestHeadersToH1Request(createSerializedHeaders(RequestUserHeaders, mockHeaders), request) assert.True(t, reflect.DeepEqual(mockHeaders, request.Header)) assert.NoError(t, headersConversionErr) } -func createSerializedHeaders(headersField string, headers http.Header) []Header { - return []Header{{ - headersField, - SerializeHeaders(headers), +func createSerializedHeaders(headersField string, headers http.Header) []h2mux.Header { + return []h2mux.Header{{ + Name: headersField, + Value: SerializeHeaders(headers), }} } @@ -54,15 +56,16 @@ func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) assert.NoError(t, err) + emptyHeaders := make(http.Header) headersConversionErr := H2RequestHeadersToH1Request( - []Header{{ - RequestUserHeadersField, - SerializeHeaders(http.Header{}), + []h2mux.Header{{ + Name: RequestUserHeaders, + Value: SerializeHeaders(emptyHeaders), }}, request, ) - assert.True(t, reflect.DeepEqual(http.Header{}, request.Header)) + assert.True(t, reflect.DeepEqual(emptyHeaders, request.Header)) assert.NoError(t, headersConversionErr) } @@ -70,9 +73,9 @@ func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) assert.NoError(t, err) - mockRequestHeaders := []Header{ + mockRequestHeaders := []h2mux.Header{ {Name: ":path", Value: "//bad_path/"}, - {Name: RequestUserHeadersField, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, + {Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, } headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) @@ -90,9 +93,9 @@ func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) assert.NoError(t, err) - mockRequestHeaders := []Header{ + mockRequestHeaders := []h2mux.Header{ {Name: ":path", Value: "/?query=mock%20value"}, - {Name: RequestUserHeadersField, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, + {Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, } headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) @@ -110,9 +113,9 @@ func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) assert.NoError(t, err) - mockRequestHeaders := []Header{ + mockRequestHeaders := []h2mux.Header{ {Name: ":path", Value: "/mock%20path"}, - {Name: RequestUserHeadersField, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, + {Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, } headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) @@ -267,9 +270,9 @@ func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) { request, err := http.NewRequest(http.MethodGet, requestURL, nil) assert.NoError(t, err) - mockRequestHeaders := []Header{ + mockRequestHeaders := []h2mux.Header{ {Name: ":path", Value: testCase.path}, - {Name: RequestUserHeadersField, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, + {Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, } headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) @@ -337,12 +340,12 @@ func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) { const expectedMethod = "POST" const expectedHostname = "request.hostname.example.com" - h2 := []Header{ + h2 := []h2mux.Header{ {Name: ":method", Value: expectedMethod}, {Name: ":scheme", Value: testScheme}, {Name: ":authority", Value: expectedHostname}, {Name: ":path", Value: testPath}, - {Name: RequestUserHeadersField, Value: ""}, + {Name: RequestUserHeaders, Value: ""}, } h1, err := http.NewRequest("GET", testOrigin.url, nil) require.NoError(t, err) @@ -424,10 +427,10 @@ func randomHTTP2Path(t *testing.T, rand *rand.Rand) string { return result } -func stdlibHeaderToH2muxHeader(headers http.Header) (h2muxHeaders []Header) { +func stdlibHeaderToH2muxHeader(headers http.Header) (h2muxHeaders []h2mux.Header) { for name, values := range headers { for _, value := range values { - h2muxHeaders = append(h2muxHeaders, Header{name, value}) + h2muxHeaders = append(h2muxHeaders, h2mux.Header{Name: name, Value: value}) } } @@ -515,14 +518,14 @@ func TestParseHeaders(t *testing.T) { "Mock-Header-Three": {"3"}, } - mockHeaders := []Header{ + mockHeaders := []h2mux.Header{ {Name: "One", Value: "1"}, // will be dropped {Name: "Cf-Two", Value: "cf-value-1"}, {Name: "Cf-Two", Value: "cf-value-2"}, - {Name: RequestUserHeadersField, Value: SerializeHeaders(mockUserHeadersToSerialize)}, + {Name: RequestUserHeaders, Value: SerializeHeaders(mockUserHeadersToSerialize)}, } - expectedHeaders := []Header{ + expectedHeaders := []h2mux.Header{ {Name: "Cf-Two", Value: "cf-value-1"}, {Name: "Cf-Two", Value: "cf-value-2"}, {Name: "Mock-Header-One", Value: "1"}, @@ -583,7 +586,7 @@ func TestH1ResponseToH2ResponseHeaders(t *testing.T) { serializedHeadersIndex := -1 for i, header := range headers { - if header.Name == ResponseUserHeadersField { + if header.Name == ResponseUserHeaders { serializedHeadersIndex = i break } @@ -593,7 +596,7 @@ func TestH1ResponseToH2ResponseHeaders(t *testing.T) { headers[:serializedHeadersIndex], headers[serializedHeadersIndex+1:]..., ) - expectedControlHeaders := []Header{ + expectedControlHeaders := []h2mux.Header{ {Name: ":status", Value: "200"}, {Name: "content-length", Value: "123"}, } @@ -601,7 +604,7 @@ func TestH1ResponseToH2ResponseHeaders(t *testing.T) { assert.ElementsMatch(t, expectedControlHeaders, actualControlHeaders) actualUserHeaders, err := DeserializeHeaders(headers[serializedHeadersIndex].Value) - expectedUserHeaders := []Header{ + expectedUserHeaders := []h2mux.Header{ {Name: "User-header-one", Value: ""}, {Name: "User-header-two", Value: "1"}, {Name: "User-header-two", Value: "2"}, @@ -630,7 +633,7 @@ func TestHeaderSize(t *testing.T) { } for _, header := range serializedHeaders { - if header.Name != ResponseUserHeadersField { + if header.Name != ResponseUserHeaders { continue } diff --git a/connection/http2.go b/connection/http2.go index 02898088..54f3c667 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -13,15 +13,15 @@ import ( "github.com/rs/zerolog" "golang.org/x/net/http2" - "github.com/cloudflare/cloudflared/h2mux" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) +// note: these constants are exported so we can reuse them in the edge-side code const ( - internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade" - tcpStreamHeader = "Cf-Cloudflared-Proxy-Src" - websocketUpgrade = "websocket" - controlStreamUpgrade = "control-stream" + InternalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade" + InternalTCPProxySrcHeader = "Cf-Cloudflared-Proxy-Src" + WebsocketUpgrade = "websocket" + ControlStreamUpgrade = "control-stream" ) var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed") @@ -178,25 +178,23 @@ func newHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type) ( func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error { dest := rp.w.Header() userHeaders := make(http.Header, len(header)) - for header, values := range header { + for name, values := range header { // Since these are http2 headers, they're required to be lowercase - h2name := strings.ToLower(header) - for _, v := range values { - if h2name == "content-length" { - // This header has meaning in HTTP/2 and will be used by the edge, - // so it should be sent as an HTTP/2 response header. - dest.Add(h2name, v) - // Since these are http2 headers, they're required to be lowercase - } else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) { - // User headers, on the other hand, must all be serialized so that - // HTTP/2 header validation won't be applied to HTTP/1 header values - userHeaders.Add(h2name, v) - } + h2name := strings.ToLower(name) + if h2name == "content-length" { + // This header has meaning in HTTP/2 and will be used by the edge, + // so it should be sent as an HTTP/2 response header. + dest[name] = values + // Since these are http2 headers, they're required to be lowercase + } else if !IsControlHeader(h2name) || IsWebsocketClientHeader(h2name) { + // User headers, on the other hand, must all be serialized so that + // HTTP/2 header validation won't be applied to HTTP/1 header values + userHeaders[name] = values } } // Perform user header serialization and set them in the single header - dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders)) + dest.Set(CanonicalResponseUserHeaders, SerializeHeaders(userHeaders)) rp.setResponseMetaHeader(responseMetaHeaderOrigin) // HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1 if status == http.StatusSwitchingProtocols { @@ -218,7 +216,7 @@ func (rp *http2RespWriter) WriteErrorResponse() { } func (rp *http2RespWriter) setResponseMetaHeader(value string) { - rp.w.Header().Set(canonicalResponseMetaHeaderField, value) + rp.w.Header().Set(CanonicalResponseMetaHeader, value) } func (rp *http2RespWriter) Read(p []byte) (n int, err error) { @@ -258,18 +256,18 @@ func determineHTTP2Type(r *http.Request) Type { } func isControlStreamUpgrade(r *http.Request) bool { - return r.Header.Get(internalUpgradeHeader) == controlStreamUpgrade + return r.Header.Get(InternalUpgradeHeader) == ControlStreamUpgrade } func isWebsocketUpgrade(r *http.Request) bool { - return r.Header.Get(internalUpgradeHeader) == websocketUpgrade + return r.Header.Get(InternalUpgradeHeader) == WebsocketUpgrade } // IsTCPStream discerns if the connection request needs a tcp stream proxy. func IsTCPStream(r *http.Request) bool { - return r.Header.Get(tcpStreamHeader) != "" + return r.Header.Get(InternalTCPProxySrcHeader) != "" } func stripWebsocketUpgradeHeader(r *http.Request) { - r.Header.Del(internalUpgradeHeader) + r.Header.Del(InternalUpgradeHeader) } diff --git a/connection/http2_test.go b/connection/http2_test.go index f0fe171c..c1b2ba4d 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -103,9 +103,9 @@ func TestServeHTTP(t *testing.T) { require.Equal(t, test.expectedBody, respBody) } if test.isProxyError { - require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeaderField)) + require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader)) } else { - require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeaderField)) + require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader)) } } cancel() @@ -191,7 +191,7 @@ func TestServeWS(t *testing.T) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe) require.NoError(t, err) - req.Header.Set(internalUpgradeHeader, websocketUpgrade) + req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade) wg.Add(1) go func() { @@ -211,7 +211,7 @@ func TestServeWS(t *testing.T) { resp := respWriter.Result() // http2RespWriter should rewrite status 101 to 200 require.Equal(t, http.StatusOK, resp.StatusCode) - require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeaderField)) + require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader)) wg.Wait() } @@ -235,7 +235,7 @@ func TestServeControlStream(t *testing.T) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil) require.NoError(t, err) - req.Header.Set(internalUpgradeHeader, controlStreamUpgrade) + req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade) edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn) require.NoError(t, err) @@ -274,7 +274,7 @@ func TestFailRegistration(t *testing.T) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil) require.NoError(t, err) - req.Header.Set(internalUpgradeHeader, controlStreamUpgrade) + req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade) edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn) require.NoError(t, err) @@ -310,7 +310,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil) require.NoError(t, err) - req.Header.Set(internalUpgradeHeader, controlStreamUpgrade) + req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade) edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn) require.NoError(t, err) diff --git a/h2mux/header.go b/h2mux/header.go deleted file mode 100644 index 848c93b0..00000000 --- a/h2mux/header.go +++ /dev/null @@ -1,234 +0,0 @@ -package h2mux - -import ( - "encoding/base64" - "fmt" - "net/http" - "net/url" - "strconv" - "strings" - - "github.com/pkg/errors" -) - -type Header struct { - Name, Value string -} - -var headerEncoding = base64.RawStdEncoding - -const ( - RequestUserHeadersField = "cf-cloudflared-request-headers" - ResponseUserHeadersField = "cf-cloudflared-response-headers" - - CFAccessTokenHeader = "cf-access-token" - CFJumpDestinationHeader = "CF-Access-Jump-Destination" - CFAccessClientIDHeader = "CF-Access-Client-Id" - CFAccessClientSecretHeader = "CF-Access-Client-Secret" -) - -// H2RequestHeadersToH1Request converts the HTTP/2 headers coming from origintunneld -// to an HTTP/1 Request object destined for the local origin web service. -// This operation includes conversion of the pseudo-headers into their closest -// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3 -func H2RequestHeadersToH1Request(h2 []Header, h1 *http.Request) error { - for _, header := range h2 { - name := strings.ToLower(header.Name) - if !IsControlHeader(name) { - continue - } - - switch name { - case ":method": - h1.Method = header.Value - case ":scheme": - // noop - use the preexisting scheme from h1.URL - case ":authority": - // Otherwise the host header will be based on the origin URL - h1.Host = header.Value - case ":path": - // We don't want to be an "opinionated" proxy, so ideally we would use :path as-is. - // However, this HTTP/1 Request object belongs to the Go standard library, - // whose URL package makes some opinionated decisions about the encoding of - // URL characters: see the docs of https://godoc.org/net/url#URL, - // in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath, - // which is always used when computing url.URL.String(), whether we'd like it or not. - // - // Well, not *always*. We could circumvent this by using url.URL.Opaque. But - // that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque - // is treated differently when HTTP_PROXY is set! - // See https://github.com/golang/go/issues/5684#issuecomment-66080888 - // - // This means we are subject to the behavior of net/url's function `shouldEscape` - // (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101 - - if header.Value == "*" { - h1.URL.Path = "*" - continue - } - // Due to the behavior of validation.ValidateUrl, h1.URL may - // already have a partial value, with or without a trailing slash. - base := h1.URL.String() - base = strings.TrimRight(base, "/") - // But we know :path begins with '/', because we handled '*' above - see RFC7540 - requestURL, err := url.Parse(base + header.Value) - if err != nil { - return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value)) - } - h1.URL = requestURL - case "content-length": - contentLength, err := strconv.ParseInt(header.Value, 10, 64) - if err != nil { - return fmt.Errorf("unparseable content length") - } - h1.ContentLength = contentLength - case RequestUserHeadersField: - // Do not forward the serialized headers to the origin -- deserialize them, and ditch the serialized version - // Find and parse user headers serialized into a single one - userHeaders, err := ParseUserHeaders(RequestUserHeadersField, h2) - if err != nil { - return errors.Wrap(err, "Unable to parse user headers") - } - for _, userHeader := range userHeaders { - h1.Header.Add(http.CanonicalHeaderKey(userHeader.Name), userHeader.Value) - } - default: - // All other control headers shall just be proxied transparently - h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value) - } - } - - return nil -} - -func ParseUserHeaders(headerNameToParseFrom string, headers []Header) ([]Header, error) { - for _, header := range headers { - if header.Name == headerNameToParseFrom { - return DeserializeHeaders(header.Value) - } - } - - return nil, fmt.Errorf("%v header not found", RequestUserHeadersField) -} - -func IsControlHeader(headerName string) bool { - return headerName == "content-length" || - headerName == "connection" || headerName == "upgrade" || // Websocket headers - strings.HasPrefix(headerName, ":") || - strings.HasPrefix(headerName, "cf-") -} - -// isWebsocketClientHeader returns true if the header name is required by the client to upgrade properly -func IsWebsocketClientHeader(headerName string) bool { - return headerName == "sec-websocket-accept" || - headerName == "connection" || - headerName == "upgrade" -} - -func H1ResponseToH2ResponseHeaders(status int, h1 http.Header) (h2 []Header) { - h2 = []Header{ - {Name: ":status", Value: strconv.Itoa(status)}, - } - userHeaders := make(http.Header, len(h1)) - for header, values := range h1 { - h2name := strings.ToLower(header) - if h2name == "content-length" { - // This header has meaning in HTTP/2 and will be used by the edge, - // so it should be sent as an HTTP/2 response header. - - // Since these are http2 headers, they're required to be lowercase - h2 = append(h2, Header{Name: "content-length", Value: values[0]}) - } else if !IsControlHeader(h2name) || IsWebsocketClientHeader(h2name) { - // User headers, on the other hand, must all be serialized so that - // HTTP/2 header validation won't be applied to HTTP/1 header values - userHeaders[header] = values - } - } - - // Perform user header serialization and set them in the single header - h2 = append(h2, Header{ResponseUserHeadersField, SerializeHeaders(userHeaders)}) - return h2 -} - -// Serialize HTTP1.x headers by base64-encoding each header name and value, -// and then joining them in the format of [key:value;] -func SerializeHeaders(h1Headers http.Header) string { - // compute size of the fully serialized value and largest temp buffer we will need - serializedLen := 0 - maxTempLen := 0 - for headerName, headerValues := range h1Headers { - for _, headerValue := range headerValues { - nameLen := headerEncoding.EncodedLen(len(headerName)) - valueLen := headerEncoding.EncodedLen(len(headerValue)) - const delims = 2 - serializedLen += delims + nameLen + valueLen - if nameLen > maxTempLen { - maxTempLen = nameLen - } - if valueLen > maxTempLen { - maxTempLen = valueLen - } - } - } - var buf strings.Builder - buf.Grow(serializedLen) - - temp := make([]byte, maxTempLen) - writeB64 := func(s string) { - n := headerEncoding.EncodedLen(len(s)) - if n > len(temp) { - temp = make([]byte, n) - } - headerEncoding.Encode(temp[:n], []byte(s)) - buf.Write(temp[:n]) - } - - for headerName, headerValues := range h1Headers { - for _, headerValue := range headerValues { - if buf.Len() > 0 { - buf.WriteByte(';') - } - writeB64(headerName) - buf.WriteByte(':') - writeB64(headerValue) - } - } - - return buf.String() -} - -// Deserialize headers serialized by `SerializeHeader` -func DeserializeHeaders(serializedHeaders string) ([]Header, error) { - const unableToDeserializeErr = "Unable to deserialize headers" - - var deserialized []Header - for _, serializedPair := range strings.Split(serializedHeaders, ";") { - if len(serializedPair) == 0 { - continue - } - - serializedHeaderParts := strings.Split(serializedPair, ":") - if len(serializedHeaderParts) != 2 { - return nil, errors.New(unableToDeserializeErr) - } - - serializedName := serializedHeaderParts[0] - serializedValue := serializedHeaderParts[1] - deserializedName := make([]byte, headerEncoding.DecodedLen(len(serializedName))) - deserializedValue := make([]byte, headerEncoding.DecodedLen(len(serializedValue))) - - if _, err := headerEncoding.Decode(deserializedName, []byte(serializedName)); err != nil { - return nil, errors.Wrap(err, unableToDeserializeErr) - } - if _, err := headerEncoding.Decode(deserializedValue, []byte(serializedValue)); err != nil { - return nil, errors.Wrap(err, unableToDeserializeErr) - } - - deserialized = append(deserialized, Header{ - Name: string(deserializedName), - Value: string(deserializedValue), - }) - } - - return deserialized, nil -} diff --git a/h2mux/muxedstream.go b/h2mux/muxedstream.go index 66b4507e..2e75735f 100644 --- a/h2mux/muxedstream.go +++ b/h2mux/muxedstream.go @@ -23,6 +23,10 @@ type MuxedStreamDataSignaller interface { Signal(ID uint32) } +type Header struct { + Name, Value string +} + // MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data. type MuxedStream struct { streamID uint32 @@ -74,8 +78,6 @@ type MuxedStream struct { sentEOF bool // true if the peer sent us an EOF receivedEOF bool - // If valid, tunnelHostname is used to identify which origin service is the intended recipient of the request - tunnelHostname TunnelHostname // Compression-related fields receivedUseDict bool method string @@ -252,10 +254,6 @@ func (s *MuxedStream) IsRPCStream() bool { return true } -func (s *MuxedStream) TunnelHostname() TunnelHostname { - return s.tunnelHostname -} - // Block until a value is sent on writeBufferHasSpace. // Must be called while holding writeLock func (s *MuxedStream) awaitWriteBufferHasSpace() { diff --git a/h2mux/muxreader.go b/h2mux/muxreader.go index 1f3fab8b..2716508f 100644 --- a/h2mux/muxreader.go +++ b/h2mux/muxreader.go @@ -12,10 +12,6 @@ import ( "golang.org/x/net/http2" ) -const ( - CloudflaredProxyTunnelHostnameHeader = "cf-cloudflared-proxy-tunnel-hostname" -) - type MuxReader struct { // f is used to read HTTP2 frames. f *http2.Framer @@ -252,8 +248,6 @@ func (r *MuxReader) receiveHeaderData(frame *http2.MetaHeadersFrame) error { if r.dictionaries.write != nil { continue } - case CloudflaredProxyTunnelHostnameHeader: - stream.tunnelHostname = TunnelHostname(header.Value) } headers = append(headers, Header{Name: header.Name, Value: header.Value}) } diff --git a/h2mux/muxreader_test.go b/h2mux/muxreader_test.go index 09ff4a4d..10ae7ff8 100644 --- a/h2mux/muxreader_test.go +++ b/h2mux/muxreader_test.go @@ -21,10 +21,6 @@ var ( Name: ":path", Value: "/api/tunnels", } - tunnelHostnameHeader = Header{ - Name: CloudflaredProxyTunnelHostnameHeader, - Value: "tunnel.example.com", - } respStatusHeader = Header{ Name: ":status", Value: "200", @@ -42,15 +38,6 @@ func (mosh *mockOriginStreamHandler) ServeStream(stream *MuxedStream) error { return nil } -func getCloudflaredProxyTunnelHostnameHeader(stream *MuxedStream) string { - for _, header := range stream.Headers { - if header.Name == CloudflaredProxyTunnelHostnameHeader { - return header.Value - } - } - return "" -} - func assertOpenStreamSucceed(t *testing.T, stream *MuxedStream, err error) { assert.NoError(t, err) assert.Len(t, stream.Headers, 1) @@ -72,13 +59,11 @@ func TestMissingHeaders(t *testing.T) { }, } - // Request doesn't contain CloudflaredProxyTunnelHostnameHeader stream, err := muxPair.EdgeMux.OpenStream(ctx, reqHeaders, nil) assertOpenStreamSucceed(t, stream, err) assert.Empty(t, originHandler.stream.method) assert.Empty(t, originHandler.stream.path) - assert.False(t, originHandler.stream.TunnelHostname().IsSet()) } func TestReceiveHeaderData(t *testing.T) { @@ -90,18 +75,14 @@ func TestReceiveHeaderData(t *testing.T) { methodHeader, schemeHeader, pathHeader, - tunnelHostnameHeader, } ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - reqHeaders = append(reqHeaders, tunnelHostnameHeader) stream, err := muxPair.EdgeMux.OpenStream(ctx, reqHeaders, nil) assertOpenStreamSucceed(t, stream, err) assert.Equal(t, methodHeader.Value, originHandler.stream.method) assert.Equal(t, pathHeader.Value, originHandler.stream.path) - assert.True(t, originHandler.stream.TunnelHostname().IsSet()) - assert.Equal(t, tunnelHostnameHeader.Value, originHandler.stream.TunnelHostname().String()) } diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index 63151f33..3ed17186 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -25,34 +25,10 @@ type OriginConnection interface { type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger) -// Stream copies copy data to & from provided io.ReadWriters. -func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) { - proxyDone := make(chan struct{}, 2) - - go func() { - _, err := io.Copy(conn, backendConn) - if err != nil { - log.Debug().Msgf("conn to backendConn copy: %v", err) - } - proxyDone <- struct{}{} - }() - - go func() { - _, err := io.Copy(backendConn, conn) - if err != nil { - log.Debug().Msgf("backendConn to conn copy: %v", err) - } - proxyDone <- struct{}{} - }() - - // If one side is done, we are done. - <-proxyDone -} - // DefaultStreamHandler is an implementation of streamHandlerFunc that // performs a two way io.Copy between originConn and remoteConn. func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger) { - Stream(originConn, remoteConn, log) + websocket.Stream(originConn, remoteConn, log) } // tcpConnection is an OriginConnection that directly streams to raw TCP. @@ -61,7 +37,7 @@ type tcpConnection struct { } func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { - Stream(tunnelConn, tc.conn, log) + websocket.Stream(tunnelConn, tc.conn, log) } func (tc *tcpConnection) Close() { @@ -89,7 +65,7 @@ type wsConnection struct { } func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { - Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log) + websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log) } func (wsc *wsConnection) Close() { diff --git a/ingress/origin_connection_test.go b/ingress/origin_connection_test.go index 3c853d21..78a2a151 100644 --- a/ingress/origin_connection_test.go +++ b/ingress/origin_connection_test.go @@ -22,6 +22,7 @@ import ( "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/socks" + "github.com/cloudflare/cloudflared/websocket" ) const ( @@ -157,7 +158,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) { require.NoError(t, err) defer wsForwarderInConn.Close() - Stream(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger) + websocket.Stream(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger) return nil }) diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index 2c01170c..08ef4bd9 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -4,12 +4,10 @@ import ( "fmt" "net" "net/http" - "net/url" - "strings" "github.com/pkg/errors" - "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/websocket" ) @@ -106,7 +104,7 @@ func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnectio var err error dest := o.dest if o.isBastion { - dest, err = o.bastionDest(r) + dest, err = carrier.ResolveBastionDest(r) if err != nil { return nil, nil, err } @@ -130,23 +128,6 @@ func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnectio } -func (o *tcpOverWSService) bastionDest(r *http.Request) (string, error) { - jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader) - if jumpDestination == "" { - return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side") - } - // Strip scheme and path set by client. Without a scheme - // Parsing a hostname and path without scheme might not return an error due to parsing ambiguities - if jumpURL, err := url.Parse(jumpDestination); err == nil && jumpURL.Host != "" { - return removePath(jumpURL.Host), nil - } - return removePath(jumpDestination), nil -} - -func removePath(dest string) string { - return strings.SplitN(dest, "/", 2)[0] -} - func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { originConn := o.conn resp := &http.Response{ diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index e9153a1c..22266383 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -14,7 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/websocket" ) @@ -126,7 +126,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") bastionReq := baseReq.Clone(context.Background()) - bastionReq.Header.Set(h2mux.CFJumpDestinationHeader, originListener.Addr().String()) + carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String()) expectHeader := http.Header{ "Connection": {"Upgrade"}, @@ -135,19 +135,23 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { } tests := []struct { + testCase string service *tcpOverWSService req *http.Request expectErr bool }{ { - service: newTCPOverWSService(originURL), - req: baseReq, + testCase: "specific TCP service", + service: newTCPOverWSService(originURL), + req: baseReq, }, { - service: newBastionService(), - req: bastionReq, + testCase: "bastion service", + service: newBastionService(), + req: bastionReq, }, { + testCase: "invalid bastion request", service: newBastionService(), req: baseReq, expectErr: true, @@ -155,13 +159,15 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { } for _, test := range tests { - if test.expectErr { - _, resp, err := test.service.EstablishConnection(test.req) - assert.Error(t, err) - assert.Nil(t, resp) - } else { - assertEstablishConnectionResponse(t, test.service, test.req, expectHeader) - } + t.Run(test.testCase, func(t *testing.T) { + if test.expectErr { + _, resp, err := test.service.EstablishConnection(test.req) + assert.Error(t, err) + assert.Nil(t, resp) + } else { + assertEstablishConnectionResponse(t, test.service, test.req, expectHeader) + } + }) } originListener.Close() @@ -175,104 +181,6 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) { } } -func TestBastionDestination(t *testing.T) { - canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader) - tests := []struct { - name string - header http.Header - expectedDest string - wantErr bool - }{ - { - name: "hostname destination", - header: http.Header{ - canonicalJumpDestHeader: []string{"localhost"}, - }, - expectedDest: "localhost", - }, - { - name: "hostname destination with port", - header: http.Header{ - canonicalJumpDestHeader: []string{"localhost:9000"}, - }, - expectedDest: "localhost:9000", - }, - { - name: "hostname destination with scheme and port", - header: http.Header{ - canonicalJumpDestHeader: []string{"ssh://localhost:9000"}, - }, - expectedDest: "localhost:9000", - }, - { - name: "full hostname url", - header: http.Header{ - canonicalJumpDestHeader: []string{"ssh://localhost:9000/metrics"}, - }, - expectedDest: "localhost:9000", - }, - { - name: "hostname destination with port and path", - header: http.Header{ - canonicalJumpDestHeader: []string{"localhost:9000/metrics"}, - }, - expectedDest: "localhost:9000", - }, - { - name: "ip destination", - header: http.Header{ - canonicalJumpDestHeader: []string{"127.0.0.1"}, - }, - expectedDest: "127.0.0.1", - }, - { - name: "ip destination with port", - header: http.Header{ - canonicalJumpDestHeader: []string{"127.0.0.1:9000"}, - }, - expectedDest: "127.0.0.1:9000", - }, - { - name: "ip destination with port and path", - header: http.Header{ - canonicalJumpDestHeader: []string{"127.0.0.1:9000/metrics"}, - }, - expectedDest: "127.0.0.1:9000", - }, - { - name: "ip destination with schem and port", - header: http.Header{ - canonicalJumpDestHeader: []string{"tcp://127.0.0.1:9000"}, - }, - expectedDest: "127.0.0.1:9000", - }, - { - name: "full ip url", - header: http.Header{ - canonicalJumpDestHeader: []string{"ssh://127.0.0.1:9000/metrics"}, - }, - expectedDest: "127.0.0.1:9000", - }, - { - name: "no destination", - wantErr: true, - }, - } - s := newBastionService() - for _, test := range tests { - r := &http.Request{ - Header: test.header, - } - dest, err := s.bastionDest(r) - if test.wantErr { - assert.Error(t, err, "Test %s expects error", test.name) - } else { - assert.NoError(t, err, "Test %s expects no error, got error %v", test.name, err) - assert.Equal(t, test.expectedDest, dest, "Test %s expect dest %s, got %s", test.name, test.expectedDest, dest) - } - } -} - func TestHTTPServiceHostHeaderOverride(t *testing.T) { cfg := OriginRequestConfig{ HTTPHostHeader: t.Name(), diff --git a/origin/reconnect.go b/origin/reconnect.go index b832c372..c93c686a 100644 --- a/origin/reconnect.go +++ b/origin/reconnect.go @@ -9,6 +9,7 @@ import ( "github.com/prometheus/client_golang/prometheus" + "github.com/cloudflare/cloudflared/retry" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) @@ -103,7 +104,7 @@ func (cm *reconnectCredentialManager) SetConnDigest(connID uint8, digest []byte) func (cm *reconnectCredentialManager) RefreshAuth( ctx context.Context, - backoff *BackoffHandler, + backoff *retry.BackoffHandler, authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error), ) (retryTimer <-chan time.Time, err error) { authOutcome, err := authenticate(ctx, backoff.Retries()) @@ -121,11 +122,11 @@ func (cm *reconnectCredentialManager) RefreshAuth( case tunnelpogs.AuthSuccess: cm.SetReconnectToken(outcome.JWT()) cm.authSuccess.Inc() - return timeAfter(outcome.RefreshAfter()), nil + return retry.Clock.After(outcome.RefreshAfter()), nil case tunnelpogs.AuthUnknown: duration := outcome.RefreshAfter() cm.authFail.WithLabelValues(outcome.Error()).Inc() - return timeAfter(duration), nil + return retry.Clock.After(duration), nil case tunnelpogs.AuthFail: cm.authFail.WithLabelValues(outcome.Error()).Inc() return nil, outcome diff --git a/origin/reconnect_test.go b/origin/reconnect_test.go index 9b9dfa8c..fb2a1df9 100644 --- a/origin/reconnect_test.go +++ b/origin/reconnect_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/cloudflare/cloudflared/retry" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) @@ -17,11 +18,11 @@ func TestRefreshAuthBackoff(t *testing.T) { rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4) var wait time.Duration - timeAfter = func(d time.Duration) <-chan time.Time { + retry.Clock.After = func(d time.Duration) <-chan time.Time { wait = d return time.After(d) } - backoff := &BackoffHandler{MaxRetries: 3} + backoff := &retry.BackoffHandler{MaxRetries: 3} auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { return nil, fmt.Errorf("authentication failure") } @@ -45,7 +46,7 @@ func TestRefreshAuthBackoff(t *testing.T) { // The backoff timer should have been reset. To confirm this, make timeNow // return a value after the backoff timer's grace period - timeNow = func() time.Time { + retry.Clock.Now = func() time.Time { expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries) return time.Now().Add(expectedGracePeriod * 2) } @@ -57,12 +58,12 @@ func TestRefreshAuthSuccess(t *testing.T) { rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4) var wait time.Duration - timeAfter = func(d time.Duration) <-chan time.Time { + retry.Clock.After = func(d time.Duration) <-chan time.Time { wait = d return time.After(d) } - backoff := &BackoffHandler{MaxRetries: 3} + backoff := &retry.BackoffHandler{MaxRetries: 3} auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil } @@ -81,12 +82,12 @@ func TestRefreshAuthUnknown(t *testing.T) { rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4) var wait time.Duration - timeAfter = func(d time.Duration) <-chan time.Time { + retry.Clock.After = func(d time.Duration) <-chan time.Time { wait = d return time.After(d) } - backoff := &BackoffHandler{MaxRetries: 3} + backoff := &retry.BackoffHandler{MaxRetries: 3} auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil } @@ -104,7 +105,7 @@ func TestRefreshAuthUnknown(t *testing.T) { func TestRefreshAuthFail(t *testing.T) { rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4) - backoff := &BackoffHandler{MaxRetries: 3} + backoff := &retry.BackoffHandler{MaxRetries: 3} auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) { return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil } diff --git a/origin/supervisor.go b/origin/supervisor.go index b8c459aa..09888514 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -13,6 +13,7 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) @@ -112,10 +113,10 @@ func (s *Supervisor) Run( var tunnelsWaiting []int tunnelsActive := s.config.HAConnections - backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true} + backoff := retry.BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true} var backoffTimer <-chan time.Time - refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true} + refreshAuthBackoff := &retry.BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true} var refreshAuthBackoffTimer <-chan time.Time if s.useReconnectToken { diff --git a/origin/tunnel.go b/origin/tunnel.go index 4e221449..40e195b6 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -18,6 +18,7 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/tunnelrpc" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -138,7 +139,7 @@ func ServeTunnelLoop( connLog := config.Log.With().Uint8(connection.LogFieldConnIndex, connIndex).Logger() protocolFallback := &protocolFallback{ - BackoffHandler{MaxRetries: config.Retries}, + retry.BackoffHandler{MaxRetries: config.Retries}, config.ProtocolSelector.Current(), false, } @@ -195,18 +196,18 @@ func ServeTunnelLoop( // protocolFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches // max retries type protocolFallback struct { - BackoffHandler + retry.BackoffHandler protocol connection.Protocol inFallback bool } func (pf *protocolFallback) reset() { - pf.resetNow() + pf.ResetNow() pf.inFallback = false } func (pf *protocolFallback) fallback(fallback connection.Protocol) { - pf.resetNow() + pf.ResetNow() pf.protocol = fallback pf.inFallback = true } @@ -281,7 +282,7 @@ func ServeTunnel( } if protocol == connection.HTTP2 { - connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries)) + connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries())) err = ServeHTTP2( ctx, connLog, @@ -382,7 +383,7 @@ func ServeH2mux( errGroup.Go(func() error { if config.NamedTunnel != nil { - connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries)) + connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries())) return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse) } registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) diff --git a/origin/tunnel_test.go b/origin/tunnel_test.go index 9be42fa8..63b547da 100644 --- a/origin/tunnel_test.go +++ b/origin/tunnel_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/retry" ) type dynamicMockFetcher struct { @@ -26,7 +27,7 @@ func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher { func TestWaitForBackoffFallback(t *testing.T) { maxRetries := uint(3) - backoff := BackoffHandler{ + backoff := retry.BackoffHandler{ MaxRetries: maxRetries, BaseTime: time.Millisecond * 10, } diff --git a/origin/backoffhandler.go b/retry/backoffhandler.go similarity index 87% rename from origin/backoffhandler.go rename to retry/backoffhandler.go index 96b3326c..8c09db55 100644 --- a/origin/backoffhandler.go +++ b/retry/backoffhandler.go @@ -1,4 +1,4 @@ -package origin +package retry import ( "context" @@ -7,10 +7,15 @@ import ( ) // Redeclare time functions so they can be overridden in tests. -var ( - timeNow = time.Now - timeAfter = time.After -) +type clock struct { + Now func() time.Time + After func(d time.Duration) <-chan time.Time +} + +var Clock = clock{ + Now: time.Now, + After: time.After, +} // BackoffHandler manages exponential backoff and limits the maximum number of retries. // The base time period is 1 second, doubling with each retry. @@ -39,7 +44,7 @@ func (b BackoffHandler) GetMaxBackoffDuration(ctx context.Context) (time.Duratio return time.Duration(0), false default: } - if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) { + if !b.resetDeadline.IsZero() && Clock.Now().After(b.resetDeadline) { // b.retries would be set to 0 at this point return time.Second, true } @@ -53,7 +58,7 @@ func (b BackoffHandler) GetMaxBackoffDuration(ctx context.Context) (time.Duratio // BackoffTimer returns a channel that sends the current time when the exponential backoff timeout expires. // Returns nil if the maximum number of retries have been used. func (b *BackoffHandler) BackoffTimer() <-chan time.Time { - if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) { + if !b.resetDeadline.IsZero() && Clock.Now().After(b.resetDeadline) { b.retries = 0 b.resetDeadline = time.Time{} } @@ -66,7 +71,7 @@ func (b *BackoffHandler) BackoffTimer() <-chan time.Time { } maxTimeToWait := time.Duration(b.GetBaseTime() * 1 << (b.retries)) timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds())) - return timeAfter(timeToWait) + return Clock.After(timeToWait) } // Backoff is used to wait according to exponential backoff. Returns false if the @@ -89,7 +94,7 @@ func (b *BackoffHandler) Backoff(ctx context.Context) bool { func (b *BackoffHandler) SetGracePeriod() { maxTimeToWait := b.GetBaseTime() * 2 << (b.retries + 1) timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds())) - b.resetDeadline = timeNow().Add(timeToWait) + b.resetDeadline = Clock.Now().Add(timeToWait) } func (b BackoffHandler) GetBaseTime() time.Duration { @@ -108,6 +113,6 @@ func (b *BackoffHandler) ReachedMaxRetries() bool { return b.retries == b.MaxRetries } -func (b *BackoffHandler) resetNow() { +func (b *BackoffHandler) ResetNow() { b.resetDeadline = time.Now() } diff --git a/origin/backoffhandler_test.go b/retry/backoffhandler_test.go similarity index 92% rename from origin/backoffhandler_test.go rename to retry/backoffhandler_test.go index 8add3f0b..8a4d1fe3 100644 --- a/origin/backoffhandler_test.go +++ b/retry/backoffhandler_test.go @@ -1,4 +1,4 @@ -package origin +package retry import ( "context" @@ -14,7 +14,7 @@ func immediateTimeAfter(time.Duration) <-chan time.Time { func TestBackoffRetries(t *testing.T) { // make backoff return immediately - timeAfter = immediateTimeAfter + Clock.After = immediateTimeAfter ctx := context.Background() backoff := BackoffHandler{MaxRetries: 3} if !backoff.Backoff(ctx) { @@ -33,7 +33,7 @@ func TestBackoffRetries(t *testing.T) { func TestBackoffCancel(t *testing.T) { // prevent backoff from returning normally - timeAfter = func(time.Duration) <-chan time.Time { return make(chan time.Time) } + Clock.After = func(time.Duration) <-chan time.Time { return make(chan time.Time) } ctx, cancelFunc := context.WithCancel(context.Background()) backoff := BackoffHandler{MaxRetries: 3} cancelFunc() @@ -47,10 +47,10 @@ func TestBackoffCancel(t *testing.T) { func TestBackoffGracePeriod(t *testing.T) { currentTime := time.Now() - // make timeNow return whatever we like - timeNow = func() time.Time { return currentTime } + // make Clock.Now return whatever we like + Clock.Now = func() time.Time { return currentTime } // make backoff return immediately - timeAfter = immediateTimeAfter + Clock.After = immediateTimeAfter ctx := context.Background() backoff := BackoffHandler{MaxRetries: 1} if !backoff.Backoff(ctx) { @@ -71,7 +71,7 @@ func TestBackoffGracePeriod(t *testing.T) { func TestGetMaxBackoffDurationRetries(t *testing.T) { // make backoff return immediately - timeAfter = immediateTimeAfter + Clock.After = immediateTimeAfter ctx := context.Background() backoff := BackoffHandler{MaxRetries: 3} if _, ok := backoff.GetMaxBackoffDuration(ctx); !ok { @@ -96,7 +96,7 @@ func TestGetMaxBackoffDurationRetries(t *testing.T) { func TestGetMaxBackoffDuration(t *testing.T) { // make backoff return immediately - timeAfter = immediateTimeAfter + Clock.After = immediateTimeAfter ctx := context.Background() backoff := BackoffHandler{MaxRetries: 3} if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*2 { @@ -118,7 +118,7 @@ func TestGetMaxBackoffDuration(t *testing.T) { func TestBackoffRetryForever(t *testing.T) { // make backoff return immediately - timeAfter = immediateTimeAfter + Clock.After = immediateTimeAfter ctx := context.Background() backoff := BackoffHandler{MaxRetries: 3, RetryForever: true} if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*2 { diff --git a/token/token.go b/token/token.go index ab04973c..f443fd70 100644 --- a/token/token.go +++ b/token/token.go @@ -18,7 +18,7 @@ import ( "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/config" - "github.com/cloudflare/cloudflared/origin" + "github.com/cloudflare/cloudflared/retry" ) const ( @@ -36,7 +36,7 @@ type AppInfo struct { type lock struct { lockFilePath string - backoff *origin.BackoffHandler + backoff *retry.BackoffHandler sigHandler *signalHandler } @@ -94,7 +94,7 @@ func newLock(path string) *lock { lockPath := path + ".lock" return &lock{ lockFilePath: lockPath, - backoff: &origin.BackoffHandler{MaxRetries: 7}, + backoff: &retry.BackoffHandler{MaxRetries: 7}, sigHandler: &signalHandler{ signals: []os.Signal{syscall.SIGINT, syscall.SIGTERM}, }, diff --git a/websocket/websocket.go b/websocket/websocket.go index 74289b40..6a619b70 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -4,15 +4,11 @@ import ( "crypto/sha1" "encoding/base64" "io" - "net" "net/http" "net/url" - "time" "github.com/gorilla/websocket" "github.com/rs/zerolog" - - "github.com/cloudflare/cloudflared/h2mux" ) var stripWebsocketHeaders = []string{ @@ -47,80 +43,6 @@ func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Con return conn, response, nil } -// StartProxyServer will start a websocket server that will decode -// the websocket data and write the resulting data to the provided -func StartProxyServer( - log *zerolog.Logger, - listener net.Listener, - staticHost string, - shutdownC <-chan struct{}, - streamHandler func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger), -) error { - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - } - h := handler{ - upgrader: upgrader, - log: log, - staticHost: staticHost, - streamHandler: streamHandler, - } - - httpServer := &http.Server{Addr: listener.Addr().String(), Handler: &h} - go func() { - <-shutdownC - _ = httpServer.Close() - }() - - return httpServer.Serve(listener) -} - -// HTTP handler for the websocket proxy. -type handler struct { - log *zerolog.Logger - staticHost string - upgrader websocket.Upgrader - streamHandler func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger) -} - -func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // If remote is an empty string, get the destination from the client. - finalDestination := h.staticHost - if finalDestination == "" { - if jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader); jumpDestination == "" { - h.log.Error().Msg("Did not receive final destination from client. The --destination flag is likely not set") - return - } else { - finalDestination = jumpDestination - } - } - - stream, err := net.Dial("tcp", finalDestination) - if err != nil { - h.log.Err(err).Msg("Cannot connect to remote") - return - } - defer stream.Close() - - if !websocket.IsWebSocketUpgrade(r) { - _, _ = w.Write(nonWebSocketRequestPage()) - return - } - conn, err := h.upgrader.Upgrade(w, r, nil) - if err != nil { - h.log.Err(err).Msg("failed to upgrade") - return - } - _ = conn.SetReadDeadline(time.Now().Add(pongWait)) - conn.SetPongHandler(func(string) error { _ = conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) - gorillaConn := &GorillaConn{Conn: conn, log: h.log} - go gorillaConn.pinger(r.Context()) - defer conn.Close() - - h.streamHandler(gorillaConn, stream, h.log) -} - // NewResponseHeader returns headers needed to return to origin for completing handshake func NewResponseHeader(req *http.Request) http.Header { header := http.Header{} @@ -174,3 +96,27 @@ func ChangeRequestScheme(reqURL *url.URL) string { return reqURL.Scheme } } + +// Stream copies copy data to & from provided io.ReadWriters. +func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) { + proxyDone := make(chan struct{}, 2) + + go func() { + _, err := io.Copy(conn, backendConn) + if err != nil { + log.Debug().Msgf("conn to backendConn copy: %v", err) + } + proxyDone <- struct{}{} + }() + + go func() { + _, err := io.Copy(backendConn, conn) + if err != nil { + log.Debug().Msgf("backendConn to conn copy: %v", err) + } + proxyDone <- struct{}{} + }() + + // If one side is done, we are done. + <-proxyDone +} diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go index 639be802..e738a106 100644 --- a/websocket/websocket_test.go +++ b/websocket/websocket_test.go @@ -151,41 +151,3 @@ func TestWebsocketWrapper(t *testing.T) { require.Equal(t, n, 2) require.Equal(t, "bc", string(buf[:n])) } - -// func TestStartProxyServer(t *testing.T) { -// var wg sync.WaitGroup -// remoteAddress := "localhost:1113" -// listenerAddress := "localhost:1112" -// message := "Good morning Austin! Time for another sunny day in the great state of Texas." -// logger := zerolog.Nop() -// shutdownC := make(chan struct{}) - -// listener, err := net.Listen("tcp", listenerAddress) -// assert.NoError(t, err) -// defer listener.Close() - -// remoteListener, err := net.Listen("tcp", remoteAddress) -// assert.NoError(t, err) -// defer remoteListener.Close() - -// wg.Add(1) -// go func() { -// defer wg.Done() -// conn, err := remoteListener.Accept() -// assert.NoError(t, err) -// buf := make([]byte, len(message)) -// conn.Read(buf) -// assert.Equal(t, string(buf), message) -// }() - -// go func() { -// StartProxyServer(logger, listener, remoteAddress, shutdownC) -// }() - -// req := testRequest(t, fmt.Sprintf("http://%s/", listenerAddress), nil) -// conn, _, err := ClientConnect(req, nil) -// assert.NoError(t, err) -// err = conn.WriteMessage(1, []byte(message)) -// assert.NoError(t, err) -// wg.Wait() -// }