diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 78f3402e..7de76a54 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -392,7 +392,7 @@ func StartServer( observer.SendURL(quickTunnelURL) } - tunnelConfig, orchestratorConfig, err := prepareTunnelConfig(c, info, log, logTransport, observer, namedTunnel) + tunnelConfig, orchestratorConfig, err := prepareTunnelConfig(ctx, c, info, log, logTransport, observer, namedTunnel) if err != nil { log.Err(err).Msg("Couldn't start tunnel") return err diff --git a/cmd/cloudflared/tunnel/config_test.go b/cmd/cloudflared/tunnel/config_test.go index 929d7dfe..0ca6c436 100644 --- a/cmd/cloudflared/tunnel/config_test.go +++ b/cmd/cloudflared/tunnel/config_test.go @@ -4,10 +4,12 @@ import ( "testing" "github.com/stretchr/testify/require" + + "github.com/cloudflare/cloudflared/features" ) func TestDedup(t *testing.T) { expected := []string{"a", "b"} - actual := dedup([]string{"a", "b", "a"}) + actual := features.Dedup([]string{"a", "b", "a"}) require.ElementsMatch(t, expected, actual) } diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 427ac87f..68780b00 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -1,6 +1,7 @@ package tunnel import ( + "context" "crypto/tls" "fmt" "net" @@ -112,6 +113,7 @@ func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelPrope } func prepareTunnelConfig( + ctx context.Context, c *cli.Context, info *cliutil.BuildInfo, log, logTransport *zerolog.Logger, @@ -131,22 +133,36 @@ func prepareTunnelConfig( tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID.String()}) transportProtocol := c.String("protocol") - needPQ := c.Bool("post-quantum") - if needPQ { + + clientFeatures := features.Dedup(append(c.StringSlice("features"), features.DefaultFeatures...)) + + staticFeatures := features.StaticFeatures{} + if c.Bool("post-quantum") { if FipsEnabled { return nil, nil, fmt.Errorf("post-quantum not supported in FIPS mode") } + pqMode := features.PostQuantumStrict + staticFeatures.PostQuantumMode = &pqMode + } + featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, staticFeatures, log) + if err != nil { + return nil, nil, errors.Wrap(err, "Failed to create feature selector") + } + pqMode := featureSelector.PostQuantumMode() + if pqMode == features.PostQuantumStrict { // Error if the user tries to force a non-quic transport protocol if transportProtocol != connection.AutoSelectFlag && transportProtocol != connection.QUIC.String() { return nil, nil, fmt.Errorf("post-quantum is only supported with the quic transport") } transportProtocol = connection.QUIC.String() + clientFeatures = append(clientFeatures, features.FeaturePostQuantum) + + log.Info().Msgf( + "Using hybrid post-quantum key agreement %s", + supervisor.PQKexName, + ) } - clientFeatures := dedup(append(c.StringSlice("features"), features.DefaultFeatures...)) - if needPQ { - clientFeatures = append(clientFeatures, features.FeaturePostQuantum) - } namedTunnel.Client = tunnelpogs.ClientInfo{ ClientID: clientID[:], Features: clientFeatures, @@ -202,13 +218,6 @@ func prepareTunnelConfig( log.Warn().Str("edgeIPVersion", edgeIPVersion.String()).Err(err).Msg("Overriding edge-ip-version") } - if needPQ { - log.Info().Msgf( - "Using hybrid post-quantum key agreement %s", - supervisor.PQKexName, - ) - } - tunnelConfig := &supervisor.TunnelConfig{ GracePeriod: gracePeriod, ReplaceExisting: c.Bool("force"), @@ -233,7 +242,7 @@ func prepareTunnelConfig( NamedTunnel: namedTunnel, ProtocolSelector: protocolSelector, EdgeTLSConfigs: edgeTLSConfigs, - NeedPQ: needPQ, + FeatureSelector: featureSelector, MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")), UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag), DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery), @@ -276,25 +285,6 @@ func isRunningFromTerminal() bool { return term.IsTerminal(int(os.Stdout.Fd())) } -// Remove any duplicates from the slice -func dedup(slice []string) []string { - - // Convert the slice into a set - set := make(map[string]bool, 0) - for _, str := range slice { - set[str] = true - } - - // Convert the set back into a slice - keys := make([]string, len(set)) - i := 0 - for str := range set { - keys[i] = str - i++ - } - return keys -} - // ParseConfigIPVersion returns the IP version from possible expected values from config func parseConfigIPVersion(version string) (v allregions.ConfigIPVersion, err error) { switch version { diff --git a/features/features.go b/features/features.go index d6aaa6a6..76f8ff8f 100644 --- a/features/features.go +++ b/features/features.go @@ -28,3 +28,22 @@ func Contains(feature string) bool { } return false } + +// Remove any duplicates from the slice +func Dedup(slice []string) []string { + + // Convert the slice into a set + set := make(map[string]bool, 0) + for _, str := range slice { + set[str] = true + } + + // Convert the set back into a slice + keys := make([]string, len(set)) + i := 0 + for str := range set { + keys[i] = str + i++ + } + return keys +} diff --git a/features/selector.go b/features/selector.go new file mode 100644 index 00000000..64587bb6 --- /dev/null +++ b/features/selector.go @@ -0,0 +1,164 @@ +package features + +import ( + "context" + "encoding/json" + "fmt" + "hash/fnv" + "net" + "sync" + "time" + + "github.com/rs/zerolog" +) + +const ( + featureSelectorHostname = "cfd-features.argotunnel.com" + defaultRefreshFreq = time.Hour * 6 + lookupTimeout = time.Second * 10 +) + +type PostQuantumMode uint8 + +const ( + PostQuantumDisabled PostQuantumMode = iota + // Prefer post quantum, but fallback if connection cannot be established + PostQuantumPrefer + // If the user passes the --post-quantum flag, we override + // CurvePreferences to only support hybrid post-quantum key agreements. + PostQuantumStrict +) + +// If the TXT record adds other fields, the umarshal logic will ignore those keys +// If the TXT record is missing a key, the field will unmarshal to the default Go value +type featuresRecord struct { + PostQuantumPercentage int32 `json:"pq"` +} + +func NewFeatureSelector(ctx context.Context, accountTag string, staticFeatures StaticFeatures, logger *zerolog.Logger) (*FeatureSelector, error) { + return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), staticFeatures, defaultRefreshFreq) +} + +// FeatureSelector determines if this account will try new features. It preiodically queries a DNS TXT record +// to see which features are turned on +type FeatureSelector struct { + accountHash int32 + logger *zerolog.Logger + resolver resolver + + staticFeatures StaticFeatures + + // lock protects concurrent access to dynamic features + lock sync.RWMutex + features featuresRecord +} + +// Features set by user provided flags +type StaticFeatures struct { + PostQuantumMode *PostQuantumMode +} + +func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, staticFeatures StaticFeatures, refreshFreq time.Duration) (*FeatureSelector, error) { + selector := &FeatureSelector{ + accountHash: switchThreshold(accountTag), + logger: logger, + resolver: resolver, + staticFeatures: staticFeatures, + } + + if err := selector.refresh(ctx); err != nil { + logger.Err(err).Msg("Failed to fetch features, default to disable") + } + + go selector.refreshLoop(ctx, refreshFreq) + + return selector, nil +} + +func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode { + if fs.staticFeatures.PostQuantumMode != nil { + return *fs.staticFeatures.PostQuantumMode + } + + fs.lock.RLock() + defer fs.lock.RUnlock() + + if fs.features.PostQuantumPercentage > fs.accountHash { + return PostQuantumPrefer + } + return PostQuantumDisabled +} + +func (fs *FeatureSelector) refreshLoop(ctx context.Context, refreshFreq time.Duration) { + ticker := time.NewTicker(refreshFreq) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + err := fs.refresh(ctx) + if err != nil { + fs.logger.Err(err).Msg("Failed to refresh feature selector") + } + } + } +} + +func (fs *FeatureSelector) refresh(ctx context.Context) error { + record, err := fs.resolver.lookupRecord(ctx) + if err != nil { + return err + } + + var features featuresRecord + if err := json.Unmarshal(record, &features); err != nil { + return err + } + + pq_enabled := features.PostQuantumPercentage > fs.accountHash + fs.logger.Debug().Int32("account_hash", fs.accountHash).Int32("pq_perct", features.PostQuantumPercentage).Bool("pq_enabled", pq_enabled).Msg("Refreshed feature") + + fs.lock.Lock() + defer fs.lock.Unlock() + + fs.features = features + + return nil +} + +// resolver represents an object that can look up featuresRecord +type resolver interface { + lookupRecord(ctx context.Context) ([]byte, error) +} + +type dnsResolver struct { + resolver *net.Resolver +} + +func newDNSResolver() *dnsResolver { + return &dnsResolver{ + resolver: net.DefaultResolver, + } +} + +func (dr *dnsResolver) lookupRecord(ctx context.Context) ([]byte, error) { + ctx, cancel := context.WithTimeout(ctx, lookupTimeout) + defer cancel() + + records, err := dr.resolver.LookupTXT(ctx, featureSelectorHostname) + if err != nil { + return nil, err + } + + if len(records) == 0 { + return nil, fmt.Errorf("No TXT record found for %s to determine which features to opt-in", featureSelectorHostname) + } + + return []byte(records[0]), nil +} + +func switchThreshold(accountTag string) int32 { + h := fnv.New32a() + _, _ = h.Write([]byte(accountTag)) + return int32(h.Sum32() % 100) +} diff --git a/features/selector_test.go b/features/selector_test.go new file mode 100644 index 00000000..4decba7b --- /dev/null +++ b/features/selector_test.go @@ -0,0 +1,128 @@ +package features + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" +) + +func TestUnmarshalFeaturesRecord(t *testing.T) { + tests := []struct { + record []byte + expectedPercentage int32 + expectedErr bool + }{ + { + record: []byte(`{"pq":0}`), + expectedPercentage: 0, + }, + { + record: []byte(`{"pq":39}`), + expectedPercentage: 39, + }, + { + record: []byte(`{"pq":100}`), + expectedPercentage: 100, + }, + { + record: []byte(`{}`), + expectedPercentage: 0, // Unmarshal to default struct if key is not present + }, + { + record: []byte(`{"kyber":768}`), + expectedPercentage: 0, // Unmarshal to default struct if key is not present + }, + { + record: []byte(`{"pq":"kyber768"}`), + expectedErr: true, + }, + } + + for _, test := range tests { + var features featuresRecord + err := json.Unmarshal(test.record, &features) + if test.expectedErr { + require.Error(t, err, test) + } else { + require.NoError(t, err) + require.Equal(t, test.expectedPercentage, features.PostQuantumPercentage, test) + } + } +} + +func TestRefreshFeaturesRecord(t *testing.T) { + // The hash of the accountTag is 82 + accountTag := t.Name() + threshold := switchThreshold(accountTag) + + percentages := []int32{0, 10, 80, 83, 100} + refreshFreq := time.Millisecond * 10 + selector := newTestSelector(t, percentages, nil, refreshFreq) + + for _, percentage := range percentages { + if percentage > threshold { + require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode()) + } else { + require.Equal(t, PostQuantumDisabled, selector.PostQuantumMode()) + } + + time.Sleep(refreshFreq + time.Millisecond) + } + + // Make sure error doesn't override the last fetched features + require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode()) +} + +func TestStaticFeatures(t *testing.T) { + percentages := []int32{0} + pqMode := PostQuantumStrict + selector := newTestSelector(t, percentages, &pqMode, time.Millisecond*10) + require.Equal(t, PostQuantumStrict, selector.PostQuantumMode()) +} + +// Verify that if the first refresh fails, the selector will use default features +func TestFailedRefreshInitToDefault(t *testing.T) { + selector := newTestSelector(t, []int32{}, nil, time.Second) + require.Equal(t, featuresRecord{}, selector.features) + require.Equal(t, PostQuantumDisabled, selector.PostQuantumMode()) +} + +func newTestSelector(t *testing.T, percentages []int32, pqMode *PostQuantumMode, refreshFreq time.Duration) *FeatureSelector { + accountTag := t.Name() + logger := zerolog.Nop() + + resolver := &mockResolver{ + percentages: percentages, + } + + staticFeatures := StaticFeatures{ + PostQuantumMode: pqMode, + } + selector, err := newFeatureSelector(context.Background(), accountTag, &logger, resolver, staticFeatures, refreshFreq) + require.NoError(t, err) + + return selector +} + +type mockResolver struct { + nextIndex int + percentages []int32 +} + +func (mr *mockResolver) lookupRecord(ctx context.Context) ([]byte, error) { + if mr.nextIndex >= len(mr.percentages) { + return nil, fmt.Errorf("no more record to lookup") + } + + record, err := json.Marshal(featuresRecord{ + PostQuantumPercentage: mr.percentages[mr.nextIndex], + }) + mr.nextIndex++ + + return record, err +} diff --git a/supervisor/pqtunnels.go b/supervisor/pqtunnels.go index 67a02e07..19a9b742 100644 --- a/supervisor/pqtunnels.go +++ b/supervisor/pqtunnels.go @@ -4,8 +4,11 @@ import ( "bytes" "crypto/tls" "encoding/json" + "fmt" "net/http" "sync" + + "github.com/cloudflare/cloudflared/features" ) // When experimental post-quantum tunnels are enabled, and we're hitting an @@ -94,3 +97,32 @@ func submitPQTunnelError(rep error, config *TunnelConfig) { } resp.Body.Close() } + +func curvePreference(pqMode features.PostQuantumMode, currentCurve []tls.CurveID) ([]tls.CurveID, error) { + switch pqMode { + case features.PostQuantumStrict: + // If the user passes the -post-quantum flag, we override + // CurvePreferences to only support hybrid post-quantum key agreements. + return []tls.CurveID{PQKex}, nil + case features.PostQuantumPrefer: + if len(currentCurve) == 0 { + return []tls.CurveID{PQKex}, nil + } + + if currentCurve[0] != PQKex { + return append([]tls.CurveID{PQKex}, currentCurve...), nil + } + return currentCurve, nil + case features.PostQuantumDisabled: + curvePref := currentCurve + // Remove PQ from curve preference + for i, curve := range currentCurve { + if curve == PQKex { + curvePref = append(curvePref[:i], curvePref[i+1:]...) + } + } + return curvePref, nil + default: + return nil, fmt.Errorf("Unexpected post quantum mode") + } +} diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 463b7d2b..55edf1b8 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -69,6 +69,8 @@ type TunnelConfig struct { UDPUnregisterSessionTimeout time.Duration DisableQUICPathMTUDiscovery bool + + FeatureSelector *features.FeatureSelector } func (c *TunnelConfig) registrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { @@ -536,7 +538,8 @@ func (e *EdgeTunnelServer) serveHTTP2( controlStreamHandler connection.ControlStreamHandler, connIndex uint8, ) error { - if e.config.NeedPQ { + pqMode := e.config.FeatureSelector.PostQuantumMode() + if pqMode == features.PostQuantumStrict { return unrecoverableError{errors.New("HTTP/2 transport does not support post-quantum")} } @@ -579,14 +582,18 @@ func (e *EdgeTunnelServer) serveQUIC( ) (err error, recoverable bool) { tlsConfig := e.config.EdgeTLSConfigs[connection.QUIC] - if e.config.NeedPQ { - // If the user passes the -post-quantum flag, we override - // CurvePreferences to only support hybrid post-quantum key agreements. - tlsConfig.CurvePreferences = []tls.CurveID{ - PQKex, - } + pqMode := e.config.FeatureSelector.PostQuantumMode() + if pqMode == features.PostQuantumStrict || pqMode == features.PostQuantumPrefer { + connOptions.Client.Features = features.Dedup(append(connOptions.Client.Features, features.FeaturePostQuantum)) } + curvePref, err := curvePreference(pqMode, tlsConfig.CurvePreferences) + if err != nil { + return err, true + } + + tlsConfig.CurvePreferences = curvePref + quicConfig := &quic.Config{ HandshakeIdleTimeout: quicpogs.HandshakeIdleTimeout, MaxIdleTimeout: quicpogs.MaxIdleTimeout, @@ -614,7 +621,7 @@ func (e *EdgeTunnelServer) serveQUIC( e.config.UDPUnregisterSessionTimeout, ) if err != nil { - if e.config.NeedPQ { + if pqMode == features.PostQuantumStrict || pqMode == features.PostQuantumPrefer { handlePQTunnelError(err, e.config) }