diff --git a/carrier/carrier.go b/carrier/carrier.go index 20af6c70..71541ab4 100644 --- a/carrier/carrier.go +++ b/carrier/carrier.go @@ -12,15 +12,13 @@ import ( "strings" "github.com/cloudflare/cloudflared/cmd/cloudflared/token" - "github.com/cloudflare/cloudflared/sshgen" "github.com/cloudflare/cloudflared/websocket" "github.com/sirupsen/logrus" ) type StartOptions struct { - OriginURL string - Headers http.Header - ShouldGenCert bool + OriginURL string + Headers http.Header } // StdinoutStream is empty struct for wrapping stdin/stdout @@ -116,17 +114,11 @@ func createWebsocketStream(options *StartOptions) (*websocket.Conn, error) { if !strings.Contains(location.String(), "cdn-cgi/access/login") { return nil, errors.New("not an Access redirect") } - req, token, err := buildAccessRequest(options.OriginURL) + req, err := buildAccessRequest(options.OriginURL) if err != nil { return nil, err } - if options.ShouldGenCert { - if err := sshgen.GenerateShortLivedCertificate(req.URL, token); err != nil { - return nil, err - } - } - wsConn, _, err = websocket.ClientConnect(req, nil) if err != nil { return nil, err @@ -139,24 +131,24 @@ func createWebsocketStream(options *StartOptions) (*websocket.Conn, error) { } // buildAccessRequest builds an HTTP request with the Access token set -func buildAccessRequest(originURL string) (*http.Request, string, error) { +func buildAccessRequest(originURL string) (*http.Request, error) { req, err := http.NewRequest(http.MethodGet, originURL, nil) if err != nil { - return nil, "", err + return nil, err } token, err := token.FetchToken(req.URL) if err != nil { - return nil, "", err + return nil, err } // We need to create a new request as FetchToken will modify req (boo mutable) // as it has to follow redirect on the API and such, so here we init a new one originRequest, err := http.NewRequest(http.MethodGet, originURL, nil) if err != nil { - return nil, "", err + return nil, err } originRequest.Header.Set("cf-access-token", token) - return originRequest, token, nil + return originRequest, nil } diff --git a/carrier/carrier_test.go b/carrier/carrier_test.go index 60dfed4e..e4c3d520 100644 --- a/carrier/carrier_test.go +++ b/carrier/carrier_test.go @@ -49,9 +49,8 @@ func TestStartClient(t *testing.T) { buf := newTestStream() options := &StartOptions{ - OriginURL: "http://" + ts.Listener.Addr().String(), - Headers: nil, - ShouldGenCert: false, + OriginURL: "http://" + ts.Listener.Addr().String(), + Headers: nil, } err := StartClient(logger, buf, options) assert.NoError(t, err) @@ -73,9 +72,8 @@ func TestStartServer(t *testing.T) { ts := newTestWebSocketServer() defer ts.Close() options := &StartOptions{ - OriginURL: "http://" + ts.Listener.Addr().String(), - Headers: nil, - ShouldGenCert: false, + OriginURL: "http://" + ts.Listener.Addr().String(), + Headers: nil, } go func() { diff --git a/cmd/cloudflared/access/carrier.go b/cmd/cloudflared/access/carrier.go index a6dc3ae1..0dfb0b87 100644 --- a/cmd/cloudflared/access/carrier.go +++ b/cmd/cloudflared/access/carrier.go @@ -34,12 +34,9 @@ func ssh(c *cli.Context) error { headers.Add("CF-Access-Client-Secret", c.String(sshTokenSecretFlag)) } - genCertBool := c.Bool(sshGenCertFlag) - options := &carrier.StartOptions{ - OriginURL: originURL, - Headers: headers, - ShouldGenCert: genCertBool, + OriginURL: originURL, + Headers: headers, } if c.NArg() > 0 || c.IsSet(sshURLFlag) { diff --git a/cmd/cloudflared/access/cmd.go b/cmd/cloudflared/access/cmd.go index b1babc46..081b7548 100644 --- a/cmd/cloudflared/access/cmd.go +++ b/cmd/cloudflared/access/cmd.go @@ -3,12 +3,15 @@ package access import ( "errors" "fmt" + "html/template" "net/url" "os" "strings" "github.com/cloudflare/cloudflared/cmd/cloudflared/shell" "github.com/cloudflare/cloudflared/cmd/cloudflared/token" + "github.com/cloudflare/cloudflared/sshgen" + "github.com/cloudflare/cloudflared/validation" "golang.org/x/net/idna" "github.com/cloudflare/cloudflared/log" @@ -22,7 +25,23 @@ const ( sshHeaderFlag = "header" sshTokenIDFlag = "service-token-id" sshTokenSecretFlag = "service-token-secret" - sshGenCertFlag = "gen-cert" + sshGenCertFlag = "short-lived-cert" + sshConfigTemplate = ` +Add this configuration block to your {{.Home}}/.ssh/config: + +Host {{.Hostname}} +{{- if .ShortLivedCerts}} + ProxyCommand bash -c '{{.Cloudflared}} access ssh-gen --hostname %h; ssh -tt cfpipe >&2 <&1' + +Host cfpipe-{{.Hostname}} + HostName {{.Hostname}} + ProxyCommand {{.Cloudflared}} access ssh --hostname %h + IdentityFile ~/.cloudflared/{{.Hostname}}.me-cf_key + CertificateFile ~/.cloudflared/{{.Hostname}}-cf_key-cert.pub +{{- else}} + ProxyCommand {{.Cloudflared}} access ssh --hostname %h +{{end}} +` ) const sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b@sentry.io/189878" @@ -124,6 +143,18 @@ func Commands() []*cli.Command { Aliases: []string{"secret"}, Usage: "specify an Access service token secret you wish to use.", }, + }, + }, + { + Name: "ssh-config", + Action: sshConfig, + Usage: "", + Description: `Prints an example configuration ~/.ssh/config`, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: sshHostnameFlag, + Usage: "specify the hostname of your application.", + }, &cli.BoolFlag{ Name: sshGenCertFlag, Usage: "specify if you wish to generate short lived certs.", @@ -131,10 +162,16 @@ func Commands() []*cli.Command { }, }, { - Name: "ssh-config", - Action: sshConfig, - Usage: "ssh-config", - Description: `Prints an example configuration ~/.ssh/config`, + Name: "ssh-gen", + Action: sshGen, + Usage: "", + Description: `Generates a short lived certificate for given hostname`, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: sshHostnameFlag, + Usage: "specify the hostname of your application.", + }, + }, }, }, }, @@ -218,8 +255,49 @@ func generateToken(c *cli.Context) error { // sshConfig prints an example SSH config to stdout func sshConfig(c *cli.Context) error { - outputMessage := "Add this configuration block to your %s/.ssh/config:\n\nHost [your hostname]\n\tProxyCommand %s access ssh --hostname %%h\n" - logger.Printf(outputMessage, os.Getenv("HOME"), cloudflaredPath()) + genCertBool := c.Bool(sshGenCertFlag) + hostname := c.String(sshHostnameFlag) + if hostname == "" { + hostname = "[your hostname]" + } + + type config struct { + Home string + ShortLivedCerts bool + Hostname string + Cloudflared string + } + + t := template.Must(template.New("sshConfig").Parse(sshConfigTemplate)) + return t.Execute(os.Stdout, config{Home: os.Getenv("HOME"), ShortLivedCerts: genCertBool, Hostname: hostname, Cloudflared: cloudflaredPath()}) +} + +// sshGen generates a short lived certificate for provided hostname +func sshGen(c *cli.Context) error { + // get the hostname from the cmdline and error out if its not provided + rawHostName := c.String(sshHostnameFlag) + hostname, err := validation.ValidateHostname(rawHostName) + if err != nil || rawHostName == "" { + return cli.ShowCommandHelp(c, "ssh-gen") + } + + originURL, err := url.Parse("https://" + hostname) + if err != nil { + return err + } + + // this fetchToken function mutates the appURL param. We should refactor that + fetchTokenURL := &url.URL{} + *fetchTokenURL = *originURL + token, err := token.FetchToken(fetchTokenURL) + if err != nil { + return err + } + + if err := sshgen.GenerateShortLivedCertificate(originURL, token); err != nil { + return err + } + return nil }