diff --git a/connection/connection.go b/connection/connection.go index 5d2db19c..cf17d506 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -39,6 +39,7 @@ type Orchestrator interface { UpdateConfig(version int32, config []byte) *pogs.UpdateConfigurationResponse GetConfigJSON() ([]byte, error) GetOriginProxy() (OriginProxy, error) + WarpRoutingEnabled() (enabled bool) } type NamedTunnelProperties struct { diff --git a/connection/connection_test.go b/connection/connection_test.go index 3708e16a..14c20c0d 100644 --- a/connection/connection_test.go +++ b/connection/connection_test.go @@ -56,6 +56,10 @@ func (mcr *mockOrchestrator) GetOriginProxy() (OriginProxy, error) { return mcr.originProxy, nil } +func (mcr *mockOrchestrator) WarpRoutingEnabled() (enabled bool) { + return true +} + type mockOriginProxy struct{} func (moc *mockOriginProxy) ProxyHTTP( diff --git a/connection/quic.go b/connection/quic.go index 3004b929..a6d6f6f0 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -75,7 +75,7 @@ func NewQUICConnection( sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) datagramMuxer := quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan) sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) - packetRouter := packet.NewRouter(packetRouterConfig, datagramMuxer, &returnPipe{muxer: datagramMuxer}, logger) + packetRouter := packet.NewRouter(packetRouterConfig, datagramMuxer, &returnPipe{muxer: datagramMuxer}, logger, orchestrator.WarpRoutingEnabled) return &QUICConnection{ session: session, diff --git a/orchestration/orchestrator.go b/orchestration/orchestrator.go index 0659411a..6bf62a5d 100644 --- a/orchestration/orchestrator.go +++ b/orchestration/orchestrator.go @@ -27,10 +27,12 @@ type Orchestrator struct { // Used by UpdateConfig to make sure one update at a time lock sync.RWMutex // Underlying value is proxy.Proxy, can be read without the lock, but still needs the lock to update - proxy atomic.Value - config *Config - tags []tunnelpogs.Tag - log *zerolog.Logger + proxy atomic.Value + // TODO: TUN-6815 Use atomic.Bool once we upgrade to go 1.19. 1 Means enabled and 0 means disabled + warpRoutingEnabled uint32 + config *Config + tags []tunnelpogs.Tag + log *zerolog.Logger // orchestrator must not handle any more updates after shutdownC is closed shutdownC <-chan struct{} @@ -122,6 +124,11 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i o.proxy.Store(newProxy) o.config.Ingress = &ingressRules o.config.WarpRouting = warpRouting + if warpRouting.Enabled { + atomic.StoreUint32(&o.warpRoutingEnabled, 1) + } else { + atomic.StoreUint32(&o.warpRoutingEnabled, 0) + } // If proxyShutdownC is nil, there is no previous running proxy if o.proxyShutdownC != nil { @@ -190,6 +197,14 @@ func (o *Orchestrator) GetOriginProxy() (connection.OriginProxy, error) { return proxy, nil } +// TODO: TUN-6815 consider storing WarpRouting.Enabled as atomic.Bool once we upgrade to go 1.19 +func (o *Orchestrator) WarpRoutingEnabled() (enabled bool) { + if atomic.LoadUint32(&o.warpRoutingEnabled) == 0 { + return false + } + return true +} + func (o *Orchestrator) waitToCloseLastProxy() { <-o.shutdownC o.lock.Lock() diff --git a/orchestration/orchestrator_test.go b/orchestration/orchestrator_test.go index 5c18b3a7..d3e1ee62 100644 --- a/orchestration/orchestrator_test.go +++ b/orchestration/orchestrator_test.go @@ -55,6 +55,7 @@ func TestUpdateConfiguration(t *testing.T) { initOriginProxy, err := orchestrator.GetOriginProxy() require.NoError(t, err) require.IsType(t, &proxy.Proxy{}, initOriginProxy) + require.False(t, orchestrator.WarpRoutingEnabled()) configJSONV2 := []byte(` { @@ -122,6 +123,7 @@ func TestUpdateConfiguration(t *testing.T) { require.Equal(t, false, configV2.Ingress.Rules[2].Config.NoTLSVerify) require.Equal(t, true, configV2.Ingress.Rules[2].Config.NoHappyEyeballs) require.True(t, configV2.WarpRouting.Enabled) + require.Equal(t, configV2.WarpRouting.Enabled, orchestrator.WarpRoutingEnabled()) require.Equal(t, configV2.WarpRouting.ConnectTimeout.Duration, 10*time.Second) originProxyV2, err := orchestrator.GetOriginProxy() @@ -166,6 +168,7 @@ func TestUpdateConfiguration(t *testing.T) { require.True(t, configV10.Ingress.Rules[0].Matches("blogs.tunnel.io", "/2022/02/10")) require.Equal(t, ingress.HelloWorldService, configV10.Ingress.Rules[0].Service.String()) require.False(t, configV10.WarpRouting.Enabled) + require.Equal(t, configV10.WarpRouting.Enabled, orchestrator.WarpRoutingEnabled()) originProxyV10, err := orchestrator.GetOriginProxy() require.NoError(t, err) diff --git a/packet/router.go b/packet/router.go index 3fc451fc..29a5d839 100644 --- a/packet/router.go +++ b/packet/router.go @@ -23,10 +23,11 @@ type Upstream interface { // Router routes packets between Upstream and ICMPRouter. Currently it rejects all other type of ICMP packets type Router struct { - upstream Upstream - returnPipe FunnelUniPipe - globalConfig *GlobalRouterConfig - logger *zerolog.Logger + upstream Upstream + returnPipe FunnelUniPipe + globalConfig *GlobalRouterConfig + logger *zerolog.Logger + checkRouterEnabledFunc func() bool } // GlobalRouterConfig is the configuration shared by all instance of Router. @@ -37,12 +38,13 @@ type GlobalRouterConfig struct { Zone string } -func NewRouter(globalConfig *GlobalRouterConfig, upstream Upstream, returnPipe FunnelUniPipe, logger *zerolog.Logger) *Router { +func NewRouter(globalConfig *GlobalRouterConfig, upstream Upstream, returnPipe FunnelUniPipe, logger *zerolog.Logger, checkRouterEnabledFunc func() bool) *Router { return &Router{ - upstream: upstream, - returnPipe: returnPipe, - globalConfig: globalConfig, - logger: logger, + upstream: upstream, + returnPipe: returnPipe, + globalConfig: globalConfig, + logger: logger, + checkRouterEnabledFunc: checkRouterEnabledFunc, } } @@ -54,10 +56,16 @@ func (r *Router) Serve(ctx context.Context) error { if err != nil { return err } + // Drop packets if ICMPRouter wasn't created if r.globalConfig == nil { continue } + + if enabled := r.checkRouterEnabledFunc(); !enabled { + continue + } + icmpPacket, err := icmpDecoder.Decode(rawPacket) if err != nil { r.logger.Err(err).Msg("Failed to decode ICMP packet from quic datagram") diff --git a/packet/router_test.go b/packet/router_test.go index c1056450..4998456e 100644 --- a/packet/router_test.go +++ b/packet/router_test.go @@ -5,7 +5,9 @@ import ( "context" "fmt" "net/netip" + "sync/atomic" "testing" + "time" "github.com/google/gopacket/layers" "github.com/rs/zerolog" @@ -16,7 +18,12 @@ import ( ) var ( - noopLogger = zerolog.Nop() + noopLogger = zerolog.Nop() + packetConfig = &GlobalRouterConfig{ + ICMPRouter: &mockICMPRouter{}, + IPv4Src: netip.MustParseAddr("172.16.0.1"), + IPv6Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"), + } ) func TestRouterReturnTTLExceed(t *testing.T) { @@ -26,12 +33,9 @@ func TestRouterReturnTTLExceed(t *testing.T) { returnPipe := &mockFunnelUniPipe{ uniPipe: make(chan RawPacket), } - packetConfig := &GlobalRouterConfig{ - ICMPRouter: &mockICMPRouter{}, - IPv4Src: netip.MustParseAddr("172.16.0.1"), - IPv6Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"), - } - router := NewRouter(packetConfig, upstream, returnPipe, &noopLogger) + routerEnabled := &routerEnabledChecker{} + routerEnabled.set(true) + router := NewRouter(packetConfig, upstream, returnPipe, &noopLogger, routerEnabled.isEnabled) ctx, cancel := context.WithCancel(context.Background()) routerStopped := make(chan struct{}) go func() { @@ -80,12 +84,71 @@ func TestRouterReturnTTLExceed(t *testing.T) { <-routerStopped } +func TestRouterCheckEnabled(t *testing.T) { + upstream := &mockUpstream{ + source: make(chan RawPacket), + } + returnPipe := &mockFunnelUniPipe{ + uniPipe: make(chan RawPacket), + } + routerEnabled := &routerEnabledChecker{} + router := NewRouter(packetConfig, upstream, returnPipe, &noopLogger, routerEnabled.isEnabled) + ctx, cancel := context.WithCancel(context.Background()) + routerStopped := make(chan struct{}) + go func() { + router.Serve(ctx) + close(routerStopped) + }() + + pk := ICMP{ + IP: &IP{ + Src: netip.MustParseAddr("192.168.1.1"), + Dst: netip.MustParseAddr("10.0.0.1"), + Protocol: layers.IPProtocolICMPv4, + TTL: 1, + }, + Message: &icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: 12481, + Seq: 8036, + Data: []byte(t.Name()), + }, + }, + } + + // router is disabled + require.NoError(t, upstream.send(&pk)) + select { + case <-time.After(time.Millisecond * 10): + case <-returnPipe.uniPipe: + t.Error("Unexpected reply when router is disabled") + } + routerEnabled.set(true) + // router is enabled, expects reply + require.NoError(t, upstream.send(&pk)) + <-returnPipe.uniPipe + + routerEnabled.set(false) + // router is disabled + require.NoError(t, upstream.send(&pk)) + select { + case <-time.After(time.Millisecond * 10): + case <-returnPipe.uniPipe: + t.Error("Unexpected reply when router is disabled") + } + + cancel() + <-routerStopped +} + func assertTTLExceed(t *testing.T, originalPacket *ICMP, expectedSrc netip.Addr, upstream *mockUpstream, returnPipe *mockFunnelUniPipe) { encoder := NewEncoder() rawPacket, err := encoder.Encode(originalPacket) require.NoError(t, err) - upstream.source <- rawPacket + resp := <-returnPipe.uniPipe decoder := NewICMPDecoder() decoded, err := decoder.Decode(resp) @@ -111,6 +174,16 @@ type mockUpstream struct { source chan RawPacket } +func (ms *mockUpstream) send(pk Packet) error { + encoder := NewEncoder() + rawPacket, err := encoder.Encode(pk) + if err != nil { + return err + } + ms.source <- rawPacket + return nil +} + func (ms *mockUpstream) ReceivePacket(ctx context.Context) (RawPacket, error) { select { case <-ctx.Done(): @@ -129,3 +202,22 @@ func (mir mockICMPRouter) Serve(ctx context.Context) error { func (mir mockICMPRouter) Request(pk *ICMP, responder FunnelUniPipe) error { return fmt.Errorf("Request not implemented by mockICMPRouter") } + +type routerEnabledChecker struct { + enabled uint32 +} + +func (rec *routerEnabledChecker) isEnabled() bool { + if atomic.LoadUint32(&rec.enabled) == 0 { + return false + } + return true +} + +func (rec *routerEnabledChecker) set(enabled bool) { + if enabled { + atomic.StoreUint32(&rec.enabled, 1) + } else { + atomic.StoreUint32(&rec.enabled, 0) + } +}