diff --git a/ingress/icmp_darwin.go b/ingress/icmp_darwin.go index 640a2bb1..dd03e46b 100644 --- a/ingress/icmp_darwin.go +++ b/ingress/icmp_darwin.go @@ -157,6 +157,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa } span.SetAttributes(attribute.Int("assignedEchoID", int(assignedEchoID))) + shouldReplaceFunnelFunc := createShouldReplaceFunnelFunc(ip.logger, responder.datagramMuxer, pk, originalEcho.ID) newFunnelFunc := func() (packet.Funnel, error) { originalEcho, err := getICMPEcho(pk.Message) if err != nil { @@ -170,7 +171,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa return icmpFlow, nil } funnelID := echoFunnelID(assignedEchoID) - funnel, isNew, err := ip.srcFunnelTracker.GetOrRegister(funnelID, newFunnelFunc) + funnel, isNew, err := ip.srcFunnelTracker.GetOrRegister(funnelID, shouldReplaceFunnelFunc, newFunnelFunc) if err != nil { tracing.EndWithErrorStatus(span, err) return err diff --git a/ingress/icmp_linux.go b/ingress/icmp_linux.go index b40d88bb..894025c2 100644 --- a/ingress/icmp_linux.go +++ b/ingress/icmp_linux.go @@ -112,6 +112,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa attribute.Int("seq", originalEcho.Seq), ) + shouldReplaceFunnelFunc := createShouldReplaceFunnelFunc(ip.logger, responder.datagramMuxer, pk, originalEcho.ID) newFunnelFunc := func() (packet.Funnel, error) { conn, err := newICMPConn(ip.listenIP, ip.ipv6Zone) if err != nil { @@ -137,7 +138,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa dstIP: pk.Dst, originalEchoID: originalEcho.ID, } - funnel, isNew, err := ip.srcFunnelTracker.GetOrRegister(funnelID, newFunnelFunc) + funnel, isNew, err := ip.srcFunnelTracker.GetOrRegister(funnelID, shouldReplaceFunnelFunc, newFunnelFunc) if err != nil { tracing.EndWithErrorStatus(span, err) return err diff --git a/ingress/icmp_posix.go b/ingress/icmp_posix.go index 95962667..504df60a 100644 --- a/ingress/icmp_posix.go +++ b/ingress/icmp_posix.go @@ -10,6 +10,7 @@ import ( "net/netip" "github.com/google/gopacket/layers" + "github.com/rs/zerolog" "golang.org/x/net/icmp" "github.com/cloudflare/cloudflared/packet" @@ -174,3 +175,30 @@ func toICMPEchoFlow(funnel packet.Funnel) (*icmpEchoFlow, error) { } return icmpFlow, nil } + +func createShouldReplaceFunnelFunc(logger *zerolog.Logger, muxer muxer, pk *packet.ICMP, originalEchoID int) func(packet.Funnel) bool { + return func(existing packet.Funnel) bool { + existingFlow, err := toICMPEchoFlow(existing) + if err != nil { + logger.Err(err). + Str("src", pk.Src.String()). + Str("dst", pk.Dst.String()). + Int("originalEchoID", originalEchoID). + Msg("Funnel of wrong type found") + return true + } + // Each quic connection should have a unique muxer. + // If the existing flow has a different muxer, there's a new quic connection where return packets should be + // routed. Otherwise, return packets will be send to the first observed incoming connection, rather than the + // most recently observed connection. + if existingFlow.responder.datagramMuxer != muxer { + logger.Debug(). + Str("src", pk.Src.String()). + Str("dst", pk.Dst.String()). + Int("originalEchoID", originalEchoID). + Msg("Replacing funnel with new responder") + return true + } + return false + } +} diff --git a/ingress/icmp_posix_test.go b/ingress/icmp_posix_test.go index a857dacc..2e81a65b 100644 --- a/ingress/icmp_posix_test.go +++ b/ingress/icmp_posix_test.go @@ -52,18 +52,28 @@ func TestFunnelIdleTimeout(t *testing.T) { }, }, } + funnelID := flow3Tuple{ + srcIP: pk.Src, + dstIP: pk.Dst, + originalEchoID: echoID, + } muxer := newMockMuxer(0) responder := packetResponder{ datagramMuxer: muxer, } require.NoError(t, proxy.Request(ctx, &pk, &responder)) validateEchoFlow(t, <-muxer.cfdToEdge, &pk) + funnel1, found := proxy.srcFunnelTracker.Get(funnelID) + require.True(t, found) // Send second request, should reuse the funnel require.NoError(t, proxy.Request(ctx, &pk, &packetResponder{ - datagramMuxer: nil, + datagramMuxer: muxer, })) validateEchoFlow(t, <-muxer.cfdToEdge, &pk) + funnel2, found := proxy.srcFunnelTracker.Get(funnelID) + require.True(t, found) + require.Equal(t, funnel1, funnel2) time.Sleep(idleTimeout * 2) newMuxer := newMockMuxer(0) diff --git a/packet/funnel.go b/packet/funnel.go index 3b5cfeb6..c76e4a03 100644 --- a/packet/funnel.go +++ b/packet/funnel.go @@ -108,13 +108,23 @@ func (ft *FunnelTracker) Get(id FunnelID) (Funnel, bool) { return funnel, ok } -// Registers a funnel. It replaces the current funnel. -func (ft *FunnelTracker) GetOrRegister(id FunnelID, newFunnelFunc func() (Funnel, error)) (funnel Funnel, new bool, err error) { +// Registers a funnel. If the `id` is already registered and `shouldReplaceFunc` returns true, it closes and replaces +// the current funnel. If `newFunnelFunc` returns an error, the `id` will remain unregistered, even if it was registered +// when calling this function. +func (ft *FunnelTracker) GetOrRegister( + id FunnelID, + shouldReplaceFunc func(Funnel) bool, + newFunnelFunc func() (Funnel, error), +) (funnel Funnel, new bool, err error) { ft.lock.Lock() defer ft.lock.Unlock() currentFunnel, exists := ft.funnels[id] if exists { - return currentFunnel, false, nil + if !shouldReplaceFunc(currentFunnel) { + return currentFunnel, false, nil + } + currentFunnel.Close() + delete(ft.funnels, id) } newFunnel, err := newFunnelFunc() if err != nil { @@ -124,7 +134,7 @@ func (ft *FunnelTracker) GetOrRegister(id FunnelID, newFunnelFunc func() (Funnel return newFunnel, true, nil } -// Unregisters a funnel if the funnel equals to the current funnel +// Unregisters and closes a funnel if the funnel equals to the current funnel func (ft *FunnelTracker) Unregister(id FunnelID, funnel Funnel) (deleted bool) { ft.lock.Lock() defer ft.lock.Unlock() diff --git a/packet/funnel_test.go b/packet/funnel_test.go index 08dc291f..762c917d 100644 --- a/packet/funnel_test.go +++ b/packet/funnel_test.go @@ -1,6 +1,13 @@ package packet -import "net/netip" +import ( + "fmt" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/require" +) type mockFunnelUniPipe struct { uniPipe chan RawPacket @@ -14,3 +21,118 @@ func (mfui *mockFunnelUniPipe) SendPacket(dst netip.Addr, pk RawPacket) error { func (mfui *mockFunnelUniPipe) Close() error { return nil } + +func TestFunnelRegistration(t *testing.T) { + id := testFunnelID{"id1"} + funnelErr := fmt.Errorf("expected error") + newFunnelFuncErr := func() (Funnel, error) { return nil, funnelErr } + newFunnelFuncUncalled := func() (Funnel, error) { + require.FailNow(t, "a new funnel should not be created") + panic("unreached") + } + funnel1, newFunnelFunc1 := newFunnelAndFunc("funnel1") + funnel2, newFunnelFunc2 := newFunnelAndFunc("funnel2") + + ft := NewFunnelTracker() + // Register funnel1 + funnel, new, err := ft.GetOrRegister(id, shouldReplaceFalse, newFunnelFunc1) + require.NoError(t, err) + require.True(t, new) + require.Equal(t, funnel1, funnel) + // Register funnel, no replace + funnel, new, err = ft.GetOrRegister(id, shouldReplaceFalse, newFunnelFuncUncalled) + require.NoError(t, err) + require.False(t, new) + require.Equal(t, funnel1, funnel) + // Register funnel2, replace + funnel, new, err = ft.GetOrRegister(id, shouldReplaceTrue, newFunnelFunc2) + require.NoError(t, err) + require.True(t, new) + require.Equal(t, funnel2, funnel) + require.True(t, funnel1.closed) + // Register funnel error, replace + funnel, new, err = ft.GetOrRegister(id, shouldReplaceTrue, newFunnelFuncErr) + require.ErrorIs(t, err, funnelErr) + require.False(t, new) + require.Nil(t, funnel) + require.True(t, funnel2.closed) +} + +func TestFunnelUnregister(t *testing.T) { + id := testFunnelID{"id1"} + funnel1, newFunnelFunc1 := newFunnelAndFunc("funnel1") + funnel2, newFunnelFunc2 := newFunnelAndFunc("funnel2") + funnel3, newFunnelFunc3 := newFunnelAndFunc("funnel3") + + ft := NewFunnelTracker() + // Register & unregister + _, _, err := ft.GetOrRegister(id, shouldReplaceFalse, newFunnelFunc1) + require.NoError(t, err) + require.True(t, ft.Unregister(id, funnel1)) + require.True(t, funnel1.closed) + require.True(t, ft.Unregister(id, funnel1)) + // Register, replace, and unregister + _, _, err = ft.GetOrRegister(id, shouldReplaceFalse, newFunnelFunc2) + require.NoError(t, err) + _, _, err = ft.GetOrRegister(id, shouldReplaceTrue, newFunnelFunc3) + require.NoError(t, err) + require.True(t, funnel2.closed) + require.False(t, ft.Unregister(id, funnel2)) + require.True(t, ft.Unregister(id, funnel3)) + require.True(t, funnel3.closed) +} + +func shouldReplaceFalse(_ Funnel) bool { + return false +} + +func shouldReplaceTrue(_ Funnel) bool { + return true +} + +func newFunnelAndFunc(id string) (*testFunnel, func() (Funnel, error)) { + funnel := newTestFunnel(id) + funnelFunc := func() (Funnel, error) { + return funnel, nil + } + return funnel, funnelFunc +} + +type testFunnelID struct { + id string +} + +func (t testFunnelID) Type() string { + return "testFunnelID" +} + +func (t testFunnelID) String() string { + return t.id +} + +type testFunnel struct { + id string + closed bool +} + +func newTestFunnel(id string) *testFunnel { + return &testFunnel{ + id, + false, + } +} + +func (tf *testFunnel) Close() error { + tf.closed = true + return nil +} + +func (tf *testFunnel) Equal(other Funnel) bool { + return tf.id == other.(*testFunnel).id +} + +func (tf *testFunnel) LastActive() time.Time { + return time.Now() +} + +func (tf *testFunnel) UpdateLastActive() {}