TUN-9291: Remove dynamic reloading of features for datagram v3

During a refresh of the supported features via the DNS TXT record,
cloudflared would update the internal feature list, but would not
propagate this information to the edge during a new connection.

This meant that a situation could occur in which cloudflared would
think that the client's connection could support datagram V3, and
would setup that muxer locally, but would not propagate that information
to the edge during a register connection in the `ClientInfo` of the
`ConnectionOptions`. This meant that the edge still thought that the
client was setup to support datagram V2 and since the protocols are
not backwards compatible, the local muxer for datagram V3 would reject
the incoming RPC calls.

To address this, the feature list will be fetched only once during
client bootstrapping and will persist as-is until the client is restarted.
This helps reduce the complexity involved with different connections
having possibly different sets of features when connecting to the edge.
The features will now be tied to the client and never diverge across
connections.

Also, retires the use of `support_datagram_v3` in-favor of
`support_datagram_v3_1` to reduce the risk of reusing the feature key.
The `dv3` TXT feature key is also deprecated.

Closes TUN-9291
This commit is contained in:
Devin Carr 2025-05-07 23:21:08 +00:00
parent 40dc601e9d
commit ce27840573
8 changed files with 84 additions and 181 deletions

View File

@ -1,15 +0,0 @@
package tunnel
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/features"
)
func TestDedup(t *testing.T) {
expected := []string{"a", "b"}
actual := features.Dedup([]string{"a", "b", "a"})
require.ElementsMatch(t, expected, actual)
}

View File

@ -140,7 +140,7 @@ func prepareTunnelConfig(
transportProtocol := c.String(flags.Protocol) transportProtocol := c.String(flags.Protocol)
isPostQuantumEnforced := c.Bool(flags.PostQuantum) isPostQuantumEnforced := c.Bool(flags.PostQuantum)
featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, c.StringSlice("features"), c.Bool("post-quantum"), log) featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, c.StringSlice(flags.Features), c.Bool(flags.PostQuantum), log)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "Failed to create feature selector") return nil, nil, errors.Wrap(err, "Failed to create feature selector")
} }

View File

@ -10,7 +10,7 @@ import (
"github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/management"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
// registerClient derives a named tunnel rpc client that can then be used to register and unregister connections. // registerClient derives a named tunnel rpc client that can then be used to register and unregister connections.
@ -36,7 +36,7 @@ type controlStream struct {
// ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown. // ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown.
type ControlStreamHandler interface { type ControlStreamHandler interface {
// ServeControlStream handles the control plane of the transport in the current goroutine calling this // ServeControlStream handles the control plane of the transport in the current goroutine calling this
ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, tunnelConfigGetter TunnelConfigJSONGetter) error ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *pogs.ConnectionOptions, tunnelConfigGetter TunnelConfigJSONGetter) error
// IsStopped tells whether the method above has finished // IsStopped tells whether the method above has finished
IsStopped() bool IsStopped() bool
} }
@ -78,7 +78,7 @@ func NewControlStream(
func (c *controlStream) ServeControlStream( func (c *controlStream) ServeControlStream(
ctx context.Context, ctx context.Context,
rw io.ReadWriteCloser, rw io.ReadWriteCloser,
connOptions *tunnelpogs.ConnectionOptions, connOptions *pogs.ConnectionOptions,
tunnelConfigGetter TunnelConfigJSONGetter, tunnelConfigGetter TunnelConfigJSONGetter,
) error { ) error {
registrationClient := c.registerClientFunc(ctx, rw, c.registerTimeout) registrationClient := c.registerClientFunc(ctx, rw, c.registerTimeout)

View File

@ -19,7 +19,7 @@ import (
cfdflow "github.com/cloudflare/cloudflared/flow" cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tracing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
// note: these constants are exported so we can reuse them in the edge-side code // note: these constants are exported so we can reuse them in the edge-side code
@ -39,7 +39,7 @@ type HTTP2Connection struct {
conn net.Conn conn net.Conn
server *http2.Server server *http2.Server
orchestrator Orchestrator orchestrator Orchestrator
connOptions *tunnelpogs.ConnectionOptions connOptions *pogs.ConnectionOptions
observer *Observer observer *Observer
connIndex uint8 connIndex uint8
@ -54,7 +54,7 @@ type HTTP2Connection struct {
func NewHTTP2Connection( func NewHTTP2Connection(
conn net.Conn, conn net.Conn,
orchestrator Orchestrator, orchestrator Orchestrator,
connOptions *tunnelpogs.ConnectionOptions, connOptions *pogs.ConnectionOptions,
observer *Observer, observer *Observer,
connIndex uint8, connIndex uint8,
controlStreamHandler ControlStreamHandler, controlStreamHandler ControlStreamHandler,

View File

@ -22,7 +22,6 @@ import (
cfdquic "github.com/cloudflare/cloudflared/quic" cfdquic "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic" rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic"
) )
@ -44,7 +43,7 @@ type quicConnection struct {
orchestrator Orchestrator orchestrator Orchestrator
datagramHandler DatagramSessionHandler datagramHandler DatagramSessionHandler
controlStreamHandler ControlStreamHandler controlStreamHandler ControlStreamHandler
connOptions *tunnelpogs.ConnectionOptions connOptions *pogs.ConnectionOptions
connIndex uint8 connIndex uint8
rpcTimeout time.Duration rpcTimeout time.Duration
@ -235,7 +234,7 @@ func (q *quicConnection) dispatchRequest(ctx context.Context, stream *rpcquic.Re
} }
// UpdateConfiguration is the RPC method invoked by edge when there is a new configuration // UpdateConfiguration is the RPC method invoked by edge when there is a new configuration
func (q *quicConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse { func (q *quicConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) *pogs.UpdateConfigurationResponse {
return q.orchestrator.UpdateConfig(version, config) return q.orchestrator.UpdateConfig(version, config)
} }

View File

@ -1,5 +1,7 @@
package features package features
import "slices"
const ( const (
FeatureSerializedHeaders = "serialized_headers" FeatureSerializedHeaders = "serialized_headers"
FeatureQuickReconnects = "quick_reconnects" FeatureQuickReconnects = "quick_reconnects"
@ -8,7 +10,9 @@ const (
FeaturePostQuantum = "postquantum" FeaturePostQuantum = "postquantum"
FeatureQUICSupportEOF = "support_quic_eof" FeatureQUICSupportEOF = "support_quic_eof"
FeatureManagementLogs = "management_logs" FeatureManagementLogs = "management_logs"
FeatureDatagramV3 = "support_datagram_v3" FeatureDatagramV3_1 = "support_datagram_v3_1"
DeprecatedFeatureDatagramV3 = "support_datagram_v3" // Deprecated: TUN-9291
) )
var defaultFeatures = []string{ var defaultFeatures = []string{
@ -19,6 +23,11 @@ var defaultFeatures = []string{
FeatureManagementLogs, FeatureManagementLogs,
} }
// List of features that are no longer in-use.
var deprecatedFeatures = []string{
DeprecatedFeatureDatagramV3,
}
// Features set by user provided flags // Features set by user provided flags
type staticFeatures struct { type staticFeatures struct {
PostQuantumMode *PostQuantumMode PostQuantumMode *PostQuantumMode
@ -40,15 +49,19 @@ const (
// DatagramV2 is the currently supported datagram protocol for UDP and ICMP packets // DatagramV2 is the currently supported datagram protocol for UDP and ICMP packets
DatagramV2 DatagramVersion = FeatureDatagramV2 DatagramV2 DatagramVersion = FeatureDatagramV2
// DatagramV3 is a new datagram protocol for UDP and ICMP packets. It is not backwards compatible with datagram v2. // DatagramV3 is a new datagram protocol for UDP and ICMP packets. It is not backwards compatible with datagram v2.
DatagramV3 DatagramVersion = FeatureDatagramV3 DatagramV3 DatagramVersion = FeatureDatagramV3_1
) )
// Remove any duplicates from the slice // Remove any duplicate features from the list and remove deprecated features
func Dedup(slice []string) []string { func dedupAndRemoveFeatures(features []string) []string {
// Convert the slice into a set // Convert the slice into a set
set := make(map[string]bool, 0) set := map[string]bool{}
for _, str := range slice { for _, feature := range features {
set[str] = true // Remove deprecated features from the provided list
if slices.Contains(deprecatedFeatures, feature) {
continue
}
set[feature] = true
} }
// Convert the set back into a slice // Convert the set back into a slice

View File

@ -7,7 +7,6 @@ import (
"hash/fnv" "hash/fnv"
"net" "net"
"slices" "slices"
"sync"
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -15,7 +14,6 @@ import (
const ( const (
featureSelectorHostname = "cfd-features.argotunnel.com" featureSelectorHostname = "cfd-features.argotunnel.com"
defaultRefreshFreq = time.Hour * 6
lookupTimeout = time.Second * 10 lookupTimeout = time.Second * 10
) )
@ -23,32 +21,27 @@ const (
// If the TXT record is missing a key, the field will unmarshal to the default Go value // If the TXT record is missing a key, the field will unmarshal to the default Go value
type featuresRecord struct { type featuresRecord struct {
// support_datagram_v3 // DatagramV3Percentage int32 `json:"dv3"` // Removed in TUN-9291
DatagramV3Percentage int32 `json:"dv3"`
// PostQuantumPercentage int32 `json:"pq"` // Removed in TUN-7970 // PostQuantumPercentage int32 `json:"pq"` // Removed in TUN-7970
} }
func NewFeatureSelector(ctx context.Context, accountTag string, cliFeatures []string, pq bool, logger *zerolog.Logger) (*FeatureSelector, error) { func NewFeatureSelector(ctx context.Context, accountTag string, cliFeatures []string, pq bool, logger *zerolog.Logger) (*FeatureSelector, error) {
return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), cliFeatures, pq, defaultRefreshFreq) return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), cliFeatures, pq)
} }
// FeatureSelector determines if this account will try new features. It periodically queries a DNS TXT record // FeatureSelector determines if this account will try new features; loaded once during startup.
// to see which features are turned on.
type FeatureSelector struct { type FeatureSelector struct {
accountHash int32 accountHash uint32
logger *zerolog.Logger logger *zerolog.Logger
resolver resolver resolver resolver
staticFeatures staticFeatures staticFeatures staticFeatures
cliFeatures []string cliFeatures []string
// lock protects concurrent access to dynamic features
lock sync.RWMutex
features featuresRecord features featuresRecord
} }
func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, cliFeatures []string, pq bool, refreshFreq time.Duration) (*FeatureSelector, error) { func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, cliFeatures []string, pq bool) (*FeatureSelector, error) {
// Combine default features and user-provided features // Combine default features and user-provided features
var pqMode *PostQuantumMode var pqMode *PostQuantumMode
if pq { if pq {
@ -64,22 +57,16 @@ func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.
logger: logger, logger: logger,
resolver: resolver, resolver: resolver,
staticFeatures: staticFeatures, staticFeatures: staticFeatures,
cliFeatures: Dedup(cliFeatures), cliFeatures: dedupAndRemoveFeatures(cliFeatures),
} }
if err := selector.refresh(ctx); err != nil { if err := selector.init(ctx); err != nil {
logger.Err(err).Msg("Failed to fetch features, default to disable") logger.Err(err).Msg("Failed to fetch features, default to disable")
} }
go selector.refreshLoop(ctx, refreshFreq)
return selector, nil return selector, nil
} }
func (fs *FeatureSelector) accountEnabled(percentage int32) bool {
return percentage > fs.accountHash
}
func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode { func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode {
if fs.staticFeatures.PostQuantumMode != nil { if fs.staticFeatures.PostQuantumMode != nil {
return *fs.staticFeatures.PostQuantumMode return *fs.staticFeatures.PostQuantumMode
@ -89,11 +76,8 @@ func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode {
} }
func (fs *FeatureSelector) DatagramVersion() DatagramVersion { func (fs *FeatureSelector) DatagramVersion() DatagramVersion {
fs.lock.RLock()
defer fs.lock.RUnlock()
// If user provides the feature via the cli, we take it as priority over remote feature evaluation // If user provides the feature via the cli, we take it as priority over remote feature evaluation
if slices.Contains(fs.cliFeatures, FeatureDatagramV3) { if slices.Contains(fs.cliFeatures, FeatureDatagramV3_1) {
return DatagramV3 return DatagramV3
} }
// If the user specifies DatagramV2, we also take that over remote // If the user specifies DatagramV2, we also take that over remote
@ -101,36 +85,16 @@ func (fs *FeatureSelector) DatagramVersion() DatagramVersion {
return DatagramV2 return DatagramV2
} }
if fs.accountEnabled(fs.features.DatagramV3Percentage) {
return DatagramV3
}
return DatagramV2 return DatagramV2
} }
// ClientFeatures will return the list of currently available features that cloudflared should provide to the edge. // ClientFeatures will return the list of currently available features that cloudflared should provide to the edge.
//
// This list is dynamic and can change in-between returns.
func (fs *FeatureSelector) ClientFeatures() []string { func (fs *FeatureSelector) ClientFeatures() []string {
// Evaluate any remote features along with static feature list to construct the list of features // Evaluate any remote features along with static feature list to construct the list of features
return Dedup(slices.Concat(defaultFeatures, fs.cliFeatures, []string{string(fs.DatagramVersion())})) return dedupAndRemoveFeatures(slices.Concat(defaultFeatures, fs.cliFeatures, []string{string(fs.DatagramVersion())}))
} }
func (fs *FeatureSelector) refreshLoop(ctx context.Context, refreshFreq time.Duration) { func (fs *FeatureSelector) init(ctx context.Context) error {
ticker := time.NewTicker(refreshFreq)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
err := fs.refresh(ctx)
if err != nil {
fs.logger.Err(err).Msg("Failed to refresh feature selector")
}
}
}
}
func (fs *FeatureSelector) refresh(ctx context.Context) error {
record, err := fs.resolver.lookupRecord(ctx) record, err := fs.resolver.lookupRecord(ctx)
if err != nil { if err != nil {
return err return err
@ -141,9 +105,6 @@ func (fs *FeatureSelector) refresh(ctx context.Context) error {
return err return err
} }
fs.lock.Lock()
defer fs.lock.Unlock()
fs.features = features fs.features = features
return nil return nil
@ -180,8 +141,8 @@ func (dr *dnsResolver) lookupRecord(ctx context.Context) ([]byte, error) {
return []byte(records[0]), nil return []byte(records[0]), nil
} }
func switchThreshold(accountTag string) int32 { func switchThreshold(accountTag string) uint32 {
h := fnv.New32a() h := fnv.New32a()
_, _ = h.Write([]byte(accountTag)) _, _ = h.Write([]byte(accountTag))
return int32(h.Sum32() % 100) return h.Sum32() % 100
} }

View File

@ -3,9 +3,7 @@ package features
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"testing" "testing"
"time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -14,33 +12,23 @@ import (
func TestUnmarshalFeaturesRecord(t *testing.T) { func TestUnmarshalFeaturesRecord(t *testing.T) {
tests := []struct { tests := []struct {
record []byte record []byte
expectedPercentage int32 expectedPercentage uint32
}{ }{
{
record: []byte(`{"dv3":0}`),
expectedPercentage: 0,
},
{
record: []byte(`{"dv3":39}`),
expectedPercentage: 39,
},
{
record: []byte(`{"dv3":100}`),
expectedPercentage: 100,
},
{ {
record: []byte(`{}`), // Unmarshal to default struct if key is not present record: []byte(`{}`), // Unmarshal to default struct if key is not present
}, },
{ {
record: []byte(`{"kyber":768}`), // Unmarshal to default struct if key is not present record: []byte(`{"kyber":768}`), // Unmarshal to default struct if key is not present
}, },
{
record: []byte(`{"pq": 101,"dv3":100}`), // Expired keys don't unmarshal to anything
},
} }
for _, test := range tests { for _, test := range tests {
var features featuresRecord var features featuresRecord
err := json.Unmarshal(test.record, &features) err := json.Unmarshal(test.record, &features)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, test.expectedPercentage, features.DatagramV3Percentage, test)
} }
} }
@ -61,7 +49,7 @@ func TestFeaturePrecedenceEvaluationPostQuantum(t *testing.T) {
{ {
name: "user_specified", name: "user_specified",
cli: true, cli: true,
expectedFeatures: Dedup(append(defaultFeatures, FeaturePostQuantum)), expectedFeatures: dedupAndRemoveFeatures(append(defaultFeatures, FeaturePostQuantum)),
expectedVersion: PostQuantumStrict, expectedVersion: PostQuantumStrict,
}, },
} }
@ -69,7 +57,7 @@ func TestFeaturePrecedenceEvaluationPostQuantum(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
resolver := &staticResolver{record: featuresRecord{}} resolver := &staticResolver{record: featuresRecord{}}
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, []string{}, test.cli, time.Second) selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, []string{}, test.cli)
require.NoError(t, err) require.NoError(t, err)
require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures()) require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures())
require.Equal(t, test.expectedVersion, selector.PostQuantumMode()) require.Equal(t, test.expectedVersion, selector.PostQuantumMode())
@ -102,44 +90,17 @@ func TestFeaturePrecedenceEvaluationDatagramVersion(t *testing.T) {
}, },
{ {
name: "user_specified_v3", name: "user_specified_v3",
cli: []string{FeatureDatagramV3}, cli: []string{FeatureDatagramV3_1},
remote: featuresRecord{}, remote: featuresRecord{},
expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)), expectedFeatures: dedupAndRemoveFeatures(append(defaultFeatures, FeatureDatagramV3_1)),
expectedVersion: FeatureDatagramV3, expectedVersion: FeatureDatagramV3_1,
},
{
name: "remote_specified_v3",
cli: []string{},
remote: featuresRecord{
DatagramV3Percentage: 100,
},
expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)),
expectedVersion: FeatureDatagramV3,
},
{
name: "remote_and_user_specified_v3",
cli: []string{FeatureDatagramV3},
remote: featuresRecord{
DatagramV3Percentage: 100,
},
expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)),
expectedVersion: FeatureDatagramV3,
},
{
name: "remote_v3_and_user_specified_v2",
cli: []string{FeatureDatagramV2},
remote: featuresRecord{
DatagramV3Percentage: 100,
},
expectedFeatures: defaultFeatures,
expectedVersion: DatagramV2,
}, },
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
resolver := &staticResolver{record: test.remote} resolver := &staticResolver{record: test.remote}
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false, time.Second) selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false)
require.NoError(t, err) require.NoError(t, err)
require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures()) require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures())
require.Equal(t, test.expectedVersion, selector.DatagramVersion()) require.Equal(t, test.expectedVersion, selector.DatagramVersion())
@ -147,75 +108,59 @@ func TestFeaturePrecedenceEvaluationDatagramVersion(t *testing.T) {
} }
} }
func TestRefreshFeaturesRecord(t *testing.T) { func TestDeprecatedFeaturesRemoved(t *testing.T) {
// The hash of the accountTag is 82 logger := zerolog.Nop()
accountTag := t.Name() tests := []struct {
threshold := switchThreshold(accountTag) name string
cli []string
percentages := []int32{0, 10, 81, 82, 83, 100, 101, 1000} remote featuresRecord
refreshFreq := time.Millisecond * 10 expectedFeatures []string
selector := newTestSelector(t, percentages, false, refreshFreq) }{
{
// Starting out should default to DatagramV2 name: "no_removals",
require.Equal(t, DatagramV2, selector.DatagramVersion()) cli: []string{},
remote: featuresRecord{},
for _, percentage := range percentages { expectedFeatures: defaultFeatures,
if percentage > threshold { },
require.Equal(t, DatagramV3, selector.DatagramVersion()) {
} else { name: "support_datagram_v3",
require.Equal(t, DatagramV2, selector.DatagramVersion()) cli: []string{DeprecatedFeatureDatagramV3},
remote: featuresRecord{},
expectedFeatures: defaultFeatures,
},
} }
time.Sleep(refreshFreq + time.Millisecond) for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
resolver := &staticResolver{record: test.remote}
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false)
require.NoError(t, err)
require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures())
})
} }
// Make sure error doesn't override the last fetched features
require.Equal(t, DatagramV3, selector.DatagramVersion())
} }
func TestStaticFeatures(t *testing.T) { func TestStaticFeatures(t *testing.T) {
percentages := []int32{0} percentages := []uint32{0}
// PostQuantum Enabled from user flag // PostQuantum Enabled from user flag
selector := newTestSelector(t, percentages, true, time.Millisecond*10) selector := newTestSelector(t, percentages, true)
require.Equal(t, PostQuantumStrict, selector.PostQuantumMode()) require.Equal(t, PostQuantumStrict, selector.PostQuantumMode())
// PostQuantum Disabled (or not set) // PostQuantum Disabled (or not set)
selector = newTestSelector(t, percentages, false, time.Millisecond*10) selector = newTestSelector(t, percentages, false)
require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode()) require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode())
} }
func newTestSelector(t *testing.T, percentages []int32, pq bool, refreshFreq time.Duration) *FeatureSelector { func newTestSelector(t *testing.T, percentages []uint32, pq bool) *FeatureSelector {
accountTag := t.Name() accountTag := t.Name()
logger := zerolog.Nop() logger := zerolog.Nop()
resolver := &mockResolver{ selector, err := newFeatureSelector(context.Background(), accountTag, &logger, &staticResolver{}, []string{}, pq)
percentages: percentages,
}
selector, err := newFeatureSelector(context.Background(), accountTag, &logger, resolver, []string{}, pq, refreshFreq)
require.NoError(t, err) require.NoError(t, err)
return selector return selector
} }
type mockResolver struct {
nextIndex int
percentages []int32
}
func (mr *mockResolver) lookupRecord(ctx context.Context) ([]byte, error) {
if mr.nextIndex >= len(mr.percentages) {
return nil, fmt.Errorf("no more record to lookup")
}
record, err := json.Marshal(featuresRecord{
DatagramV3Percentage: mr.percentages[mr.nextIndex],
})
mr.nextIndex++
return record, err
}
type staticResolver struct { type staticResolver struct {
record featuresRecord record featuresRecord
} }