diff --git a/websocket/websocket.go b/websocket/websocket.go index 7abe01b5..d600cb9d 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -147,56 +147,70 @@ func StartProxyServer(logger logger.Service, listener net.Listener, staticHost s ReadBufferSize: 1024, WriteBufferSize: 1024, } + h := handler{ + upgrader: upgrader, + logger: logger, + staticHost: staticHost, + streamHandler: streamHandler, + } - httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil} + httpServer := &http.Server{Addr: listener.Addr().String(), Handler: &h} go func() { <-shutdownC httpServer.Close() }() - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - // If remote is an empty string, get the destination from the client. - finalDestination := staticHost - if finalDestination == "" { - if jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader); jumpDestination == "" { - logger.Error("Did not receive final destination from client. The --destination flag is likely not set") - return - } else { - finalDestination = jumpDestination - } - } - - stream, err := net.Dial("tcp", finalDestination) - if err != nil { - logger.Errorf("Cannot connect to remote: %s", err) - return - } - defer stream.Close() - - if !websocket.IsWebSocketUpgrade(r) { - w.Write(nonWebSocketRequestPage()) - return - } - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - logger.Errorf("failed to upgrade: %s", err) - return - } - conn.SetReadDeadline(time.Now().Add(pongWait)) - conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) - done := make(chan struct{}) - go pinger(logger, conn, done) - defer func() { - done <- struct{}{} - conn.Close() - }() - - streamHandler(&Conn{conn}, stream, r.Header) - }) - return httpServer.Serve(listener) } +// HTTP handler for the websocket proxy. +type handler struct { + logger logger.Service + staticHost string + upgrader websocket.Upgrader + streamHandler func(wsConn *Conn, remoteConn net.Conn, requestHeaders http.Header) +} + +func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // If remote is an empty string, get the destination from the client. + finalDestination := h.staticHost + if finalDestination == "" { + if jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader); jumpDestination == "" { + h.logger.Error("Did not receive final destination from client. The --destination flag is likely not set") + return + } else { + finalDestination = jumpDestination + } + } + + stream, err := net.Dial("tcp", finalDestination) + if err != nil { + h.logger.Errorf("Cannot connect to remote: %s", err) + return + } + defer stream.Close() + + if !websocket.IsWebSocketUpgrade(r) { + w.Write(nonWebSocketRequestPage()) + return + } + conn, err := h.upgrader.Upgrade(w, r, nil) + if err != nil { + h.logger.Errorf("failed to upgrade: %s", err) + return + } + conn.SetReadDeadline(time.Now().Add(pongWait)) + conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + done := make(chan struct{}) + go pinger(h.logger, conn, done) + defer func() { + done <- struct{}{} + conn.Close() + }() + + h.streamHandler(&Conn{conn}, stream, r.Header) +} + // SendSSHPreamble sends the final SSH destination address to the cloudflared SSH proxy // The destination is preceded by its length // Not part of sshserver module to fix compilation for incompatible operating systems