package middleware

import (
	"context"
	"crypto"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"encoding/json"
	"fmt"
	"net/http/httptest"
	"testing"
	"time"

	"github.com/coreos/go-oidc/v3/oidc"
	"github.com/go-jose/go-jose/v4"
	"github.com/go-jose/go-jose/v4/jwt"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

var (
	issuer = fmt.Sprintf(cloudflareAccessCertsURL, "testteam")
)

type accessTokenClaims struct {
	Email string `json:"email"`
	Type  string `json:"type"`
	jwt.Claims
}

func TestJWTValidator(t *testing.T) {
	req := httptest.NewRequest("GET", "http://example.com", nil)

	key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	require.NoError(t, err)
	issued := time.Now()
	claims := accessTokenClaims{
		Email: "test@example.com",
		Type:  "app",
		Claims: jwt.Claims{
			Issuer:   issuer,
			Subject:  "ee239b7a-e3e6-4173-972a-8fbe9d99c04f",
			Audience: []string{""},
			Expiry:   jwt.NewNumericDate(issued.Add(time.Hour)),
			IssuedAt: jwt.NewNumericDate(issued),
		},
	}
	token := signToken(t, claims, key)
	req.Header.Add(headerKeyAccessJWTAssertion, token)

	keySet := oidc.StaticKeySet{PublicKeys: []crypto.PublicKey{key.Public()}}
	config := &oidc.Config{
		SkipClientIDCheck:    true,
		SupportedSigningAlgs: []string{string(jose.ES256)},
	}
	verifier := oidc.NewVerifier(issuer, &keySet, config)

	tests := []struct {
		name    string
		audTags []string
		aud     jwt.Audience
		error   bool
	}{
		{
			name: "valid",
			audTags: []string{
				"0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
				"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
			},
			aud:   jwt.Audience{"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba"},
			error: false,
		},
		{
			name: "invalid no match",
			audTags: []string{
				"0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
				"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
			},
			aud:   jwt.Audience{"09dc377143841843ecca28b196bdb1ec1675af38c8b7b60c7def5876c8877157"},
			error: true,
		},
		{
			name:    "invalid empty check",
			audTags: []string{},
			aud:     jwt.Audience{"09dc377143841843ecca28b196bdb1ec1675af38c8b7b60c7def5876c8877157"},
			error:   true,
		},
		{
			name: "invalid absent aud",
			audTags: []string{
				"0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
				"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
			},
			aud:   jwt.Audience{""},
			error: true,
		},
	}
	for _, test := range tests {
		t.Run(test.name, func(t *testing.T) {
			validator := JWTValidator{
				IDTokenVerifier: verifier,
				audTags:         test.audTags,
			}
			claims.Audience = test.aud
			token := signToken(t, claims, key)
			req.Header.Set(headerKeyAccessJWTAssertion, token)

			result, err := validator.Handle(context.Background(), req)
			assert.NoError(t, err)
			assert.Equal(t, test.error, result.ShouldFilterRequest)
		})
	}
}

func signToken(t *testing.T, token accessTokenClaims, key *ecdsa.PrivateKey) string {
	signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, &jose.SignerOptions{})
	require.NoError(t, err)
	payload, err := json.Marshal(token)
	require.NoError(t, err)
	jws, err := signer.Sign(payload)
	require.NoError(t, err)
	jwt, err := jws.CompactSerialize()
	require.NoError(t, err)
	return jwt
}