From 8d7b2575ba6277fb736f12aa63393831166d8c4b Mon Sep 17 00:00:00 2001 From: cthuang Date: Fri, 11 Sep 2020 23:02:34 +0100 Subject: [PATCH] TUN-3400: Use Go HTTP2 library as transport to connect with the edge --- cmd/cloudflared/tunnel/configuration.go | 5 +- connection/dial.go | 3 +- origin/connection.go | 14 +++ origin/server.go | 160 ++++++++++++++++++++++++ origin/tunnel.go | 120 ++++++++++++++++-- tlsconfig/certreloader.go | 13 +- validation/validation.go | 46 +++---- 7 files changed, 324 insertions(+), 37 deletions(-) create mode 100644 origin/connection.go create mode 100644 origin/server.go diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index bb033835..c5e4978c 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -156,7 +156,8 @@ func prepareTunnelConfig( transportLogger logger.Service, namedTunnel *origin.NamedTunnelConfig, ) (*origin.TunnelConfig, error) { - compatibilityMode := namedTunnel == nil + isNamedTunnel := namedTunnel != nil + compatibilityMode := !isNamedTunnel hostname, err := validation.ValidateHostname(c.String("hostname")) if err != nil { @@ -219,7 +220,7 @@ func prepareTunnelConfig( } } - toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c) + toEdgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, isNamedTunnel) if err != nil { logger.Errorf("unable to create TLS config to connect with edge: %s", err) return nil, errors.Wrap(err, "unable to create TLS config to connect with edge") diff --git a/connection/dial.go b/connection/dial.go index 9363700c..4651afd9 100644 --- a/connection/dial.go +++ b/connection/dial.go @@ -9,7 +9,7 @@ import ( "github.com/pkg/errors" ) -// DialEdge makes a TLS connection to a Cloudflare edge node +// DialEdgeWithH2Mux makes a TLS connection to a Cloudflare edge node func DialEdge( ctx context.Context, timeout time.Duration, @@ -25,6 +25,7 @@ func DialEdge( if err != nil { return nil, newDialError(err, "DialContext error") } + tlsEdgeConn := tls.Client(edgeConn, tlsConfig) tlsEdgeConn.SetDeadline(time.Now().Add(timeout)) diff --git a/origin/connection.go b/origin/connection.go new file mode 100644 index 00000000..11a930d7 --- /dev/null +++ b/origin/connection.go @@ -0,0 +1,14 @@ +package origin + +import ( + "net" +) + +// persistentTCPConn is a wrapper around net.Conn that is noop when Close is called +type persistentConn struct { + net.Conn +} + +func (pc *persistentConn) Close() error { + return nil +} diff --git a/origin/server.go b/origin/server.go new file mode 100644 index 00000000..a5b1e89d --- /dev/null +++ b/origin/server.go @@ -0,0 +1,160 @@ +package origin + +import ( + "context" + "encoding/json" + "io" + "net" + "net/http" + "net/url" + "strings" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/logger" + "golang.org/x/net/http2" +) + +type cfdServer struct { + httpServer *http2.Server + originClient http.RoundTripper + logger logger.Service + originURL *url.URL + connectionIndex string + config *TunnelConfig +} + +func (c *cfdServer) serve(ctx context.Context, conn net.Conn) { + go func() { + <-ctx.Done() + conn.Close() + }() + c.httpServer.ServeConn(conn, &http2.ServeConnOpts{ + Context: ctx, + Handler: c, + }) +} + +func (c *cfdServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.config.Metrics.incrementRequests(c.connectionIndex) + defer c.config.Metrics.decrementConcurrentRequests(c.connectionIndex) + + cfRay := findCfRayHeader(r) + lbProbe := isLBProbeRequest(r) + c.logRequest(r, cfRay, lbProbe) + + r.URL = c.originURL + // TODO: TUN-3406 support websocket, event stream and WSGI servers. + var resp *http.Response + var err error + if strings.ToLower(r.Header.Get("Cf-Int-Argo-Tunnel-Upgrade")) == "websocket" { + resp, err = serveWebsocket(newWebsocketBody(w, r, c.logger), r, c.config.HTTPHostHeader, c.config.ClientTlsConfig) + } else { + resp, err = c.originClient.RoundTrip(r) + } + if err != nil { + c.writeErrorResponse(w, err) + return + } + defer resp.Body.Close() + + w.WriteHeader(resp.StatusCode) + _, err = io.Copy(w, resp.Body) + if err != nil { + c.logger.Errorf("Copy response error, err: %v", err) + w.WriteHeader(http.StatusBadGateway) + return + } +} + +func (c *cfdServer) writeErrorResponse(w http.ResponseWriter, err error) { + c.logger.Errorf("HTTP request error: %s", err) + c.config.Metrics.incrementResponses(c.connectionIndex, "502") + jsonResponseMetaHeader, err := json.Marshal(h2mux.ResponseMetaHeader{Source: h2mux.ResponseSourceCloudflared}) + if err != nil { + panic(err) + } + w.Header().Set(h2mux.ResponseMetaHeaderField, string(jsonResponseMetaHeader)) + w.WriteHeader(http.StatusBadGateway) +} + +func (c *cfdServer) logRequest(r *http.Request, cfRay string, lbProbe bool) { + logger := c.logger + if cfRay != "" { + logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto) + } else if lbProbe { + logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto) + } else { + logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto) + } + logger.Infof("CF-RAY: %s Request Headers %+v", cfRay, r.Header) + + if contentLen := r.ContentLength; contentLen == -1 { + logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay) + } else { + logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen) + } +} + +func (c *cfdServer) logResponseOk(r *http.Response, cfRay string, lbProbe bool) { + c.config.Metrics.incrementResponses(c.connectionIndex, "200") + logger := c.logger + if cfRay != "" { + logger.Debugf("CF-RAY: %s %s", cfRay, r.Status) + } else if lbProbe { + logger.Debugf("Response to Load Balancer health check %s", r.Status) + } else { + logger.Infof("%s", r.Status) + } + logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) + + if contentLen := r.ContentLength; contentLen == -1 { + logger.Debugf("CF-RAY: %s Response content length unknown", cfRay) + } else { + logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen) + } +} + +type WebsocketResp interface { + WriteRespHeaders(*http.Response) error + io.ReadWriter +} + +type http2WebsocketResp struct { + pr *io.PipeReader + w http.ResponseWriter +} + +func newWebsocketBody(w http.ResponseWriter, r *http.Request, logger logger.Service) *http2WebsocketResp { + pr, pw := io.Pipe() + go func() { + n, err := io.Copy(pw, r.Body) + logger.Errorf("websocket body copy ended, err: %v, bytes: %d", err, n) + }() + return &http2WebsocketResp{pr: pr, w: w} +} + +func (wr *http2WebsocketResp) WriteRespHeaders(resp *http.Response) error { + dest := wr.w.Header() + for name, values := range resp.Header { + for _, v := range values { + dest.Add(name, v) + } + } + return nil +} + +func (wr *http2WebsocketResp) Read(p []byte) (n int, err error) { + return wr.pr.Read(p) +} + +func (wr *http2WebsocketResp) Write(p []byte) (n int, err error) { + return wr.w.Write(p) +} + +type h2muxWebsocketResp struct { + *h2mux.MuxedStream +} + +func (wr *h2muxWebsocketResp) WriteRespHeaders(resp *http.Response) error { + return wr.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp)) +} diff --git a/origin/tunnel.go b/origin/tunnel.go index f15cd3b0..f94937d5 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -17,7 +17,9 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" + "golang.org/x/net/http2" "golang.org/x/sync/errgroup" + "zombiezen.com/go/capnproto2/rpc" "github.com/cloudflare/cloudflared/buffer" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" @@ -30,6 +32,7 @@ import ( "github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/validation" "github.com/cloudflare/cloudflared/websocket" ) @@ -304,7 +307,14 @@ func ServeTunnel( connectionTag := uint8ToString(connectionIndex) if config.NamedTunnel != nil && config.NamedTunnel.Protocol == http2Protocol { - return ServeNamedTunnel(ctx, config, connectionIndex, addr, connectedFuse, reconnectCh) + tlsConn, err := RegisterConnection(ctx, config, connectionIndex, uint8(backoff.retries), addr) + if err != nil { + logger.Errorf("Register connectio error: %+v", err) + return err, true + } + connectedFuse.Fuse(true) + backoff.SetGracePeriod() + return serveNamedTunnel(ctx, config, tlsConn, connectionIndex, reconnectCh) } // Returns error from parsing the origin URL or handshake errors @@ -332,10 +342,6 @@ func ServeTunnel( } }() - if config.NamedTunnel != nil { - return RegisterConnection(ctx, handler.muxer, config, connectionIndex, originLocalAddr, uint8(backoff.retries)) - } - if config.UseReconnectToken && connectedFuse.Value() { err := ReconnectTunnel(serveCtx, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager) if err == nil { @@ -426,7 +432,55 @@ func ServeTunnel( return nil, true } -func RegisterConnection( +func serveNamedTunnel( + ctx context.Context, + config *TunnelConfig, + tlsConn net.Conn, + connectionIndex uint8, + reconnectCh chan ReconnectSignal, +) (err error, recoverable bool) { + originURLStr, err := validation.ValidateUrl(config.OriginUrl) + if err != nil { + return fmt.Errorf("unable to parse origin URL %#v", config.OriginUrl), false + } + originURL, err := url.Parse(originURLStr) + if err != nil { + return fmt.Errorf("unable to parse origin URL %#v", originURLStr), false + } + + originClient := config.HTTPTransport + if originClient == nil { + originClient = http.DefaultTransport + } + + errGroup, serveCtx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + cfdServer := &cfdServer{ + httpServer: &http2.Server{}, + originClient: originClient, + logger: config.Logger, + originURL: originURL, + connectionIndex: uint8ToString(connectionIndex), + config: config, + } + cfdServer.serve(serveCtx, tlsConn) + return fmt.Errorf("Connection with edge closed") + }) + + errGroup.Go(func() error { + select { + case reconnect := <-reconnectCh: + return &reconnect + case <-serveCtx.Done(): + return nil + } + }) + + err = errGroup.Wait() + return err, true +} + +func RegisterConnectionWithH2Mux( ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, @@ -470,6 +524,52 @@ func RegisterConnection( return nil } +func RegisterConnection( + ctx context.Context, + config *TunnelConfig, + connectionID uint8, + numPreviousAttempts uint8, + addr *net.TCPAddr, +) (net.Conn, error) { + originCert, err := tls.X509KeyPair(config.OriginCert, config.OriginCert) + if err != nil { + return nil, err + } + tlsConfig := config.TlsConfig + tlsConfig.Certificates = []tls.Certificate{originCert} + tlsServerConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr) + if err != nil { + return nil, err + } + + rpcTransport := tunnelrpc.NewTransportLogger(config.Logger, rpc.StreamTransport(&persistentConn{tlsServerConn})) + rpcConn := rpc.NewConn( + rpcTransport, + tunnelrpc.ConnLog(config.Logger), + ) + rpcClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx), Conn: rpcConn} + connDetail, err := rpcClient.RegisterConnection( + ctx, + config.NamedTunnel.Auth, + config.NamedTunnel.ID, + connectionID, + config.ConnectionOptions(tlsServerConn.LocalAddr().String(), numPreviousAttempts), + ) + if err != nil { + return nil, err + } + config.Logger.Infof("Connection %d registered with %s using ID %s", connectionID, connDetail.Location, connDetail.UUID) + rpcTransport.Close() + // Closing the client will also close the connection + rpcClient.Close() + + flushMessage := make([]byte, 8) + buf := make([]byte, len(flushMessage)) + tlsServerConn.Write(buf) + + return tlsServerConn, nil +} + func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError { if retryable, ok := err.(*tunnelpogs.RetryableError); ok { return &serverRegisterTunnelError{ @@ -698,7 +798,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { var resp *http.Response var respErr error if websocket.IsWebSocketUpgrade(req) { - resp, respErr = h.serveWebsocket(stream, req, rule) + resp, respErr = serveWebsocket(&h2muxWebsocketResp{stream}, req, rule) } else { resp, respErr = h.serveHTTP(stream, req, rule) } @@ -725,7 +825,7 @@ func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, return req, rule, nil } -func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) { +func serveWebsocket(wsResp WebsocketResp, req *http.Request, rule *ingress.Rule) (*http.Response, error) { if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { req.Header.Set("Host", hostHeader) req.Host = hostHeader @@ -740,13 +840,13 @@ func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Requ return nil, err } defer conn.Close() - err = stream.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(response)) + err = wsResp.WriteRespHeaders(response) if err != nil { return nil, errors.Wrap(err, "Error writing response header") } // Copy to/from stream to the undelying connection. Use the underlying // connection because cloudflared doesn't operate on the message themselves - websocket.Stream(conn.UnderlyingConn(), stream) + websocket.Stream(conn.UnderlyingConn(), wsResp) return response, nil } diff --git a/tlsconfig/certreloader.go b/tlsconfig/certreloader.go index 041392bc..357d009d 100644 --- a/tlsconfig/certreloader.go +++ b/tlsconfig/certreloader.go @@ -18,7 +18,10 @@ const ( OriginCAPoolFlag = "origin-ca-pool" CaCertFlag = "cacert" - edgeTLSServerName = "cftunnel.com" + // edgeH2muxTLSServerName is the server name to establish h2mux connection with edge + edgeH2muxTLSServerName = "cftunnel.com" + // edgeH2TLSServerName is the server name to establish http2 connection with edge + edgeH2TLSServerName = "h2.cftunnel.com" ) // CertReloader can load and reload a TLS certificate from a particular filepath. @@ -120,13 +123,17 @@ func LoadCustomOriginCA(originCAFilename string) (*x509.CertPool, error) { return certPool, nil } -func CreateTunnelConfig(c *cli.Context) (*tls.Config, error) { +func CreateTunnelConfig(c *cli.Context, isNamedTunnel bool) (*tls.Config, error) { var rootCAs []string if c.String(CaCertFlag) != "" { rootCAs = append(rootCAs, c.String(CaCertFlag)) } - userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: edgeTLSServerName} + serverName := edgeH2muxTLSServerName + if isNamedTunnel { + serverName = edgeH2TLSServerName + } + userConfig := &TLSParameters{RootCAs: rootCAs, ServerName: serverName} tlsConfig, err := GetConfig(userConfig) if err != nil { return nil, err diff --git a/validation/validation.go b/validation/validation.go index 21ba940d..3007aea1 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -176,28 +176,32 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round return err } initialRequest.Host = hostname - _, initialErr := client.Do(initialRequest) - if initialErr != nil { - // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck? - oldScheme := parsedURL.Scheme - parsedURL.Scheme = toggleProtocol(parsedURL.Scheme) + resp, initialErr := client.Do(initialRequest) + if initialErr == nil { + resp.Body.Close() + return nil + } - secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil) - if err != nil { - return err - } - secondRequest.Host = hostname - _, secondErr := client.Do(secondRequest) - if secondErr == nil { // Worked this time--advise the user to switch protocols - return errors.Errorf( - "%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s", - parsedURL.Host, - oldScheme, - parsedURL.Scheme, - initialErr, - parsedURL, - ) - } + // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck? + oldScheme := parsedURL.Scheme + parsedURL.Scheme = toggleProtocol(parsedURL.Scheme) + + secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil) + if err != nil { + return err + } + secondRequest.Host = hostname + resp, secondErr := client.Do(secondRequest) + if secondErr == nil { // Worked this time--advise the user to switch protocols + resp.Body.Close() + return errors.Errorf( + "%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s", + parsedURL.Host, + oldScheme, + parsedURL.Scheme, + initialErr, + parsedURL, + ) } return initialErr