feat: Add custom HTTP header management via CLI flags and config files
This commit is contained in:
parent
50104548cf
commit
e4930d09d2
|
|
@ -1056,6 +1056,19 @@ func configureProxyFlags(shouldHide bool) []cli.Flag {
|
||||||
Hidden: shouldHide,
|
Hidden: shouldHide,
|
||||||
Value: false,
|
Value: false,
|
||||||
}),
|
}),
|
||||||
|
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||||
|
Name: "header",
|
||||||
|
Aliases: []string{"H"},
|
||||||
|
Usage: "Add custom header when forwarding to origin (format: 'Name: Value')",
|
||||||
|
EnvVars: []string{"TUNNEL_HEADERS"},
|
||||||
|
Hidden: shouldHide,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||||
|
Name: "remove-header",
|
||||||
|
Usage: "Remove header when forwarding to origin",
|
||||||
|
EnvVars: []string{"TUNNEL_REMOVE_HEADERS"},
|
||||||
|
Hidden: shouldHide,
|
||||||
|
}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: cfdflags.ManagementHostname,
|
Name: cfdflags.ManagementHostname,
|
||||||
Usage: "Management hostname to signify incoming management requests",
|
Usage: "Management hostname to signify incoming management requests",
|
||||||
|
|
|
||||||
|
|
@ -231,6 +231,10 @@ type OriginRequestConfig struct {
|
||||||
Http2Origin *bool `yaml:"http2Origin" json:"http2Origin,omitempty"`
|
Http2Origin *bool `yaml:"http2Origin" json:"http2Origin,omitempty"`
|
||||||
// Access holds all access related configs
|
// Access holds all access related configs
|
||||||
Access *AccessConfig `yaml:"access" json:"access,omitempty"`
|
Access *AccessConfig `yaml:"access" json:"access,omitempty"`
|
||||||
|
// Custom headers to add/modify when forwarding to origin
|
||||||
|
Headers map[string]string `yaml:"headers" json:"headers,omitempty"`
|
||||||
|
// Headers to remove when forwarding to origin
|
||||||
|
RemoveHeaders []string `yaml:"removeHeaders" json:"removeHeaders,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccessConfig struct {
|
type AccessConfig struct {
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package ingress
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
|
@ -137,6 +138,16 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
|
||||||
var proxyPort uint
|
var proxyPort uint
|
||||||
var proxyType string
|
var proxyType string
|
||||||
var http2Origin bool
|
var http2Origin bool
|
||||||
|
|
||||||
|
var cliHeaders map[string]string
|
||||||
|
var cliRemoveHeaders []string
|
||||||
|
if c.IsSet("header") {
|
||||||
|
cliHeaders = parseHeadersFromCLI(c)
|
||||||
|
}
|
||||||
|
if c.IsSet("remove-header") {
|
||||||
|
cliRemoveHeaders = parseRemoveHeadersFromCLI(c)
|
||||||
|
}
|
||||||
|
|
||||||
if flag := ProxyConnectTimeoutFlag; c.IsSet(flag) {
|
if flag := ProxyConnectTimeoutFlag; c.IsSet(flag) {
|
||||||
connectTimeout = config.CustomDuration{Duration: c.Duration(flag)}
|
connectTimeout = config.CustomDuration{Duration: c.Duration(flag)}
|
||||||
}
|
}
|
||||||
|
|
@ -209,6 +220,8 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
|
||||||
ProxyPort: proxyPort,
|
ProxyPort: proxyPort,
|
||||||
ProxyType: proxyType,
|
ProxyType: proxyType,
|
||||||
Http2Origin: http2Origin,
|
Http2Origin: http2Origin,
|
||||||
|
Headers: cliHeaders,
|
||||||
|
RemoveHeaders: cliRemoveHeaders,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -333,6 +346,12 @@ type OriginRequestConfig struct {
|
||||||
|
|
||||||
// Access holds all access related configs
|
// Access holds all access related configs
|
||||||
Access config.AccessConfig `yaml:"access" json:"access,omitempty"`
|
Access config.AccessConfig `yaml:"access" json:"access,omitempty"`
|
||||||
|
|
||||||
|
// Custom headers to add/modify when forwarding to origin
|
||||||
|
Headers map[string]string `yaml:"headers" json:"headers,omitempty"`
|
||||||
|
|
||||||
|
// Headers to remove when forwarding to origin
|
||||||
|
RemoveHeaders []string `yaml:"removeHeaders" json:"removeHeaders,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (defaults *OriginRequestConfig) setConnectTimeout(overrides config.OriginRequestConfig) {
|
func (defaults *OriginRequestConfig) setConnectTimeout(overrides config.OriginRequestConfig) {
|
||||||
|
|
@ -456,6 +475,26 @@ func (defaults *OriginRequestConfig) setAccess(overrides config.OriginRequestCon
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (defaults *OriginRequestConfig) setHeaders(overrides config.OriginRequestConfig) {
|
||||||
|
if overrides.Headers != nil {
|
||||||
|
if defaults.Headers == nil {
|
||||||
|
defaults.Headers = make(map[string]string)
|
||||||
|
}
|
||||||
|
for k, v := range overrides.Headers {
|
||||||
|
defaults.Headers[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (defaults *OriginRequestConfig) setRemoveHeaders(overrides config.OriginRequestConfig) {
|
||||||
|
if overrides.RemoveHeaders != nil {
|
||||||
|
if defaults.RemoveHeaders == nil {
|
||||||
|
defaults.RemoveHeaders = make([]string, 0)
|
||||||
|
}
|
||||||
|
defaults.RemoveHeaders = append(defaults.RemoveHeaders, overrides.RemoveHeaders...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SetConfig gets config for the requests that cloudflared sends to origins.
|
// SetConfig gets config for the requests that cloudflared sends to origins.
|
||||||
// Each field has a setter method which sets a value for the field by trying to find:
|
// Each field has a setter method which sets a value for the field by trying to find:
|
||||||
// 1. The user config for this rule
|
// 1. The user config for this rule
|
||||||
|
|
@ -485,6 +524,8 @@ func setConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfi
|
||||||
cfg.setIPRules(overrides)
|
cfg.setIPRules(overrides)
|
||||||
cfg.setHttp2Origin(overrides)
|
cfg.setHttp2Origin(overrides)
|
||||||
cfg.setAccess(overrides)
|
cfg.setAccess(overrides)
|
||||||
|
cfg.setHeaders(overrides)
|
||||||
|
cfg.setRemoveHeaders(overrides)
|
||||||
|
|
||||||
return cfg
|
return cfg
|
||||||
}
|
}
|
||||||
|
|
@ -540,6 +581,8 @@ func ConvertToRawOriginConfig(c OriginRequestConfig) config.OriginRequestConfig
|
||||||
IPRules: convertToRawIPRules(c.IPRules),
|
IPRules: convertToRawIPRules(c.IPRules),
|
||||||
Http2Origin: defaultBoolToNil(c.Http2Origin),
|
Http2Origin: defaultBoolToNil(c.Http2Origin),
|
||||||
Access: access,
|
Access: access,
|
||||||
|
Headers: c.Headers,
|
||||||
|
RemoveHeaders: c.RemoveHeaders,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -583,3 +626,58 @@ func zeroUIntToNil(v uint) *uint {
|
||||||
|
|
||||||
return &v
|
return &v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseHeadersFromCLI(c *cli.Context) map[string]string {
|
||||||
|
headers := make(map[string]string)
|
||||||
|
|
||||||
|
if c.IsSet("header") {
|
||||||
|
headerFlags := c.StringSlice("header")
|
||||||
|
for _, headerFlag := range headerFlags {
|
||||||
|
if name, value, valid := parseHeaderFlag(headerFlag); valid {
|
||||||
|
headers[name] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRemoveHeadersFromCLI(c *cli.Context) []string {
|
||||||
|
if c.IsSet("remove-header") {
|
||||||
|
return c.StringSlice("remove-header")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHeaderFlag(headerFlag string) (name, value string, valid bool) {
|
||||||
|
parts := strings.SplitN(headerFlag, ":", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
name = strings.TrimSpace(parts[0])
|
||||||
|
value = strings.TrimSpace(parts[1])
|
||||||
|
|
||||||
|
if name == "" || value == "" || !isValidHeaderName(name) {
|
||||||
|
return "", "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
return name, value, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidHeaderName(name string) bool {
|
||||||
|
if name == "" || strings.Contains(name, ":") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.ContainsAny(name, " \t\r\n") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(name) == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(name) > 256 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package ingress
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"flag"
|
"flag"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -12,6 +13,7 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/config"
|
"github.com/cloudflare/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/ipaccess"
|
"github.com/cloudflare/cloudflared/ipaccess"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Ensure that the nullable config from `config` package and the
|
// Ensure that the nullable config from `config` package and the
|
||||||
|
|
@ -415,6 +417,187 @@ func TestDefaultConfigFromCLI(t *testing.T) {
|
||||||
require.Equal(t, expected, actual)
|
require.Equal(t, expected, actual)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOriginRequestConfigHeaders(t *testing.T) {
|
||||||
|
config := OriginRequestConfig{
|
||||||
|
Headers: map[string]string{
|
||||||
|
"X-Custom-Header": "custom-value",
|
||||||
|
"Authorization": "Bearer token123",
|
||||||
|
},
|
||||||
|
RemoveHeaders: []string{"X-Unwanted", "Server"},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Contains(t, string(jsonData), "X-Custom-Header")
|
||||||
|
assert.Contains(t, string(jsonData), "custom-value")
|
||||||
|
assert.Contains(t, string(jsonData), "X-Unwanted")
|
||||||
|
|
||||||
|
var unmarshaled OriginRequestConfig
|
||||||
|
err = json.Unmarshal(jsonData, &unmarshaled)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "custom-value", unmarshaled.Headers["X-Custom-Header"])
|
||||||
|
assert.Equal(t, "Bearer token123", unmarshaled.Headers["Authorization"])
|
||||||
|
assert.Contains(t, unmarshaled.RemoveHeaders, "X-Unwanted")
|
||||||
|
assert.Contains(t, unmarshaled.RemoveHeaders, "Server")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseHeaderFlag(t *testing.T) {
|
||||||
|
name, value, valid := parseHeaderFlag("X-Custom-Header: custom-value")
|
||||||
|
assert.True(t, valid)
|
||||||
|
assert.Equal(t, "X-Custom-Header", name)
|
||||||
|
assert.Equal(t, "custom-value", value)
|
||||||
|
|
||||||
|
name, value, valid = parseHeaderFlag(" Authorization : Bearer token ")
|
||||||
|
assert.True(t, valid)
|
||||||
|
assert.Equal(t, "Authorization", name)
|
||||||
|
assert.Equal(t, "Bearer token", value)
|
||||||
|
|
||||||
|
name, value, valid = parseHeaderFlag("X-Header: ")
|
||||||
|
assert.False(t, valid)
|
||||||
|
|
||||||
|
name, value, valid = parseHeaderFlag(" : value")
|
||||||
|
assert.False(t, valid)
|
||||||
|
|
||||||
|
_, _, valid = parseHeaderFlag("invalid-format")
|
||||||
|
assert.False(t, valid)
|
||||||
|
|
||||||
|
_, _, valid = parseHeaderFlag(": value-only")
|
||||||
|
assert.False(t, valid)
|
||||||
|
|
||||||
|
_, _, valid = parseHeaderFlag("name-only:")
|
||||||
|
assert.False(t, valid)
|
||||||
|
|
||||||
|
_, _, valid = parseHeaderFlag("")
|
||||||
|
assert.False(t, valid)
|
||||||
|
|
||||||
|
name, value, valid = parseHeaderFlag("X-Special: value with @#$%^&*()")
|
||||||
|
assert.True(t, valid)
|
||||||
|
assert.Equal(t, "X-Special", name)
|
||||||
|
assert.Equal(t, "value with @#$%^&*()", value)
|
||||||
|
|
||||||
|
name, value, valid = parseHeaderFlag("X-URL: https://example.com:8080/path")
|
||||||
|
assert.True(t, valid)
|
||||||
|
assert.Equal(t, "X-URL", name)
|
||||||
|
assert.Equal(t, "https://example.com:8080/path", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValidHeaderName(t *testing.T) {
|
||||||
|
assert.True(t, isValidHeaderName("X-Custom-Header"))
|
||||||
|
assert.True(t, isValidHeaderName("Authorization"))
|
||||||
|
assert.True(t, isValidHeaderName("Content-Type"))
|
||||||
|
assert.True(t, isValidHeaderName("X-API-Key"))
|
||||||
|
assert.True(t, isValidHeaderName("User-Agent"))
|
||||||
|
|
||||||
|
assert.False(t, isValidHeaderName(""))
|
||||||
|
assert.False(t, isValidHeaderName(" "))
|
||||||
|
assert.False(t, isValidHeaderName("\t"))
|
||||||
|
assert.False(t, isValidHeaderName("\n"))
|
||||||
|
assert.False(t, isValidHeaderName("\r"))
|
||||||
|
|
||||||
|
assert.False(t, isValidHeaderName("Header With Space"))
|
||||||
|
assert.False(t, isValidHeaderName("Header\tWith\tTab"))
|
||||||
|
assert.False(t, isValidHeaderName("Header\nWith\nNewline"))
|
||||||
|
assert.False(t, isValidHeaderName("Header\rWith\rCarriageReturn"))
|
||||||
|
|
||||||
|
assert.False(t, isValidHeaderName(":Header"))
|
||||||
|
assert.False(t, isValidHeaderName("Header:"))
|
||||||
|
assert.False(t, isValidHeaderName("Header::Value"))
|
||||||
|
|
||||||
|
longHeader := strings.Repeat("A", 257)
|
||||||
|
assert.False(t, isValidHeaderName(longHeader))
|
||||||
|
|
||||||
|
boundaryHeader := strings.Repeat("A", 256)
|
||||||
|
assert.True(t, isValidHeaderName(boundaryHeader))
|
||||||
|
|
||||||
|
assert.True(t, isValidHeaderName("X"))
|
||||||
|
assert.True(t, isValidHeaderName("a"))
|
||||||
|
assert.True(t, isValidHeaderName("1"))
|
||||||
|
|
||||||
|
assert.True(t, isValidHeaderName("X-Header"))
|
||||||
|
assert.True(t, isValidHeaderName("X_Header"))
|
||||||
|
assert.True(t, isValidHeaderName("X.Header"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseHeadersFromCLI(t *testing.T) {
|
||||||
|
app := cli.NewApp()
|
||||||
|
app.Flags = []cli.Flag{
|
||||||
|
&cli.StringSliceFlag{
|
||||||
|
Name: "header",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app.Action = func(c *cli.Context) error {
|
||||||
|
headers := parseHeadersFromCLI(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 3, len(headers))
|
||||||
|
assert.Equal(t, "test-value", headers["X-Test-Header"])
|
||||||
|
assert.Equal(t, "static-key-123", headers["X-API-Key"])
|
||||||
|
assert.Equal(t, "Bearer token", headers["Authorization"])
|
||||||
|
|
||||||
|
assert.NotContains(t, headers, "Invalid-Header")
|
||||||
|
assert.NotContains(t, headers, "X-Empty")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := app.Run([]string{"app", "--header", "X-Test-Header: test-value", "--header", "X-API-Key: static-key-123", "--header", "Authorization: Bearer token", "--header", "Invalid-Header", "--header", "X-Empty: "})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRemoveHeadersFromCLI(t *testing.T) {
|
||||||
|
app := cli.NewApp()
|
||||||
|
app.Flags = []cli.Flag{
|
||||||
|
&cli.StringSliceFlag{
|
||||||
|
Name: "remove-header",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app.Action = func(c *cli.Context) error {
|
||||||
|
removeHeaders := parseRemoveHeadersFromCLI(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 3, len(removeHeaders))
|
||||||
|
assert.Contains(t, removeHeaders, "X-Unwanted")
|
||||||
|
assert.Contains(t, removeHeaders, "Server")
|
||||||
|
assert.Contains(t, removeHeaders, "User-Agent")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := app.Run([]string{"app", "--remove-header", "X-Unwanted", "--remove-header", "Server", "--remove-header", "User-Agent"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseHeadersFromCLINotSet(t *testing.T) {
|
||||||
|
app := cli.NewApp()
|
||||||
|
|
||||||
|
app.Action = func(c *cli.Context) error {
|
||||||
|
headers := parseHeadersFromCLI(c)
|
||||||
|
|
||||||
|
assert.Equal(t, 0, len(headers))
|
||||||
|
assert.NotNil(t, headers)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := app.Run([]string{"app"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRemoveHeadersFromCLINotSet(t *testing.T) {
|
||||||
|
app := cli.NewApp()
|
||||||
|
|
||||||
|
app.Action = func(c *cli.Context) error {
|
||||||
|
removeHeaders := parseRemoveHeadersFromCLI(c)
|
||||||
|
|
||||||
|
assert.Nil(t, removeHeaders)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := app.Run([]string{"app"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func newIPRule(t *testing.T, prefix string, ports []int, allow bool) ipaccess.Rule {
|
func newIPRule(t *testing.T, prefix string, ports []int, allow bool) ipaccess.Rule {
|
||||||
rule, err := ipaccess.NewRuleByCIDR(&prefix, ports, allow)
|
rule, err := ipaccess.NewRuleByCIDR(&prefix, ports, allow)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
|
||||||
|
|
@ -217,6 +217,10 @@ func (p *Proxy) proxyHTTPRequest(
|
||||||
roundTripReq.Header.Set("User-Agent", "")
|
roundTripReq.Header.Set("User-Agent", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rule, ok := p.findRuleForRequest(roundTripReq); ok {
|
||||||
|
p.applyCustomHeaders(roundTripReq, rule)
|
||||||
|
}
|
||||||
|
|
||||||
_, ttfbSpan := tr.Tracer().Start(tr.Context(), "ttfb_origin")
|
_, ttfbSpan := tr.Tracer().Start(tr.Context(), "ttfb_origin")
|
||||||
resp, err := httpService.RoundTrip(roundTripReq)
|
resp, err := httpService.RoundTrip(roundTripReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -394,3 +398,26 @@ func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
|
||||||
return rule.Service.String(), nil
|
return rule.Service.String(), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) applyCustomHeaders(req *http.Request, rule *ingress.Rule) {
|
||||||
|
if rule == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Config.Headers != nil {
|
||||||
|
for name, value := range rule.Config.Headers {
|
||||||
|
req.Header.Set(name, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Config.RemoveHeaders != nil {
|
||||||
|
for _, name := range rule.Config.RemoveHeaders {
|
||||||
|
req.Header.Del(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) findRuleForRequest(req *http.Request) (*ingress.Rule, bool) {
|
||||||
|
rule, _ := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
|
||||||
|
return rule, rule != nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1014,3 +1014,55 @@ func runEchoWSService(t *testing.T, l net.Listener) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyCustomHeaders(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||||
|
|
||||||
|
rule := &ingress.Rule{
|
||||||
|
Config: ingress.OriginRequestConfig{
|
||||||
|
Headers: map[string]string{
|
||||||
|
"X-Custom-Header": "custom-value",
|
||||||
|
"Authorization": "Bearer token123",
|
||||||
|
},
|
||||||
|
RemoveHeaders: []string{"X-Unwanted", "Server"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("X-Unwanted", "remove-me")
|
||||||
|
req.Header.Set("Server", "nginx")
|
||||||
|
req.Header.Set("Keep-Me", "keep-this")
|
||||||
|
|
||||||
|
proxy := &Proxy{}
|
||||||
|
proxy.applyCustomHeaders(req, rule)
|
||||||
|
|
||||||
|
assert.Equal(t, "custom-value", req.Header.Get("X-Custom-Header"))
|
||||||
|
assert.Equal(t, "Bearer token123", req.Header.Get("Authorization"))
|
||||||
|
assert.Empty(t, req.Header.Get("X-Unwanted"))
|
||||||
|
assert.Empty(t, req.Header.Get("Server"))
|
||||||
|
assert.Equal(t, "keep-this", req.Header.Get("Keep-Me"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCustomHeadersNoConfig(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||||
|
|
||||||
|
rule := &ingress.Rule{
|
||||||
|
Config: ingress.OriginRequestConfig{},
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("X-Original", "original-value")
|
||||||
|
|
||||||
|
proxy := &Proxy{}
|
||||||
|
|
||||||
|
proxy.applyCustomHeaders(req, rule)
|
||||||
|
|
||||||
|
assert.Equal(t, "original-value", req.Header.Get("X-Original"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCustomHeadersNilRule(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||||
|
|
||||||
|
proxy := &Proxy{}
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
proxy.applyCustomHeaders(req, nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue