From 5974fb4cfd1c5d6c3568af57b4ccfa4427b279b5 Mon Sep 17 00:00:00 2001 From: cthuang Date: Mon, 2 Nov 2020 11:21:34 +0000 Subject: [PATCH] TUN-3500: Integrate replace h2mux by http2 work with multiple origin support --- cmd/cloudflared/config/configuration.go | 4 +- cmd/cloudflared/tunnel/cmd.go | 7 +- cmd/cloudflared/tunnel/configuration.go | 40 ++- connection/h2mux.go | 5 +- connection/http2.go | 7 - hello/hello.go | 8 + ingress/ingress.go | 12 +- ingress/origin_service.go | 2 +- origin/h2mux.go | 229 ----------------- origin/http2.go | 320 ------------------------ origin/metrics.go | 9 + origin/proxy.go | 99 ++++---- origin/proxy_test.go | 170 ++++++++++--- origin/tunnel.go | 10 +- validation/validation.go | 19 +- validation/validation_test.go | 27 +- 16 files changed, 252 insertions(+), 716 deletions(-) delete mode 100644 origin/h2mux.go delete mode 100644 origin/http2.go diff --git a/cmd/cloudflared/config/configuration.go b/cmd/cloudflared/config/configuration.go index e8917857..a3b927d3 100644 --- a/cmd/cloudflared/config/configuration.go +++ b/cmd/cloudflared/config/configuration.go @@ -190,9 +190,9 @@ func ValidateUnixSocket(c *cli.Context) (string, error) { // ValidateUrl will validate url flag correctness. It can be either from --url or argument // Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument -func ValidateUrl(c *cli.Context, allowFromArgs bool) (*url.URL, error) { +func ValidateUrl(c *cli.Context, allowURLFromArgs bool) (*url.URL, error) { var url = c.String("url") - if allowFromArgs && c.NArg() > 0 { + if allowURLFromArgs && c.NArg() > 0 { if c.IsSet("url") { return nil, errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.") } diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 2850451a..59a8db1b 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -367,12 +367,12 @@ func StartServer( return errors.Wrap(err, "error setting up transport logger") } - tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, log, transportLogger, namedTunnel, isUIEnabled) + tunnelConfig, ingressRules, err := prepareTunnelConfig(c, buildInfo, version, log, transportLogger, namedTunnel, isUIEnabled) if err != nil { return err } - tunnelConfig.IngressRules.StartOrigins(&wg, log, shutdownC, errC) + ingressRules.StartOrigins(&wg, log, shutdownC, errC) reconnectCh := make(chan origin.ReconnectSignal, 1) if c.IsSet("stdin-control") { @@ -391,8 +391,7 @@ func StartServer( version, hostname, metricsListener.Addr().String(), - // TODO (TUN-3461): Update UI to show multiple origin URLs - &tunnelConfig.IngressRules, + &ingressRules, tunnelConfig.HAConnections, ) logLevels, err := logger.ParseLevelString(c.String("loglevel")) diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index c44e0ecc..fb57abb7 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -1,6 +1,7 @@ package tunnel import ( + "crypto/tls" "fmt" "io/ioutil" "os" @@ -160,27 +161,27 @@ func prepareTunnelConfig( transportLogger logger.Service, namedTunnel *connection.NamedTunnelConfig, uiIsEnabled bool, -) (*origin.TunnelConfig, error) { +) (*origin.TunnelConfig, ingress.Ingress, error) { isNamedTunnel := namedTunnel != nil hostname, err := validation.ValidateHostname(c.String("hostname")) if err != nil { logger.Errorf("Invalid hostname: %s", err) - return nil, errors.Wrap(err, "Invalid hostname") + return nil, ingress.Ingress{}, errors.Wrap(err, "Invalid hostname") } isFreeTunnel := hostname == "" clientID := c.String("id") if !c.IsSet("id") { clientID, err = generateRandomClientID(logger) if err != nil { - return nil, err + return nil, ingress.Ingress{}, err } } tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) if err != nil { logger.Errorf("Tag parse failure: %s", err) - return nil, errors.Wrap(err, "Tag parse failure") + return nil, ingress.Ingress{}, errors.Wrap(err, "Tag parse failure") } tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) @@ -189,7 +190,7 @@ func prepareTunnelConfig( if !isFreeTunnel { originCert, err = getOriginCert(c, logger) if err != nil { - return nil, errors.Wrap(err, "Error getting origin cert") + return nil, ingress.Ingress{}, errors.Wrap(err, "Error getting origin cert") } } @@ -200,7 +201,7 @@ func prepareTunnelConfig( if isNamedTunnel { clientUUID, err := uuid.NewRandom() if err != nil { - return nil, errors.Wrap(err, "can't generate clientUUID") + return nil, ingress.Ingress{}, errors.Wrap(err, "can't generate clientUUID") } namedTunnel.Client = tunnelpogs.ClientInfo{ ClientID: clientUUID[:], @@ -210,10 +211,10 @@ func prepareTunnelConfig( } ingressRules, err = ingress.ParseIngress(config.GetConfiguration()) if err != nil && err != ingress.ErrNoIngressRules { - return nil, err + return nil, ingress.Ingress{}, err } if !ingressRules.IsEmpty() && c.IsSet("url") { - return nil, ingress.ErrURLIncompatibleWithIngress + return nil, ingress.Ingress{}, ingress.ErrURLIncompatibleWithIngress } } else { classicTunnel = &connection.ClassicTunnelConfig{ @@ -226,15 +227,15 @@ func prepareTunnelConfig( // Convert single-origin configuration into multi-origin configuration. if ingressRules.IsEmpty() { - ingressRules, err = ingress.NewSingleOrigin(c, compatibilityMode, logger) + ingressRules, err = ingress.NewSingleOrigin(c, !isNamedTunnel, logger) if err != nil { - return nil, err + return nil, ingress.Ingress{}, err } } protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, logger) if err != nil { - return nil, err + return nil, ingress.Ingress{}, err } logger.Infof("Initial protocol %s", protocolSelector.Current()) @@ -242,20 +243,12 @@ func prepareTunnelConfig( for _, p := range connection.ProtocolList { edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, p.ServerName()) if err != nil { - return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") + return nil, ingress.Ingress{}, errors.Wrap(err, "unable to create TLS config to connect with edge") } edgeTLSConfigs[p] = edgeTLSConfig } - proxyConfig := &origin.ProxyConfig{ - Client: httpTransport, - URL: originURL, - TLSConfig: httpTransport.TLSClientConfig, - HostHeader: c.String("http-host-header"), - NoChunkedEncoding: c.Bool("no-chunked-encoding"), - Tags: tags, - } - originClient := origin.NewClient(proxyConfig, logger) + originClient := origin.NewClient(ingressRules, tags, logger) connectionConfig := &connection.Config{ OriginClient: originClient, GracePeriod: c.Duration("grace-period"), @@ -275,7 +268,6 @@ func prepareTunnelConfig( return &origin.TunnelConfig{ ConnectionConfig: connectionConfig, - ProxyConfig: proxyConfig, BuildInfo: buildInfo, ClientID: clientID, EdgeAddrs: c.StringSlice("edge"), @@ -284,6 +276,7 @@ func prepareTunnelConfig( IsAutoupdated: c.Bool("is-autoupdated"), IsFreeTunnel: isFreeTunnel, LBPool: c.String("lb-pool"), + Tags: tags, Logger: logger, Observer: connection.NewObserver(transportLogger, tunnelEventChan), ReportedVersion: version, @@ -293,10 +286,9 @@ func prepareTunnelConfig( ClassicTunnel: classicTunnel, MuxerConfig: muxerConfig, TunnelEventChan: tunnelEventChan, - IngressRules: ingressRules, ProtocolSelector: protocolSelector, EdgeTLSConfigs: edgeTLSConfigs, - }, nil + }, ingressRules, nil } func isRunningFromTerminal() bool { diff --git a/connection/h2mux.go b/connection/h2mux.go index 35932fea..a85a44c8 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -22,7 +22,6 @@ const ( type h2muxConnection struct { config *Config muxerConfig *MuxerConfig - originURL string muxer *h2mux.Muxer // connectionID is only used by metrics, and prometheus requires labels to be string connIndexStr string @@ -54,7 +53,6 @@ func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, logger logger.S func NewH2muxConnection(ctx context.Context, config *Config, muxerConfig *MuxerConfig, - originURL string, edgeConn net.Conn, connIndex uint8, observer *Observer, @@ -62,7 +60,6 @@ func NewH2muxConnection(ctx context.Context, h := &h2muxConnection{ config: config, muxerConfig: muxerConfig, - originURL: originURL, connIndexStr: uint8ToString(connIndex), connIndex: connIndex, observer: observer, @@ -188,7 +185,7 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error { } func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) { - req, err := http.NewRequest("GET", h.originURL, h2mux.MuxedStreamReader{MuxedStream: stream}) + req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream}) if err != nil { return nil, errors.Wrap(err, "Unexpected error from http.NewRequest") } diff --git a/connection/http2.go b/connection/http2.go index b5724d10..7145caab 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -7,7 +7,6 @@ import ( "math" "net" "net/http" - "net/url" "strings" "sync" @@ -31,7 +30,6 @@ type HTTP2Connection struct { conn net.Conn server *http2.Server config *Config - originURL *url.URL namedTunnel *NamedTunnelConfig connOptions *tunnelpogs.ConnectionOptions observer *Observer @@ -44,7 +42,6 @@ type HTTP2Connection struct { func NewHTTP2Connection( conn net.Conn, config *Config, - originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, @@ -57,7 +54,6 @@ func NewHTTP2Connection( MaxConcurrentStreams: math.MaxUint32, }, config: config, - originURL: originURL, namedTunnel: namedTunnelConfig, connOptions: connOptions, observer: observer, @@ -83,9 +79,6 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.wg.Add(1) defer c.wg.Done() - r.URL.Scheme = c.originURL.Scheme - r.URL.Host = c.originURL.Host - respWriter := &http2RespWriter{ r: r.Body, w: w, diff --git a/hello/hello.go b/hello/hello.go index b78c8f7e..fcb41821 100644 --- a/hello/hello.go +++ b/hello/hello.go @@ -22,6 +22,7 @@ const ( UptimeRoute = "/uptime" WSRoute = "/ws" SSERoute = "/sse" + HealthRoute = "/_health" defaultSSEFreq = time.Second * 10 ) @@ -114,6 +115,7 @@ func StartHelloWorldServer(logger logger.Service, listener net.Listener, shutdow muxer.HandleFunc(UptimeRoute, uptimeHandler(time.Now())) muxer.HandleFunc(WSRoute, websocketHandler(logger, upgrader)) muxer.HandleFunc(SSERoute, sseHandler(logger)) + muxer.HandleFunc(HealthRoute, healthHandler()) muxer.HandleFunc("/", rootHandler(serverName)) httpServer := &http.Server{Addr: listener.Addr().String(), Handler: muxer} go func() { @@ -221,6 +223,12 @@ func sseHandler(logger logger.Service) http.HandlerFunc { } } +func healthHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + } +} + func rootHandler(serverName string) http.HandlerFunc { responseTemplate := template.Must(template.New("index").Parse(indexTemplate)) return func(w http.ResponseWriter, r *http.Request) { diff --git a/ingress/ingress.go b/ingress/ingress.go index 613b7c6a..7e189acb 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -63,9 +63,9 @@ type Ingress struct { // NewSingleOrigin constructs an Ingress set with only one rule, constructed from // legacy CLI parameters like --url or --no-chunked-encoding. -func NewSingleOrigin(c *cli.Context, compatibilityMode bool, logger logger.Service) (Ingress, error) { +func NewSingleOrigin(c *cli.Context, allowURLFromArgs bool, logger logger.Service) (Ingress, error) { - service, err := parseSingleOriginService(c, compatibilityMode) + service, err := parseSingleOriginService(c, allowURLFromArgs) if err != nil { return Ingress{}, err } @@ -85,19 +85,15 @@ func NewSingleOrigin(c *cli.Context, compatibilityMode bool, logger logger.Servi } // Get a single origin service from the CLI/config. -func parseSingleOriginService(c *cli.Context, compatibilityMode bool) (OriginService, error) { +func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginService, error) { if c.IsSet("hello-world") { return new(helloWorld), nil } if c.IsSet("url") { - originURLStr, err := config.ValidateUrl(c, compatibilityMode) + originURL, err := config.ValidateUrl(c, allowURLFromArgs) if err != nil { return nil, errors.Wrap(err, "Error validating origin URL") } - originURL, err := url.Parse(originURLStr) - if err != nil { - return nil, errors.Wrap(err, "couldn't parse origin URL") - } return &localService{URL: originURL, RootURL: originURL}, nil } if c.IsSet("unix-socket") { diff --git a/ingress/origin_service.go b/ingress/origin_service.go index 8e194bd9..0993d202 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -245,7 +245,7 @@ type statusCode struct { func newStatusCode(status int) statusCode { resp := &http.Response{ StatusCode: status, - Status: http.StatusText(status), + Status: fmt.Sprintf("%d %s", status, http.StatusText(status)), Body: new(NopReadCloser), } return statusCode{resp: resp} diff --git a/origin/h2mux.go b/origin/h2mux.go deleted file mode 100644 index 18c2bf32..00000000 --- a/origin/h2mux.go +++ /dev/null @@ -1,229 +0,0 @@ -package origin - -import ( - "bufio" - "context" - "io" - "net" - "net/http" - "strconv" - - "github.com/cloudflare/cloudflared/buffer" - "github.com/cloudflare/cloudflared/connection" - "github.com/cloudflare/cloudflared/h2mux" - "github.com/cloudflare/cloudflared/ingress" - "github.com/cloudflare/cloudflared/logger" - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - "github.com/cloudflare/cloudflared/websocket" - "github.com/pkg/errors" -) - -type TunnelHandler struct { - ingressRules ingress.Ingress - muxer *h2mux.Muxer - tags []tunnelpogs.Tag - metrics *TunnelMetrics - // connectionID is only used by metrics, and prometheus requires labels to be string - connectionID string - logger logger.Service - - bufferPool *buffer.Pool -} - -// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error -func NewTunnelHandler(ctx context.Context, - config *TunnelConfig, - addr *net.TCPAddr, - connectionID uint8, - bufferPool *buffer.Pool, -) (*TunnelHandler, string, error) { - h := &TunnelHandler{ - ingressRules: config.IngressRules, - tags: config.Tags, - metrics: config.Metrics, - connectionID: uint8ToString(connectionID), - logger: config.Logger, - bufferPool: bufferPool, - } - - edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr) - if err != nil { - return nil, "", err - } - // Establish a muxed connection with the edge - // Client mux handshake with agent server - h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, config.muxerConfig(h), h.metrics.activeStreams) - if err != nil { - return nil, "", errors.Wrap(err, "h2mux handshake with edge error") - } - return h, edgeConn.LocalAddr().String(), nil -} - -func (h *TunnelHandler) AppendTagHeaders(r *http.Request) { - for _, tag := range h.tags { - r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value) - } -} - -func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { - h.metrics.incrementRequests(h.connectionID) - defer h.metrics.decrementConcurrentRequests(h.connectionID) - - req, rule, reqErr := h.createRequest(stream) - if reqErr != nil { - h.writeErrorResponse(stream, reqErr) - return reqErr - } - - cfRay := findCfRayHeader(req) - lbProbe := isLBProbeRequest(req) - h.logRequest(req, cfRay, lbProbe) - - var resp *http.Response - var respErr error - if websocket.IsWebSocketUpgrade(req) { - resp, respErr = serveWebsocket(&h2muxWebsocketResp{stream}, req, rule) - } else { - resp, respErr = h.serveHTTP(stream, req, rule) - } - if respErr != nil { - h.writeErrorResponse(stream, respErr) - return respErr - } - h.logResponseOk(resp, cfRay, lbProbe) - return nil -} - -func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, *ingress.Rule, error) { - req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream}) - if err != nil { - return nil, nil, errors.Wrap(err, "Unexpected error from http.NewRequest") - } - err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req) - if err != nil { - return nil, nil, errors.Wrap(err, "invalid request received") - } - rule, _ := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path) - rule.Service.RewriteOriginURL(req.URL) - return req, rule, nil -} - -func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) { - // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate - if rule.Config.DisableChunkedEncoding { - req.TransferEncoding = []string{"gzip", "deflate"} - cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) - if err == nil { - req.ContentLength = int64(cLength) - } - } - - // Request origin to keep connection alive to improve performance - req.Header.Set("Connection", "keep-alive") - - if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { - req.Header.Set("Host", hostHeader) - req.Host = hostHeader - } - - response, err := h.httpClient.RoundTrip(req) - if err != nil { - return nil, errors.Wrap(err, "Error proxying request to origin") - } - defer response.Body.Close() - - headers := h2mux.H1ResponseToH2ResponseHeaders(response) - headers = append(headers, h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceOrigin)) - err = stream.WriteHeaders(headers) - if err != nil { - return nil, errors.Wrap(err, "Error writing response header") - } - if h.isEventStream(response) { - h.writeEventStream(stream, response.Body) - } else { - // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream - // compression generates dictionary on first write - buf := h.bufferPool.Get() - defer h.bufferPool.Put(buf) - io.CopyBuffer(stream, response.Body, buf) - } - return response, nil -} - -func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) { - reader := bufio.NewReader(responseBody) - for { - line, err := reader.ReadBytes('\n') - if err != nil { - break - } - stream.Write(line) - } -} - -func (h *TunnelHandler) isEventStream(response *http.Response) bool { - if response.Header.Get("content-type") == "text/event-stream" { - h.logger.Debug("Detected Server-Side Events from Origin") - return true - } - return false -} - -func (h *TunnelHandler) writeErrorResponse(stream *h2mux.MuxedStream, err error) { - h.logger.Errorf("HTTP request error: %s", err) - stream.WriteHeaders([]h2mux.Header{ - {Name: ":status", Value: "502"}, - h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceCloudflared), - }) - stream.Write([]byte("502 Bad Gateway")) - h.metrics.incrementResponses(h.connectionID, "502") -} - -func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool) { - logger := h.logger - if cfRay != "" { - logger.Debugf("CF-RAY: %s %s %s %s", cfRay, req.Method, req.URL, req.Proto) - } else if lbProbe { - logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, req.Method, req.URL, req.Proto) - } else { - logger.Infof("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, req.Method, req.URL, req.Proto) - } - logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, req.Header) - - if contentLen := req.ContentLength; contentLen == -1 { - logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay) - } else { - logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen) - } -} - -func (h *TunnelHandler) logResponseOk(r *http.Response, cfRay string, lbProbe bool) { - h.metrics.incrementResponses(h.connectionID, "200") - logger := h.logger - if cfRay != "" { - logger.Debugf("CF-RAY: %s %s", cfRay, r.Status) - } else if lbProbe { - logger.Debugf("Response to Load Balancer health check %s", r.Status) - } else { - logger.Infof("%s", r.Status) - } - logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) - - if contentLen := r.ContentLength; contentLen == -1 { - logger.Debugf("CF-RAY: %s Response content length unknown", cfRay) - } else { - logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen) - } -} - -func (h *TunnelHandler) UpdateMetrics(connectionID string) { - h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics()) -} - -type h2muxWebsocketResp struct { - *h2mux.MuxedStream -} - -func (wr *h2muxWebsocketResp) WriteRespHeaders(resp *http.Response) error { - return wr.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp)) -} diff --git a/origin/http2.go b/origin/http2.go deleted file mode 100644 index df77fd2f..00000000 --- a/origin/http2.go +++ /dev/null @@ -1,320 +0,0 @@ -package origin - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "strconv" - "strings" - - "github.com/cloudflare/cloudflared/h2mux" - "github.com/cloudflare/cloudflared/ingress" - "github.com/cloudflare/cloudflared/logger" - "github.com/cloudflare/cloudflared/tunnelrpc" - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - - "github.com/pkg/errors" - "golang.org/x/net/http2" - "zombiezen.com/go/capnproto2/rpc" -) - -const ( - internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade" - websocketUpgrade = "websocket" - controlPlaneUpgrade = "control-plane" -) - -type http2Server struct { - server *http2.Server - ingressRules ingress.Ingress - logger logger.Service - connIndexStr string - connIndex uint8 - config *TunnelConfig - localAddr net.Addr - shutdownChan chan struct{} - connectedFuse *h2mux.BooleanFuse -} - -func newHTTP2Server(config *TunnelConfig, connIndex uint8, localAddr net.Addr, connectedFuse *h2mux.BooleanFuse) (*http2Server, error) { - return &http2Server{ - server: &http2.Server{}, - ingressRules: config.IngressRules, - logger: config.Logger, - connIndexStr: uint8ToString(connIndex), - connIndex: connIndex, - config: config, - localAddr: localAddr, - shutdownChan: make(chan struct{}), - connectedFuse: connectedFuse, - }, nil -} - -func (c *http2Server) serve(ctx context.Context, conn net.Conn) { - go func() { - <-ctx.Done() - c.close(conn) - }() - c.server.ServeConn(conn, &http2.ServeConnOpts{ - Context: ctx, - Handler: c, - }) -} - -func (c *http2Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - c.config.Metrics.incrementRequests(c.connIndexStr) - defer c.config.Metrics.decrementConcurrentRequests(c.connIndexStr) - - cfRay := findCfRayHeader(r) - lbProbe := isLBProbeRequest(r) - c.logRequest(r, cfRay, lbProbe) - - rule, _ := c.ingressRules.FindMatchingRule(r.Host, r.URL.Path) - rule.Service.RewriteOriginURL(r.URL) - - var resp *http.Response - var err error - - if isControlPlaneUpgrade(r) { - stripWebsocketUpgradeHeader(r) - err = c.serveControlPlane(w, r) - } else if isWebsocketUpgrade(r) { - stripWebsocketUpgradeHeader(r) - var respBody BidirectionalStream - respBody, err = newHTTP2Stream(w, r) - if err == nil { - resp, err = serveWebsocket(respBody, r, rule) - } - } else { - resp, err = c.serveHTTP(w, r, rule) - } - - if err != nil { - c.writeErrorResponse(w, err) - return - } - if resp != nil { - resp.Body.Close() - } -} - -func (c *http2Server) serveHTTP(w http.ResponseWriter, r *http.Request, rule *ingress.Rule) (*http.Response, error) { - // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate - if rule.Config.DisableChunkedEncoding { - r.TransferEncoding = []string{"gzip", "deflate"} - cLength, err := strconv.Atoi(r.Header.Get("Content-Length")) - if err == nil { - r.ContentLength = int64(cLength) - } - } - - // Request origin to keep connection alive to improve performance - r.Header.Set("Connection", "keep-alive") - - if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { - r.Header.Set("Host", hostHeader) - r.Host = hostHeader - } - - resp, err := rule.HTTPTransport.RoundTrip(r) - if err != nil { - return nil, errors.Wrap(err, "Error proxying request to origin") - } - w.WriteHeader(resp.StatusCode) - _, err = io.Copy(w, resp.Body) - if err != nil { - return nil, errors.Wrap(err, "Copy response error") - } - return resp, nil -} - -func (c *http2Server) serveControlPlane(w http.ResponseWriter, r *http.Request) error { - stream, err := newHTTP2Stream(w, r) - if err != nil { - return err - } - - rpcTransport := tunnelrpc.NewTransportLogger(c.logger, rpc.StreamTransport(stream)) - rpcConn := rpc.NewConn( - rpcTransport, - tunnelrpc.ConnLog(c.logger), - ) - rpcClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(r.Context()), Conn: rpcConn} - - if err = c.registerConnection(r.Context(), rpcClient, 0); err != nil { - return err - } - c.connectedFuse.Fuse(true) - - <-c.shutdownChan - c.gracefulShutdown(rpcClient) - - // Closing the client will also close the connection - rpcClient.Close() - rpcTransport.Close() - close(c.shutdownChan) - return nil -} - -func (c *http2Server) registerConnection( - ctx context.Context, - rpcClient tunnelpogs.TunnelServer_PogsClient, - numPreviousAttempts uint8, -) error { - connDetail, err := rpcClient.RegisterConnection( - ctx, - c.config.NamedTunnel.Auth, - c.config.NamedTunnel.ID, - c.connIndex, - c.config.ConnectionOptions(c.localAddr.String(), numPreviousAttempts), - ) - if err != nil { - c.logger.Errorf("Cannot register connection, err: %v", err) - return err - } - c.logger.Infof("Connection %s registered with %s using ID %s", c.connIndexStr, connDetail.Location, connDetail.UUID) - return nil -} - -func (c *http2Server) gracefulShutdown(rpcClient tunnelpogs.TunnelServer_PogsClient) { - ctx, cancel := context.WithTimeout(context.Background(), c.config.GracePeriod) - defer cancel() - err := rpcClient.UnregisterConnection(ctx) - if err != nil { - c.logger.Errorf("Cannot unregister connection gracefully, err: %v", err) - return - } - c.logger.Info("Sent graceful shutdown signal") - - <-ctx.Done() -} - -func (c *http2Server) writeErrorResponse(w http.ResponseWriter, err error) { - c.logger.Errorf("HTTP request error: %s", err) - c.config.Metrics.incrementResponses(c.connIndexStr, "502") - jsonResponseMetaHeader, err := json.Marshal(h2mux.ResponseMetaHeader{Source: h2mux.ResponseSourceCloudflared}) - if err != nil { - panic(err) - } - w.Header().Set(h2mux.ResponseMetaHeaderField, string(jsonResponseMetaHeader)) - w.WriteHeader(http.StatusBadGateway) -} - -func (c *http2Server) logRequest(r *http.Request, cfRay string, lbProbe bool) { - logger := c.logger - if cfRay != "" { - logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto) - } else if lbProbe { - logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto) - } else { - logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto) - } - logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, r.Header) - - if contentLen := r.ContentLength; contentLen == -1 { - logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay) - } else { - logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen) - } -} - -func (c *http2Server) logResponseOk(r *http.Response, cfRay string, lbProbe bool) { - c.config.Metrics.incrementResponses(c.connIndexStr, "200") - logger := c.logger - if cfRay != "" { - logger.Debugf("CF-RAY: %s %s", cfRay, r.Status) - } else if lbProbe { - logger.Debugf("Response to Load Balancer health check %s", r.Status) - } else { - logger.Infof("%s", r.Status) - } - logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) - - if contentLen := r.ContentLength; contentLen == -1 { - logger.Debugf("CF-RAY: %s Response content length unknown", cfRay) - } else { - logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen) - } -} - -func (c *http2Server) close(conn net.Conn) { - // Send signal to control loop to start graceful shutdown - c.shutdownChan <- struct{}{} - // Wait for control loop to close channel - <-c.shutdownChan - conn.Close() -} - -type http2Stream struct { - r io.Reader - w http.ResponseWriter - flusher http.Flusher -} - -func newHTTP2Stream(w http.ResponseWriter, r *http.Request) (*http2Stream, error) { - flusher, ok := w.(http.Flusher) - if !ok { - return nil, fmt.Errorf("ResponseWriter doesn't implement http.Flusher") - } - return &http2Stream{r: r.Body, w: w, flusher: flusher}, nil -} - -func (wr *http2Stream) WriteRespHeaders(resp *http.Response) error { - dest := wr.w.Header() - userHeaders := make(http.Header, len(resp.Header)) - for header, values := range resp.Header { - // Since these are http2 headers, they're required to be lowercase - h2name := strings.ToLower(header) - for _, v := range values { - if h2name == "content-length" { - // This header has meaning in HTTP/2 and will be used by the edge, - // so it should be sent as an HTTP/2 response header. - dest.Add(h2name, v) - // Since these are http2 headers, they're required to be lowercase - } else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) { - // User headers, on the other hand, must all be serialized so that - // HTTP/2 header validation won't be applied to HTTP/1 header values - userHeaders.Add(h2name, v) - } - } - } - - // Perform user header serialization and set them in the single header - dest.Set(h2mux.ResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders)) - // HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1 - wr.w.WriteHeader(http.StatusOK) - wr.flusher.Flush() - return nil -} - -func (wr *http2Stream) Read(p []byte) (n int, err error) { - return wr.r.Read(p) -} - -func (wr *http2Stream) Write(p []byte) (n int, err error) { - n, err = wr.w.Write(p) - if err != nil { - return 0, err - } - wr.flusher.Flush() - return -} - -func (wr *http2Stream) Close() error { - return nil -} - -func isControlPlaneUpgrade(r *http.Request) bool { - return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == controlPlaneUpgrade -} - -func isWebsocketUpgrade(r *http.Request) bool { - return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == websocketUpgrade -} - -func stripWebsocketUpgradeHeader(r *http.Request) { - r.Header.Del(internalUpgradeHeader) -} diff --git a/origin/metrics.go b/origin/metrics.go index edf2cab6..b1064803 100644 --- a/origin/metrics.go +++ b/origin/metrics.go @@ -34,6 +34,14 @@ var ( }, []string{"status_code"}, ) + requestErrors = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: connection.MetricsNamespace, + Subsystem: connection.TunnelSubsystem, + Name: "request_errors", + Help: "Count of error proxying to origin", + }, + ) haConnections = prometheus.NewGauge( prometheus.GaugeOpts{ Namespace: connection.MetricsNamespace, @@ -49,6 +57,7 @@ func init() { totalRequests, concurrentRequests, responseByCode, + requestErrors, haConnections, ) } diff --git a/origin/proxy.go b/origin/proxy.go index 9c52ab24..589d601a 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -3,15 +3,15 @@ package origin import ( "bufio" "context" - "crypto/tls" + "fmt" "io" "net/http" - "net/url" "strconv" "strings" "github.com/cloudflare/cloudflared/buffer" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/websocket" @@ -23,28 +23,21 @@ const ( ) type client struct { - config *ProxyConfig - logger logger.Service - bufferPool *buffer.Pool + ingressRules ingress.Ingress + tags []tunnelpogs.Tag + logger logger.Service + bufferPool *buffer.Pool } -func NewClient(config *ProxyConfig, logger logger.Service) connection.OriginClient { +func NewClient(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, logger logger.Service) connection.OriginClient { return &client{ - config: config, - logger: logger, - bufferPool: buffer.NewPool(512 * 1024), + ingressRules: ingressRules, + tags: tags, + logger: logger, + bufferPool: buffer.NewPool(512 * 1024), } } -type ProxyConfig struct { - Client http.RoundTripper - URL *url.URL - TLSConfig *tls.Config - HostHeader string - NoChunkedEncoding bool - Tags []tunnelpogs.Tag -} - func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error { incrementRequests() defer decrementConcurrentRequests() @@ -53,29 +46,30 @@ func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsock lbProbe := isLBProbeRequest(req) c.appendTagHeaders(req) - c.logRequest(req, cfRay, lbProbe) + rule, ruleNum := c.ingressRules.FindMatchingRule(req.Host, req.URL.Path) + c.logRequest(req, cfRay, lbProbe, ruleNum) + var ( resp *http.Response err error ) if isWebsocket { - resp, err = c.proxyWebsocket(w, req) + resp, err = c.proxyWebsocket(w, req, rule) } else { - resp, err = c.proxyHTTP(w, req) + resp, err = c.proxyHTTP(w, req, rule) } if err != nil { - c.logger.Errorf("HTTP request error: %s", err) - responseByCode.WithLabelValues("502").Inc() + c.logRequestError(err, cfRay, ruleNum) w.WriteErrorResponse(err) return err } - c.logResponseOk(resp, cfRay, lbProbe) + c.logOriginResponse(resp, cfRay, lbProbe, ruleNum) return nil } -func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*http.Response, error) { +func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) { // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate - if c.config.NoChunkedEncoding { + if rule.Config.DisableChunkedEncoding { req.TransferEncoding = []string{"gzip", "deflate"} cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) if err == nil { @@ -86,9 +80,12 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*htt // Request origin to keep connection alive to improve performance req.Header.Set("Connection", "keep-alive") - c.setHostHeader(req) + if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { + req.Header.Set("Host", hostHeader) + req.Host = hostHeader + } - resp, err := c.config.Client.RoundTrip(req) + resp, err := rule.Service.RoundTrip(req) if err != nil { return nil, errors.Wrap(err, "Error proxying request to origin") } @@ -111,9 +108,17 @@ func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*htt return resp, nil } -func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request) (*http.Response, error) { - c.setHostHeader(req) - conn, resp, err := websocket.ClientConnect(req, c.config.TLSConfig) +func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule) (*http.Response, error) { + if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { + req.Header.Set("Host", hostHeader) + req.Host = hostHeader + } + + dialler, ok := rule.Service.(websocket.Dialler) + if !ok { + return nil, fmt.Errorf("Websockets aren't supported by the origin service '%s'", rule.Service) + } + conn, resp, err := websocket.ClientConnect(req, dialler) if err != nil { return nil, err } @@ -145,28 +150,22 @@ func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadC } } -func (c *client) setHostHeader(req *http.Request) { - if c.config.HostHeader != "" { - req.Header.Set("Host", c.config.HostHeader) - req.Host = c.config.HostHeader - } -} - func (c *client) appendTagHeaders(r *http.Request) { - for _, tag := range c.config.Tags { + for _, tag := range c.tags { r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value) } } -func (c *client) logRequest(r *http.Request, cfRay string, lbProbe bool) { +func (c *client) logRequest(r *http.Request, cfRay string, lbProbe bool, ruleNum int) { if cfRay != "" { c.logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto) } else if lbProbe { c.logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto) } else { - c.logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto) + c.logger.Debugf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto) } c.logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, r.Header) + c.logger.Debugf("CF-RAY: %s Serving with ingress rule %d", cfRay, ruleNum) if contentLen := r.ContentLength; contentLen == -1 { c.logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay) @@ -175,14 +174,14 @@ func (c *client) logRequest(r *http.Request, cfRay string, lbProbe bool) { } } -func (c *client) logResponseOk(r *http.Response, cfRay string, lbProbe bool) { - responseByCode.WithLabelValues("200").Inc() +func (c *client) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, ruleNum int) { + responseByCode.WithLabelValues(strconv.Itoa(r.StatusCode)).Inc() if cfRay != "" { - c.logger.Debugf("CF-RAY: %s %s", cfRay, r.Status) + c.logger.Infof("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, ruleNum) } else if lbProbe { c.logger.Debugf("Response to Load Balancer health check %s", r.Status) } else { - c.logger.Infof("%s", r.Status) + c.logger.Debugf("Status: %s served by ingress %d", r.Status, ruleNum) } c.logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) @@ -193,6 +192,16 @@ func (c *client) logResponseOk(r *http.Response, cfRay string, lbProbe bool) { } } +func (c *client) logRequestError(err error, cfRay string, ruleNum int) { + requestErrors.Inc() + if cfRay != "" { + c.logger.Errorf("CF-RAY: %s Proxying to ingress %d error: %v", cfRay, ruleNum, err) + } else { + c.logger.Errorf("Proxying to ingress %d error: %v", ruleNum, err) + } + +} + func findCfRayHeader(req *http.Request) string { return req.Header.Get("Cf-Ray") } diff --git a/origin/proxy_test.go b/origin/proxy_test.go index 9a286a3d..09b02ea0 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -3,27 +3,32 @@ package origin import ( "bytes" "context" - "crypto/tls" - "crypto/x509" + "flag" "fmt" "io" "net/http" "net/http/httptest" - "net/url" "sync" "testing" "time" + "github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/hello" + "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" - "github.com/cloudflare/cloudflared/tlsconfig" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/urfave/cli/v2" "github.com/gobwas/ws/wsutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +var ( + testTags = []tunnelpogs.Tag(nil) +) + type mockHTTPRespWriter struct { *httptest.ResponseRecorder } @@ -99,49 +104,39 @@ func (w *mockSSERespWriter) ReadBytes() []byte { return <-w.writeNotification } -func TestProxy(t *testing.T) { +func TestProxySingleOrigin(t *testing.T) { logger, err := logger.New() require.NoError(t, err) - // let runtime pick an available port - listener, err := hello.CreateTLSListener("127.0.0.1:0") - require.NoError(t, err) - - originURL := &url.URL{ - Scheme: "https", - Host: listener.Addr().String(), - } - originCA := x509.NewCertPool() - helloCert, err := tlsconfig.GetHelloCertificateX509() - require.NoError(t, err) - originCA.AddCert(helloCert) - clientTLS := &tls.Config{ - RootCAs: originCA, - } - proxyConfig := &ProxyConfig{ - Client: &http.Transport{ - TLSClientConfig: clientTLS, - }, - URL: originURL, - TLSConfig: clientTLS, - } ctx, cancel := context.WithCancel(context.Background()) - go func() { - hello.StartHelloWorldServer(logger, listener, ctx.Done()) - }() + flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError) + flagSet.Bool("hello-world", true, "") - client := NewClient(proxyConfig, logger) - t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL)) - t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS)) - t.Run("testProxySSE", testProxySSE(t, client, originURL)) + cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil) + err = cliCtx.Set("hello-world", "true") + require.NoError(t, err) + + allowURLFromArgs := false + ingressRule, err := ingress.NewSingleOrigin(cliCtx, allowURLFromArgs, logger) + require.NoError(t, err) + + var wg sync.WaitGroup + errC := make(chan error) + ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) + + client := NewClient(ingressRule, testTags, logger) + t.Run("testProxyHTTP", testProxyHTTP(t, client)) + t.Run("testProxyWebsocket", testProxyWebsocket(t, client)) + t.Run("testProxySSE", testProxySSE(t, client)) cancel() + wg.Wait() } -func testProxyHTTP(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) { +func testProxyHTTP(t *testing.T, client connection.OriginClient) func(t *testing.T) { return func(t *testing.T) { respWriter := newMockHTTPRespWriter() - req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) + req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) require.NoError(t, err) err = client.Proxy(respWriter, req, false) @@ -151,11 +146,11 @@ func testProxyHTTP(t *testing.T, client connection.OriginClient, originURL *url. } } -func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL *url.URL, tlsConfig *tls.Config) func(t *testing.T) { +func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *testing.T) { return func(t *testing.T) { // WSRoute is a websocket echo handler ctx, cancel := context.WithCancel(context.Background()) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), nil) readPipe, writePipe := io.Pipe() respWriter := newMockWSRespWriter(readPipe) @@ -191,7 +186,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL } } -func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) { +func testProxySSE(t *testing.T, client connection.OriginClient) func(t *testing.T) { return func(t *testing.T) { var ( pushCount = 50 @@ -199,7 +194,7 @@ func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.U ) respWriter := newMockSSERespWriter() ctx, cancel := context.WithCancel(context.Background()) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s?freq=%s", originURL, hello.SSERoute, pushFreq), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s?freq=%s", hello.SSERoute, pushFreq), nil) require.NoError(t, err) var wg sync.WaitGroup @@ -225,3 +220,98 @@ func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.U wg.Wait() } } + +func TestProxyMultipleOrigins(t *testing.T) { + api := httptest.NewServer(mockAPI{}) + defer api.Close() + + unvalidatedIngress := []config.UnvalidatedIngressRule{ + { + Hostname: "api.example.com", + Service: api.URL, + }, + { + Hostname: "hello.example.com", + Service: "hello-world", + }, + { + Hostname: "health.example.com", + Path: "/health", + Service: "http_status:200", + }, + { + Hostname: "*", + Service: "http_status:404", + }, + } + + ingress, err := ingress.ParseIngress(&config.Configuration{ + TunnelID: t.Name(), + Ingress: unvalidatedIngress, + }) + require.NoError(t, err) + + logger, err := logger.New() + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + errC := make(chan error) + var wg sync.WaitGroup + ingress.StartOrigins(&wg, logger, ctx.Done(), errC) + + client := NewClient(ingress, testTags, logger) + + tests := []struct { + url string + expectedStatus int + expectedBody []byte + }{ + { + url: "http://api.example.com", + expectedStatus: http.StatusCreated, + expectedBody: []byte("Created"), + }, + { + url: fmt.Sprintf("http://hello.example.com%s", hello.HealthRoute), + expectedStatus: http.StatusOK, + expectedBody: []byte("ok"), + }, + { + url: "http://health.example.com/health", + expectedStatus: http.StatusOK, + }, + { + url: "http://health.example.com/", + expectedStatus: http.StatusNotFound, + }, + { + url: "http://not-found.example.com", + expectedStatus: http.StatusNotFound, + }, + } + + for _, test := range tests { + respWriter := newMockHTTPRespWriter() + req, err := http.NewRequest(http.MethodGet, test.url, nil) + require.NoError(t, err) + + err = client.Proxy(respWriter, req, false) + require.NoError(t, err) + + assert.Equal(t, test.expectedStatus, respWriter.Code) + if test.expectedBody != nil { + assert.Equal(t, test.expectedBody, respWriter.Body.Bytes()) + } else { + assert.Equal(t, 0, respWriter.Body.Len()) + } + } + cancel() + wg.Wait() +} + +type mockAPI struct{} + +func (ma mockAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + w.Write([]byte("Created")) +} diff --git a/origin/tunnel.go b/origin/tunnel.go index 75e01f33..211ab8e4 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -20,7 +20,6 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" - "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/tunnelrpc" @@ -47,7 +46,6 @@ const ( type TunnelConfig struct { ConnectionConfig *connection.Config - ProxyConfig *ProxyConfig BuildInfo *buildinfo.BuildInfo ClientID string CloseConnOnce *sync.Once // Used to close connectedSignal no more than once @@ -57,6 +55,7 @@ type TunnelConfig struct { IsAutoupdated bool IsFreeTunnel bool LBPool string + Tags []tunnelpogs.Tag Logger logger.Service Observer *connection.Observer ReportedVersion string @@ -67,7 +66,6 @@ type TunnelConfig struct { ClassicTunnel *connection.ClassicTunnelConfig MuxerConfig *connection.MuxerConfig TunnelEventChan chan ui.TunnelEvent - IngressRules ingress.Ingress ProtocolSelector connection.ProtocolSelector EdgeTLSConfigs map[connection.Protocol]*tls.Config } @@ -113,7 +111,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str OS: fmt.Sprintf("%s_%s", c.BuildInfo.GoOS, c.BuildInfo.GoArch), ExistingTunnelPolicy: policy, PoolName: c.LBPool, - Tags: c.ProxyConfig.Tags, + Tags: c.Tags, ConnectionID: connectionID, OriginLocalIP: OriginLocalIP, IsAutoupdated: c.IsAutoupdated, @@ -324,7 +322,7 @@ func ServeH2mux( ) (err error, recoverable bool) { config.Logger.Debugf("Connecting via h2mux") // Returns error from parsing the origin URL or handshake errors - handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer) + handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, edgeConn, connectionIndex, config.Observer) if err != nil { return err, recoverable } @@ -388,7 +386,7 @@ func ServeHTTP2( reconnectCh chan ReconnectSignal, ) (err error, recoverable bool) { config.Logger.Debugf("Connecting via http2") - server := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse) + server := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse) errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { diff --git a/validation/validation.go b/validation/validation.go index e22d206c..f8a4e5b8 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -166,7 +166,12 @@ func validateIP(scheme, host, port string) (string, error) { } // originURL shouldn't be a pointer, because this function might change the scheme -func ValidateHTTPService(originURL url.URL, hostname string, transport http.RoundTripper) error { +func ValidateHTTPService(originURL string, hostname string, transport http.RoundTripper) error { + parsedURL, err := url.Parse(originURL) + if err != nil { + return err + } + client := &http.Client{ Transport: transport, CheckRedirect: func(req *http.Request, via []*http.Request) error { @@ -175,7 +180,7 @@ func ValidateHTTPService(originURL url.URL, hostname string, transport http.Roun Timeout: validationTimeout, } - initialRequest, err := http.NewRequest("GET", originURL.String(), nil) + initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil) if err != nil { return err } @@ -187,10 +192,10 @@ func ValidateHTTPService(originURL url.URL, hostname string, transport http.Roun } // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck? - oldScheme := originURL.Scheme - originURL.Scheme = toggleProtocol(originURL.Scheme) + oldScheme := parsedURL.Scheme + parsedURL.Scheme = toggleProtocol(oldScheme) - secondRequest, err := http.NewRequest("GET", originURL.String(), nil) + secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil) if err != nil { return err } @@ -200,9 +205,9 @@ func ValidateHTTPService(originURL url.URL, hostname string, transport http.Roun resp.Body.Close() return errors.Errorf( "%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %v", - originURL.Host, + parsedURL.Host, oldScheme, - originURL.Scheme, + parsedURL.Scheme, initialErr, originURL, ) diff --git a/validation/validation_test.go b/validation/validation_test.go index 0745b085..d740aff7 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -123,7 +123,7 @@ func TestToggleProtocol(t *testing.T) { // Happy path 1: originURL is HTTP, and HTTP connections work func TestValidateHTTPService_HTTP2HTTP(t *testing.T) { - originURL := mustParse(t, "http://127.0.0.1/") + originURL := "http://127.0.0.1/" hostname := "example.com" assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { @@ -151,7 +151,7 @@ func TestValidateHTTPService_HTTP2HTTP(t *testing.T) { // Happy path 2: originURL is HTTPS, and HTTPS connections work func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) { - originURL := mustParse(t, "https://127.0.0.1:1234/") + originURL := "https://127.0.0.1:1234/" hostname := "example.com" assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { @@ -179,7 +179,7 @@ func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) { // Error path 1: originURL is HTTPS, but HTTP connections work func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) { - originURL := mustParse(t, "https://127.0.0.1:1234/") + originURL := "https://127.0.0.1:1234/" hostname := "example.com" assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { @@ -207,13 +207,10 @@ func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) { // Error path 2: originURL is HTTP, but HTTPS connections work func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) { - originURLWithPort := url.URL{ - Scheme: "http", - Host: "127.0.0.1:1234", - } + originURL := "http://127.0.0.1:1234/" hostname := "example.com" - assert.Error(t, ValidateHTTPService(originURLWithPort, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { + assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Equal(t, req.Host, hostname) if req.URL.Scheme == "http" { return nil, assert.AnError @@ -224,7 +221,7 @@ func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) { panic("Shouldn't reach here") }))) - assert.Error(t, ValidateHTTPService(originURLWithPort, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { + assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) { assert.Equal(t, req.Host, hostname) if req.URL.Scheme == "http" { return nil, assert.AnError @@ -253,14 +250,12 @@ func TestValidateHTTPService_NoFollowRedirects(t *testing.T) { })) assert.NoError(t, err) defer redirectServer.Close() - redirectServerURL, err := url.Parse(redirectServer.URL) - assert.NoError(t, err) - assert.NoError(t, ValidateHTTPService(*redirectServerURL, hostname, redirectClient.Transport)) + assert.NoError(t, ValidateHTTPService(redirectServer.URL, hostname, redirectClient.Transport)) } // Ensure validation times out when origin URL is nonresponsive func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) { - originURL := mustParse(t, "http://127.0.0.1/") + originURL := "http://127.0.0.1/" hostname := "example.com" oldValidationTimeout := validationTimeout defer func() { @@ -376,9 +371,3 @@ func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *h return server, client, nil } - -func mustParse(t *testing.T, originURL string) url.URL { - parsedURL, err := url.Parse(originURL) - assert.NoError(t, err) - return *parsedURL -}