diff --git a/client/config.go b/client/config.go new file mode 100644 index 00000000..b4e053e7 --- /dev/null +++ b/client/config.go @@ -0,0 +1,74 @@ +package client + +import ( + "fmt" + "net" + + "github.com/google/uuid" + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/features" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" +) + +// Config captures the local client runtime configuration. +type Config struct { + ConnectorID uuid.UUID + Version string + Arch string + + featureSelector features.FeatureSelector +} + +func NewConfig(version string, arch string, featureSelector features.FeatureSelector) (*Config, error) { + connectorID, err := uuid.NewRandom() + if err != nil { + return nil, fmt.Errorf("unable to generate a connector UUID: %w", err) + } + return &Config{ + ConnectorID: connectorID, + Version: version, + Arch: arch, + featureSelector: featureSelector, + }, nil +} + +// ConnectionOptionsSnapshot is a snapshot of the current client information used to initialize a connection. +// +// The FeatureSnapshot is the features that are available for this connection. At the client level they may +// change, but they will not change within the scope of this struct. +type ConnectionOptionsSnapshot struct { + client pogs.ClientInfo + originLocalIP net.IP + numPreviousAttempts uint8 + FeatureSnapshot features.FeatureSnapshot +} + +func (c *Config) ConnectionOptionsSnapshot(originIP net.IP, previousAttempts uint8) *ConnectionOptionsSnapshot { + snapshot := c.featureSelector.Snapshot() + return &ConnectionOptionsSnapshot{ + client: pogs.ClientInfo{ + ClientID: c.ConnectorID[:], + Version: c.Version, + Arch: c.Arch, + Features: snapshot.FeaturesList, + }, + originLocalIP: originIP, + numPreviousAttempts: previousAttempts, + FeatureSnapshot: snapshot, + } +} + +func (c ConnectionOptionsSnapshot) ConnectionOptions() *pogs.ConnectionOptions { + return &pogs.ConnectionOptions{ + Client: c.client, + OriginLocalIP: c.originLocalIP, + ReplaceExisting: false, + CompressionQuality: 0, + NumPreviousAttempts: c.numPreviousAttempts, + } +} + +func (c ConnectionOptionsSnapshot) LogFields(event *zerolog.Event) *zerolog.Event { + return event.Strs("features", c.client.Features) +} diff --git a/client/config_test.go b/client/config_test.go new file mode 100644 index 00000000..5fe4e7c1 --- /dev/null +++ b/client/config_test.go @@ -0,0 +1,50 @@ +package client + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cloudflare/cloudflared/features" +) + +func TestGenerateConnectionOptions(t *testing.T) { + version := "1234" + arch := "linux_amd64" + originIP := net.ParseIP("192.168.1.1") + var previousAttempts uint8 = 4 + + config, err := NewConfig(version, arch, &mockFeatureSelector{}) + require.NoError(t, err) + require.Equal(t, version, config.Version) + require.Equal(t, arch, config.Arch) + + // Validate ConnectionOptionsSnapshot fields + connOptions := config.ConnectionOptionsSnapshot(originIP, previousAttempts) + require.Equal(t, version, connOptions.client.Version) + require.Equal(t, arch, connOptions.client.Arch) + require.Equal(t, config.ConnectorID[:], connOptions.client.ClientID) + + // Vaidate snapshot feature fields against the connOptions generated + snapshot := config.featureSelector.Snapshot() + require.Equal(t, features.DatagramV3, snapshot.DatagramVersion) + require.Equal(t, features.DatagramV3, connOptions.FeatureSnapshot.DatagramVersion) + + pogsConnOptions := connOptions.ConnectionOptions() + require.Equal(t, connOptions.client, pogsConnOptions.Client) + require.Equal(t, originIP, pogsConnOptions.OriginLocalIP) + require.False(t, pogsConnOptions.ReplaceExisting) + require.Equal(t, uint8(0), pogsConnOptions.CompressionQuality) + require.Equal(t, previousAttempts, pogsConnOptions.NumPreviousAttempts) +} + +type mockFeatureSelector struct{} + +func (m *mockFeatureSelector) Snapshot() features.FeatureSnapshot { + return features.FeatureSnapshot{ + PostQuantum: features.PostQuantumPrefer, + DatagramVersion: features.DatagramV3, + FeaturesList: []string{features.FeaturePostQuantum, features.FeatureDatagramV3_1}, + } +} diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 378802ce..89b5448d 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -15,7 +15,6 @@ import ( "github.com/coreos/go-systemd/v22/daemon" "github.com/facebookgo/grace/gracenet" "github.com/getsentry/sentry-go" - "github.com/google/uuid" "github.com/mitchellh/go-homedir" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -446,14 +445,7 @@ func StartServer( log.Err(err).Msg("Couldn't start tunnel") return err } - var clientID uuid.UUID - if tunnelConfig.NamedTunnel != nil { - clientID, err = uuid.FromBytes(tunnelConfig.NamedTunnel.Client.ClientID) - if err != nil { - // set to nil for classic tunnels - clientID = uuid.Nil - } - } + connectorID := tunnelConfig.ClientConfig.ConnectorID // Disable ICMP packet routing for quick tunnels if quickTunnelURL != "" { @@ -471,7 +463,7 @@ func StartServer( c.String("management-hostname"), c.Bool("management-diagnostics"), serviceIP, - clientID, + connectorID, c.String(cfdflags.ConnectorLabel), logger.ManagementLogger.Log, logger.ManagementLogger, @@ -503,14 +495,14 @@ func StartServer( sources = append(sources, ipv6.String()) } - readinessServer := metrics.NewReadyServer(clientID, tracker) + readinessServer := metrics.NewReadyServer(connectorID, tracker) cliFlags := nonSecretCliFlags(log, c, nonSecretFlagsList) diagnosticHandler := diagnostic.NewDiagnosticHandler( log, 0, diagnostic.NewSystemCollectorImpl(buildInfo.CloudflaredVersion), tunnelConfig.NamedTunnel.Credentials.TunnelID, - clientID, + connectorID, tracker, cliFlags, sources, diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index b78dba89..7961c813 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -10,13 +10,13 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/pkg/errors" "github.com/rs/zerolog" "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" "golang.org/x/term" + "github.com/cloudflare/cloudflared/client" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/flags" "github.com/cloudflare/cloudflared/config" @@ -125,27 +125,29 @@ func prepareTunnelConfig( observer *connection.Observer, namedTunnel *connection.TunnelProperties, ) (*supervisor.TunnelConfig, *orchestration.Config, error) { - clientID, err := uuid.NewRandom() + transportProtocol := c.String(flags.Protocol) + isPostQuantumEnforced := c.Bool(flags.PostQuantum) + featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, c.StringSlice(flags.Features), isPostQuantumEnforced, log) if err != nil { - return nil, nil, errors.Wrap(err, "can't generate connector UUID") + return nil, nil, errors.Wrap(err, "Failed to create feature selector") } - log.Info().Msgf("Generated Connector ID: %s", clientID) + + clientConfig, err := client.NewConfig(info.Version(), info.OSArch(), featureSelector) + if err != nil { + return nil, nil, err + } + + log.Info().Msgf("Generated Connector ID: %s", clientConfig.ConnectorID) + tags, err := NewTagSliceFromCLI(c.StringSlice(flags.Tag)) if err != nil { log.Err(err).Msg("Tag parse failure") return nil, nil, errors.Wrap(err, "Tag parse failure") } - tags = append(tags, pogs.Tag{Name: "ID", Value: clientID.String()}) + tags = append(tags, pogs.Tag{Name: "ID", Value: clientConfig.ConnectorID.String()}) - transportProtocol := c.String(flags.Protocol) - isPostQuantumEnforced := c.Bool(flags.PostQuantum) - - featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, c.StringSlice(flags.Features), c.Bool(flags.PostQuantum), log) - if err != nil { - return nil, nil, errors.Wrap(err, "Failed to create feature selector") - } - clientFeatures := featureSelector.ClientFeatures() - pqMode := featureSelector.PostQuantumMode() + clientFeatures := featureSelector.Snapshot() + pqMode := clientFeatures.PostQuantum if pqMode == features.PostQuantumStrict { // Error if the user tries to force a non-quic transport protocol if transportProtocol != connection.AutoSelectFlag && transportProtocol != connection.QUIC.String() { @@ -154,12 +156,6 @@ func prepareTunnelConfig( transportProtocol = connection.QUIC.String() } - namedTunnel.Client = pogs.ClientInfo{ - ClientID: clientID[:], - Features: clientFeatures, - Version: info.Version(), - Arch: info.OSArch(), - } cfg := config.GetConfiguration() ingressRules, err := ingress.ParseIngressFromConfigAndCLI(cfg, c, log) if err != nil { @@ -224,10 +220,8 @@ func prepareTunnelConfig( } tunnelConfig := &supervisor.TunnelConfig{ + ClientConfig: clientConfig, GracePeriod: gracePeriod, - ReplaceExisting: c.Bool(flags.Force), - OSArch: info.OSArch(), - ClientID: clientID.String(), EdgeAddrs: c.StringSlice(flags.Edge), Region: resolvedRegion, EdgeIPVersion: edgeIPVersion, @@ -246,7 +240,6 @@ func prepareTunnelConfig( NamedTunnel: namedTunnel, ProtocolSelector: protocolSelector, EdgeTLSConfigs: edgeTLSConfigs, - FeatureSelector: featureSelector, MaxEdgeAddrRetries: uint8(c.Int(flags.MaxEdgeAddrRetries)), // nolint: gosec RPCTimeout: c.Duration(flags.RpcTimeout), WriteStreamTimeout: c.Duration(flags.WriteStreamTimeout), diff --git a/connection/connection.go b/connection/connection.go index 8c05eeea..4803e930 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -57,7 +57,6 @@ type Orchestrator interface { type TunnelProperties struct { Credentials Credentials - Client pogs.ClientInfo QuickTunnelUrl string } diff --git a/connection/http2.go b/connection/http2.go index 243c464d..c7e14c67 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -16,10 +16,10 @@ import ( "github.com/rs/zerolog" "golang.org/x/net/http2" + "github.com/cloudflare/cloudflared/client" cfdflow "github.com/cloudflare/cloudflared/flow" "github.com/cloudflare/cloudflared/tracing" - "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) // note: these constants are exported so we can reuse them in the edge-side code @@ -39,7 +39,7 @@ type HTTP2Connection struct { conn net.Conn server *http2.Server orchestrator Orchestrator - connOptions *pogs.ConnectionOptions + connOptions *client.ConnectionOptionsSnapshot observer *Observer connIndex uint8 @@ -54,7 +54,7 @@ type HTTP2Connection struct { func NewHTTP2Connection( conn net.Conn, orchestrator Orchestrator, - connOptions *pogs.ConnectionOptions, + connOptions *client.ConnectionOptionsSnapshot, observer *Observer, connIndex uint8, controlStreamHandler ControlStreamHandler, @@ -118,7 +118,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { var requestErr error switch connType { case TypeControlStream: - requestErr = c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, c.orchestrator) + requestErr = c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions.ConnectionOptions(), c.orchestrator) if requestErr != nil { c.controlStreamErr = requestErr } diff --git a/connection/http2_test.go b/connection/http2_test.go index d2045600..a1b01563 100644 --- a/connection/http2_test.go +++ b/connection/http2_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/net/http2" + "github.com/cloudflare/cloudflared/client" "github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tunnelrpc" @@ -51,7 +52,7 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) { cfdConn, // OriginProxy is set in testConfigManager testOrchestrator, - &pogs.ConnectionOptions{}, + &client.ConnectionOptionsSnapshot{}, obs, connIndex, controlStream, @@ -74,7 +75,7 @@ func TestHTTP2ConfigurationSet(t *testing.T) { require.NoError(t, err) reqBody := []byte(`{ -"version": 2, +"version": 2, "config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}} `) reader := bytes.NewReader(reqBody) diff --git a/connection/quic_connection.go b/connection/quic_connection.go index 3451369c..74e1f4a4 100644 --- a/connection/quic_connection.go +++ b/connection/quic_connection.go @@ -17,6 +17,7 @@ import ( "github.com/rs/zerolog" "golang.org/x/sync/errgroup" + "github.com/cloudflare/cloudflared/client" cfdflow "github.com/cloudflare/cloudflared/flow" cfdquic "github.com/cloudflare/cloudflared/quic" @@ -43,7 +44,7 @@ type quicConnection struct { orchestrator Orchestrator datagramHandler DatagramSessionHandler controlStreamHandler ControlStreamHandler - connOptions *pogs.ConnectionOptions + connOptions *client.ConnectionOptionsSnapshot connIndex uint8 rpcTimeout time.Duration @@ -59,7 +60,7 @@ func NewTunnelConnection( orchestrator Orchestrator, datagramSessionHandler DatagramSessionHandler, controlStreamHandler ControlStreamHandler, - connOptions *pogs.ConnectionOptions, + connOptions *client.ConnectionOptionsSnapshot, rpcTimeout time.Duration, streamWriteTimeout time.Duration, gracePeriod time.Duration, @@ -130,7 +131,7 @@ func (q *quicConnection) Serve(ctx context.Context) error { // serveControlStream will serve the RPC; blocking until the control plane is done. func (q *quicConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error { - return q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions, q.orchestrator) + return q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions.ConnectionOptions(), q.orchestrator) } // Close the connection with no errors specified. diff --git a/connection/quic_connection_test.go b/connection/quic_connection_test.go index 49968372..8027447f 100644 --- a/connection/quic_connection_test.go +++ b/connection/quic_connection_test.go @@ -29,6 +29,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/net/nettest" + "github.com/cloudflare/cloudflared/client" cfdflow "github.com/cloudflare/cloudflared/flow" "github.com/cloudflare/cloudflared/datagramsession" @@ -843,7 +844,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) &mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}}, datagramConn, fakeControlStream{}, - &pogs.ConnectionOptions{}, + &client.ConnectionOptionsSnapshot{}, 15*time.Second, 0*time.Second, 0*time.Second, diff --git a/features/features.go b/features/features.go index c16b51fd..2c7e6850 100644 --- a/features/features.go +++ b/features/features.go @@ -33,6 +33,15 @@ type staticFeatures struct { PostQuantumMode *PostQuantumMode } +type FeatureSnapshot struct { + PostQuantum PostQuantumMode + DatagramVersion DatagramVersion + + // We provide the list of features since we need it to send in the ConnectionOptions during connection + // registrations. + FeaturesList []string +} + type PostQuantumMode uint8 const ( diff --git a/features/selector.go b/features/selector.go index c4925c45..38cc43b4 100644 --- a/features/selector.go +++ b/features/selector.go @@ -7,6 +7,7 @@ import ( "hash/fnv" "net" "slices" + "sync" "time" "github.com/rs/zerolog" @@ -15,22 +16,29 @@ import ( const ( featureSelectorHostname = "cfd-features.argotunnel.com" lookupTimeout = time.Second * 10 + defaultLookupFreq = time.Hour ) // If the TXT record adds other fields, the umarshal logic will ignore those keys // If the TXT record is missing a key, the field will unmarshal to the default Go value type featuresRecord struct { + DatagramV3Percentage uint32 `json:"dv3_1"` + // DatagramV3Percentage int32 `json:"dv3"` // Removed in TUN-9291 // PostQuantumPercentage int32 `json:"pq"` // Removed in TUN-7970 } -func NewFeatureSelector(ctx context.Context, accountTag string, cliFeatures []string, pq bool, logger *zerolog.Logger) (*FeatureSelector, error) { - return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), cliFeatures, pq) +func NewFeatureSelector(ctx context.Context, accountTag string, cliFeatures []string, pq bool, logger *zerolog.Logger) (FeatureSelector, error) { + return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), cliFeatures, pq, defaultLookupFreq) +} + +type FeatureSelector interface { + Snapshot() FeatureSnapshot } // FeatureSelector determines if this account will try new features; loaded once during startup. -type FeatureSelector struct { +type featureSelector struct { accountHash uint32 logger *zerolog.Logger resolver resolver @@ -38,10 +46,12 @@ type FeatureSelector struct { staticFeatures staticFeatures cliFeatures []string - features featuresRecord + // lock protects concurrent access to dynamic features + lock sync.RWMutex + remoteFeatures featuresRecord } -func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, cliFeatures []string, pq bool) (*FeatureSelector, error) { +func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, cliFeatures []string, pq bool, refreshFreq time.Duration) (*featureSelector, error) { // Combine default features and user-provided features var pqMode *PostQuantumMode if pq { @@ -52,7 +62,7 @@ func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog. staticFeatures := staticFeatures{ PostQuantumMode: pqMode, } - selector := &FeatureSelector{ + selector := &featureSelector{ accountHash: switchThreshold(accountTag), logger: logger, resolver: resolver, @@ -60,14 +70,32 @@ func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog. cliFeatures: dedupAndRemoveFeatures(cliFeatures), } - if err := selector.init(ctx); err != nil { + // Load the remote features + if err := selector.refresh(ctx); err != nil { logger.Err(err).Msg("Failed to fetch features, default to disable") } + // Spin off reloading routine + go selector.refreshLoop(ctx, refreshFreq) + return selector, nil } -func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode { +func (fs *featureSelector) Snapshot() FeatureSnapshot { + fs.lock.RLock() + defer fs.lock.RUnlock() + return FeatureSnapshot{ + PostQuantum: fs.postQuantumMode(), + DatagramVersion: fs.datagramVersion(), + FeaturesList: fs.clientFeatures(), + } +} + +func (fs *featureSelector) accountEnabled(percentage uint32) bool { + return percentage > fs.accountHash +} + +func (fs *featureSelector) postQuantumMode() PostQuantumMode { if fs.staticFeatures.PostQuantumMode != nil { return *fs.staticFeatures.PostQuantumMode } @@ -75,7 +103,7 @@ func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode { return PostQuantumPrefer } -func (fs *FeatureSelector) DatagramVersion() DatagramVersion { +func (fs *featureSelector) datagramVersion() DatagramVersion { // If user provides the feature via the cli, we take it as priority over remote feature evaluation if slices.Contains(fs.cliFeatures, FeatureDatagramV3_1) { return DatagramV3 @@ -85,16 +113,20 @@ func (fs *FeatureSelector) DatagramVersion() DatagramVersion { return DatagramV2 } + if fs.accountEnabled(fs.remoteFeatures.DatagramV3Percentage) { + return DatagramV3 + } + return DatagramV2 } -// ClientFeatures will return the list of currently available features that cloudflared should provide to the edge. -func (fs *FeatureSelector) ClientFeatures() []string { +// clientFeatures will return the list of currently available features that cloudflared should provide to the edge. +func (fs *featureSelector) clientFeatures() []string { // Evaluate any remote features along with static feature list to construct the list of features - return dedupAndRemoveFeatures(slices.Concat(defaultFeatures, fs.cliFeatures, []string{string(fs.DatagramVersion())})) + return dedupAndRemoveFeatures(slices.Concat(defaultFeatures, fs.cliFeatures, []string{string(fs.datagramVersion())})) } -func (fs *FeatureSelector) init(ctx context.Context) error { +func (fs *featureSelector) refresh(ctx context.Context) error { record, err := fs.resolver.lookupRecord(ctx) if err != nil { return err @@ -105,11 +137,29 @@ func (fs *FeatureSelector) init(ctx context.Context) error { return err } - fs.features = features + fs.lock.Lock() + defer fs.lock.Unlock() + + fs.remoteFeatures = features return nil } +func (fs *featureSelector) refreshLoop(ctx context.Context, refreshFreq time.Duration) { + ticker := time.NewTicker(refreshFreq) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + err := fs.refresh(ctx) + if err != nil { + fs.logger.Err(err).Msg("Failed to refresh feature selector") + } + } + } +} + // resolver represents an object that can look up featuresRecord type resolver interface { lookupRecord(ctx context.Context) ([]byte, error) diff --git a/features/selector_test.go b/features/selector_test.go index b7d67cc9..96284d78 100644 --- a/features/selector_test.go +++ b/features/selector_test.go @@ -3,17 +3,36 @@ package features import ( "context" "encoding/json" + "fmt" "testing" + "time" "github.com/rs/zerolog" "github.com/stretchr/testify/require" ) +const ( + testAccountTag = "123456" + testAccountHash = 74 // switchThreshold of `accountTag` +) + func TestUnmarshalFeaturesRecord(t *testing.T) { tests := []struct { record []byte expectedPercentage uint32 }{ + { + record: []byte(`{"dv3_1":0}`), + expectedPercentage: 0, + }, + { + record: []byte(`{"dv3_1":39}`), + expectedPercentage: 39, + }, + { + record: []byte(`{"dv3_1":100}`), + expectedPercentage: 100, + }, { record: []byte(`{}`), // Unmarshal to default struct if key is not present }, @@ -29,6 +48,7 @@ func TestUnmarshalFeaturesRecord(t *testing.T) { var features featuresRecord err := json.Unmarshal(test.record, &features) require.NoError(t, err) + require.Equal(t, test.expectedPercentage, features.DatagramV3Percentage, test) } } @@ -57,10 +77,11 @@ func TestFeaturePrecedenceEvaluationPostQuantum(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { resolver := &staticResolver{record: featuresRecord{}} - selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, []string{}, test.cli) + selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, []string{}, test.cli, time.Second) require.NoError(t, err) - require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures()) - require.Equal(t, test.expectedVersion, selector.PostQuantumMode()) + snapshot := selector.Snapshot() + require.ElementsMatch(t, test.expectedFeatures, snapshot.FeaturesList) + require.Equal(t, test.expectedVersion, snapshot.PostQuantum) }) } } @@ -100,10 +121,11 @@ func TestFeaturePrecedenceEvaluationDatagramVersion(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { resolver := &staticResolver{record: test.remote} - selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false) + selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false, time.Second) require.NoError(t, err) - require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures()) - require.Equal(t, test.expectedVersion, selector.DatagramVersion()) + snapshot := selector.Snapshot() + require.ElementsMatch(t, test.expectedFeatures, snapshot.FeaturesList) + require.Equal(t, test.expectedVersion, snapshot.DatagramVersion) }) } } @@ -133,34 +155,99 @@ func TestDeprecatedFeaturesRemoved(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { resolver := &staticResolver{record: test.remote} - selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false) + selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false, time.Second) require.NoError(t, err) - require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures()) + snapshot := selector.Snapshot() + require.ElementsMatch(t, test.expectedFeatures, snapshot.FeaturesList) }) } } +func TestRefreshFeaturesRecord(t *testing.T) { + percentages := []uint32{0, 10, testAccountHash - 1, testAccountHash, testAccountHash + 1, 100, 101, 1000} + selector := newTestSelector(t, percentages, false, time.Minute) + + // Starting out should default to DatagramV2 + snapshot := selector.Snapshot() + require.Equal(t, DatagramV2, snapshot.DatagramVersion) + + for _, percentage := range percentages { + snapshot = selector.Snapshot() + if percentage > testAccountHash { + require.Equal(t, DatagramV3, snapshot.DatagramVersion) + } else { + require.Equal(t, DatagramV2, snapshot.DatagramVersion) + } + + // Manually progress the next refresh + _ = selector.refresh(context.Background()) + } + + // Make sure a resolver error doesn't override the last fetched features + snapshot = selector.Snapshot() + require.Equal(t, DatagramV3, snapshot.DatagramVersion) +} + +func TestSnapshotIsolation(t *testing.T) { + percentages := []uint32{testAccountHash, testAccountHash + 1} + selector := newTestSelector(t, percentages, false, time.Minute) + + // Starting out should default to DatagramV2 + snapshot := selector.Snapshot() + require.Equal(t, DatagramV2, snapshot.DatagramVersion) + + // Manually progress the next refresh + _ = selector.refresh(context.Background()) + + snapshot2 := selector.Snapshot() + require.Equal(t, DatagramV3, snapshot2.DatagramVersion) + require.NotEqual(t, snapshot.DatagramVersion, snapshot2.DatagramVersion) +} + func TestStaticFeatures(t *testing.T) { percentages := []uint32{0} // PostQuantum Enabled from user flag - selector := newTestSelector(t, percentages, true) - require.Equal(t, PostQuantumStrict, selector.PostQuantumMode()) + selector := newTestSelector(t, percentages, true, time.Second) + snapshot := selector.Snapshot() + require.Equal(t, PostQuantumStrict, snapshot.PostQuantum) // PostQuantum Disabled (or not set) - selector = newTestSelector(t, percentages, false) - require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode()) + selector = newTestSelector(t, percentages, false, time.Second) + snapshot = selector.Snapshot() + require.Equal(t, PostQuantumPrefer, snapshot.PostQuantum) } -func newTestSelector(t *testing.T, percentages []uint32, pq bool) *FeatureSelector { - accountTag := t.Name() +func newTestSelector(t *testing.T, percentages []uint32, pq bool, refreshFreq time.Duration) *featureSelector { logger := zerolog.Nop() - selector, err := newFeatureSelector(context.Background(), accountTag, &logger, &staticResolver{}, []string{}, pq) + resolver := &mockResolver{ + percentages: percentages, + } + + selector, err := newFeatureSelector(context.Background(), testAccountTag, &logger, resolver, []string{}, pq, refreshFreq) require.NoError(t, err) return selector } +type mockResolver struct { + nextIndex int + percentages []uint32 +} + +func (mr *mockResolver) lookupRecord(ctx context.Context) ([]byte, error) { + if mr.nextIndex >= len(mr.percentages) { + return nil, fmt.Errorf("no more record to lookup") + } + + record, err := json.Marshal(featuresRecord{ + DatagramV3Percentage: mr.percentages[mr.nextIndex], + }) + mr.nextIndex++ + + return record, err +} + type staticResolver struct { record featuresRecord } diff --git a/quic/v3/metrics.go b/quic/v3/metrics.go index 8f330af8..e51e3683 100644 --- a/quic/v3/metrics.go +++ b/quic/v3/metrics.go @@ -1,6 +1,8 @@ package v3 import ( + "fmt" + "github.com/prometheus/client_golang/prometheus" "github.com/cloudflare/cloudflared/quic" @@ -32,28 +34,28 @@ type metrics struct { } func (m *metrics) IncrementFlows(connIndex uint8) { - m.totalUDPFlows.WithLabelValues(string(connIndex)).Inc() - m.activeUDPFlows.WithLabelValues(string(connIndex)).Inc() + m.totalUDPFlows.WithLabelValues(fmt.Sprintf("%d", connIndex)).Inc() + m.activeUDPFlows.WithLabelValues(fmt.Sprintf("%d", connIndex)).Inc() } func (m *metrics) DecrementFlows(connIndex uint8) { - m.activeUDPFlows.WithLabelValues(string(connIndex)).Dec() + m.activeUDPFlows.WithLabelValues(fmt.Sprintf("%d", connIndex)).Dec() } func (m *metrics) PayloadTooLarge(connIndex uint8) { - m.payloadTooLarge.WithLabelValues(string(connIndex)).Inc() + m.payloadTooLarge.WithLabelValues(fmt.Sprintf("%d", connIndex)).Inc() } func (m *metrics) RetryFlowResponse(connIndex uint8) { - m.retryFlowResponses.WithLabelValues(string(connIndex)).Inc() + m.retryFlowResponses.WithLabelValues(fmt.Sprintf("%d", connIndex)).Inc() } func (m *metrics) MigrateFlow(connIndex uint8) { - m.migratedFlows.WithLabelValues(string(connIndex)).Inc() + m.migratedFlows.WithLabelValues(fmt.Sprintf("%d", connIndex)).Inc() } func (m *metrics) UnsupportedRemoteCommand(connIndex uint8, command string) { - m.unsupportedRemoteCommands.WithLabelValues(string(connIndex), command).Inc() + m.unsupportedRemoteCommands.WithLabelValues(fmt.Sprintf("%d", connIndex), command).Inc() } func NewMetrics(registerer prometheus.Registerer) Metrics { diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 01937756..c708c944 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -17,6 +17,7 @@ import ( "github.com/rs/zerolog" "golang.org/x/sync/errgroup" + "github.com/cloudflare/cloudflared/client" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery/allregions" @@ -38,10 +39,8 @@ const ( ) type TunnelConfig struct { + ClientConfig *client.Config GracePeriod time.Duration - ReplaceExisting bool - OSArch string - ClientID string CloseConnOnce *sync.Once // Used to close connectedSignal no more than once EdgeAddrs []string Region string @@ -72,22 +71,13 @@ type TunnelConfig struct { DisableQUICPathMTUDiscovery bool QUICConnectionLevelFlowControlLimit uint64 QUICStreamLevelFlowControlLimit uint64 - - FeatureSelector *features.FeatureSelector } -func (c *TunnelConfig) connectionOptions(originLocalAddr string, numPreviousAttempts uint8) *pogs.ConnectionOptions { +func (c *TunnelConfig) connectionOptions(originLocalAddr string, previousAttempts uint8) *client.ConnectionOptionsSnapshot { // attempt to parse out origin IP, but don't fail since it's informational field host, _, _ := net.SplitHostPort(originLocalAddr) originIP := net.ParseIP(host) - - return &pogs.ConnectionOptions{ - Client: c.NamedTunnel.Client, - OriginLocalIP: originIP, - ReplaceExisting: c.ReplaceExisting, - CompressionQuality: 0, - NumPreviousAttempts: numPreviousAttempts, - } + return c.ClientConfig.ConnectionOptionsSnapshot(originIP, previousAttempts) } func StartTunnelDaemon( @@ -463,6 +453,8 @@ func (e *EdgeTunnelServer) serveConnection( case connection.QUIC: // nolint: gosec connOptions := e.config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries())) + // nolint: zerologlint + connOptions.LogFields(connLog.Logger().Debug().Uint8(connection.LogFieldConnIndex, connIndex)).Msgf("Tunnel connection options") return e.serveQUIC(ctx, addr.UDP.AddrPort(), connLog, @@ -479,6 +471,8 @@ func (e *EdgeTunnelServer) serveConnection( // nolint: gosec connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries())) + // nolint: zerologlint + connOptions.LogFields(connLog.Logger().Debug().Uint8(connection.LogFieldConnIndex, connIndex)).Msgf("Tunnel connection options") if err := e.serveHTTP2( ctx, connLog, @@ -508,11 +502,11 @@ func (e *EdgeTunnelServer) serveHTTP2( ctx context.Context, connLog *ConnAwareLogger, tlsServerConn net.Conn, - connOptions *pogs.ConnectionOptions, + connOptions *client.ConnectionOptionsSnapshot, controlStreamHandler connection.ControlStreamHandler, connIndex uint8, ) error { - pqMode := e.config.FeatureSelector.PostQuantumMode() + pqMode := connOptions.FeatureSnapshot.PostQuantum if pqMode == features.PostQuantumStrict { return unrecoverableError{errors.New("HTTP/2 transport does not support post-quantum")} } @@ -550,19 +544,19 @@ func (e *EdgeTunnelServer) serveQUIC( ctx context.Context, edgeAddr netip.AddrPort, connLogger *ConnAwareLogger, - connOptions *pogs.ConnectionOptions, + connOptions *client.ConnectionOptionsSnapshot, controlStreamHandler connection.ControlStreamHandler, connIndex uint8, ) (err error, recoverable bool) { tlsConfig := e.config.EdgeTLSConfigs[connection.QUIC] - pqMode := e.config.FeatureSelector.PostQuantumMode() + pqMode := connOptions.FeatureSnapshot.PostQuantum curvePref, err := curvePreference(pqMode, fips.IsFipsEnabled(), tlsConfig.CurvePreferences) if err != nil { return err, true } - connLogger.Logger().Info().Msgf("Using %v as curve preferences", curvePref) + connLogger.Logger().Info().Msgf("Tunnel connection curve preferences: %v", curvePref) tlsConfig.CurvePreferences = curvePref @@ -600,12 +594,12 @@ func (e *EdgeTunnelServer) serveQUIC( if err != nil { connLogger.ConnAwareLogger().Err(err).Msgf("Failed to dial a quic connection") - e.reportErrorToSentry(err) + e.reportErrorToSentry(err, connOptions.FeatureSnapshot.PostQuantum) return err, true } var datagramSessionManager connection.DatagramSessionHandler - if e.config.FeatureSelector.DatagramVersion() == features.DatagramV3 { + if connOptions.FeatureSnapshot.DatagramVersion == features.DatagramV3 { datagramSessionManager = connection.NewDatagramV3Connection( ctx, conn, @@ -672,7 +666,7 @@ func (e *EdgeTunnelServer) serveQUIC( // The reportErrorToSentry is an helper function that handles // verifies if an error should be reported to Sentry. -func (e *EdgeTunnelServer) reportErrorToSentry(err error) { +func (e *EdgeTunnelServer) reportErrorToSentry(err error, pqMode features.PostQuantumMode) { dialErr, ok := err.(*connection.EdgeQuicDialError) if ok { // The TransportError provides an Unwrap function however @@ -681,7 +675,7 @@ func (e *EdgeTunnelServer) reportErrorToSentry(err error) { if ok && transportErr.ErrorCode.IsCryptoError() && fips.IsFipsEnabled() && - e.config.FeatureSelector.PostQuantumMode() == features.PostQuantumStrict { + pqMode == features.PostQuantumStrict { // Only report to Sentry when using FIPS, PQ, // and the error is a Crypto error reported by // an EdgeQuicDialError