diff --git a/cmd/cloudflared/tunnel/config_test.go b/cmd/cloudflared/tunnel/config_test.go deleted file mode 100644 index 0ca6c436..00000000 --- a/cmd/cloudflared/tunnel/config_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package tunnel - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/cloudflare/cloudflared/features" -) - -func TestDedup(t *testing.T) { - expected := []string{"a", "b"} - actual := features.Dedup([]string{"a", "b", "a"}) - require.ElementsMatch(t, expected, actual) -} diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index fc21c7ec..b78dba89 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -140,7 +140,7 @@ func prepareTunnelConfig( transportProtocol := c.String(flags.Protocol) isPostQuantumEnforced := c.Bool(flags.PostQuantum) - featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, c.StringSlice("features"), c.Bool("post-quantum"), log) + featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, c.StringSlice(flags.Features), c.Bool(flags.PostQuantum), log) if err != nil { return nil, nil, errors.Wrap(err, "Failed to create feature selector") } diff --git a/connection/control.go b/connection/control.go index 2e5f1e35..d590b078 100644 --- a/connection/control.go +++ b/connection/control.go @@ -10,7 +10,7 @@ import ( "github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/tunnelrpc" - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) // registerClient derives a named tunnel rpc client that can then be used to register and unregister connections. @@ -36,7 +36,7 @@ type controlStream struct { // ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown. type ControlStreamHandler interface { // ServeControlStream handles the control plane of the transport in the current goroutine calling this - ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, tunnelConfigGetter TunnelConfigJSONGetter) error + ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *pogs.ConnectionOptions, tunnelConfigGetter TunnelConfigJSONGetter) error // IsStopped tells whether the method above has finished IsStopped() bool } @@ -78,7 +78,7 @@ func NewControlStream( func (c *controlStream) ServeControlStream( ctx context.Context, rw io.ReadWriteCloser, - connOptions *tunnelpogs.ConnectionOptions, + connOptions *pogs.ConnectionOptions, tunnelConfigGetter TunnelConfigJSONGetter, ) error { registrationClient := c.registerClientFunc(ctx, rw, c.registerTimeout) diff --git a/connection/http2.go b/connection/http2.go index 15afcbde..243c464d 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -19,7 +19,7 @@ import ( cfdflow "github.com/cloudflare/cloudflared/flow" "github.com/cloudflare/cloudflared/tracing" - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) // note: these constants are exported so we can reuse them in the edge-side code @@ -39,7 +39,7 @@ type HTTP2Connection struct { conn net.Conn server *http2.Server orchestrator Orchestrator - connOptions *tunnelpogs.ConnectionOptions + connOptions *pogs.ConnectionOptions observer *Observer connIndex uint8 @@ -54,7 +54,7 @@ type HTTP2Connection struct { func NewHTTP2Connection( conn net.Conn, orchestrator Orchestrator, - connOptions *tunnelpogs.ConnectionOptions, + connOptions *pogs.ConnectionOptions, observer *Observer, connIndex uint8, controlStreamHandler ControlStreamHandler, diff --git a/connection/quic_connection.go b/connection/quic_connection.go index 6addfd60..3451369c 100644 --- a/connection/quic_connection.go +++ b/connection/quic_connection.go @@ -22,7 +22,6 @@ import ( cfdquic "github.com/cloudflare/cloudflared/quic" "github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic" ) @@ -44,7 +43,7 @@ type quicConnection struct { orchestrator Orchestrator datagramHandler DatagramSessionHandler controlStreamHandler ControlStreamHandler - connOptions *tunnelpogs.ConnectionOptions + connOptions *pogs.ConnectionOptions connIndex uint8 rpcTimeout time.Duration @@ -235,7 +234,7 @@ func (q *quicConnection) dispatchRequest(ctx context.Context, stream *rpcquic.Re } // UpdateConfiguration is the RPC method invoked by edge when there is a new configuration -func (q *quicConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { +func (q *quicConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) *pogs.UpdateConfigurationResponse { return q.orchestrator.UpdateConfig(version, config) } diff --git a/features/features.go b/features/features.go index 25b5dc8b..c16b51fd 100644 --- a/features/features.go +++ b/features/features.go @@ -1,5 +1,7 @@ package features +import "slices" + const ( FeatureSerializedHeaders = "serialized_headers" FeatureQuickReconnects = "quick_reconnects" @@ -8,7 +10,9 @@ const ( FeaturePostQuantum = "postquantum" FeatureQUICSupportEOF = "support_quic_eof" FeatureManagementLogs = "management_logs" - FeatureDatagramV3 = "support_datagram_v3" + FeatureDatagramV3_1 = "support_datagram_v3_1" + + DeprecatedFeatureDatagramV3 = "support_datagram_v3" // Deprecated: TUN-9291 ) var defaultFeatures = []string{ @@ -19,6 +23,11 @@ var defaultFeatures = []string{ FeatureManagementLogs, } +// List of features that are no longer in-use. +var deprecatedFeatures = []string{ + DeprecatedFeatureDatagramV3, +} + // Features set by user provided flags type staticFeatures struct { PostQuantumMode *PostQuantumMode @@ -40,15 +49,19 @@ const ( // DatagramV2 is the currently supported datagram protocol for UDP and ICMP packets DatagramV2 DatagramVersion = FeatureDatagramV2 // DatagramV3 is a new datagram protocol for UDP and ICMP packets. It is not backwards compatible with datagram v2. - DatagramV3 DatagramVersion = FeatureDatagramV3 + DatagramV3 DatagramVersion = FeatureDatagramV3_1 ) -// Remove any duplicates from the slice -func Dedup(slice []string) []string { +// Remove any duplicate features from the list and remove deprecated features +func dedupAndRemoveFeatures(features []string) []string { // Convert the slice into a set - set := make(map[string]bool, 0) - for _, str := range slice { - set[str] = true + set := map[string]bool{} + for _, feature := range features { + // Remove deprecated features from the provided list + if slices.Contains(deprecatedFeatures, feature) { + continue + } + set[feature] = true } // Convert the set back into a slice diff --git a/features/selector.go b/features/selector.go index a97a0134..c4925c45 100644 --- a/features/selector.go +++ b/features/selector.go @@ -7,7 +7,6 @@ import ( "hash/fnv" "net" "slices" - "sync" "time" "github.com/rs/zerolog" @@ -15,7 +14,6 @@ import ( const ( featureSelectorHostname = "cfd-features.argotunnel.com" - defaultRefreshFreq = time.Hour * 6 lookupTimeout = time.Second * 10 ) @@ -23,32 +21,27 @@ const ( // If the TXT record is missing a key, the field will unmarshal to the default Go value type featuresRecord struct { - // support_datagram_v3 - DatagramV3Percentage int32 `json:"dv3"` - + // DatagramV3Percentage int32 `json:"dv3"` // Removed in TUN-9291 // PostQuantumPercentage int32 `json:"pq"` // Removed in TUN-7970 } func NewFeatureSelector(ctx context.Context, accountTag string, cliFeatures []string, pq bool, logger *zerolog.Logger) (*FeatureSelector, error) { - return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), cliFeatures, pq, defaultRefreshFreq) + return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), cliFeatures, pq) } -// FeatureSelector determines if this account will try new features. It periodically queries a DNS TXT record -// to see which features are turned on. +// FeatureSelector determines if this account will try new features; loaded once during startup. type FeatureSelector struct { - accountHash int32 + accountHash uint32 logger *zerolog.Logger resolver resolver staticFeatures staticFeatures cliFeatures []string - // lock protects concurrent access to dynamic features - lock sync.RWMutex features featuresRecord } -func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, cliFeatures []string, pq bool, refreshFreq time.Duration) (*FeatureSelector, error) { +func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, cliFeatures []string, pq bool) (*FeatureSelector, error) { // Combine default features and user-provided features var pqMode *PostQuantumMode if pq { @@ -64,22 +57,16 @@ func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog. logger: logger, resolver: resolver, staticFeatures: staticFeatures, - cliFeatures: Dedup(cliFeatures), + cliFeatures: dedupAndRemoveFeatures(cliFeatures), } - if err := selector.refresh(ctx); err != nil { + if err := selector.init(ctx); err != nil { logger.Err(err).Msg("Failed to fetch features, default to disable") } - go selector.refreshLoop(ctx, refreshFreq) - return selector, nil } -func (fs *FeatureSelector) accountEnabled(percentage int32) bool { - return percentage > fs.accountHash -} - func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode { if fs.staticFeatures.PostQuantumMode != nil { return *fs.staticFeatures.PostQuantumMode @@ -89,11 +76,8 @@ func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode { } func (fs *FeatureSelector) DatagramVersion() DatagramVersion { - fs.lock.RLock() - defer fs.lock.RUnlock() - // If user provides the feature via the cli, we take it as priority over remote feature evaluation - if slices.Contains(fs.cliFeatures, FeatureDatagramV3) { + if slices.Contains(fs.cliFeatures, FeatureDatagramV3_1) { return DatagramV3 } // If the user specifies DatagramV2, we also take that over remote @@ -101,36 +85,16 @@ func (fs *FeatureSelector) DatagramVersion() DatagramVersion { return DatagramV2 } - if fs.accountEnabled(fs.features.DatagramV3Percentage) { - return DatagramV3 - } return DatagramV2 } // ClientFeatures will return the list of currently available features that cloudflared should provide to the edge. -// -// This list is dynamic and can change in-between returns. func (fs *FeatureSelector) ClientFeatures() []string { // Evaluate any remote features along with static feature list to construct the list of features - return Dedup(slices.Concat(defaultFeatures, fs.cliFeatures, []string{string(fs.DatagramVersion())})) + return dedupAndRemoveFeatures(slices.Concat(defaultFeatures, fs.cliFeatures, []string{string(fs.DatagramVersion())})) } -func (fs *FeatureSelector) refreshLoop(ctx context.Context, refreshFreq time.Duration) { - ticker := time.NewTicker(refreshFreq) - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - err := fs.refresh(ctx) - if err != nil { - fs.logger.Err(err).Msg("Failed to refresh feature selector") - } - } - } -} - -func (fs *FeatureSelector) refresh(ctx context.Context) error { +func (fs *FeatureSelector) init(ctx context.Context) error { record, err := fs.resolver.lookupRecord(ctx) if err != nil { return err @@ -141,9 +105,6 @@ func (fs *FeatureSelector) refresh(ctx context.Context) error { return err } - fs.lock.Lock() - defer fs.lock.Unlock() - fs.features = features return nil @@ -180,8 +141,8 @@ func (dr *dnsResolver) lookupRecord(ctx context.Context) ([]byte, error) { return []byte(records[0]), nil } -func switchThreshold(accountTag string) int32 { +func switchThreshold(accountTag string) uint32 { h := fnv.New32a() _, _ = h.Write([]byte(accountTag)) - return int32(h.Sum32() % 100) + return h.Sum32() % 100 } diff --git a/features/selector_test.go b/features/selector_test.go index 5c57e3e6..b7d67cc9 100644 --- a/features/selector_test.go +++ b/features/selector_test.go @@ -3,9 +3,7 @@ package features import ( "context" "encoding/json" - "fmt" "testing" - "time" "github.com/rs/zerolog" "github.com/stretchr/testify/require" @@ -14,33 +12,23 @@ import ( func TestUnmarshalFeaturesRecord(t *testing.T) { tests := []struct { record []byte - expectedPercentage int32 + expectedPercentage uint32 }{ - { - record: []byte(`{"dv3":0}`), - expectedPercentage: 0, - }, - { - record: []byte(`{"dv3":39}`), - expectedPercentage: 39, - }, - { - record: []byte(`{"dv3":100}`), - expectedPercentage: 100, - }, { record: []byte(`{}`), // Unmarshal to default struct if key is not present }, { record: []byte(`{"kyber":768}`), // Unmarshal to default struct if key is not present }, + { + record: []byte(`{"pq": 101,"dv3":100}`), // Expired keys don't unmarshal to anything + }, } for _, test := range tests { var features featuresRecord err := json.Unmarshal(test.record, &features) require.NoError(t, err) - require.Equal(t, test.expectedPercentage, features.DatagramV3Percentage, test) } } @@ -61,7 +49,7 @@ func TestFeaturePrecedenceEvaluationPostQuantum(t *testing.T) { { name: "user_specified", cli: true, - expectedFeatures: Dedup(append(defaultFeatures, FeaturePostQuantum)), + expectedFeatures: dedupAndRemoveFeatures(append(defaultFeatures, FeaturePostQuantum)), expectedVersion: PostQuantumStrict, }, } @@ -69,7 +57,7 @@ func TestFeaturePrecedenceEvaluationPostQuantum(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { resolver := &staticResolver{record: featuresRecord{}} - selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, []string{}, test.cli, time.Second) + selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, []string{}, test.cli) require.NoError(t, err) require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures()) require.Equal(t, test.expectedVersion, selector.PostQuantumMode()) @@ -102,44 +90,17 @@ func TestFeaturePrecedenceEvaluationDatagramVersion(t *testing.T) { }, { name: "user_specified_v3", - cli: []string{FeatureDatagramV3}, + cli: []string{FeatureDatagramV3_1}, remote: featuresRecord{}, - expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)), - expectedVersion: FeatureDatagramV3, - }, - { - name: "remote_specified_v3", - cli: []string{}, - remote: featuresRecord{ - DatagramV3Percentage: 100, - }, - expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)), - expectedVersion: FeatureDatagramV3, - }, - { - name: "remote_and_user_specified_v3", - cli: []string{FeatureDatagramV3}, - remote: featuresRecord{ - DatagramV3Percentage: 100, - }, - expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)), - expectedVersion: FeatureDatagramV3, - }, - { - name: "remote_v3_and_user_specified_v2", - cli: []string{FeatureDatagramV2}, - remote: featuresRecord{ - DatagramV3Percentage: 100, - }, - expectedFeatures: defaultFeatures, - expectedVersion: DatagramV2, + expectedFeatures: dedupAndRemoveFeatures(append(defaultFeatures, FeatureDatagramV3_1)), + expectedVersion: FeatureDatagramV3_1, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { resolver := &staticResolver{record: test.remote} - selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false, time.Second) + selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false) require.NoError(t, err) require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures()) require.Equal(t, test.expectedVersion, selector.DatagramVersion()) @@ -147,75 +108,59 @@ func TestFeaturePrecedenceEvaluationDatagramVersion(t *testing.T) { } } -func TestRefreshFeaturesRecord(t *testing.T) { - // The hash of the accountTag is 82 - accountTag := t.Name() - threshold := switchThreshold(accountTag) - - percentages := []int32{0, 10, 81, 82, 83, 100, 101, 1000} - refreshFreq := time.Millisecond * 10 - selector := newTestSelector(t, percentages, false, refreshFreq) - - // Starting out should default to DatagramV2 - require.Equal(t, DatagramV2, selector.DatagramVersion()) - - for _, percentage := range percentages { - if percentage > threshold { - require.Equal(t, DatagramV3, selector.DatagramVersion()) - } else { - require.Equal(t, DatagramV2, selector.DatagramVersion()) - } - - time.Sleep(refreshFreq + time.Millisecond) +func TestDeprecatedFeaturesRemoved(t *testing.T) { + logger := zerolog.Nop() + tests := []struct { + name string + cli []string + remote featuresRecord + expectedFeatures []string + }{ + { + name: "no_removals", + cli: []string{}, + remote: featuresRecord{}, + expectedFeatures: defaultFeatures, + }, + { + name: "support_datagram_v3", + cli: []string{DeprecatedFeatureDatagramV3}, + remote: featuresRecord{}, + expectedFeatures: defaultFeatures, + }, } - // Make sure error doesn't override the last fetched features - require.Equal(t, DatagramV3, selector.DatagramVersion()) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + resolver := &staticResolver{record: test.remote} + selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false) + require.NoError(t, err) + require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures()) + }) + } } func TestStaticFeatures(t *testing.T) { - percentages := []int32{0} + percentages := []uint32{0} // PostQuantum Enabled from user flag - selector := newTestSelector(t, percentages, true, time.Millisecond*10) + selector := newTestSelector(t, percentages, true) require.Equal(t, PostQuantumStrict, selector.PostQuantumMode()) // PostQuantum Disabled (or not set) - selector = newTestSelector(t, percentages, false, time.Millisecond*10) + selector = newTestSelector(t, percentages, false) require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode()) } -func newTestSelector(t *testing.T, percentages []int32, pq bool, refreshFreq time.Duration) *FeatureSelector { +func newTestSelector(t *testing.T, percentages []uint32, pq bool) *FeatureSelector { accountTag := t.Name() logger := zerolog.Nop() - resolver := &mockResolver{ - percentages: percentages, - } - - selector, err := newFeatureSelector(context.Background(), accountTag, &logger, resolver, []string{}, pq, refreshFreq) + selector, err := newFeatureSelector(context.Background(), accountTag, &logger, &staticResolver{}, []string{}, pq) require.NoError(t, err) return selector } -type mockResolver struct { - nextIndex int - percentages []int32 -} - -func (mr *mockResolver) lookupRecord(ctx context.Context) ([]byte, error) { - if mr.nextIndex >= len(mr.percentages) { - return nil, fmt.Errorf("no more record to lookup") - } - - record, err := json.Marshal(featuresRecord{ - DatagramV3Percentage: mr.percentages[mr.nextIndex], - }) - mr.nextIndex++ - - return record, err -} - type staticResolver struct { record featuresRecord }