TUN-3492: Refactor OriginService, shrink its interface
This commit is contained in:
parent
18c359cb86
commit
d01770107e
|
@ -395,7 +395,7 @@ func StartServer(
|
||||||
hostname,
|
hostname,
|
||||||
metricsListener.Addr().String(),
|
metricsListener.Addr().String(),
|
||||||
// TODO (TUN-3461): Update UI to show multiple origin URLs
|
// TODO (TUN-3461): Update UI to show multiple origin URLs
|
||||||
tunnelConfig.IngressRules.CatchAll().Service.Address(),
|
tunnelConfig.IngressRules.CatchAll().Service.String(),
|
||||||
tunnelConfig.HAConnections,
|
tunnelConfig.HAConnections,
|
||||||
)
|
)
|
||||||
logLevels, err := logger.ParseLevelString(c.String("loglevel"))
|
logLevels, err := logger.ParseLevelString(c.String("loglevel"))
|
||||||
|
|
|
@ -202,7 +202,7 @@ func prepareTunnelConfig(
|
||||||
Version: version,
|
Version: version,
|
||||||
Arch: fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch),
|
Arch: fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch),
|
||||||
}
|
}
|
||||||
ingressRules, err = ingress.ParseIngress(config.GetConfiguration(), logger)
|
ingressRules, err = ingress.ParseIngress(config.GetConfiguration())
|
||||||
if err != nil && err != ingress.ErrNoIngressRules {
|
if err != nil && err != ingress.ErrNoIngressRules {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,7 +71,7 @@ func buildTestURLCommand() *cli.Command {
|
||||||
func validateIngressCommand(c *cli.Context) error {
|
func validateIngressCommand(c *cli.Context) error {
|
||||||
conf := config.GetConfiguration()
|
conf := config.GetConfiguration()
|
||||||
fmt.Println("Validating rules from", conf.Source())
|
fmt.Println("Validating rules from", conf.Source())
|
||||||
if _, err := ingress.ParseIngressDryRun(conf); err != nil {
|
if _, err := ingress.ParseIngress(conf); err != nil {
|
||||||
return errors.Wrap(err, "Validation failed")
|
return errors.Wrap(err, "Validation failed")
|
||||||
}
|
}
|
||||||
if c.IsSet("url") {
|
if c.IsSet("url") {
|
||||||
|
@ -98,7 +98,7 @@ func testURLCommand(c *cli.Context) error {
|
||||||
|
|
||||||
conf := config.GetConfiguration()
|
conf := config.GetConfiguration()
|
||||||
fmt.Println("Using rules from", conf.Source())
|
fmt.Println("Using rules from", conf.Source())
|
||||||
ing, err := ingress.ParseIngressDryRun(conf)
|
ing, err := ingress.ParseIngress(conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "Validation failed")
|
return errors.Wrap(err, "Validation failed")
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,24 +1,17 @@
|
||||||
package ingress
|
package ingress
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
|
||||||
"github.com/cloudflare/cloudflared/validation"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -28,82 +21,6 @@ var (
|
||||||
ErrURLIncompatibleWithIngress = errors.New("You can't set the --url flag (or $TUNNEL_URL) when using multiple-origin ingress rules")
|
ErrURLIncompatibleWithIngress = errors.New("You can't set the --url flag (or $TUNNEL_URL) when using multiple-origin ingress rules")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Finalize the rules by adding missing struct fields and validating each origin.
|
|
||||||
func (ing *Ingress) setHTTPTransport(logger logger.Service) error {
|
|
||||||
for ruleNumber, rule := range ing.Rules {
|
|
||||||
cfg := rule.Config
|
|
||||||
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, nil)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "Error loading cert pool")
|
|
||||||
}
|
|
||||||
|
|
||||||
httpTransport := &http.Transport{
|
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
MaxIdleConns: cfg.KeepAliveConnections,
|
|
||||||
MaxIdleConnsPerHost: cfg.KeepAliveConnections,
|
|
||||||
IdleConnTimeout: cfg.KeepAliveTimeout,
|
|
||||||
TLSHandshakeTimeout: cfg.TLSTimeout,
|
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
|
||||||
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: cfg.NoTLSVerify},
|
|
||||||
}
|
|
||||||
if _, isHelloWorld := rule.Service.(*HelloWorld); !isHelloWorld && cfg.OriginServerName != "" {
|
|
||||||
httpTransport.TLSClientConfig.ServerName = cfg.OriginServerName
|
|
||||||
}
|
|
||||||
|
|
||||||
dialer := &net.Dialer{
|
|
||||||
Timeout: cfg.ConnectTimeout,
|
|
||||||
KeepAlive: cfg.TCPKeepAlive,
|
|
||||||
}
|
|
||||||
if cfg.NoHappyEyeballs {
|
|
||||||
dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialContext depends on which kind of origin is being used.
|
|
||||||
dialContext := dialer.DialContext
|
|
||||||
switch service := rule.Service.(type) {
|
|
||||||
|
|
||||||
// If this origin is a unix socket, enforce network type "unix".
|
|
||||||
case UnixSocketPath:
|
|
||||||
httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
|
||||||
return dialContext(ctx, "unix", service.Address())
|
|
||||||
}
|
|
||||||
// Otherwise, use the regular network config.
|
|
||||||
default:
|
|
||||||
httpTransport.DialContext = dialContext
|
|
||||||
}
|
|
||||||
|
|
||||||
ing.Rules[ruleNumber].HTTPTransport = httpTransport
|
|
||||||
ing.Rules[ruleNumber].ClientTLSConfig = httpTransport.TLSClientConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate each origin
|
|
||||||
for _, rule := range ing.Rules {
|
|
||||||
// If tunnel running in bastion mode, a connection to origin will not exist until initiated by the client.
|
|
||||||
if rule.Config.BastionMode {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unix sockets don't have validation
|
|
||||||
if _, ok := rule.Service.(UnixSocketPath); ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
switch service := rule.Service.(type) {
|
|
||||||
|
|
||||||
case UnixSocketPath:
|
|
||||||
continue
|
|
||||||
|
|
||||||
case *HelloWorld:
|
|
||||||
continue
|
|
||||||
|
|
||||||
default:
|
|
||||||
if err := validation.ValidateHTTPService(service.Address(), rule.Hostname, rule.HTTPTransport); err != nil {
|
|
||||||
logger.Errorf("unable to connect to the origin: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindMatchingRule returns the index of the Ingress Rule which matches the given
|
// FindMatchingRule returns the index of the Ingress Rule which matches the given
|
||||||
// hostname and path. This function assumes the last rule matches everything,
|
// hostname and path. This function assumes the last rule matches everything,
|
||||||
// which is the case if the rules were instantiated via the ingress#Validate method
|
// which is the case if the rules were instantiated via the ingress#Validate method
|
||||||
|
@ -154,14 +71,13 @@ func NewSingleOrigin(c *cli.Context, compatibilityMode bool, logger logger.Servi
|
||||||
},
|
},
|
||||||
defaults: originRequestFromSingeRule(c),
|
defaults: originRequestFromSingeRule(c),
|
||||||
}
|
}
|
||||||
err = ing.setHTTPTransport(logger)
|
|
||||||
return ing, err
|
return ing, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a single origin service from the CLI/config.
|
// Get a single origin service from the CLI/config.
|
||||||
func parseSingleOriginService(c *cli.Context, compatibilityMode bool) (OriginService, error) {
|
func parseSingleOriginService(c *cli.Context, compatibilityMode bool) (OriginService, error) {
|
||||||
if c.IsSet("hello-world") {
|
if c.IsSet("hello-world") {
|
||||||
return new(HelloWorld), nil
|
return new(helloWorld), nil
|
||||||
}
|
}
|
||||||
if c.IsSet("url") {
|
if c.IsSet("url") {
|
||||||
originURLStr, err := config.ValidateUrl(c, compatibilityMode)
|
originURLStr, err := config.ValidateUrl(c, compatibilityMode)
|
||||||
|
@ -172,14 +88,14 @@ func parseSingleOriginService(c *cli.Context, compatibilityMode bool) (OriginSer
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "couldn't parse origin URL")
|
return nil, errors.Wrap(err, "couldn't parse origin URL")
|
||||||
}
|
}
|
||||||
return &URL{URL: originURL, RootURL: originURL}, nil
|
return &localService{URL: originURL, RootURL: originURL}, nil
|
||||||
}
|
}
|
||||||
if c.IsSet("unix-socket") {
|
if c.IsSet("unix-socket") {
|
||||||
unixSocket, err := config.ValidateUnixSocket(c)
|
path, err := config.ValidateUnixSocket(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Error validating --unix-socket")
|
return nil, errors.Wrap(err, "Error validating --unix-socket")
|
||||||
}
|
}
|
||||||
return UnixSocketPath(unixSocket), nil
|
return &unixSocketPath{path: path}, nil
|
||||||
}
|
}
|
||||||
return nil, errors.New("You must either set ingress rules in your config file, or use --url or use --unix-socket")
|
return nil, errors.New("You must either set ingress rules in your config file, or use --url or use --unix-socket")
|
||||||
}
|
}
|
||||||
|
@ -192,7 +108,7 @@ func (ing Ingress) IsEmpty() bool {
|
||||||
// StartOrigins will start any origin services managed by cloudflared, e.g. proxy servers or Hello World.
|
// StartOrigins will start any origin services managed by cloudflared, e.g. proxy servers or Hello World.
|
||||||
func (ing Ingress) StartOrigins(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error) error {
|
func (ing Ingress) StartOrigins(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error) error {
|
||||||
for _, rule := range ing.Rules {
|
for _, rule := range ing.Rules {
|
||||||
if err := rule.Service.Start(wg, log, shutdownC, errC, rule.Config); err != nil {
|
if err := rule.Service.start(wg, log, shutdownC, errC, rule.Config); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -209,11 +125,12 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
||||||
for i, r := range ingress {
|
for i, r := range ingress {
|
||||||
var service OriginService
|
var service OriginService
|
||||||
|
|
||||||
if strings.HasPrefix(r.Service, "unix:") {
|
if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) {
|
||||||
// No validation necessary for unix socket filepath services
|
// No validation necessary for unix socket filepath services
|
||||||
service = UnixSocketPath(strings.TrimPrefix(r.Service, "unix:"))
|
path := strings.TrimPrefix(r.Service, prefix)
|
||||||
|
service = &unixSocketPath{path: path}
|
||||||
} else if r.Service == "hello_world" || r.Service == "hello-world" || r.Service == "helloworld" {
|
} else if r.Service == "hello_world" || r.Service == "hello-world" || r.Service == "helloworld" {
|
||||||
service = new(HelloWorld)
|
service = new(helloWorld)
|
||||||
} else {
|
} else {
|
||||||
// Validate URL services
|
// Validate URL services
|
||||||
u, err := url.Parse(r.Service)
|
u, err := url.Parse(r.Service)
|
||||||
|
@ -228,7 +145,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
||||||
if u.Path != "" {
|
if u.Path != "" {
|
||||||
return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path", r.Service)
|
return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path", r.Service)
|
||||||
}
|
}
|
||||||
serviceURL := URL{URL: u}
|
serviceURL := localService{URL: u}
|
||||||
service = &serviceURL
|
service = &serviceURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -262,7 +179,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
||||||
Hostname: r.Hostname,
|
Hostname: r.Hostname,
|
||||||
Service: service,
|
Service: service,
|
||||||
Path: pathRegex,
|
Path: pathRegex,
|
||||||
Config: SetConfig(defaults, r.OriginRequest),
|
Config: setConfig(defaults, r.OriginRequest),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Ingress{Rules: rules, defaults: defaults}, nil
|
return Ingress{Rules: rules, defaults: defaults}, nil
|
||||||
|
@ -279,20 +196,10 @@ func (e errRuleShouldNotBeCatchAll) Error() string {
|
||||||
"will never be triggered.", e.i+1, e.hostname)
|
"will never be triggered.", e.i+1, e.hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseIngress parses, validates and initializes HTTP transports to each origin.
|
// ParseIngress parses ingress rules, but does not send HTTP requests to the origins.
|
||||||
func ParseIngress(conf *config.Configuration, logger logger.Service) (Ingress, error) {
|
func ParseIngress(conf *config.Configuration) (Ingress, error) {
|
||||||
ing, err := ParseIngressDryRun(conf)
|
|
||||||
if err != nil {
|
|
||||||
return Ingress{}, err
|
|
||||||
}
|
|
||||||
err = ing.setHTTPTransport(logger)
|
|
||||||
return ing, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseIngressDryRun parses ingress rules, but does not send HTTP requests to the origins.
|
|
||||||
func ParseIngressDryRun(conf *config.Configuration) (Ingress, error) {
|
|
||||||
if len(conf.Ingress) == 0 {
|
if len(conf.Ingress) == 0 {
|
||||||
return Ingress{}, ErrNoIngressRules
|
return Ingress{}, ErrNoIngressRules
|
||||||
}
|
}
|
||||||
return validate(conf.Ingress, OriginRequestFromYAML(conf.OriginRequest))
|
return validate(conf.Ingress, originRequestFromYAML(conf.OriginRequest))
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
|
|
||||||
|
@ -16,16 +15,16 @@ func TestParseUnixSocket(t *testing.T) {
|
||||||
ingress:
|
ingress:
|
||||||
- service: unix:/tmp/echo.sock
|
- service: unix:/tmp/echo.sock
|
||||||
`
|
`
|
||||||
ing, err := ParseIngressDryRun(MustReadIngress(rawYAML))
|
ing, err := ParseIngress(MustReadIngress(rawYAML))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, ok := ing.Rules[0].Service.(UnixSocketPath)
|
_, ok := ing.Rules[0].Service.(*unixSocketPath)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_parseIngress(t *testing.T) {
|
func Test_parseIngress(t *testing.T) {
|
||||||
localhost8000 := MustParseURL(t, "https://localhost:8000")
|
localhost8000 := MustParseURL(t, "https://localhost:8000")
|
||||||
localhost8001 := MustParseURL(t, "https://localhost:8001")
|
localhost8001 := MustParseURL(t, "https://localhost:8001")
|
||||||
defaultConfig := SetConfig(OriginRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{})
|
defaultConfig := setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{})
|
||||||
require.Equal(t, defaultKeepAliveConnections, defaultConfig.KeepAliveConnections)
|
require.Equal(t, defaultKeepAliveConnections, defaultConfig.KeepAliveConnections)
|
||||||
type args struct {
|
type args struct {
|
||||||
rawYAML string
|
rawYAML string
|
||||||
|
@ -53,12 +52,12 @@ ingress:
|
||||||
want: []Rule{
|
want: []Rule{
|
||||||
{
|
{
|
||||||
Hostname: "tunnel1.example.com",
|
Hostname: "tunnel1.example.com",
|
||||||
Service: &URL{URL: localhost8000},
|
Service: &localService{URL: localhost8000},
|
||||||
Config: defaultConfig,
|
Config: defaultConfig,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Hostname: "*",
|
Hostname: "*",
|
||||||
Service: &URL{URL: localhost8001},
|
Service: &localService{URL: localhost8001},
|
||||||
Config: defaultConfig,
|
Config: defaultConfig,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -74,7 +73,7 @@ extraKey: extraValue
|
||||||
want: []Rule{
|
want: []Rule{
|
||||||
{
|
{
|
||||||
Hostname: "*",
|
Hostname: "*",
|
||||||
Service: &URL{URL: localhost8000},
|
Service: &localService{URL: localhost8000},
|
||||||
Config: defaultConfig,
|
Config: defaultConfig,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -87,7 +86,7 @@ ingress:
|
||||||
`},
|
`},
|
||||||
want: []Rule{
|
want: []Rule{
|
||||||
{
|
{
|
||||||
Service: &URL{URL: localhost8000},
|
Service: &localService{URL: localhost8000},
|
||||||
Config: defaultConfig,
|
Config: defaultConfig,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -165,15 +164,37 @@ ingress:
|
||||||
`},
|
`},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid HTTP status",
|
||||||
|
args: args{rawYAML: `
|
||||||
|
ingress:
|
||||||
|
- service: http_status:asdf
|
||||||
|
`},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid hello world service",
|
||||||
|
args: args{rawYAML: `
|
||||||
|
ingress:
|
||||||
|
- service: hello_world
|
||||||
|
`},
|
||||||
|
want: []Rule{
|
||||||
|
{
|
||||||
|
Hostname: "",
|
||||||
|
Service: new(helloWorld),
|
||||||
|
Config: defaultConfig,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := ParseIngressDryRun(MustReadIngress(tt.args.rawYAML))
|
got, err := ParseIngress(MustReadIngress(tt.args.rawYAML))
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("ParseIngressDryRun() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("ParseIngress() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
assert.Equal(t, tt.want, got.Rules)
|
require.Equal(t, tt.want, got.Rules)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -195,7 +216,7 @@ ingress:
|
||||||
service: https://localhost:8002
|
service: https://localhost:8002
|
||||||
`
|
`
|
||||||
|
|
||||||
ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML))
|
ing, err := ParseIngress(MustReadIngress(rulesYAML))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Error(err)
|
b.Error(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -116,7 +116,7 @@ func originRequestFromSingeRule(c *cli.Context) OriginRequestConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func OriginRequestFromYAML(y config.OriginRequestConfig) OriginRequestConfig {
|
func originRequestFromYAML(y config.OriginRequestConfig) OriginRequestConfig {
|
||||||
out := OriginRequestConfig{
|
out := OriginRequestConfig{
|
||||||
ConnectTimeout: defaultConnectTimeout,
|
ConnectTimeout: defaultConnectTimeout,
|
||||||
TLSTimeout: defaultTLSTimeout,
|
TLSTimeout: defaultTLSTimeout,
|
||||||
|
@ -310,7 +310,7 @@ func (defaults *OriginRequestConfig) setProxyType(overrides config.OriginRequest
|
||||||
// 3. Defaults chosen by the cloudflared team
|
// 3. Defaults chosen by the cloudflared team
|
||||||
// 4. Golang zero values for that type
|
// 4. Golang zero values for that type
|
||||||
// If an earlier option isn't set, it will try the next option down.
|
// If an earlier option isn't set, it will try the next option down.
|
||||||
func SetConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfig) OriginRequestConfig {
|
func setConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfig) OriginRequestConfig {
|
||||||
cfg := defaults
|
cfg := defaults
|
||||||
cfg.setConnectTimeout(overrides)
|
cfg.setConnectTimeout(overrides)
|
||||||
cfg.setTLSTimeout(overrides)
|
cfg.setTLSTimeout(overrides)
|
||||||
|
|
|
@ -71,7 +71,7 @@ ingress:
|
||||||
proxyPort: 200
|
proxyPort: 200
|
||||||
proxyType: ""
|
proxyType: ""
|
||||||
`
|
`
|
||||||
ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML))
|
ing, err := ParseIngress(MustReadIngress(rulesYAML))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
@ -144,7 +144,7 @@ ingress:
|
||||||
proxyPort: 200
|
proxyPort: 200
|
||||||
proxyType: ""
|
proxyType: ""
|
||||||
`
|
`
|
||||||
ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML))
|
ing, err := ParseIngress(MustReadIngress(rulesYAML))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,72 +1,103 @@
|
||||||
package ingress
|
package ingress
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/hello"
|
"github.com/cloudflare/cloudflared/hello"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/socks"
|
"github.com/cloudflare/cloudflared/socks"
|
||||||
|
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
|
gws "github.com/gorilla/websocket"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OriginService is something a tunnel can proxy traffic to.
|
// OriginService is something a tunnel can proxy traffic to.
|
||||||
type OriginService interface {
|
type OriginService interface {
|
||||||
Address() string
|
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
|
||||||
|
http.RoundTripper
|
||||||
|
String() string
|
||||||
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
|
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
|
||||||
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
|
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
|
||||||
// starting the origin service.
|
// starting the origin service.
|
||||||
Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error
|
start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error
|
||||||
String() string
|
|
||||||
// RewriteOriginURL modifies the HTTP request from cloudflared to the origin, so that it apply
|
|
||||||
// this particular type of origin service's specific routing logic.
|
|
||||||
RewriteOriginURL(*url.URL)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnixSocketPath is an OriginService representing a unix socket (which accepts HTTP)
|
// unixSocketPath is an OriginService representing a unix socket (which accepts HTTP)
|
||||||
type UnixSocketPath string
|
type unixSocketPath struct {
|
||||||
|
path string
|
||||||
func (o UnixSocketPath) Address() string {
|
transport *http.Transport
|
||||||
return string(o)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o UnixSocketPath) String() string {
|
func (o *unixSocketPath) String() string {
|
||||||
return "unix socket: " + string(o)
|
return "unix socket: " + o.path
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o UnixSocketPath) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
func (o *unixSocketPath) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||||
|
transport, err := newHTTPTransport(o, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
o.transport = transport
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o UnixSocketPath) RewriteOriginURL(u *url.URL) {
|
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
// No changes necessary because the origin request URL isn't used.
|
return o.transport.RoundTrip(req)
|
||||||
// Instead, HTTPTransport's dial is already configured to address the unix socket.
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// URL is an OriginService listening on a TCP address
|
func (o *unixSocketPath) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||||
type URL struct {
|
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
|
||||||
|
return d.Dial(url, headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// localService is an OriginService listening on a TCP/IP address the user's origin can route to.
|
||||||
|
type localService struct {
|
||||||
// The URL for the user's origin service
|
// The URL for the user's origin service
|
||||||
RootURL *url.URL
|
RootURL *url.URL
|
||||||
// The URL that cloudflared should send requests to.
|
// The URL that cloudflared should send requests to.
|
||||||
// If this origin requires starting a proxy, this is the proxy's address,
|
// If this origin requires starting a proxy, this is the proxy's address,
|
||||||
// and that proxy points to RootURL. Otherwise, this is equal to RootURL.
|
// and that proxy points to RootURL. Otherwise, this is equal to RootURL.
|
||||||
URL *url.URL
|
URL *url.URL
|
||||||
|
transport *http.Transport
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *URL) Address() string {
|
func (o *localService) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||||
|
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
|
||||||
|
return d.Dial(url, headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *localService) address() string {
|
||||||
return o.URL.String()
|
return o.URL.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *URL) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
func (o *localService) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||||
staticHost := o.staticHost()
|
transport, err := newHTTPTransport(o, cfg)
|
||||||
if !originRequiresProxy(staticHost, cfg) {
|
if err != nil {
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
o.transport = transport
|
||||||
|
|
||||||
|
// Start a proxy if one is needed
|
||||||
|
staticHost := o.staticHost()
|
||||||
|
if originRequiresProxy(staticHost, cfg) {
|
||||||
|
if err := o.startProxy(staticHost, wg, log, shutdownC, errC, cfg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *localService) startProxy(staticHost string, wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||||
|
|
||||||
// Start a listener for the proxy
|
// Start a listener for the proxy
|
||||||
proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort)))
|
proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort)))
|
||||||
|
@ -111,16 +142,18 @@ func (o *URL) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan str
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *URL) String() string {
|
func (o *localService) String() string {
|
||||||
return o.Address()
|
return o.address()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *URL) RewriteOriginURL(u *url.URL) {
|
func (o *localService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
u.Host = o.URL.Host
|
// Rewrite the request URL so that it goes to the origin service.
|
||||||
u.Scheme = o.URL.Scheme
|
req.URL.Host = o.URL.Host
|
||||||
|
req.URL.Scheme = o.URL.Scheme
|
||||||
|
return o.transport.RoundTrip(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *URL) staticHost() string {
|
func (o *localService) staticHost() string {
|
||||||
|
|
||||||
addPortIfMissing := func(uri *url.URL, port int) string {
|
addPortIfMissing := func(uri *url.URL, port int) string {
|
||||||
if uri.Port() != "" {
|
if uri.Port() != "" {
|
||||||
|
@ -143,21 +176,24 @@ func (o *URL) staticHost() string {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HelloWorld is the built-in Hello World service. Used for testing and experimenting with cloudflared.
|
// HelloWorld is an OriginService for the built-in Hello World server.
|
||||||
type HelloWorld struct {
|
// Users only use this for testing and experimenting with cloudflared.
|
||||||
server net.Listener
|
type helloWorld struct {
|
||||||
|
server net.Listener
|
||||||
|
transport *http.Transport
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *HelloWorld) Address() string {
|
func (o *helloWorld) String() string {
|
||||||
return o.server.Addr().String()
|
return "Hello World test origin"
|
||||||
}
|
|
||||||
|
|
||||||
func (o *HelloWorld) String() string {
|
|
||||||
return "Hello World static HTML service"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts a HelloWorld server and stores its address in the Service receiver.
|
// Start starts a HelloWorld server and stores its address in the Service receiver.
|
||||||
func (o *HelloWorld) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
func (o *helloWorld) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||||
|
transport, err := newHTTPTransport(o, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
o.transport = transport
|
||||||
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
|
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "Cannot start Hello World Server")
|
return errors.Wrap(err, "Cannot start Hello World Server")
|
||||||
|
@ -171,11 +207,63 @@ func (o *HelloWorld) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-c
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *HelloWorld) RewriteOriginURL(u *url.URL) {
|
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
u.Host = o.Address()
|
// Rewrite the request URL so that it goes to the Hello World server.
|
||||||
u.Scheme = "https"
|
req.URL.Host = o.server.Addr().String()
|
||||||
|
req.URL.Scheme = "https"
|
||||||
|
return o.transport.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *helloWorld) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||||
|
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
|
||||||
|
return d.Dial(url, headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool {
|
func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool {
|
||||||
return staticHost != "" || cfg.BastionMode
|
return staticHost != "" || cfg.BastionMode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newHTTPTransport(service OriginService, cfg OriginRequestConfig) (*http.Transport, error) {
|
||||||
|
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error loading cert pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
httpTransport := http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
MaxIdleConns: cfg.KeepAliveConnections,
|
||||||
|
MaxIdleConnsPerHost: cfg.KeepAliveConnections,
|
||||||
|
IdleConnTimeout: cfg.KeepAliveTimeout,
|
||||||
|
TLSHandshakeTimeout: cfg.TLSTimeout,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: cfg.NoTLSVerify},
|
||||||
|
}
|
||||||
|
if _, isHelloWorld := service.(*helloWorld); !isHelloWorld && cfg.OriginServerName != "" {
|
||||||
|
httpTransport.TLSClientConfig.ServerName = cfg.OriginServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := &net.Dialer{
|
||||||
|
Timeout: cfg.ConnectTimeout,
|
||||||
|
KeepAlive: cfg.TCPKeepAlive,
|
||||||
|
}
|
||||||
|
if cfg.NoHappyEyeballs {
|
||||||
|
dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialContext depends on which kind of origin is being used.
|
||||||
|
dialContext := dialer.DialContext
|
||||||
|
switch service := service.(type) {
|
||||||
|
|
||||||
|
// If this origin is a unix socket, enforce network type "unix".
|
||||||
|
case *unixSocketPath:
|
||||||
|
httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||||
|
return dialContext(ctx, "unix", service.path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, use the regular network config.
|
||||||
|
default:
|
||||||
|
httpTransport.DialContext = dialContext
|
||||||
|
}
|
||||||
|
|
||||||
|
return &httpTransport, nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package ingress
|
package ingress
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"net/http"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
@ -23,11 +21,6 @@ type Rule struct {
|
||||||
|
|
||||||
// Configure the request cloudflared sends to this specific origin.
|
// Configure the request cloudflared sends to this specific origin.
|
||||||
Config OriginRequestConfig
|
Config OriginRequestConfig
|
||||||
|
|
||||||
// Configures TLS for the cloudflared -> origin request
|
|
||||||
ClientTLSConfig *tls.Config
|
|
||||||
// Configures HTTP for the cloudflared -> origin request
|
|
||||||
HTTPTransport http.RoundTripper
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MultiLineString is for outputting rules in a human-friendly way when Cloudflared
|
// MultiLineString is for outputting rules in a human-friendly way when Cloudflared
|
||||||
|
|
|
@ -698,7 +698,6 @@ func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request,
|
||||||
}
|
}
|
||||||
h.AppendTagHeaders(req)
|
h.AppendTagHeaders(req)
|
||||||
rule, _ := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
|
rule, _ := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
|
||||||
rule.Service.RewriteOriginURL(req.URL)
|
|
||||||
return req, rule, nil
|
return req, rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -708,7 +707,11 @@ func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Requ
|
||||||
req.Host = hostHeader
|
req.Host = hostHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, response, err := websocket.ClientConnect(req, rule.ClientTLSConfig)
|
dialler, ok := rule.Service.(websocket.Dialler)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("Websockets aren't supported by the origin service '%s'", rule.Service)
|
||||||
|
}
|
||||||
|
conn, response, err := websocket.ClientConnect(req, dialler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -742,7 +745,7 @@ func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request,
|
||||||
req.Host = hostHeader
|
req.Host = hostHeader
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := rule.HTTPTransport.RoundTrip(req)
|
response, err := rule.Service.RoundTrip(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "Error proxying request to origin")
|
return nil, errors.Wrap(err, "Error proxying request to origin")
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,15 +69,31 @@ func IsWebSocketUpgrade(req *http.Request) bool {
|
||||||
return websocket.IsWebSocketUpgrade(req)
|
return websocket.IsWebSocketUpgrade(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dialler is something that can proxy websocket requests.
|
||||||
|
type Dialler interface {
|
||||||
|
Dial(url string, headers http.Header) (*websocket.Conn, *http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type defaultDialler struct {
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dd *defaultDialler) Dial(url string, header http.Header) (*websocket.Conn, *http.Response, error) {
|
||||||
|
d := &websocket.Dialer{TLSClientConfig: dd.tlsConfig}
|
||||||
|
return d.Dial(url, header)
|
||||||
|
}
|
||||||
|
|
||||||
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
|
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
|
||||||
// the connection. The response body may not contain the entire response and does
|
// the connection. The response body may not contain the entire response and does
|
||||||
// not need to be closed by the application.
|
// not need to be closed by the application.
|
||||||
func ClientConnect(req *http.Request, tlsClientConfig *tls.Config) (*websocket.Conn, *http.Response, error) {
|
func ClientConnect(req *http.Request, dialler Dialler) (*websocket.Conn, *http.Response, error) {
|
||||||
req.URL.Scheme = changeRequestScheme(req)
|
req.URL.Scheme = changeRequestScheme(req)
|
||||||
wsHeaders := websocketHeaders(req)
|
wsHeaders := websocketHeaders(req)
|
||||||
|
|
||||||
d := &websocket.Dialer{TLSClientConfig: tlsClientConfig}
|
if dialler == nil {
|
||||||
conn, response, err := d.Dial(req.URL.String(), wsHeaders)
|
dialler = new(defaultDialler)
|
||||||
|
}
|
||||||
|
conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, response, err
|
return nil, response, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,7 +77,8 @@ func TestServe(t *testing.T) {
|
||||||
|
|
||||||
tlsConfig := websocketClientTLSConfig(t)
|
tlsConfig := websocketClientTLSConfig(t)
|
||||||
assert.NotNil(t, tlsConfig)
|
assert.NotNil(t, tlsConfig)
|
||||||
conn, resp, err := ClientConnect(req, tlsConfig)
|
d := defaultDialler{tlsConfig: tlsConfig}
|
||||||
|
conn, resp, err := ClientConnect(req, &d)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
|
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue