diff --git a/carrier/carrier.go b/carrier/carrier.go index 71541ab4..e5741169 100644 --- a/carrier/carrier.go +++ b/carrier/carrier.go @@ -4,7 +4,6 @@ package carrier import ( - "errors" "io" "net" "net/http" @@ -12,7 +11,8 @@ import ( "strings" "github.com/cloudflare/cloudflared/cmd/cloudflared/token" - "github.com/cloudflare/cloudflared/websocket" + cloudflaredWebsocket "github.com/cloudflare/cloudflared/websocket" + "github.com/gorilla/websocket" "github.com/sirupsen/logrus" ) @@ -37,6 +37,13 @@ func (c *StdinoutStream) Write(p []byte) (int, error) { return os.Stdout.Write(p) } +// Helper to allow defering the response close with a check that the resp is not nil +func closeRespBody(resp *http.Response) { + if resp != nil { + resp.Body.Close() + } +} + // StartClient will copy the data from stdin/stdout over a WebSocket connection // to the edge (originURL) func StartClient(logger *logrus.Logger, stream io.ReadWriter, options *StartOptions) error { @@ -90,7 +97,7 @@ func serveStream(logger *logrus.Logger, conn io.ReadWriter, options *StartOption } defer wsConn.Close() - websocket.Stream(wsConn, conn) + cloudflaredWebsocket.Stream(wsConn, conn) return nil } @@ -98,28 +105,17 @@ func serveStream(logger *logrus.Logger, conn io.ReadWriter, options *StartOption // createWebsocketStream will create a WebSocket connection to stream data over // It also handles redirects from Access and will present that flow if // the token is not present on the request -func createWebsocketStream(options *StartOptions) (*websocket.Conn, error) { +func createWebsocketStream(options *StartOptions) (*cloudflaredWebsocket.Conn, error) { req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil) if err != nil { return nil, err } req.Header = options.Headers - wsConn, resp, err := websocket.ClientConnect(req, nil) - if err != nil && resp != nil && resp.StatusCode > 300 { - location, err := resp.Location() - if err != nil { - return nil, err - } - if !strings.Contains(location.String(), "cdn-cgi/access/login") { - return nil, errors.New("not an Access redirect") - } - req, err := buildAccessRequest(options.OriginURL) - if err != nil { - return nil, err - } - - wsConn, _, err = websocket.ClientConnect(req, nil) + wsConn, resp, err := cloudflaredWebsocket.ClientConnect(req, nil) + defer closeRespBody(resp) + if err != nil && isAccessResponse(resp) { + wsConn, err = createAccessAuthenticatedStream(options) if err != nil { return nil, err } @@ -127,12 +123,72 @@ func createWebsocketStream(options *StartOptions) (*websocket.Conn, error) { return nil, err } - return &websocket.Conn{Conn: wsConn}, nil + return &cloudflaredWebsocket.Conn{Conn: wsConn}, nil +} + +// isAccessResponse checks the http Response to see if the url location +// contains the Access structure. +func isAccessResponse(resp *http.Response) bool { + if resp == nil || resp.StatusCode <= 300 { + return false + } + + location, err := resp.Location() + if err != nil || location == nil { + return false + } + if strings.HasPrefix(location.Path, "/cdn-cgi/access/login") { + return true + } + + return false +} + +// createAccessAuthenticatedStream will try load a token from storage and make +// a connection with the token set on the request. If it still get redirect, +// this probably means the token in storage is invalid (expired/revoked). If that +// happens it deletes the token and runs the connection again, so the user can +// login again and generate a new one. +func createAccessAuthenticatedStream(options *StartOptions) (*websocket.Conn, error) { + wsConn, resp, err := createAccessWebSocketStream(options) + defer closeRespBody(resp) + if err == nil { + return wsConn, nil + } + + if !isAccessResponse(resp) { + return nil, err + } + + // Access Token is invalid for some reason. Go through regen flow + originReq, err := http.NewRequest(http.MethodGet, options.OriginURL, nil) + if err != nil { + return nil, err + } + if err := token.RemoveTokenIfExists(originReq.URL); err != nil { + return nil, err + } + wsConn, resp, err = createAccessWebSocketStream(options) + defer closeRespBody(resp) + if err != nil { + return nil, err + } + + return wsConn, nil +} + +// createAccessWebSocketStream builds an Access request and makes a connection +func createAccessWebSocketStream(options *StartOptions) (*websocket.Conn, *http.Response, error) { + req, err := buildAccessRequest(options) + if err != nil { + return nil, nil, err + } + return cloudflaredWebsocket.ClientConnect(req, nil) } // buildAccessRequest builds an HTTP request with the Access token set -func buildAccessRequest(originURL string) (*http.Request, error) { - req, err := http.NewRequest(http.MethodGet, originURL, nil) +func buildAccessRequest(options *StartOptions) (*http.Request, error) { + req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil) if err != nil { return nil, err } @@ -144,11 +200,17 @@ func buildAccessRequest(originURL string) (*http.Request, error) { // We need to create a new request as FetchToken will modify req (boo mutable) // as it has to follow redirect on the API and such, so here we init a new one - originRequest, err := http.NewRequest(http.MethodGet, originURL, nil) + originRequest, err := http.NewRequest(http.MethodGet, options.OriginURL, nil) if err != nil { return nil, err } originRequest.Header.Set("cf-access-token", token) + for k, v := range options.Headers { + if len(v) >= 1 { + originRequest.Header.Set(k, v[0]) + } + } + return originRequest, nil } diff --git a/carrier/carrier_test.go b/carrier/carrier_test.go index e4c3d520..5360f092 100644 --- a/carrier/carrier_test.go +++ b/carrier/carrier_test.go @@ -91,6 +91,31 @@ func TestStartServer(t *testing.T) { assert.Equal(t, string(readBuffer), message) } +func TestIsAccessResponse(t *testing.T) { + validLocationHeader := http.Header{} + validLocationHeader.Add("location", "https://test.cloudflareaccess.com/cdn-cgi/access/login/blahblah") + invalidLocationHeader := http.Header{} + invalidLocationHeader.Add("location", "https://google.com") + testCases := []struct { + Description string + In *http.Response + ExpectedOut bool + }{ + {"nil response", nil, false}, + {"redirect with no location", &http.Response{StatusCode: http.StatusPermanentRedirect}, false}, + {"200 ok", &http.Response{StatusCode: http.StatusOK}, false}, + {"redirect with location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: validLocationHeader}, true}, + {"redirect with invalid location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: invalidLocationHeader}, false}, + } + + for i, tc := range testCases { + if isAccessResponse(tc.In) != tc.ExpectedOut { + t.Fatalf("Failed case %d -- %s", i, tc.Description) + } + } + +} + func newTestWebSocketServer() *httptest.Server { upgrader := ws.Upgrader{ ReadBufferSize: 1024, diff --git a/cmd/cloudflared/token/token.go b/cmd/cloudflared/token/token.go index a7655e4f..8a17076f 100644 --- a/cmd/cloudflared/token/token.go +++ b/cmd/cloudflared/token/token.go @@ -1,15 +1,18 @@ package token import ( + "context" + "fmt" "io/ioutil" "net/url" - "time" + "os" + "github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/cmd/cloudflared/path" "github.com/cloudflare/cloudflared/cmd/cloudflared/transfer" "github.com/cloudflare/cloudflared/log" + "github.com/cloudflare/cloudflared/origin" "github.com/coreos/go-oidc/jose" - "github.com/coreos/go-oidc/oidc" ) const ( @@ -18,6 +21,58 @@ const ( var logger = log.CreateLogger() +type lock struct { + lockFilePath string + backoff *origin.BackoffHandler +} + +func errDeleteTokenFailed(lockFilePath string) error { + return fmt.Errorf("failed to acquire a new Access token. Please try to delete %s", lockFilePath) +} + +// newLock will get a new file lock +func newLock(path string) *lock { + lockPath := path + ".lock" + return &lock{ + lockFilePath: lockPath, + backoff: &origin.BackoffHandler{MaxRetries: 7}, + } +} + +func (l *lock) Acquire() error { + // Check for a path.lock file + // if the lock file exists; start polling + // if not, create the lock file and go through the normal flow. + // See AUTH-1736 for the reason why we do all this + for isTokenLocked(l.lockFilePath) { + if l.backoff.Backoff(context.Background()) { + continue + } else { + return errDeleteTokenFailed(l.lockFilePath) + } + } + + // Create a lock file so other processes won't also try to get the token at + // the same time + if err := ioutil.WriteFile(l.lockFilePath, []byte{}, 0600); err != nil { + return err + } + return nil +} + +func (l *lock) Release() error { + if err := os.Remove(l.lockFilePath); err != nil && !os.IsNotExist(err) { + return errDeleteTokenFailed(l.lockFilePath) + } + return nil +} + +// isTokenLocked checks to see if there is another process attempting to get the token already +func isTokenLocked(lockFilePath string) bool { + exists, err := config.FileExists(lockFilePath) + return exists && err == nil +} + // FetchToken will either load a stored token or generate a new one func FetchToken(appURL *url.URL) (string, error) { if token, err := GetTokenIfExists(appURL); token != "" && err == nil { @@ -29,6 +84,18 @@ func FetchToken(appURL *url.URL) (string, error) { return "", err } + lock := newLock(path) + err = lock.Acquire() + if err != nil { + return "", err + } + defer lock.Release() + + // check to see if another process has gotten a token while we waited for the lock + if token, err := GetTokenIfExists(appURL); token != "" && err == nil { + return token, nil + } + // this weird parameter is the resource name (token) and the key/value // we want to send to the transfer service. the key is token and the value // is blank (basically just the id generated in the transfer service) @@ -55,14 +122,18 @@ func GetTokenIfExists(url *url.URL) (string, error) { return "", err } - claims, err := token.Claims() - if err != nil { - return "", err - } - ident, err := oidc.IdentityFromClaims(claims) - // AUTH-1404, reauth if the token is about to expire within 15 minutes - if err == nil && ident.ExpiresAt.After(time.Now().Add(time.Minute*15)) { - return token.Encode(), nil - } - return "", err + return token.Encode(), nil +} + +// RemoveTokenIfExists removes the a token from local storage if it exists +func RemoveTokenIfExists(url *url.URL) error { + path, err := path.GenerateFilePathFromURL(url, keyName) + if err != nil { + return err + } + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return err + } + + return nil }