package features import ( "context" "encoding/json" "fmt" "testing" "time" "github.com/rs/zerolog" "github.com/stretchr/testify/require" ) func TestUnmarshalFeaturesRecord(t *testing.T) { tests := []struct { record []byte expectedPercentage int32 }{ { record: []byte(`{"dv3":0}`), expectedPercentage: 0, }, { record: []byte(`{"dv3":39}`), expectedPercentage: 39, }, { record: []byte(`{"dv3":100}`), expectedPercentage: 100, }, { record: []byte(`{}`), // Unmarshal to default struct if key is not present }, { record: []byte(`{"kyber":768}`), // Unmarshal to default struct if key is not present }, } for _, test := range tests { var features featuresRecord err := json.Unmarshal(test.record, &features) require.NoError(t, err) require.Equal(t, test.expectedPercentage, features.DatagramV3Percentage, test) } } func TestFeaturePrecedenceEvaluationPostQuantum(t *testing.T) { logger := zerolog.Nop() tests := []struct { name string cli bool expectedFeatures []string expectedVersion PostQuantumMode }{ { name: "default", cli: false, expectedFeatures: defaultFeatures, expectedVersion: PostQuantumPrefer, }, { name: "user_specified", cli: true, expectedFeatures: Dedup(append(defaultFeatures, FeaturePostQuantum)), expectedVersion: PostQuantumStrict, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { resolver := &staticResolver{record: featuresRecord{}} selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, []string{}, test.cli, time.Second) require.NoError(t, err) require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures()) require.Equal(t, test.expectedVersion, selector.PostQuantumMode()) }) } } func TestFeaturePrecedenceEvaluationDatagramVersion(t *testing.T) { logger := zerolog.Nop() tests := []struct { name string cli []string remote featuresRecord expectedFeatures []string expectedVersion DatagramVersion }{ { name: "default", cli: []string{}, remote: featuresRecord{}, expectedFeatures: defaultFeatures, expectedVersion: DatagramV2, }, { name: "user_specified_v2", cli: []string{FeatureDatagramV2}, remote: featuresRecord{}, expectedFeatures: defaultFeatures, expectedVersion: DatagramV2, }, { name: "user_specified_v3", cli: []string{FeatureDatagramV3}, remote: featuresRecord{}, expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)), expectedVersion: FeatureDatagramV3, }, { name: "remote_specified_v3", cli: []string{}, remote: featuresRecord{ DatagramV3Percentage: 100, }, expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)), expectedVersion: FeatureDatagramV3, }, { name: "remote_and_user_specified_v3", cli: []string{FeatureDatagramV3}, remote: featuresRecord{ DatagramV3Percentage: 100, }, expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)), expectedVersion: FeatureDatagramV3, }, { name: "remote_v3_and_user_specified_v2", cli: []string{FeatureDatagramV2}, remote: featuresRecord{ DatagramV3Percentage: 100, }, expectedFeatures: defaultFeatures, expectedVersion: DatagramV2, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { resolver := &staticResolver{record: test.remote} selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false, time.Second) require.NoError(t, err) require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures()) require.Equal(t, test.expectedVersion, selector.DatagramVersion()) }) } } func TestRefreshFeaturesRecord(t *testing.T) { // The hash of the accountTag is 82 accountTag := t.Name() threshold := switchThreshold(accountTag) percentages := []int32{0, 10, 81, 82, 83, 100, 101, 1000} refreshFreq := time.Millisecond * 10 selector := newTestSelector(t, percentages, false, refreshFreq) // Starting out should default to DatagramV2 require.Equal(t, DatagramV2, selector.DatagramVersion()) for _, percentage := range percentages { if percentage > threshold { require.Equal(t, DatagramV3, selector.DatagramVersion()) } else { require.Equal(t, DatagramV2, selector.DatagramVersion()) } time.Sleep(refreshFreq + time.Millisecond) } // Make sure error doesn't override the last fetched features require.Equal(t, DatagramV3, selector.DatagramVersion()) } func TestStaticFeatures(t *testing.T) { percentages := []int32{0} // PostQuantum Enabled from user flag selector := newTestSelector(t, percentages, true, time.Millisecond*10) require.Equal(t, PostQuantumStrict, selector.PostQuantumMode()) // PostQuantum Disabled (or not set) selector = newTestSelector(t, percentages, false, time.Millisecond*10) require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode()) } func newTestSelector(t *testing.T, percentages []int32, pq bool, refreshFreq time.Duration) *FeatureSelector { accountTag := t.Name() logger := zerolog.Nop() resolver := &mockResolver{ percentages: percentages, } selector, err := newFeatureSelector(context.Background(), accountTag, &logger, resolver, []string{}, pq, refreshFreq) require.NoError(t, err) return selector } type mockResolver struct { nextIndex int percentages []int32 } func (mr *mockResolver) lookupRecord(ctx context.Context) ([]byte, error) { if mr.nextIndex >= len(mr.percentages) { return nil, fmt.Errorf("no more record to lookup") } record, err := json.Marshal(featuresRecord{ DatagramV3Percentage: mr.percentages[mr.nextIndex], }) mr.nextIndex++ return record, err } type staticResolver struct { record featuresRecord } func (r *staticResolver) lookupRecord(ctx context.Context) ([]byte, error) { return json.Marshal(r.record) }