From 200f9a37867fe656a209f755f6e5e33df1ea27b3 Mon Sep 17 00:00:00 2001 From: Austin Cherry Date: Fri, 1 Feb 2019 16:43:59 -0600 Subject: [PATCH] AUTH-1503: Added RDP support --- cmd/cloudflared/access/cmd.go | 1 + cmd/cloudflared/tunnel/cmd.go | 27 ++++++++++++++++++++++----- cmd/cloudflared/tunnel/cmd_test.go | 15 +++++++++++++++ 3 files changed, 38 insertions(+), 5 deletions(-) create mode 100644 cmd/cloudflared/tunnel/cmd_test.go diff --git a/cmd/cloudflared/access/cmd.go b/cmd/cloudflared/access/cmd.go index d567ca49..74b9fedb 100644 --- a/cmd/cloudflared/access/cmd.go +++ b/cmd/cloudflared/access/cmd.go @@ -86,6 +86,7 @@ func Commands() []*cli.Command { { Name: "ssh", Action: ssh, + Aliases: []string{"rdp"}, Usage: "", ArgsUsage: "", Description: `The ssh subcommand sends data over a proxy to the Cloudflare edge.`, diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 61ce6c83..84aa82e1 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -300,11 +300,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan c.Set("url", "https://"+helloListener.Addr().String()) } - if uri, err := url.Parse(c.String("url")); err == nil && uri.Scheme == "ssh" { - host := uri.Host - if uri.Port() == "" { // default to 22 - host = uri.Hostname() + ":22" - } + if host := hostnameFromURI(c.String("url")); host != "" { listener, err := net.Listen("tcp", "127.0.0.1:") if err != nil { logger.WithError(err).Error("Cannot start Websocket Proxy Server") @@ -393,6 +389,27 @@ func writePidFile(waitForSignal chan struct{}, pidFile string) { fmt.Fprintf(file, "%d", os.Getpid()) } +func hostnameFromURI(uri string) string { + u, err := url.Parse(uri) + if err != nil { + return "" + } + switch u.Scheme { + case "ssh": + return addPortIfMissing(u, 22) + case "rdp": + return addPortIfMissing(u, 3389) + } + return "" +} + +func addPortIfMissing(uri *url.URL, port int) string { + if uri.Port() != "" { + return uri.Host + } + return fmt.Sprintf("%s:%d", uri.Hostname(), port) +} + func tunnelFlags(shouldHide bool) []cli.Flag { return []cli.Flag{ &cli.StringFlag{ diff --git a/cmd/cloudflared/tunnel/cmd_test.go b/cmd/cloudflared/tunnel/cmd_test.go new file mode 100644 index 00000000..4d400284 --- /dev/null +++ b/cmd/cloudflared/tunnel/cmd_test.go @@ -0,0 +1,15 @@ +package tunnel + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TesthostnameFromURI(t *testing.T) { + assert.Equal(t, "ssh://awesome.warptunnels.horse:22", hostnameFromURI("ssh://awesome.warptunnels.horse:22")) + assert.Equal(t, "ssh://awesome.warptunnels.horse:22", hostnameFromURI("ssh://awesome.warptunnels.horse")) + assert.Equal(t, "rdp://localhost:3389", hostnameFromURI("rdp://localhost")) + assert.Equal(t, "", hostnameFromURI("trash")) + assert.Equal(t, "", hostnameFromURI("https://awesomesauce.com")) +}