TUN-5749: Refactor cloudflared to pave way for reconfigurable ingress

- Split origin into supervisor and proxy packages
- Create configManager to handle dynamic config
This commit is contained in:
cthuang 2022-02-07 09:42:07 +00:00
parent ff4cfeda0c
commit e22422aafb
33 changed files with 317 additions and 220 deletions

View File

@ -31,8 +31,8 @@ import (
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/metrics" "github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/supervisor"
"github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns" "github.com/cloudflare/cloudflared/tunneldns"
) )
@ -223,7 +223,7 @@ func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) {
func StartServer( func StartServer(
c *cli.Context, c *cli.Context,
info *cliutil.BuildInfo, info *cliutil.BuildInfo,
namedTunnel *connection.NamedTunnelConfig, namedTunnel *connection.NamedTunnelProperties,
log *zerolog.Logger, log *zerolog.Logger,
isUIEnabled bool, isUIEnabled bool,
) error { ) error {
@ -333,7 +333,7 @@ func StartServer(
observer.SendURL(quickTunnelURL) observer.SendURL(quickTunnelURL)
} }
tunnelConfig, ingressRules, err := prepareTunnelConfig(c, info, log, logTransport, observer, namedTunnel) tunnelConfig, dynamicConfig, err := prepareTunnelConfig(c, info, log, logTransport, observer, namedTunnel)
if err != nil { if err != nil {
log.Err(err).Msg("Couldn't start tunnel") log.Err(err).Msg("Couldn't start tunnel")
return err return err
@ -353,11 +353,11 @@ func StartServer(
errC <- metrics.ServeMetrics(metricsListener, ctx.Done(), readinessServer, quickTunnelURL, log) errC <- metrics.ServeMetrics(metricsListener, ctx.Done(), readinessServer, quickTunnelURL, log)
}() }()
if err := ingressRules.StartOrigins(&wg, log, ctx.Done(), errC); err != nil { if err := dynamicConfig.Ingress.StartOrigins(&wg, log, ctx.Done(), errC); err != nil {
return err return err
} }
reconnectCh := make(chan origin.ReconnectSignal, 1) reconnectCh := make(chan supervisor.ReconnectSignal, 1)
if c.IsSet("stdin-control") { if c.IsSet("stdin-control") {
log.Info().Msg("Enabling control through stdin") log.Info().Msg("Enabling control through stdin")
go stdinControl(reconnectCh, log) go stdinControl(reconnectCh, log)
@ -369,7 +369,7 @@ func StartServer(
wg.Done() wg.Done()
log.Info().Msg("Tunnel server stopped") log.Info().Msg("Tunnel server stopped")
}() }()
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, reconnectCh, graceShutdownC) errC <- supervisor.StartTunnelDaemon(ctx, tunnelConfig, dynamicConfig, connectedSignal, reconnectCh, graceShutdownC)
}() }()
if isUIEnabled { if isUIEnabled {
@ -377,7 +377,7 @@ func StartServer(
info.Version(), info.Version(),
hostname, hostname,
metricsListener.Addr().String(), metricsListener.Addr().String(),
&ingressRules, dynamicConfig.Ingress,
tunnelConfig.HAConnections, tunnelConfig.HAConnections,
) )
app := tunnelUI.Launch(ctx, log, logTransport) app := tunnelUI.Launch(ctx, log, logTransport)
@ -998,7 +998,7 @@ func configureProxyDNSFlags(shouldHide bool) []cli.Flag {
} }
} }
func stdinControl(reconnectCh chan origin.ReconnectSignal, log *zerolog.Logger) { func stdinControl(reconnectCh chan supervisor.ReconnectSignal, log *zerolog.Logger) {
for { for {
scanner := bufio.NewScanner(os.Stdin) scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() { for scanner.Scan() {
@ -1009,7 +1009,7 @@ func stdinControl(reconnectCh chan origin.ReconnectSignal, log *zerolog.Logger)
case "": case "":
break break
case "reconnect": case "reconnect":
var reconnect origin.ReconnectSignal var reconnect supervisor.ReconnectSignal
if len(parts) > 1 { if len(parts) > 1 {
var err error var err error
if reconnect.Delay, err = time.ParseDuration(parts[1]); err != nil { if reconnect.Delay, err = time.ParseDuration(parts[1]); err != nil {

View File

@ -23,7 +23,7 @@ import (
"github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/origin" "github.com/cloudflare/cloudflared/supervisor"
"github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tlsconfig"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation" "github.com/cloudflare/cloudflared/validation"
@ -87,7 +87,7 @@ func logClientOptions(c *cli.Context, log *zerolog.Logger) {
} }
} }
func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelConfig) bool { func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelProperties) bool {
return c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world") && namedTunnel == nil) return c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world") && namedTunnel == nil)
} }
@ -152,44 +152,44 @@ func prepareTunnelConfig(
info *cliutil.BuildInfo, info *cliutil.BuildInfo,
log, logTransport *zerolog.Logger, log, logTransport *zerolog.Logger,
observer *connection.Observer, observer *connection.Observer,
namedTunnel *connection.NamedTunnelConfig, namedTunnel *connection.NamedTunnelProperties,
) (*origin.TunnelConfig, ingress.Ingress, error) { ) (*supervisor.TunnelConfig, *supervisor.DynamicConfig, error) {
isNamedTunnel := namedTunnel != nil isNamedTunnel := namedTunnel != nil
configHostname := c.String("hostname") configHostname := c.String("hostname")
hostname, err := validation.ValidateHostname(configHostname) hostname, err := validation.ValidateHostname(configHostname)
if err != nil { if err != nil {
log.Err(err).Str(LogFieldHostname, configHostname).Msg("Invalid hostname") log.Err(err).Str(LogFieldHostname, configHostname).Msg("Invalid hostname")
return nil, ingress.Ingress{}, errors.Wrap(err, "Invalid hostname") return nil, nil, errors.Wrap(err, "Invalid hostname")
} }
clientID := c.String("id") clientID := c.String("id")
if !c.IsSet("id") { if !c.IsSet("id") {
clientID, err = generateRandomClientID(log) clientID, err = generateRandomClientID(log)
if err != nil { if err != nil {
return nil, ingress.Ingress{}, err return nil, nil, err
} }
} }
tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
if err != nil { if err != nil {
log.Err(err).Msg("Tag parse failure") log.Err(err).Msg("Tag parse failure")
return nil, ingress.Ingress{}, errors.Wrap(err, "Tag parse failure") return nil, nil, errors.Wrap(err, "Tag parse failure")
} }
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
var ( var (
ingressRules ingress.Ingress ingressRules ingress.Ingress
classicTunnel *connection.ClassicTunnelConfig classicTunnel *connection.ClassicTunnelProperties
) )
cfg := config.GetConfiguration() cfg := config.GetConfiguration()
if isNamedTunnel { if isNamedTunnel {
clientUUID, err := uuid.NewRandom() clientUUID, err := uuid.NewRandom()
if err != nil { if err != nil {
return nil, ingress.Ingress{}, errors.Wrap(err, "can't generate connector UUID") return nil, nil, errors.Wrap(err, "can't generate connector UUID")
} }
log.Info().Msgf("Generated Connector ID: %s", clientUUID) log.Info().Msgf("Generated Connector ID: %s", clientUUID)
features := append(c.StringSlice("features"), origin.FeatureSerializedHeaders) features := append(c.StringSlice("features"), supervisor.FeatureSerializedHeaders)
namedTunnel.Client = tunnelpogs.ClientInfo{ namedTunnel.Client = tunnelpogs.ClientInfo{
ClientID: clientUUID[:], ClientID: clientUUID[:],
Features: dedup(features), Features: dedup(features),
@ -198,10 +198,10 @@ func prepareTunnelConfig(
} }
ingressRules, err = ingress.ParseIngress(cfg) ingressRules, err = ingress.ParseIngress(cfg)
if err != nil && err != ingress.ErrNoIngressRules { if err != nil && err != ingress.ErrNoIngressRules {
return nil, ingress.Ingress{}, err return nil, nil, err
} }
if !ingressRules.IsEmpty() && c.IsSet("url") { if !ingressRules.IsEmpty() && c.IsSet("url") {
return nil, ingress.Ingress{}, ingress.ErrURLIncompatibleWithIngress return nil, nil, ingress.ErrURLIncompatibleWithIngress
} }
} else { } else {
@ -212,10 +212,10 @@ func prepareTunnelConfig(
originCert, err := getOriginCert(originCertPath, &originCertLog) originCert, err := getOriginCert(originCertPath, &originCertLog)
if err != nil { if err != nil {
return nil, ingress.Ingress{}, errors.Wrap(err, "Error getting origin cert") return nil, nil, errors.Wrap(err, "Error getting origin cert")
} }
classicTunnel = &connection.ClassicTunnelConfig{ classicTunnel = &connection.ClassicTunnelProperties{
Hostname: hostname, Hostname: hostname,
OriginCert: originCert, OriginCert: originCert,
// turn off use of reconnect token and auth refresh when using named tunnels // turn off use of reconnect token and auth refresh when using named tunnels
@ -227,20 +227,14 @@ func prepareTunnelConfig(
if ingressRules.IsEmpty() { if ingressRules.IsEmpty() {
ingressRules, err = ingress.NewSingleOrigin(c, !isNamedTunnel) ingressRules, err = ingress.NewSingleOrigin(c, !isNamedTunnel)
if err != nil { if err != nil {
return nil, ingress.Ingress{}, err return nil, nil, err
} }
} }
var warpRoutingService *ingress.WarpRoutingService
warpRoutingEnabled := isWarpRoutingEnabled(cfg.WarpRouting, isNamedTunnel) warpRoutingEnabled := isWarpRoutingEnabled(cfg.WarpRouting, isNamedTunnel)
if warpRoutingEnabled { protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, supervisor.ResolveTTL, log)
warpRoutingService = ingress.NewWarpRoutingService()
log.Info().Msgf("Warp-routing is enabled")
}
protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, origin.ResolveTTL, log)
if err != nil { if err != nil {
return nil, ingress.Ingress{}, err return nil, nil, err
} }
log.Info().Msgf("Initial protocol %s", protocolSelector.Current()) log.Info().Msgf("Initial protocol %s", protocolSelector.Current())
@ -248,11 +242,11 @@ func prepareTunnelConfig(
for _, p := range connection.ProtocolList { for _, p := range connection.ProtocolList {
tlsSettings := p.TLSSettings() tlsSettings := p.TLSSettings()
if tlsSettings == nil { if tlsSettings == nil {
return nil, ingress.Ingress{}, fmt.Errorf("%s has unknown TLS settings", p) return nil, nil, fmt.Errorf("%s has unknown TLS settings", p)
} }
edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, tlsSettings.ServerName) edgeTLSConfig, err := tlsconfig.CreateTunnelConfig(c, tlsSettings.ServerName)
if err != nil { if err != nil {
return nil, ingress.Ingress{}, errors.Wrap(err, "unable to create TLS config to connect with edge") return nil, nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
} }
if len(tlsSettings.NextProtos) > 0 { if len(tlsSettings.NextProtos) > 0 {
edgeTLSConfig.NextProtos = tlsSettings.NextProtos edgeTLSConfig.NextProtos = tlsSettings.NextProtos
@ -260,15 +254,9 @@ func prepareTunnelConfig(
edgeTLSConfigs[p] = edgeTLSConfig edgeTLSConfigs[p] = edgeTLSConfig
} }
originProxy := origin.NewOriginProxy(ingressRules, warpRoutingService, tags, log)
gracePeriod, err := gracePeriod(c) gracePeriod, err := gracePeriod(c)
if err != nil { if err != nil {
return nil, ingress.Ingress{}, err return nil, nil, err
}
connectionConfig := &connection.Config{
OriginProxy: originProxy,
GracePeriod: gracePeriod,
ReplaceExisting: c.Bool("force"),
} }
muxerConfig := &connection.MuxerConfig{ muxerConfig := &connection.MuxerConfig{
HeartbeatInterval: c.Duration("heartbeat-interval"), HeartbeatInterval: c.Duration("heartbeat-interval"),
@ -279,14 +267,15 @@ func prepareTunnelConfig(
MetricsUpdateFreq: c.Duration("metrics-update-freq"), MetricsUpdateFreq: c.Duration("metrics-update-freq"),
} }
return &origin.TunnelConfig{ tunnelConfig := &supervisor.TunnelConfig{
ConnectionConfig: connectionConfig, GracePeriod: gracePeriod,
ReplaceExisting: c.Bool("force"),
OSArch: info.OSArch(), OSArch: info.OSArch(),
ClientID: clientID, ClientID: clientID,
EdgeAddrs: c.StringSlice("edge"), EdgeAddrs: c.StringSlice("edge"),
Region: c.String("region"), Region: c.String("region"),
HAConnections: c.Int("ha-connections"), HAConnections: c.Int("ha-connections"),
IncidentLookup: origin.NewIncidentLookup(), IncidentLookup: supervisor.NewIncidentLookup(),
IsAutoupdated: c.Bool("is-autoupdated"), IsAutoupdated: c.Bool("is-autoupdated"),
LBPool: c.String("lb-pool"), LBPool: c.String("lb-pool"),
Tags: tags, Tags: tags,
@ -302,7 +291,12 @@ func prepareTunnelConfig(
MuxerConfig: muxerConfig, MuxerConfig: muxerConfig,
ProtocolSelector: protocolSelector, ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs, EdgeTLSConfigs: edgeTLSConfigs,
}, ingressRules, nil }
dynamicConfig := &supervisor.DynamicConfig{
Ingress: &ingressRules,
WarpRoutingEnabled: warpRoutingEnabled,
}
return tunnelConfig, dynamicConfig, nil
} }
func gracePeriod(c *cli.Context) (time.Duration, error) { func gracePeriod(c *cli.Context) (time.Duration, error) {

View File

@ -77,7 +77,7 @@ func RunQuickTunnel(sc *subcommandContext) error {
return StartServer( return StartServer(
sc.c, sc.c,
buildInfo, buildInfo,
&connection.NamedTunnelConfig{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname}, &connection.NamedTunnelProperties{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname},
sc.log, sc.log,
sc.isUIEnabled, sc.isUIEnabled,
) )

View File

@ -304,7 +304,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
return StartServer( return StartServer(
sc.c, sc.c,
buildInfo, buildInfo,
&connection.NamedTunnelConfig{Credentials: credentials}, &connection.NamedTunnelProperties{Credentials: credentials},
sc.log, sc.log,
sc.isUIEnabled, sc.isUIEnabled,
) )

View File

@ -25,13 +25,12 @@ const (
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
type Config struct { type ConfigManager interface {
OriginProxy OriginProxy Update(version int32, config []byte) *pogs.UpdateConfigurationResponse
GracePeriod time.Duration GetOriginProxy() OriginProxy
ReplaceExisting bool
} }
type NamedTunnelConfig struct { type NamedTunnelProperties struct {
Credentials Credentials Credentials Credentials
Client pogs.ClientInfo Client pogs.ClientInfo
QuickTunnelUrl string QuickTunnelUrl string
@ -52,7 +51,7 @@ func (c *Credentials) Auth() pogs.TunnelAuth {
} }
} }
type ClassicTunnelConfig struct { type ClassicTunnelProperties struct {
Hostname string Hostname string
OriginCert []byte OriginCert []byte
// feature-flag to use new edge reconnect tokens // feature-flag to use new edge reconnect tokens

View File

@ -14,18 +14,19 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
) )
const ( const (
largeFileSize = 2 * 1024 * 1024 largeFileSize = 2 * 1024 * 1024
testGracePeriod = time.Millisecond * 100
) )
var ( var (
unusedWarpRoutingService = (*ingress.WarpRoutingService)(nil) unusedWarpRoutingService = (*ingress.WarpRoutingService)(nil)
testConfig = &Config{ testConfigManager = &mockConfigManager{
OriginProxy: &mockOriginProxy{}, originProxy: &mockOriginProxy{},
GracePeriod: time.Millisecond * 100,
} }
log = zerolog.Nop() log = zerolog.Nop()
testOriginURL = &url.URL{ testOriginURL = &url.URL{
@ -43,6 +44,20 @@ type testRequest struct {
isProxyError bool isProxyError bool
} }
type mockConfigManager struct {
originProxy OriginProxy
}
func (*mockConfigManager) Update(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
return &tunnelpogs.UpdateConfigurationResponse{
LastAppliedVersion: version,
}
}
func (mcr *mockConfigManager) GetOriginProxy() OriginProxy {
return mcr.originProxy
}
type mockOriginProxy struct{} type mockOriginProxy struct{}
func (moc *mockOriginProxy) ProxyHTTP( func (moc *mockOriginProxy) ProxyHTTP(

View File

@ -17,7 +17,7 @@ type controlStream struct {
observer *Observer observer *Observer
connectedFuse ConnectedFuse connectedFuse ConnectedFuse
namedTunnelConfig *NamedTunnelConfig namedTunnelProperties *NamedTunnelProperties
connIndex uint8 connIndex uint8
newRPCClientFunc RPCClientFunc newRPCClientFunc RPCClientFunc
@ -39,7 +39,7 @@ type ControlStreamHandler interface {
func NewControlStream( func NewControlStream(
observer *Observer, observer *Observer,
connectedFuse ConnectedFuse, connectedFuse ConnectedFuse,
namedTunnelConfig *NamedTunnelConfig, namedTunnelConfig *NamedTunnelProperties,
connIndex uint8, connIndex uint8,
newRPCClientFunc RPCClientFunc, newRPCClientFunc RPCClientFunc,
gracefulShutdownC <-chan struct{}, gracefulShutdownC <-chan struct{},
@ -51,7 +51,7 @@ func NewControlStream(
return &controlStream{ return &controlStream{
observer: observer, observer: observer,
connectedFuse: connectedFuse, connectedFuse: connectedFuse,
namedTunnelConfig: namedTunnelConfig, namedTunnelProperties: namedTunnelConfig,
newRPCClientFunc: newRPCClientFunc, newRPCClientFunc: newRPCClientFunc,
connIndex: connIndex, connIndex: connIndex,
gracefulShutdownC: gracefulShutdownC, gracefulShutdownC: gracefulShutdownC,
@ -66,7 +66,7 @@ func (c *controlStream) ServeControlStream(
) error { ) error {
rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log) rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log)
if err := rpcClient.RegisterConnection(ctx, c.namedTunnelConfig, connOptions, c.connIndex, c.observer); err != nil { if err := rpcClient.RegisterConnection(ctx, c.namedTunnelProperties, connOptions, c.connIndex, c.observer); err != nil {
rpcClient.Close() rpcClient.Close()
return err return err
} }

View File

@ -22,7 +22,8 @@ const (
) )
type h2muxConnection struct { type h2muxConnection struct {
config *Config configManager ConfigManager
gracePeriod time.Duration
muxerConfig *MuxerConfig muxerConfig *MuxerConfig
muxer *h2mux.Muxer muxer *h2mux.Muxer
// connectionID is only used by metrics, and prometheus requires labels to be string // connectionID is only used by metrics, and prometheus requires labels to be string
@ -60,7 +61,8 @@ func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, log *zerolog.Lo
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
func NewH2muxConnection( func NewH2muxConnection(
config *Config, configManager ConfigManager,
gracePeriod time.Duration,
muxerConfig *MuxerConfig, muxerConfig *MuxerConfig,
edgeConn net.Conn, edgeConn net.Conn,
connIndex uint8, connIndex uint8,
@ -68,7 +70,8 @@ func NewH2muxConnection(
gracefulShutdownC <-chan struct{}, gracefulShutdownC <-chan struct{},
) (*h2muxConnection, error, bool) { ) (*h2muxConnection, error, bool) {
h := &h2muxConnection{ h := &h2muxConnection{
config: config, configManager: configManager,
gracePeriod: gracePeriod,
muxerConfig: muxerConfig, muxerConfig: muxerConfig,
connIndexStr: uint8ToString(connIndex), connIndexStr: uint8ToString(connIndex),
connIndex: connIndex, connIndex: connIndex,
@ -88,7 +91,7 @@ func NewH2muxConnection(
return h, nil, false return h, nil, false
} }
func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error { func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *NamedTunnelProperties, connOptions *tunnelpogs.ConnectionOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx) errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
return h.serveMuxer(serveCtx) return h.serveMuxer(serveCtx)
@ -117,7 +120,7 @@ func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *Nam
return err return err
} }
func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error { func (h *h2muxConnection) ServeClassicTunnel(ctx context.Context, classicTunnel *ClassicTunnelProperties, credentialManager CredentialManager, registrationOptions *tunnelpogs.RegistrationOptions, connectedFuse ConnectedFuse) error {
errGroup, serveCtx := errgroup.WithContext(ctx) errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
return h.serveMuxer(serveCtx) return h.serveMuxer(serveCtx)
@ -224,7 +227,7 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
sourceConnectionType = TypeWebsocket sourceConnectionType = TypeWebsocket
} }
err := h.config.OriginProxy.ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket) err := h.configManager.GetOriginProxy().ProxyHTTP(respWriter, req, sourceConnectionType == TypeWebsocket)
if err != nil { if err != nil {
respWriter.WriteErrorResponse() respWriter.WriteErrorResponse()
} }

View File

@ -48,7 +48,7 @@ func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) {
}() }()
var connIndex = uint8(0) var connIndex = uint8(0)
testObserver := NewObserver(&log, &log, false) testObserver := NewObserver(&log, &log, false)
h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver, nil) h2muxConn, err, _ := NewH2muxConnection(testConfigManager, testGracePeriod, testMuxerConfig, originConn, connIndex, testObserver, nil)
require.NoError(t, err) require.NoError(t, err)
return h2muxConn, <-edgeMuxChan return h2muxConn, <-edgeMuxChan
} }

View File

@ -32,7 +32,7 @@ var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed")
type HTTP2Connection struct { type HTTP2Connection struct {
conn net.Conn conn net.Conn
server *http2.Server server *http2.Server
config *Config configManager ConfigManager
connOptions *tunnelpogs.ConnectionOptions connOptions *tunnelpogs.ConnectionOptions
observer *Observer observer *Observer
connIndex uint8 connIndex uint8
@ -49,7 +49,7 @@ type HTTP2Connection struct {
// NewHTTP2Connection returns a new instance of HTTP2Connection. // NewHTTP2Connection returns a new instance of HTTP2Connection.
func NewHTTP2Connection( func NewHTTP2Connection(
conn net.Conn, conn net.Conn,
config *Config, configManager ConfigManager,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
observer *Observer, observer *Observer,
connIndex uint8, connIndex uint8,
@ -61,7 +61,7 @@ func NewHTTP2Connection(
server: &http2.Server{ server: &http2.Server{
MaxConcurrentStreams: MaxConcurrentStreams, MaxConcurrentStreams: MaxConcurrentStreams,
}, },
config: config, configManager: configManager,
connOptions: connOptions, connOptions: connOptions,
observer: observer, observer: observer,
connIndex: connIndex, connIndex: connIndex,
@ -116,7 +116,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case TypeWebsocket, TypeHTTP: case TypeWebsocket, TypeHTTP:
stripWebsocketUpgradeHeader(r) stripWebsocketUpgradeHeader(r)
if err := c.config.OriginProxy.ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil { if err := c.configManager.GetOriginProxy().ProxyHTTP(respWriter, r, connType == TypeWebsocket); err != nil {
err := fmt.Errorf("Failed to proxy HTTP: %w", err) err := fmt.Errorf("Failed to proxy HTTP: %w", err)
c.log.Error().Err(err) c.log.Error().Err(err)
respWriter.WriteErrorResponse() respWriter.WriteErrorResponse()
@ -131,7 +131,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
rws := NewHTTPResponseReadWriterAcker(respWriter, r) rws := NewHTTPResponseReadWriterAcker(respWriter, r)
if err := c.config.OriginProxy.ProxyTCP(r.Context(), rws, &TCPRequest{ if err := c.configManager.GetOriginProxy().ProxyTCP(r.Context(), rws, &TCPRequest{
Dest: host, Dest: host,
CFRay: FindCfRayHeader(r), CFRay: FindCfRayHeader(r),
LBProbe: IsLBProbeRequest(r), LBProbe: IsLBProbeRequest(r),

View File

@ -35,7 +35,7 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
controlStream := NewControlStream( controlStream := NewControlStream(
obs, obs,
mockConnectedFuse{}, mockConnectedFuse{},
&NamedTunnelConfig{}, &NamedTunnelProperties{},
connIndex, connIndex,
nil, nil,
nil, nil,
@ -43,8 +43,8 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
) )
return NewHTTP2Connection( return NewHTTP2Connection(
cfdConn, cfdConn,
// OriginProxy is set in testConfig // OriginProxy is set in testConfigManager
testConfig, testConfigManager,
&pogs.ConnectionOptions{}, &pogs.ConnectionOptions{},
obs, obs,
connIndex, connIndex,
@ -132,7 +132,7 @@ type mockNamedTunnelRPCClient struct {
func (mc mockNamedTunnelRPCClient) RegisterConnection( func (mc mockNamedTunnelRPCClient) RegisterConnection(
c context.Context, c context.Context,
config *NamedTunnelConfig, properties *NamedTunnelProperties,
options *tunnelpogs.ConnectionOptions, options *tunnelpogs.ConnectionOptions,
connIndex uint8, connIndex uint8,
observer *Observer, observer *Observer,
@ -313,7 +313,7 @@ func TestServeControlStream(t *testing.T) {
controlStream := NewControlStream( controlStream := NewControlStream(
obs, obs,
mockConnectedFuse{}, mockConnectedFuse{},
&NamedTunnelConfig{}, &NamedTunnelProperties{},
1, 1,
rpcClientFactory.newMockRPCClient, rpcClientFactory.newMockRPCClient,
nil, nil,
@ -363,7 +363,7 @@ func TestFailRegistration(t *testing.T) {
controlStream := NewControlStream( controlStream := NewControlStream(
obs, obs,
mockConnectedFuse{}, mockConnectedFuse{},
&NamedTunnelConfig{}, &NamedTunnelProperties{},
http2Conn.connIndex, http2Conn.connIndex,
rpcClientFactory.newMockRPCClient, rpcClientFactory.newMockRPCClient,
nil, nil,
@ -409,7 +409,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
controlStream := NewControlStream( controlStream := NewControlStream(
obs, obs,
mockConnectedFuse{}, mockConnectedFuse{},
&NamedTunnelConfig{}, &NamedTunnelProperties{},
http2Conn.connIndex, http2Conn.connIndex,
rpcClientFactory.newMockRPCClient, rpcClientFactory.newMockRPCClient,
shutdownC, shutdownC,

View File

@ -195,7 +195,7 @@ type PercentageFetcher func() (edgediscovery.ProtocolPercents, error)
func NewProtocolSelector( func NewProtocolSelector(
protocolFlag string, protocolFlag string,
warpRoutingEnabled bool, warpRoutingEnabled bool,
namedTunnel *NamedTunnelConfig, namedTunnel *NamedTunnelProperties,
fetchFunc PercentageFetcher, fetchFunc PercentageFetcher,
ttl time.Duration, ttl time.Duration,
log *zerolog.Logger, log *zerolog.Logger,

View File

@ -16,7 +16,7 @@ const (
) )
var ( var (
testNamedTunnelConfig = &NamedTunnelConfig{ testNamedTunnelProperties = &NamedTunnelProperties{
Credentials: Credentials{ Credentials: Credentials{
AccountTag: "testAccountTag", AccountTag: "testAccountTag",
}, },
@ -51,7 +51,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback bool hasFallback bool
expectedFallback Protocol expectedFallback Protocol
warpRoutingEnabled bool warpRoutingEnabled bool
namedTunnelConfig *NamedTunnelConfig namedTunnelConfig *NamedTunnelProperties
fetchFunc PercentageFetcher fetchFunc PercentageFetcher
wantErr bool wantErr bool
}{ }{
@ -66,35 +66,35 @@ func TestNewProtocolSelector(t *testing.T) {
protocol: "h2mux", protocol: "h2mux",
expectedProtocol: H2mux, expectedProtocol: H2mux,
fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil }, fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil },
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "named tunnel over http2", name: "named tunnel over http2",
protocol: "http2", protocol: "http2",
expectedProtocol: HTTP2, expectedProtocol: HTTP2,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "named tunnel http2 disabled still gets http2 because it is manually picked", name: "named tunnel http2 disabled still gets http2 because it is manually picked",
protocol: "http2", protocol: "http2",
expectedProtocol: HTTP2, expectedProtocol: HTTP2,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "named tunnel quic disabled still gets quic because it is manually picked", name: "named tunnel quic disabled still gets quic because it is manually picked",
protocol: "quic", protocol: "quic",
expectedProtocol: QUIC, expectedProtocol: QUIC,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "named tunnel quic and http2 disabled", name: "named tunnel quic and http2 disabled",
protocol: "auto", protocol: "auto",
expectedProtocol: H2mux, expectedProtocol: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "named tunnel quic disabled", name: "named tunnel quic disabled",
@ -104,21 +104,21 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: true, hasFallback: true,
expectedFallback: H2mux, expectedFallback: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "named tunnel auto all http2 disabled", name: "named tunnel auto all http2 disabled",
protocol: "auto", protocol: "auto",
expectedProtocol: H2mux, expectedProtocol: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "named tunnel auto to h2mux", name: "named tunnel auto to h2mux",
protocol: "auto", protocol: "auto",
expectedProtocol: H2mux, expectedProtocol: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "named tunnel auto to http2", name: "named tunnel auto to http2",
@ -127,7 +127,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: true, hasFallback: true,
expectedFallback: H2mux, expectedFallback: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "named tunnel auto to quic", name: "named tunnel auto to quic",
@ -136,7 +136,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: true, hasFallback: true,
expectedFallback: HTTP2, expectedFallback: HTTP2,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "warp routing requesting h2mux", name: "warp routing requesting h2mux",
@ -145,7 +145,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: false, hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
warpRoutingEnabled: true, warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1", name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1",
@ -154,7 +154,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: false, hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
warpRoutingEnabled: true, warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "warp routing http2", name: "warp routing http2",
@ -163,7 +163,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: false, hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
warpRoutingEnabled: true, warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "warp routing quic", name: "warp routing quic",
@ -173,7 +173,7 @@ func TestNewProtocolSelector(t *testing.T) {
expectedFallback: HTTP2Warp, expectedFallback: HTTP2Warp,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
warpRoutingEnabled: true, warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "warp routing auto", name: "warp routing auto",
@ -182,7 +182,7 @@ func TestNewProtocolSelector(t *testing.T) {
hasFallback: false, hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
warpRoutingEnabled: true, warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
name: "warp routing auto- quic", name: "warp routing auto- quic",
@ -192,7 +192,7 @@ func TestNewProtocolSelector(t *testing.T) {
expectedFallback: HTTP2Warp, expectedFallback: HTTP2Warp,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
warpRoutingEnabled: true, warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
}, },
{ {
// None named tunnel can only use h2mux, so specifying an unknown protocol is not an error // None named tunnel can only use h2mux, so specifying an unknown protocol is not an error
@ -204,14 +204,14 @@ func TestNewProtocolSelector(t *testing.T) {
name: "named tunnel unknown protocol", name: "named tunnel unknown protocol",
protocol: "unknown", protocol: "unknown",
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
wantErr: true, wantErr: true,
}, },
{ {
name: "named tunnel fetch error", name: "named tunnel fetch error",
protocol: "auto", protocol: "auto",
fetchFunc: mockFetcher(true), fetchFunc: mockFetcher(true),
namedTunnelConfig: testNamedTunnelConfig, namedTunnelConfig: testNamedTunnelProperties,
expectedProtocol: HTTP2, expectedProtocol: HTTP2,
wantErr: false, wantErr: false,
}, },
@ -237,7 +237,7 @@ func TestNewProtocolSelector(t *testing.T) {
func TestAutoProtocolSelectorRefresh(t *testing.T) { func TestAutoProtocolSelectorRefresh(t *testing.T) {
fetcher := dynamicMockFetcher{} fetcher := dynamicMockFetcher{}
selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log) selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, H2mux, selector.Current()) assert.Equal(t, H2mux, selector.Current())
@ -267,7 +267,7 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) {
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
fetcher := dynamicMockFetcher{} fetcher := dynamicMockFetcher{}
// Since the user chooses http2 on purpose, we always stick to it. // Since the user chooses http2 on purpose, we always stick to it.
selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log) selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current()) assert.Equal(t, HTTP2, selector.Current())
@ -297,7 +297,7 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
func TestProtocolSelectorRefreshTTL(t *testing.T) { func TestProtocolSelectorRefreshTTL(t *testing.T) {
fetcher := dynamicMockFetcher{} fetcher := dynamicMockFetcher{}
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log) selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), time.Hour, &log)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, QUIC, selector.Current()) assert.Equal(t, QUIC, selector.Current())

View File

@ -36,7 +36,7 @@ const (
type QUICConnection struct { type QUICConnection struct {
session quic.Session session quic.Session
logger *zerolog.Logger logger *zerolog.Logger
httpProxy OriginProxy configManager ConfigManager
sessionManager datagramsession.Manager sessionManager datagramsession.Manager
controlStreamHandler ControlStreamHandler controlStreamHandler ControlStreamHandler
connOptions *tunnelpogs.ConnectionOptions connOptions *tunnelpogs.ConnectionOptions
@ -47,7 +47,7 @@ func NewQUICConnection(
quicConfig *quic.Config, quicConfig *quic.Config,
edgeAddr net.Addr, edgeAddr net.Addr,
tlsConfig *tls.Config, tlsConfig *tls.Config,
httpProxy OriginProxy, configManager ConfigManager,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler ControlStreamHandler, controlStreamHandler ControlStreamHandler,
logger *zerolog.Logger, logger *zerolog.Logger,
@ -66,7 +66,7 @@ func NewQUICConnection(
return &QUICConnection{ return &QUICConnection{
session: session, session: session,
httpProxy: httpProxy, configManager: configManager,
logger: logger, logger: logger,
sessionManager: sessionManager, sessionManager: sessionManager,
controlStreamHandler: controlStreamHandler, controlStreamHandler: controlStreamHandler,
@ -183,10 +183,10 @@ func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream)
} }
w := newHTTPResponseAdapter(stream) w := newHTTPResponseAdapter(stream)
return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket) return q.configManager.GetOriginProxy().ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket)
case quicpogs.ConnectionTypeTCP: case quicpogs.ConnectionTypeTCP:
rwa := &streamReadWriteAcker{stream} rwa := &streamReadWriteAcker{stream}
return q.httpProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest}) return q.configManager.GetOriginProxy().ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest})
} }
return nil return nil
} }

View File

@ -627,13 +627,12 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection
NextProtos: []string{"argotunnel"}, NextProtos: []string{"argotunnel"},
} }
// Start a mock httpProxy // Start a mock httpProxy
originProxy := &mockOriginProxyWithRequest{}
log := zerolog.New(os.Stdout) log := zerolog.New(os.Stdout)
qc, err := NewQUICConnection( qc, err := NewQUICConnection(
testQUICConfig, testQUICConfig,
udpListenerAddr, udpListenerAddr,
tlsClientConfig, tlsClientConfig,
originProxy, &mockConfigManager{originProxy: &mockOriginProxyWithRequest{}},
&tunnelpogs.ConnectionOptions{}, &tunnelpogs.ConnectionOptions{},
fakeControlStream{}, fakeControlStream{},
&log, &log,

View File

@ -37,7 +37,7 @@ func NewTunnelServerClient(
} }
} }
func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) { func (tsc *tunnelServerClient) Authenticate(ctx context.Context, classicTunnel *ClassicTunnelProperties, registrationOptions *tunnelpogs.RegistrationOptions) (tunnelpogs.AuthOutcome, error) {
authResp, err := tsc.client.Authenticate(ctx, classicTunnel.OriginCert, classicTunnel.Hostname, registrationOptions) authResp, err := tsc.client.Authenticate(ctx, classicTunnel.OriginCert, classicTunnel.Hostname, registrationOptions)
if err != nil { if err != nil {
return nil, err return nil, err
@ -54,7 +54,7 @@ func (tsc *tunnelServerClient) Close() {
type NamedTunnelRPCClient interface { type NamedTunnelRPCClient interface {
RegisterConnection( RegisterConnection(
c context.Context, c context.Context,
config *NamedTunnelConfig, config *NamedTunnelProperties,
options *tunnelpogs.ConnectionOptions, options *tunnelpogs.ConnectionOptions,
connIndex uint8, connIndex uint8,
observer *Observer, observer *Observer,
@ -86,15 +86,15 @@ func newRegistrationRPCClient(
func (rsc *registrationServerClient) RegisterConnection( func (rsc *registrationServerClient) RegisterConnection(
ctx context.Context, ctx context.Context,
config *NamedTunnelConfig, properties *NamedTunnelProperties,
options *tunnelpogs.ConnectionOptions, options *tunnelpogs.ConnectionOptions,
connIndex uint8, connIndex uint8,
observer *Observer, observer *Observer,
) error { ) error {
conn, err := rsc.client.RegisterConnection( conn, err := rsc.client.RegisterConnection(
ctx, ctx,
config.Credentials.Auth(), properties.Credentials.Auth(),
config.Credentials.TunnelID, properties.Credentials.TunnelID,
connIndex, connIndex,
options, options,
) )
@ -137,7 +137,7 @@ const (
authenticate rpcName = " authenticate" authenticate rpcName = " authenticate"
) )
func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error { func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelProperties, registrationOptions *tunnelpogs.RegistrationOptions) error {
h.observer.sendRegisteringEvent(registrationOptions.ConnectionID) h.observer.sendRegisteringEvent(registrationOptions.ConnectionID)
stream, err := h.newRPCStream(ctx, register) stream, err := h.newRPCStream(ctx, register)
@ -174,7 +174,7 @@ type CredentialManager interface {
func (h *h2muxConnection) processRegistrationSuccess( func (h *h2muxConnection) processRegistrationSuccess(
registration *tunnelpogs.TunnelRegistration, registration *tunnelpogs.TunnelRegistration,
name rpcName, name rpcName,
credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, credentialManager CredentialManager, classicTunnel *ClassicTunnelProperties,
) error { ) error {
for _, logLine := range registration.LogLines { for _, logLine := range registration.LogLines {
h.observer.log.Info().Msg(logLine) h.observer.log.Info().Msg(logLine)
@ -205,7 +205,7 @@ func (h *h2muxConnection) processRegisterTunnelError(err tunnelpogs.TunnelRegist
} }
} }
func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error { func (h *h2muxConnection) reconnectTunnel(ctx context.Context, credentialManager CredentialManager, classicTunnel *ClassicTunnelProperties, registrationOptions *tunnelpogs.RegistrationOptions) error {
token, err := credentialManager.ReconnectToken() token, err := credentialManager.ReconnectToken()
if err != nil { if err != nil {
return err return err
@ -264,7 +264,7 @@ func (h *h2muxConnection) logServerInfo(ctx context.Context, rpcClient *tunnelSe
func (h *h2muxConnection) registerNamedTunnel( func (h *h2muxConnection) registerNamedTunnel(
ctx context.Context, ctx context.Context,
namedTunnel *NamedTunnelConfig, namedTunnel *NamedTunnelProperties,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
) error { ) error {
stream, err := h.newRPCStream(ctx, register) stream, err := h.newRPCStream(ctx, register)
@ -283,7 +283,7 @@ func (h *h2muxConnection) registerNamedTunnel(
func (h *h2muxConnection) unregister(isNamedTunnel bool) { func (h *h2muxConnection) unregister(isNamedTunnel bool) {
h.observer.sendUnregisteringEvent(h.connIndex) h.observer.sendUnregisteringEvent(h.connIndex)
unregisterCtx, cancel := context.WithTimeout(context.Background(), h.config.GracePeriod) unregisterCtx, cancel := context.WithTimeout(context.Background(), h.gracePeriod)
defer cancel() defer cancel()
stream, err := h.newRPCStream(unregisterCtx, unregister) stream, err := h.newRPCStream(unregisterCtx, unregister)
@ -296,13 +296,13 @@ func (h *h2muxConnection) unregister(isNamedTunnel bool) {
rpcClient := h.newRPCClientFunc(unregisterCtx, stream, h.observer.log) rpcClient := h.newRPCClientFunc(unregisterCtx, stream, h.observer.log)
defer rpcClient.Close() defer rpcClient.Close()
rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod) rpcClient.GracefulShutdown(unregisterCtx, h.gracePeriod)
} else { } else {
rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer.log) rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer.log)
defer rpcClient.Close() defer rpcClient.Close()
// gracePeriod is encoded in int64 using capnproto // gracePeriod is encoded in int64 using capnproto
_ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.config.GracePeriod.Nanoseconds()) _ = rpcClient.client.UnregisterTunnel(unregisterCtx, h.gracePeriod.Nanoseconds())
} }
h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unregistered tunnel connection") h.observer.log.Info().Uint8(LogFieldConnIndex, h.connIndex).Msg("Unregistered tunnel connection")

View File

@ -1,4 +1,4 @@
package origin package proxy
import ( import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -43,14 +43,6 @@ var (
Help: "Count of error proxying to origin", Help: "Count of error proxying to origin",
}, },
) )
haConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: connection.MetricsNamespace,
Subsystem: connection.TunnelSubsystem,
Name: "ha_connections",
Help: "Number of active ha connections",
},
)
) )
func init() { func init() {
@ -59,7 +51,6 @@ func init() {
concurrentRequests, concurrentRequests,
responseByCode, responseByCode,
requestErrors, requestErrors,
haConnections,
) )
} }

View File

@ -1,4 +1,4 @@
package origin package proxy
import ( import (
"sync" "sync"

View File

@ -1,4 +1,4 @@
package origin package proxy
import ( import (
"bufio" "bufio"
@ -28,7 +28,7 @@ const (
// Proxy represents a means to Proxy between cloudflared and the origin services. // Proxy represents a means to Proxy between cloudflared and the origin services.
type Proxy struct { type Proxy struct {
ingressRules ingress.Ingress ingressRules *ingress.Ingress
warpRouting *ingress.WarpRoutingService warpRouting *ingress.WarpRoutingService
tags []tunnelpogs.Tag tags []tunnelpogs.Tag
log *zerolog.Logger log *zerolog.Logger
@ -37,7 +37,7 @@ type Proxy struct {
// NewOriginProxy returns a new instance of the Proxy struct. // NewOriginProxy returns a new instance of the Proxy struct.
func NewOriginProxy( func NewOriginProxy(
ingressRules ingress.Ingress, ingressRules *ingress.Ingress,
warpRouting *ingress.WarpRoutingService, warpRouting *ingress.WarpRoutingService,
tags []tunnelpogs.Tag, tags []tunnelpogs.Tag,
log *zerolog.Logger, log *zerolog.Logger,
@ -139,7 +139,7 @@ func (p *Proxy) ProxyTCP(
return nil return nil
} }
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) { func ruleField(ing *ingress.Ingress, ruleNum int) (ruleID string, srv string) {
srv = ing.Rules[ruleNum].Service.String() srv = ing.Rules[ruleNum].Service.String()
if ing.IsSingleRule() { if ing.IsSingleRule() {
return "", srv return "", srv

View File

@ -1,7 +1,7 @@
//go:build !windows //go:build !windows
// +build !windows // +build !windows
package origin package proxy
import ( import (
"io/ioutil" "io/ioutil"

View File

@ -1,4 +1,4 @@
package origin package proxy
import ( import (
"bytes" "bytes"
@ -135,7 +135,7 @@ func TestProxySingleOrigin(t *testing.T) {
errC := make(chan error) errC := make(chan error)
require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC)) require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC))
proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, &log) proxy := NewOriginProxy(&ingressRule, unusedWarpRoutingService, testTags, &log)
t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyHTTP", testProxyHTTP(proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
t.Run("testProxySSE", testProxySSE(proxy)) t.Run("testProxySSE", testProxySSE(proxy))
@ -345,7 +345,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
var wg sync.WaitGroup var wg sync.WaitGroup
require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC)) require.NoError(t, ingress.StartOrigins(&wg, &log, ctx.Done(), errC))
proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log) proxy := NewOriginProxy(&ingress, unusedWarpRoutingService, testTags, &log)
for _, test := range tests { for _, test := range tests {
responseWriter := newMockHTTPRespWriter() responseWriter := newMockHTTPRespWriter()
@ -394,7 +394,7 @@ func TestProxyError(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
proxy := NewOriginProxy(ing, unusedWarpRoutingService, testTags, &log) proxy := NewOriginProxy(&ing, unusedWarpRoutingService, testTags, &log)
responseWriter := newMockHTTPRespWriter() responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
@ -637,7 +637,7 @@ func TestConnections(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
errC := make(chan error) errC := make(chan error)
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC)
proxy := NewOriginProxy(ingressRule, test.args.warpRoutingService, testTags, logger) proxy := NewOriginProxy(&ingressRule, test.args.warpRoutingService, testTags, logger)
dest := ln.Addr().String() dest := ln.Addr().String()
req, err := http.NewRequest( req, err := http.NewRequest(

View File

@ -1,4 +1,4 @@
package origin package supervisor
import ( import (
"encoding/json" "encoding/json"

View File

@ -1,4 +1,4 @@
package origin package supervisor
import ( import (
"testing" "testing"

View File

@ -0,0 +1,55 @@
package supervisor
import (
"sync"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/proxy"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
type configManager struct {
currentVersion int32
// Only used by UpdateConfig
updateLock sync.Mutex
// TODO: TUN-5698: Make proxy atomic.Value
proxy *proxy.Proxy
config *DynamicConfig
tags []tunnelpogs.Tag
log *zerolog.Logger
}
func newConfigManager(config *DynamicConfig, tags []tunnelpogs.Tag, log *zerolog.Logger) *configManager {
var warpRoutingService *ingress.WarpRoutingService
if config.WarpRoutingEnabled {
warpRoutingService = ingress.NewWarpRoutingService()
log.Info().Msgf("Warp-routing is enabled")
}
return &configManager{
// Lowest possible version, any remote configuration will have version higher than this
currentVersion: 0,
proxy: proxy.NewOriginProxy(config.Ingress, warpRoutingService, tags, log),
config: config,
log: log,
}
}
func (cm *configManager) Update(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
// TODO: TUN-5698: make ingress configurable
return &tunnelpogs.UpdateConfigurationResponse{
LastAppliedVersion: cm.currentVersion,
}
}
func (cm *configManager) GetOriginProxy() connection.OriginProxy {
return cm.proxy
}
type DynamicConfig struct {
Ingress *ingress.Ingress
WarpRoutingEnabled bool
}

View File

@ -1,4 +1,4 @@
package origin package supervisor
import ( import (
"github.com/rs/zerolog" "github.com/rs/zerolog"

View File

@ -1,4 +1,4 @@
package origin package supervisor
import ( import (
"time" "time"

27
supervisor/metrics.go Normal file
View File

@ -0,0 +1,27 @@
package supervisor
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/cloudflare/cloudflared/connection"
)
// Metrics uses connection.MetricsNamespace(aka cloudflared) as namespace and connection.TunnelSubsystem
// (tunnel) as subsystem to keep them consistent with the previous qualifier.
var (
haConnections = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: connection.MetricsNamespace,
Subsystem: connection.TunnelSubsystem,
Name: "ha_connections",
Help: "Number of active ha connections",
},
)
)
func init() {
prometheus.MustRegister(
haConnections,
)
}

View File

@ -1,4 +1,4 @@
package origin package supervisor
import ( import (
"context" "context"

View File

@ -1,4 +1,4 @@
package origin package supervisor
import ( import (
"context" "context"

View File

@ -1,4 +1,4 @@
package origin package supervisor
import ( import (
"context" "context"
@ -37,6 +37,7 @@ const (
// reconnects them if they disconnect. // reconnects them if they disconnect.
type Supervisor struct { type Supervisor struct {
cloudflaredUUID uuid.UUID cloudflaredUUID uuid.UUID
configManager *configManager
config *TunnelConfig config *TunnelConfig
edgeIPs *edgediscovery.Edge edgeIPs *edgediscovery.Edge
tunnelErrors chan tunnelError tunnelErrors chan tunnelError
@ -64,7 +65,7 @@ type tunnelError struct {
err error err error
} }
func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) { func NewSupervisor(config *TunnelConfig, dynamiConfig *DynamicConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
cloudflaredUUID, err := uuid.NewRandom() cloudflaredUUID, err := uuid.NewRandom()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err) return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err)
@ -88,6 +89,7 @@ func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, grace
return &Supervisor{ return &Supervisor{
cloudflaredUUID: cloudflaredUUID, cloudflaredUUID: cloudflaredUUID,
config: config, config: config,
configManager: newConfigManager(dynamiConfig, config.Tags, config.Log),
edgeIPs: edgeIPs, edgeIPs: edgeIPs,
tunnelErrors: make(chan tunnelError), tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{}, tunnelsConnecting: map[int]chan struct{}{},
@ -242,6 +244,7 @@ func (s *Supervisor) startFirstTunnel(
err = ServeTunnelLoop( err = ServeTunnelLoop(
ctx, ctx,
s.reconnectCredentialManager, s.reconnectCredentialManager,
s.configManager,
s.config, s.config,
addr, addr,
s.log, s.log,
@ -276,6 +279,7 @@ func (s *Supervisor) startFirstTunnel(
err = ServeTunnelLoop( err = ServeTunnelLoop(
ctx, ctx,
s.reconnectCredentialManager, s.reconnectCredentialManager,
s.configManager,
s.config, s.config,
addr, addr,
s.log, s.log,
@ -310,6 +314,7 @@ func (s *Supervisor) startTunnel(
err = ServeTunnelLoop( err = ServeTunnelLoop(
ctx, ctx,
s.reconnectCredentialManager, s.reconnectCredentialManager,
s.configManager,
s.config, s.config,
addr, addr,
s.log, s.log,
@ -380,7 +385,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
defer rpcClient.Close() defer rpcClient.Close()
const arbitraryConnectionID = uint8(0) const arbitraryConnectionID = uint8(0)
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID) registrationOptions := s.config.registrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts) registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions) return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions)
} }

View File

@ -1,4 +1,4 @@
package origin package supervisor
import ( import (
"context" "context"
@ -34,7 +34,8 @@ const (
) )
type TunnelConfig struct { type TunnelConfig struct {
ConnectionConfig *connection.Config GracePeriod time.Duration
ReplaceExisting bool
OSArch string OSArch string
ClientID string ClientID string
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
@ -52,14 +53,14 @@ type TunnelConfig struct {
Retries uint Retries uint
RunFromTerminal bool RunFromTerminal bool
NamedTunnel *connection.NamedTunnelConfig NamedTunnel *connection.NamedTunnelProperties
ClassicTunnel *connection.ClassicTunnelConfig ClassicTunnel *connection.ClassicTunnelProperties
MuxerConfig *connection.MuxerConfig MuxerConfig *connection.MuxerConfig
ProtocolSelector connection.ProtocolSelector ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config EdgeTLSConfigs map[connection.Protocol]*tls.Config
} }
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { func (c *TunnelConfig) registrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
policy := tunnelrpc.ExistingTunnelPolicy_balance policy := tunnelrpc.ExistingTunnelPolicy_balance
if c.HAConnections <= 1 && c.LBPool == "" { if c.HAConnections <= 1 && c.LBPool == "" {
policy = tunnelrpc.ExistingTunnelPolicy_disconnect policy = tunnelrpc.ExistingTunnelPolicy_disconnect
@ -81,7 +82,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
} }
} }
func (c *TunnelConfig) ConnectionOptions(originLocalAddr string, numPreviousAttempts uint8) *tunnelpogs.ConnectionOptions { func (c *TunnelConfig) connectionOptions(originLocalAddr string, numPreviousAttempts uint8) *tunnelpogs.ConnectionOptions {
// attempt to parse out origin IP, but don't fail since it's informational field // attempt to parse out origin IP, but don't fail since it's informational field
host, _, _ := net.SplitHostPort(originLocalAddr) host, _, _ := net.SplitHostPort(originLocalAddr)
originIP := net.ParseIP(host) originIP := net.ParseIP(host)
@ -89,7 +90,7 @@ func (c *TunnelConfig) ConnectionOptions(originLocalAddr string, numPreviousAtte
return &tunnelpogs.ConnectionOptions{ return &tunnelpogs.ConnectionOptions{
Client: c.NamedTunnel.Client, Client: c.NamedTunnel.Client,
OriginLocalIP: originIP, OriginLocalIP: originIP,
ReplaceExisting: c.ConnectionConfig.ReplaceExisting, ReplaceExisting: c.ReplaceExisting,
CompressionQuality: uint8(c.MuxerConfig.CompressionSetting), CompressionQuality: uint8(c.MuxerConfig.CompressionSetting),
NumPreviousAttempts: numPreviousAttempts, NumPreviousAttempts: numPreviousAttempts,
} }
@ -106,11 +107,12 @@ func (c *TunnelConfig) SupportedFeatures() []string {
func StartTunnelDaemon( func StartTunnelDaemon(
ctx context.Context, ctx context.Context,
config *TunnelConfig, config *TunnelConfig,
dynamiConfig *DynamicConfig,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
reconnectCh chan ReconnectSignal, reconnectCh chan ReconnectSignal,
graceShutdownC <-chan struct{}, graceShutdownC <-chan struct{},
) error { ) error {
s, err := NewSupervisor(config, reconnectCh, graceShutdownC) s, err := NewSupervisor(config, dynamiConfig, reconnectCh, graceShutdownC)
if err != nil { if err != nil {
return err return err
} }
@ -120,6 +122,7 @@ func StartTunnelDaemon(
func ServeTunnelLoop( func ServeTunnelLoop(
ctx context.Context, ctx context.Context,
credentialManager *reconnectCredentialManager, credentialManager *reconnectCredentialManager,
configManager *configManager,
config *TunnelConfig, config *TunnelConfig,
addr *allregions.EdgeAddr, addr *allregions.EdgeAddr,
connAwareLogger *ConnAwareLogger, connAwareLogger *ConnAwareLogger,
@ -155,6 +158,7 @@ func ServeTunnelLoop(
ctx, ctx,
connLog, connLog,
credentialManager, credentialManager,
configManager,
config, config,
addr, addr,
connIndex, connIndex,
@ -253,6 +257,7 @@ func ServeTunnel(
ctx context.Context, ctx context.Context,
connLog *ConnAwareLogger, connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager, credentialManager *reconnectCredentialManager,
configManager *configManager,
config *TunnelConfig, config *TunnelConfig,
addr *allregions.EdgeAddr, addr *allregions.EdgeAddr,
connIndex uint8, connIndex uint8,
@ -281,6 +286,7 @@ func ServeTunnel(
ctx, ctx,
connLog, connLog,
credentialManager, credentialManager,
configManager,
config, config,
addr, addr,
connIndex, connIndex,
@ -329,6 +335,7 @@ func serveTunnel(
ctx context.Context, ctx context.Context,
connLog *ConnAwareLogger, connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager, credentialManager *reconnectCredentialManager,
configManager *configManager,
config *TunnelConfig, config *TunnelConfig,
addr *allregions.EdgeAddr, addr *allregions.EdgeAddr,
connIndex uint8, connIndex uint8,
@ -339,7 +346,6 @@ func serveTunnel(
protocol connection.Protocol, protocol connection.Protocol,
gracefulShutdownC <-chan struct{}, gracefulShutdownC <-chan struct{},
) (err error, recoverable bool) { ) (err error, recoverable bool) {
connectedFuse := &connectedFuse{ connectedFuse := &connectedFuse{
fuse: fuse, fuse: fuse,
backoff: backoff, backoff: backoff,
@ -351,14 +357,15 @@ func serveTunnel(
connIndex, connIndex,
nil, nil,
gracefulShutdownC, gracefulShutdownC,
config.ConnectionConfig.GracePeriod, config.GracePeriod,
) )
switch protocol { switch protocol {
case connection.QUIC, connection.QUICWarp: case connection.QUIC, connection.QUICWarp:
connOptions := config.ConnectionOptions(addr.UDP.String(), uint8(backoff.Retries())) connOptions := config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
return ServeQUIC(ctx, return ServeQUIC(ctx,
addr.UDP, addr.UDP,
configManager,
config, config,
connLog, connLog,
connOptions, connOptions,
@ -374,10 +381,11 @@ func serveTunnel(
return err, true return err, true
} }
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries())) connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
if err := ServeHTTP2( if err := ServeHTTP2(
ctx, ctx,
connLog, connLog,
configManager,
config, config,
edgeConn, edgeConn,
connOptions, connOptions,
@ -400,6 +408,7 @@ func serveTunnel(
ctx, ctx,
connLog, connLog,
credentialManager, credentialManager,
configManager,
config, config,
edgeConn, edgeConn,
connIndex, connIndex,
@ -426,6 +435,7 @@ func ServeH2mux(
ctx context.Context, ctx context.Context,
connLog *ConnAwareLogger, connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager, credentialManager *reconnectCredentialManager,
configManager *configManager,
config *TunnelConfig, config *TunnelConfig,
edgeConn net.Conn, edgeConn net.Conn,
connIndex uint8, connIndex uint8,
@ -437,7 +447,8 @@ func ServeH2mux(
connLog.Logger().Debug().Msgf("Connecting via h2mux") connLog.Logger().Debug().Msgf("Connecting via h2mux")
// Returns error from parsing the origin URL or handshake errors // Returns error from parsing the origin URL or handshake errors
handler, err, recoverable := connection.NewH2muxConnection( handler, err, recoverable := connection.NewH2muxConnection(
config.ConnectionConfig, configManager,
config.GracePeriod,
config.MuxerConfig, config.MuxerConfig,
edgeConn, edgeConn,
connIndex, connIndex,
@ -455,10 +466,10 @@ func ServeH2mux(
errGroup.Go(func() error { errGroup.Go(func() error {
if config.NamedTunnel != nil { if config.NamedTunnel != nil {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries())) connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries()))
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse) return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse)
} }
registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) registrationOptions := config.registrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse) return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
}) })
@ -472,6 +483,7 @@ func ServeH2mux(
func ServeHTTP2( func ServeHTTP2(
ctx context.Context, ctx context.Context,
connLog *ConnAwareLogger, connLog *ConnAwareLogger,
configManager *configManager,
config *TunnelConfig, config *TunnelConfig,
tlsServerConn net.Conn, tlsServerConn net.Conn,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
@ -483,7 +495,7 @@ func ServeHTTP2(
connLog.Logger().Debug().Msgf("Connecting via http2") connLog.Logger().Debug().Msgf("Connecting via http2")
h2conn := connection.NewHTTP2Connection( h2conn := connection.NewHTTP2Connection(
tlsServerConn, tlsServerConn,
config.ConnectionConfig, configManager,
connOptions, connOptions,
config.Observer, config.Observer,
connIndex, connIndex,
@ -511,6 +523,7 @@ func ServeHTTP2(
func ServeQUIC( func ServeQUIC(
ctx context.Context, ctx context.Context,
edgeAddr *net.UDPAddr, edgeAddr *net.UDPAddr,
configManager *configManager,
config *TunnelConfig, config *TunnelConfig,
connLogger *ConnAwareLogger, connLogger *ConnAwareLogger,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
@ -535,7 +548,7 @@ func ServeQUIC(
quicConfig, quicConfig,
edgeAddr, edgeAddr,
tlsConfig, tlsConfig,
config.ConnectionConfig.OriginProxy, configManager,
connOptions, connOptions,
controlStreamHandler, controlStreamHandler,
connLogger.Logger()) connLogger.Logger())

View File

@ -1,4 +1,4 @@
package origin package supervisor
import ( import (
"testing" "testing"
@ -32,11 +32,7 @@ func TestWaitForBackoffFallback(t *testing.T) {
} }
log := zerolog.Nop() log := zerolog.Nop()
resolveTTL := time.Duration(0) resolveTTL := time.Duration(0)
namedTunnel := &connection.NamedTunnelConfig{ namedTunnel := &connection.NamedTunnelProperties{}
Credentials: connection.Credentials{
AccountTag: "test-account",
},
}
mockFetcher := dynamicMockFetcher{ mockFetcher := dynamicMockFetcher{
protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}, protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}},
} }

View File

@ -1,4 +1,4 @@
package origin package supervisor
import ( import (
"fmt" "fmt"