feat: auto tls sni

Signed-off-by: Steven Kreitzer <skre@skre.me>
This commit is contained in:
Steven Kreitzer 2024-01-18 09:19:11 -06:00 committed by Devin
parent a665d3245a
commit b5be8a6fa4
5 changed files with 51 additions and 7 deletions

View File

@ -205,6 +205,8 @@ type OriginRequestConfig struct {
HTTPHostHeader *string `yaml:"httpHostHeader" json:"httpHostHeader,omitempty"`
// Hostname on the origin server certificate.
OriginServerName *string `yaml:"originServerName" json:"originServerName,omitempty"`
// Auto configure the Hostname on the origin server certificate.
MatchSNIToHost *bool `yaml:"matchSNItoHost" json:"matchSNItoHost,omitempty"`
// Path to the CA for the certificate of your origin.
// This option should be used only if your certificate is not signed by Cloudflare.
CAPool *string `yaml:"caPool" json:"caPool,omitempty"`

View File

@ -32,6 +32,7 @@ const (
ProxyKeepAliveTimeoutFlag = "proxy-keepalive-timeout"
HTTPHostHeaderFlag = "http-host-header"
OriginServerNameFlag = "origin-server-name"
MatchSNIToHostFlag = "match-sni-to-host"
NoTLSVerifyFlag = "no-tls-verify"
NoChunkedEncodingFlag = "no-chunked-encoding"
ProxyAddressFlag = "proxy-address"
@ -118,6 +119,7 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
var keepAliveTimeout = defaultKeepAliveTimeout
var httpHostHeader string
var originServerName string
var matchSNItoHost bool
var caPool string
var noTLSVerify bool
var disableChunkedEncoding bool
@ -150,6 +152,9 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
if flag := OriginServerNameFlag; c.IsSet(flag) {
originServerName = c.String(flag)
}
if flag := MatchSNIToHostFlag; c.IsSet(flag) {
matchSNItoHost = c.Bool(flag)
}
if flag := tlsconfig.OriginCAPoolFlag; c.IsSet(flag) {
caPool = c.String(flag)
}
@ -185,6 +190,7 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
KeepAliveTimeout: keepAliveTimeout,
HTTPHostHeader: httpHostHeader,
OriginServerName: originServerName,
MatchSNIToHost: matchSNItoHost,
CAPool: caPool,
NoTLSVerify: noTLSVerify,
DisableChunkedEncoding: disableChunkedEncoding,
@ -229,6 +235,9 @@ func originRequestFromConfig(c config.OriginRequestConfig) OriginRequestConfig {
if c.OriginServerName != nil {
out.OriginServerName = *c.OriginServerName
}
if c.MatchSNIToHost != nil {
out.MatchSNIToHost = *c.MatchSNIToHost
}
if c.CAPool != nil {
out.CAPool = *c.CAPool
}
@ -287,6 +296,8 @@ type OriginRequestConfig struct {
HTTPHostHeader string `yaml:"httpHostHeader" json:"httpHostHeader"`
// Hostname on the origin server certificate.
OriginServerName string `yaml:"originServerName" json:"originServerName"`
// Auto configure the Hostname on the origin server certificate.
MatchSNIToHost bool `yaml:"matchSNItoHost" json:"matchSNItoHost"`
// Path to the CA for the certificate of your origin.
// This option should be used only if your certificate is not signed by Cloudflare.
CAPool string `yaml:"caPool" json:"caPool"`
@ -362,6 +373,12 @@ func (defaults *OriginRequestConfig) setOriginServerName(overrides config.Origin
}
}
func (defaults *OriginRequestConfig) setMatchSNIToHost(overrides config.OriginRequestConfig) {
if val := overrides.MatchSNIToHost; val != nil {
defaults.MatchSNIToHost = *val
}
}
func (defaults *OriginRequestConfig) setCAPool(overrides config.OriginRequestConfig) {
if val := overrides.CAPool; val != nil {
defaults.CAPool = *val
@ -447,6 +464,7 @@ func setConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfi
cfg.setTCPKeepAlive(overrides)
cfg.setHTTPHostHeader(overrides)
cfg.setOriginServerName(overrides)
cfg.setMatchSNIToHost(overrides)
cfg.setCAPool(overrides)
cfg.setNoTLSVerify(overrides)
cfg.setDisableChunkedEncoding(overrides)
@ -501,6 +519,7 @@ func ConvertToRawOriginConfig(c OriginRequestConfig) config.OriginRequestConfig
KeepAliveTimeout: keepAliveTimeout,
HTTPHostHeader: emptyStringToNil(c.HTTPHostHeader),
OriginServerName: emptyStringToNil(c.OriginServerName),
MatchSNIToHost: defaultBoolToNil(c.MatchSNIToHost),
CAPool: emptyStringToNil(c.CAPool),
NoTLSVerify: defaultBoolToNil(c.NoTLSVerify),
DisableChunkedEncoding: defaultBoolToNil(c.DisableChunkedEncoding),

View File

@ -2,7 +2,9 @@ package ingress
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"github.com/rs/zerolog"
@ -48,9 +50,28 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("X-Forwarded-Host", req.Host)
req.Host = o.hostHeader
}
if o.matchSNIToHost {
o.SetOriginServerName(req)
}
return o.transport.RoundTrip(req)
}
func (o *httpService) SetOriginServerName(req *http.Request) {
o.transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := o.transport.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
return tls.Client(conn, &tls.Config{
RootCAs: o.transport.TLSClientConfig.RootCAs,
InsecureSkipVerify: o.transport.TLSClientConfig.InsecureSkipVerify,
ServerName: req.Host,
}), nil
}
}
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
if o.defaultResp {
o.log.Warn().Msgf(ErrNoIngressRulesCLI.Error())

View File

@ -68,9 +68,10 @@ func (o unixSocketPath) MarshalJSON() ([]byte, error) {
}
type httpService struct {
url *url.URL
hostHeader string
transport *http.Transport
url *url.URL
hostHeader string
transport *http.Transport
matchSNIToHost bool
}
func (o *httpService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
@ -80,6 +81,7 @@ func (o *httpService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRe
}
o.hostHeader = cfg.HTTPHostHeader
o.transport = transport
o.matchSNIToHost = cfg.MatchSNIToHost
return nil
}

View File

@ -204,25 +204,25 @@ func TestMarshalJSON(t *testing.T) {
{
name: "Nil",
path: nil,
expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
want: true,
},
{
name: "Nil regex",
path: &Regexp{Regexp: nil},
expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
want: true,
},
{
name: "Empty",
path: &Regexp{Regexp: regexp.MustCompile("")},
expected: `{"hostname":"example.com","path":"","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
expected: `{"hostname":"example.com","path":"","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
want: true,
},
{
name: "Basic",
path: &Regexp{Regexp: regexp.MustCompile("/echo")},
expected: `{"hostname":"example.com","path":"/echo","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
expected: `{"hostname":"example.com","path":"/echo","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
want: true,
},
}