AUTH-2055: Verifies token at edge on access login
This commit is contained in:
parent
a412f629c2
commit
1d5cc45ac7
|
@ -114,7 +114,7 @@ func createWebsocketStream(options *StartOptions) (*cloudflaredWebsocket.Conn, e
|
||||||
|
|
||||||
wsConn, resp, err := cloudflaredWebsocket.ClientConnect(req, nil)
|
wsConn, resp, err := cloudflaredWebsocket.ClientConnect(req, nil)
|
||||||
defer closeRespBody(resp)
|
defer closeRespBody(resp)
|
||||||
if err != nil && isAccessResponse(resp) {
|
if err != nil && IsAccessResponse(resp) {
|
||||||
wsConn, err = createAccessAuthenticatedStream(options)
|
wsConn, err = createAccessAuthenticatedStream(options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -126,10 +126,10 @@ func createWebsocketStream(options *StartOptions) (*cloudflaredWebsocket.Conn, e
|
||||||
return &cloudflaredWebsocket.Conn{Conn: wsConn}, nil
|
return &cloudflaredWebsocket.Conn{Conn: wsConn}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// isAccessResponse checks the http Response to see if the url location
|
// IsAccessResponse checks the http Response to see if the url location
|
||||||
// contains the Access structure.
|
// contains the Access structure.
|
||||||
func isAccessResponse(resp *http.Response) bool {
|
func IsAccessResponse(resp *http.Response) bool {
|
||||||
if resp == nil || resp.StatusCode <= 300 {
|
if resp == nil || resp.StatusCode != http.StatusFound {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -156,7 +156,7 @@ func createAccessAuthenticatedStream(options *StartOptions) (*websocket.Conn, er
|
||||||
return wsConn, nil
|
return wsConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isAccessResponse(resp) {
|
if !IsAccessResponse(resp) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,7 +179,7 @@ func createAccessAuthenticatedStream(options *StartOptions) (*websocket.Conn, er
|
||||||
|
|
||||||
// createAccessWebSocketStream builds an Access request and makes a connection
|
// createAccessWebSocketStream builds an Access request and makes a connection
|
||||||
func createAccessWebSocketStream(options *StartOptions) (*websocket.Conn, *http.Response, error) {
|
func createAccessWebSocketStream(options *StartOptions) (*websocket.Conn, *http.Response, error) {
|
||||||
req, err := buildAccessRequest(options)
|
req, err := BuildAccessRequest(options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -187,7 +187,7 @@ func createAccessWebSocketStream(options *StartOptions) (*websocket.Conn, *http.
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildAccessRequest builds an HTTP request with the Access token set
|
// buildAccessRequest builds an HTTP request with the Access token set
|
||||||
func buildAccessRequest(options *StartOptions) (*http.Request, error) {
|
func BuildAccessRequest(options *StartOptions) (*http.Request, error) {
|
||||||
req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
|
req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -102,14 +102,14 @@ func TestIsAccessResponse(t *testing.T) {
|
||||||
ExpectedOut bool
|
ExpectedOut bool
|
||||||
}{
|
}{
|
||||||
{"nil response", nil, false},
|
{"nil response", nil, false},
|
||||||
{"redirect with no location", &http.Response{StatusCode: http.StatusPermanentRedirect}, false},
|
{"redirect with no location", &http.Response{StatusCode: http.StatusFound}, false},
|
||||||
{"200 ok", &http.Response{StatusCode: http.StatusOK}, false},
|
{"200 ok", &http.Response{StatusCode: http.StatusOK}, false},
|
||||||
{"redirect with location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: validLocationHeader}, true},
|
{"redirect with location", &http.Response{StatusCode: http.StatusFound, Header: validLocationHeader}, true},
|
||||||
{"redirect with invalid location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: invalidLocationHeader}, false},
|
{"redirect with invalid location", &http.Response{StatusCode: http.StatusFound, Header: invalidLocationHeader}, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tc := range testCases {
|
for i, tc := range testCases {
|
||||||
if isAccessResponse(tc.In) != tc.ExpectedOut {
|
if IsAccessResponse(tc.In) != tc.ExpectedOut {
|
||||||
t.Fatalf("Failed case %d -- %s", i, tc.Description)
|
t.Fatalf("Failed case %d -- %s", i, tc.Description)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,17 +1,20 @@
|
||||||
package access
|
package access
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/carrier"
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/shell"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/shell"
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/token"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/token"
|
||||||
"github.com/cloudflare/cloudflared/sshgen"
|
"github.com/cloudflare/cloudflared/sshgen"
|
||||||
"github.com/cloudflare/cloudflared/validation"
|
"github.com/cloudflare/cloudflared/validation"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"golang.org/x/net/idna"
|
"golang.org/x/net/idna"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/log"
|
"github.com/cloudflare/cloudflared/log"
|
||||||
|
@ -188,9 +191,14 @@ func login(c *cli.Context) error {
|
||||||
logger.Errorf("Please provide the url of the Access application\n")
|
logger.Errorf("Please provide the url of the Access application\n")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
token, err := token.FetchToken(appURL)
|
if err := verifyTokenAtEdge(appURL, c); err != nil {
|
||||||
if err != nil {
|
logger.WithError(err).Error("Could not verify token")
|
||||||
logger.Errorf("Failed to fetch token: %s\n", err)
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := token.GetTokenIfExists(appURL)
|
||||||
|
if err != nil || token == "" {
|
||||||
|
fmt.Fprintln(os.Stderr, "Unable to find token for provided application.")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stdout, "Successfully fetched your token:\n\n%s\n\n", string(token))
|
fmt.Fprintf(os.Stdout, "Successfully fetched your token:\n\n%s\n\n", string(token))
|
||||||
|
@ -372,3 +380,59 @@ func isFileThere(candidate string) bool {
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// verifyTokenAtEdge checks for a token on disk, or generates a new one.
|
||||||
|
// Then makes a request to to the origin with the token to ensure it is valid.
|
||||||
|
// Returns nil if token is valid.
|
||||||
|
func verifyTokenAtEdge(appUrl *url.URL, c *cli.Context) error {
|
||||||
|
headers := buildRequestHeaders(c.StringSlice(sshHeaderFlag))
|
||||||
|
if c.IsSet(sshTokenIDFlag) {
|
||||||
|
headers.Add("CF-Access-Client-Id", c.String(sshTokenIDFlag))
|
||||||
|
}
|
||||||
|
if c.IsSet(sshTokenSecretFlag) {
|
||||||
|
headers.Add("CF-Access-Client-Secret", c.String(sshTokenSecretFlag))
|
||||||
|
}
|
||||||
|
options := &carrier.StartOptions{OriginURL: appUrl.String(), Headers: headers}
|
||||||
|
|
||||||
|
if valid, err := isTokenValid(options); err != nil {
|
||||||
|
return err
|
||||||
|
} else if valid {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := token.RemoveTokenIfExists(appUrl); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if valid, err := isTokenValid(options); err != nil {
|
||||||
|
return err
|
||||||
|
} else if !valid {
|
||||||
|
return errors.New("failed to verify token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isTokenValid makes a request to the origin and returns true if the response was not a 302.
|
||||||
|
func isTokenValid(options *carrier.StartOptions) (bool, error) {
|
||||||
|
req, err := carrier.BuildAccessRequest(options)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "Could not create access request")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do not follow redirects
|
||||||
|
client := &http.Client{
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
Timeout: time.Second * 5,
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// A redirect to login means the token was invalid.
|
||||||
|
return !carrier.IsAccessResponse(resp), nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue