diff --git a/cmd/cloudflared/config/configuration.go b/cmd/cloudflared/config/configuration.go index 035442be..af43996b 100644 --- a/cmd/cloudflared/config/configuration.go +++ b/cmd/cloudflared/config/configuration.go @@ -226,10 +226,15 @@ type OriginRequestConfig struct { type Configuration struct { TunnelID string `yaml:"tunnel"` Ingress []UnvalidatedIngressRule + WarpRouting WarpRoutingConfig `yaml:"warp-routing"` OriginRequest OriginRequestConfig `yaml:"originRequest"` sourceFile string } +type WarpRoutingConfig struct { + Enabled bool `yaml:"enabled"` +} + type configFileSettings struct { Configuration `yaml:",inline"` // older settings will be aggregated into the generic map, should be read via cli.Context diff --git a/cmd/cloudflared/config/configuration_test.go b/cmd/cloudflared/config/configuration_test.go index 37357539..355224fb 100644 --- a/cmd/cloudflared/config/configuration_test.go +++ b/cmd/cloudflared/config/configuration_test.go @@ -20,6 +20,9 @@ func TestConfigFileSettings(t *testing.T) { Path: "", Service: "https://localhost:8001", } + warpRouting = WarpRoutingConfig{ + Enabled: true, + } ) rawYAML := ` tunnel: config-file-test @@ -29,6 +32,8 @@ ingress: service: https://localhost:8000 - hostname: "*" service: https://localhost:8001 +warp-routing: + enabled: true retries: 5 grace-period: 30s percentage: 3.14 @@ -47,6 +52,7 @@ counters: assert.Equal(t, "config-file-test", config.TunnelID) assert.Equal(t, firstIngress, config.Ingress[0]) assert.Equal(t, secondIngress, config.Ingress[1]) + assert.Equal(t, warpRouting, config.WarpRouting) retries, err := config.Int("retries") assert.NoError(t, err) @@ -73,4 +79,5 @@ counters: assert.NoError(t, err) assert.Equal(t, 123, counters[0]) assert.Equal(t, 456, counters[1]) + } diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 9c97d350..5cbb7a0b 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -195,6 +195,7 @@ func prepareTunnelConfig( ingressRules ingress.Ingress classicTunnel *connection.ClassicTunnelConfig ) + cfg := config.GetConfiguration() if isNamedTunnel { clientUUID, err := uuid.NewRandom() if err != nil { @@ -206,7 +207,7 @@ func prepareTunnelConfig( Version: version, Arch: fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch), } - ingressRules, err = ingress.ParseIngress(config.GetConfiguration()) + ingressRules, err = ingress.ParseIngress(cfg) if err != nil && err != ingress.ErrNoIngressRules { return nil, ingress.Ingress{}, err } @@ -245,7 +246,12 @@ func prepareTunnelConfig( edgeTLSConfigs[p] = edgeTLSConfig } - originProxy := origin.NewOriginProxy(ingressRules, tags, log) + var warpRoutingService *ingress.WarpRoutingService + if isWarpRoutingEnabled(cfg.WarpRouting, isNamedTunnel, protocolSelector.Current()) { + warpRoutingService = ingress.NewWarpRoutingService() + } + + originProxy := origin.NewOriginProxy(ingressRules, warpRoutingService, tags, log) connectionConfig := &connection.Config{ OriginProxy: originProxy, GracePeriod: c.Duration("grace-period"), @@ -286,6 +292,10 @@ func prepareTunnelConfig( }, ingressRules, nil } +func isWarpRoutingEnabled(warpConfig config.WarpRoutingConfig, isNamedTunnel bool, protocol connection.Protocol) bool { + return warpConfig.Enabled && isNamedTunnel && protocol == connection.HTTP2 +} + func isRunningFromTerminal() bool { return terminal.IsTerminal(int(os.Stdout.Fd())) } diff --git a/ingress/ingress.go b/ingress/ingress.go index 3e2f8f4c..396d36fb 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -25,8 +25,9 @@ var ( ) const ( - ServiceBastion = "bastion" - ServiceTeamnet = "teamnet-proxy" + ServiceBridge = "bridge service" + ServiceBastion = "bastion" + ServiceWarpRouting = "warp-routing" ) // FindMatchingRule returns the index of the Ingress Rule which matches the given @@ -43,6 +44,7 @@ func (ing Ingress) FindMatchingRule(hostname, path string) (*Rule, int) { return &rule, i } } + i := len(ing.Rules) - 1 return &ing.Rules[i], i } @@ -89,13 +91,24 @@ func NewSingleOrigin(c *cli.Context, allowURLFromArgs bool) (Ingress, error) { return ing, err } +// WarpRoutingService starts a tcp stream between the origin and requests from +// warp clients. +type WarpRoutingService struct { + Proxy StreamBasedOriginProxy +} + +func NewWarpRoutingService() *WarpRoutingService { + warpRoutingService := newBridgeService(DefaultStreamHandler, ServiceWarpRouting) + return &WarpRoutingService{Proxy: warpRoutingService} +} + // Get a single origin service from the CLI/config. func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originService, error) { if c.IsSet("hello-world") { return new(helloWorld), nil } if c.IsSet(config.BastionFlag) { - return newBridgeService(nil), nil + return newBridgeService(nil, ServiceBastion), nil } if c.IsSet("url") { originURL, err := config.ValidateUrl(c, allowURLFromArgs) @@ -169,9 +182,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon // overwrite the localService.URL field when `start` is called. So, // leave the URL field empty for now. cfg.BastionMode = true - service = newBridgeService(nil) - } else if r.Service == ServiceTeamnet { - service = newBridgeService(DefaultStreamHandler) + service = newBridgeService(nil, ServiceBastion) } else { // Validate URL services u, err := url.Parse(r.Service) diff --git a/ingress/ingress_test.go b/ingress/ingress_test.go index 6a143b1b..911e24bf 100644 --- a/ingress/ingress_test.go +++ b/ingress/ingress_test.go @@ -3,6 +3,7 @@ package ingress import ( "flag" "fmt" + "net/http" "net/url" "regexp" "testing" @@ -315,7 +316,7 @@ ingress: want: []Rule{ { Hostname: "bastion.foo.com", - Service: newBridgeService(nil), + Service: newBridgeService(nil, ServiceBastion), Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), }, { @@ -335,7 +336,7 @@ ingress: want: []Rule{ { Hostname: "bastion.foo.com", - Service: newBridgeService(nil), + Service: newBridgeService(nil, ServiceBastion), Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}), }, { @@ -463,6 +464,7 @@ func TestFindMatchingRule(t *testing.T) { tests := []struct { host string path string + req *http.Request wantRuleIndex int }{ { @@ -497,9 +499,9 @@ func TestFindMatchingRule(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { _, ruleIndex := ingress.FindMatchingRule(test.host, test.path) - assert.Equal(t, test.wantRuleIndex, ruleIndex, fmt.Sprintf("Expect host=%s, path=%s to match rule %d, got %d", test.host, test.path, test.wantRuleIndex, i)) + assert.Equal(t, test.wantRuleIndex, ruleIndex, fmt.Sprintf("Expect host=%s, path=%s to match rule %d, got %d", test.host, test.path, test.wantRuleIndex, ruleIndex)) } } @@ -561,6 +563,7 @@ ingress: if err != nil { b.Error(err) } + for n := 0; n < b.N; n++ { ing.FindMatchingRule("tunnel1.example.com", "") ing.FindMatchingRule("tunnel2.example.com", "") diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index 16534eac..e2238a40 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -91,7 +91,7 @@ func TestBridgeServiceDestination(t *testing.T) { wantErr: true, }, } - s := newBridgeService(nil) + s := newBridgeService(nil, ServiceBastion) for _, test := range tests { r := &http.Request{ Header: test.header, diff --git a/ingress/origin_service.go b/ingress/origin_service.go index ba8892df..ccf1e008 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -78,20 +78,22 @@ func (o *httpService) String() string { // bridgeService is like a jump host, the destination is specified by the client type bridgeService struct { - client *tcpClient + client *tcpClient + serviceName string } // if streamHandler is nil, a default one is set. -func newBridgeService(streamHandler streamHandlerFunc) *bridgeService { +func newBridgeService(streamHandler streamHandlerFunc, serviceName string) *bridgeService { return &bridgeService{ client: &tcpClient{ streamHandler: streamHandler, }, + serviceName: serviceName, } } func (o *bridgeService) String() string { - return "bridge service" + return ServiceBridge + ":" + o.serviceName } func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { diff --git a/origin/proxy.go b/origin/proxy.go index aaafa9e6..42e771d0 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -24,14 +24,21 @@ const ( type proxy struct { ingressRules ingress.Ingress + warpRouting *ingress.WarpRoutingService tags []tunnelpogs.Tag log *zerolog.Logger bufferPool *buffer.Pool } -func NewOriginProxy(ingressRules ingress.Ingress, tags []tunnelpogs.Tag, log *zerolog.Logger) connection.OriginProxy { +func NewOriginProxy( + ingressRules ingress.Ingress, + warpRouting *ingress.WarpRoutingService, + tags []tunnelpogs.Tag, + log *zerolog.Logger) connection.OriginProxy { + return &proxy{ ingressRules: ingressRules, + warpRouting: warpRouting, tags: tags, log: log, bufferPool: buffer.NewPool(512 * 1024), @@ -46,6 +53,22 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn lbProbe := isLBProbeRequest(req) p.appendTagHeaders(req) + if sourceConnectionType == connection.TypeTCP { + if p.warpRouting == nil { + err := errors.New(`cloudflared received a request from Warp client, but your configuration has disabled ingress from Warp clients. To enable this, set "warp-routing:\n\t enabled: true" in your config.yaml`) + p.log.Error().Msg(err.Error()) + return err + } + resp, err := p.handleProxyConn(w, req, nil, p.warpRouting.Proxy) + if err != nil { + p.logRequestError(err, cfRay, ingress.ServiceWarpRouting) + w.WriteErrorResponse() + return err + } + p.logOriginResponse(resp, cfRay, lbProbe, ingress.ServiceWarpRouting) + return nil + } + rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path) p.logRequest(req, cfRay, lbProbe, ruleNum) @@ -66,13 +89,37 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn respHeader = websocket.NewResponseHeader(req) } - connClosedChan := make(chan struct{}) - err := p.proxyConnection(connClosedChan, w, req, rule) + if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { + req.Header.Set("Host", hostHeader) + req.Host = hostHeader + } + + connectionProxy, ok := rule.Service.(ingress.StreamBasedOriginProxy) + if !ok { + p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service) + return fmt.Errorf("Not a connection-oriented service") + } + resp, err := p.handleProxyConn(w, req, respHeader, connectionProxy) if err != nil { p.logErrorAndWriteResponse(w, err, cfRay, ruleNum) return err } + p.logOriginResponse(resp, cfRay, lbProbe, ruleNum) + return nil +} + +func (p *proxy) handleProxyConn( + w connection.ResponseWriter, + req *http.Request, + respHeader http.Header, + connectionProxy ingress.StreamBasedOriginProxy) (*http.Response, error) { + connClosedChan := make(chan struct{}) + err := p.proxyConnection(connClosedChan, w, req, connectionProxy) + if err != nil { + return nil, err + } + status := http.StatusSwitchingProtocols resp := &http.Response{ Status: http.StatusText(status), @@ -83,9 +130,8 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn w.WriteRespHeaders(http.StatusSwitchingProtocols, nil) <-connClosedChan + return resp, nil - p.logOriginResponse(resp, cfRay, lbProbe, ruleNum) - return nil } func (p *proxy) logErrorAndWriteResponse(w connection.ResponseWriter, err error, cfRay string, ruleNum int) { @@ -141,18 +187,8 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule * } func (p *proxy) proxyConnection(connClosedChan chan struct{}, - conn io.ReadWriter, req *http.Request, rule *ingress.Rule) error { - if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" { - req.Header.Set("Host", hostHeader) - req.Host = hostHeader - } - - connectionService, ok := rule.Service.(ingress.StreamBasedOriginProxy) - if !ok { - p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service) - return fmt.Errorf("Not a connection-oriented service") - } - originConn, err := connectionService.EstablishConnection(req) + conn io.ReadWriter, req *http.Request, connectionProxy ingress.StreamBasedOriginProxy) error { + originConn, err := connectionProxy.EstablishConnection(req) if err != nil { return err } @@ -190,7 +226,7 @@ func (p *proxy) appendTagHeaders(r *http.Request) { } } -func (p *proxy) logRequest(r *http.Request, cfRay string, lbProbe bool, ruleNum int) { +func (p *proxy) logRequest(r *http.Request, cfRay string, lbProbe bool, rule interface{}) { if cfRay != "" { p.log.Debug().Msgf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto) } else if lbProbe { @@ -199,7 +235,7 @@ func (p *proxy) logRequest(r *http.Request, cfRay string, lbProbe bool, ruleNum p.log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto) } p.log.Debug().Msgf("CF-RAY: %s Request Headers %+v", cfRay, r.Header) - p.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %d", cfRay, ruleNum) + p.log.Debug().Msgf("CF-RAY: %s Serving with ingress rule %v", cfRay, rule) if contentLen := r.ContentLength; contentLen == -1 { p.log.Debug().Msgf("CF-RAY: %s Request Content length unknown", cfRay) @@ -208,14 +244,14 @@ func (p *proxy) logRequest(r *http.Request, cfRay string, lbProbe bool, ruleNum } } -func (p *proxy) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, ruleNum int) { +func (p *proxy) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, rule interface{}) { responseByCode.WithLabelValues(strconv.Itoa(r.StatusCode)).Inc() if cfRay != "" { - p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, ruleNum) + p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", cfRay, r.Status, rule) } else if lbProbe { p.log.Debug().Msgf("Response to Load Balancer health check %s", r.Status) } else { - p.log.Debug().Msgf("Status: %s served by ingress %d", r.Status, ruleNum) + p.log.Debug().Msgf("Status: %s served by ingress %v", r.Status, rule) } p.log.Debug().Msgf("CF-RAY: %s Response Headers %+v", cfRay, r.Header) @@ -226,14 +262,13 @@ func (p *proxy) logOriginResponse(r *http.Response, cfRay string, lbProbe bool, } } -func (p *proxy) logRequestError(err error, cfRay string, ruleNum int) { +func (p *proxy) logRequestError(err error, cfRay string, rule interface{}) { requestErrors.Inc() if cfRay != "" { - p.log.Error().Msgf("CF-RAY: %s Proxying to ingress %d error: %v", cfRay, ruleNum, err) + p.log.Error().Msgf("CF-RAY: %s Proxying to ingress %v error: %v", cfRay, rule, err) } else { - p.log.Error().Msgf("Proxying to ingress %d error: %v", ruleNum, err) + p.log.Error().Msgf("Proxying to ingress %v error: %v", rule, err) } - } func findCfRayHeader(req *http.Request) string { diff --git a/origin/proxy_test.go b/origin/proxy_test.go index fa31d15b..5078217d 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -5,7 +5,6 @@ import ( "context" "flag" "fmt" - "github.com/cloudflare/cloudflared/logger" "io" "net" "net/http" @@ -14,6 +13,8 @@ import ( "testing" "time" + "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/cmd/cloudflared/config" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/h2mux" @@ -30,7 +31,8 @@ import ( ) var ( - testTags = []tunnelpogs.Tag(nil) + testTags = []tunnelpogs.Tag(nil) + unusedWarpRoutingService = (*ingress.WarpRoutingService)(nil) ) type mockHTTPRespWriter struct { @@ -129,7 +131,7 @@ func TestProxySingleOrigin(t *testing.T) { errC := make(chan error) require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC)) - proxy := NewOriginProxy(ingressRule, testTags, &log) + proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, &log) t.Run("testProxyHTTP", testProxyHTTP(t, proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(t, proxy)) t.Run("testProxySSE", testProxySSE(t, proxy)) @@ -262,7 +264,7 @@ func TestProxyMultipleOrigins(t *testing.T) { var wg sync.WaitGroup require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC)) - proxy := NewOriginProxy(ingress, testTags, &log) + proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log) tests := []struct { url string @@ -340,7 +342,7 @@ func TestProxyError(t *testing.T) { log := zerolog.Nop() - proxy := NewOriginProxy(ingress, testTags, &log) + proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log) respWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) @@ -372,7 +374,7 @@ func TestProxyBastionMode(t *testing.T) { ingressRule.StartOrigins(&wg, log, ctx.Done(), errC) - proxy := NewOriginProxy(ingressRule, testTags, log) + proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, log) t.Run("testBastionWebsocket", testBastionWebsocket(proxy)) cancel() @@ -429,9 +431,9 @@ func TestTCPStream(t *testing.T) { ingressConfig := &config.Configuration{ Ingress: []config.UnvalidatedIngressRule{ - config.UnvalidatedIngressRule{ + { Hostname: "*", - Service: ingress.ServiceTeamnet, + Service: "bastion", }, }, } @@ -442,11 +444,10 @@ func TestTCPStream(t *testing.T) { errC := make(chan error) ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) - proxy := NewOriginProxy(ingressRule, testTags, logger) + proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger) t.Run("testTCPStream", testTCPStreamProxy(proxy)) cancel() - wg.Wait() } type mockTCPRespWriter struct { diff --git a/vendor/github.com/rs/zerolog/go.mod b/vendor/github.com/rs/zerolog/go.mod index 340ed40e..acba1950 100644 --- a/vendor/github.com/rs/zerolog/go.mod +++ b/vendor/github.com/rs/zerolog/go.mod @@ -1,5 +1,7 @@ module github.com/rs/zerolog +go 1.15 + require ( github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e github.com/pkg/errors v0.8.1 diff --git a/websocket/connection.go b/websocket/connection.go index 9616764e..8098fcb1 100644 --- a/websocket/connection.go +++ b/websocket/connection.go @@ -2,10 +2,11 @@ package websocket import ( "context" - "github.com/rs/zerolog" "io" "time" + "github.com/rs/zerolog" + gobwas "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" "github.com/gorilla/websocket"