TUN-5698: Make ingress rules and warp routing dynamically configurable

This commit is contained in:
cthuang 2022-02-11 10:49:06 +00:00
parent 0571210374
commit d68ff390ca
21 changed files with 978 additions and 175 deletions

View File

@ -31,6 +31,7 @@ import (
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/metrics" "github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/orchestration"
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/supervisor"
"github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tlsconfig"
@ -353,7 +354,8 @@ func StartServer(
errC <- metrics.ServeMetrics(metricsListener, ctx.Done(), readinessServer, quickTunnelURL, log) errC <- metrics.ServeMetrics(metricsListener, ctx.Done(), readinessServer, quickTunnelURL, log)
}() }()
if err := dynamicConfig.Ingress.StartOrigins(&wg, log, ctx.Done(), errC); err != nil { orchestrator, err := orchestration.NewOrchestrator(ctx, dynamicConfig, tunnelConfig.Tags, tunnelConfig.Log)
if err != nil {
return err return err
} }
@ -369,7 +371,7 @@ func StartServer(
wg.Done() wg.Done()
log.Info().Msg("Tunnel server stopped") log.Info().Msg("Tunnel server stopped")
}() }()
errC <- supervisor.StartTunnelDaemon(ctx, tunnelConfig, dynamicConfig, connectedSignal, reconnectCh, graceShutdownC) errC <- supervisor.StartTunnelDaemon(ctx, tunnelConfig, orchestrator, connectedSignal, reconnectCh, graceShutdownC)
}() }()
if isUIEnabled { if isUIEnabled {

View File

@ -23,6 +23,7 @@ import (
"github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/orchestration"
"github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/supervisor"
"github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tlsconfig"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -153,7 +154,7 @@ func prepareTunnelConfig(
log, logTransport *zerolog.Logger, log, logTransport *zerolog.Logger,
observer *connection.Observer, observer *connection.Observer,
namedTunnel *connection.NamedTunnelProperties, namedTunnel *connection.NamedTunnelProperties,
) (*supervisor.TunnelConfig, *supervisor.DynamicConfig, error) { ) (*supervisor.TunnelConfig, *orchestration.Config, error) {
isNamedTunnel := namedTunnel != nil isNamedTunnel := namedTunnel != nil
configHostname := c.String("hostname") configHostname := c.String("hostname")
@ -292,7 +293,7 @@ func prepareTunnelConfig(
ProtocolSelector: protocolSelector, ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs, EdgeTLSConfigs: edgeTLSConfigs,
} }
dynamicConfig := &supervisor.DynamicConfig{ dynamicConfig := &orchestration.Config{
Ingress: &ingressRules, Ingress: &ingressRules,
WarpRoutingEnabled: warpRoutingEnabled, WarpRoutingEnabled: warpRoutingEnabled,
} }

View File

@ -25,9 +25,9 @@ const (
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
type ConfigManager interface { type Orchestrator interface {
Update(version int32, config []byte) *pogs.UpdateConfigurationResponse UpdateConfig(version int32, config []byte) *pogs.UpdateConfigurationResponse
GetOriginProxy() OriginProxy GetOriginProxy() (OriginProxy, error)
} }
type NamedTunnelProperties struct { type NamedTunnelProperties struct {

View File

@ -6,14 +6,12 @@ import (
"io" "io"
"math/rand" "math/rand"
"net/http" "net/http"
"net/url"
"testing" "testing"
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/ingress"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
) )
@ -24,15 +22,10 @@ const (
) )
var ( var (
unusedWarpRoutingService = (*ingress.WarpRoutingService)(nil) testOrchestrator = &mockOrchestrator{
testConfigManager = &mockConfigManager{
originProxy: &mockOriginProxy{}, originProxy: &mockOriginProxy{},
} }
log = zerolog.Nop() log = zerolog.Nop()
testOriginURL = &url.URL{
Scheme: "https",
Host: "connectiontest.argotunnel.com",
}
testLargeResp = make([]byte, largeFileSize) testLargeResp = make([]byte, largeFileSize)
) )
@ -44,18 +37,18 @@ type testRequest struct {
isProxyError bool isProxyError bool
} }
type mockConfigManager struct { type mockOrchestrator struct {
originProxy OriginProxy originProxy OriginProxy
} }
func (*mockConfigManager) Update(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { func (*mockOrchestrator) UpdateConfig(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
return &tunnelpogs.UpdateConfigurationResponse{ return &tunnelpogs.UpdateConfigurationResponse{
LastAppliedVersion: version, LastAppliedVersion: version,
} }
} }
func (mcr *mockConfigManager) GetOriginProxy() OriginProxy { func (mcr *mockOrchestrator) GetOriginProxy() (OriginProxy, error) {
return mcr.originProxy return mcr.originProxy, nil
} }
type mockOriginProxy struct{} type mockOriginProxy struct{}

View File

@ -22,10 +22,10 @@ const (
) )
type h2muxConnection struct { type h2muxConnection struct {
configManager ConfigManager orchestrator Orchestrator
gracePeriod time.Duration gracePeriod time.Duration
muxerConfig *MuxerConfig muxerConfig *MuxerConfig
muxer *h2mux.Muxer muxer *h2mux.Muxer
// connectionID is only used by metrics, and prometheus requires labels to be string // connectionID is only used by metrics, and prometheus requires labels to be string
connIndexStr string connIndexStr string
connIndex uint8 connIndex uint8
@ -61,7 +61,7 @@ func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, log *zerolog.Lo
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
func NewH2muxConnection( func NewH2muxConnection(
configManager ConfigManager, orchestrator Orchestrator,
gracePeriod time.Duration, gracePeriod time.Duration,
muxerConfig *MuxerConfig, muxerConfig *MuxerConfig,
edgeConn net.Conn, edgeConn net.Conn,
@ -70,7 +70,7 @@ func NewH2muxConnection(
gracefulShutdownC <-chan struct{}, gracefulShutdownC <-chan struct{},
) (*h2muxConnection, error, bool) { ) (*h2muxConnection, error, bool) {
h := &h2muxConnection{ h := &h2muxConnection{
configManager: configManager, orchestrator: orchestrator,
gracePeriod: gracePeriod, gracePeriod: gracePeriod,
muxerConfig: muxerConfig, muxerConfig: muxerConfig,
connIndexStr: uint8ToString(connIndex), connIndexStr: uint8ToString(connIndex),
@ -227,7 +227,13 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
sourceConnectionType = TypeWebsocket sourceConnectionType = TypeWebsocket
} }
err := h.configManager.GetOriginProxy().ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket) originProxy, err := h.orchestrator.GetOriginProxy()
if err != nil {
respWriter.WriteErrorResponse()
return err
}
err = originProxy.ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket)
if err != nil { if err != nil {
respWriter.WriteErrorResponse() respWriter.WriteErrorResponse()
} }

View File

@ -48,7 +48,7 @@ func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) {
}() }()
var connIndex = uint8(0) var connIndex = uint8(0)
testObserver := NewObserver(&log, &log, false) testObserver := NewObserver(&log, &log, false)
h2muxConn, err, _ := NewH2muxConnection(testConfigManager, testGracePeriod, testMuxerConfig, originConn, connIndex, testObserver, nil) h2muxConn, err, _ := NewH2muxConnection(testOrchestrator, testGracePeriod, testMuxerConfig, originConn, connIndex, testObserver, nil)
require.NoError(t, err) require.NoError(t, err)
return h2muxConn, <-edgeMuxChan return h2muxConn, <-edgeMuxChan
} }

View File

@ -30,12 +30,12 @@ var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed")
// HTTP2Connection represents a net.Conn that uses HTTP2 frames to proxy traffic from the edge to cloudflared on the // HTTP2Connection represents a net.Conn that uses HTTP2 frames to proxy traffic from the edge to cloudflared on the
// origin. // origin.
type HTTP2Connection struct { type HTTP2Connection struct {
conn net.Conn conn net.Conn
server *http2.Server server *http2.Server
configManager ConfigManager orchestrator Orchestrator
connOptions *tunnelpogs.ConnectionOptions connOptions *tunnelpogs.ConnectionOptions
observer *Observer observer *Observer
connIndex uint8 connIndex uint8
// newRPCClientFunc allows us to mock RPCs during testing // newRPCClientFunc allows us to mock RPCs during testing
newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient newRPCClientFunc func(context.Context, io.ReadWriteCloser, *zerolog.Logger) NamedTunnelRPCClient
@ -49,7 +49,7 @@ type HTTP2Connection struct {
// NewHTTP2Connection returns a new instance of HTTP2Connection. // NewHTTP2Connection returns a new instance of HTTP2Connection.
func NewHTTP2Connection( func NewHTTP2Connection(
conn net.Conn, conn net.Conn,
configManager ConfigManager, orchestrator Orchestrator,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
observer *Observer, observer *Observer,
connIndex uint8, connIndex uint8,
@ -61,7 +61,7 @@ func NewHTTP2Connection(
server: &http2.Server{ server: &http2.Server{
MaxConcurrentStreams: MaxConcurrentStreams, MaxConcurrentStreams: MaxConcurrentStreams,
}, },
configManager: configManager, orchestrator: orchestrator,
connOptions: connOptions, connOptions: connOptions,
observer: observer, observer: observer,
connIndex: connIndex, connIndex: connIndex,
@ -106,6 +106,12 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
originProxy, err := c.orchestrator.GetOriginProxy()
if err != nil {
c.observer.log.Error().Msg(err.Error())
return
}
switch connType { switch connType {
case TypeControlStream: case TypeControlStream:
if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions); err != nil { if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions); err != nil {
@ -116,7 +122,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case TypeWebsocket, TypeHTTP: case TypeWebsocket, TypeHTTP:
stripWebsocketUpgradeHeader(r) stripWebsocketUpgradeHeader(r)
if err := c.configManager.GetOriginProxy().ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil { if err := originProxy.ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil {
err := fmt.Errorf("Failed to proxy HTTP: %w", err) err := fmt.Errorf("Failed to proxy HTTP: %w", err)
c.log.Error().Err(err) c.log.Error().Err(err)
respWriter.WriteErrorResponse() respWriter.WriteErrorResponse()
@ -131,7 +137,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
rws := NewHTTPResponseReadWriterAcker(respWriter, r) rws := NewHTTPResponseReadWriterAcker(respWriter, r)
if err := c.configManager.GetOriginProxy().ProxyTCP(r.Context(), rws, &TCPRequest{ if err := originProxy.ProxyTCP(r.Context(), rws, &TCPRequest{
Dest: host, Dest: host,
CFRay: FindCfRayHeader(r), CFRay: FindCfRayHeader(r),
LBProbe: IsLBProbeRequest(r), LBProbe: IsLBProbeRequest(r),

View File

@ -44,7 +44,7 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
return NewHTTP2Connection( return NewHTTP2Connection(
cfdConn, cfdConn,
// OriginProxy is set in testConfigManager // OriginProxy is set in testConfigManager
testConfigManager, testOrchestrator,
&pogs.ConnectionOptions{}, &pogs.ConnectionOptions{},
obs, obs,
connIndex, connIndex,

View File

@ -36,7 +36,7 @@ const (
type QUICConnection struct { type QUICConnection struct {
session quic.Session session quic.Session
logger *zerolog.Logger logger *zerolog.Logger
configManager ConfigManager orchestrator Orchestrator
sessionManager datagramsession.Manager sessionManager datagramsession.Manager
controlStreamHandler ControlStreamHandler controlStreamHandler ControlStreamHandler
connOptions *tunnelpogs.ConnectionOptions connOptions *tunnelpogs.ConnectionOptions
@ -47,7 +47,7 @@ func NewQUICConnection(
quicConfig *quic.Config, quicConfig *quic.Config,
edgeAddr net.Addr, edgeAddr net.Addr,
tlsConfig *tls.Config, tlsConfig *tls.Config,
configManager ConfigManager, orchestrator Orchestrator,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler ControlStreamHandler, controlStreamHandler ControlStreamHandler,
logger *zerolog.Logger, logger *zerolog.Logger,
@ -66,7 +66,7 @@ func NewQUICConnection(
return &QUICConnection{ return &QUICConnection{
session: session, session: session,
configManager: configManager, orchestrator: orchestrator,
logger: logger, logger: logger,
sessionManager: sessionManager, sessionManager: sessionManager,
controlStreamHandler: controlStreamHandler, controlStreamHandler: controlStreamHandler,
@ -175,6 +175,10 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream)
return err return err
} }
originProxy, err := q.orchestrator.GetOriginProxy()
if err != nil {
return err
}
switch connectRequest.Type { switch connectRequest.Type {
case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket: case quicpogs.ConnectionTypeHTTP, quicpogs.ConnectionTypeWebsocket:
req, err := buildHTTPRequest(connectRequest, stream) req, err := buildHTTPRequest(connectRequest, stream)
@ -183,10 +187,10 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream)
} }
w := newHTTPResponseAdapter(stream) w := newHTTPResponseAdapter(stream)
return q.configManager.GetOriginProxy().ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket) return originProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket)
case quicpogs.ConnectionTypeTCP: case quicpogs.ConnectionTypeTCP:
rwa := &streamReadWriteAcker{stream} rwa := &streamReadWriteAcker{stream}
return q.configManager.GetOriginProxy().ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest}) return originProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest})
} }
return nil return nil
} }

View File

@ -632,7 +632,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection
testQUICConfig, testQUICConfig,
udpListenerAddr, udpListenerAddr,
tlsClientConfig, tlsClientConfig,
&mockConfigManager{originProxy: &mockOriginProxyWithRequest{}}, &mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}},
&tunnelpogs.ConnectionOptions{}, &tunnelpogs.ConnectionOptions{},
fakeControlStream{}, fakeControlStream{},
&log, &log,

View File

@ -7,7 +7,6 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"sync"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -145,13 +144,11 @@ func (ing Ingress) IsSingleRule() bool {
// StartOrigins will start any origin services managed by cloudflared, e.g. proxy servers or Hello World. // StartOrigins will start any origin services managed by cloudflared, e.g. proxy servers or Hello World.
func (ing Ingress) StartOrigins( func (ing Ingress) StartOrigins(
wg *sync.WaitGroup,
log *zerolog.Logger, log *zerolog.Logger,
shutdownC <-chan struct{}, shutdownC <-chan struct{},
errC chan error,
) error { ) error {
for _, rule := range ing.Rules { for _, rule := range ing.Rules {
if err := rule.Service.start(wg, log, shutdownC, errC, rule.Config); err != nil { if err := rule.Service.start(log, shutdownC, rule.Config); err != nil {
return errors.Wrapf(err, "Error starting local service %s", rule.Service) return errors.Wrapf(err, "Error starting local service %s", rule.Service)
} }
} }

View File

@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"sync"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -132,10 +131,8 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
httpService := &httpService{ httpService := &httpService{
url: originURL, url: originURL,
} }
var wg sync.WaitGroup
shutdownC := make(chan struct{}) shutdownC := make(chan struct{})
errC := make(chan error) require.NoError(t, httpService.start(testLogger, shutdownC, cfg))
require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg))
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
require.NoError(t, err) require.NoError(t, err)
@ -169,10 +166,8 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
httpService := &httpService{ httpService := &httpService{
url: originURL, url: originURL,
} }
var wg sync.WaitGroup
shutdownC := make(chan struct{}) shutdownC := make(chan struct{})
errC := make(chan error) require.NoError(t, httpService.start(testLogger, shutdownC, cfg))
require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg))
// Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header // Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header
protos := []string{"https", "http", "dne"} protos := []string{"https", "http", "dne"}

View File

@ -8,7 +8,6 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"sync"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -20,13 +19,18 @@ import (
"github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tlsconfig"
) )
const (
HelloWorldService = "Hello World test origin"
)
// OriginService is something a tunnel can proxy traffic to. // OriginService is something a tunnel can proxy traffic to.
type OriginService interface { type OriginService interface {
String() string String() string
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World. // Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
// If it's not managed by cloudflared, this is a no-op because the user is responsible for // If it's not managed by cloudflared, this is a no-op because the user is responsible for
// starting the origin service. // starting the origin service.
start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error // Implementor of services managed by cloudflared should terminate the service if shutdownC is closed
start(log *zerolog.Logger, shutdownC <-chan struct{}, cfg OriginRequestConfig) error
} }
// unixSocketPath is an OriginService representing a unix socket (which accepts HTTP) // unixSocketPath is an OriginService representing a unix socket (which accepts HTTP)
@ -39,7 +43,7 @@ func (o *unixSocketPath) String() string {
return "unix socket: " + o.path return "unix socket: " + o.path
} }
func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { func (o *unixSocketPath) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
transport, err := newHTTPTransport(o, cfg, log) transport, err := newHTTPTransport(o, cfg, log)
if err != nil { if err != nil {
return err return err
@ -54,7 +58,7 @@ type httpService struct {
transport *http.Transport transport *http.Transport
} }
func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { func (o *httpService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
transport, err := newHTTPTransport(o, cfg, log) transport, err := newHTTPTransport(o, cfg, log)
if err != nil { if err != nil {
return err return err
@ -78,7 +82,7 @@ func (o *rawTCPService) String() string {
return o.name return o.name
} }
func (o *rawTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { func (o *rawTCPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
return nil return nil
} }
@ -139,7 +143,7 @@ func (o *tcpOverWSService) String() string {
return o.dest return o.dest
} }
func (o *tcpOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { func (o *tcpOverWSService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
if cfg.ProxyType == socksProxy { if cfg.ProxyType == socksProxy {
o.streamHandler = socks.StreamHandler o.streamHandler = socks.StreamHandler
} else { } else {
@ -148,7 +152,7 @@ func (o *tcpOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdo
return nil return nil
} }
func (o *socksProxyOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { func (o *socksProxyOverWSService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
return nil return nil
} }
@ -164,18 +168,16 @@ type helloWorld struct {
} }
func (o *helloWorld) String() string { func (o *helloWorld) String() string {
return "Hello World test origin" return HelloWorldService
} }
// Start starts a HelloWorld server and stores its address in the Service receiver. // Start starts a HelloWorld server and stores its address in the Service receiver.
func (o *helloWorld) start( func (o *helloWorld) start(
wg *sync.WaitGroup,
log *zerolog.Logger, log *zerolog.Logger,
shutdownC <-chan struct{}, shutdownC <-chan struct{},
errC chan error,
cfg OriginRequestConfig, cfg OriginRequestConfig,
) error { ) error {
if err := o.httpService.start(wg, log, shutdownC, errC, cfg); err != nil { if err := o.httpService.start(log, shutdownC, cfg); err != nil {
return err return err
} }
@ -183,11 +185,7 @@ func (o *helloWorld) start(
if err != nil { if err != nil {
return errors.Wrap(err, "Cannot start Hello World Server") return errors.Wrap(err, "Cannot start Hello World Server")
} }
wg.Add(1) go hello.StartHelloWorldServer(log, helloListener, shutdownC)
go func() {
defer wg.Done()
_ = hello.StartHelloWorldServer(log, helloListener, shutdownC)
}()
o.server = helloListener o.server = helloListener
o.httpService.url = &url.URL{ o.httpService.url = &url.URL{
@ -218,10 +216,8 @@ func (o *statusCode) String() string {
} }
func (o *statusCode) start( func (o *statusCode) start(
wg *sync.WaitGroup,
log *zerolog.Logger, log *zerolog.Logger,
shutdownC <-chan struct{}, _ <-chan struct{},
errC chan error,
cfg OriginRequestConfig, cfg OriginRequestConfig,
) error { ) error {
return nil return nil
@ -296,6 +292,6 @@ func (mos MockOriginHTTPService) String() string {
return "MockOriginService" return "MockOriginService"
} }
func (mos MockOriginHTTPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { func (mos MockOriginHTTPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
return nil return nil
} }

15
orchestration/config.go Normal file
View File

@ -0,0 +1,15 @@
package orchestration
import (
"github.com/cloudflare/cloudflared/ingress"
)
type newConfig struct {
ingress.RemoteConfig
// Add more fields when we support other settings in tunnel orchestration
}
type Config struct {
Ingress *ingress.Ingress
WarpRoutingEnabled bool
}

View File

@ -0,0 +1,158 @@
package orchestration
import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/proxy"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// Orchestrator manages configurations so they can be updatable during runtime
// properties are static, so it can be read without lock
// currentVersion and config are read/write infrequently, so their access are synchronized with RWMutex
// access to proxy is synchronized with atmoic.Value, because it uses copy-on-write to provide scalable frequently
// read when update is infrequent
type Orchestrator struct {
currentVersion int32
// Used by UpdateConfig to make sure one update at a time
lock sync.RWMutex
// Underlying value is proxy.Proxy, can be read without the lock, but still needs the lock to update
proxy atomic.Value
config *Config
tags []tunnelpogs.Tag
log *zerolog.Logger
// orchestrator must not handle any more updates after shutdownC is closed
shutdownC <-chan struct{}
// Closing proxyShutdownC will close the previous proxy
proxyShutdownC chan<- struct{}
}
func NewOrchestrator(ctx context.Context, config *Config, tags []tunnelpogs.Tag, log *zerolog.Logger) (*Orchestrator, error) {
o := &Orchestrator{
// Lowest possible version, any remote configuration will have version higher than this
currentVersion: 0,
config: config,
tags: tags,
log: log,
shutdownC: ctx.Done(),
}
if err := o.updateIngress(*config.Ingress, config.WarpRoutingEnabled); err != nil {
return nil, err
}
go o.waitToCloseLastProxy()
return o, nil
}
// Update creates a new proxy with the new ingress rules
func (o *Orchestrator) UpdateConfig(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
o.lock.Lock()
defer o.lock.Unlock()
if o.currentVersion >= version {
o.log.Debug().
Int32("current_version", o.currentVersion).
Int32("received_version", version).
Msg("Current version is equal or newer than receivied version")
return &tunnelpogs.UpdateConfigurationResponse{
LastAppliedVersion: o.currentVersion,
}
}
var newConf newConfig
if err := json.Unmarshal(config, &newConf); err != nil {
o.log.Err(err).
Int32("version", version).
Str("config", string(config)).
Msgf("Failed to deserialize new configuration")
return &tunnelpogs.UpdateConfigurationResponse{
LastAppliedVersion: o.currentVersion,
Err: err,
}
}
if err := o.updateIngress(newConf.Ingress, newConf.WarpRouting.Enabled); err != nil {
o.log.Err(err).
Int32("version", version).
Str("config", string(config)).
Msgf("Failed to update ingress")
return &tunnelpogs.UpdateConfigurationResponse{
LastAppliedVersion: o.currentVersion,
Err: err,
}
}
o.currentVersion = version
o.log.Info().
Int32("version", version).
Str("config", string(config)).
Msg("Updated to new configuration")
return &tunnelpogs.UpdateConfigurationResponse{
LastAppliedVersion: o.currentVersion,
}
}
// The caller is responsible to make sure there is no concurrent access
func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRoutingEnabled bool) error {
select {
case <-o.shutdownC:
return fmt.Errorf("cloudflared already shutdown")
default:
}
// Start new proxy before closing the ones from last version.
// The upside is we don't need to restart proxy from last version, which can fail
// The downside is new version might have ingress rule that require previous version to be shutdown first
// The downside is minimized because none of the ingress.OriginService implementation have that requirement
proxyShutdownC := make(chan struct{})
if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil {
return errors.Wrap(err, "failed to start origin")
}
newProxy := proxy.NewOriginProxy(ingressRules, warpRoutingEnabled, o.tags, o.log)
o.proxy.Store(newProxy)
o.config.Ingress = &ingressRules
o.config.WarpRoutingEnabled = warpRoutingEnabled
// If proxyShutdownC is nil, there is no previous running proxy
if o.proxyShutdownC != nil {
close(o.proxyShutdownC)
}
o.proxyShutdownC = proxyShutdownC
return nil
}
// GetOriginProxy returns an interface to proxy to origin. It satisfies connection.ConfigManager interface
func (o *Orchestrator) GetOriginProxy() (connection.OriginProxy, error) {
val := o.proxy.Load()
if val == nil {
err := fmt.Errorf("origin proxy not configured")
o.log.Error().Msg(err.Error())
return nil, err
}
proxy, ok := val.(*proxy.Proxy)
if !ok {
err := fmt.Errorf("origin proxy has unexpected value %+v", val)
o.log.Error().Msg(err.Error())
return nil, err
}
return proxy, nil
}
func (o *Orchestrator) waitToCloseLastProxy() {
<-o.shutdownC
o.lock.Lock()
defer o.lock.Unlock()
if o.proxyShutdownC != nil {
close(o.proxyShutdownC)
o.proxyShutdownC = nil
}
}

View File

@ -0,0 +1,686 @@
package orchestration
import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/gobwas/ws/wsutil"
gows "github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/proxy"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
var (
testLogger = zerolog.Logger{}
testTags = []tunnelpogs.Tag{
{
Name: "package",
Value: "orchestration",
},
{
Name: "purpose",
Value: "test",
},
}
)
// TestUpdateConfiguration tests that
// - configurations can be deserialized
// - proxy can be updated
// - last applied version and error are returned
// - configurations can be deserialized
// - receiving an old version is noop
func TestUpdateConfiguration(t *testing.T) {
initConfig := &Config{
Ingress: &ingress.Ingress{},
WarpRoutingEnabled: false,
}
orchestrator, err := NewOrchestrator(context.Background(), initConfig, testTags, &testLogger)
require.NoError(t, err)
initOriginProxy, err := orchestrator.GetOriginProxy()
require.NoError(t, err)
require.IsType(t, &proxy.Proxy{}, initOriginProxy)
configJSONV2 := []byte(`
{
"unknown_field": "not_deserialized",
"originRequest": {
"connectTimeout": 90000000000,
"noHappyEyeballs": true
},
"ingress": [
{
"hostname": "jira.tunnel.org",
"path": "^\/login",
"service": "http://192.16.19.1:443",
"originRequest": {
"noTLSVerify": true,
"connectTimeout": 10000000000
}
},
{
"hostname": "jira.tunnel.org",
"service": "http://172.32.20.6:80",
"originRequest": {
"noTLSVerify": true,
"connectTimeout": 30000000000
}
},
{
"service": "http_status:404"
}
],
"warp-routing": {
"enabled": true
}
}
`)
updateWithValidation(t, orchestrator, 2, configJSONV2)
configV2 := orchestrator.config
// Validate ingress rule 0
require.Equal(t, "jira.tunnel.org", configV2.Ingress.Rules[0].Hostname)
require.True(t, configV2.Ingress.Rules[0].Matches("jira.tunnel.org", "/login"))
require.True(t, configV2.Ingress.Rules[0].Matches("jira.tunnel.org", "/login/2fa"))
require.False(t, configV2.Ingress.Rules[0].Matches("jira.tunnel.org", "/users"))
require.Equal(t, "http://192.16.19.1:443", configV2.Ingress.Rules[0].Service.String())
require.Len(t, configV2.Ingress.Rules, 3)
// originRequest of this ingress rule overrides global default
require.Equal(t, time.Second*10, configV2.Ingress.Rules[0].Config.ConnectTimeout)
require.Equal(t, true, configV2.Ingress.Rules[0].Config.NoTLSVerify)
// Inherited from global default
require.Equal(t, true, configV2.Ingress.Rules[0].Config.NoHappyEyeballs)
// Validate ingress rule 1
require.Equal(t, "jira.tunnel.org", configV2.Ingress.Rules[1].Hostname)
require.True(t, configV2.Ingress.Rules[1].Matches("jira.tunnel.org", "/users"))
require.Equal(t, "http://172.32.20.6:80", configV2.Ingress.Rules[1].Service.String())
// originRequest of this ingress rule overrides global default
require.Equal(t, time.Second*30, configV2.Ingress.Rules[1].Config.ConnectTimeout)
require.Equal(t, true, configV2.Ingress.Rules[1].Config.NoTLSVerify)
// Inherited from global default
require.Equal(t, true, configV2.Ingress.Rules[1].Config.NoHappyEyeballs)
// Validate ingress rule 2, it's the catch-all rule
require.True(t, configV2.Ingress.Rules[2].Matches("blogs.tunnel.io", "/2022/02/10"))
// Inherited from global default
require.Equal(t, time.Second*90, configV2.Ingress.Rules[2].Config.ConnectTimeout)
require.Equal(t, false, configV2.Ingress.Rules[2].Config.NoTLSVerify)
require.Equal(t, true, configV2.Ingress.Rules[2].Config.NoHappyEyeballs)
require.True(t, configV2.WarpRoutingEnabled)
originProxyV2, err := orchestrator.GetOriginProxy()
require.NoError(t, err)
require.IsType(t, &proxy.Proxy{}, originProxyV2)
require.NotEqual(t, originProxyV2, initOriginProxy)
// Should not downgrade to an older version
resp := orchestrator.UpdateConfig(1, nil)
require.NoError(t, resp.Err)
require.Equal(t, int32(2), resp.LastAppliedVersion)
invalidJSON := []byte(`
{
"originRequest":
}
`)
resp = orchestrator.UpdateConfig(3, invalidJSON)
require.Error(t, resp.Err)
require.Equal(t, int32(2), resp.LastAppliedVersion)
originProxyV3, err := orchestrator.GetOriginProxy()
require.NoError(t, err)
require.Equal(t, originProxyV2, originProxyV3)
configJSONV10 := []byte(`
{
"ingress": [
{
"service": "hello-world"
}
],
"warp-routing": {
"enabled": false
}
}
`)
updateWithValidation(t, orchestrator, 10, configJSONV10)
configV10 := orchestrator.config
require.Len(t, configV10.Ingress.Rules, 1)
require.True(t, configV10.Ingress.Rules[0].Matches("blogs.tunnel.io", "/2022/02/10"))
require.Equal(t, ingress.HelloWorldService, configV10.Ingress.Rules[0].Service.String())
require.False(t, configV10.WarpRoutingEnabled)
originProxyV10, err := orchestrator.GetOriginProxy()
require.NoError(t, err)
require.IsType(t, &proxy.Proxy{}, originProxyV10)
require.NotEqual(t, originProxyV10, originProxyV2)
}
// TestConcurrentUpdateAndRead makes sure orchestrator can receive updates and return origin proxy concurrently
func TestConcurrentUpdateAndRead(t *testing.T) {
const (
concurrentRequests = 200
hostname = "public.tunnels.org"
expectedHost = "internal.tunnels.svc.cluster.local"
tcpBody = "testProxyTCP"
)
httpOrigin := httptest.NewServer(&validateHostHandler{
expectedHost: expectedHost,
body: t.Name(),
})
defer httpOrigin.Close()
tcpOrigin, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer tcpOrigin.Close()
var (
configJSONV1 = []byte(fmt.Sprintf(`
{
"originRequest": {
"connectTimeout": 90000000000,
"noHappyEyeballs": true
},
"ingress": [
{
"hostname": "%s",
"service": "%s",
"originRequest": {
"httpHostHeader": "%s",
"connectTimeout": 10000000000
}
},
{
"service": "http_status:404"
}
],
"warp-routing": {
"enabled": true
}
}
`, hostname, httpOrigin.URL, expectedHost))
configJSONV2 = []byte(`
{
"ingress": [
{
"service": "http_status:204"
}
],
"warp-routing": {
"enabled": false
}
}
`)
configJSONV3 = []byte(`
{
"ingress": [
{
"service": "http_status:418"
}
],
"warp-routing": {
"enabled": true
}
}
`)
// appliedV2 makes sure v3 is applied after v2
appliedV2 = make(chan struct{})
initConfig = &Config{
Ingress: &ingress.Ingress{},
WarpRoutingEnabled: false,
}
)
orchestrator, err := NewOrchestrator(context.Background(), initConfig, testTags, &testLogger)
require.NoError(t, err)
updateWithValidation(t, orchestrator, 1, configJSONV1)
var wg sync.WaitGroup
// tcpOrigin will be closed when the test exits. Only the handler routines are included in the wait group
go func() {
serveTCPOrigin(t, tcpOrigin, &wg)
}()
for i := 0; i < concurrentRequests; i++ {
originProxy, err := orchestrator.GetOriginProxy()
require.NoError(t, err)
wg.Add(1)
go func(i int, originProxy connection.OriginProxy) {
defer wg.Done()
resp, err := proxyHTTP(t, originProxy, hostname)
require.NoError(t, err)
var warpRoutingDisabled bool
// The response can be from initOrigin, http_status:204 or http_status:418
switch resp.StatusCode {
// v1 proxy, warp enabled
case 200:
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, t.Name(), string(body))
warpRoutingDisabled = false
// v2 proxy, warp disabled
case 204:
require.Greater(t, i, concurrentRequests/4)
warpRoutingDisabled = true
// v3 proxy, warp enabled
case 418:
require.Greater(t, i, concurrentRequests/2)
warpRoutingDisabled = false
}
// Once we have originProxy, it won't be changed by configuration updates.
// We can infer the version by the ProxyHTTP response code
pr, pw := io.Pipe()
// concurrentRespWriter makes sure ResponseRecorder is not read/write concurrently, and read waits for the first write
w := newRespReadWriteFlusher()
// Write TCP message and make sure it's echo back. This has to be done in a go routune since ProxyTCP doesn't
// return until the stream is closed.
if !warpRoutingDisabled {
wg.Add(1)
go func() {
defer wg.Done()
defer pw.Close()
tcpEyeball(t, pw, tcpBody, w)
}()
}
proxyTCP(t, originProxy, tcpOrigin.Addr().String(), w, pr, warpRoutingDisabled)
}(i, originProxy)
if i == concurrentRequests/4 {
wg.Add(1)
go func() {
defer wg.Done()
updateWithValidation(t, orchestrator, 2, configJSONV2)
close(appliedV2)
}()
}
if i == concurrentRequests/2 {
wg.Add(1)
go func() {
defer wg.Done()
<-appliedV2
updateWithValidation(t, orchestrator, 3, configJSONV3)
}()
}
}
wg.Wait()
}
func proxyHTTP(t *testing.T, originProxy connection.OriginProxy, hostname string) (*http.Response, error) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", hostname), nil)
require.NoError(t, err)
w := httptest.NewRecorder()
respWriter, err := connection.NewHTTP2RespWriter(req, w, connection.TypeHTTP)
require.NoError(t, err)
err = originProxy.ProxyHTTP(respWriter, req, false)
if err != nil {
return nil, err
}
return w.Result(), nil
}
func tcpEyeball(t *testing.T, reqWriter io.WriteCloser, body string, respReadWriter *respReadWriteFlusher) {
writeN, err := reqWriter.Write([]byte(body))
require.NoError(t, err)
readBuffer := make([]byte, writeN)
n, err := respReadWriter.Read(readBuffer)
require.NoError(t, err)
require.Equal(t, body, string(readBuffer[:n]))
require.Equal(t, writeN, n)
}
func proxyTCP(t *testing.T, originProxy connection.OriginProxy, originAddr string, w http.ResponseWriter, reqBody io.ReadCloser, expectErr bool) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originAddr), reqBody)
require.NoError(t, err)
respWriter, err := connection.NewHTTP2RespWriter(req, w, connection.TypeTCP)
require.NoError(t, err)
tcpReq := &connection.TCPRequest{
Dest: originAddr,
CFRay: "123",
LBProbe: false,
}
rws := connection.NewHTTPResponseReadWriterAcker(respWriter, req)
if expectErr {
require.Error(t, originProxy.ProxyTCP(context.Background(), rws, tcpReq))
return
}
require.NoError(t, originProxy.ProxyTCP(context.Background(), rws, tcpReq))
}
func serveTCPOrigin(t *testing.T, tcpOrigin net.Listener, wg *sync.WaitGroup) {
for {
conn, err := tcpOrigin.Accept()
if err != nil {
return
}
wg.Add(1)
go func() {
defer wg.Done()
defer conn.Close()
echoTCP(t, conn)
}()
}
}
func echoTCP(t *testing.T, conn net.Conn) {
readBuf := make([]byte, 1000)
readN, err := conn.Read(readBuf)
require.NoError(t, err)
writeN, err := conn.Write(readBuf[:readN])
require.NoError(t, err)
require.Equal(t, readN, writeN)
}
type validateHostHandler struct {
expectedHost string
body string
}
func (vhh *validateHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Host != vhh.expectedHost {
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(vhh.body))
}
func updateWithValidation(t *testing.T, orchestrator *Orchestrator, version int32, config []byte) {
resp := orchestrator.UpdateConfig(version, config)
require.NoError(t, resp.Err)
require.Equal(t, version, resp.LastAppliedVersion)
}
// TestClosePreviousProxies makes sure proxies started in the pervious configuration version are shutdown
func TestClosePreviousProxies(t *testing.T) {
var (
hostname = "hello.tunnel1.org"
configWithHelloWorld = []byte(fmt.Sprintf(`
{
"ingress": [
{
"hostname": "%s",
"service": "hello-world"
},
{
"service": "http_status:404"
}
],
"warp-routing": {
"enabled": true
}
}
`, hostname))
configTeapot = []byte(`
{
"ingress": [
{
"service": "http_status:418"
}
],
"warp-routing": {
"enabled": true
}
}
`)
initConfig = &Config{
Ingress: &ingress.Ingress{},
WarpRoutingEnabled: false,
}
)
ctx, cancel := context.WithCancel(context.Background())
orchestrator, err := NewOrchestrator(ctx, initConfig, testTags, &testLogger)
require.NoError(t, err)
updateWithValidation(t, orchestrator, 1, configWithHelloWorld)
originProxyV1, err := orchestrator.GetOriginProxy()
require.NoError(t, err)
resp, err := proxyHTTP(t, originProxyV1, hostname)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
updateWithValidation(t, orchestrator, 2, configTeapot)
originProxyV2, err := orchestrator.GetOriginProxy()
require.NoError(t, err)
resp, err = proxyHTTP(t, originProxyV2, hostname)
require.NoError(t, err)
require.Equal(t, http.StatusTeapot, resp.StatusCode)
// The hello-world server in config v1 should have been stopped
resp, err = proxyHTTP(t, originProxyV1, hostname)
require.Error(t, err)
require.Nil(t, resp)
// Apply the config with hello world server again, orchestrator should spin up another hello world server
updateWithValidation(t, orchestrator, 3, configWithHelloWorld)
originProxyV3, err := orchestrator.GetOriginProxy()
require.NoError(t, err)
require.NotEqual(t, originProxyV1, originProxyV3)
resp, err = proxyHTTP(t, originProxyV3, hostname)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
// cancel the context should terminate the last proxy
cancel()
// Wait for proxies to shutdown
time.Sleep(time.Millisecond * 10)
resp, err = proxyHTTP(t, originProxyV3, hostname)
require.Error(t, err)
require.Nil(t, resp)
}
// TestPersistentConnection makes sure updating the ingress doesn't intefere with existing connections
func TestPersistentConnection(t *testing.T) {
const (
hostname = "http://ws.tunnel.org"
)
msg := t.Name()
initConfig := &Config{
Ingress: &ingress.Ingress{},
WarpRoutingEnabled: false,
}
orchestrator, err := NewOrchestrator(context.Background(), initConfig, testTags, &testLogger)
require.NoError(t, err)
wsOrigin := httptest.NewServer(http.HandlerFunc(wsEcho))
defer wsOrigin.Close()
tcpOrigin, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer tcpOrigin.Close()
configWithWSAndWarp := []byte(fmt.Sprintf(`
{
"ingress": [
{
"service": "%s"
}
],
"warp-routing": {
"enabled": true
}
}
`, wsOrigin.URL))
updateWithValidation(t, orchestrator, 1, configWithWSAndWarp)
originProxy, err := orchestrator.GetOriginProxy()
require.NoError(t, err)
wsReqReader, wsReqWriter := io.Pipe()
wsRespReadWriter := newRespReadWriteFlusher()
tcpReqReader, tcpReqWriter := io.Pipe()
tcpRespReadWriter := newRespReadWriteFlusher()
var wg sync.WaitGroup
wg.Add(3)
// Start TCP origin
go func() {
defer wg.Done()
conn, err := tcpOrigin.Accept()
require.NoError(t, err)
defer conn.Close()
// Expect 3 TCP messages
for i := 0; i < 3; i++ {
echoTCP(t, conn)
}
}()
// Simulate cloudflared recieving a TCP connection
go func() {
defer wg.Done()
proxyTCP(t, originProxy, tcpOrigin.Addr().String(), tcpRespReadWriter, tcpReqReader, false)
}()
// Simulate cloudflared recieving a WS connection
go func() {
defer wg.Done()
req, err := http.NewRequest(http.MethodGet, hostname, wsReqReader)
require.NoError(t, err)
// ProxyHTTP will add Connection, Upgrade and Sec-Websocket-Version headers
req.Header.Add("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
respWriter, err := connection.NewHTTP2RespWriter(req, wsRespReadWriter, connection.TypeWebsocket)
require.NoError(t, err)
err = originProxy.ProxyHTTP(respWriter, req, true)
require.NoError(t, err)
}()
// Simulate eyeball WS and TCP connections
validateWsEcho(t, msg, wsReqWriter, wsRespReadWriter)
tcpEyeball(t, tcpReqWriter, msg, tcpRespReadWriter)
configNoWSAndWarp := []byte(`
{
"ingress": [
{
"service": "http_status:404"
}
],
"warp-routing": {
"enabled": false
}
}
`)
updateWithValidation(t, orchestrator, 2, configNoWSAndWarp)
// Make sure connection is still up
validateWsEcho(t, msg, wsReqWriter, wsRespReadWriter)
tcpEyeball(t, tcpReqWriter, msg, tcpRespReadWriter)
updateWithValidation(t, orchestrator, 3, configWithWSAndWarp)
// Make sure connection is still up
validateWsEcho(t, msg, wsReqWriter, wsRespReadWriter)
tcpEyeball(t, tcpReqWriter, msg, tcpRespReadWriter)
wsReqWriter.Close()
tcpReqWriter.Close()
wg.Wait()
}
func wsEcho(w http.ResponseWriter, r *http.Request) {
upgrader := gows.Upgrader{}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
fmt.Println("read message err", err)
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
fmt.Println("write message err", err)
break
}
}
}
func validateWsEcho(t *testing.T, msg string, reqWriter io.Writer, respReadWriter io.ReadWriter) {
err := wsutil.WriteClientText(reqWriter, []byte(msg))
require.NoError(t, err)
receivedMsg, err := wsutil.ReadServerText(respReadWriter)
require.NoError(t, err)
require.Equal(t, msg, string(receivedMsg))
}
type respReadWriteFlusher struct {
io.Reader
w io.Writer
headers http.Header
statusCode int
setStatusOnce sync.Once
hasStatus chan struct{}
}
func newRespReadWriteFlusher() *respReadWriteFlusher {
pr, pw := io.Pipe()
return &respReadWriteFlusher{
Reader: pr,
w: pw,
headers: make(http.Header),
hasStatus: make(chan struct{}),
}
}
func (rrw *respReadWriteFlusher) Write(buf []byte) (int, error) {
rrw.WriteHeader(http.StatusOK)
return rrw.w.Write(buf)
}
func (rrw *respReadWriteFlusher) Flush() {}
func (rrw *respReadWriteFlusher) Header() http.Header {
return rrw.headers
}
func (rrw *respReadWriteFlusher) WriteHeader(statusCode int) {
rrw.setStatusOnce.Do(func() {
rrw.statusCode = statusCode
close(rrw.hasStatus)
})
}

View File

@ -28,7 +28,7 @@ const (
// Proxy represents a means to Proxy between cloudflared and the origin services. // Proxy represents a means to Proxy between cloudflared and the origin services.
type Proxy struct { type Proxy struct {
ingressRules *ingress.Ingress ingressRules ingress.Ingress
warpRouting *ingress.WarpRoutingService warpRouting *ingress.WarpRoutingService
tags []tunnelpogs.Tag tags []tunnelpogs.Tag
log *zerolog.Logger log *zerolog.Logger
@ -37,18 +37,23 @@ type Proxy struct {
// NewOriginProxy returns a new instance of the Proxy struct. // NewOriginProxy returns a new instance of the Proxy struct.
func NewOriginProxy( func NewOriginProxy(
ingressRules *ingress.Ingress, ingressRules ingress.Ingress,
warpRouting *ingress.WarpRoutingService, warpRoutingEnabled bool,
tags []tunnelpogs.Tag, tags []tunnelpogs.Tag,
log *zerolog.Logger, log *zerolog.Logger,
) *Proxy { ) *Proxy {
return &Proxy{ proxy := &Proxy{
ingressRules: ingressRules, ingressRules: ingressRules,
warpRouting: warpRouting,
tags: tags, tags: tags,
log: log, log: log,
bufferPool: newBufferPool(512 * 1024), bufferPool: newBufferPool(512 * 1024),
} }
if warpRoutingEnabled {
proxy.warpRouting = ingress.NewWarpRoutingService()
log.Info().Msgf("Warp-routing is enabled")
}
return proxy
} }
// ProxyHTTP further depends on ingress rules to establish a connection with the origin service. This may be // ProxyHTTP further depends on ingress rules to establish a connection with the origin service. This may be
@ -139,7 +144,7 @@ func (p *Proxy) ProxyTCP(
return nil return nil
} }
func ruleField(ing *ingress.Ingress, ruleNum int) (ruleID string, srv string) { func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
srv = ing.Rules[ruleNum].Service.String() srv = ing.Rules[ruleNum].Service.String()
if ing.IsSingleRule() { if ing.IsSingleRule() {
return "", srv return "", srv

View File

@ -31,8 +31,7 @@ import (
) )
var ( var (
testTags = []tunnelpogs.Tag{tunnelpogs.Tag{Name: "Name", Value: "value"}} testTags = []tunnelpogs.Tag{tunnelpogs.Tag{Name: "Name", Value: "value"}}
unusedWarpRoutingService = (*ingress.WarpRoutingService)(nil)
) )
type mockHTTPRespWriter struct { type mockHTTPRespWriter struct {
@ -131,17 +130,14 @@ func TestProxySingleOrigin(t *testing.T) {
ingressRule, err := ingress.NewSingleOrigin(cliCtx, allowURLFromArgs) ingressRule, err := ingress.NewSingleOrigin(cliCtx, allowURLFromArgs)
require.NoError(t, err) require.NoError(t, err)
var wg sync.WaitGroup require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
errC := make(chan error)
require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC))
proxy := NewOriginProxy(&ingressRule, unusedWarpRoutingService, testTags, &log) proxy := NewOriginProxy(ingressRule, false, testTags, &log)
t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyHTTP", testProxyHTTP(proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
t.Run("testProxySSE", testProxySSE(proxy)) t.Run("testProxySSE", testProxySSE(proxy))
t.Run("testProxySSEAllData", testProxySSEAllData(proxy)) t.Run("testProxySSEAllData", testProxySSEAllData(proxy))
cancel() cancel()
wg.Wait()
} }
func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) { func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) {
@ -341,11 +337,9 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
log := zerolog.Nop() log := zerolog.Nop()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
errC := make(chan error) require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
var wg sync.WaitGroup
require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC))
proxy := NewOriginProxy(&ingress, unusedWarpRoutingService, testTags, &log) proxy := NewOriginProxy(ingress, false, testTags, &log)
for _, test := range tests { for _, test := range tests {
responseWriter := newMockHTTPRespWriter() responseWriter := newMockHTTPRespWriter()
@ -363,7 +357,6 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
} }
} }
cancel() cancel()
wg.Wait()
} }
type mockAPI struct{} type mockAPI struct{}
@ -394,7 +387,7 @@ func TestProxyError(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
proxy := NewOriginProxy(&ing, unusedWarpRoutingService, testTags, &log) proxy := NewOriginProxy(ing, false, testTags, &log)
responseWriter := newMockHTTPRespWriter() responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
@ -634,10 +627,9 @@ func TestConnections(t *testing.T) {
test.args.originService(t, ln) test.args.originService(t, ln)
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String()) ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
var wg sync.WaitGroup ingressRule.StartOrigins(logger, ctx.Done())
errC := make(chan error) proxy := NewOriginProxy(ingressRule, true, testTags, logger)
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) proxy.warpRouting = test.args.warpRoutingService
proxy := NewOriginProxy(&ingressRule, test.args.warpRoutingService, testTags, logger)
dest := ln.Addr().String() dest := ln.Addr().String()
req, err := http.NewRequest( req, err := http.NewRequest(

View File

@ -1,55 +0,0 @@
package supervisor
import (
"sync"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/proxy"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
type configManager struct {
currentVersion int32
// Only used by UpdateConfig
updateLock sync.Mutex
// TODO: TUN-5698: Make proxy atomic.Value
proxy *proxy.Proxy
config *DynamicConfig
tags []tunnelpogs.Tag
log *zerolog.Logger
}
func newConfigManager(config *DynamicConfig, tags []tunnelpogs.Tag, log *zerolog.Logger) *configManager {
var warpRoutingService *ingress.WarpRoutingService
if config.WarpRoutingEnabled {
warpRoutingService = ingress.NewWarpRoutingService()
log.Info().Msgf("Warp-routing is enabled")
}
return &configManager{
// Lowest possible version, any remote configuration will have version higher than this
currentVersion: 0,
proxy: proxy.NewOriginProxy(config.Ingress, warpRoutingService, tags, log),
config: config,
log: log,
}
}
func (cm *configManager) Update(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
// TODO: TUN-5698: make ingress configurable
return &tunnelpogs.UpdateConfigurationResponse{
LastAppliedVersion: cm.currentVersion,
}
}
func (cm *configManager) GetOriginProxy() connection.OriginProxy {
return cm.proxy
}
type DynamicConfig struct {
Ingress *ingress.Ingress
WarpRoutingEnabled bool
}

View File

@ -13,6 +13,7 @@ import (
"github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/edgediscovery/allregions"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/orchestration"
"github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/retry"
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -37,8 +38,8 @@ const (
// reconnects them if they disconnect. // reconnects them if they disconnect.
type Supervisor struct { type Supervisor struct {
cloudflaredUUID uuid.UUID cloudflaredUUID uuid.UUID
configManager *configManager
config *TunnelConfig config *TunnelConfig
orchestrator *orchestration.Orchestrator
edgeIPs *edgediscovery.Edge edgeIPs *edgediscovery.Edge
tunnelErrors chan tunnelError tunnelErrors chan tunnelError
tunnelsConnecting map[int]chan struct{} tunnelsConnecting map[int]chan struct{}
@ -65,7 +66,7 @@ type tunnelError struct {
err error err error
} }
func NewSupervisor(config *TunnelConfig, dynamiConfig *DynamicConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) { func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrator, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
cloudflaredUUID, err := uuid.NewRandom() cloudflaredUUID, err := uuid.NewRandom()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err) return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err)
@ -89,7 +90,7 @@ func NewSupervisor(config *TunnelConfig, dynamiConfig *DynamicConfig, reconnectC
return &Supervisor{ return &Supervisor{
cloudflaredUUID: cloudflaredUUID, cloudflaredUUID: cloudflaredUUID,
config: config, config: config,
configManager: newConfigManager(dynamiConfig, config.Tags, config.Log), orchestrator: orchestrator,
edgeIPs: edgeIPs, edgeIPs: edgeIPs,
tunnelErrors: make(chan tunnelError), tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{}, tunnelsConnecting: map[int]chan struct{}{},
@ -244,8 +245,8 @@ func (s *Supervisor) startFirstTunnel(
err = ServeTunnelLoop( err = ServeTunnelLoop(
ctx, ctx,
s.reconnectCredentialManager, s.reconnectCredentialManager,
s.configManager,
s.config, s.config,
s.orchestrator,
addr, addr,
s.log, s.log,
firstConnIndex, firstConnIndex,
@ -279,8 +280,8 @@ func (s *Supervisor) startFirstTunnel(
err = ServeTunnelLoop( err = ServeTunnelLoop(
ctx, ctx,
s.reconnectCredentialManager, s.reconnectCredentialManager,
s.configManager,
s.config, s.config,
s.orchestrator,
addr, addr,
s.log, s.log,
firstConnIndex, firstConnIndex,
@ -314,8 +315,8 @@ func (s *Supervisor) startTunnel(
err = ServeTunnelLoop( err = ServeTunnelLoop(
ctx, ctx,
s.reconnectCredentialManager, s.reconnectCredentialManager,
s.configManager,
s.config, s.config,
s.orchestrator,
addr, addr,
s.log, s.log,
uint8(index), uint8(index),

View File

@ -20,6 +20,7 @@ import (
"github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/edgediscovery/allregions"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/orchestration"
quicpogs "github.com/cloudflare/cloudflared/quic" quicpogs "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/retry"
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
@ -107,12 +108,12 @@ func (c *TunnelConfig) SupportedFeatures() []string {
func StartTunnelDaemon( func StartTunnelDaemon(
ctx context.Context, ctx context.Context,
config *TunnelConfig, config *TunnelConfig,
dynamiConfig *DynamicConfig, orchestrator *orchestration.Orchestrator,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
graceShutdownC <-chan struct{}, graceShutdownC <-chan struct{},
) error { ) error {
s, err := NewSupervisor(config, dynamiConfig, reconnectCh, graceShutdownC) s, err := NewSupervisor(config, orchestrator, reconnectCh, graceShutdownC)
if err != nil { if err != nil {
return err return err
} }
@ -122,8 +123,8 @@ func StartTunnelDaemon(
func ServeTunnelLoop( func ServeTunnelLoop(
ctx context.Context, ctx context.Context,
credentialManager *reconnectCredentialManager, credentialManager *reconnectCredentialManager,
configManager *configManager,
config *TunnelConfig, config *TunnelConfig,
orchestrator *orchestration.Orchestrator,
addr *allregions.EdgeAddr, addr *allregions.EdgeAddr,
connAwareLogger *ConnAwareLogger, connAwareLogger *ConnAwareLogger,
connIndex uint8, connIndex uint8,
@ -158,8 +159,8 @@ func ServeTunnelLoop(
ctx, ctx,
connLog, connLog,
credentialManager, credentialManager,
configManager,
config, config,
orchestrator,
addr, addr,
connIndex, connIndex,
connectedFuse, connectedFuse,
@ -257,8 +258,8 @@ func ServeTunnel(
ctx context.Context, ctx context.Context,
connLog *ConnAwareLogger, connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager, credentialManager *reconnectCredentialManager,
configManager *configManager,
config *TunnelConfig, config *TunnelConfig,
orchestrator *orchestration.Orchestrator,
addr *allregions.EdgeAddr, addr *allregions.EdgeAddr,
connIndex uint8, connIndex uint8,
fuse *h2mux.BooleanFuse, fuse *h2mux.BooleanFuse,
@ -286,8 +287,8 @@ func ServeTunnel(
ctx, ctx,
connLog, connLog,
credentialManager, credentialManager,
configManager,
config, config,
orchestrator,
addr, addr,
connIndex, connIndex,
fuse, fuse,
@ -335,8 +336,8 @@ func serveTunnel(
ctx context.Context, ctx context.Context,
connLog *ConnAwareLogger, connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager, credentialManager *reconnectCredentialManager,
configManager *configManager,
config *TunnelConfig, config *TunnelConfig,
orchestrator *orchestration.Orchestrator,
addr *allregions.EdgeAddr, addr *allregions.EdgeAddr,
connIndex uint8, connIndex uint8,
fuse *h2mux.BooleanFuse, fuse *h2mux.BooleanFuse,
@ -365,8 +366,8 @@ func serveTunnel(
connOptions := config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries())) connOptions := config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
return ServeQUIC(ctx, return ServeQUIC(ctx,
addr.UDP, addr.UDP,
configManager,
config, config,
orchestrator,
connLog, connLog,
connOptions, connOptions,
controlStream, controlStream,
@ -385,8 +386,8 @@ func serveTunnel(
if err := ServeHTTP2( if err := ServeHTTP2(
ctx, ctx,
connLog, connLog,
configManager,
config, config,
orchestrator,
edgeConn, edgeConn,
connOptions, connOptions,
controlStream, controlStream,
@ -408,8 +409,8 @@ func serveTunnel(
ctx, ctx,
connLog, connLog,
credentialManager, credentialManager,
configManager,
config, config,
orchestrator,
edgeConn, edgeConn,
connIndex, connIndex,
connectedFuse, connectedFuse,
@ -435,8 +436,8 @@ func ServeH2mux(
ctx context.Context, ctx context.Context,
connLog *ConnAwareLogger, connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager, credentialManager *reconnectCredentialManager,
configManager *configManager,
config *TunnelConfig, config *TunnelConfig,
orchestrator *orchestration.Orchestrator,
edgeConn net.Conn, edgeConn net.Conn,
connIndex uint8, connIndex uint8,
connectedFuse *connectedFuse, connectedFuse *connectedFuse,
@ -447,7 +448,7 @@ func ServeH2mux(
connLog.Logger().Debug().Msgf("Connecting via h2mux") connLog.Logger().Debug().Msgf("Connecting via h2mux")
// Returns error from parsing the origin URL or handshake errors // Returns error from parsing the origin URL or handshake errors
handler, err, recoverable := connection.NewH2muxConnection( handler, err, recoverable := connection.NewH2muxConnection(
configManager, orchestrator,
config.GracePeriod, config.GracePeriod,
config.MuxerConfig, config.MuxerConfig,
edgeConn, edgeConn,
@ -483,8 +484,8 @@ func ServeH2mux(
func ServeHTTP2( func ServeHTTP2(
ctx context.Context, ctx context.Context,
connLog *ConnAwareLogger, connLog *ConnAwareLogger,
configManager *configManager,
config *TunnelConfig, config *TunnelConfig,
orchestrator *orchestration.Orchestrator,
tlsServerConn net.Conn, tlsServerConn net.Conn,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler connection.ControlStreamHandler, controlStreamHandler connection.ControlStreamHandler,
@ -495,7 +496,7 @@ func ServeHTTP2(
connLog.Logger().Debug().Msgf("Connecting via http2") connLog.Logger().Debug().Msgf("Connecting via http2")
h2conn := connection.NewHTTP2Connection( h2conn := connection.NewHTTP2Connection(
tlsServerConn, tlsServerConn,
configManager, orchestrator,
connOptions, connOptions,
config.Observer, config.Observer,
connIndex, connIndex,
@ -523,8 +524,8 @@ func ServeHTTP2(
func ServeQUIC( func ServeQUIC(
ctx context.Context, ctx context.Context,
edgeAddr *net.UDPAddr, edgeAddr *net.UDPAddr,
configManager *configManager,
config *TunnelConfig, config *TunnelConfig,
orchestrator *orchestration.Orchestrator,
connLogger *ConnAwareLogger, connLogger *ConnAwareLogger,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler connection.ControlStreamHandler, controlStreamHandler connection.ControlStreamHandler,
@ -548,7 +549,7 @@ func ServeQUIC(
quicConfig, quicConfig,
edgeAddr, edgeAddr,
tlsConfig, tlsConfig,
configManager, orchestrator,
connOptions, connOptions,
controlStreamHandler, controlStreamHandler,
connLogger.Logger()) connLogger.Logger())