feat: Add custom HTTP header management via CLI flags and config files

This commit is contained in:
absar0920 2025-08-18 17:31:11 +05:00
parent 50104548cf
commit e4930d09d2
6 changed files with 377 additions and 0 deletions

View File

@ -1056,6 +1056,19 @@ func configureProxyFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide,
Value: false,
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "header",
Aliases: []string{"H"},
Usage: "Add custom header when forwarding to origin (format: 'Name: Value')",
EnvVars: []string{"TUNNEL_HEADERS"},
Hidden: shouldHide,
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "remove-header",
Usage: "Remove header when forwarding to origin",
EnvVars: []string{"TUNNEL_REMOVE_HEADERS"},
Hidden: shouldHide,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.ManagementHostname,
Usage: "Management hostname to signify incoming management requests",

View File

@ -231,6 +231,10 @@ type OriginRequestConfig struct {
Http2Origin *bool `yaml:"http2Origin" json:"http2Origin,omitempty"`
// Access holds all access related configs
Access *AccessConfig `yaml:"access" json:"access,omitempty"`
// Custom headers to add/modify when forwarding to origin
Headers map[string]string `yaml:"headers" json:"headers,omitempty"`
// Headers to remove when forwarding to origin
RemoveHeaders []string `yaml:"removeHeaders" json:"removeHeaders,omitempty"`
}
type AccessConfig struct {

View File

@ -2,6 +2,7 @@ package ingress
import (
"encoding/json"
"strings"
"time"
"github.com/urfave/cli/v2"
@ -137,6 +138,16 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
var proxyPort uint
var proxyType string
var http2Origin bool
var cliHeaders map[string]string
var cliRemoveHeaders []string
if c.IsSet("header") {
cliHeaders = parseHeadersFromCLI(c)
}
if c.IsSet("remove-header") {
cliRemoveHeaders = parseRemoveHeadersFromCLI(c)
}
if flag := ProxyConnectTimeoutFlag; c.IsSet(flag) {
connectTimeout = config.CustomDuration{Duration: c.Duration(flag)}
}
@ -209,6 +220,8 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
ProxyPort: proxyPort,
ProxyType: proxyType,
Http2Origin: http2Origin,
Headers: cliHeaders,
RemoveHeaders: cliRemoveHeaders,
}
}
@ -333,6 +346,12 @@ type OriginRequestConfig struct {
// Access holds all access related configs
Access config.AccessConfig `yaml:"access" json:"access,omitempty"`
// Custom headers to add/modify when forwarding to origin
Headers map[string]string `yaml:"headers" json:"headers,omitempty"`
// Headers to remove when forwarding to origin
RemoveHeaders []string `yaml:"removeHeaders" json:"removeHeaders,omitempty"`
}
func (defaults *OriginRequestConfig) setConnectTimeout(overrides config.OriginRequestConfig) {
@ -456,6 +475,26 @@ func (defaults *OriginRequestConfig) setAccess(overrides config.OriginRequestCon
}
}
func (defaults *OriginRequestConfig) setHeaders(overrides config.OriginRequestConfig) {
if overrides.Headers != nil {
if defaults.Headers == nil {
defaults.Headers = make(map[string]string)
}
for k, v := range overrides.Headers {
defaults.Headers[k] = v
}
}
}
func (defaults *OriginRequestConfig) setRemoveHeaders(overrides config.OriginRequestConfig) {
if overrides.RemoveHeaders != nil {
if defaults.RemoveHeaders == nil {
defaults.RemoveHeaders = make([]string, 0)
}
defaults.RemoveHeaders = append(defaults.RemoveHeaders, overrides.RemoveHeaders...)
}
}
// SetConfig gets config for the requests that cloudflared sends to origins.
// Each field has a setter method which sets a value for the field by trying to find:
// 1. The user config for this rule
@ -485,6 +524,8 @@ func setConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfi
cfg.setIPRules(overrides)
cfg.setHttp2Origin(overrides)
cfg.setAccess(overrides)
cfg.setHeaders(overrides)
cfg.setRemoveHeaders(overrides)
return cfg
}
@ -540,6 +581,8 @@ func ConvertToRawOriginConfig(c OriginRequestConfig) config.OriginRequestConfig
IPRules: convertToRawIPRules(c.IPRules),
Http2Origin: defaultBoolToNil(c.Http2Origin),
Access: access,
Headers: c.Headers,
RemoveHeaders: c.RemoveHeaders,
}
}
@ -583,3 +626,58 @@ func zeroUIntToNil(v uint) *uint {
return &v
}
func parseHeadersFromCLI(c *cli.Context) map[string]string {
headers := make(map[string]string)
if c.IsSet("header") {
headerFlags := c.StringSlice("header")
for _, headerFlag := range headerFlags {
if name, value, valid := parseHeaderFlag(headerFlag); valid {
headers[name] = value
}
}
}
return headers
}
func parseRemoveHeadersFromCLI(c *cli.Context) []string {
if c.IsSet("remove-header") {
return c.StringSlice("remove-header")
}
return nil
}
func parseHeaderFlag(headerFlag string) (name, value string, valid bool) {
parts := strings.SplitN(headerFlag, ":", 2)
if len(parts) != 2 {
return "", "", false
}
name = strings.TrimSpace(parts[0])
value = strings.TrimSpace(parts[1])
if name == "" || value == "" || !isValidHeaderName(name) {
return "", "", false
}
return name, value, true
}
func isValidHeaderName(name string) bool {
if name == "" || strings.Contains(name, ":") {
return false
}
if strings.ContainsAny(name, " \t\r\n") {
return false
}
if strings.TrimSpace(name) == "" {
return false
}
if len(name) > 256 {
return false
}
return true
}

View File

@ -3,6 +3,7 @@ package ingress
import (
"encoding/json"
"flag"
"strings"
"testing"
"time"
@ -12,6 +13,7 @@ import (
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/ipaccess"
"github.com/stretchr/testify/assert"
)
// Ensure that the nullable config from `config` package and the
@ -415,6 +417,187 @@ func TestDefaultConfigFromCLI(t *testing.T) {
require.Equal(t, expected, actual)
}
func TestOriginRequestConfigHeaders(t *testing.T) {
config := OriginRequestConfig{
Headers: map[string]string{
"X-Custom-Header": "custom-value",
"Authorization": "Bearer token123",
},
RemoveHeaders: []string{"X-Unwanted", "Server"},
}
jsonData, err := json.Marshal(config)
assert.NoError(t, err)
assert.Contains(t, string(jsonData), "X-Custom-Header")
assert.Contains(t, string(jsonData), "custom-value")
assert.Contains(t, string(jsonData), "X-Unwanted")
var unmarshaled OriginRequestConfig
err = json.Unmarshal(jsonData, &unmarshaled)
assert.NoError(t, err)
assert.Equal(t, "custom-value", unmarshaled.Headers["X-Custom-Header"])
assert.Equal(t, "Bearer token123", unmarshaled.Headers["Authorization"])
assert.Contains(t, unmarshaled.RemoveHeaders, "X-Unwanted")
assert.Contains(t, unmarshaled.RemoveHeaders, "Server")
}
func TestParseHeaderFlag(t *testing.T) {
name, value, valid := parseHeaderFlag("X-Custom-Header: custom-value")
assert.True(t, valid)
assert.Equal(t, "X-Custom-Header", name)
assert.Equal(t, "custom-value", value)
name, value, valid = parseHeaderFlag(" Authorization : Bearer token ")
assert.True(t, valid)
assert.Equal(t, "Authorization", name)
assert.Equal(t, "Bearer token", value)
name, value, valid = parseHeaderFlag("X-Header: ")
assert.False(t, valid)
name, value, valid = parseHeaderFlag(" : value")
assert.False(t, valid)
_, _, valid = parseHeaderFlag("invalid-format")
assert.False(t, valid)
_, _, valid = parseHeaderFlag(": value-only")
assert.False(t, valid)
_, _, valid = parseHeaderFlag("name-only:")
assert.False(t, valid)
_, _, valid = parseHeaderFlag("")
assert.False(t, valid)
name, value, valid = parseHeaderFlag("X-Special: value with @#$%^&*()")
assert.True(t, valid)
assert.Equal(t, "X-Special", name)
assert.Equal(t, "value with @#$%^&*()", value)
name, value, valid = parseHeaderFlag("X-URL: https://example.com:8080/path")
assert.True(t, valid)
assert.Equal(t, "X-URL", name)
assert.Equal(t, "https://example.com:8080/path", value)
}
func TestIsValidHeaderName(t *testing.T) {
assert.True(t, isValidHeaderName("X-Custom-Header"))
assert.True(t, isValidHeaderName("Authorization"))
assert.True(t, isValidHeaderName("Content-Type"))
assert.True(t, isValidHeaderName("X-API-Key"))
assert.True(t, isValidHeaderName("User-Agent"))
assert.False(t, isValidHeaderName(""))
assert.False(t, isValidHeaderName(" "))
assert.False(t, isValidHeaderName("\t"))
assert.False(t, isValidHeaderName("\n"))
assert.False(t, isValidHeaderName("\r"))
assert.False(t, isValidHeaderName("Header With Space"))
assert.False(t, isValidHeaderName("Header\tWith\tTab"))
assert.False(t, isValidHeaderName("Header\nWith\nNewline"))
assert.False(t, isValidHeaderName("Header\rWith\rCarriageReturn"))
assert.False(t, isValidHeaderName(":Header"))
assert.False(t, isValidHeaderName("Header:"))
assert.False(t, isValidHeaderName("Header::Value"))
longHeader := strings.Repeat("A", 257)
assert.False(t, isValidHeaderName(longHeader))
boundaryHeader := strings.Repeat("A", 256)
assert.True(t, isValidHeaderName(boundaryHeader))
assert.True(t, isValidHeaderName("X"))
assert.True(t, isValidHeaderName("a"))
assert.True(t, isValidHeaderName("1"))
assert.True(t, isValidHeaderName("X-Header"))
assert.True(t, isValidHeaderName("X_Header"))
assert.True(t, isValidHeaderName("X.Header"))
}
func TestParseHeadersFromCLI(t *testing.T) {
app := cli.NewApp()
app.Flags = []cli.Flag{
&cli.StringSliceFlag{
Name: "header",
},
}
app.Action = func(c *cli.Context) error {
headers := parseHeadersFromCLI(c)
assert.Equal(t, 3, len(headers))
assert.Equal(t, "test-value", headers["X-Test-Header"])
assert.Equal(t, "static-key-123", headers["X-API-Key"])
assert.Equal(t, "Bearer token", headers["Authorization"])
assert.NotContains(t, headers, "Invalid-Header")
assert.NotContains(t, headers, "X-Empty")
return nil
}
err := app.Run([]string{"app", "--header", "X-Test-Header: test-value", "--header", "X-API-Key: static-key-123", "--header", "Authorization: Bearer token", "--header", "Invalid-Header", "--header", "X-Empty: "})
assert.NoError(t, err)
}
func TestParseRemoveHeadersFromCLI(t *testing.T) {
app := cli.NewApp()
app.Flags = []cli.Flag{
&cli.StringSliceFlag{
Name: "remove-header",
},
}
app.Action = func(c *cli.Context) error {
removeHeaders := parseRemoveHeadersFromCLI(c)
assert.Equal(t, 3, len(removeHeaders))
assert.Contains(t, removeHeaders, "X-Unwanted")
assert.Contains(t, removeHeaders, "Server")
assert.Contains(t, removeHeaders, "User-Agent")
return nil
}
err := app.Run([]string{"app", "--remove-header", "X-Unwanted", "--remove-header", "Server", "--remove-header", "User-Agent"})
assert.NoError(t, err)
}
func TestParseHeadersFromCLINotSet(t *testing.T) {
app := cli.NewApp()
app.Action = func(c *cli.Context) error {
headers := parseHeadersFromCLI(c)
assert.Equal(t, 0, len(headers))
assert.NotNil(t, headers)
return nil
}
err := app.Run([]string{"app"})
assert.NoError(t, err)
}
func TestParseRemoveHeadersFromCLINotSet(t *testing.T) {
app := cli.NewApp()
app.Action = func(c *cli.Context) error {
removeHeaders := parseRemoveHeadersFromCLI(c)
assert.Nil(t, removeHeaders)
return nil
}
err := app.Run([]string{"app"})
assert.NoError(t, err)
}
func newIPRule(t *testing.T, prefix string, ports []int, allow bool) ipaccess.Rule {
rule, err := ipaccess.NewRuleByCIDR(&prefix, ports, allow)
require.NoError(t, err)

View File

@ -217,6 +217,10 @@ func (p *Proxy) proxyHTTPRequest(
roundTripReq.Header.Set("User-Agent", "")
}
if rule, ok := p.findRuleForRequest(roundTripReq); ok {
p.applyCustomHeaders(roundTripReq, rule)
}
_, ttfbSpan := tr.Tracer().Start(tr.Context(), "ttfb_origin")
resp, err := httpService.RoundTrip(roundTripReq)
if err != nil {
@ -394,3 +398,26 @@ func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
return rule.Service.String(), nil
}
}
func (p *Proxy) applyCustomHeaders(req *http.Request, rule *ingress.Rule) {
if rule == nil {
return
}
if rule.Config.Headers != nil {
for name, value := range rule.Config.Headers {
req.Header.Set(name, value)
}
}
if rule.Config.RemoveHeaders != nil {
for _, name := range rule.Config.RemoveHeaders {
req.Header.Del(name)
}
}
}
func (p *Proxy) findRuleForRequest(req *http.Request) (*ingress.Rule, bool) {
rule, _ := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
return rule, rule != nil
}

View File

@ -1014,3 +1014,55 @@ func runEchoWSService(t *testing.T, l net.Listener) {
}
}()
}
func TestApplyCustomHeaders(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
rule := &ingress.Rule{
Config: ingress.OriginRequestConfig{
Headers: map[string]string{
"X-Custom-Header": "custom-value",
"Authorization": "Bearer token123",
},
RemoveHeaders: []string{"X-Unwanted", "Server"},
},
}
req.Header.Set("X-Unwanted", "remove-me")
req.Header.Set("Server", "nginx")
req.Header.Set("Keep-Me", "keep-this")
proxy := &Proxy{}
proxy.applyCustomHeaders(req, rule)
assert.Equal(t, "custom-value", req.Header.Get("X-Custom-Header"))
assert.Equal(t, "Bearer token123", req.Header.Get("Authorization"))
assert.Empty(t, req.Header.Get("X-Unwanted"))
assert.Empty(t, req.Header.Get("Server"))
assert.Equal(t, "keep-this", req.Header.Get("Keep-Me"))
}
func TestApplyCustomHeadersNoConfig(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
rule := &ingress.Rule{
Config: ingress.OriginRequestConfig{},
}
req.Header.Set("X-Original", "original-value")
proxy := &Proxy{}
proxy.applyCustomHeaders(req, rule)
assert.Equal(t, "original-value", req.Header.Get("X-Original"))
}
func TestApplyCustomHeadersNilRule(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/test", nil)
proxy := &Proxy{}
assert.NotPanics(t, func() {
proxy.applyCustomHeaders(req, nil)
})
}