diff --git a/connection/connection.go b/connection/connection.go index b2f090a8..e68afb64 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -1,6 +1,7 @@ package connection import ( + "context" "fmt" "io" "net/http" @@ -11,9 +12,15 @@ import ( "github.com/google/uuid" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/websocket" ) -const LogFieldConnIndex = "connIndex" +const ( + lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" + LogFieldConnIndex = "connIndex" +) + +var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) type Config struct { OriginProxy OriginProxy @@ -87,9 +94,64 @@ func (t Type) String() string { } } +// OriginProxy is how data flows from cloudflared to the origin services running behind it. type OriginProxy interface { - // If Proxy returns an error, the caller is responsible for writing the error status to ResponseWriter - Proxy(w ResponseWriter, req *http.Request, sourceConnectionType Type) error + ProxyHTTP(w ResponseWriter, req *http.Request, isWebsocket bool) error + ProxyTCP(ctx context.Context, rwa ReadWriteAcker, req *TCPRequest) error +} + +// TCPRequest defines the input format needed to perform a TCP proxy. +type TCPRequest struct { + Dest string + CFRay string + LBProbe bool +} + +// ReadWriteAcker is a readwriter with the ability to Acknowledge to the downstream (edge) that the origin has +// accepted the connection. +type ReadWriteAcker interface { + io.ReadWriter + AckConnection() error +} + +// HTTPResponseReadWriteAcker is an HTTP implementation of ReadWriteAcker. +type HTTPResponseReadWriteAcker struct { + r io.Reader + w ResponseWriter + req *http.Request +} + +// NewHTTPResponseReadWriterAcker returns a new instance of HTTPResponseReadWriteAcker. +func NewHTTPResponseReadWriterAcker(w ResponseWriter, req *http.Request) *HTTPResponseReadWriteAcker { + return &HTTPResponseReadWriteAcker{ + r: req.Body, + w: w, + req: req, + } +} + +func (h *HTTPResponseReadWriteAcker) Read(p []byte) (int, error) { + return h.r.Read(p) +} + +func (h *HTTPResponseReadWriteAcker) Write(p []byte) (int, error) { + return h.w.Write(p) +} + +// AckConnection acks an HTTP connection by sending a switch protocols status code that enables the caller to +// upgrade to streams. +func (h *HTTPResponseReadWriteAcker) AckConnection() error { + resp := &http.Response{ + Status: switchingProtocolText, + StatusCode: http.StatusSwitchingProtocols, + ContentLength: -1, + } + + if secWebsocketKey := h.req.Header.Get("Sec-WebSocket-Key"); secWebsocketKey != "" { + resp.Header = websocket.NewResponseHeader(h.req) + } + + return h.w.WriteRespHeaders(resp.StatusCode, resp.Header) } type ResponseWriter interface { @@ -112,3 +174,11 @@ func IsServerSentEvent(headers http.Header) bool { func uint8ToString(input uint8) string { return strconv.FormatUint(uint64(input), 10) } + +func FindCfRayHeader(req *http.Request) string { + return req.Header.Get("Cf-Ray") +} + +func IsLBProbeRequest(req *http.Request) bool { + return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix) +} diff --git a/connection/connection_test.go b/connection/connection_test.go index 0fa3b29f..e8e477ea 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -1,6 +1,7 @@ package connection import ( + "context" "fmt" "io" "net/http" @@ -11,6 +12,8 @@ import ( "github.com/gobwas/ws/wsutil" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + + "github.com/cloudflare/cloudflared/ingress" ) const ( @@ -18,7 +21,8 @@ const ( ) var ( - testConfig = &Config{ + unusedWarpRoutingService = (*ingress.WarpRoutingService)(nil) + testConfig = &Config{ OriginProxy: &mockOriginProxy{}, GracePeriod: time.Millisecond * 100, } @@ -38,14 +42,17 @@ type testRequest struct { isProxyError bool } -type mockOriginProxy struct { -} +type mockOriginProxy struct{} -func (moc *mockOriginProxy) Proxy(w ResponseWriter, r *http.Request, sourceConnectionType Type) error { - if sourceConnectionType == TypeWebsocket { - return wsEndpoint(w, r) +func (moc *mockOriginProxy) ProxyHTTP( + w ResponseWriter, + req *http.Request, + isWebsocket bool, +) error { + if isWebsocket { + return wsEndpoint(w, req) } - switch r.URL.Path { + switch req.URL.Path { case "/ok": originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK))) case "/large_file": @@ -60,6 +67,15 @@ func (moc *mockOriginProxy) Proxy(w ResponseWriter, r *http.Request, sourceConne originRespEndpoint(w, http.StatusNotFound, []byte("page not found")) } return nil + +} + +func (moc *mockOriginProxy) ProxyTCP( + ctx context.Context, + rwa ReadWriteAcker, + r *TCPRequest, +) error { + return nil } type nowriter struct { diff --git a/connection/h2mux.go b/connection/h2mux.go index cd91a67f..1e7c652b 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -33,6 +33,8 @@ type h2muxConnection struct { gracefulShutdownC <-chan struct{} stoppedGracefully bool + log *zerolog.Logger + // newRPCClientFunc allows us to mock RPCs during testing newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient } @@ -222,12 +224,11 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error { sourceConnectionType = TypeWebsocket } - err := h.config.OriginProxy.Proxy(respWriter, req, sourceConnectionType) + err := h.config.OriginProxy.ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket) if err != nil { respWriter.WriteErrorResponse() - return err } - return nil + return err } func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) { diff --git a/connection/http2.go b/connection/http2.go index ea81fdbf..e748a283 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -10,6 +10,7 @@ import ( "strings" "sync" + "github.com/pkg/errors" "github.com/rs/zerolog" "golang.org/x/net/http2" @@ -26,7 +27,9 @@ const ( var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed") -type http2Connection struct { +// HTTP2Connection represents a net.Conn that uses HTTP2 frames to proxy traffic from the edge to cloudflared on the +// origin. +type HTTP2Connection struct { conn net.Conn server *http2.Server config *Config @@ -38,6 +41,7 @@ type http2Connection struct { // newRPCClientFunc allows us to mock RPCs during testing newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient + log *zerolog.Logger activeRequestsWG sync.WaitGroup connectedFuse ConnectedFuse gracefulShutdownC <-chan struct{} @@ -45,6 +49,7 @@ type http2Connection struct { controlStreamErr error // result of running control stream handler } +// NewHTTP2Connection returns a new instance of HTTP2Connection. func NewHTTP2Connection( conn net.Conn, config *Config, @@ -53,9 +58,10 @@ func NewHTTP2Connection( observer *Observer, connIndex uint8, connectedFuse ConnectedFuse, + log *zerolog.Logger, gracefulShutdownC <-chan struct{}, -) *http2Connection { - return &http2Connection{ +) *HTTP2Connection { + return &HTTP2Connection{ conn: conn, server: &http2.Server{ MaxConcurrentStreams: math.MaxUint32, @@ -68,11 +74,13 @@ func NewHTTP2Connection( connIndex: connIndex, newRPCClientFunc: newRegistrationRPCClient, connectedFuse: connectedFuse, + log: log, gracefulShutdownC: gracefulShutdownC, } } -func (c *http2Connection) Serve(ctx context.Context) error { +// Serve serves an HTTP2 server that the edge can talk to. +func (c *HTTP2Connection) Serve(ctx context.Context) error { go func() { <-ctx.Done() c.close() @@ -93,7 +101,7 @@ func (c *http2Connection) Serve(ctx context.Context) error { } } -func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.activeRequestsWG.Add(1) defer c.activeRequestsWG.Done() @@ -106,23 +114,47 @@ func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - var proxyErr error switch connType { case TypeControlStream: - proxyErr = c.serveControlStream(r.Context(), respWriter) - c.controlStreamErr = proxyErr - case TypeWebsocket: + if err := c.serveControlStream(r.Context(), respWriter); err != nil { + c.controlStreamErr = err + c.log.Error().Err(err) + respWriter.WriteErrorResponse() + } + + case TypeWebsocket, TypeHTTP: stripWebsocketUpgradeHeader(r) - proxyErr = c.config.OriginProxy.Proxy(respWriter, r, TypeWebsocket) + if err := c.config.OriginProxy.ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil { + err := fmt.Errorf("Failed to proxy HTTP: %w", err) + c.log.Error().Err(err) + respWriter.WriteErrorResponse() + } + + case TypeTCP: + host, err := getRequestHost(r) + if err != nil { + err := fmt.Errorf(`cloudflared recieved a warp-routing request with an empty host value: %w`, err) + c.log.Error().Err(err) + respWriter.WriteErrorResponse() + } + + rws := NewHTTPResponseReadWriterAcker(respWriter, r) + if err := c.config.OriginProxy.ProxyTCP(r.Context(), rws, &TCPRequest{ + Dest: host, + CFRay: FindCfRayHeader(r), + LBProbe: IsLBProbeRequest(r), + }); err != nil { + respWriter.WriteErrorResponse() + } + default: - proxyErr = c.config.OriginProxy.Proxy(respWriter, r, connType) - } - if proxyErr != nil { + err := fmt.Errorf("Received unknown connection type: %s", connType) + c.log.Error().Err(err) respWriter.WriteErrorResponse() } } -func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error { +func (c *HTTP2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error { rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer.log) defer rpcClient.Close() @@ -145,7 +177,7 @@ func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *ht return nil } -func (c *http2Connection) close() { +func (c *HTTP2Connection) close() { // Wait for all serve HTTP handlers to return c.activeRequestsWG.Wait() c.conn.Close() @@ -287,3 +319,14 @@ func IsTCPStream(r *http.Request) bool { func stripWebsocketUpgradeHeader(r *http.Request) { r.Header.Del(InternalUpgradeHeader) } + +// getRequestHost returns the host of the http.Request. +func getRequestHost(r *http.Request) (string, error) { + if r.Host != "" { + return r.Host, nil + } + if r.URL != nil { + return r.URL.Host, nil + } + return "", errors.New("host not set in incoming request") +} diff --git a/connection/http2_test.go b/connection/http2_test.go index c1b2ba4d..dc1b2c70 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -26,9 +26,10 @@ var ( testTransport = http2.Transport{} ) -func newTestHTTP2Connection() (*http2Connection, net.Conn) { +func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { edgeConn, originConn := net.Pipe() var connIndex = uint8(0) + log := zerolog.Nop() return NewHTTP2Connection( originConn, testConfig, @@ -37,6 +38,7 @@ func newTestHTTP2Connection() (*http2Connection, net.Conn) { NewObserver(&log, &log, false), connIndex, mockConnectedFuse{}, + &log, nil, ), edgeConn } diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index c9da44f5..08f25a38 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -1,7 +1,6 @@ package ingress import ( - "fmt" "net" "net/http" @@ -9,7 +8,6 @@ import ( ) var ( - switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) errUnsupportedConnectionType = errors.New("internal error: unsupported connection type") ) diff --git a/origin/proxy.go b/origin/proxy.go index 8f984a94..13fe470a 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -7,7 +7,6 @@ import ( "io" "net/http" "strconv" - "strings" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -20,13 +19,15 @@ import ( ) const ( + // TagHeaderNamePrefix indicates a Cloudflared Warp Tag prefix that gets appended for warp traffic stream headers. TagHeaderNamePrefix = "Cf-Warp-Tag-" LogFieldCFRay = "cfRay" LogFieldRule = "ingressRule" LogFieldOriginService = "originService" ) -type proxy struct { +// Proxy represents a means to Proxy between cloudflared and the origin services. +type Proxy struct { ingressRules ingress.Ingress warpRouting *ingress.WarpRoutingService tags []tunnelpogs.Tag @@ -34,15 +35,14 @@ type proxy struct { bufferPool *bufferPool } -var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) - +// NewOriginProxy returns a new instance of the Proxy struct. func NewOriginProxy( ingressRules ingress.Ingress, warpRouting *ingress.WarpRoutingService, tags []tunnelpogs.Tag, - log *zerolog.Logger) connection.OriginProxy { - - return &proxy{ + log *zerolog.Logger, +) *Proxy { + return &Proxy{ ingressRules: ingressRules, warpRouting: warpRouting, tags: tags, @@ -51,41 +51,18 @@ func NewOriginProxy( } } -// Caller is responsible for writing any error to ResponseWriter -func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConnectionType connection.Type) error { +// ProxyHTTP further depends on ingress rules to establish a connection with the origin service. This may be +// a simple roundtrip or a tcp/websocket dial depending on ingres rule setup. +func (p *Proxy) ProxyHTTP( + w connection.ResponseWriter, + req *http.Request, + isWebsocket bool, +) error { incrementRequests() defer decrementConcurrentRequests() - cfRay := findCfRayHeader(req) - lbProbe := isLBProbeRequest(req) - - serveCtx, cancel := context.WithCancel(req.Context()) - defer cancel() - - p.appendTagHeaders(req) - if sourceConnectionType == connection.TypeTCP { - if p.warpRouting == nil { - err := errors.New(`cloudflared received a request from WARP client, but your configuration has disabled ingress from WARP clients. To enable this, set "warp-routing:\n\t enabled: true" in your config.yaml`) - p.log.Error().Msg(err.Error()) - return err - } - logFields := logFields{ - cfRay: cfRay, - lbProbe: lbProbe, - rule: ingress.ServiceWarpRouting, - } - - host, err := getRequestHost(req) - if err != nil { - err = fmt.Errorf(`cloudflared recieved a warp-routing request with an empty host value: %v`, err) - return err - } - if err := p.proxyStreamRequest(serveCtx, w, host, req, p.warpRouting.Proxy, logFields); err != nil { - p.logRequestError(err, cfRay, "", ingress.ServiceWarpRouting) - return err - } - return nil - } + cfRay := connection.FindCfRayHeader(req) + lbProbe := connection.IsLBProbeRequest(req) rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path) logFields := logFields{ @@ -97,8 +74,14 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn switch originProxy := rule.Service.(type) { case ingress.HTTPOriginProxy: - if err := p.proxyHTTPRequest(w, req, originProxy, sourceConnectionType == connection.TypeWebsocket, - rule.Config.DisableChunkedEncoding, logFields); err != nil { + if err := p.proxyHTTPRequest( + w, + req, + originProxy, + isWebsocket, + rule.Config.DisableChunkedEncoding, + logFields, + ); err != nil { rule, srv := ruleField(p.ingressRules, ruleNum) p.logRequestError(err, cfRay, rule, srv) return err @@ -110,7 +93,9 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn if err != nil { return err } - if err := p.proxyStreamRequest(serveCtx, w, dest, req, originProxy, logFields); err != nil { + + rws := connection.NewHTTPResponseReadWriterAcker(w, req) + if err := p.proxyStream(req.Context(), rws, dest, originProxy, logFields); err != nil { rule, srv := ruleField(p.ingressRules, ruleNum) p.logRequestError(err, cfRay, rule, srv) return err @@ -121,24 +106,36 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn } } -func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) { - switch rule.Service.String() { - case ingress.ServiceBastion: - return carrier.ResolveBastionDest(req) - default: - return rule.Service.String(), nil - } -} +// ProxyTCP proxies to a TCP connection between the origin service and cloudflared. +func (p *Proxy) ProxyTCP( + ctx context.Context, + rwa connection.ReadWriteAcker, + req *connection.TCPRequest, +) error { + incrementRequests() + defer decrementConcurrentRequests() -// getRequestHost returns the host of the http.Request. -func getRequestHost(r *http.Request) (string, error) { - if r.Host != "" { - return r.Host, nil + if p.warpRouting == nil { + err := errors.New(`cloudflared received a request from WARP client, but your configuration has disabled ingress from WARP clients. To enable this, set "warp-routing:\n\t enabled: true" in your config.yaml`) + p.log.Error().Msg(err.Error()) + return err } - if r.URL != nil { - return r.URL.Host, nil + + serveCtx, cancel := context.WithCancel(ctx) + defer cancel() + + logFields := logFields{ + cfRay: req.CFRay, + lbProbe: req.LBProbe, + rule: ingress.ServiceWarpRouting, } - return "", errors.New("host not set in incoming request") + + if err := p.proxyStream(serveCtx, rwa, req.Dest, p.warpRouting.Proxy, logFields); err != nil { + p.logRequestError(err, req.CFRay, "", ingress.ServiceWarpRouting) + return err + } + + return nil } func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) { @@ -149,13 +146,15 @@ func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) { return fmt.Sprintf("%d", ruleNum), srv } -func (p *proxy) proxyHTTPRequest( +// ProxyHTTPRequest proxies requests of underlying type http and websocket to the origin service. +func (p *Proxy) proxyHTTPRequest( w connection.ResponseWriter, req *http.Request, httpService ingress.HTTPOriginProxy, isWebsocket bool, disableChunkedEncoding bool, - fields logFields) error { + fields logFields, +) error { roundTripReq := req if isWebsocket { roundTripReq = req.Clone(req.Context()) @@ -214,17 +213,17 @@ func (p *proxy) proxyHTTPRequest( defer p.bufferPool.Put(buf) _, _ = io.CopyBuffer(w, resp.Body, buf) } + p.logOriginResponse(resp, fields) return nil } -// proxyStreamRequest first establish a connection with origin, then it writes the status code and headers, and finally it streams data between -// eyeball and origin. -func (p *proxy) proxyStreamRequest( - serveCtx context.Context, - w connection.ResponseWriter, +// proxyStream proxies type TCP and other underlying types if the connection is defined as a stream oriented +// ingress rule. +func (p *Proxy) proxyStream( + ctx context.Context, + rwa connection.ReadWriteAcker, dest string, - req *http.Request, connectionProxy ingress.StreamBasedOriginProxy, fields logFields, ) error { @@ -233,21 +232,11 @@ func (p *proxy) proxyStreamRequest( return err } - resp := &http.Response{ - Status: switchingProtocolText, - StatusCode: http.StatusSwitchingProtocols, - ContentLength: -1, - } - - if secWebsocketKey := req.Header.Get("Sec-WebSocket-Key"); secWebsocketKey != "" { - resp.Header = websocket.NewResponseHeader(req) - } - - if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil { + if err := rwa.AckConnection(); err != nil { return err } - streamCtx, cancel := context.WithCancel(serveCtx) + streamCtx, cancel := context.WithCancel(ctx) defer cancel() go func() { @@ -256,12 +245,7 @@ func (p *proxy) proxyStreamRequest( originConn.Close() }() - eyeballStream := &bidirectionalStream{ - writer: w, - reader: req.Body, - } - originConn.Stream(serveCtx, eyeballStream, p.log) - p.logOriginResponse(resp, fields) + originConn.Stream(ctx, rwa, p.log) return nil } @@ -278,7 +262,7 @@ func (wr *bidirectionalStream) Write(p []byte) (n int, err error) { return wr.writer.Write(p) } -func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) { +func (p *Proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) { reader := bufio.NewReader(respBody) for { line, err := reader.ReadBytes('\n') @@ -289,7 +273,7 @@ func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCl } } -func (p *proxy) appendTagHeaders(r *http.Request) { +func (p *Proxy) appendTagHeaders(r *http.Request) { for _, tag := range p.tags { r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value) } @@ -301,7 +285,7 @@ type logFields struct { rule interface{} } -func (p *proxy) logRequest(r *http.Request, fields logFields) { +func (p *Proxy) logRequest(r *http.Request, fields logFields) { if fields.cfRay != "" { p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto) } else if fields.lbProbe { @@ -324,7 +308,7 @@ func (p *proxy) logRequest(r *http.Request, fields logFields) { } } -func (p *proxy) logOriginResponse(resp *http.Response, fields logFields) { +func (p *Proxy) logOriginResponse(resp *http.Response, fields logFields) { responseByCode.WithLabelValues(strconv.Itoa(resp.StatusCode)).Inc() if fields.cfRay != "" { p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", fields.cfRay, resp.Status, fields.rule) @@ -342,7 +326,7 @@ func (p *proxy) logOriginResponse(resp *http.Response, fields logFields) { } } -func (p *proxy) logRequestError(err error, cfRay string, rule, service string) { +func (p *Proxy) logRequestError(err error, cfRay string, rule, service string) { requestErrors.Inc() log := p.log.Error().Err(err) if cfRay != "" { @@ -357,10 +341,11 @@ func (p *proxy) logRequestError(err error, cfRay string, rule, service string) { log.Msg("") } -func findCfRayHeader(req *http.Request) string { - return req.Header.Get("Cf-Ray") -} - -func isLBProbeRequest(req *http.Request) bool { - return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix) +func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) { + switch rule.Service.String() { + case ingress.ServiceBastion: + return carrier.ResolveBastionDest(req) + default: + return rule.Service.String(), nil + } } diff --git a/origin/proxy_test.go b/origin/proxy_test.go index 6b785647..5ba89a22 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -46,6 +46,10 @@ func newMockHTTPRespWriter() *mockHTTPRespWriter { } } +func (w *mockHTTPRespWriter) WriteResponse() error { + return nil +} + func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) error { w.WriteHeader(status) for header, val := range header { @@ -146,7 +150,7 @@ func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) require.NoError(t, err) - err = proxy.Proxy(responseWriter, req, connection.TypeHTTP) + err = proxy.ProxyHTTP(responseWriter, req, false) require.NoError(t, err) assert.Equal(t, http.StatusOK, responseWriter.Code) @@ -170,7 +174,7 @@ func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) { errGroup, ctx := errgroup.WithContext(ctx) errGroup.Go(func() error { - err = proxy.Proxy(responseWriter, req, connection.TypeWebsocket) + err = proxy.ProxyHTTP(responseWriter, req, true) require.NoError(t, err) require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code) @@ -231,7 +235,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - err = proxy.Proxy(responseWriter, req, connection.TypeHTTP) + err = proxy.ProxyHTTP(responseWriter, req, false) require.NoError(t, err) require.Equal(t, http.StatusOK, responseWriter.Code) @@ -330,7 +334,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat req, err := http.NewRequest(http.MethodGet, test.url, nil) require.NoError(t, err) - err = proxy.Proxy(responseWriter, req, connection.TypeHTTP) + err = proxy.ProxyHTTP(responseWriter, req, false) require.NoError(t, err) assert.Equal(t, test.expectedStatus, responseWriter.Code) @@ -358,7 +362,7 @@ func (errorOriginTransport) RoundTrip(*http.Request) (*http.Response, error) { } func TestProxyError(t *testing.T) { - ingress := ingress.Ingress{ + ing := ingress.Ingress{ Rules: []ingress.Rule{ { Hostname: "*", @@ -372,13 +376,13 @@ func TestProxyError(t *testing.T) { log := zerolog.Nop() - proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log) + proxy := NewOriginProxy(ing, unusedWarpRoutingService, testTags, &log) responseWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) assert.NoError(t, err) - assert.Error(t, proxy.Proxy(responseWriter, req, connection.TypeHTTP)) + assert.Error(t, proxy.ProxyHTTP(responseWriter, req, false)) } type replayer struct { @@ -617,6 +621,7 @@ func TestConnections(t *testing.T) { ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) proxy := NewOriginProxy(ingressRule, test.args.warpRoutingService, testTags, logger) + dest := ln.Addr().String() req, err := http.NewRequest( http.MethodGet, test.args.ingressServiceScheme+ln.Addr().String(), @@ -634,8 +639,12 @@ func TestConnections(t *testing.T) { replayer.Write(resp) }() } - - err = proxy.Proxy(respWriter, req, test.args.connectionType) + if test.args.connectionType == connection.TypeTCP { + rws := connection.NewHTTPResponseReadWriterAcker(respWriter, req) + err = proxy.ProxyTCP(ctx, rws, &connection.TCPRequest{Dest: dest}) + } else { + err = proxy.ProxyHTTP(respWriter, req, test.args.connectionType == connection.TypeWebsocket) + } cancel() assert.Equal(t, test.want.err, err != nil) @@ -829,6 +838,10 @@ func newTCPRespWriter(w io.Writer) *mockTCPRespWriter { } } +func (m *mockTCPRespWriter) Read(p []byte) (n int, err error) { + return len(p), nil +} + func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) { return m.w.Write(p) } diff --git a/origin/tunnel.go b/origin/tunnel.go index 40e195b6..83a74302 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -26,7 +26,6 @@ import ( const ( dialTimeout = 15 * time.Second - lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" FeatureSerializedHeaders = "serialized_headers" FeatureQuickReconnects = "quick_reconnects" ) @@ -417,6 +416,7 @@ func ServeHTTP2( config.Observer, connIndex, connectedFuse, + config.Log, gracefulShutdownC, )