diff --git a/origin/h2mux.go b/origin/h2mux.go new file mode 100644 index 00000000..18c2bf32 --- /dev/null +++ b/origin/h2mux.go @@ -0,0 +1,229 @@ +package origin + +import ( + "bufio" + "context" + "io" + "net" + "net/http" + "strconv" + + "github.com/cloudflare/cloudflared/buffer" + "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/logger" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/websocket" + "github.com/pkg/errors" +) + +type TunnelHandler struct { + ingressRules ingress.Ingress + muxer *h2mux.Muxer + tags []tunnelpogs.Tag + metrics *TunnelMetrics + // connectionID is only used by metrics, and prometheus requires labels to be string + connectionID string + logger logger.Service + + bufferPool *buffer.Pool +} + +// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error +func NewTunnelHandler(ctx context.Context, + config *TunnelConfig, + addr *net.TCPAddr, + connectionID uint8, + bufferPool *buffer.Pool, +) (*TunnelHandler, string, error) { + h := &TunnelHandler{ + ingressRules: config.IngressRules, + tags: config.Tags, + metrics: config.Metrics, + connectionID: uint8ToString(connectionID), + logger: config.Logger, + bufferPool: bufferPool, + } + + edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr) + if err != nil { + return nil, "", err + } + // Establish a muxed connection with the edge + // Client mux handshake with agent server + h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, config.muxerConfig(h), h.metrics.activeStreams) + if err != nil { + return nil, "", errors.Wrap(err, "h2mux handshake with edge error") + } + return h, edgeConn.LocalAddr().String(), nil +} + +func (h *TunnelHandler) AppendTagHeaders(r *http.Request) { + for _, tag := range h.tags { + r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value) + } +} + +func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { + h.metrics.incrementRequests(h.connectionID) + defer h.metrics.decrementConcurrentRequests(h.connectionID) + + req, rule, reqErr := h.createRequest(stream) + if reqErr != nil { + h.writeErrorResponse(stream, reqErr) + return reqErr + } + + cfRay := findCfRayHeader(req) + lbProbe := isLBProbeRequest(req) + h.logRequest(req, cfRay, lbProbe) + + var resp *http.Response + var respErr error + if websocket.IsWebSocketUpgrade(req) { + resp, respErr = serveWebsocket(&h2muxWebsocketResp{stream}, req, rule) + } else { + resp, respErr = h.serveHTTP(stream, req, rule) + } + if respErr != nil { + h.writeErrorResponse(stream, respErr) + return respErr + } + h.logResponseOk(resp, cfRay, lbProbe) + return nil +} + +func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, *ingress.Rule, error) { + req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream}) + if err != nil { + return nil, nil, errors.Wrap(err, "Unexpected error from http.NewRequest") + } + err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req) + if err != nil { + return nil, nil, errors.Wrap(err, "invalid request received") + } + rule, _ := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path) + rule.Service.RewriteOriginURL(req.URL) + return req, rule, nil +} + +func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) { + // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate + if rule.Config.DisableChunkedEncoding { + req.TransferEncoding = []string{"gzip", "deflate"} + cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) + if err == nil { + req.ContentLength = int64(cLength) + } + } + + // Request origin to keep connection alive to improve performance + req.Header.Set("Connection", "keep-alive") + + if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { + req.Header.Set("Host", hostHeader) + req.Host = hostHeader + } + + response, err := h.httpClient.RoundTrip(req) + if err != nil { + return nil, errors.Wrap(err, "Error proxying request to origin") + } + defer response.Body.Close() + + headers := h2mux.H1ResponseToH2ResponseHeaders(response) + headers = append(headers, h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceOrigin)) + err = stream.WriteHeaders(headers) + if err != nil { + return nil, errors.Wrap(err, "Error writing response header") + } + if h.isEventStream(response) { + h.writeEventStream(stream, response.Body) + } else { + // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream + // compression generates dictionary on first write + buf := h.bufferPool.Get() + defer h.bufferPool.Put(buf) + io.CopyBuffer(stream, response.Body, buf) + } + return response, nil +} + +func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) { + reader := bufio.NewReader(responseBody) + for { + line, err := reader.ReadBytes('\n') + if err != nil { + break + } + stream.Write(line) + } +} + +func (h *TunnelHandler) isEventStream(response *http.Response) bool { + if response.Header.Get("content-type") == "text/event-stream" { + h.logger.Debug("Detected Server-Side Events from Origin") + return true + } + return false +} + +func (h *TunnelHandler) writeErrorResponse(stream *h2mux.MuxedStream, err error) { + h.logger.Errorf("HTTP request error: %s", err) + stream.WriteHeaders([]h2mux.Header{ + {Name: ":status", Value: "502"}, + h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceCloudflared), + }) + stream.Write([]byte("502 Bad Gateway")) + h.metrics.incrementResponses(h.connectionID, "502") +} + +func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool) { + logger := h.logger + if cfRay != "" { + logger.Debugf("CF-RAY: %s %s %s %s", cfRay, req.Method, req.URL, req.Proto) + } else if lbProbe { + logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, req.Method, req.URL, req.Proto) + } else { + logger.Infof("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, req.Method, req.URL, req.Proto) + } + logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, req.Header) + + if contentLen := req.ContentLength; contentLen == -1 { + logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay) + } else { + logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen) + } +} + +func (h *TunnelHandler) logResponseOk(r *http.Response, cfRay string, lbProbe bool) { + h.metrics.incrementResponses(h.connectionID, "200") + logger := h.logger + if cfRay != "" { + logger.Debugf("CF-RAY: %s %s", cfRay, r.Status) + } else if lbProbe { + logger.Debugf("Response to Load Balancer health check %s", r.Status) + } else { + logger.Infof("%s", r.Status) + } + logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) + + if contentLen := r.ContentLength; contentLen == -1 { + logger.Debugf("CF-RAY: %s Response content length unknown", cfRay) + } else { + logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen) + } +} + +func (h *TunnelHandler) UpdateMetrics(connectionID string) { + h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics()) +} + +type h2muxWebsocketResp struct { + *h2mux.MuxedStream +} + +func (wr *h2muxWebsocketResp) WriteRespHeaders(resp *http.Response) error { + return wr.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp)) +} diff --git a/origin/http2.go b/origin/http2.go new file mode 100644 index 00000000..df77fd2f --- /dev/null +++ b/origin/http2.go @@ -0,0 +1,320 @@ +package origin + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "strconv" + "strings" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/tunnelrpc" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + + "github.com/pkg/errors" + "golang.org/x/net/http2" + "zombiezen.com/go/capnproto2/rpc" +) + +const ( + internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade" + websocketUpgrade = "websocket" + controlPlaneUpgrade = "control-plane" +) + +type http2Server struct { + server *http2.Server + ingressRules ingress.Ingress + logger logger.Service + connIndexStr string + connIndex uint8 + config *TunnelConfig + localAddr net.Addr + shutdownChan chan struct{} + connectedFuse *h2mux.BooleanFuse +} + +func newHTTP2Server(config *TunnelConfig, connIndex uint8, localAddr net.Addr, connectedFuse *h2mux.BooleanFuse) (*http2Server, error) { + return &http2Server{ + server: &http2.Server{}, + ingressRules: config.IngressRules, + logger: config.Logger, + connIndexStr: uint8ToString(connIndex), + connIndex: connIndex, + config: config, + localAddr: localAddr, + shutdownChan: make(chan struct{}), + connectedFuse: connectedFuse, + }, nil +} + +func (c *http2Server) serve(ctx context.Context, conn net.Conn) { + go func() { + <-ctx.Done() + c.close(conn) + }() + c.server.ServeConn(conn, &http2.ServeConnOpts{ + Context: ctx, + Handler: c, + }) +} + +func (c *http2Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.config.Metrics.incrementRequests(c.connIndexStr) + defer c.config.Metrics.decrementConcurrentRequests(c.connIndexStr) + + cfRay := findCfRayHeader(r) + lbProbe := isLBProbeRequest(r) + c.logRequest(r, cfRay, lbProbe) + + rule, _ := c.ingressRules.FindMatchingRule(r.Host, r.URL.Path) + rule.Service.RewriteOriginURL(r.URL) + + var resp *http.Response + var err error + + if isControlPlaneUpgrade(r) { + stripWebsocketUpgradeHeader(r) + err = c.serveControlPlane(w, r) + } else if isWebsocketUpgrade(r) { + stripWebsocketUpgradeHeader(r) + var respBody BidirectionalStream + respBody, err = newHTTP2Stream(w, r) + if err == nil { + resp, err = serveWebsocket(respBody, r, rule) + } + } else { + resp, err = c.serveHTTP(w, r, rule) + } + + if err != nil { + c.writeErrorResponse(w, err) + return + } + if resp != nil { + resp.Body.Close() + } +} + +func (c *http2Server) serveHTTP(w http.ResponseWriter, r *http.Request, rule *ingress.Rule) (*http.Response, error) { + // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate + if rule.Config.DisableChunkedEncoding { + r.TransferEncoding = []string{"gzip", "deflate"} + cLength, err := strconv.Atoi(r.Header.Get("Content-Length")) + if err == nil { + r.ContentLength = int64(cLength) + } + } + + // Request origin to keep connection alive to improve performance + r.Header.Set("Connection", "keep-alive") + + if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { + r.Header.Set("Host", hostHeader) + r.Host = hostHeader + } + + resp, err := rule.HTTPTransport.RoundTrip(r) + if err != nil { + return nil, errors.Wrap(err, "Error proxying request to origin") + } + w.WriteHeader(resp.StatusCode) + _, err = io.Copy(w, resp.Body) + if err != nil { + return nil, errors.Wrap(err, "Copy response error") + } + return resp, nil +} + +func (c *http2Server) serveControlPlane(w http.ResponseWriter, r *http.Request) error { + stream, err := newHTTP2Stream(w, r) + if err != nil { + return err + } + + rpcTransport := tunnelrpc.NewTransportLogger(c.logger, rpc.StreamTransport(stream)) + rpcConn := rpc.NewConn( + rpcTransport, + tunnelrpc.ConnLog(c.logger), + ) + rpcClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(r.Context()), Conn: rpcConn} + + if err = c.registerConnection(r.Context(), rpcClient, 0); err != nil { + return err + } + c.connectedFuse.Fuse(true) + + <-c.shutdownChan + c.gracefulShutdown(rpcClient) + + // Closing the client will also close the connection + rpcClient.Close() + rpcTransport.Close() + close(c.shutdownChan) + return nil +} + +func (c *http2Server) registerConnection( + ctx context.Context, + rpcClient tunnelpogs.TunnelServer_PogsClient, + numPreviousAttempts uint8, +) error { + connDetail, err := rpcClient.RegisterConnection( + ctx, + c.config.NamedTunnel.Auth, + c.config.NamedTunnel.ID, + c.connIndex, + c.config.ConnectionOptions(c.localAddr.String(), numPreviousAttempts), + ) + if err != nil { + c.logger.Errorf("Cannot register connection, err: %v", err) + return err + } + c.logger.Infof("Connection %s registered with %s using ID %s", c.connIndexStr, connDetail.Location, connDetail.UUID) + return nil +} + +func (c *http2Server) gracefulShutdown(rpcClient tunnelpogs.TunnelServer_PogsClient) { + ctx, cancel := context.WithTimeout(context.Background(), c.config.GracePeriod) + defer cancel() + err := rpcClient.UnregisterConnection(ctx) + if err != nil { + c.logger.Errorf("Cannot unregister connection gracefully, err: %v", err) + return + } + c.logger.Info("Sent graceful shutdown signal") + + <-ctx.Done() +} + +func (c *http2Server) writeErrorResponse(w http.ResponseWriter, err error) { + c.logger.Errorf("HTTP request error: %s", err) + c.config.Metrics.incrementResponses(c.connIndexStr, "502") + jsonResponseMetaHeader, err := json.Marshal(h2mux.ResponseMetaHeader{Source: h2mux.ResponseSourceCloudflared}) + if err != nil { + panic(err) + } + w.Header().Set(h2mux.ResponseMetaHeaderField, string(jsonResponseMetaHeader)) + w.WriteHeader(http.StatusBadGateway) +} + +func (c *http2Server) logRequest(r *http.Request, cfRay string, lbProbe bool) { + logger := c.logger + if cfRay != "" { + logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto) + } else if lbProbe { + logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto) + } else { + logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto) + } + logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, r.Header) + + if contentLen := r.ContentLength; contentLen == -1 { + logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay) + } else { + logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen) + } +} + +func (c *http2Server) logResponseOk(r *http.Response, cfRay string, lbProbe bool) { + c.config.Metrics.incrementResponses(c.connIndexStr, "200") + logger := c.logger + if cfRay != "" { + logger.Debugf("CF-RAY: %s %s", cfRay, r.Status) + } else if lbProbe { + logger.Debugf("Response to Load Balancer health check %s", r.Status) + } else { + logger.Infof("%s", r.Status) + } + logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) + + if contentLen := r.ContentLength; contentLen == -1 { + logger.Debugf("CF-RAY: %s Response content length unknown", cfRay) + } else { + logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen) + } +} + +func (c *http2Server) close(conn net.Conn) { + // Send signal to control loop to start graceful shutdown + c.shutdownChan <- struct{}{} + // Wait for control loop to close channel + <-c.shutdownChan + conn.Close() +} + +type http2Stream struct { + r io.Reader + w http.ResponseWriter + flusher http.Flusher +} + +func newHTTP2Stream(w http.ResponseWriter, r *http.Request) (*http2Stream, error) { + flusher, ok := w.(http.Flusher) + if !ok { + return nil, fmt.Errorf("ResponseWriter doesn't implement http.Flusher") + } + return &http2Stream{r: r.Body, w: w, flusher: flusher}, nil +} + +func (wr *http2Stream) WriteRespHeaders(resp *http.Response) error { + dest := wr.w.Header() + userHeaders := make(http.Header, len(resp.Header)) + for header, values := range resp.Header { + // Since these are http2 headers, they're required to be lowercase + h2name := strings.ToLower(header) + for _, v := range values { + if h2name == "content-length" { + // This header has meaning in HTTP/2 and will be used by the edge, + // so it should be sent as an HTTP/2 response header. + dest.Add(h2name, v) + // Since these are http2 headers, they're required to be lowercase + } else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) { + // User headers, on the other hand, must all be serialized so that + // HTTP/2 header validation won't be applied to HTTP/1 header values + userHeaders.Add(h2name, v) + } + } + } + + // Perform user header serialization and set them in the single header + dest.Set(h2mux.ResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders)) + // HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1 + wr.w.WriteHeader(http.StatusOK) + wr.flusher.Flush() + return nil +} + +func (wr *http2Stream) Read(p []byte) (n int, err error) { + return wr.r.Read(p) +} + +func (wr *http2Stream) Write(p []byte) (n int, err error) { + n, err = wr.w.Write(p) + if err != nil { + return 0, err + } + wr.flusher.Flush() + return +} + +func (wr *http2Stream) Close() error { + return nil +} + +func isControlPlaneUpgrade(r *http.Request) bool { + return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == controlPlaneUpgrade +} + +func isWebsocketUpgrade(r *http.Request) bool { + return strings.ToLower(r.Header.Get(internalUpgradeHeader)) == websocketUpgrade +} + +func stripWebsocketUpgradeHeader(r *http.Request) { + r.Header.Del(internalUpgradeHeader) +} diff --git a/origin/server.go b/origin/server.go deleted file mode 100644 index 4748da26..00000000 --- a/origin/server.go +++ /dev/null @@ -1,202 +0,0 @@ -package origin - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "net/url" - "strings" - - "github.com/cloudflare/cloudflared/h2mux" - "github.com/cloudflare/cloudflared/logger" - "github.com/pkg/errors" - "golang.org/x/net/http2" -) - -type cfdServer struct { - httpServer *http2.Server - originClient http.RoundTripper - logger logger.Service - originURL *url.URL - connectionIndex string - config *TunnelConfig -} - -func (c *cfdServer) serve(ctx context.Context, conn net.Conn) { - go func() { - <-ctx.Done() - conn.Close() - }() - c.httpServer.ServeConn(conn, &http2.ServeConnOpts{ - Context: ctx, - Handler: c, - }) -} - -func (c *cfdServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - c.config.Metrics.incrementRequests(c.connectionIndex) - defer c.config.Metrics.decrementConcurrentRequests(c.connectionIndex) - - cfRay := findCfRayHeader(r) - lbProbe := isLBProbeRequest(r) - c.logRequest(r, cfRay, lbProbe) - - r.URL = c.originURL - c.logger.Infof("URL %v", r.URL) - // TODO: TUN-3406 support websocket, event stream and WSGI servers. - var resp *http.Response - var err error - - if isWebsocketUpgrade(r) { - var respBody WebsocketResp - respBody, err = newWebsocketBody(w, r) - if err == nil { - resp, err = serveWebsocket(respBody, r, c.config.HTTPHostHeader, c.config.ClientTlsConfig) - } - } else { - resp, err = c.serveHTTP(w, r) - } - - if err != nil { - c.writeErrorResponse(w, err) - return - } - defer resp.Body.Close() - -} - -func (c *cfdServer) serveHTTP(w http.ResponseWriter, r *http.Request) (*http.Response, error) { - resp, err := c.originClient.RoundTrip(r) - if err != nil { - return nil, err - } - w.WriteHeader(resp.StatusCode) - _, err = io.Copy(w, resp.Body) - if err != nil { - return nil, errors.Wrap(err, "Copy response error") - } - return resp, nil -} - -func (c *cfdServer) writeErrorResponse(w http.ResponseWriter, err error) { - c.logger.Errorf("HTTP request error: %s", err) - c.config.Metrics.incrementResponses(c.connectionIndex, "502") - jsonResponseMetaHeader, err := json.Marshal(h2mux.ResponseMetaHeader{Source: h2mux.ResponseSourceCloudflared}) - if err != nil { - panic(err) - } - w.Header().Set(h2mux.ResponseMetaHeaderField, string(jsonResponseMetaHeader)) - w.WriteHeader(http.StatusBadGateway) -} - -func (c *cfdServer) logRequest(r *http.Request, cfRay string, lbProbe bool) { - logger := c.logger - if cfRay != "" { - logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto) - } else if lbProbe { - logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto) - } else { - logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto) - } - logger.Infof("CF-RAY: %s Request Headers %+v", cfRay, r.Header) - - if contentLen := r.ContentLength; contentLen == -1 { - logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay) - } else { - logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen) - } -} - -func (c *cfdServer) logResponseOk(r *http.Response, cfRay string, lbProbe bool) { - c.config.Metrics.incrementResponses(c.connectionIndex, "200") - logger := c.logger - if cfRay != "" { - logger.Debugf("CF-RAY: %s %s", cfRay, r.Status) - } else if lbProbe { - logger.Debugf("Response to Load Balancer health check %s", r.Status) - } else { - logger.Infof("%s", r.Status) - } - logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) - - if contentLen := r.ContentLength; contentLen == -1 { - logger.Debugf("CF-RAY: %s Response content length unknown", cfRay) - } else { - logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen) - } -} - -type WebsocketResp interface { - WriteRespHeaders(*http.Response) error - io.ReadWriter -} - -type http2WebsocketResp struct { - r io.Reader - w http.ResponseWriter - flusher http.Flusher -} - -func newWebsocketBody(w http.ResponseWriter, r *http.Request) (*http2WebsocketResp, error) { - flusher, ok := w.(http.Flusher) - if !ok { - return nil, fmt.Errorf("ResponseWriter doesn't implement http.Flusher") - } - return &http2WebsocketResp{r: r.Body, w: w, flusher: flusher}, nil -} - -func (wr *http2WebsocketResp) WriteRespHeaders(resp *http.Response) error { - dest := wr.w.Header() - userHeaders := make(http.Header, len(resp.Header)) - for header, values := range resp.Header { - // Since these are http2 headers, they're required to be lowercase - h2name := strings.ToLower(header) - for _, v := range values { - if h2name == "content-length" { - // This header has meaning in HTTP/2 and will be used by the edge, - // so it should be sent as an HTTP/2 response header. - dest.Add(h2name, v) - // Since these are http2 headers, they're required to be lowercase - } else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) { - // User headers, on the other hand, must all be serialized so that - // HTTP/2 header validation won't be applied to HTTP/1 header values - userHeaders.Add(h2name, v) - } - } - } - - // Perform user header serialization and set them in the single header - dest.Set(h2mux.ResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders)) - // HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1 - wr.w.WriteHeader(http.StatusOK) - wr.flusher.Flush() - return nil -} - -func (wr *http2WebsocketResp) Read(p []byte) (n int, err error) { - return wr.r.Read(p) -} - -func (wr *http2WebsocketResp) Write(p []byte) (n int, err error) { - n, err = wr.w.Write(p) - if err != nil { - return 0, err - } - wr.flusher.Flush() - return -} - -type h2muxWebsocketResp struct { - *h2mux.MuxedStream -} - -func (wr *h2muxWebsocketResp) WriteRespHeaders(resp *http.Response) error { - return wr.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp)) -} - -func isWebsocketUpgrade(r *http.Request) bool { - return strings.ToLower(r.Header.Get("Cf-Int-Tunnel-Upgrade")) == "websocket" -} diff --git a/origin/tunnel.go b/origin/tunnel.go index f94937d5..7bc36263 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -1,11 +1,9 @@ package origin import ( - "bufio" "context" "crypto/tls" "fmt" - "io" "net" "net/http" "net/url" @@ -17,9 +15,7 @@ import ( "github.com/google/uuid" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" - "golang.org/x/net/http2" "golang.org/x/sync/errgroup" - "zombiezen.com/go/capnproto2/rpc" "github.com/cloudflare/cloudflared/buffer" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" @@ -32,7 +28,6 @@ import ( "github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - "github.com/cloudflare/cloudflared/validation" "github.com/cloudflare/cloudflared/websocket" ) @@ -262,7 +257,6 @@ func ServeTunnelLoop(ctx context.Context, if config.TunnelEventChan != nil { config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting} } - config.Logger.Infof("Retrying connection %d in %s seconds", connectionIndex, duration) backoff.Backoff(ctx) continue @@ -307,14 +301,7 @@ func ServeTunnel( connectionTag := uint8ToString(connectionIndex) if config.NamedTunnel != nil && config.NamedTunnel.Protocol == http2Protocol { - tlsConn, err := RegisterConnection(ctx, config, connectionIndex, uint8(backoff.retries), addr) - if err != nil { - logger.Errorf("Register connectio error: %+v", err) - return err, true - } - connectedFuse.Fuse(true) - backoff.SetGracePeriod() - return serveNamedTunnel(ctx, config, tlsConn, connectionIndex, reconnectCh) + return ServeNamedTunnel(ctx, config, connectionIndex, addr, connectedFuse, reconnectCh) } // Returns error from parsing the origin URL or handshake errors @@ -432,54 +419,6 @@ func ServeTunnel( return nil, true } -func serveNamedTunnel( - ctx context.Context, - config *TunnelConfig, - tlsConn net.Conn, - connectionIndex uint8, - reconnectCh chan ReconnectSignal, -) (err error, recoverable bool) { - originURLStr, err := validation.ValidateUrl(config.OriginUrl) - if err != nil { - return fmt.Errorf("unable to parse origin URL %#v", config.OriginUrl), false - } - originURL, err := url.Parse(originURLStr) - if err != nil { - return fmt.Errorf("unable to parse origin URL %#v", originURLStr), false - } - - originClient := config.HTTPTransport - if originClient == nil { - originClient = http.DefaultTransport - } - - errGroup, serveCtx := errgroup.WithContext(ctx) - errGroup.Go(func() error { - cfdServer := &cfdServer{ - httpServer: &http2.Server{}, - originClient: originClient, - logger: config.Logger, - originURL: originURL, - connectionIndex: uint8ToString(connectionIndex), - config: config, - } - cfdServer.serve(serveCtx, tlsConn) - return fmt.Errorf("Connection with edge closed") - }) - - errGroup.Go(func() error { - select { - case reconnect := <-reconnectCh: - return &reconnect - case <-serveCtx.Done(): - return nil - } - }) - - err = errGroup.Wait() - return err, true -} - func RegisterConnectionWithH2Mux( ctx context.Context, muxer *h2mux.Muxer, @@ -524,50 +463,44 @@ func RegisterConnectionWithH2Mux( return nil } -func RegisterConnection( +func ServeNamedTunnel( ctx context.Context, config *TunnelConfig, - connectionID uint8, - numPreviousAttempts uint8, + connIndex uint8, addr *net.TCPAddr, -) (net.Conn, error) { - originCert, err := tls.X509KeyPair(config.OriginCert, config.OriginCert) - if err != nil { - return nil, err - } - tlsConfig := config.TlsConfig - tlsConfig.Certificates = []tls.Certificate{originCert} + connectedFuse *h2mux.BooleanFuse, + reconnectCh chan ReconnectSignal, +) (err error, recoverable bool) { tlsServerConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr) if err != nil { - return nil, err + return err, true } - rpcTransport := tunnelrpc.NewTransportLogger(config.Logger, rpc.StreamTransport(&persistentConn{tlsServerConn})) - rpcConn := rpc.NewConn( - rpcTransport, - tunnelrpc.ConnLog(config.Logger), - ) - rpcClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx), Conn: rpcConn} - connDetail, err := rpcClient.RegisterConnection( - ctx, - config.NamedTunnel.Auth, - config.NamedTunnel.ID, - connectionID, - config.ConnectionOptions(tlsServerConn.LocalAddr().String(), numPreviousAttempts), - ) + cfdServer, err := newHTTP2Server(config, connIndex, tlsServerConn.LocalAddr(), connectedFuse) if err != nil { - return nil, err + return err, false } - config.Logger.Infof("Connection %d registered with %s using ID %s", connectionID, connDetail.Location, connDetail.UUID) - rpcTransport.Close() - // Closing the client will also close the connection - rpcClient.Close() - flushMessage := make([]byte, 8) - buf := make([]byte, len(flushMessage)) - tlsServerConn.Write(buf) + errGroup, serveCtx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + cfdServer.serve(serveCtx, tlsServerConn) + return fmt.Errorf("Connection with edge closed") + }) - return tlsServerConn, nil + errGroup.Go(func() error { + select { + case reconnect := <-reconnectCh: + return &reconnect + case <-serveCtx.Done(): + return nil + } + }) + + err = errGroup.Wait() + if err != nil { + return err, true + } + return nil, false } func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError { @@ -733,98 +666,6 @@ func LogServerInfo( metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName) } -type TunnelHandler struct { - ingressRules ingress.Ingress - muxer *h2mux.Muxer - tags []tunnelpogs.Tag - metrics *TunnelMetrics - // connectionID is only used by metrics, and prometheus requires labels to be string - connectionID string - logger logger.Service - - bufferPool *buffer.Pool -} - -// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error -func NewTunnelHandler(ctx context.Context, - config *TunnelConfig, - addr *net.TCPAddr, - connectionID uint8, - bufferPool *buffer.Pool, -) (*TunnelHandler, string, error) { - - h := &TunnelHandler{ - ingressRules: config.IngressRules, - tags: config.Tags, - metrics: config.Metrics, - connectionID: uint8ToString(connectionID), - logger: config.Logger, - bufferPool: bufferPool, - } - - edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr) - if err != nil { - return nil, "", err - } - // Establish a muxed connection with the edge - // Client mux handshake with agent server - h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, config.muxerConfig(h), h.metrics.activeStreams) - if err != nil { - return nil, "", errors.Wrap(err, "h2mux handshake with edge error") - } - return h, edgeConn.LocalAddr().String(), nil -} - -func (h *TunnelHandler) AppendTagHeaders(r *http.Request) { - for _, tag := range h.tags { - r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value) - } -} - -func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { - h.metrics.incrementRequests(h.connectionID) - defer h.metrics.decrementConcurrentRequests(h.connectionID) - - req, rule, reqErr := h.createRequest(stream) - if reqErr != nil { - h.writeErrorResponse(stream, reqErr) - return reqErr - } - - cfRay := findCfRayHeader(req) - lbProbe := isLBProbeRequest(req) - h.logRequest(req, cfRay, lbProbe) - - var resp *http.Response - var respErr error - if websocket.IsWebSocketUpgrade(req) { - resp, respErr = serveWebsocket(&h2muxWebsocketResp{stream}, req, rule) - } else { - resp, respErr = h.serveHTTP(stream, req, rule) - } - if respErr != nil { - h.writeErrorResponse(stream, respErr) - return respErr - } - h.logResponseOk(resp, cfRay, lbProbe) - return nil -} - -func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, *ingress.Rule, error) { - req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream}) - if err != nil { - return nil, nil, errors.Wrap(err, "Unexpected error from http.NewRequest") - } - err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req) - if err != nil { - return nil, nil, errors.Wrap(err, "invalid request received") - } - h.AppendTagHeaders(req) - // For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map. - rule, _ := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path) - return req, rule, nil -} - func serveWebsocket(wsResp WebsocketResp, req *http.Request, rule *ingress.Rule) (*http.Response, error) { if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { req.Header.Set("Host", hostHeader) @@ -851,118 +692,6 @@ func serveWebsocket(wsResp WebsocketResp, req *http.Request, rule *ingress.Rule) return response, nil } -func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) { - // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate - if rule.Config.DisableChunkedEncoding { - req.TransferEncoding = []string{"gzip", "deflate"} - cLength, err := strconv.Atoi(req.Header.Get("Content-Length")) - if err == nil { - req.ContentLength = int64(cLength) - } - } - - // Request origin to keep connection alive to improve performance - req.Header.Set("Connection", "keep-alive") - - if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { - req.Header.Set("Host", hostHeader) - req.Host = hostHeader - } - - response, err := rule.Service.RoundTrip(req) - if err != nil { - return nil, errors.Wrap(err, "Error proxying request to origin") - } - defer response.Body.Close() - - headers := h2mux.H1ResponseToH2ResponseHeaders(response) - headers = append(headers, h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceOrigin)) - err = stream.WriteHeaders(headers) - if err != nil { - return nil, errors.Wrap(err, "Error writing response header") - } - if h.isEventStream(response) { - h.writeEventStream(stream, response.Body) - } else { - // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream - // compression generates dictionary on first write - buf := h.bufferPool.Get() - defer h.bufferPool.Put(buf) - io.CopyBuffer(stream, response.Body, buf) - } - return response, nil -} - -func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) { - reader := bufio.NewReader(responseBody) - for { - line, err := reader.ReadBytes('\n') - if err != nil { - break - } - stream.Write(line) - } -} - -func (h *TunnelHandler) isEventStream(response *http.Response) bool { - if response.Header.Get("content-type") == "text/event-stream" { - h.logger.Debug("Detected Server-Side Events from Origin") - return true - } - return false -} - -func (h *TunnelHandler) writeErrorResponse(stream *h2mux.MuxedStream, err error) { - h.logger.Errorf("HTTP request error: %s", err) - stream.WriteHeaders([]h2mux.Header{ - {Name: ":status", Value: "502"}, - h2mux.CreateResponseMetaHeader(h2mux.ResponseMetaHeaderField, h2mux.ResponseSourceCloudflared), - }) - stream.Write([]byte("502 Bad Gateway")) - h.metrics.incrementResponses(h.connectionID, "502") -} - -func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool) { - logger := h.logger - if cfRay != "" { - logger.Debugf("CF-RAY: %s %s %s %s", cfRay, req.Method, req.URL, req.Proto) - } else if lbProbe { - logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, req.Method, req.URL, req.Proto) - } else { - logger.Infof("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, req.Method, req.URL, req.Proto) - } - logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, req.Header) - - if contentLen := req.ContentLength; contentLen == -1 { - logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay) - } else { - logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen) - } -} - -func (h *TunnelHandler) logResponseOk(r *http.Response, cfRay string, lbProbe bool) { - h.metrics.incrementResponses(h.connectionID, "200") - logger := h.logger - if cfRay != "" { - logger.Debugf("CF-RAY: %s %s", cfRay, r.Status) - } else if lbProbe { - logger.Debugf("Response to Load Balancer health check %s", r.Status) - } else { - logger.Infof("%s", r.Status) - } - logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) - - if contentLen := r.ContentLength; contentLen == -1 { - logger.Debugf("CF-RAY: %s Response content length unknown", cfRay) - } else { - logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen) - } -} - -func (h *TunnelHandler) UpdateMetrics(connectionID string) { - h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics()) -} - func uint8ToString(input uint8) string { return strconv.FormatUint(uint64(input), 10) }