diff --git a/carrier/carrier.go b/carrier/carrier.go new file mode 100644 index 00000000..82799d90 --- /dev/null +++ b/carrier/carrier.go @@ -0,0 +1,131 @@ +//Package carrier provides a WebSocket proxy to carry or proxy a connection +//from the local client to the edge. See it as a wrapper around any protocol +//that it packages up in a WebSocket connection to the edge. +package carrier + +import ( + "errors" + "io" + "net" + "net/http" + "os" + "strings" + + "github.com/cloudflare/cloudflared/cmd/cloudflared/access" + "github.com/cloudflare/cloudflared/websocket" + "github.com/sirupsen/logrus" +) + +// StdinoutStream is empty struct for wrapping stdin/stdout +// into a single ReadWriter +type StdinoutStream struct { +} + +// Read will read from Stdin +func (c *StdinoutStream) Read(p []byte) (int, error) { + return os.Stdin.Read(p) + +} + +// Write will write to Stdout +func (c *StdinoutStream) Write(p []byte) (int, error) { + return os.Stdout.Write(p) +} + +// StartClient will copy the data from stdin/stdout over a WebSocket connection +// to the edge (originURL) +func StartClient(logger *logrus.Logger, originURL string, stream io.ReadWriter) error { + return serveStream(logger, originURL, stream) +} + +// StartServer will setup a server on a specified port and copy data over a WebSocket connection +// to the edge (originURL) +func StartServer(logger *logrus.Logger, address, originURL string, shutdownC <-chan struct{}) error { + listener, err := net.Listen("tcp", address) + if err != nil { + logger.WithError(err).Error("failed to start forwarding server") + return err + } + defer listener.Close() + for { + select { + case <-shutdownC: + return nil + default: + conn, err := listener.Accept() + if err != nil { + return err + } + go serveConnection(logger, conn, originURL) + } + } +} + +// serveConnection handles connections for the StartServer call +func serveConnection(logger *logrus.Logger, c net.Conn, originURL string) { + defer c.Close() + serveStream(logger, originURL, c) +} + +// serveStream will serve the data over the WebSocket stream +func serveStream(logger *logrus.Logger, originURL string, conn io.ReadWriter) error { + wsConn, err := createWebsocketStream(originURL) + if err != nil { + logger.WithError(err).Error("failed to create websocket stream") + return err + } + defer wsConn.Close() + + websocket.Stream(wsConn, conn) + + return nil +} + +// createWebsocketStream will create a WebSocket connection to stream data over +// It also handles redirects from Access and will present that flow if +// the token is not present on the request +func createWebsocketStream(originURL string) (*websocket.Conn, error) { + req, err := http.NewRequest(http.MethodGet, originURL, nil) + if err != nil { + return nil, err + } + wsConn, resp, err := websocket.ClientConnect(req, nil) + if err != nil && resp != nil && resp.StatusCode > 300 { + location, err := resp.Location() + if err != nil { + return nil, err + } + if !strings.Contains(location.String(), "cdn-cgi/access/login") { + return nil, errors.New("not an Access redirect") + } + req, err := buildAccessRequest(originURL) + if err != nil { + return nil, err + } + + wsConn, _, err = websocket.ClientConnect(req, nil) + if err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + + return &websocket.Conn{Conn: wsConn}, nil +} + +// buildAccessRequest builds an HTTP request with the Access token set +func buildAccessRequest(originURL string) (*http.Request, error) { + req, err := http.NewRequest(http.MethodGet, originURL, nil) + if err != nil { + return nil, err + } + + token, err := access.FetchToken(req.URL) + if err != nil { + return nil, err + } + req.Header.Set("cf-access-token", token) + + return req, nil +} diff --git a/carrier/carrier_test.go b/carrier/carrier_test.go new file mode 100644 index 00000000..6e51b9e2 --- /dev/null +++ b/carrier/carrier_test.go @@ -0,0 +1,117 @@ +package carrier + +import ( + "bytes" + "io" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + + ws "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +const ( + // example in Sec-Websocket-Key in rfc6455 + testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ==" +) + +type testStreamer struct { + buf *bytes.Buffer + l sync.RWMutex +} + +func newTestStream() *testStreamer { + return &testStreamer{buf: new(bytes.Buffer)} +} + +func (s *testStreamer) Read(p []byte) (int, error) { + s.l.RLock() + defer s.l.RUnlock() + return s.buf.Read(p) + +} + +func (s *testStreamer) Write(p []byte) (int, error) { + s.l.Lock() + defer s.l.Unlock() + return s.buf.Write(p) +} + +func TestStartClient(t *testing.T) { + message := "Good morning Austin! Time for another sunny day in the great state of Texas." + logger := logrus.New() + ts := newTestWebSocketServer() + defer ts.Close() + + buf := newTestStream() + err := StartClient(logger, "http://"+ts.Listener.Addr().String(), buf) + assert.NoError(t, err) + buf.Write([]byte(message)) + + readBuffer := make([]byte, len(message)) + buf.Read(readBuffer) + assert.Equal(t, message, string(readBuffer)) +} + +func TestStartServer(t *testing.T) { + listenerAddress := "localhost:1117" + message := "Good morning Austin! Time for another sunny day in the great state of Texas." + logger := logrus.New() + shutdownC := make(chan struct{}) + ts := newTestWebSocketServer() + defer ts.Close() + + go func() { + StartServer(logger, listenerAddress, "http://"+ts.Listener.Addr().String(), shutdownC) + }() + + conn, err := net.Dial("tcp", listenerAddress) + assert.NoError(t, err) + conn.Write([]byte(message)) + + readBuffer := make([]byte, len(message)) + conn.Read(readBuffer) + assert.Equal(t, string(readBuffer), message) +} + +func newTestWebSocketServer() *httptest.Server { + upgrader := ws.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, _ := upgrader.Upgrade(w, r, nil) + defer conn.Close() + for { + mt, message, err := conn.ReadMessage() + if err != nil { + break + } + + if err := conn.WriteMessage(mt, []byte(message)); err != nil { + break + } + } + })) +} + +func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request { + req, err := http.NewRequest("GET", url, stream) + if err != nil { + t.Fatalf("testRequestHeader error") + } + + req.Header.Add("Connection", "Upgrade") + req.Header.Add("Upgrade", "WebSocket") + req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey) + req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol") + req.Header.Add("Sec-Websocket-Version", "13") + req.Header.Add("User-Agent", "curl/7.59.0") + + return req +} diff --git a/cmd/cloudflared/access/cmd.go b/cmd/cloudflared/access/cmd.go index 2661d116..073982ff 100644 --- a/cmd/cloudflared/access/cmd.go +++ b/cmd/cloudflared/access/cmd.go @@ -86,10 +86,12 @@ func login(c *cli.Context) error { logger.Errorf("Please provide the url of the Access application\n") return err } - if _, err := fetchToken(c, appURL); err != nil { + token, err := FetchToken(appURL) + if err != nil { logger.Errorf("Failed to fetch token: %s\n", err) return err } + fmt.Fprintf(os.Stdout, "Successfully fetched your token:\n\n%s\n\n", string(token)) return nil } @@ -115,7 +117,7 @@ func curl(c *cli.Context) error { logger.Warn("You don't have an Access token set. Please run access token to fetch one.") return shell.Run("curl", cmdArgs...) } - token, err = fetchToken(c, appURL) + token, err = FetchToken(appURL) if err != nil { logger.Error("Failed to refresh token: ", err) return err diff --git a/cmd/cloudflared/access/token.go b/cmd/cloudflared/access/token.go index 943fe402..3fd7c20c 100644 --- a/cmd/cloudflared/access/token.go +++ b/cmd/cloudflared/access/token.go @@ -15,15 +15,13 @@ import ( "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/oidc" homedir "github.com/mitchellh/go-homedir" - cli "gopkg.in/urfave/cli.v2" ) var logger = log.CreateLogger() -// fetchToken will either load a stored token or generate a new one -func fetchToken(c *cli.Context, appURL *url.URL) (string, error) { +// FetchToken will either load a stored token or generate a new one +func FetchToken(appURL *url.URL) (string, error) { if token, err := getTokenIfExists(appURL); token != "" && err == nil { - fmt.Fprintf(os.Stdout, "You have an existing token:\n\n%s\n\n", token) return token, nil } @@ -36,12 +34,11 @@ func fetchToken(c *cli.Context, appURL *url.URL) (string, error) { // we want to send to the transfer service. the key is token and the value // is blank (basically just the id generated in the transfer service) const resourceName, key, value = "token", "token", "" - token, err := transfer.Run(c, appURL, resourceName, key, value, path, true) + token, err := transfer.Run(appURL, resourceName, key, value, path, true) if err != nil { return "", err } - fmt.Fprintf(os.Stdout, "Successfully fetched your token:\n\n%s\n\n", string(token)) return string(token), nil } diff --git a/cmd/cloudflared/transfer/transfer.go b/cmd/cloudflared/transfer/transfer.go index 0b786189..6560c5ff 100644 --- a/cmd/cloudflared/transfer/transfer.go +++ b/cmd/cloudflared/transfer/transfer.go @@ -16,7 +16,6 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/encrypter" "github.com/cloudflare/cloudflared/cmd/cloudflared/shell" "github.com/cloudflare/cloudflared/log" - cli "gopkg.in/urfave/cli.v2" ) const ( @@ -32,7 +31,7 @@ var logger = log.CreateLogger() // The "dance" we refer to is building a HTTP request, opening that in a browser waiting for // the user to complete an action, while it long polls in the background waiting for an // action to be completed to download the resource. -func Run(c *cli.Context, transferURL *url.URL, resourceName, key, value, path string, shouldEncrypt bool) ([]byte, error) { +func Run(transferURL *url.URL, resourceName, key, value, path string, shouldEncrypt bool) ([]byte, error) { encrypterClient, err := encrypter.New("cloudflared_priv.pem", "cloudflared_pub.pem") if err != nil { return nil, err @@ -49,16 +48,10 @@ func Run(c *cli.Context, transferURL *url.URL, resourceName, key, value, path st fmt.Fprintf(os.Stdout, "A browser window should have opened at the following URL:\n\n%s\n\nIf the browser failed to open, open it yourself and visit the URL above.\n", requestURL) } - // for local debugging - baseURL := baseStoreURL - if c.IsSet("url") { - baseURL = c.String("url") - } - var resourceData []byte if shouldEncrypt { - buf, key, err := transferRequest(baseURL + filepath.Join("transfer", encrypterClient.PublicKey())) + buf, key, err := transferRequest(baseStoreURL + filepath.Join("transfer", encrypterClient.PublicKey())) if err != nil { return nil, err } @@ -74,7 +67,7 @@ func Run(c *cli.Context, transferURL *url.URL, resourceName, key, value, path st resourceData = decrypted } else { - buf, _, err := transferRequest(baseURL + filepath.Join(encrypterClient.PublicKey())) + buf, _, err := transferRequest(baseStoreURL + filepath.Join(encrypterClient.PublicKey())) if err != nil { return nil, err } diff --git a/cmd/cloudflared/tunnel/carrier.go b/cmd/cloudflared/tunnel/carrier.go new file mode 100644 index 00000000..f0ecc903 --- /dev/null +++ b/cmd/cloudflared/tunnel/carrier.go @@ -0,0 +1,38 @@ +package tunnel + +import ( + "net/url" + + "github.com/cloudflare/cloudflared/carrier" + "github.com/cloudflare/cloudflared/validation" + "github.com/pkg/errors" + cli "gopkg.in/urfave/cli.v2" +) + +// ssh will start a WS proxy server for server mode +// or copy from stdin/stdout for client mode +// useful for proxying other protocols (like ssh) over websockets +// (which you can put Access in front of) +func ssh(c *cli.Context) error { + hostname, err := validation.ValidateHostname(c.String("hostname")) + if err != nil { + logger.WithError(err).Error("Invalid hostname") + return errors.Wrap(err, "invalid hostname") + } + + if c.NArg() > 0 || c.IsSet("url") { + localForwarder, err := validateUrl(c) + if err != nil { + logger.WithError(err).Error("Error validating origin URL") + return errors.Wrap(err, "error validating origin URL") + } + forwarder, err := url.Parse(localForwarder) + if err != nil { + logger.WithError(err).Error("Error validating origin URL") + return errors.Wrap(err, "error validating origin URL") + } + return carrier.StartServer(logger, forwarder.Host, "https://"+hostname, shutdownC) + } + + return carrier.StartClient(logger, "https://"+hostname, &carrier.StdinoutStream{}) +} diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 02e5ada8..eeb88270 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -3,6 +3,7 @@ package tunnel import ( "fmt" "io/ioutil" + "net" "os" "runtime/trace" "sync" @@ -19,6 +20,7 @@ import ( "github.com/cloudflare/cloudflared/metrics" "github.com/cloudflare/cloudflared/origin" "github.com/cloudflare/cloudflared/tunneldns" + "github.com/cloudflare/cloudflared/websocket" "github.com/coreos/go-systemd/daemon" "github.com/facebookgo/grace/gracenet" "github.com/pkg/errors" @@ -137,6 +139,21 @@ func Commands() []*cli.Command { }, Hidden: true, }, + { + Name: "ssh", + Action: ssh, + Usage: `ssh -o ProxyCommand="cloudflared tunnel ssh --hostname %h" ssh.warptunnels.org`, + ArgsUsage: "[origin-url]", + Description: `The ssh subcommand wraps sends data over a WebSocket proxy to the Cloudflare edge.`, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "hostname", + }, + &cli.StringFlag{ + Name: "url", + }, + }, + }, } var subcommands []*cli.Command @@ -308,6 +325,20 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan c.Set("url", "https://"+helloListener.Addr().String()) } + if c.IsSet("ws-proxy-server") { + listener, err := net.Listen("tcp", "127.0.0.1:") + if err != nil { + logger.WithError(err).Error("Cannot start Websocket Proxy Server") + return errors.Wrap(err, "Cannot start Websocket Proxy Server") + } + wg.Add(1) + go func() { + defer wg.Done() + errC <- websocket.StartProxyServer(logger, listener, c.String("remote"), shutdownC) + }() + c.Set("url", "http://"+listener.Addr().String()) + } + tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, protoLogger) if err != nil { return err @@ -447,6 +478,13 @@ func tunnelFlags(shouldHide bool) []cli.Flag { EnvVars: []string{"TUNNEL_URL"}, Hidden: shouldHide, }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "remote", + Value: "localhost:22", + Usage: "Connect to the local server over tcp at `remote`.", + EnvVars: []string{"TUNNEL_REMOTE"}, + Hidden: shouldHide, + }), altsrc.NewStringFlag(&cli.StringFlag{ Name: "hostname", Usage: "Set a hostname on a Cloudflare zone to route traffic through this tunnel.", @@ -549,6 +587,13 @@ func tunnelFlags(shouldHide bool) []cli.Flag { EnvVars: []string{"TUNNEL_HELLO_WORLD"}, Hidden: shouldHide, }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: "ws-proxy-server", + Value: false, + Usage: "Run WS proxy Server", + EnvVars: []string{"TUNNEL_WS_PROXY"}, + Hidden: shouldHide, + }), altsrc.NewStringFlag(&cli.StringFlag{ Name: "pidfile", Usage: "Write the application's PID to this file after first successful connection.", diff --git a/cmd/cloudflared/tunnel/login.go b/cmd/cloudflared/tunnel/login.go index d2d4ec5a..375d39cb 100644 --- a/cmd/cloudflared/tunnel/login.go +++ b/cmd/cloudflared/tunnel/login.go @@ -33,7 +33,7 @@ func login(c *cli.Context) error { return err } - _, err = transfer.Run(c, loginURL, "cert", "callback", callbackStoreURL, path, false) + _, err = transfer.Run(loginURL, "cert", "callback", callbackStoreURL, path, false) if err != nil { fmt.Fprintf(os.Stderr, "Failed to write the certificate due to the following error:\n%v\n\nYour browser will download the certificate instead. You will have to manually\ncopy it to the following path:\n\n%s\n", err, path) return err diff --git a/websocket/websocket.go b/websocket/websocket.go index 371269b2..64959121 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -9,11 +9,24 @@ import ( "io" "net" "net/http" + "time" "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" ) -var stripWebsocketHeaders = []string { +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 +) + +var stripWebsocketHeaders = []string{ "Upgrade", "Connection", "Sec-Websocket-Key", @@ -21,6 +34,32 @@ var stripWebsocketHeaders = []string { "Sec-Websocket-Extensions", } +// Conn is a wrapper around the standard gorilla websocket +// but implements a ReadWriter +type Conn struct { + *websocket.Conn +} + +// Read will read messages from the websocket connection +func (c *Conn) Read(p []byte) (int, error) { + _, message, err := c.Conn.ReadMessage() + if err != nil { + return 0, err + } + + return copy(p, message), nil + +} + +// Write will write messages to the websocket connection +func (c *Conn) Write(p []byte) (int, error) { + if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil { + return 0, err + } + + return len(p), nil +} + // IsWebSocketUpgrade checks to see if the request is a WebSocket connection. func IsWebSocketUpgrade(req *http.Request) bool { return websocket.IsWebSocketUpgrade(req) @@ -36,7 +75,7 @@ func ClientConnect(req *http.Request, tlsClientConfig *tls.Config) (*websocket.C d := &websocket.Dialer{TLSClientConfig: tlsClientConfig} conn, response, err := d.Dial(req.URL.String(), wsHeaders) if err != nil { - return nil, nil, err + return nil, response, err } response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req)) return conn, response, err @@ -74,16 +113,58 @@ func Stream(conn, backendConn io.ReadWriter) { <-proxyDone } +// StartProxyServer will start a websocket server that will decode +// the websocket data and write the resulting data to the provided +// address +func StartProxyServer(logger *logrus.Logger, listener net.Listener, remote string, shutdownC <-chan struct{}) error { + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil} + go func() { + <-shutdownC + httpServer.Close() + }() + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + stream, err := net.Dial("tcp", remote) + if err != nil { + logger.WithError(err).Error("Cannot connect to remote.") + return + } + defer stream.Close() + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + logger.WithError(err).Error("failed to upgrade") + return + } + conn.SetReadDeadline(time.Now().Add(pongWait)) + conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + done := make(chan struct{}) + go pinger(logger, conn, done) + defer func() { + <-done + conn.Close() + }() + Stream(&Conn{conn}, stream) + }) + + return httpServer.Serve(listener) +} + // the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key, // Sec-WebSocket-Version and Sec-Websocket-Extensions headers. // https://github.com/gorilla/websocket/blob/master/client.go#L189-L194. func websocketHeaders(req *http.Request) http.Header { wsHeaders := make(http.Header) for key, val := range req.Header { - wsHeaders[key] = val + wsHeaders[key] = val } // Assume the header keys are in canonical format. - for _, header := range stripWebsocketHeaders { + for _, header := range stripWebsocketHeaders { wsHeaders.Del(header) } return wsHeaders @@ -115,3 +196,19 @@ func changeRequestScheme(req *http.Request) string { return req.URL.Scheme } } + +// pinger simulates the websocket connection to keep it alive +func pinger(logger *logrus.Logger, ws *websocket.Conn, done chan struct{}) { + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil { + logger.WithError(err).Debug("failed to send ping message") + } + case <-done: + return + } + } +} diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go index 45eacd95..8383a422 100644 --- a/websocket/websocket_test.go +++ b/websocket/websocket_test.go @@ -1,100 +1,138 @@ package websocket import ( - "crypto/tls" - "io" - "math/rand" - "net/http" - "testing" + "crypto/tls" + "io" + "math/rand" + "net/http" + "testing" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" - "golang.org/x/net/websocket" + "golang.org/x/net/websocket" - "github.com/cloudflare/cloudflared/hello" - "github.com/cloudflare/cloudflared/tlsconfig" + "github.com/cloudflare/cloudflared/hello" + "github.com/cloudflare/cloudflared/tlsconfig" ) const ( - // example in Sec-Websocket-Key in rfc6455 - testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ==" - // example Sec-Websocket-Accept in rfc6455 - testSecWebsocketAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" + // example in Sec-Websocket-Key in rfc6455 + testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ==" + // example Sec-Websocket-Accept in rfc6455 + testSecWebsocketAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" ) func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request { - req, err := http.NewRequest("GET", url, stream) - if err != nil { - t.Fatalf("testRequestHeader error") - } + req, err := http.NewRequest("GET", url, stream) + if err != nil { + t.Fatalf("testRequestHeader error") + } - req.Header.Add("Connection", "Upgrade") - req.Header.Add("Upgrade", "WebSocket") - req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey) - req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol") - req.Header.Add("Sec-Websocket-Version", "13") - req.Header.Add("User-Agent", "curl/7.59.0") + req.Header.Add("Connection", "Upgrade") + req.Header.Add("Upgrade", "WebSocket") + req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey) + req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol") + req.Header.Add("Sec-Websocket-Version", "13") + req.Header.Add("User-Agent", "curl/7.59.0") - return req + return req } func websocketClientTLSConfig(t *testing.T) *tls.Config { - certPool, err := tlsconfig.LoadOriginCertPool(nil) - assert.NoError(t, err) - assert.NotNil(t, certPool) - return &tls.Config{RootCAs: certPool} + certPool, err := tlsconfig.LoadOriginCertPool(nil) + assert.NoError(t, err) + assert.NotNil(t, certPool) + return &tls.Config{RootCAs: certPool} } func TestWebsocketHeaders(t *testing.T) { - req := testRequest(t, "http://example.com", nil) - wsHeaders := websocketHeaders(req) - for _, header := range stripWebsocketHeaders { - assert.Empty(t, wsHeaders[header]) - } - assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent")) + req := testRequest(t, "http://example.com", nil) + wsHeaders := websocketHeaders(req) + for _, header := range stripWebsocketHeaders { + assert.Empty(t, wsHeaders[header]) + } + assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent")) } func TestGenerateAcceptKey(t *testing.T) { - req := testRequest(t, "http://example.com", nil) - assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(req)) + req := testRequest(t, "http://example.com", nil) + assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(req)) } func TestServe(t *testing.T) { - logger := logrus.New() - shutdownC := make(chan struct{}) - errC := make(chan error) - listener, err := hello.CreateTLSListener("localhost:1111") - assert.NoError(t, err) - defer listener.Close() + logger := logrus.New() + shutdownC := make(chan struct{}) + errC := make(chan error) + listener, err := hello.CreateTLSListener("localhost:1111") + assert.NoError(t, err) + defer listener.Close() - go func() { - errC <- hello.StartHelloWorldServer(logger, listener, shutdownC) - }() + go func() { + errC <- hello.StartHelloWorldServer(logger, listener, shutdownC) + }() - req := testRequest(t, "https://localhost:1111/ws", nil) + req := testRequest(t, "https://localhost:1111/ws", nil) - tlsConfig := websocketClientTLSConfig(t) - assert.NotNil(t, tlsConfig) - conn, resp, err := ClientConnect(req, tlsConfig) - assert.NoError(t, err) - assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept")) + tlsConfig := websocketClientTLSConfig(t) + assert.NotNil(t, tlsConfig) + conn, resp, err := ClientConnect(req, tlsConfig) + assert.NoError(t, err) + assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept")) - for i := 0; i < 1000; i++ { - messageSize := rand.Int() % 2048 + 1 - clientMessage := make([]byte, messageSize) - // rand.Read always returns len(clientMessage) and a nil error - rand.Read(clientMessage) - err = conn.WriteMessage(websocket.BinaryFrame, clientMessage) - assert.NoError(t, err) + for i := 0; i < 1000; i++ { + messageSize := rand.Int()%2048 + 1 + clientMessage := make([]byte, messageSize) + // rand.Read always returns len(clientMessage) and a nil error + rand.Read(clientMessage) + err = conn.WriteMessage(websocket.BinaryFrame, clientMessage) + assert.NoError(t, err) - messageType, message, err := conn.ReadMessage() - assert.NoError(t, err) - assert.Equal(t, websocket.BinaryFrame, messageType) - assert.Equal(t, clientMessage, message) - } + messageType, message, err := conn.ReadMessage() + assert.NoError(t, err) + assert.Equal(t, websocket.BinaryFrame, messageType) + assert.Equal(t, clientMessage, message) + } - conn.Close() - close(shutdownC) - <-errC + conn.Close() + close(shutdownC) + <-errC } + +// func TestStartProxyServer(t *testing.T) { +// var wg sync.WaitGroup +// remoteAddress := "localhost:1113" +// listenerAddress := "localhost:1112" +// message := "Good morning Austin! Time for another sunny day in the great state of Texas." +// logger := logrus.New() +// shutdownC := make(chan struct{}) + +// listener, err := net.Listen("tcp", listenerAddress) +// assert.NoError(t, err) +// defer listener.Close() + +// remoteListener, err := net.Listen("tcp", remoteAddress) +// assert.NoError(t, err) +// defer remoteListener.Close() + +// wg.Add(1) +// go func() { +// defer wg.Done() +// conn, err := remoteListener.Accept() +// assert.NoError(t, err) +// buf := make([]byte, len(message)) +// conn.Read(buf) +// assert.Equal(t, string(buf), message) +// }() + +// go func() { +// StartProxyServer(logger, listener, remoteAddress, shutdownC) +// }() + +// req := testRequest(t, fmt.Sprintf("http://%s/", listenerAddress), nil) +// conn, _, err := ClientConnect(req, nil) +// assert.NoError(t, err) +// err = conn.WriteMessage(1, []byte(message)) +// assert.NoError(t, err) +// wg.Wait() +// }