From cb39f26f27d179e8c5dec9b03159225f0d7bdbca Mon Sep 17 00:00:00 2001 From: cthuang Date: Mon, 21 Sep 2020 09:45:51 +0100 Subject: [PATCH] TUN-3406: Proxy websocket requests over Go http2 --- h2mux/header.go | 4 +-- origin/server.go | 80 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 63 insertions(+), 21 deletions(-) diff --git a/h2mux/header.go b/h2mux/header.go index 13aa3d57..d752b8e7 100644 --- a/h2mux/header.go +++ b/h2mux/header.go @@ -124,7 +124,7 @@ func IsControlHeader(headerName string) bool { } // isWebsocketClientHeader returns true if the header name is required by the client to upgrade properly -func isWebsocketClientHeader(headerName string) bool { +func IsWebsocketClientHeader(headerName string) bool { return headerName == "sec-websocket-accept" || headerName == "connection" || headerName == "upgrade" @@ -143,7 +143,7 @@ func H1ResponseToH2ResponseHeaders(h1 *http.Response) (h2 []Header) { // Since these are http2 headers, they're required to be lowercase h2 = append(h2, Header{Name: "content-length", Value: values[0]}) - } else if !IsControlHeader(h2name) || isWebsocketClientHeader(h2name) { + } else if !IsControlHeader(h2name) || IsWebsocketClientHeader(h2name) { // User headers, on the other hand, must all be serialized so that // HTTP/2 header validation won't be applied to HTTP/1 header values userHeaders[header] = values diff --git a/origin/server.go b/origin/server.go index a5b1e89d..4748da26 100644 --- a/origin/server.go +++ b/origin/server.go @@ -3,6 +3,7 @@ package origin import ( "context" "encoding/json" + "fmt" "io" "net" "net/http" @@ -11,6 +12,7 @@ import ( "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/logger" + "github.com/pkg/errors" "golang.org/x/net/http2" ) @@ -43,27 +45,40 @@ func (c *cfdServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.logRequest(r, cfRay, lbProbe) r.URL = c.originURL + c.logger.Infof("URL %v", r.URL) // TODO: TUN-3406 support websocket, event stream and WSGI servers. var resp *http.Response var err error - if strings.ToLower(r.Header.Get("Cf-Int-Argo-Tunnel-Upgrade")) == "websocket" { - resp, err = serveWebsocket(newWebsocketBody(w, r, c.logger), r, c.config.HTTPHostHeader, c.config.ClientTlsConfig) + + if isWebsocketUpgrade(r) { + var respBody WebsocketResp + respBody, err = newWebsocketBody(w, r) + if err == nil { + resp, err = serveWebsocket(respBody, r, c.config.HTTPHostHeader, c.config.ClientTlsConfig) + } } else { - resp, err = c.originClient.RoundTrip(r) + resp, err = c.serveHTTP(w, r) } + if err != nil { c.writeErrorResponse(w, err) return } defer resp.Body.Close() +} + +func (c *cfdServer) serveHTTP(w http.ResponseWriter, r *http.Request) (*http.Response, error) { + resp, err := c.originClient.RoundTrip(r) + if err != nil { + return nil, err + } w.WriteHeader(resp.StatusCode) _, err = io.Copy(w, resp.Body) if err != nil { - c.logger.Errorf("Copy response error, err: %v", err) - w.WriteHeader(http.StatusBadGateway) - return + return nil, errors.Wrap(err, "Copy response error") } + return resp, nil } func (c *cfdServer) writeErrorResponse(w http.ResponseWriter, err error) { @@ -120,35 +135,58 @@ type WebsocketResp interface { } type http2WebsocketResp struct { - pr *io.PipeReader - w http.ResponseWriter + r io.Reader + w http.ResponseWriter + flusher http.Flusher } -func newWebsocketBody(w http.ResponseWriter, r *http.Request, logger logger.Service) *http2WebsocketResp { - pr, pw := io.Pipe() - go func() { - n, err := io.Copy(pw, r.Body) - logger.Errorf("websocket body copy ended, err: %v, bytes: %d", err, n) - }() - return &http2WebsocketResp{pr: pr, w: w} +func newWebsocketBody(w http.ResponseWriter, r *http.Request) (*http2WebsocketResp, error) { + flusher, ok := w.(http.Flusher) + if !ok { + return nil, fmt.Errorf("ResponseWriter doesn't implement http.Flusher") + } + return &http2WebsocketResp{r: r.Body, w: w, flusher: flusher}, nil } func (wr *http2WebsocketResp) WriteRespHeaders(resp *http.Response) error { dest := wr.w.Header() - for name, values := range resp.Header { + userHeaders := make(http.Header, len(resp.Header)) + for header, values := range resp.Header { + // Since these are http2 headers, they're required to be lowercase + h2name := strings.ToLower(header) for _, v := range values { - dest.Add(name, v) + if h2name == "content-length" { + // This header has meaning in HTTP/2 and will be used by the edge, + // so it should be sent as an HTTP/2 response header. + dest.Add(h2name, v) + // Since these are http2 headers, they're required to be lowercase + } else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) { + // User headers, on the other hand, must all be serialized so that + // HTTP/2 header validation won't be applied to HTTP/1 header values + userHeaders.Add(h2name, v) + } } } + + // Perform user header serialization and set them in the single header + dest.Set(h2mux.ResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders)) + // HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1 + wr.w.WriteHeader(http.StatusOK) + wr.flusher.Flush() return nil } func (wr *http2WebsocketResp) Read(p []byte) (n int, err error) { - return wr.pr.Read(p) + return wr.r.Read(p) } func (wr *http2WebsocketResp) Write(p []byte) (n int, err error) { - return wr.w.Write(p) + n, err = wr.w.Write(p) + if err != nil { + return 0, err + } + wr.flusher.Flush() + return } type h2muxWebsocketResp struct { @@ -158,3 +196,7 @@ type h2muxWebsocketResp struct { func (wr *h2muxWebsocketResp) WriteRespHeaders(resp *http.Response) error { return wr.WriteHeaders(h2mux.H1ResponseToH2ResponseHeaders(resp)) } + +func isWebsocketUpgrade(r *http.Request) bool { + return strings.ToLower(r.Header.Get("Cf-Int-Tunnel-Upgrade")) == "websocket" +}