127 lines
3.3 KiB
Go
127 lines
3.3 KiB
Go
|
package middleware
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto"
|
||
|
"crypto/ecdsa"
|
||
|
"crypto/elliptic"
|
||
|
"crypto/rand"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"net/http/httptest"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||
|
"github.com/go-jose/go-jose/v4"
|
||
|
"github.com/go-jose/go-jose/v4/jwt"
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
issuer = fmt.Sprintf(cloudflareAccessCertsURL, "testteam")
|
||
|
)
|
||
|
|
||
|
type accessTokenClaims struct {
|
||
|
Email string `json:"email"`
|
||
|
Type string `json:"type"`
|
||
|
jwt.Claims
|
||
|
}
|
||
|
|
||
|
func TestJWTValidator(t *testing.T) {
|
||
|
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||
|
|
||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||
|
require.NoError(t, err)
|
||
|
issued := time.Now()
|
||
|
claims := accessTokenClaims{
|
||
|
Email: "test@example.com",
|
||
|
Type: "app",
|
||
|
Claims: jwt.Claims{
|
||
|
Issuer: issuer,
|
||
|
Subject: "ee239b7a-e3e6-4173-972a-8fbe9d99c04f",
|
||
|
Audience: []string{""},
|
||
|
Expiry: jwt.NewNumericDate(issued.Add(time.Hour)),
|
||
|
IssuedAt: jwt.NewNumericDate(issued),
|
||
|
},
|
||
|
}
|
||
|
token := signToken(t, claims, key)
|
||
|
req.Header.Add(headerKeyAccessJWTAssertion, token)
|
||
|
|
||
|
keySet := oidc.StaticKeySet{PublicKeys: []crypto.PublicKey{key.Public()}}
|
||
|
config := &oidc.Config{
|
||
|
SkipClientIDCheck: true,
|
||
|
SupportedSigningAlgs: []string{string(jose.ES256)},
|
||
|
}
|
||
|
verifier := oidc.NewVerifier(issuer, &keySet, config)
|
||
|
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
audTags []string
|
||
|
aud jwt.Audience
|
||
|
error bool
|
||
|
}{
|
||
|
{
|
||
|
name: "valid",
|
||
|
audTags: []string{
|
||
|
"0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
|
||
|
"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
|
||
|
},
|
||
|
aud: jwt.Audience{"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba"},
|
||
|
error: false,
|
||
|
},
|
||
|
{
|
||
|
name: "invalid no match",
|
||
|
audTags: []string{
|
||
|
"0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
|
||
|
"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
|
||
|
},
|
||
|
aud: jwt.Audience{"09dc377143841843ecca28b196bdb1ec1675af38c8b7b60c7def5876c8877157"},
|
||
|
error: true,
|
||
|
},
|
||
|
{
|
||
|
name: "invalid empty check",
|
||
|
audTags: []string{},
|
||
|
aud: jwt.Audience{"09dc377143841843ecca28b196bdb1ec1675af38c8b7b60c7def5876c8877157"},
|
||
|
error: true,
|
||
|
},
|
||
|
{
|
||
|
name: "invalid absent aud",
|
||
|
audTags: []string{
|
||
|
"0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
|
||
|
"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
|
||
|
},
|
||
|
aud: jwt.Audience{""},
|
||
|
error: true,
|
||
|
},
|
||
|
}
|
||
|
for _, test := range tests {
|
||
|
t.Run(test.name, func(t *testing.T) {
|
||
|
validator := JWTValidator{
|
||
|
IDTokenVerifier: verifier,
|
||
|
audTags: test.audTags,
|
||
|
}
|
||
|
claims.Audience = test.aud
|
||
|
token := signToken(t, claims, key)
|
||
|
req.Header.Set(headerKeyAccessJWTAssertion, token)
|
||
|
|
||
|
result, err := validator.Handle(context.Background(), req)
|
||
|
assert.NoError(t, err)
|
||
|
assert.Equal(t, test.error, result.ShouldFilterRequest)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func signToken(t *testing.T, token accessTokenClaims, key *ecdsa.PrivateKey) string {
|
||
|
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, &jose.SignerOptions{})
|
||
|
require.NoError(t, err)
|
||
|
payload, err := json.Marshal(token)
|
||
|
require.NoError(t, err)
|
||
|
jws, err := signer.Sign(payload)
|
||
|
require.NoError(t, err)
|
||
|
jwt, err := jws.CompactSerialize()
|
||
|
require.NoError(t, err)
|
||
|
return jwt
|
||
|
}
|