diff --git a/ingress/middleware/jwtvalidator.go b/ingress/middleware/jwtvalidator.go index 667bc1ce..8ee9b789 100644 --- a/ingress/middleware/jwtvalidator.go +++ b/ingress/middleware/jwtvalidator.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/coreos/go-oidc/v3/oidc" - "github.com/pkg/errors" ) const ( @@ -14,7 +13,6 @@ const ( ) var ( - ErrNoAccessToken = errors.New("no access token provided in request") cloudflareAccessCertsURL = "https://%s.cloudflareaccess.com" ) @@ -39,28 +37,43 @@ func NewJWTValidator(teamName string, certsURL string, audTags []string) *JWTVal verifier := oidc.NewVerifier(certsURL, keySet, config) return &JWTValidator{ IDTokenVerifier: verifier, + audTags: audTags, } } -func (v *JWTValidator) Handle(ctx context.Context, r *http.Request) error { +func (v *JWTValidator) Name() string { + return "AccessJWTValidator" +} + +func (v *JWTValidator) Handle(ctx context.Context, r *http.Request) (*HandleResult, error) { accessJWT := r.Header.Get(headerKeyAccessJWTAssertion) if accessJWT == "" { - return ErrNoAccessToken + // log the exact error message here. the message is specific to the handler implementation logic, we don't gain anything + // in passing it upstream. and each handler impl know what logging level to use for each. + return &HandleResult{ + ShouldFilterRequest: true, + StatusCode: http.StatusForbidden, + Reason: "no access token in request", + }, nil } token, err := v.IDTokenVerifier.Verify(ctx, accessJWT) if err != nil { - return fmt.Errorf("Invalid token: %w", err) + return nil, err } - // We want atleast one audTag to match + // We want at least one audTag to match for _, jwtAudTag := range token.Audience { for _, acceptedAudTag := range v.audTags { if acceptedAudTag == jwtAudTag { - return nil + return &HandleResult{ShouldFilterRequest: false}, nil } } } - return fmt.Errorf("Invalid token: %w", err) + return &HandleResult{ + ShouldFilterRequest: true, + StatusCode: http.StatusForbidden, + Reason: fmt.Sprintf("Invalid token in jwt: %v", token.Audience), + }, nil } diff --git a/ingress/middleware/jwtvalidator_test.go b/ingress/middleware/jwtvalidator_test.go deleted file mode 100644 index c12f6011..00000000 --- a/ingress/middleware/jwtvalidator_test.go +++ /dev/null @@ -1,6 +0,0 @@ -package middleware - -import "testing" - -func TestJWTValidatorHandle(t *testing.T) { -} diff --git a/ingress/middleware/middleware.go b/ingress/middleware/middleware.go index 7888dc31..c3b02612 100644 --- a/ingress/middleware/middleware.go +++ b/ingress/middleware/middleware.go @@ -5,6 +5,15 @@ import ( "net/http" ) -type Handler interface { - Handle(ctx context.Context, r *http.Request) error +type HandleResult struct { + // Tells that the request didn't passed the handler and should be filtered + ShouldFilterRequest bool + // The status code to return in case ShouldFilterRequest is true. + StatusCode int + Reason string +} + +type Handler interface { + Name() string + Handle(ctx context.Context, r *http.Request) (result *HandleResult, err error) } diff --git a/ingress/middleware/verifier.go b/ingress/middleware/verifier.go deleted file mode 100644 index 7888dc31..00000000 --- a/ingress/middleware/verifier.go +++ /dev/null @@ -1,10 +0,0 @@ -package middleware - -import ( - "context" - "net/http" -) - -type Handler interface { - Handle(ctx context.Context, r *http.Request) error -} diff --git a/proxy/proxy.go b/proxy/proxy.go index 6fe4efa4..c769e19e 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -60,6 +60,21 @@ func NewOriginProxy( return proxy } +func (p *Proxy) applyIngressMiddleware(rule *ingress.Rule, r *http.Request, w connection.ResponseWriter) (error, bool) { + for _, handler := range rule.Handlers { + result, err := handler.Handle(r.Context(), r) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("error while processing middleware handler %s", handler.Name())), false + } + + if result.ShouldFilterRequest { + w.WriteRespHeaders(result.StatusCode, nil) + return fmt.Errorf("request filtered by middleware handler (%s) due to: %s", handler.Name(), result.Reason), true + } + } + return nil, true +} + // ProxyHTTP further depends on ingress rules to establish a connection with the origin service. This may be // a simple roundtrip or a tcp/websocket dial depending on ingres rule setup. func (p *Proxy) ProxyHTTP( @@ -86,6 +101,13 @@ func (p *Proxy) ProxyHTTP( p.logRequest(req, logFields) ruleSpan.SetAttributes(attribute.Int("rule-num", ruleNum)) ruleSpan.End() + if err, applied := p.applyIngressMiddleware(rule, req, w); err != nil { + if applied { + p.log.Error().Msg(err.Error()) + return nil + } + return err + } switch originProxy := rule.Service.(type) { case ingress.HTTPOriginProxy: