From 1d5cc45ac7d3afcf6dc1f7f3ffa4bbee128f8109 Mon Sep 17 00:00:00 2001 From: Michael Borkenstein Date: Thu, 19 Sep 2019 13:47:08 -0500 Subject: [PATCH] AUTH-2055: Verifies token at edge on access login --- carrier/carrier.go | 14 +++---- carrier/carrier_test.go | 8 ++-- cmd/cloudflared/access/cmd.go | 72 +++++++++++++++++++++++++++++++++-- 3 files changed, 79 insertions(+), 15 deletions(-) diff --git a/carrier/carrier.go b/carrier/carrier.go index e5741169..17cdda8a 100644 --- a/carrier/carrier.go +++ b/carrier/carrier.go @@ -114,7 +114,7 @@ func createWebsocketStream(options *StartOptions) (*cloudflaredWebsocket.Conn, e wsConn, resp, err := cloudflaredWebsocket.ClientConnect(req, nil) defer closeRespBody(resp) - if err != nil && isAccessResponse(resp) { + if err != nil && IsAccessResponse(resp) { wsConn, err = createAccessAuthenticatedStream(options) if err != nil { return nil, err @@ -126,10 +126,10 @@ func createWebsocketStream(options *StartOptions) (*cloudflaredWebsocket.Conn, e return &cloudflaredWebsocket.Conn{Conn: wsConn}, nil } -// isAccessResponse checks the http Response to see if the url location +// IsAccessResponse checks the http Response to see if the url location // contains the Access structure. -func isAccessResponse(resp *http.Response) bool { - if resp == nil || resp.StatusCode <= 300 { +func IsAccessResponse(resp *http.Response) bool { + if resp == nil || resp.StatusCode != http.StatusFound { return false } @@ -156,7 +156,7 @@ func createAccessAuthenticatedStream(options *StartOptions) (*websocket.Conn, er return wsConn, nil } - if !isAccessResponse(resp) { + if !IsAccessResponse(resp) { return nil, err } @@ -179,7 +179,7 @@ func createAccessAuthenticatedStream(options *StartOptions) (*websocket.Conn, er // createAccessWebSocketStream builds an Access request and makes a connection func createAccessWebSocketStream(options *StartOptions) (*websocket.Conn, *http.Response, error) { - req, err := buildAccessRequest(options) + req, err := BuildAccessRequest(options) if err != nil { return nil, nil, err } @@ -187,7 +187,7 @@ func createAccessWebSocketStream(options *StartOptions) (*websocket.Conn, *http. } // buildAccessRequest builds an HTTP request with the Access token set -func buildAccessRequest(options *StartOptions) (*http.Request, error) { +func BuildAccessRequest(options *StartOptions) (*http.Request, error) { req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil) if err != nil { return nil, err diff --git a/carrier/carrier_test.go b/carrier/carrier_test.go index 5360f092..379faadf 100644 --- a/carrier/carrier_test.go +++ b/carrier/carrier_test.go @@ -102,14 +102,14 @@ func TestIsAccessResponse(t *testing.T) { ExpectedOut bool }{ {"nil response", nil, false}, - {"redirect with no location", &http.Response{StatusCode: http.StatusPermanentRedirect}, false}, + {"redirect with no location", &http.Response{StatusCode: http.StatusFound}, false}, {"200 ok", &http.Response{StatusCode: http.StatusOK}, false}, - {"redirect with location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: validLocationHeader}, true}, - {"redirect with invalid location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: invalidLocationHeader}, false}, + {"redirect with location", &http.Response{StatusCode: http.StatusFound, Header: validLocationHeader}, true}, + {"redirect with invalid location", &http.Response{StatusCode: http.StatusFound, Header: invalidLocationHeader}, false}, } for i, tc := range testCases { - if isAccessResponse(tc.In) != tc.ExpectedOut { + if IsAccessResponse(tc.In) != tc.ExpectedOut { t.Fatalf("Failed case %d -- %s", i, tc.Description) } } diff --git a/cmd/cloudflared/access/cmd.go b/cmd/cloudflared/access/cmd.go index dde26b3f..b7751d0e 100644 --- a/cmd/cloudflared/access/cmd.go +++ b/cmd/cloudflared/access/cmd.go @@ -1,17 +1,20 @@ package access import ( - "errors" "fmt" + "net/http" "net/url" "os" "strings" "text/template" + "time" + "github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/cmd/cloudflared/shell" "github.com/cloudflare/cloudflared/cmd/cloudflared/token" "github.com/cloudflare/cloudflared/sshgen" "github.com/cloudflare/cloudflared/validation" + "github.com/pkg/errors" "golang.org/x/net/idna" "github.com/cloudflare/cloudflared/log" @@ -188,9 +191,14 @@ func login(c *cli.Context) error { logger.Errorf("Please provide the url of the Access application\n") return err } - token, err := token.FetchToken(appURL) - if err != nil { - logger.Errorf("Failed to fetch token: %s\n", err) + if err := verifyTokenAtEdge(appURL, c); err != nil { + logger.WithError(err).Error("Could not verify token") + return err + } + + token, err := token.GetTokenIfExists(appURL) + if err != nil || token == "" { + fmt.Fprintln(os.Stderr, "Unable to find token for provided application.") return err } fmt.Fprintf(os.Stdout, "Successfully fetched your token:\n\n%s\n\n", string(token)) @@ -372,3 +380,59 @@ func isFileThere(candidate string) bool { } return true } + +// verifyTokenAtEdge checks for a token on disk, or generates a new one. +// Then makes a request to to the origin with the token to ensure it is valid. +// Returns nil if token is valid. +func verifyTokenAtEdge(appUrl *url.URL, c *cli.Context) error { + headers := buildRequestHeaders(c.StringSlice(sshHeaderFlag)) + if c.IsSet(sshTokenIDFlag) { + headers.Add("CF-Access-Client-Id", c.String(sshTokenIDFlag)) + } + if c.IsSet(sshTokenSecretFlag) { + headers.Add("CF-Access-Client-Secret", c.String(sshTokenSecretFlag)) + } + options := &carrier.StartOptions{OriginURL: appUrl.String(), Headers: headers} + + if valid, err := isTokenValid(options); err != nil { + return err + } else if valid { + return nil + } + + if err := token.RemoveTokenIfExists(appUrl); err != nil { + return err + } + + if valid, err := isTokenValid(options); err != nil { + return err + } else if !valid { + return errors.New("failed to verify token") + } + + return nil +} + +// isTokenValid makes a request to the origin and returns true if the response was not a 302. +func isTokenValid(options *carrier.StartOptions) (bool, error) { + req, err := carrier.BuildAccessRequest(options) + if err != nil { + return false, errors.Wrap(err, "Could not create access request") + } + + // Do not follow redirects + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + Timeout: time.Second * 5, + } + resp, err := client.Do(req) + if err != nil { + return false, err + } + defer resp.Body.Close() + + // A redirect to login means the token was invalid. + return !carrier.IsAccessResponse(resp), nil +}