package middleware

import (
	"context"
	"fmt"
	"net/http"

	"github.com/coreos/go-oidc/v3/oidc"
)

const (
	headerKeyAccessJWTAssertion = "Cf-Access-Jwt-Assertion"
)

var (
	cloudflareAccessCertsURL = "https://%s.cloudflareaccess.com"
)

// JWTValidator is an implementation of Verifier that validates access based JWT tokens.
type JWTValidator struct {
	*oidc.IDTokenVerifier
	audTags []string
}

func NewJWTValidator(teamName string, certsURL string, audTags []string) *JWTValidator {
	if certsURL == "" {
		certsURL = fmt.Sprintf(cloudflareAccessCertsURL, teamName)
	}
	certsEndpoint := fmt.Sprintf("%s/cdn-cgi/access/certs", certsURL)

	config := &oidc.Config{
		SkipClientIDCheck: true,
	}

	ctx := context.Background()
	keySet := oidc.NewRemoteKeySet(ctx, certsEndpoint)
	verifier := oidc.NewVerifier(certsURL, keySet, config)
	return &JWTValidator{
		IDTokenVerifier: verifier,
		audTags:         audTags,
	}
}

func (v *JWTValidator) Name() string {
	return "AccessJWTValidator"
}

func (v *JWTValidator) Handle(ctx context.Context, r *http.Request) (*HandleResult, error) {
	accessJWT := r.Header.Get(headerKeyAccessJWTAssertion)
	if accessJWT == "" {
		// log the exact error message here. the message is specific to the handler implementation logic, we don't gain anything
		// in passing it upstream. and each handler impl know what logging level to use for each.
		return &HandleResult{
			ShouldFilterRequest: true,
			StatusCode:          http.StatusForbidden,
			Reason:              "no access token in request",
		}, nil
	}

	token, err := v.IDTokenVerifier.Verify(ctx, accessJWT)
	if err != nil {
		return nil, err
	}

	// We want at least one audTag to match
	for _, jwtAudTag := range token.Audience {
		for _, acceptedAudTag := range v.audTags {
			if acceptedAudTag == jwtAudTag {
				return &HandleResult{ShouldFilterRequest: false}, nil
			}
		}
	}

	return &HandleResult{
		ShouldFilterRequest: true,
		StatusCode:          http.StatusForbidden,
		Reason:              fmt.Sprintf("Invalid token in jwt: %v", token.Audience),
	}, nil
}