TUN-3506: OriginService needs to set request host and scheme for websocket requests

This commit is contained in:
cthuang 2020-11-05 13:52:46 +00:00
parent be9a558867
commit 61c814bd79
3 changed files with 39 additions and 17 deletions

View File

@ -150,8 +150,15 @@ func uptimeHandler(startTime time.Time) http.HandlerFunc {
// This handler will echo message // This handler will echo message
func websocketHandler(logger logger.Service, upgrader websocket.Upgrader) http.HandlerFunc { func websocketHandler(logger logger.Service, upgrader websocket.Upgrader) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
// This addresses the issue of r.Host includes port but origin header doesn't
host, _, err := net.SplitHostPort(r.Host)
if err == nil {
r.Host = host
}
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
logger.Errorf("failed to upgrade to websocket connection, error: %s", err)
return return
} }
defer conn.Close() defer conn.Close()

View File

@ -55,9 +55,14 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req) return o.transport.RoundTrip(req)
} }
func (o *unixSocketPath) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) { func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig} d := &gws.Dialer{
return d.Dial(url, headers) NetDial: o.transport.Dial,
NetDialContext: o.transport.DialContext,
TLSClientConfig: o.transport.TLSClientConfig,
}
reqURL.Scheme = websocket.ChangeRequestScheme(reqURL)
return d.Dial(reqURL.String(), headers)
} }
// localService is an OriginService listening on a TCP/IP address the user's origin can route to. // localService is an OriginService listening on a TCP/IP address the user's origin can route to.
@ -71,9 +76,12 @@ type localService struct {
transport *http.Transport transport *http.Transport
} }
func (o *localService) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) { func (o *localService) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig} d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
return d.Dial(url, headers) // Rewrite the request URL so that it goes to the origin service.
reqURL.Host = o.URL.Host
reqURL.Scheme = websocket.ChangeRequestScheme(o.URL)
return d.Dial(reqURL.String(), headers)
} }
func (o *localService) address() string { func (o *localService) address() string {
@ -215,9 +223,13 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req) return o.transport.RoundTrip(req)
} }
func (o *helloWorld) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) { func (o *helloWorld) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig} d := &gws.Dialer{
return d.Dial(url, headers) TLSClientConfig: o.transport.TLSClientConfig,
}
reqURL.Host = o.server.Addr().String()
reqURL.Scheme = "wss"
return d.Dial(reqURL.String(), headers)
} }
func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool { func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool {

View File

@ -11,6 +11,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url"
"time" "time"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
@ -71,29 +72,29 @@ func IsWebSocketUpgrade(req *http.Request) bool {
// Dialler is something that can proxy websocket requests. // Dialler is something that can proxy websocket requests.
type Dialler interface { type Dialler interface {
Dial(url string, headers http.Header) (*websocket.Conn, *http.Response, error) Dial(url *url.URL, headers http.Header) (*websocket.Conn, *http.Response, error)
} }
type defaultDialler struct { type defaultDialler struct {
tlsConfig *tls.Config tlsConfig *tls.Config
} }
func (dd *defaultDialler) Dial(url string, header http.Header) (*websocket.Conn, *http.Response, error) { func (dd *defaultDialler) Dial(url *url.URL, header http.Header) (*websocket.Conn, *http.Response, error) {
d := &websocket.Dialer{TLSClientConfig: dd.tlsConfig} d := &websocket.Dialer{TLSClientConfig: dd.tlsConfig}
return d.Dial(url, header) return d.Dial(url.String(), header)
} }
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing // ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
// the connection. The response body may not contain the entire response and does // the connection. The response body may not contain the entire response and does
// not need to be closed by the application. // not need to be closed by the application.
func ClientConnect(req *http.Request, dialler Dialler) (*websocket.Conn, *http.Response, error) { func ClientConnect(req *http.Request, dialler Dialler) (*websocket.Conn, *http.Response, error) {
req.URL.Scheme = changeRequestScheme(req) req.URL.Scheme = ChangeRequestScheme(req.URL)
wsHeaders := websocketHeaders(req) wsHeaders := websocketHeaders(req)
if dialler == nil { if dialler == nil {
dialler = new(defaultDialler) dialler = new(defaultDialler)
} }
conn, response, err := dialler.Dial(req.URL.String(), wsHeaders) conn, response, err := dialler.Dial(req.URL, wsHeaders)
if err != nil { if err != nil {
return nil, response, err return nil, response, err
} }
@ -252,16 +253,18 @@ func generateAcceptKey(req *http.Request) string {
return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
} }
// changeRequestScheme is needed as the gorilla websocket library requires the ws scheme. // ChangeRequestScheme is needed as the gorilla websocket library requires the ws scheme.
// (even though it changes it back to http/https, but ¯\_(ツ)_/¯.) // (even though it changes it back to http/https, but ¯\_(ツ)_/¯.)
func changeRequestScheme(req *http.Request) string { func ChangeRequestScheme(reqURL *url.URL) string {
switch req.URL.Scheme { switch reqURL.Scheme {
case "https": case "https":
return "wss" return "wss"
case "http": case "http":
return "ws" return "ws"
case "":
return "ws"
default: default:
return req.URL.Scheme return reqURL.Scheme
} }
} }