From 92defa26d40e319dde3c0639b780b989a083a791 Mon Sep 17 00:00:00 2001 From: Austin Cherry Date: Thu, 7 Feb 2019 10:56:33 -0600 Subject: [PATCH] AUTH-1511: Add custom headers for ssh command --- carrier/carrier.go | 19 ++++++++++--------- carrier/carrier_test.go | 4 ++-- cmd/cloudflared/access/carrier.go | 18 ++++++++++++++++-- cmd/cloudflared/access/carrier_test.go | 18 ++++++++++++++++++ cmd/cloudflared/access/cmd.go | 4 ++++ 5 files changed, 50 insertions(+), 13 deletions(-) create mode 100644 cmd/cloudflared/access/carrier_test.go diff --git a/carrier/carrier.go b/carrier/carrier.go index 5b765011..5b308d1f 100644 --- a/carrier/carrier.go +++ b/carrier/carrier.go @@ -34,13 +34,13 @@ func (c *StdinoutStream) Write(p []byte) (int, error) { // StartClient will copy the data from stdin/stdout over a WebSocket connection // to the edge (originURL) -func StartClient(logger *logrus.Logger, originURL string, stream io.ReadWriter) error { - return serveStream(logger, originURL, stream) +func StartClient(logger *logrus.Logger, originURL string, stream io.ReadWriter, headers http.Header) error { + return serveStream(logger, originURL, stream, headers) } // StartServer will setup a server on a specified port and copy data over a WebSocket connection // to the edge (originURL) -func StartServer(logger *logrus.Logger, address, originURL string, shutdownC <-chan struct{}) error { +func StartServer(logger *logrus.Logger, address, originURL string, shutdownC <-chan struct{}, headers http.Header) error { listener, err := net.Listen("tcp", address) if err != nil { logger.WithError(err).Error("failed to start forwarding server") @@ -56,20 +56,20 @@ func StartServer(logger *logrus.Logger, address, originURL string, shutdownC <-c if err != nil { return err } - go serveConnection(logger, conn, originURL) + go serveConnection(logger, conn, originURL, headers) } } } // serveConnection handles connections for the StartServer call -func serveConnection(logger *logrus.Logger, c net.Conn, originURL string) { +func serveConnection(logger *logrus.Logger, c net.Conn, originURL string, headers http.Header) { defer c.Close() - serveStream(logger, originURL, c) + serveStream(logger, originURL, c, headers) } // serveStream will serve the data over the WebSocket stream -func serveStream(logger *logrus.Logger, originURL string, conn io.ReadWriter) error { - wsConn, err := createWebsocketStream(originURL) +func serveStream(logger *logrus.Logger, originURL string, conn io.ReadWriter, headers http.Header) error { + wsConn, err := createWebsocketStream(originURL, headers) if err != nil { logger.WithError(err).Errorf("failed to connect to %s\n", originURL) return err @@ -84,11 +84,12 @@ func serveStream(logger *logrus.Logger, originURL string, conn io.ReadWriter) er // createWebsocketStream will create a WebSocket connection to stream data over // It also handles redirects from Access and will present that flow if // the token is not present on the request -func createWebsocketStream(originURL string) (*websocket.Conn, error) { +func createWebsocketStream(originURL string, headers http.Header) (*websocket.Conn, error) { req, err := http.NewRequest(http.MethodGet, originURL, nil) if err != nil { return nil, err } + req.Header = headers wsConn, resp, err := websocket.ClientConnect(req, nil) if err != nil && resp != nil && resp.StatusCode > 300 { diff --git a/carrier/carrier_test.go b/carrier/carrier_test.go index 4ccc76d2..0828e45d 100644 --- a/carrier/carrier_test.go +++ b/carrier/carrier_test.go @@ -48,7 +48,7 @@ func TestStartClient(t *testing.T) { defer ts.Close() buf := newTestStream() - err := StartClient(logger, "http://"+ts.Listener.Addr().String(), buf) + err := StartClient(logger, "http://"+ts.Listener.Addr().String(), buf, nil) assert.NoError(t, err) buf.Write([]byte(message)) @@ -66,7 +66,7 @@ func TestStartServer(t *testing.T) { defer ts.Close() go func() { - err := StartServer(logger, listenerAddress, "http://"+ts.Listener.Addr().String(), shutdownC) + err := StartServer(logger, listenerAddress, "http://"+ts.Listener.Addr().String(), shutdownC, nil) if err != nil { t.Fatalf("Error starting server: %v", err) } diff --git a/cmd/cloudflared/access/carrier.go b/cmd/cloudflared/access/carrier.go index 02d8118e..3ab88bc9 100644 --- a/cmd/cloudflared/access/carrier.go +++ b/cmd/cloudflared/access/carrier.go @@ -1,7 +1,9 @@ package access import ( + "net/http" "net/url" + "strings" "github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/cmd/cloudflared/config" @@ -19,6 +21,7 @@ func ssh(c *cli.Context) error { if err != nil || c.String("hostname") == "" { return cli.ShowCommandHelp(c, "ssh") } + headers := buildRequestHeaders(c.StringSlice("header")) if c.NArg() > 0 || c.IsSet("url") { localForwarder, err := config.ValidateUrl(c) @@ -31,8 +34,19 @@ func ssh(c *cli.Context) error { logger.WithError(err).Error("Error validating origin URL") return errors.Wrap(err, "error validating origin URL") } - return carrier.StartServer(logger, forwarder.Host, "https://"+hostname, shutdownC) + return carrier.StartServer(logger, forwarder.Host, "https://"+hostname, shutdownC, headers) } - return carrier.StartClient(logger, "https://"+hostname, &carrier.StdinoutStream{}) + return carrier.StartClient(logger, "https://"+hostname, &carrier.StdinoutStream{}, headers) +} + +func buildRequestHeaders(values []string) http.Header { + headers := make(http.Header) + for _, valuePair := range values { + split := strings.Split(valuePair, ":") + if len(split) > 1 { + headers.Add(strings.TrimSpace(split[0]), strings.TrimSpace(split[1])) + } + } + return headers } diff --git a/cmd/cloudflared/access/carrier_test.go b/cmd/cloudflared/access/carrier_test.go new file mode 100644 index 00000000..2bca1fba --- /dev/null +++ b/cmd/cloudflared/access/carrier_test.go @@ -0,0 +1,18 @@ +package access + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBuildRequestHeaders(t *testing.T) { + headers := make(http.Header) + headers.Add("client", "value") + headers.Add("secret", "safe-value") + + values := buildRequestHeaders([]string{"client: value", "secret: safe-value", "trash"}) + assert.Equal(t, headers.Get("client"), values.Get("client")) + assert.Equal(t, headers.Get("secret"), values.Get("secret")) +} diff --git a/cmd/cloudflared/access/cmd.go b/cmd/cloudflared/access/cmd.go index eefb56e2..d567ca49 100644 --- a/cmd/cloudflared/access/cmd.go +++ b/cmd/cloudflared/access/cmd.go @@ -96,6 +96,10 @@ func Commands() []*cli.Command { &cli.StringFlag{ Name: "url", }, + &cli.StringSliceFlag{ + Name: "header", + Aliases: []string{"H"}, + }, }, }, {