diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index aa678231..c8d565f2 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -31,7 +31,6 @@ import ( "github.com/cloudflare/cloudflared/credentials" "github.com/cloudflare/cloudflared/diagnostic" "github.com/cloudflare/cloudflared/edgediscovery" - "github.com/cloudflare/cloudflared/features" "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/management" @@ -515,26 +514,23 @@ func StartServer( tunnelConfig.ICMPRouterServer = nil } - internalRules := []ingress.Rule{} - if features.Contains(features.FeatureManagementLogs) { - serviceIP := c.String("service-op-ip") - if edgeAddrs, err := edgediscovery.ResolveEdge(log, tunnelConfig.Region, tunnelConfig.EdgeIPVersion); err == nil { - if serviceAddr, err := edgeAddrs.GetAddrForRPC(); err == nil { - serviceIP = serviceAddr.TCP.String() - } + serviceIP := c.String("service-op-ip") + if edgeAddrs, err := edgediscovery.ResolveEdge(log, tunnelConfig.Region, tunnelConfig.EdgeIPVersion); err == nil { + if serviceAddr, err := edgeAddrs.GetAddrForRPC(); err == nil { + serviceIP = serviceAddr.TCP.String() } - - mgmt := management.New( - c.String("management-hostname"), - c.Bool("management-diagnostics"), - serviceIP, - clientID, - c.String(connectorLabelFlag), - logger.ManagementLogger.Log, - logger.ManagementLogger, - ) - internalRules = []ingress.Rule{ingress.NewManagementRule(mgmt)} } + + mgmt := management.New( + c.String("management-hostname"), + c.Bool("management-diagnostics"), + serviceIP, + clientID, + c.String(connectorLabelFlag), + logger.ManagementLogger.Log, + logger.ManagementLogger, + ) + internalRules := []ingress.Rule{ingress.NewManagementRule(mgmt)} orchestrator, err := orchestration.NewOrchestrator(ctx, orchestratorConfig, tunnelConfig.Tags, internalRules, tunnelConfig.Log) if err != nil { return err diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index e04a1c76..c5983273 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -137,20 +137,15 @@ func prepareTunnelConfig( transportProtocol := c.String("protocol") - clientFeatures := features.Dedup(append(c.StringSlice("features"), features.DefaultFeatures...)) - - staticFeatures := features.StaticFeatures{} - if c.Bool("post-quantum") { - if FipsEnabled { - return nil, nil, fmt.Errorf("post-quantum not supported in FIPS mode") - } - pqMode := features.PostQuantumStrict - staticFeatures.PostQuantumMode = &pqMode + if c.Bool("post-quantum") && FipsEnabled { + return nil, nil, fmt.Errorf("post-quantum not supported in FIPS mode") } - featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, staticFeatures, log) + + featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, c.StringSlice("features"), c.Bool("post-quantum"), log) if err != nil { return nil, nil, errors.Wrap(err, "Failed to create feature selector") } + clientFeatures := featureSelector.ClientFeatures() pqMode := featureSelector.PostQuantumMode() if pqMode == features.PostQuantumStrict { // Error if the user tries to force a non-quic transport protocol @@ -158,7 +153,6 @@ func prepareTunnelConfig( return nil, nil, fmt.Errorf("post-quantum is only supported with the quic transport") } transportProtocol = connection.QUIC.String() - clientFeatures = append(clientFeatures, features.FeaturePostQuantum) log.Info().Msgf( "Using hybrid post-quantum key agreement %s", diff --git a/features/features.go b/features/features.go index 574f55ae..d1476285 100644 --- a/features/features.go +++ b/features/features.go @@ -12,7 +12,7 @@ const ( ) var ( - DefaultFeatures = []string{ + defaultFeatures = []string{ FeatureAllowRemoteConfig, FeatureSerializedHeaders, FeatureDatagramV2, @@ -21,15 +21,30 @@ var ( } ) -func Contains(feature string) bool { - for _, f := range DefaultFeatures { - if f == feature { - return true - } - } - return false +// Features set by user provided flags +type staticFeatures struct { + PostQuantumMode *PostQuantumMode } +type PostQuantumMode uint8 + +const ( + // Prefer post quantum, but fallback if connection cannot be established + PostQuantumPrefer PostQuantumMode = iota + // If the user passes the --post-quantum flag, we override + // CurvePreferences to only support hybrid post-quantum key agreements. + PostQuantumStrict +) + +type DatagramVersion string + +const ( + // DatagramV2 is the currently supported datagram protocol for UDP and ICMP packets + DatagramV2 DatagramVersion = FeatureDatagramV2 + // DatagramV3 is a new datagram protocol for UDP and ICMP packets. It is not backwards compatible with datagram v2. + DatagramV3 DatagramVersion = FeatureDatagramV3 +) + // Remove any duplicates from the slice func Dedup(slice []string) []string { diff --git a/features/selector.go b/features/selector.go index 3e952dd2..a97a0134 100644 --- a/features/selector.go +++ b/features/selector.go @@ -6,6 +6,7 @@ import ( "fmt" "hash/fnv" "net" + "slices" "sync" "time" @@ -18,61 +19,67 @@ const ( lookupTimeout = time.Second * 10 ) -type PostQuantumMode uint8 - -const ( - // Prefer post quantum, but fallback if connection cannot be established - PostQuantumPrefer PostQuantumMode = iota - // If the user passes the --post-quantum flag, we override - // CurvePreferences to only support hybrid post-quantum key agreements. - PostQuantumStrict -) - // If the TXT record adds other fields, the umarshal logic will ignore those keys // If the TXT record is missing a key, the field will unmarshal to the default Go value -// pq was removed in TUN-7970 -type featuresRecord struct{} -func NewFeatureSelector(ctx context.Context, accountTag string, staticFeatures StaticFeatures, logger *zerolog.Logger) (*FeatureSelector, error) { - return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), staticFeatures, defaultRefreshFreq) +type featuresRecord struct { + // support_datagram_v3 + DatagramV3Percentage int32 `json:"dv3"` + + // PostQuantumPercentage int32 `json:"pq"` // Removed in TUN-7970 } -// FeatureSelector determines if this account will try new features. It preiodically queries a DNS TXT record -// to see which features are turned on +func NewFeatureSelector(ctx context.Context, accountTag string, cliFeatures []string, pq bool, logger *zerolog.Logger) (*FeatureSelector, error) { + return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), cliFeatures, pq, defaultRefreshFreq) +} + +// FeatureSelector determines if this account will try new features. It periodically queries a DNS TXT record +// to see which features are turned on. type FeatureSelector struct { accountHash int32 logger *zerolog.Logger resolver resolver - staticFeatures StaticFeatures + staticFeatures staticFeatures + cliFeatures []string // lock protects concurrent access to dynamic features lock sync.RWMutex features featuresRecord } -// Features set by user provided flags -type StaticFeatures struct { - PostQuantumMode *PostQuantumMode -} - -func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, staticFeatures StaticFeatures, refreshFreq time.Duration) (*FeatureSelector, error) { +func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, cliFeatures []string, pq bool, refreshFreq time.Duration) (*FeatureSelector, error) { + // Combine default features and user-provided features + var pqMode *PostQuantumMode + if pq { + mode := PostQuantumStrict + pqMode = &mode + cliFeatures = append(cliFeatures, FeaturePostQuantum) + } + staticFeatures := staticFeatures{ + PostQuantumMode: pqMode, + } selector := &FeatureSelector{ accountHash: switchThreshold(accountTag), logger: logger, resolver: resolver, staticFeatures: staticFeatures, + cliFeatures: Dedup(cliFeatures), } if err := selector.refresh(ctx); err != nil { logger.Err(err).Msg("Failed to fetch features, default to disable") } - // Run refreshLoop next time we have a new feature to rollout + go selector.refreshLoop(ctx, refreshFreq) return selector, nil } +func (fs *FeatureSelector) accountEnabled(percentage int32) bool { + return percentage > fs.accountHash +} + func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode { if fs.staticFeatures.PostQuantumMode != nil { return *fs.staticFeatures.PostQuantumMode @@ -81,6 +88,33 @@ func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode { return PostQuantumPrefer } +func (fs *FeatureSelector) DatagramVersion() DatagramVersion { + fs.lock.RLock() + defer fs.lock.RUnlock() + + // If user provides the feature via the cli, we take it as priority over remote feature evaluation + if slices.Contains(fs.cliFeatures, FeatureDatagramV3) { + return DatagramV3 + } + // If the user specifies DatagramV2, we also take that over remote + if slices.Contains(fs.cliFeatures, FeatureDatagramV2) { + return DatagramV2 + } + + if fs.accountEnabled(fs.features.DatagramV3Percentage) { + return DatagramV3 + } + return DatagramV2 +} + +// ClientFeatures will return the list of currently available features that cloudflared should provide to the edge. +// +// This list is dynamic and can change in-between returns. +func (fs *FeatureSelector) ClientFeatures() []string { + // Evaluate any remote features along with static feature list to construct the list of features + return Dedup(slices.Concat(defaultFeatures, fs.cliFeatures, []string{string(fs.DatagramVersion())})) +} + func (fs *FeatureSelector) refreshLoop(ctx context.Context, refreshFreq time.Duration) { ticker := time.NewTicker(refreshFreq) for { diff --git a/features/selector_test.go b/features/selector_test.go index b8739c6b..5c57e3e6 100644 --- a/features/selector_test.go +++ b/features/selector_test.go @@ -13,16 +13,20 @@ import ( func TestUnmarshalFeaturesRecord(t *testing.T) { tests := []struct { - record []byte + record []byte + expectedPercentage int32 }{ { - record: []byte(`{"pq":0}`), + record: []byte(`{"dv3":0}`), + expectedPercentage: 0, }, { - record: []byte(`{"pq":39}`), + record: []byte(`{"dv3":39}`), + expectedPercentage: 39, }, { - record: []byte(`{"pq":100}`), + record: []byte(`{"dv3":100}`), + expectedPercentage: 100, }, { record: []byte(`{}`), // Unmarshal to default struct if key is not present @@ -36,37 +40,186 @@ func TestUnmarshalFeaturesRecord(t *testing.T) { var features featuresRecord err := json.Unmarshal(test.record, &features) require.NoError(t, err) - require.Equal(t, featuresRecord{}, features) + require.Equal(t, test.expectedPercentage, features.DatagramV3Percentage, test) } } +func TestFeaturePrecedenceEvaluationPostQuantum(t *testing.T) { + logger := zerolog.Nop() + tests := []struct { + name string + cli bool + expectedFeatures []string + expectedVersion PostQuantumMode + }{ + { + name: "default", + cli: false, + expectedFeatures: defaultFeatures, + expectedVersion: PostQuantumPrefer, + }, + { + name: "user_specified", + cli: true, + expectedFeatures: Dedup(append(defaultFeatures, FeaturePostQuantum)), + expectedVersion: PostQuantumStrict, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + resolver := &staticResolver{record: featuresRecord{}} + selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, []string{}, test.cli, time.Second) + require.NoError(t, err) + require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures()) + require.Equal(t, test.expectedVersion, selector.PostQuantumMode()) + }) + } +} + +func TestFeaturePrecedenceEvaluationDatagramVersion(t *testing.T) { + logger := zerolog.Nop() + tests := []struct { + name string + cli []string + remote featuresRecord + expectedFeatures []string + expectedVersion DatagramVersion + }{ + { + name: "default", + cli: []string{}, + remote: featuresRecord{}, + expectedFeatures: defaultFeatures, + expectedVersion: DatagramV2, + }, + { + name: "user_specified_v2", + cli: []string{FeatureDatagramV2}, + remote: featuresRecord{}, + expectedFeatures: defaultFeatures, + expectedVersion: DatagramV2, + }, + { + name: "user_specified_v3", + cli: []string{FeatureDatagramV3}, + remote: featuresRecord{}, + expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)), + expectedVersion: FeatureDatagramV3, + }, + { + name: "remote_specified_v3", + cli: []string{}, + remote: featuresRecord{ + DatagramV3Percentage: 100, + }, + expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)), + expectedVersion: FeatureDatagramV3, + }, + { + name: "remote_and_user_specified_v3", + cli: []string{FeatureDatagramV3}, + remote: featuresRecord{ + DatagramV3Percentage: 100, + }, + expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)), + expectedVersion: FeatureDatagramV3, + }, + { + name: "remote_v3_and_user_specified_v2", + cli: []string{FeatureDatagramV2}, + remote: featuresRecord{ + DatagramV3Percentage: 100, + }, + expectedFeatures: defaultFeatures, + expectedVersion: DatagramV2, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + resolver := &staticResolver{record: test.remote} + selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false, time.Second) + require.NoError(t, err) + require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures()) + require.Equal(t, test.expectedVersion, selector.DatagramVersion()) + }) + } +} + +func TestRefreshFeaturesRecord(t *testing.T) { + // The hash of the accountTag is 82 + accountTag := t.Name() + threshold := switchThreshold(accountTag) + + percentages := []int32{0, 10, 81, 82, 83, 100, 101, 1000} + refreshFreq := time.Millisecond * 10 + selector := newTestSelector(t, percentages, false, refreshFreq) + + // Starting out should default to DatagramV2 + require.Equal(t, DatagramV2, selector.DatagramVersion()) + + for _, percentage := range percentages { + if percentage > threshold { + require.Equal(t, DatagramV3, selector.DatagramVersion()) + } else { + require.Equal(t, DatagramV2, selector.DatagramVersion()) + } + + time.Sleep(refreshFreq + time.Millisecond) + } + + // Make sure error doesn't override the last fetched features + require.Equal(t, DatagramV3, selector.DatagramVersion()) +} + func TestStaticFeatures(t *testing.T) { - pqMode := PostQuantumStrict - selector := newTestSelector(t, &pqMode, time.Millisecond*10) + percentages := []int32{0} + // PostQuantum Enabled from user flag + selector := newTestSelector(t, percentages, true, time.Millisecond*10) require.Equal(t, PostQuantumStrict, selector.PostQuantumMode()) - // No StaticFeatures configured - selector = newTestSelector(t, nil, time.Millisecond*10) + // PostQuantum Disabled (or not set) + selector = newTestSelector(t, percentages, false, time.Millisecond*10) require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode()) } -func newTestSelector(t *testing.T, pqMode *PostQuantumMode, refreshFreq time.Duration) *FeatureSelector { +func newTestSelector(t *testing.T, percentages []int32, pq bool, refreshFreq time.Duration) *FeatureSelector { accountTag := t.Name() logger := zerolog.Nop() - resolver := &mockResolver{} - - staticFeatures := StaticFeatures{ - PostQuantumMode: pqMode, + resolver := &mockResolver{ + percentages: percentages, } - selector, err := newFeatureSelector(context.Background(), accountTag, &logger, resolver, staticFeatures, refreshFreq) + + selector, err := newFeatureSelector(context.Background(), accountTag, &logger, resolver, []string{}, pq, refreshFreq) require.NoError(t, err) return selector } -type mockResolver struct{} +type mockResolver struct { + nextIndex int + percentages []int32 +} func (mr *mockResolver) lookupRecord(ctx context.Context) ([]byte, error) { - return nil, fmt.Errorf("mockResolver hasn't implement lookupRecord") + if mr.nextIndex >= len(mr.percentages) { + return nil, fmt.Errorf("no more record to lookup") + } + + record, err := json.Marshal(featuresRecord{ + DatagramV3Percentage: mr.percentages[mr.nextIndex], + }) + mr.nextIndex++ + + return record, err +} + +type staticResolver struct { + record featuresRecord +} + +func (r *staticResolver) lookupRecord(ctx context.Context) ([]byte, error) { + return json.Marshal(r.record) } diff --git a/logger/create.go b/logger/create.go index 83048e59..4a298ad4 100644 --- a/logger/create.go +++ b/logger/create.go @@ -16,7 +16,6 @@ import ( "golang.org/x/term" "gopkg.in/natefinch/lumberjack.v2" - "github.com/cloudflare/cloudflared/features" "github.com/cloudflare/cloudflared/management" ) @@ -46,11 +45,7 @@ func init() { zerolog.TimeFieldFormat = time.RFC3339 zerolog.TimestampFunc = utcNow - if features.Contains(features.FeatureManagementLogs) { - // Management logger needs to be initialized before any of the other loggers as to not capture - // it's own logging events. - ManagementLogger = management.NewLogger() - } + ManagementLogger = management.NewLogger() } func utcNow() time.Time { @@ -124,10 +119,7 @@ func newZerolog(loggerConfig *Config) *zerolog.Logger { writers = append(writers, rollingLogger) } - var managementWriter zerolog.LevelWriter - if features.Contains(features.FeatureManagementLogs) { - managementWriter = ManagementLogger - } + managementWriter := ManagementLogger level, levelErr := zerolog.ParseLevel(loggerConfig.MinLevel) if levelErr != nil { diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 09983e11..43349c19 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -7,7 +7,6 @@ import ( "net" "net/netip" "runtime/debug" - "slices" "strings" "sync" "time" @@ -554,10 +553,6 @@ func (e *EdgeTunnelServer) serveQUIC( tlsConfig := e.config.EdgeTLSConfigs[connection.QUIC] pqMode := e.config.FeatureSelector.PostQuantumMode() - if pqMode == features.PostQuantumStrict || pqMode == features.PostQuantumPrefer { - connOptions.Client.Features = features.Dedup(append(connOptions.Client.Features, features.FeaturePostQuantum)) - } - curvePref, err := curvePreference(pqMode, tlsConfig.CurvePreferences) if err != nil { return err, true @@ -602,7 +597,7 @@ func (e *EdgeTunnelServer) serveQUIC( } var datagramSessionManager connection.DatagramSessionHandler - if slices.Contains(connOptions.Client.Features, features.FeatureDatagramV3) { + if e.config.FeatureSelector.DatagramVersion() == features.DatagramV3 { datagramSessionManager = connection.NewDatagramV3Connection( ctx, conn,