cloudflared-mirror/packet/router_test.go

224 lines
5.2 KiB
Go

package packet
import (
"bytes"
"context"
"fmt"
"net/netip"
"sync/atomic"
"testing"
"time"
"github.com/google/gopacket/layers"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
var (
noopLogger = zerolog.Nop()
packetConfig = &GlobalRouterConfig{
ICMPRouter: &mockICMPRouter{},
IPv4Src: netip.MustParseAddr("172.16.0.1"),
IPv6Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
}
)
func TestRouterReturnTTLExceed(t *testing.T) {
upstream := &mockUpstream{
source: make(chan RawPacket),
}
returnPipe := &mockFunnelUniPipe{
uniPipe: make(chan RawPacket),
}
routerEnabled := &routerEnabledChecker{}
routerEnabled.set(true)
router := NewRouter(packetConfig, upstream, returnPipe, &noopLogger, routerEnabled.isEnabled)
ctx, cancel := context.WithCancel(context.Background())
routerStopped := make(chan struct{})
go func() {
router.Serve(ctx)
close(routerStopped)
}()
pk := ICMP{
IP: &IP{
Src: netip.MustParseAddr("192.168.1.1"),
Dst: netip.MustParseAddr("10.0.0.1"),
Protocol: layers.IPProtocolICMPv4,
TTL: 1,
},
Message: &icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: 12481,
Seq: 8036,
Data: []byte("TTL exceed"),
},
},
}
assertTTLExceed(t, &pk, router.globalConfig.IPv4Src, upstream, returnPipe)
pk = ICMP{
IP: &IP{
Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"),
Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"),
Protocol: layers.IPProtocolICMPv6,
TTL: 1,
},
Message: &icmp.Message{
Type: ipv6.ICMPTypeEchoRequest,
Code: 0,
Body: &icmp.Echo{
ID: 42583,
Seq: 7039,
Data: []byte("TTL exceed"),
},
},
}
assertTTLExceed(t, &pk, router.globalConfig.IPv6Src, upstream, returnPipe)
cancel()
<-routerStopped
}
func TestRouterCheckEnabled(t *testing.T) {
upstream := &mockUpstream{
source: make(chan RawPacket),
}
returnPipe := &mockFunnelUniPipe{
uniPipe: make(chan RawPacket),
}
routerEnabled := &routerEnabledChecker{}
router := NewRouter(packetConfig, upstream, returnPipe, &noopLogger, routerEnabled.isEnabled)
ctx, cancel := context.WithCancel(context.Background())
routerStopped := make(chan struct{})
go func() {
router.Serve(ctx)
close(routerStopped)
}()
pk := ICMP{
IP: &IP{
Src: netip.MustParseAddr("192.168.1.1"),
Dst: netip.MustParseAddr("10.0.0.1"),
Protocol: layers.IPProtocolICMPv4,
TTL: 1,
},
Message: &icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: 12481,
Seq: 8036,
Data: []byte(t.Name()),
},
},
}
// router is disabled
require.NoError(t, upstream.send(&pk))
select {
case <-time.After(time.Millisecond * 10):
case <-returnPipe.uniPipe:
t.Error("Unexpected reply when router is disabled")
}
routerEnabled.set(true)
// router is enabled, expects reply
require.NoError(t, upstream.send(&pk))
<-returnPipe.uniPipe
routerEnabled.set(false)
// router is disabled
require.NoError(t, upstream.send(&pk))
select {
case <-time.After(time.Millisecond * 10):
case <-returnPipe.uniPipe:
t.Error("Unexpected reply when router is disabled")
}
cancel()
<-routerStopped
}
func assertTTLExceed(t *testing.T, originalPacket *ICMP, expectedSrc netip.Addr, upstream *mockUpstream, returnPipe *mockFunnelUniPipe) {
encoder := NewEncoder()
rawPacket, err := encoder.Encode(originalPacket)
require.NoError(t, err)
upstream.source <- rawPacket
resp := <-returnPipe.uniPipe
decoder := NewICMPDecoder()
decoded, err := decoder.Decode(resp)
require.NoError(t, err)
require.Equal(t, expectedSrc, decoded.Src)
require.Equal(t, originalPacket.Src, decoded.Dst)
require.Equal(t, originalPacket.Protocol, decoded.Protocol)
require.Equal(t, DefaultTTL, decoded.TTL)
if originalPacket.Dst.Is4() {
require.Equal(t, ipv4.ICMPTypeTimeExceeded, decoded.Type)
} else {
require.Equal(t, ipv6.ICMPTypeTimeExceeded, decoded.Type)
}
require.Equal(t, 0, decoded.Code)
assertICMPChecksum(t, decoded)
timeExceed, ok := decoded.Body.(*icmp.TimeExceeded)
require.True(t, ok)
require.True(t, bytes.Equal(rawPacket.Data, timeExceed.Data))
}
type mockUpstream struct {
source chan RawPacket
}
func (ms *mockUpstream) send(pk Packet) error {
encoder := NewEncoder()
rawPacket, err := encoder.Encode(pk)
if err != nil {
return err
}
ms.source <- rawPacket
return nil
}
func (ms *mockUpstream) ReceivePacket(ctx context.Context) (RawPacket, error) {
select {
case <-ctx.Done():
return RawPacket{}, ctx.Err()
case pk := <-ms.source:
return pk, nil
}
}
type mockICMPRouter struct{}
func (mir mockICMPRouter) Serve(ctx context.Context) error {
return fmt.Errorf("Serve not implemented by mockICMPRouter")
}
func (mir mockICMPRouter) Request(pk *ICMP, responder FunnelUniPipe) error {
return fmt.Errorf("Request not implemented by mockICMPRouter")
}
type routerEnabledChecker struct {
enabled uint32
}
func (rec *routerEnabledChecker) isEnabled() bool {
if atomic.LoadUint32(&rec.enabled) == 0 {
return false
}
return true
}
func (rec *routerEnabledChecker) set(enabled bool) {
if enabled {
atomic.StoreUint32(&rec.enabled, 1)
} else {
atomic.StoreUint32(&rec.enabled, 0)
}
}