TUN-5749: Refactor cloudflared to pave way for reconfigurable ingress

- Split origin into supervisor and proxy packages
- Create configManager to handle dynamic config
This commit is contained in:
cthuang 2022-02-07 09:42:07 +00:00
parent ff4cfeda0c
commit e22422aafb
33 changed files with 317 additions and 220 deletions

View File

@ -31,8 +31,8 @@ import (
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/supervisor"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns"
)
@ -223,7 +223,7 @@ func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) {
func StartServer(
c *cli.Context,
info *cliutil.BuildInfo,
namedTunnel *connection.NamedTunnelConfig,
namedTunnel *connection.NamedTunnelProperties,
log *zerolog.Logger,
isUIEnabled bool,
) error {
@ -333,7 +333,7 @@ func StartServer(
observer.SendURL(quickTunnelURL)
}
tunnelConfig, ingressRules, err := prepareTunnelConfig(c, info, log, logTransport, observer, namedTunnel)
tunnelConfig, dynamicConfig, err := prepareTunnelConfig(c, info, log, logTransport, observer, namedTunnel)
if err != nil {
log.Err(err).Msg("Couldn't start tunnel")
return err
@ -353,11 +353,11 @@ func StartServer(
errC <- metrics.ServeMetrics(metricsListener, ctx.Done(), readinessServer, quickTunnelURL, log)
}()
if err := ingressRules.StartOrigins(&wg, log, ctx.Done(), errC); err != nil {
if err := dynamicConfig.Ingress.StartOrigins(&wg, log, ctx.Done(), errC); err != nil {
return err
}
reconnectCh := make(chan origin.ReconnectSignal, 1)
reconnectCh := make(chan supervisor.ReconnectSignal, 1)
if c.IsSet("stdin-control") {
log.Info().Msg("Enabling control through stdin")
go stdinControl(reconnectCh, log)
@ -369,7 +369,7 @@ func StartServer(
wg.Done()
log.Info().Msg("Tunnel server stopped")
}()
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, reconnectCh, graceShutdownC)
errC <- supervisor.StartTunnelDaemon(ctx, tunnelConfig, dynamicConfig, connectedSignal, reconnectCh, graceShutdownC)
}()
if isUIEnabled {
@ -377,7 +377,7 @@ func StartServer(
info.Version(),
hostname,
metricsListener.Addr().String(),
&ingressRules,
dynamicConfig.Ingress,
tunnelConfig.HAConnections,
)
app := tunnelUI.Launch(ctx, log, logTransport)
@ -998,7 +998,7 @@ func configureProxyDNSFlags(shouldHide bool) []cli.Flag {
}
}
func stdinControl(reconnectCh chan origin.ReconnectSignal, log *zerolog.Logger) {
func stdinControl(reconnectCh chan supervisor.ReconnectSignal, log *zerolog.Logger) {
for {
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
@ -1009,7 +1009,7 @@ func stdinControl(reconnectCh chan origin.ReconnectSignal, log *zerolog.Logger)
case "":
break
case "reconnect":
var reconnect origin.ReconnectSignal
var reconnect supervisor.ReconnectSignal
if len(parts) > 1 {
var err error
if reconnect.Delay, err = time.ParseDuration(parts[1]); err != nil {

View File

@ -23,7 +23,7 @@ import (
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/supervisor"
"github.com/cloudflare/cloudflared/tlsconfig"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation"
@ -87,7 +87,7 @@ func logClientOptions(c *cli.Context, log *zerolog.Logger) {
}
}
func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelConfig) bool {
func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelProperties) bool {
return c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world") && namedTunnel == nil)
}
@ -152,44 +152,44 @@ func prepareTunnelConfig(
info *cliutil.BuildInfo,
log, logTransport *zerolog.Logger,
observer *connection.Observer,
namedTunnel *connection.NamedTunnelConfig,
) (*origin.TunnelConfig, ingress.Ingress, error) {
namedTunnel *connection.NamedTunnelProperties,
) (*supervisor.TunnelConfig, *supervisor.DynamicConfig, error) {
isNamedTunnel := namedTunnel != nil
configHostname := c.String("hostname")
hostname, err := validation.ValidateHostname(configHostname)
if err != nil {
log.Err(err).Str(LogFieldHostname, configHostname).Msg("Invalid hostname")
return nil, ingress.Ingress{}, errors.Wrap(err, "Invalid hostname")
return nil, nil, errors.Wrap(err, "Invalid hostname")
}
clientID := c.String("id")
if !c.IsSet("id") {
clientID, err = generateRandomClientID(log)
if err != nil {
return nil, ingress.Ingress{}, err
return nil, nil, err
}
}
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
if err != nil {
log.Err(err).Msg("Tag parse failure")
return nil, ingress.Ingress{}, errors.Wrap(err, "Tag parse failure")
return nil, nil, errors.Wrap(err, "Tag parse failure")
}
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
var (
ingressRules ingress.Ingress
classicTunnel *connection.ClassicTunnelConfig
classicTunnel *connection.ClassicTunnelProperties
)
cfg := config.GetConfiguration()
if isNamedTunnel {
clientUUID, err := uuid.NewRandom()
if err != nil {
return nil, ingress.Ingress{}, errors.Wrap(err, "can't generate connector UUID")
return nil, nil, errors.Wrap(err, "can't generate connector UUID")
}
log.Info().Msgf("Generated Connector ID: %s", clientUUID)
features := append(c.StringSlice("features"), origin.FeatureSerializedHeaders)
features := append(c.StringSlice("features"), supervisor.FeatureSerializedHeaders)
namedTunnel.Client = tunnelpogs.ClientInfo{
ClientID: clientUUID[:],
Features: dedup(features),
@ -198,10 +198,10 @@ func prepareTunnelConfig(
}
ingressRules, err = ingress.ParseIngress(cfg)
if err != nil && err != ingress.ErrNoIngressRules {
return nil, ingress.Ingress{}, err
return nil, nil, err
}
if !ingressRules.IsEmpty() && c.IsSet("url") {
return nil, ingress.Ingress{}, ingress.ErrURLIncompatibleWithIngress
return nil, nil, ingress.ErrURLIncompatibleWithIngress
}
} else {
@ -212,10 +212,10 @@ func prepareTunnelConfig(
originCert, err := getOriginCert(originCertPath, &originCertLog)
if err != nil {
return nil, ingress.Ingress{}, errors.Wrap(err, "Error getting origin cert")
return nil, nil, errors.Wrap(err, "Error getting origin cert")
}
classicTunnel = &connection.ClassicTunnelConfig{
classicTunnel = &connection.ClassicTunnelProperties{
Hostname: hostname,
OriginCert: originCert,
// turn off use of reconnect token and auth refresh when using named tunnels
@ -227,20 +227,14 @@ func prepareTunnelConfig(
if ingressRules.IsEmpty() {
ingressRules, err = ingress.NewSingleOrigin(c, !isNamedTunnel)
if err != nil {
return nil, ingress.Ingress{}, err
return nil, nil, err
}
}
var warpRoutingService *ingress.WarpRoutingService
warpRoutingEnabled := isWarpRoutingEnabled(cfg.WarpRouting, isNamedTunnel)
if warpRoutingEnabled {
warpRoutingService = ingress.NewWarpRoutingService()
log.Info().Msgf("Warp-routing is enabled")
}
protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, origin.ResolveTTL, log)
protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, supervisor.ResolveTTL, log)
if err != nil {
return nil, ingress.Ingress{}, err
return nil, nil, err
}
log.Info().Msgf("Initial protocol %s", protocolSelector.Current())
@ -248,11 +242,11 @@ func prepareTunnelConfig(
for _, p := range connection.ProtocolList {
tlsSettings := p.TLSSettings()
if tlsSettings == nil {
return nil, ingress.Ingress{}, fmt.Errorf("%s has unknown TLS settings", p)
return nil, nil, fmt.Errorf("%s has unknown TLS settings", p)
}
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, tlsSettings.ServerName)
if err != nil {
return nil, ingress.Ingress{}, errors.Wrap(err, "unable to create TLS config to connect with edge")
return nil, nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
}
if len(tlsSettings.NextProtos) > 0 {
edgeTLSConfig.NextProtos = tlsSettings.NextProtos
@ -260,15 +254,9 @@ func prepareTunnelConfig(
edgeTLSConfigs[p] = edgeTLSConfig
}
originProxy := origin.NewOriginProxy(ingressRules, warpRoutingService, tags, log)
gracePeriod, err := gracePeriod(c)
if err != nil {
return nil, ingress.Ingress{}, err
}
connectionConfig := &connection.Config{
OriginProxy: originProxy,
GracePeriod: gracePeriod,
ReplaceExisting: c.Bool("force"),
return nil, nil, err
}
muxerConfig := &connection.MuxerConfig{
HeartbeatInterval: c.Duration("heartbeat-interval"),
@ -279,14 +267,15 @@ func prepareTunnelConfig(
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
}
return &origin.TunnelConfig{
ConnectionConfig: connectionConfig,
tunnelConfig := &supervisor.TunnelConfig{
GracePeriod: gracePeriod,
ReplaceExisting: c.Bool("force"),
OSArch: info.OSArch(),
ClientID: clientID,
EdgeAddrs: c.StringSlice("edge"),
Region: c.String("region"),
HAConnections: c.Int("ha-connections"),
IncidentLookup: origin.NewIncidentLookup(),
IncidentLookup: supervisor.NewIncidentLookup(),
IsAutoupdated: c.Bool("is-autoupdated"),
LBPool: c.String("lb-pool"),
Tags: tags,
@ -302,7 +291,12 @@ func prepareTunnelConfig(
MuxerConfig: muxerConfig,
ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs,
}, ingressRules, nil
}
dynamicConfig := &supervisor.DynamicConfig{
Ingress: &ingressRules,
WarpRoutingEnabled: warpRoutingEnabled,
}
return tunnelConfig, dynamicConfig, nil
}
func gracePeriod(c *cli.Context) (time.Duration, error) {

View File

@ -77,7 +77,7 @@ func RunQuickTunnel(sc *subcommandContext) error {
return StartServer(
sc.c,
buildInfo,
&connection.NamedTunnelConfig{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname},
&connection.NamedTunnelProperties{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname},
sc.log,
sc.isUIEnabled,
)

View File

@ -304,7 +304,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
return StartServer(
sc.c,
buildInfo,
&connection.NamedTunnelConfig{Credentials: credentials},
&connection.NamedTunnelProperties{Credentials: credentials},
sc.log,
sc.isUIEnabled,
)

View File

@ -25,13 +25,12 @@ const (
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
type Config struct {
OriginProxy OriginProxy
GracePeriod time.Duration
ReplaceExisting bool
type ConfigManager interface {
Update(version int32, config []byte) *pogs.UpdateConfigurationResponse
GetOriginProxy() OriginProxy
}
type NamedTunnelConfig struct {
type NamedTunnelProperties struct {
Credentials Credentials
Client pogs.ClientInfo
QuickTunnelUrl string
@ -52,7 +51,7 @@ func (c *Credentials) Auth() pogs.TunnelAuth {
}
}
type ClassicTunnelConfig struct {
type ClassicTunnelProperties struct {
Hostname string
OriginCert []byte
// feature-flag to use new edge reconnect tokens

View File

@ -14,18 +14,19 @@ import (
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/ingress"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
)
const (
largeFileSize = 2 * 1024 * 1024
testGracePeriod = time.Millisecond * 100
)
var (
unusedWarpRoutingService = (*ingress.WarpRoutingService)(nil)
testConfig = &Config{
OriginProxy: &mockOriginProxy{},
GracePeriod: time.Millisecond * 100,
testConfigManager = &mockConfigManager{
originProxy: &mockOriginProxy{},
}
log = zerolog.Nop()
testOriginURL = &url.URL{
@ -43,6 +44,20 @@ type testRequest struct {
isProxyError bool
}
type mockConfigManager struct {
originProxy OriginProxy
}
func (*mockConfigManager) Update(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
return &tunnelpogs.UpdateConfigurationResponse{
LastAppliedVersion: version,
}
}
func (mcr *mockConfigManager) GetOriginProxy() OriginProxy {
return mcr.originProxy
}
type mockOriginProxy struct{}
func (moc *mockOriginProxy) ProxyHTTP(

View File

@ -17,7 +17,7 @@ type controlStream struct {
observer *Observer
connectedFuse ConnectedFuse
namedTunnelConfig *NamedTunnelConfig
namedTunnelProperties *NamedTunnelProperties
connIndex uint8
newRPCClientFunc RPCClientFunc
@ -39,7 +39,7 @@ type ControlStreamHandler interface {
func NewControlStream(
observer *Observer,
connectedFuse ConnectedFuse,
namedTunnelConfig *NamedTunnelConfig,
namedTunnelConfig *NamedTunnelProperties,
connIndex uint8,
newRPCClientFunc RPCClientFunc,
gracefulShutdownC <-chan struct{},
@ -51,7 +51,7 @@ func NewControlStream(
return &controlStream{
observer: observer,
connectedFuse: connectedFuse,
namedTunnelConfig: namedTunnelConfig,
namedTunnelProperties: namedTunnelConfig,
newRPCClientFunc: newRPCClientFunc,
connIndex: connIndex,
gracefulShutdownC: gracefulShutdownC,
@ -66,7 +66,7 @@ func (c *controlStream) ServeControlStream(
) error {
rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log)
if err := rpcClient.RegisterConnection(ctx, c.namedTunnelConfig, connOptions, c.connIndex, c.observer); err != nil {
if err := rpcClient.RegisterConnection(ctx, c.namedTunnelProperties, connOptions, c.connIndex, c.observer); err != nil {
rpcClient.Close()
return err
}

View File

@ -22,7 +22,8 @@ const (
)
type h2muxConnection struct {
config *Config
configManager ConfigManager
gracePeriod time.Duration
muxerConfig *MuxerConfig
muxer *h2mux.Muxer
// connectionID is only used by metrics, and prometheus requires labels to be string
@ -60,7 +61,8 @@ func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, log *zerolog.Lo
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
func NewH2muxConnection(
config *Config,
configManager ConfigManager,
gracePeriod time.Duration,
muxerConfig *MuxerConfig,
edgeConn net.Conn,
connIndex uint8,
@ -68,7 +70,8 @@ func NewH2muxConnection(
gracefulShutdownC <-chan struct{},
) (*h2muxConnection, error, bool) {
h := &h2muxConnection{
config: config,
configManager: configManager,
gracePeriod: gracePeriod,
muxerConfig: muxerConfig,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
@ -88,7 +91,7 @@ func NewH2muxConnection(
return h, nil, false
}
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelProperties, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return h.serveMuxer(serveCtx)
@ -117,7 +120,7 @@ func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *Nam
return err
}
func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error {
func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelProperties, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return h.serveMuxer(serveCtx)
@ -224,7 +227,7 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
sourceConnectionType = TypeWebsocket
}
err := h.config.OriginProxy.ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket)
err := h.configManager.GetOriginProxy().ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket)
if err != nil {
respWriter.WriteErrorResponse()
}

View File

@ -48,7 +48,7 @@ func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) {
}()
var connIndex = uint8(0)
testObserver := NewObserver(&log, &log, false)
h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver, nil)
h2muxConn, err, _ := NewH2muxConnection(testConfigManager, testGracePeriod, testMuxerConfig, originConn, connIndex, testObserver, nil)
require.NoError(t, err)
return h2muxConn, <-edgeMuxChan
}

View File

@ -32,7 +32,7 @@ var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed")
type HTTP2Connection struct {
conn net.Conn
server *http2.Server
config *Config
configManager ConfigManager
connOptions *tunnelpogs.ConnectionOptions
observer *Observer
connIndex uint8
@ -49,7 +49,7 @@ type HTTP2Connection struct {
// NewHTTP2Connection returns a new instance of HTTP2Connection.
func NewHTTP2Connection(
conn net.Conn,
config *Config,
configManager ConfigManager,
connOptions *tunnelpogs.ConnectionOptions,
observer *Observer,
connIndex uint8,
@ -61,7 +61,7 @@ func NewHTTP2Connection(
server: &http2.Server{
MaxConcurrentStreams: MaxConcurrentStreams,
},
config: config,
configManager: configManager,
connOptions: connOptions,
observer: observer,
connIndex: connIndex,
@ -116,7 +116,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case TypeWebsocket, TypeHTTP:
stripWebsocketUpgradeHeader(r)
if err := c.config.OriginProxy.ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil {
if err := c.configManager.GetOriginProxy().ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil {
err := fmt.Errorf("Failed to proxy HTTP: %w", err)
c.log.Error().Err(err)
respWriter.WriteErrorResponse()
@ -131,7 +131,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
rws := NewHTTPResponseReadWriterAcker(respWriter, r)
if err := c.config.OriginProxy.ProxyTCP(r.Context(), rws, &TCPRequest{
if err := c.configManager.GetOriginProxy().ProxyTCP(r.Context(), rws, &TCPRequest{
Dest: host,
CFRay: FindCfRayHeader(r),
LBProbe: IsLBProbeRequest(r),

View File

@ -35,7 +35,7 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
&NamedTunnelProperties{},
connIndex,
nil,
nil,
@ -43,8 +43,8 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
)
return NewHTTP2Connection(
cfdConn,
// OriginProxy is set in testConfig
testConfig,
// OriginProxy is set in testConfigManager
testConfigManager,
&pogs.ConnectionOptions{},
obs,
connIndex,
@ -132,7 +132,7 @@ type mockNamedTunnelRPCClient struct {
func (mc mockNamedTunnelRPCClient) RegisterConnection(
c context.Context,
config *NamedTunnelConfig,
properties *NamedTunnelProperties,
options *tunnelpogs.ConnectionOptions,
connIndex uint8,
observer *Observer,
@ -313,7 +313,7 @@ func TestServeControlStream(t *testing.T) {
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
&NamedTunnelProperties{},
1,
rpcClientFactory.newMockRPCClient,
nil,
@ -363,7 +363,7 @@ func TestFailRegistration(t *testing.T) {
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
&NamedTunnelProperties{},
http2Conn.connIndex,
rpcClientFactory.newMockRPCClient,
nil,
@ -409,7 +409,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&NamedTunnelConfig{},
&NamedTunnelProperties{},
http2Conn.connIndex,
rpcClientFactory.newMockRPCClient,
shutdownC,

View File

@ -195,7 +195,7 @@ type PercentageFetcher func() (edgediscovery.ProtocolPercents, error)
func NewProtocolSelector(
protocolFlag string,
warpRoutingEnabled bool,
namedTunnel *NamedTunnelConfig,
namedTunnel *NamedTunnelProperties,
fetchFunc PercentageFetcher,
ttl time.Duration,
log *zerolog.Logger,

View File

@ -16,7 +16,7 @@ const (
)
var (
testNamedTunnelConfig = &NamedTunnelConfig{
testNamedTunnelProperties = &NamedTunnelProperties{
Credentials: Credentials{
AccountTag: "testAccountTag",
},
@ -51,7 +51,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback bool
expectedFallback Protocol
warpRoutingEnabled bool
namedTunnelConfig *NamedTunnelConfig
namedTunnelConfig *NamedTunnelProperties
fetchFunc PercentageFetcher
wantErr bool
}{
@ -66,35 +66,35 @@ func TestNewProtocolSelector(t *testing.T) {
protocol: "h2mux",
expectedProtocol: H2mux,
fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil },
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "named tunnel over http2",
protocol: "http2",
expectedProtocol: HTTP2,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}),
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "named tunnel http2 disabled still gets http2 because it is manually picked",
protocol: "http2",
expectedProtocol: HTTP2,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "named tunnel quic disabled still gets quic because it is manually picked",
protocol: "quic",
expectedProtocol: QUIC,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "named tunnel quic and http2 disabled",
protocol: "auto",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "named tunnel quic disabled",
@ -104,21 +104,21 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: true,
expectedFallback: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "named tunnel auto all http2 disabled",
protocol: "auto",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "named tunnel auto to h2mux",
protocol: "auto",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}),
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "named tunnel auto to http2",
@ -127,7 +127,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: true,
expectedFallback: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "named tunnel auto to quic",
@ -136,7 +136,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: true,
expectedFallback: HTTP2,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "warp routing requesting h2mux",
@ -145,7 +145,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1",
@ -154,7 +154,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "warp routing http2",
@ -163,7 +163,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "warp routing quic",
@ -173,7 +173,7 @@ func TestNewProtocolSelector(t *testing.T) {
expectedFallback: HTTP2Warp,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "warp routing auto",
@ -182,7 +182,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
name: "warp routing auto- quic",
@ -192,7 +192,7 @@ func TestNewProtocolSelector(t *testing.T) {
expectedFallback: HTTP2Warp,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
},
{
// None named tunnel can only use h2mux, so specifying an unknown protocol is not an error
@ -204,14 +204,14 @@ func TestNewProtocolSelector(t *testing.T) {
name: "named tunnel unknown protocol",
protocol: "unknown",
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
wantErr: true,
},
{
name: "named tunnel fetch error",
protocol: "auto",
fetchFunc: mockFetcher(true),
namedTunnelConfig: testNamedTunnelConfig,
namedTunnelConfig: testNamedTunnelProperties,
expectedProtocol: HTTP2,
wantErr: false,
},
@ -237,7 +237,7 @@ func TestNewProtocolSelector(t *testing.T) {
func TestAutoProtocolSelectorRefresh(t *testing.T) {
fetcher := dynamicMockFetcher{}
selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log)
selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log)
assert.NoError(t, err)
assert.Equal(t, H2mux, selector.Current())
@ -267,7 +267,7 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) {
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
fetcher := dynamicMockFetcher{}
// Since the user chooses http2 on purpose, we always stick to it.
selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log)
selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log)
assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current())
@ -297,7 +297,7 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
func TestProtocolSelectorRefreshTTL(t *testing.T) {
fetcher := dynamicMockFetcher{}
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log)
selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), time.Hour, &log)
assert.NoError(t, err)
assert.Equal(t, QUIC, selector.Current())

View File

@ -36,7 +36,7 @@ const (
type QUICConnection struct {
session quic.Session
logger *zerolog.Logger
httpProxy OriginProxy
configManager ConfigManager
sessionManager datagramsession.Manager
controlStreamHandler ControlStreamHandler
connOptions *tunnelpogs.ConnectionOptions
@ -47,7 +47,7 @@ func NewQUICConnection(
quicConfig *quic.Config,
edgeAddr net.Addr,
tlsConfig *tls.Config,
httpProxy OriginProxy,
configManager ConfigManager,
connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler ControlStreamHandler,
logger *zerolog.Logger,
@ -66,7 +66,7 @@ func NewQUICConnection(
return &QUICConnection{
session: session,
httpProxy: httpProxy,
configManager: configManager,
logger: logger,
sessionManager: sessionManager,
controlStreamHandler: controlStreamHandler,
@ -183,10 +183,10 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream)
}
w := newHTTPResponseAdapter(stream)
return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket)
return q.configManager.GetOriginProxy().ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket)
case quicpogs.ConnectionTypeTCP:
rwa := &streamReadWriteAcker{stream}
return q.httpProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest})
return q.configManager.GetOriginProxy().ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest})
}
return nil
}

View File

@ -627,13 +627,12 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection
NextProtos: []string{"argotunnel"},
}
// Start a mock httpProxy
originProxy := &mockOriginProxyWithRequest{}
log := zerolog.New(os.Stdout)
qc, err := NewQUICConnection(
testQUICConfig,
udpListenerAddr,
tlsClientConfig,
originProxy,
&mockConfigManager{originProxy: &mockOriginProxyWithRequest{}},
&tunnelpogs.ConnectionOptions{},
fakeControlStream{},
&log,

View File

@ -37,7 +37,7 @@ func NewTunnelServerClient(
}
}
func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) {
func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelProperties, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) {
authResp, err := tsc.client.Authenticate(ctx, classicTunnel.OriginCert, classicTunnel.Hostname, registrationOptions)
if err != nil {
return nil, err
@ -54,7 +54,7 @@ func (tsc *tunnelServerClient) Close() {
type NamedTunnelRPCClient interface {
RegisterConnection(
c context.Context,
config *NamedTunnelConfig,
config *NamedTunnelProperties,
options *tunnelpogs.ConnectionOptions,
connIndex uint8,
observer *Observer,
@ -86,15 +86,15 @@ func newRegistrationRPCClient(
func (rsc *registrationServerClient) RegisterConnection(
ctx context.Context,
config *NamedTunnelConfig,
properties *NamedTunnelProperties,
options *tunnelpogs.ConnectionOptions,
connIndex uint8,
observer *Observer,
) error {
conn, err := rsc.client.RegisterConnection(
ctx,
config.Credentials.Auth(),
config.Credentials.TunnelID,
properties.Credentials.Auth(),
properties.Credentials.TunnelID,
connIndex,
options,
)
@ -137,7 +137,7 @@ const (
authenticate rpcName = " authenticate"
)
func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelProperties, registrationOptions *tunnelpogs.RegistrationOptions) error {
h.observer.sendRegisteringEvent(registrationOptions.ConnectionID)
stream, err := h.newRPCStream(ctx, register)
@ -174,7 +174,7 @@ type CredentialManager interface {
func (h *h2muxConnection) processRegistrationSuccess(
registration *tunnelpogs.TunnelRegistration,
name rpcName,
credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig,
credentialManager CredentialManager, classicTunnel *ClassicTunnelProperties,
) error {
for _, logLine := range registration.LogLines {
h.observer.log.Info().Msg(logLine)
@ -205,7 +205,7 @@ func (h *h2muxConnection) processRegisterTunnelError(err tunnelpogs.TunnelRegist
}
}
func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelProperties, registrationOptions *tunnelpogs.RegistrationOptions) error {
token, err := credentialManager.ReconnectToken()
if err != nil {
return err
@ -264,7 +264,7 @@ func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelSe
func (h *h2muxConnection) registerNamedTunnel(
ctx context.Context,
namedTunnel *NamedTunnelConfig,
namedTunnel *NamedTunnelProperties,
connOptions *tunnelpogs.ConnectionOptions,
) error {
stream, err := h.newRPCStream(ctx, register)
@ -283,7 +283,7 @@ func (h *h2muxConnection) registerNamedTunnel(
func (h *h2muxConnection) unregister(isNamedTunnel bool) {
h.observer.sendUnregisteringEvent(h.connIndex)
unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod)
unregisterCtx, cancel := context.WithTimeout(context.Background(), h.gracePeriod)
defer cancel()
stream, err := h.newRPCStream(unregisterCtx, unregister)
@ -296,13 +296,13 @@ func (h *h2muxConnection) unregister(isNamedTunnel bool) {
rpcClient := h.newRPCClientFunc(unregisterCtx, stream, h.observer.log)
defer rpcClient.Close()
rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod)
rpcClient.GracefulShutdown(unregisterCtx, h.gracePeriod)
} else {
rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer.log)
defer rpcClient.Close()
// gracePeriod is encoded in int64 using capnproto
_ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds())
_ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.gracePeriod.Nanoseconds())
}
h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unregistered tunnel connection")

View File

@ -1,4 +1,4 @@
package origin
package proxy
import (
"github.com/prometheus/client_golang/prometheus"
@ -43,14 +43,6 @@ var (
Help: "Count of error proxying to origin",
},
)
haConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: connection.MetricsNamespace,
Subsystem: connection.TunnelSubsystem,
Name: "ha_connections",
Help: "Number of active ha connections",
},
)
)
func init() {
@ -59,7 +51,6 @@ func init() {
concurrentRequests,
responseByCode,
requestErrors,
haConnections,
)
}

View File

@ -1,4 +1,4 @@
package origin
package proxy
import (
"sync"

View File

@ -1,4 +1,4 @@
package origin
package proxy
import (
"bufio"
@ -28,7 +28,7 @@ const (
// Proxy represents a means to Proxy between cloudflared and the origin services.
type Proxy struct {
ingressRules ingress.Ingress
ingressRules *ingress.Ingress
warpRouting *ingress.WarpRoutingService
tags []tunnelpogs.Tag
log *zerolog.Logger
@ -37,7 +37,7 @@ type Proxy struct {
// NewOriginProxy returns a new instance of the Proxy struct.
func NewOriginProxy(
ingressRules ingress.Ingress,
ingressRules *ingress.Ingress,
warpRouting *ingress.WarpRoutingService,
tags []tunnelpogs.Tag,
log *zerolog.Logger,
@ -139,7 +139,7 @@ func (p *Proxy) ProxyTCP(
return nil
}
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
func ruleField(ing *ingress.Ingress, ruleNum int) (ruleID string, srv string) {
srv = ing.Rules[ruleNum].Service.String()
if ing.IsSingleRule() {
return "", srv

View File

@ -1,7 +1,7 @@
//go:build !windows
// +build !windows
package origin
package proxy
import (
"io/ioutil"

View File

@ -1,4 +1,4 @@
package origin
package proxy
import (
"bytes"
@ -135,7 +135,7 @@ func TestProxySingleOrigin(t *testing.T) {
errC := make(chan error)
require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC))
proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, &log)
proxy := NewOriginProxy(&ingressRule, unusedWarpRoutingService, testTags, &log)
t.Run("testProxyHTTP", testProxyHTTP(proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
t.Run("testProxySSE", testProxySSE(proxy))
@ -345,7 +345,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
var wg sync.WaitGroup
require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC))
proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log)
proxy := NewOriginProxy(&ingress, unusedWarpRoutingService, testTags, &log)
for _, test := range tests {
responseWriter := newMockHTTPRespWriter()
@ -394,7 +394,7 @@ func TestProxyError(t *testing.T) {
log := zerolog.Nop()
proxy := NewOriginProxy(ing, unusedWarpRoutingService, testTags, &log)
proxy := NewOriginProxy(&ing, unusedWarpRoutingService, testTags, &log)
responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
@ -637,7 +637,7 @@ func TestConnections(t *testing.T) {
var wg sync.WaitGroup
errC := make(chan error)
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC)
proxy := NewOriginProxy(ingressRule, test.args.warpRoutingService, testTags, logger)
proxy := NewOriginProxy(&ingressRule, test.args.warpRoutingService, testTags, logger)
dest := ln.Addr().String()
req, err := http.NewRequest(

View File

@ -1,4 +1,4 @@
package origin
package supervisor
import (
"encoding/json"

View File

@ -1,4 +1,4 @@
package origin
package supervisor
import (
"testing"

View File

@ -0,0 +1,55 @@
package supervisor
import (
"sync"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/proxy"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
type configManager struct {
currentVersion int32
// Only used by UpdateConfig
updateLock sync.Mutex
// TODO: TUN-5698: Make proxy atomic.Value
proxy *proxy.Proxy
config *DynamicConfig
tags []tunnelpogs.Tag
log *zerolog.Logger
}
func newConfigManager(config *DynamicConfig, tags []tunnelpogs.Tag, log *zerolog.Logger) *configManager {
var warpRoutingService *ingress.WarpRoutingService
if config.WarpRoutingEnabled {
warpRoutingService = ingress.NewWarpRoutingService()
log.Info().Msgf("Warp-routing is enabled")
}
return &configManager{
// Lowest possible version, any remote configuration will have version higher than this
currentVersion: 0,
proxy: proxy.NewOriginProxy(config.Ingress, warpRoutingService, tags, log),
config: config,
log: log,
}
}
func (cm *configManager) Update(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
// TODO: TUN-5698: make ingress configurable
return &tunnelpogs.UpdateConfigurationResponse{
LastAppliedVersion: cm.currentVersion,
}
}
func (cm *configManager) GetOriginProxy() connection.OriginProxy {
return cm.proxy
}
type DynamicConfig struct {
Ingress *ingress.Ingress
WarpRoutingEnabled bool
}

View File

@ -1,4 +1,4 @@
package origin
package supervisor
import (
"github.com/rs/zerolog"

View File

@ -1,4 +1,4 @@
package origin
package supervisor
import (
"time"

27
supervisor/metrics.go Normal file
View File

@ -0,0 +1,27 @@
package supervisor
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/cloudflare/cloudflared/connection"
)
// Metrics uses connection.MetricsNamespace(aka cloudflared) as namespace and connection.TunnelSubsystem
// (tunnel) as subsystem to keep them consistent with the previous qualifier.
var (
haConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: connection.MetricsNamespace,
Subsystem: connection.TunnelSubsystem,
Name: "ha_connections",
Help: "Number of active ha connections",
},
)
)
func init() {
prometheus.MustRegister(
haConnections,
)
}

View File

@ -1,4 +1,4 @@
package origin
package supervisor
import (
"context"

View File

@ -1,4 +1,4 @@
package origin
package supervisor
import (
"context"

View File

@ -1,4 +1,4 @@
package origin
package supervisor
import (
"context"
@ -37,6 +37,7 @@ const (
// reconnects them if they disconnect.
type Supervisor struct {
cloudflaredUUID uuid.UUID
configManager *configManager
config *TunnelConfig
edgeIPs *edgediscovery.Edge
tunnelErrors chan tunnelError
@ -64,7 +65,7 @@ type tunnelError struct {
err error
}
func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
func NewSupervisor(config *TunnelConfig, dynamiConfig *DynamicConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
cloudflaredUUID, err := uuid.NewRandom()
if err != nil {
return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err)
@ -88,6 +89,7 @@ func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, grace
return &Supervisor{
cloudflaredUUID: cloudflaredUUID,
config: config,
configManager: newConfigManager(dynamiConfig, config.Tags, config.Log),
edgeIPs: edgeIPs,
tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{},
@ -242,6 +244,7 @@ func (s *Supervisor) startFirstTunnel(
err = ServeTunnelLoop(
ctx,
s.reconnectCredentialManager,
s.configManager,
s.config,
addr,
s.log,
@ -276,6 +279,7 @@ func (s *Supervisor) startFirstTunnel(
err = ServeTunnelLoop(
ctx,
s.reconnectCredentialManager,
s.configManager,
s.config,
addr,
s.log,
@ -310,6 +314,7 @@ func (s *Supervisor) startTunnel(
err = ServeTunnelLoop(
ctx,
s.reconnectCredentialManager,
s.configManager,
s.config,
addr,
s.log,
@ -380,7 +385,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
defer rpcClient.Close()
const arbitraryConnectionID = uint8(0)
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
registrationOptions := s.config.registrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions)
}

View File

@ -1,4 +1,4 @@
package origin
package supervisor
import (
"context"
@ -34,7 +34,8 @@ const (
)
type TunnelConfig struct {
ConnectionConfig *connection.Config
GracePeriod time.Duration
ReplaceExisting bool
OSArch string
ClientID string
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
@ -52,14 +53,14 @@ type TunnelConfig struct {
Retries uint
RunFromTerminal bool
NamedTunnel *connection.NamedTunnelConfig
ClassicTunnel *connection.ClassicTunnelConfig
NamedTunnel *connection.NamedTunnelProperties
ClassicTunnel *connection.ClassicTunnelProperties
MuxerConfig *connection.MuxerConfig
ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config
}
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
func (c *TunnelConfig) registrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
policy := tunnelrpc.ExistingTunnelPolicy_balance
if c.HAConnections <= 1 && c.LBPool == "" {
policy = tunnelrpc.ExistingTunnelPolicy_disconnect
@ -81,7 +82,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
}
}
func (c *TunnelConfig) ConnectionOptions(originLocalAddr string, numPreviousAttempts uint8) *tunnelpogs.ConnectionOptions {
func (c *TunnelConfig) connectionOptions(originLocalAddr string, numPreviousAttempts uint8) *tunnelpogs.ConnectionOptions {
// attempt to parse out origin IP, but don't fail since it's informational field
host, _, _ := net.SplitHostPort(originLocalAddr)
originIP := net.ParseIP(host)
@ -89,7 +90,7 @@ func (c *TunnelConfig) ConnectionOptions(originLocalAddr string, numPreviousAtte
return &tunnelpogs.ConnectionOptions{
Client: c.NamedTunnel.Client,
OriginLocalIP: originIP,
ReplaceExisting: c.ConnectionConfig.ReplaceExisting,
ReplaceExisting: c.ReplaceExisting,
CompressionQuality: uint8(c.MuxerConfig.CompressionSetting),
NumPreviousAttempts: numPreviousAttempts,
}
@ -106,11 +107,12 @@ func (c *TunnelConfig) SupportedFeatures() []string {
func StartTunnelDaemon(
ctx context.Context,
config *TunnelConfig,
dynamiConfig *DynamicConfig,
connectedSignal *signal.Signal,
reconnectCh chan ReconnectSignal,
graceShutdownC <-chan struct{},
) error {
s, err := NewSupervisor(config, reconnectCh, graceShutdownC)
s, err := NewSupervisor(config, dynamiConfig, reconnectCh, graceShutdownC)
if err != nil {
return err
}
@ -120,6 +122,7 @@ func StartTunnelDaemon(
func ServeTunnelLoop(
ctx context.Context,
credentialManager *reconnectCredentialManager,
configManager *configManager,
config *TunnelConfig,
addr *allregions.EdgeAddr,
connAwareLogger *ConnAwareLogger,
@ -155,6 +158,7 @@ func ServeTunnelLoop(
ctx,
connLog,
credentialManager,
configManager,
config,
addr,
connIndex,
@ -253,6 +257,7 @@ func ServeTunnel(
ctx context.Context,
connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager,
configManager *configManager,
config *TunnelConfig,
addr *allregions.EdgeAddr,
connIndex uint8,
@ -281,6 +286,7 @@ func ServeTunnel(
ctx,
connLog,
credentialManager,
configManager,
config,
addr,
connIndex,
@ -329,6 +335,7 @@ func serveTunnel(
ctx context.Context,
connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager,
configManager *configManager,
config *TunnelConfig,
addr *allregions.EdgeAddr,
connIndex uint8,
@ -339,7 +346,6 @@ func serveTunnel(
protocol connection.Protocol,
gracefulShutdownC <-chan struct{},
) (err error, recoverable bool) {
connectedFuse := &connectedFuse{
fuse: fuse,
backoff: backoff,
@ -351,14 +357,15 @@ func serveTunnel(
connIndex,
nil,
gracefulShutdownC,
config.ConnectionConfig.GracePeriod,
config.GracePeriod,
)
switch protocol {
case connection.QUIC, connection.QUICWarp:
connOptions := config.ConnectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
connOptions := config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
return ServeQUIC(ctx,
addr.UDP,
configManager,
config,
connLog,
connOptions,
@ -374,10 +381,11 @@ func serveTunnel(
return err, true
}
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
if err := ServeHTTP2(
ctx,
connLog,
configManager,
config,
edgeConn,
connOptions,
@ -400,6 +408,7 @@ func serveTunnel(
ctx,
connLog,
credentialManager,
configManager,
config,
edgeConn,
connIndex,
@ -426,6 +435,7 @@ func ServeH2mux(
ctx context.Context,
connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager,
configManager *configManager,
config *TunnelConfig,
edgeConn net.Conn,
connIndex uint8,
@ -437,7 +447,8 @@ func ServeH2mux(
connLog.Logger().Debug().Msgf("Connecting via h2mux")
// Returns error from parsing the origin URL or handshake errors
handler, err, recoverable := connection.NewH2muxConnection(
config.ConnectionConfig,
configManager,
config.GracePeriod,
config.MuxerConfig,
edgeConn,
connIndex,
@ -455,10 +466,10 @@ func ServeH2mux(
errGroup.Go(func() error {
if config.NamedTunnel != nil {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries()))
connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries()))
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse)
}
registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
registrationOptions := config.registrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
})
@ -472,6 +483,7 @@ func ServeH2mux(
func ServeHTTP2(
ctx context.Context,
connLog *ConnAwareLogger,
configManager *configManager,
config *TunnelConfig,
tlsServerConn net.Conn,
connOptions *tunnelpogs.ConnectionOptions,
@ -483,7 +495,7 @@ func ServeHTTP2(
connLog.Logger().Debug().Msgf("Connecting via http2")
h2conn := connection.NewHTTP2Connection(
tlsServerConn,
config.ConnectionConfig,
configManager,
connOptions,
config.Observer,
connIndex,
@ -511,6 +523,7 @@ func ServeHTTP2(
func ServeQUIC(
ctx context.Context,
edgeAddr *net.UDPAddr,
configManager *configManager,
config *TunnelConfig,
connLogger *ConnAwareLogger,
connOptions *tunnelpogs.ConnectionOptions,
@ -535,7 +548,7 @@ func ServeQUIC(
quicConfig,
edgeAddr,
tlsConfig,
config.ConnectionConfig.OriginProxy,
configManager,
connOptions,
controlStreamHandler,
connLogger.Logger())

View File

@ -1,4 +1,4 @@
package origin
package supervisor
import (
"testing"
@ -32,11 +32,7 @@ func TestWaitForBackoffFallback(t *testing.T) {
}
log := zerolog.Nop()
resolveTTL := time.Duration(0)
namedTunnel := &connection.NamedTunnelConfig{
Credentials: connection.Credentials{
AccountTag: "test-account",
},
}
namedTunnel := &connection.NamedTunnelProperties{}
mockFetcher := dynamicMockFetcher{
protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}},
}

View File

@ -1,4 +1,4 @@
package origin
package supervisor
import (
"fmt"