2022-09-21 14:17:44 +00:00
|
|
|
package middleware
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"fmt"
|
|
|
|
"net/http"
|
|
|
|
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
|
|
)
|
|
|
|
|
|
|
|
const (
|
|
|
|
headerKeyAccessJWTAssertion = "Cf-Access-Jwt-Assertion"
|
|
|
|
)
|
|
|
|
|
|
|
|
var (
|
|
|
|
cloudflareAccessCertsURL = "https://%s.cloudflareaccess.com"
|
|
|
|
)
|
|
|
|
|
|
|
|
// JWTValidator is an implementation of Verifier that validates access based JWT tokens.
|
|
|
|
type JWTValidator struct {
|
|
|
|
*oidc.IDTokenVerifier
|
|
|
|
audTags []string
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewJWTValidator(teamName string, certsURL string, audTags []string) *JWTValidator {
|
|
|
|
if certsURL == "" {
|
|
|
|
certsURL = fmt.Sprintf(cloudflareAccessCertsURL, teamName)
|
|
|
|
}
|
|
|
|
certsEndpoint := fmt.Sprintf("%s/cdn-cgi/access/certs", certsURL)
|
|
|
|
|
|
|
|
config := &oidc.Config{
|
|
|
|
SkipClientIDCheck: true,
|
|
|
|
}
|
|
|
|
|
|
|
|
ctx := context.Background()
|
|
|
|
keySet := oidc.NewRemoteKeySet(ctx, certsEndpoint)
|
|
|
|
verifier := oidc.NewVerifier(certsURL, keySet, config)
|
|
|
|
return &JWTValidator{
|
|
|
|
IDTokenVerifier: verifier,
|
2022-09-22 14:11:59 +00:00
|
|
|
audTags: audTags,
|
2022-09-21 14:17:44 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-09-22 14:11:59 +00:00
|
|
|
func (v *JWTValidator) Name() string {
|
|
|
|
return "AccessJWTValidator"
|
|
|
|
}
|
|
|
|
|
|
|
|
func (v *JWTValidator) Handle(ctx context.Context, r *http.Request) (*HandleResult, error) {
|
2022-09-22 13:04:47 +00:00
|
|
|
accessJWT := r.Header.Get(headerKeyAccessJWTAssertion)
|
2022-09-21 14:17:44 +00:00
|
|
|
if accessJWT == "" {
|
2022-09-22 14:11:59 +00:00
|
|
|
// log the exact error message here. the message is specific to the handler implementation logic, we don't gain anything
|
|
|
|
// in passing it upstream. and each handler impl know what logging level to use for each.
|
|
|
|
return &HandleResult{
|
|
|
|
ShouldFilterRequest: true,
|
|
|
|
StatusCode: http.StatusForbidden,
|
|
|
|
Reason: "no access token in request",
|
|
|
|
}, nil
|
2022-09-21 14:17:44 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
token, err := v.IDTokenVerifier.Verify(ctx, accessJWT)
|
|
|
|
if err != nil {
|
2022-09-22 14:11:59 +00:00
|
|
|
return nil, err
|
2022-09-21 14:17:44 +00:00
|
|
|
}
|
|
|
|
|
2022-09-22 14:11:59 +00:00
|
|
|
// We want at least one audTag to match
|
2022-09-21 14:17:44 +00:00
|
|
|
for _, jwtAudTag := range token.Audience {
|
|
|
|
for _, acceptedAudTag := range v.audTags {
|
|
|
|
if acceptedAudTag == jwtAudTag {
|
2022-09-22 14:11:59 +00:00
|
|
|
return &HandleResult{ShouldFilterRequest: false}, nil
|
2022-09-21 14:17:44 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-09-22 14:11:59 +00:00
|
|
|
return &HandleResult{
|
|
|
|
ShouldFilterRequest: true,
|
|
|
|
StatusCode: http.StatusForbidden,
|
|
|
|
Reason: fmt.Sprintf("Invalid token in jwt: %v", token.Audience),
|
|
|
|
}, nil
|
2022-09-21 14:17:44 +00:00
|
|
|
}
|