diff --git a/features/selector.go b/features/selector.go index 64587bb6..3e952dd2 100644 --- a/features/selector.go +++ b/features/selector.go @@ -21,9 +21,8 @@ const ( type PostQuantumMode uint8 const ( - PostQuantumDisabled PostQuantumMode = iota // Prefer post quantum, but fallback if connection cannot be established - PostQuantumPrefer + PostQuantumPrefer PostQuantumMode = iota // If the user passes the --post-quantum flag, we override // CurvePreferences to only support hybrid post-quantum key agreements. PostQuantumStrict @@ -31,9 +30,8 @@ const ( // If the TXT record adds other fields, the umarshal logic will ignore those keys // If the TXT record is missing a key, the field will unmarshal to the default Go value -type featuresRecord struct { - PostQuantumPercentage int32 `json:"pq"` -} +// pq was removed in TUN-7970 +type featuresRecord struct{} func NewFeatureSelector(ctx context.Context, accountTag string, staticFeatures StaticFeatures, logger *zerolog.Logger) (*FeatureSelector, error) { return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), staticFeatures, defaultRefreshFreq) @@ -70,7 +68,7 @@ func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog. logger.Err(err).Msg("Failed to fetch features, default to disable") } - go selector.refreshLoop(ctx, refreshFreq) + // Run refreshLoop next time we have a new feature to rollout return selector, nil } @@ -80,13 +78,7 @@ func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode { return *fs.staticFeatures.PostQuantumMode } - fs.lock.RLock() - defer fs.lock.RUnlock() - - if fs.features.PostQuantumPercentage > fs.accountHash { - return PostQuantumPrefer - } - return PostQuantumDisabled + return PostQuantumPrefer } func (fs *FeatureSelector) refreshLoop(ctx context.Context, refreshFreq time.Duration) { @@ -115,9 +107,6 @@ func (fs *FeatureSelector) refresh(ctx context.Context) error { return err } - pq_enabled := features.PostQuantumPercentage > fs.accountHash - fs.logger.Debug().Int32("account_hash", fs.accountHash).Int32("pq_perct", features.PostQuantumPercentage).Bool("pq_enabled", pq_enabled).Msg("Refreshed feature") - fs.lock.Lock() defer fs.lock.Unlock() diff --git a/features/selector_test.go b/features/selector_test.go index 4decba7b..b8739c6b 100644 --- a/features/selector_test.go +++ b/features/selector_test.go @@ -13,92 +13,48 @@ import ( func TestUnmarshalFeaturesRecord(t *testing.T) { tests := []struct { - record []byte - expectedPercentage int32 - expectedErr bool + record []byte }{ { - record: []byte(`{"pq":0}`), - expectedPercentage: 0, + record: []byte(`{"pq":0}`), }, { - record: []byte(`{"pq":39}`), - expectedPercentage: 39, + record: []byte(`{"pq":39}`), }, { - record: []byte(`{"pq":100}`), - expectedPercentage: 100, + record: []byte(`{"pq":100}`), }, { - record: []byte(`{}`), - expectedPercentage: 0, // Unmarshal to default struct if key is not present + record: []byte(`{}`), // Unmarshal to default struct if key is not present }, { - record: []byte(`{"kyber":768}`), - expectedPercentage: 0, // Unmarshal to default struct if key is not present - }, - { - record: []byte(`{"pq":"kyber768"}`), - expectedErr: true, + record: []byte(`{"kyber":768}`), // Unmarshal to default struct if key is not present }, } for _, test := range tests { var features featuresRecord err := json.Unmarshal(test.record, &features) - if test.expectedErr { - require.Error(t, err, test) - } else { - require.NoError(t, err) - require.Equal(t, test.expectedPercentage, features.PostQuantumPercentage, test) - } + require.NoError(t, err) + require.Equal(t, featuresRecord{}, features) } } -func TestRefreshFeaturesRecord(t *testing.T) { - // The hash of the accountTag is 82 - accountTag := t.Name() - threshold := switchThreshold(accountTag) - - percentages := []int32{0, 10, 80, 83, 100} - refreshFreq := time.Millisecond * 10 - selector := newTestSelector(t, percentages, nil, refreshFreq) - - for _, percentage := range percentages { - if percentage > threshold { - require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode()) - } else { - require.Equal(t, PostQuantumDisabled, selector.PostQuantumMode()) - } - - time.Sleep(refreshFreq + time.Millisecond) - } - - // Make sure error doesn't override the last fetched features - require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode()) -} - func TestStaticFeatures(t *testing.T) { - percentages := []int32{0} pqMode := PostQuantumStrict - selector := newTestSelector(t, percentages, &pqMode, time.Millisecond*10) + selector := newTestSelector(t, &pqMode, time.Millisecond*10) require.Equal(t, PostQuantumStrict, selector.PostQuantumMode()) + + // No StaticFeatures configured + selector = newTestSelector(t, nil, time.Millisecond*10) + require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode()) } -// Verify that if the first refresh fails, the selector will use default features -func TestFailedRefreshInitToDefault(t *testing.T) { - selector := newTestSelector(t, []int32{}, nil, time.Second) - require.Equal(t, featuresRecord{}, selector.features) - require.Equal(t, PostQuantumDisabled, selector.PostQuantumMode()) -} - -func newTestSelector(t *testing.T, percentages []int32, pqMode *PostQuantumMode, refreshFreq time.Duration) *FeatureSelector { +func newTestSelector(t *testing.T, pqMode *PostQuantumMode, refreshFreq time.Duration) *FeatureSelector { accountTag := t.Name() logger := zerolog.Nop() - resolver := &mockResolver{ - percentages: percentages, - } + resolver := &mockResolver{} staticFeatures := StaticFeatures{ PostQuantumMode: pqMode, @@ -109,20 +65,8 @@ func newTestSelector(t *testing.T, percentages []int32, pqMode *PostQuantumMode, return selector } -type mockResolver struct { - nextIndex int - percentages []int32 -} +type mockResolver struct{} func (mr *mockResolver) lookupRecord(ctx context.Context) ([]byte, error) { - if mr.nextIndex >= len(mr.percentages) { - return nil, fmt.Errorf("no more record to lookup") - } - - record, err := json.Marshal(featuresRecord{ - PostQuantumPercentage: mr.percentages[mr.nextIndex], - }) - mr.nextIndex++ - - return record, err + return nil, fmt.Errorf("mockResolver hasn't implement lookupRecord") } diff --git a/supervisor/pqtunnels.go b/supervisor/pqtunnels.go index 19a9b742..f8dce98d 100644 --- a/supervisor/pqtunnels.go +++ b/supervisor/pqtunnels.go @@ -1,12 +1,8 @@ package supervisor import ( - "bytes" "crypto/tls" - "encoding/json" "fmt" - "net/http" - "sync" "github.com/cloudflare/cloudflared/features" ) @@ -20,84 +16,6 @@ const ( PQKexName = "X25519Kyber768Draft00" ) -var ( - pqtMux sync.Mutex // protects pqtSubmitted and pqtWaitForMessage - pqtSubmitted bool // whether an error has already been submitted - - // Number of errors to ignore before printing elaborate instructions. - pqtWaitForMessage int -) - -func handlePQTunnelError(rep error, config *TunnelConfig) { - needToMessage := false - - pqtMux.Lock() - needToSubmit := !pqtSubmitted - if needToSubmit { - pqtSubmitted = true - } - pqtWaitForMessage-- - if pqtWaitForMessage < 0 { - pqtWaitForMessage = 5 - needToMessage = true - } - pqtMux.Unlock() - - if needToMessage { - config.Log.Info().Msgf( - "\n\n" + - "===================================================================================\n" + - "You are hitting an error while using the experimental post-quantum tunnels feature.\n" + - "\n" + - "Please check:\n" + - "\n" + - " https://pqtunnels.cloudflareresearch.com\n" + - "\n" + - "for known problems.\n" + - "===================================================================================\n\n", - ) - } - - if needToSubmit { - go submitPQTunnelError(rep, config) - } -} - -func submitPQTunnelError(rep error, config *TunnelConfig) { - body, err := json.Marshal(struct { - Group int `json:"g"` - Message string `json:"m"` - Version string `json:"v"` - }{ - Group: int(PQKex), - Message: rep.Error(), - Version: config.ReportedVersion, - }) - if err != nil { - config.Log.Err(err).Msg("Failed to create error report") - return - } - - resp, err := http.Post( - "https://pqtunnels.cloudflareresearch.com", - "application/json", - bytes.NewBuffer(body), - ) - if err != nil { - config.Log.Err(err).Msg( - "Failed to submit post-quantum tunnel error report", - ) - return - } - if resp.StatusCode != 200 { - config.Log.Error().Msgf( - "Failed to submit post-quantum tunnel error report: status %d", - resp.StatusCode, - ) - } - resp.Body.Close() -} - func curvePreference(pqMode features.PostQuantumMode, currentCurve []tls.CurveID) ([]tls.CurveID, error) { switch pqMode { case features.PostQuantumStrict: @@ -113,15 +31,6 @@ func curvePreference(pqMode features.PostQuantumMode, currentCurve []tls.CurveID return append([]tls.CurveID{PQKex}, currentCurve...), nil } return currentCurve, nil - case features.PostQuantumDisabled: - curvePref := currentCurve - // Remove PQ from curve preference - for i, curve := range currentCurve { - if curve == PQKex { - curvePref = append(curvePref[:i], curvePref[i+1:]...) - } - } - return curvePref, nil default: return nil, fmt.Errorf("Unexpected post quantum mode") } diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 46fed38d..8e9afbd3 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -616,10 +616,6 @@ func (e *EdgeTunnelServer) serveQUIC( e.config.UDPUnregisterSessionTimeout, ) if err != nil { - if pqMode == features.PostQuantumStrict || pqMode == features.PostQuantumPrefer { - handlePQTunnelError(err, e.config) - } - connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection") return err, true }