diff --git a/hello/hello.go b/hello/hello.go index 8c064b40..a2b12798 100644 --- a/hello/hello.go +++ b/hello/hello.go @@ -150,8 +150,15 @@ func uptimeHandler(startTime time.Time) http.HandlerFunc { // This handler will echo message func websocketHandler(logger logger.Service, upgrader websocket.Upgrader) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + // This addresses the issue of r.Host includes port but origin header doesn't + host, _, err := net.SplitHostPort(r.Host) + if err == nil { + r.Host = host + } + conn, err := upgrader.Upgrade(w, r, nil) if err != nil { + logger.Errorf("failed to upgrade to websocket connection, error: %s", err) return } defer conn.Close() diff --git a/ingress/origin_service.go b/ingress/origin_service.go index 5d7cec32..8e194bd9 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -55,9 +55,14 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { return o.transport.RoundTrip(req) } -func (o *unixSocketPath) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) { - d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig} - return d.Dial(url, headers) +func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) { + d := &gws.Dialer{ + NetDial: o.transport.Dial, + NetDialContext: o.transport.DialContext, + TLSClientConfig: o.transport.TLSClientConfig, + } + reqURL.Scheme = websocket.ChangeRequestScheme(reqURL) + return d.Dial(reqURL.String(), headers) } // localService is an OriginService listening on a TCP/IP address the user's origin can route to. @@ -71,9 +76,12 @@ type localService struct { transport *http.Transport } -func (o *localService) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) { +func (o *localService) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) { d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig} - return d.Dial(url, headers) + // Rewrite the request URL so that it goes to the origin service. + reqURL.Host = o.URL.Host + reqURL.Scheme = websocket.ChangeRequestScheme(o.URL) + return d.Dial(reqURL.String(), headers) } func (o *localService) address() string { @@ -215,9 +223,13 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) { return o.transport.RoundTrip(req) } -func (o *helloWorld) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) { - d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig} - return d.Dial(url, headers) +func (o *helloWorld) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) { + d := &gws.Dialer{ + TLSClientConfig: o.transport.TLSClientConfig, + } + reqURL.Host = o.server.Addr().String() + reqURL.Scheme = "wss" + return d.Dial(reqURL.String(), headers) } func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool { diff --git a/websocket/websocket.go b/websocket/websocket.go index 26367113..7abe01b5 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -11,6 +11,7 @@ import ( "io" "net" "net/http" + "net/url" "time" "github.com/cloudflare/cloudflared/h2mux" @@ -71,29 +72,29 @@ func IsWebSocketUpgrade(req *http.Request) bool { // Dialler is something that can proxy websocket requests. type Dialler interface { - Dial(url string, headers http.Header) (*websocket.Conn, *http.Response, error) + Dial(url *url.URL, headers http.Header) (*websocket.Conn, *http.Response, error) } type defaultDialler struct { tlsConfig *tls.Config } -func (dd *defaultDialler) Dial(url string, header http.Header) (*websocket.Conn, *http.Response, error) { +func (dd *defaultDialler) Dial(url *url.URL, header http.Header) (*websocket.Conn, *http.Response, error) { d := &websocket.Dialer{TLSClientConfig: dd.tlsConfig} - return d.Dial(url, header) + return d.Dial(url.String(), header) } // ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing // the connection. The response body may not contain the entire response and does // not need to be closed by the application. func ClientConnect(req *http.Request, dialler Dialler) (*websocket.Conn, *http.Response, error) { - req.URL.Scheme = changeRequestScheme(req) + req.URL.Scheme = ChangeRequestScheme(req.URL) wsHeaders := websocketHeaders(req) if dialler == nil { dialler = new(defaultDialler) } - conn, response, err := dialler.Dial(req.URL.String(), wsHeaders) + conn, response, err := dialler.Dial(req.URL, wsHeaders) if err != nil { return nil, response, err } @@ -252,16 +253,18 @@ func generateAcceptKey(req *http.Request) string { return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") } -// changeRequestScheme is needed as the gorilla websocket library requires the ws scheme. +// ChangeRequestScheme is needed as the gorilla websocket library requires the ws scheme. // (even though it changes it back to http/https, but ¯\_(ツ)_/¯.) -func changeRequestScheme(req *http.Request) string { - switch req.URL.Scheme { +func ChangeRequestScheme(reqURL *url.URL) string { + switch reqURL.Scheme { case "https": return "wss" case "http": return "ws" + case "": + return "ws" default: - return req.URL.Scheme + return reqURL.Scheme } }