From 8f25704a901fa8cb76622326ea8c5b9529322a70 Mon Sep 17 00:00:00 2001 From: Austin Cherry Date: Wed, 26 Jun 2019 10:48:45 -0500 Subject: [PATCH] AUTH-1736: Better handling of token revocation We removed all token validation from cloudflared and now rely on the edge to do the validation. This is better because the edge is the only thing that fully knows about token revocation. So if a user logs out or the application revokes all it's tokens cloudflared will now handle that process instead of barfing on it. When we go to fetch a token we will check for the existence of a lock file. If the lock file exists, we stop and poll every half second to see if the lock is still there. Once the lock file is removed, it will restart the function to (hopefully) go pick up the valid token that was just created. --- carrier/carrier.go | 108 ++++++++++++++++++++++++++------- carrier/carrier_test.go | 25 ++++++++ cmd/cloudflared/token/token.go | 95 +++++++++++++++++++++++++---- 3 files changed, 193 insertions(+), 35 deletions(-) diff --git a/carrier/carrier.go b/carrier/carrier.go index 71541ab4..e5741169 100644 --- a/carrier/carrier.go +++ b/carrier/carrier.go @@ -4,7 +4,6 @@ package carrier import ( - "errors" "io" "net" "net/http" @@ -12,7 +11,8 @@ import ( "strings" "github.com/cloudflare/cloudflared/cmd/cloudflared/token" - "github.com/cloudflare/cloudflared/websocket" + cloudflaredWebsocket "github.com/cloudflare/cloudflared/websocket" + "github.com/gorilla/websocket" "github.com/sirupsen/logrus" ) @@ -37,6 +37,13 @@ func (c *StdinoutStream) Write(p []byte) (int, error) { return os.Stdout.Write(p) } +// Helper to allow defering the response close with a check that the resp is not nil +func closeRespBody(resp *http.Response) { + if resp != nil { + resp.Body.Close() + } +} + // StartClient will copy the data from stdin/stdout over a WebSocket connection // to the edge (originURL) func StartClient(logger *logrus.Logger, stream io.ReadWriter, options *StartOptions) error { @@ -90,7 +97,7 @@ func serveStream(logger *logrus.Logger, conn io.ReadWriter, options *StartOption } defer wsConn.Close() - websocket.Stream(wsConn, conn) + cloudflaredWebsocket.Stream(wsConn, conn) return nil } @@ -98,28 +105,17 @@ func serveStream(logger *logrus.Logger, conn io.ReadWriter, options *StartOption // createWebsocketStream will create a WebSocket connection to stream data over // It also handles redirects from Access and will present that flow if // the token is not present on the request -func createWebsocketStream(options *StartOptions) (*websocket.Conn, error) { +func createWebsocketStream(options *StartOptions) (*cloudflaredWebsocket.Conn, error) { req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil) if err != nil { return nil, err } req.Header = options.Headers - wsConn, resp, err := websocket.ClientConnect(req, nil) - if err != nil && resp != nil && resp.StatusCode > 300 { - location, err := resp.Location() - if err != nil { - return nil, err - } - if !strings.Contains(location.String(), "cdn-cgi/access/login") { - return nil, errors.New("not an Access redirect") - } - req, err := buildAccessRequest(options.OriginURL) - if err != nil { - return nil, err - } - - wsConn, _, err = websocket.ClientConnect(req, nil) + wsConn, resp, err := cloudflaredWebsocket.ClientConnect(req, nil) + defer closeRespBody(resp) + if err != nil && isAccessResponse(resp) { + wsConn, err = createAccessAuthenticatedStream(options) if err != nil { return nil, err } @@ -127,12 +123,72 @@ func createWebsocketStream(options *StartOptions) (*websocket.Conn, error) { return nil, err } - return &websocket.Conn{Conn: wsConn}, nil + return &cloudflaredWebsocket.Conn{Conn: wsConn}, nil +} + +// isAccessResponse checks the http Response to see if the url location +// contains the Access structure. +func isAccessResponse(resp *http.Response) bool { + if resp == nil || resp.StatusCode <= 300 { + return false + } + + location, err := resp.Location() + if err != nil || location == nil { + return false + } + if strings.HasPrefix(location.Path, "/cdn-cgi/access/login") { + return true + } + + return false +} + +// createAccessAuthenticatedStream will try load a token from storage and make +// a connection with the token set on the request. If it still get redirect, +// this probably means the token in storage is invalid (expired/revoked). If that +// happens it deletes the token and runs the connection again, so the user can +// login again and generate a new one. +func createAccessAuthenticatedStream(options *StartOptions) (*websocket.Conn, error) { + wsConn, resp, err := createAccessWebSocketStream(options) + defer closeRespBody(resp) + if err == nil { + return wsConn, nil + } + + if !isAccessResponse(resp) { + return nil, err + } + + // Access Token is invalid for some reason. Go through regen flow + originReq, err := http.NewRequest(http.MethodGet, options.OriginURL, nil) + if err != nil { + return nil, err + } + if err := token.RemoveTokenIfExists(originReq.URL); err != nil { + return nil, err + } + wsConn, resp, err = createAccessWebSocketStream(options) + defer closeRespBody(resp) + if err != nil { + return nil, err + } + + return wsConn, nil +} + +// createAccessWebSocketStream builds an Access request and makes a connection +func createAccessWebSocketStream(options *StartOptions) (*websocket.Conn, *http.Response, error) { + req, err := buildAccessRequest(options) + if err != nil { + return nil, nil, err + } + return cloudflaredWebsocket.ClientConnect(req, nil) } // buildAccessRequest builds an HTTP request with the Access token set -func buildAccessRequest(originURL string) (*http.Request, error) { - req, err := http.NewRequest(http.MethodGet, originURL, nil) +func buildAccessRequest(options *StartOptions) (*http.Request, error) { + req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil) if err != nil { return nil, err } @@ -144,11 +200,17 @@ func buildAccessRequest(originURL string) (*http.Request, error) { // We need to create a new request as FetchToken will modify req (boo mutable) // as it has to follow redirect on the API and such, so here we init a new one - originRequest, err := http.NewRequest(http.MethodGet, originURL, nil) + originRequest, err := http.NewRequest(http.MethodGet, options.OriginURL, nil) if err != nil { return nil, err } originRequest.Header.Set("cf-access-token", token) + for k, v := range options.Headers { + if len(v) >= 1 { + originRequest.Header.Set(k, v[0]) + } + } + return originRequest, nil } diff --git a/carrier/carrier_test.go b/carrier/carrier_test.go index e4c3d520..5360f092 100644 --- a/carrier/carrier_test.go +++ b/carrier/carrier_test.go @@ -91,6 +91,31 @@ func TestStartServer(t *testing.T) { assert.Equal(t, string(readBuffer), message) } +func TestIsAccessResponse(t *testing.T) { + validLocationHeader := http.Header{} + validLocationHeader.Add("location", "https://test.cloudflareaccess.com/cdn-cgi/access/login/blahblah") + invalidLocationHeader := http.Header{} + invalidLocationHeader.Add("location", "https://google.com") + testCases := []struct { + Description string + In *http.Response + ExpectedOut bool + }{ + {"nil response", nil, false}, + {"redirect with no location", &http.Response{StatusCode: http.StatusPermanentRedirect}, false}, + {"200 ok", &http.Response{StatusCode: http.StatusOK}, false}, + {"redirect with location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: validLocationHeader}, true}, + {"redirect with invalid location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: invalidLocationHeader}, false}, + } + + for i, tc := range testCases { + if isAccessResponse(tc.In) != tc.ExpectedOut { + t.Fatalf("Failed case %d -- %s", i, tc.Description) + } + } + +} + func newTestWebSocketServer() *httptest.Server { upgrader := ws.Upgrader{ ReadBufferSize: 1024, diff --git a/cmd/cloudflared/token/token.go b/cmd/cloudflared/token/token.go index a7655e4f..8a17076f 100644 --- a/cmd/cloudflared/token/token.go +++ b/cmd/cloudflared/token/token.go @@ -1,15 +1,18 @@ package token import ( + "context" + "fmt" "io/ioutil" "net/url" - "time" + "os" + "github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/path" "github.com/cloudflare/cloudflared/cmd/cloudflared/transfer" "github.com/cloudflare/cloudflared/log" + "github.com/cloudflare/cloudflared/origin" "github.com/coreos/go-oidc/jose" - "github.com/coreos/go-oidc/oidc" ) const ( @@ -18,6 +21,58 @@ const ( var logger = log.CreateLogger() +type lock struct { + lockFilePath string + backoff *origin.BackoffHandler +} + +func errDeleteTokenFailed(lockFilePath string) error { + return fmt.Errorf("failed to acquire a new Access token. Please try to delete %s", lockFilePath) +} + +// newLock will get a new file lock +func newLock(path string) *lock { + lockPath := path + ".lock" + return &lock{ + lockFilePath: lockPath, + backoff: &origin.BackoffHandler{MaxRetries: 7}, + } +} + +func (l *lock) Acquire() error { + // Check for a path.lock file + // if the lock file exists; start polling + // if not, create the lock file and go through the normal flow. + // See AUTH-1736 for the reason why we do all this + for isTokenLocked(l.lockFilePath) { + if l.backoff.Backoff(context.Background()) { + continue + } else { + return errDeleteTokenFailed(l.lockFilePath) + } + } + + // Create a lock file so other processes won't also try to get the token at + // the same time + if err := ioutil.WriteFile(l.lockFilePath, []byte{}, 0600); err != nil { + return err + } + return nil +} + +func (l *lock) Release() error { + if err := os.Remove(l.lockFilePath); err != nil && !os.IsNotExist(err) { + return errDeleteTokenFailed(l.lockFilePath) + } + return nil +} + +// isTokenLocked checks to see if there is another process attempting to get the token already +func isTokenLocked(lockFilePath string) bool { + exists, err := config.FileExists(lockFilePath) + return exists && err == nil +} + // FetchToken will either load a stored token or generate a new one func FetchToken(appURL *url.URL) (string, error) { if token, err := GetTokenIfExists(appURL); token != "" && err == nil { @@ -29,6 +84,18 @@ func FetchToken(appURL *url.URL) (string, error) { return "", err } + lock := newLock(path) + err = lock.Acquire() + if err != nil { + return "", err + } + defer lock.Release() + + // check to see if another process has gotten a token while we waited for the lock + if token, err := GetTokenIfExists(appURL); token != "" && err == nil { + return token, nil + } + // this weird parameter is the resource name (token) and the key/value // we want to send to the transfer service. the key is token and the value // is blank (basically just the id generated in the transfer service) @@ -55,14 +122,18 @@ func GetTokenIfExists(url *url.URL) (string, error) { return "", err } - claims, err := token.Claims() - if err != nil { - return "", err - } - ident, err := oidc.IdentityFromClaims(claims) - // AUTH-1404, reauth if the token is about to expire within 15 minutes - if err == nil && ident.ExpiresAt.After(time.Now().Add(time.Minute*15)) { - return token.Encode(), nil - } - return "", err + return token.Encode(), nil +} + +// RemoveTokenIfExists removes the a token from local storage if it exists +func RemoveTokenIfExists(url *url.URL) error { + path, err := path.GenerateFilePathFromURL(url, keyName) + if err != nil { + return err + } + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return err + } + + return nil }