diff --git a/carrier/carrier_test.go b/carrier/carrier_test.go index bedaad07..aa399080 100644 --- a/carrier/carrier_test.go +++ b/carrier/carrier_test.go @@ -158,7 +158,6 @@ func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request { } func TestBastionDestination(t *testing.T) { - tests := []struct { name string header http.Header diff --git a/ingress/ingress.go b/ingress/ingress.go index cc034351..6bdfcb79 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -54,7 +54,7 @@ func (ing Ingress) FindMatchingRule(hostname, path string, cfJumpDestinationHead for i, rule := range ing.Rules { // If bastion mode is turned on and request is made as bastion, attempt // to match a rule where jump destination header matches the hostname - if rule.Config.BastionMode && len(cfJumpDestinationHeader) > 0 { + if matchBastionDest(rule, cfJumpDestinationHeader) { jumpDestinationUri, err := url.Parse(cfJumpDestinationHeader) if err == nil { derivedHostName = jumpDestinationUri.Hostname() @@ -69,6 +69,10 @@ func (ing Ingress) FindMatchingRule(hostname, path string, cfJumpDestinationHead return &ing.Rules[i], i } +func matchBastionDest(rule Rule, cfJumpDestinationHeader string) bool { + return rule.Config.BastionMode && len(cfJumpDestinationHeader) > 0 && rule.Service != nil && rule.Service.String() != config.BastionFlag +} + func matchHost(ruleHost, reqHost string) bool { if ruleHost == reqHost { return true diff --git a/proxy/proxy.go b/proxy/proxy.go index 8edd3a9a..ac13b2fb 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -90,61 +90,62 @@ func (p *Proxy) ProxyHTTP( rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path, req.Header.Get(carrier.CFJumpDestinationHeader)) ruleSpan.SetAttributes(attribute.Int("rule-num", ruleNum)) ruleSpan.End() + logger := newHTTPLogger(p.log, tr.ConnIndex, req, ruleNum, rule.Service.String()) logHTTPRequest(&logger, req) + if err, applied := p.applyIngressMiddleware(rule, req, w); err != nil { if applied { logRequestError(&logger, err) - return nil } return err } - // Check if config is for Bastion Mode and service is a stream based origin proxy, if so stream service in bastion mode - if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); ok || rule.Config.BastionMode { - if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); !ok && rule.Config.BastionMode { - return fmt.Errorf("Unsupported service to stream to in bastion mode: %s", rule.Service) - } - - dest, err := getDestFromRule(rule, req) - if err != nil { - return err - } - - flusher, ok := w.(http.Flusher) - if !ok { - return fmt.Errorf("response writer is not a flusher") - } - rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req) - logger := logger.With().Str(logFieldDestAddr, dest).Logger() - // We know that Bastion mode is supported by StreamBasedOriginProxy, hence use the same - if err := p.proxyStream(tr.ToTracedContext(), rws, dest, rule.Service.(ingress.StreamBasedOriginProxy), &logger); err != nil { - logRequestError(&logger, err) - return err - } - return nil + if _, isStreamBased := rule.Service.(ingress.StreamBasedOriginProxy); isStreamBased || rule.Config.BastionMode { + return p.handleStreamBasedService(rule, req, w, tr, &logger) } + return p.handleHTTPBasedService(rule, req, w, tr, isWebsocket, &logger) +} + +func (p *Proxy) handleStreamBasedService(rule *ingress.Rule, req *http.Request, w connection.ResponseWriter, tr *tracing.TracedHTTPRequest, logger *zerolog.Logger) error { + // If in bastion mode, we need to resolve the destination from the request, so service like http_status:404 + // won't work since it doesn't have EstablishConnection method to resolve the destination + if _, ok := rule.Service.(ingress.StreamBasedOriginProxy); !ok && rule.Config.BastionMode { + return fmt.Errorf("Unsupported service to stream to in bastion mode: %s", rule.Service) + } + + dest, err := getDestFromRule(rule, req) + if err != nil { + return err + } + + flusher, ok := w.(http.Flusher) + if !ok { + return fmt.Errorf("response writer is not a flusher") + } + + rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req) + if err := p.proxyStream(tr.ToTracedContext(), rws, dest, rule.Service.(ingress.StreamBasedOriginProxy), logger); err != nil { + logRequestError(logger, err) + return err + } + return nil +} + +func (p *Proxy) handleHTTPBasedService(rule *ingress.Rule, req *http.Request, w connection.ResponseWriter, tr *tracing.TracedHTTPRequest, isWebsocket bool, logger *zerolog.Logger) error { switch originProxy := rule.Service.(type) { case ingress.HTTPOriginProxy: - if err := p.proxyHTTPRequest( - w, - tr, - originProxy, - isWebsocket, - rule.Config.DisableChunkedEncoding, - &logger, - ); err != nil { - logRequestError(&logger, err) + if err := p.proxyHTTPRequest(w, tr, originProxy, isWebsocket, rule.Config.DisableChunkedEncoding, logger); err != nil { + logRequestError(logger, err) return err } - return nil case ingress.HTTPLocalProxy: p.proxyLocalRequest(originProxy, w, req, isWebsocket) - return nil default: return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy) } + return nil } // ProxyTCP proxies to a TCP connection between the origin service and cloudflared.