diff --git a/connection/quic.go b/connection/quic.go index 640c7be2..86a5228c 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -51,7 +51,7 @@ type QUICConnection struct { sessionManager datagramsession.Manager // datagramMuxer mux/demux datagrams from quic connection datagramMuxer quicpogs.BaseDatagramMuxer - packetRouter *packetRouter + packetRouter *packet.Router controlStreamHandler ControlStreamHandler connOptions *tunnelpogs.ConnectionOptions } @@ -65,7 +65,7 @@ func NewQUICConnection( connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler ControlStreamHandler, logger *zerolog.Logger, - icmpProxy ingress.ICMPProxy, + icmpRouter packet.ICMPRouter, ) (*QUICConnection, error) { session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig) if err != nil { @@ -75,15 +75,12 @@ func NewQUICConnection( sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) var ( datagramMuxer quicpogs.BaseDatagramMuxer - pr *packetRouter + pr *packet.Router ) - if icmpProxy != nil { - pr = &packetRouter{ - muxer: quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan), - icmpProxy: icmpProxy, - logger: logger, - } - datagramMuxer = pr.muxer + if icmpRouter != nil { + datagramMuxerV2 := quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan) + pr = packet.NewRouter(datagramMuxerV2, &returnPipe{muxer: datagramMuxerV2}, icmpRouter, logger) + datagramMuxer = datagramMuxerV2 } else { datagramMuxer = quicpogs.NewDatagramMuxer(session, logger, sessionDemuxChan) } @@ -139,7 +136,7 @@ func (q *QUICConnection) Serve(ctx context.Context) error { if q.packetRouter != nil { errGroup.Go(func() error { defer cancel() - return q.packetRouter.serve(ctx) + return q.packetRouter.Serve(ctx) }) } @@ -348,50 +345,6 @@ func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32, return q.orchestrator.UpdateConfig(version, config) } -type packetRouter struct { - muxer *quicpogs.DatagramMuxerV2 - icmpProxy ingress.ICMPProxy - logger *zerolog.Logger -} - -func (pr *packetRouter) serve(ctx context.Context) error { - icmpDecoder := packet.NewICMPDecoder() - for { - pk, err := pr.muxer.ReceivePacket(ctx) - if err != nil { - return err - } - icmpPacket, err := icmpDecoder.Decode(pk) - if err != nil { - pr.logger.Err(err).Msg("Failed to decode ICMP packet from quic datagram") - continue - } - - flowPipe := muxerResponder{muxer: pr.muxer} - if err := pr.icmpProxy.Request(icmpPacket, &flowPipe); err != nil { - pr.logger.Err(err). - Str("src", icmpPacket.Src.String()). - Str("dst", icmpPacket.Dst.String()). - Interface("type", icmpPacket.Type). - Msg("Failed to send ICMP packet") - continue - } - } -} - -// muxerResponder wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface -type muxerResponder struct { - muxer *quicpogs.DatagramMuxerV2 -} - -func (mr *muxerResponder) SendPacket(dst netip.Addr, pk packet.RawPacket) error { - return mr.muxer.SendPacket(pk) -} - -func (mr *muxerResponder) Close() error { - return nil -} - // streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to // the client. type streamReadWriteAcker struct { @@ -538,3 +491,16 @@ func (np *nopCloserReadWriter) Close() error { return nil } + +// returnPipe wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface +type returnPipe struct { + muxer *quicpogs.DatagramMuxerV2 +} + +func (rp *returnPipe) SendPacket(dst netip.Addr, pk packet.RawPacket) error { + return rp.muxer.SendPacket(pk) +} + +func (rp *returnPipe) Close() error { + return nil +} diff --git a/ingress/icmp_darwin.go b/ingress/icmp_darwin.go index ace04f23..73ac4de1 100644 --- a/ingress/icmp_darwin.go +++ b/ingress/icmp_darwin.go @@ -115,7 +115,7 @@ func (snf echoFunnelID) String() string { return strconv.FormatUint(uint64(snf), 10) } -func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (ICMPProxy, error) { +func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { conn, err := newICMPConn(listenIP) if err != nil { return nil, err diff --git a/ingress/icmp_generic.go b/ingress/icmp_generic.go index 781dc65b..0e481006 100644 --- a/ingress/icmp_generic.go +++ b/ingress/icmp_generic.go @@ -3,14 +3,29 @@ package ingress import ( + "context" "fmt" "net/netip" "runtime" "time" "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/packet" ) -func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (ICMPProxy, error) { - return nil, fmt.Errorf("ICMP proxy is not implemented on %s", runtime.GOOS) +var errICMPProxyNotImplemented = fmt.Errorf("ICMP proxy is not implemented on %s", runtime.GOOS) + +type icmpProxy struct{} + +func (ip icmpProxy) Request(pk *packet.ICMP, responder packet.FunnelUniPipe) error { + return errICMPProxyNotImplemented +} + +func (ip *icmpProxy) Serve(ctx context.Context) error { + return errICMPProxyNotImplemented +} + +func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { + return nil, errICMPProxyNotImplemented } diff --git a/ingress/icmp_linux.go b/ingress/icmp_linux.go index 21da6ce1..3b7ae828 100644 --- a/ingress/icmp_linux.go +++ b/ingress/icmp_linux.go @@ -28,7 +28,7 @@ type icmpProxy struct { idleTimeout time.Duration } -func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (ICMPProxy, error) { +func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { if err := testPermission(listenIP); err != nil { return nil, err } diff --git a/ingress/icmp_posix.go b/ingress/icmp_posix.go index 927c02f9..a2a0a5b9 100644 --- a/ingress/icmp_posix.go +++ b/ingress/icmp_posix.go @@ -96,6 +96,7 @@ func (ief *icmpEchoFlow) returnToSrc(reply *echoReply) error { Src: reply.from, Dst: ief.Src, Protocol: layers.IPProtocol(reply.msg.Type.Protocol()), + TTL: packet.DefaultTTL, }, Message: reply.msg, } diff --git a/ingress/icmp_windows.go b/ingress/icmp_windows.go index c1ebb511..baba8d4d 100644 --- a/ingress/icmp_windows.go +++ b/ingress/icmp_windows.go @@ -224,7 +224,7 @@ type icmpProxy struct { encoderPool sync.Pool } -func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (ICMPProxy, error) { +func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { var ( srcSocketAddr *sockAddrIn6 handle uintptr @@ -302,6 +302,7 @@ func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, d Src: request.Dst, Dst: request.Src, Protocol: layers.IPProtocol(request.Type.Protocol()), + TTL: packet.DefaultTTL, }, Message: &icmp.Message{ Type: replyType, diff --git a/ingress/icmp_windows_test.go b/ingress/icmp_windows_test.go index 5361ac44..05f7e620 100644 --- a/ingress/icmp_windows_test.go +++ b/ingress/icmp_windows_test.go @@ -134,7 +134,6 @@ func TestSendEchoErrors(t *testing.T) { func testSendEchoErrors(t *testing.T, listenIP netip.Addr) { proxy, err := newICMPProxy(listenIP, &noopLogger, time.Second) require.NoError(t, err) - winProxy := proxy.(*icmpProxy) echo := icmp.Echo{ ID: 6193, @@ -145,7 +144,7 @@ func testSendEchoErrors(t *testing.T, listenIP netip.Addr) { if listenIP.Is6() { documentIP = netip.MustParseAddr("2001:db8::1") } - resp, err := winProxy.icmpEchoRoundtrip(documentIP, &echo) + resp, err := proxy.icmpEchoRoundtrip(documentIP, &echo) require.Error(t, err) require.Nil(t, resp) } diff --git a/ingress/origin_icmp_proxy.go b/ingress/origin_icmp_proxy.go index c114eb75..4bd0c0fe 100644 --- a/ingress/origin_icmp_proxy.go +++ b/ingress/origin_icmp_proxy.go @@ -26,22 +26,14 @@ var ( errPacketNil = fmt.Errorf("packet is nil") ) -// ICMPProxy sends ICMP messages and listens for their responses -type ICMPProxy interface { - // Serve starts listening for responses to the requests until context is done - Serve(ctx context.Context) error - // Request sends an ICMP message - Request(pk *packet.ICMP, responder packet.FunnelUniPipe) error -} - type icmpRouter struct { - ipv4Proxy ICMPProxy - ipv6Proxy ICMPProxy + ipv4Proxy *icmpProxy + ipv6Proxy *icmpProxy } -// NewICMPProxy doesn't return an error if either ipv4 proxy or ipv6 proxy can be created. The machine might only +// NewICMPRouter doesn't return an error if either ipv4 proxy or ipv6 proxy can be created. The machine might only // support one of them -func NewICMPProxy(logger *zerolog.Logger) (ICMPProxy, error) { +func NewICMPRouter(logger *zerolog.Logger) (*icmpRouter, error) { // TODO: TUN-6741: don't bind to all interface ipv4Proxy, ipv4Err := newICMPProxy(netip.IPv4Unspecified(), logger, funnelIdleTimeout) ipv6Proxy, ipv6Err := newICMPProxy(netip.IPv6Unspecified(), logger, funnelIdleTimeout) @@ -49,11 +41,11 @@ func NewICMPProxy(logger *zerolog.Logger) (ICMPProxy, error) { return nil, fmt.Errorf("cannot create ICMPv4 proxy: %v nor ICMPv6 proxy: %v", ipv4Err, ipv6Err) } if ipv4Err != nil { - logger.Warn().Err(ipv4Err).Msg("failed to create ICMPv4 proxy, only ICMPv6 proxy is created") + logger.Debug().Err(ipv4Err).Msg("failed to create ICMPv4 proxy, only ICMPv6 proxy is created") ipv4Proxy = nil } if ipv6Err != nil { - logger.Warn().Err(ipv6Err).Msg("failed to create ICMPv6 proxy, only ICMPv4 proxy is created") + logger.Debug().Err(ipv6Err).Msg("failed to create ICMPv6 proxy, only ICMPv4 proxy is created") ipv6Proxy = nil } return &icmpRouter{ diff --git a/ingress/origin_icmp_proxy_test.go b/ingress/origin_icmp_proxy_test.go index 55379dcb..5684a608 100644 --- a/ingress/origin_icmp_proxy_test.go +++ b/ingress/origin_icmp_proxy_test.go @@ -30,24 +30,24 @@ var ( // Note: if this test fails on your device under Linux, then most likely you need to make sure that your user // is allowed in ping_group_range. See the following gist for how to do that: // https://github.com/ValentinBELYN/icmplib/blob/main/docs/6-use-icmplib-without-privileges.md -func TestICMPProxyEcho(t *testing.T) { - testICMPProxyEcho(t, true) - testICMPProxyEcho(t, false) +func TestICMPRouterEcho(t *testing.T) { + testICMPRouterEcho(t, true) + testICMPRouterEcho(t, false) } -func testICMPProxyEcho(t *testing.T, sendIPv4 bool) { +func testICMPRouterEcho(t *testing.T, sendIPv4 bool) { const ( echoID = 36571 endSeq = 20 ) - proxy, err := NewICMPProxy(&noopLogger) + router, err := NewICMPRouter(&noopLogger) require.NoError(t, err) proxyDone := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background()) go func() { - proxy.Serve(ctx) + router.Serve(ctx) close(proxyDone) }() @@ -67,6 +67,7 @@ func testICMPProxyEcho(t *testing.T, sendIPv4 bool) { Src: localIP, Dst: localIP, Protocol: protocol, + TTL: packet.DefaultTTL, } } @@ -88,7 +89,7 @@ func testICMPProxyEcho(t *testing.T, sendIPv4 bool) { }, }, } - require.NoError(t, proxy.Request(&pk, &responder)) + require.NoError(t, router.Request(&pk, &responder)) responder.validate(t, &pk) } } @@ -97,7 +98,7 @@ func testICMPProxyEcho(t *testing.T, sendIPv4 bool) { } // TestICMPProxyRejectNotEcho makes sure it rejects messages other than echo -func TestICMPProxyRejectNotEcho(t *testing.T) { +func TestICMPRouterRejectNotEcho(t *testing.T) { msgs := []icmp.Message{ { Type: ipv4.ICMPTypeDestinationUnreachable, @@ -122,7 +123,7 @@ func TestICMPProxyRejectNotEcho(t *testing.T) { }, }, } - testICMPProxyRejectNotEcho(t, localhostIP, msgs) + testICMPRouterRejectNotEcho(t, localhostIP, msgs) msgsV6 := []icmp.Message{ { Type: ipv6.ICMPTypeDestinationUnreachable, @@ -147,11 +148,11 @@ func TestICMPProxyRejectNotEcho(t *testing.T) { }, }, } - testICMPProxyRejectNotEcho(t, localhostIPv6, msgsV6) + testICMPRouterRejectNotEcho(t, localhostIPv6, msgsV6) } -func testICMPProxyRejectNotEcho(t *testing.T, srcDstIP netip.Addr, msgs []icmp.Message) { - proxy, err := NewICMPProxy(&noopLogger) +func testICMPRouterRejectNotEcho(t *testing.T, srcDstIP netip.Addr, msgs []icmp.Message) { + router, err := NewICMPRouter(&noopLogger) require.NoError(t, err) responder := echoFlowResponder{ @@ -168,10 +169,11 @@ func testICMPProxyRejectNotEcho(t *testing.T, srcDstIP netip.Addr, msgs []icmp.M Src: srcDstIP, Dst: srcDstIP, Protocol: protocol, + TTL: packet.DefaultTTL, }, Message: &m, } - require.Error(t, proxy.Request(&pk, &responder)) + require.Error(t, router.Request(&pk, &responder)) } } diff --git a/packet/decoder.go b/packet/decoder.go index 147cbd16..2737b18e 100644 --- a/packet/decoder.go +++ b/packet/decoder.go @@ -16,8 +16,8 @@ func FindProtocol(p []byte) (layers.IPProtocol, error) { } switch version { case 4: - if len(p) < ipv4HeaderLen { - return 0, fmt.Errorf("IPv4 packet should have at least %d bytes, got %d bytes", ipv4HeaderLen, len(p)) + if len(p) < ipv4MinHeaderLen { + return 0, fmt.Errorf("IPv4 packet should have at least %d bytes, got %d bytes", ipv4MinHeaderLen, len(p)) } // Protocol is in the 10th byte of IPv4 header return layers.IPProtocol(p[9]), nil diff --git a/packet/decoder_test.go b/packet/decoder_test.go index 6db377dd..c07896e4 100644 --- a/packet/decoder_test.go +++ b/packet/decoder_test.go @@ -61,11 +61,13 @@ func TestDecodeICMP(t *testing.T) { Src: netip.MustParseAddr("172.16.0.1"), Dst: netip.MustParseAddr("10.0.0.1"), Protocol: layers.IPProtocolICMPv4, + TTL: DefaultTTL, } ipv6Packet = IP{ Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"), Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"), Protocol: layers.IPProtocolICMPv6, + TTL: DefaultTTL, } icmpID = 100 icmpSeq = 52819 @@ -171,7 +173,7 @@ func TestDecodeBadPackets(t *testing.T) { SrcIP: srcIPv4, DstIP: dstIPv4, Protocol: layers.IPProtocolICMPv4, - TTL: defaultTTL, + TTL: DefaultTTL, } icmpLayer := layers.ICMPv4{ TypeCode: layers.CreateICMPv4TypeCode(uint8(ipv4.ICMPTypeEcho), 0), @@ -231,6 +233,7 @@ func assertIPLayer(t *testing.T, expected, actual *IP) { require.Equal(t, expected.Src, actual.Src) require.Equal(t, expected.Dst, actual.Dst) require.Equal(t, expected.Protocol, actual.Protocol) + require.Equal(t, expected.TTL, actual.TTL) } type UDP struct { diff --git a/packet/funnel_test.go b/packet/funnel_test.go new file mode 100644 index 00000000..08dc291f --- /dev/null +++ b/packet/funnel_test.go @@ -0,0 +1,16 @@ +package packet + +import "net/netip" + +type mockFunnelUniPipe struct { + uniPipe chan RawPacket +} + +func (mfui *mockFunnelUniPipe) SendPacket(dst netip.Addr, pk RawPacket) error { + mfui.uniPipe <- pk + return nil +} + +func (mfui *mockFunnelUniPipe) Close() error { + return nil +} diff --git a/packet/packet.go b/packet/packet.go index 62b32ed8..b691790f 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -7,12 +7,20 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) const ( - defaultTTL uint8 = 64 - ipv4HeaderLen = 20 - ipv6HeaderLen = 40 + ipv4MinHeaderLen = 20 + ipv6HeaderLen = 40 + ipv4MinMTU = 576 + ipv6MinMTU = 1280 + icmpHeaderLen = 8 + // https://www.rfc-editor.org/rfc/rfc792 and https://datatracker.ietf.org/doc/html/rfc4443#section-3.3 define 2 codes. + // 0 = ttl exceed in transit, 1 = fragment reassembly time exceeded + icmpTTLExceedInTransitCode = 0 + DefaultTTL uint8 = 255 ) // Packet represents an IP packet or a packet that is encapsulated by IP @@ -28,6 +36,7 @@ type IP struct { Src netip.Addr Dst netip.Addr Protocol layers.IPProtocol + TTL uint8 } func newIPv4(ipLayer *layers.IPv4) (*IP, error) { @@ -43,6 +52,7 @@ func newIPv4(ipLayer *layers.IPv4) (*IP, error) { Src: src, Dst: dst, Protocol: ipLayer.Protocol, + TTL: ipLayer.TTL, }, nil } @@ -59,6 +69,7 @@ func newIPv6(ipLayer *layers.IPv6) (*IP, error) { Src: src, Dst: dst, Protocol: ipLayer.NextHeader, + TTL: ipLayer.HopLimit, }, nil } @@ -78,7 +89,7 @@ func (ip *IP) EncodeLayers() ([]gopacket.SerializableLayer, error) { SrcIP: ip.Src.AsSlice(), DstIP: ip.Dst.AsSlice(), Protocol: layers.IPProtocol(ip.Protocol), - TTL: defaultTTL, + TTL: ip.TTL, }, }, nil } else { @@ -88,7 +99,7 @@ func (ip *IP) EncodeLayers() ([]gopacket.SerializableLayer, error) { SrcIP: ip.Src.AsSlice(), DstIP: ip.Dst.AsSlice(), NextHeader: layers.IPProtocol(ip.Protocol), - HopLimit: defaultTTL, + HopLimit: ip.TTL, }, }, nil } @@ -113,3 +124,51 @@ func (i *ICMP) EncodeLayers() ([]gopacket.SerializableLayer, error) { icmpLayer := gopacket.Payload(msg) return append(ipLayers, icmpLayer), nil } + +func NewICMPTTLExceedPacket(originalIP *IP, originalPacket RawPacket, routerIP netip.Addr) *ICMP { + var ( + protocol layers.IPProtocol + icmpType icmp.Type + ) + if originalIP.Dst.Is4() { + protocol = layers.IPProtocolICMPv4 + icmpType = ipv4.ICMPTypeTimeExceeded + } else { + protocol = layers.IPProtocolICMPv6 + icmpType = ipv6.ICMPTypeTimeExceeded + } + return &ICMP{ + IP: &IP{ + Src: routerIP, + Dst: originalIP.Src, + Protocol: protocol, + TTL: DefaultTTL, + }, + Message: &icmp.Message{ + Type: icmpType, + Code: icmpTTLExceedInTransitCode, + Body: &icmp.TimeExceeded{ + Data: originalDatagram(originalPacket, originalIP.Dst.Is4()), + }, + }, + } +} + +// originalDatagram returns a slice of the original datagram for ICMP error messages +// https://www.rfc-editor.org/rfc/rfc1812#section-4.3.2.3 suggests to copy as much without exceeding 576 bytes. +// https://datatracker.ietf.org/doc/html/rfc4443#section-3.3 suggests to copy as much without exceeding 1280 bytes +func originalDatagram(originalPacket RawPacket, isIPv4 bool) []byte { + var upperBound int + if isIPv4 { + upperBound = ipv4MinMTU - ipv4MinHeaderLen - icmpHeaderLen + if upperBound > len(originalPacket.Data) { + upperBound = len(originalPacket.Data) + } + } else { + upperBound = ipv6MinMTU - ipv6HeaderLen - icmpHeaderLen + if upperBound > len(originalPacket.Data) { + upperBound = len(originalPacket.Data) + } + } + return originalPacket.Data[:upperBound] +} diff --git a/packet/packet_test.go b/packet/packet_test.go new file mode 100644 index 00000000..8b0b314f --- /dev/null +++ b/packet/packet_test.go @@ -0,0 +1,104 @@ +package packet + +import ( + "bytes" + "net/netip" + "testing" + + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +func TestNewICMPTTLExceedPacket(t *testing.T) { + ipv4Packet := IP{ + Src: netip.MustParseAddr("192.168.1.1"), + Dst: netip.MustParseAddr("10.0.0.1"), + Protocol: layers.IPProtocolICMPv4, + TTL: 0, + } + icmpV4Packet := ICMP{ + IP: &ipv4Packet, + Message: &icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: 25821, + Seq: 58129, + Data: []byte("test ttl=0"), + }, + }, + } + assertTTLExceedPacket(t, &icmpV4Packet) + icmpV4Packet.Body = &icmp.Echo{ + ID: 3487, + Seq: 19183, + Data: make([]byte, ipv4MinMTU), + } + assertTTLExceedPacket(t, &icmpV4Packet) + ipv6Packet := IP{ + Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"), + Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"), + Protocol: layers.IPProtocolICMPv6, + TTL: 0, + } + icmpV6Packet := ICMP{ + IP: &ipv6Packet, + Message: &icmp.Message{ + Type: ipv6.ICMPTypeEchoRequest, + Code: 0, + Body: &icmp.Echo{ + ID: 25821, + Seq: 58129, + Data: []byte("test ttl=0"), + }, + }, + } + assertTTLExceedPacket(t, &icmpV6Packet) + icmpV6Packet.Body = &icmp.Echo{ + ID: 1497, + Seq: 39284, + Data: make([]byte, ipv6MinMTU), + } + assertTTLExceedPacket(t, &icmpV6Packet) +} + +func assertTTLExceedPacket(t *testing.T, pk *ICMP) { + encoder := NewEncoder() + rawPacket, err := encoder.Encode(pk) + require.NoError(t, err) + + minMTU := ipv4MinMTU + headerLen := ipv4MinHeaderLen + routerIP := netip.MustParseAddr("172.16.0.3") + if pk.Dst.Is6() { + minMTU = ipv6MinMTU + headerLen = ipv6HeaderLen + routerIP = netip.MustParseAddr("fd51:2391:697:f4ee::3") + } + + ttlExceedPacket := NewICMPTTLExceedPacket(pk.IP, rawPacket, routerIP) + require.Equal(t, routerIP, ttlExceedPacket.Src) + require.Equal(t, pk.Src, ttlExceedPacket.Dst) + require.Equal(t, pk.Protocol, ttlExceedPacket.Protocol) + require.Equal(t, DefaultTTL, ttlExceedPacket.TTL) + + timeExceed, ok := ttlExceedPacket.Body.(*icmp.TimeExceeded) + require.True(t, ok) + if len(rawPacket.Data) > minMTU { + require.True(t, bytes.Equal(timeExceed.Data, rawPacket.Data[:minMTU-headerLen-icmpHeaderLen])) + } else { + require.True(t, bytes.Equal(timeExceed.Data, rawPacket.Data)) + } + + rawTTLExceedPacket, err := encoder.Encode(ttlExceedPacket) + require.NoError(t, err) + if len(rawPacket.Data) > minMTU { + require.Len(t, rawTTLExceedPacket.Data, minMTU) + } else { + require.Len(t, rawTTLExceedPacket.Data, headerLen+icmpHeaderLen+len(rawPacket.Data)) + require.True(t, bytes.Equal(rawPacket.Data, rawTTLExceedPacket.Data[headerLen+icmpHeaderLen:])) + } +} diff --git a/packet/router.go b/packet/router.go new file mode 100644 index 00000000..83acc393 --- /dev/null +++ b/packet/router.go @@ -0,0 +1,126 @@ +package packet + +import ( + "context" + "net" + "net/netip" + + "github.com/rs/zerolog" +) + +var ( + // Source IP in documentation range to return ICMP error messages if we can't determine the IP of this machine + icmpv4ErrFallbackSrc = netip.MustParseAddr("192.0.2.30") + icmpv6ErrFallbackSrc = netip.MustParseAddr("2001:db8::") +) + +// ICMPRouter sends ICMP messages and listens for their responses +type ICMPRouter interface { + // Serve starts listening for responses to the requests until context is done + Serve(ctx context.Context) error + // Request sends an ICMP message + Request(pk *ICMP, responder FunnelUniPipe) error +} + +// Upstream of raw packets +type Upstream interface { + // ReceivePacket waits for the next raw packet from upstream + ReceivePacket(ctx context.Context) (RawPacket, error) +} + +type Router struct { + upstream Upstream + returnPipe FunnelUniPipe + icmpProxy ICMPRouter + ipv4Src netip.Addr + ipv6Src netip.Addr + logger *zerolog.Logger +} + +func NewRouter(upstream Upstream, returnPipe FunnelUniPipe, icmpProxy ICMPRouter, logger *zerolog.Logger) *Router { + ipv4Src, err := findLocalAddr(net.ParseIP("1.1.1.1"), 53) + if err != nil { + logger.Warn().Err(err).Msgf("Failed to determine the IPv4 for this machine. It will use %s as source IP for error messages such as ICMP TTL exceed", icmpv4ErrFallbackSrc) + ipv4Src = icmpv4ErrFallbackSrc + } + ipv6Src, err := findLocalAddr(net.ParseIP("2606:4700:4700::1111"), 53) + if err != nil { + logger.Warn().Err(err).Msgf("Failed to determine the IPv6 for this machine. It will use %s as source IP for error messages such as ICMP TTL exceed", icmpv6ErrFallbackSrc) + ipv6Src = icmpv6ErrFallbackSrc + } + return &Router{ + upstream: upstream, + returnPipe: returnPipe, + icmpProxy: icmpProxy, + ipv4Src: ipv4Src, + ipv6Src: ipv6Src, + logger: logger, + } +} + +func (r *Router) Serve(ctx context.Context) error { + icmpDecoder := NewICMPDecoder() + encoder := NewEncoder() + for { + rawPacket, err := r.upstream.ReceivePacket(ctx) + if err != nil { + return err + } + icmpPacket, err := icmpDecoder.Decode(rawPacket) + if err != nil { + r.logger.Err(err).Msg("Failed to decode ICMP packet from quic datagram") + continue + } + + if icmpPacket.TTL <= 1 { + if err := r.sendTTLExceedMsg(icmpPacket, rawPacket, encoder); err != nil { + r.logger.Err(err).Msg("Failed to return ICMP TTL exceed error") + } + continue + } + icmpPacket.TTL-- + + if err := r.icmpProxy.Request(icmpPacket, r.returnPipe); err != nil { + r.logger.Err(err). + Str("src", icmpPacket.Src.String()). + Str("dst", icmpPacket.Dst.String()). + Interface("type", icmpPacket.Type). + Msg("Failed to send ICMP packet") + continue + } + } +} + +func (r *Router) sendTTLExceedMsg(pk *ICMP, rawPacket RawPacket, encoder *Encoder) error { + var srcIP netip.Addr + if pk.Dst.Is4() { + srcIP = r.ipv4Src + } else { + srcIP = r.ipv6Src + } + ttlExceedPacket := NewICMPTTLExceedPacket(pk.IP, rawPacket, srcIP) + + encodedTTLExceed, err := encoder.Encode(ttlExceedPacket) + if err != nil { + return err + } + return r.returnPipe.SendPacket(pk.Src, encodedTTLExceed) +} + +// findLocalAddr tries to dial UDP and returns the local address picked by the OS +func findLocalAddr(dst net.IP, port int) (netip.Addr, error) { + udpConn, err := net.DialUDP("udp", nil, &net.UDPAddr{ + IP: dst, + Port: port, + }) + if err != nil { + return netip.Addr{}, err + } + defer udpConn.Close() + localAddrPort, err := netip.ParseAddrPort(udpConn.LocalAddr().String()) + if err != nil { + return netip.Addr{}, err + } + localAddr := localAddrPort.Addr() + return localAddr, nil +} diff --git a/packet/router_test.go b/packet/router_test.go new file mode 100644 index 00000000..14e49986 --- /dev/null +++ b/packet/router_test.go @@ -0,0 +1,125 @@ +package packet + +import ( + "bytes" + "context" + "fmt" + "net/netip" + "testing" + + "github.com/google/gopacket/layers" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var ( + noopLogger = zerolog.Nop() +) + +func TestRouterReturnTTLExceed(t *testing.T) { + upstream := &mockUpstream{ + source: make(chan RawPacket), + } + returnPipe := &mockFunnelUniPipe{ + uniPipe: make(chan RawPacket), + } + router := NewRouter(upstream, returnPipe, &mockICMPRouter{}, &noopLogger) + ctx, cancel := context.WithCancel(context.Background()) + routerStopped := make(chan struct{}) + go func() { + router.Serve(ctx) + close(routerStopped) + }() + + pk := ICMP{ + IP: &IP{ + Src: netip.MustParseAddr("192.168.1.1"), + Dst: netip.MustParseAddr("10.0.0.1"), + Protocol: layers.IPProtocolICMPv4, + TTL: 1, + }, + Message: &icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: 12481, + Seq: 8036, + Data: []byte("TTL exceed"), + }, + }, + } + assertTTLExceed(t, &pk, router.ipv4Src, upstream, returnPipe) + pk = ICMP{ + IP: &IP{ + Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"), + Dst: netip.MustParseAddr("fd51:2391:697:f4ee::2"), + Protocol: layers.IPProtocolICMPv6, + TTL: 1, + }, + Message: &icmp.Message{ + Type: ipv6.ICMPTypeEchoRequest, + Code: 0, + Body: &icmp.Echo{ + ID: 42583, + Seq: 7039, + Data: []byte("TTL exceed"), + }, + }, + } + assertTTLExceed(t, &pk, router.ipv6Src, upstream, returnPipe) + + cancel() + <-routerStopped +} + +func assertTTLExceed(t *testing.T, originalPacket *ICMP, expectedSrc netip.Addr, upstream *mockUpstream, returnPipe *mockFunnelUniPipe) { + encoder := NewEncoder() + rawPacket, err := encoder.Encode(originalPacket) + require.NoError(t, err) + + upstream.source <- rawPacket + resp := <-returnPipe.uniPipe + decoder := NewICMPDecoder() + decoded, err := decoder.Decode(resp) + require.NoError(t, err) + + require.Equal(t, expectedSrc, decoded.Src) + require.Equal(t, originalPacket.Src, decoded.Dst) + require.Equal(t, originalPacket.Protocol, decoded.Protocol) + require.Equal(t, DefaultTTL, decoded.TTL) + if originalPacket.Dst.Is4() { + require.Equal(t, ipv4.ICMPTypeTimeExceeded, decoded.Type) + } else { + require.Equal(t, ipv6.ICMPTypeTimeExceeded, decoded.Type) + } + require.Equal(t, 0, decoded.Code) + timeExceed, ok := decoded.Body.(*icmp.TimeExceeded) + require.True(t, ok) + require.True(t, bytes.Equal(rawPacket.Data, timeExceed.Data)) +} + +type mockUpstream struct { + source chan RawPacket +} + +func (ms *mockUpstream) ReceivePacket(ctx context.Context) (RawPacket, error) { + select { + case <-ctx.Done(): + return RawPacket{}, ctx.Err() + case pk := <-ms.source: + return pk, nil + } +} + +type mockICMPRouter struct{} + +func (mir mockICMPRouter) Serve(ctx context.Context) error { + return fmt.Errorf("Serve not implemented by mockICMPRouter") +} + +func (mir mockICMPRouter) Request(pk *ICMP, responder FunnelUniPipe) error { + return fmt.Errorf("Request not implemented by mockICMPRouter") +} diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index fc80ed1c..ebd0e449 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -117,11 +117,11 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato } if useDatagramV2(config) { // TODO: TUN-6701: Decouple upgrade of datagram v2 and using icmp proxy - icmpProxy, err := ingress.NewICMPProxy(config.Log) + icmpRouter, err := ingress.NewICMPRouter(config.Log) if err != nil { - log.Logger().Warn().Err(err).Msg("Failed to create icmp proxy, will continue to use datagram v1") + log.Logger().Warn().Err(err).Msg("Failed to create icmp router, will continue to use datagram v1") } else { - edgeTunnelServer.icmpProxy = icmpProxy + edgeTunnelServer.icmpRouter = icmpRouter } } @@ -152,10 +152,10 @@ func (s *Supervisor) Run( ctx context.Context, connectedSignal *signal.Signal, ) error { - if s.edgeTunnelServer.icmpProxy != nil { + if s.edgeTunnelServer.icmpRouter != nil { go func() { - if err := s.edgeTunnelServer.icmpProxy.Serve(ctx); err != nil { - s.log.Logger().Err(err).Msg("icmp proxy terminated") + if err := s.edgeTunnelServer.icmpRouter.Serve(ctx); err != nil { + s.log.Logger().Err(err).Msg("icmp router terminated") } }() } diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 34f6cecc..22fb55e3 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -20,8 +20,8 @@ import ( "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/h2mux" - "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/orchestration" + "github.com/cloudflare/cloudflared/packet" quicpogs "github.com/cloudflare/cloudflared/quic" "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" @@ -200,7 +200,7 @@ type EdgeTunnelServer struct { reconnectCh chan ReconnectSignal gracefulShutdownC <-chan struct{} tracker *tunnelstate.ConnTracker - icmpProxy ingress.ICMPProxy + icmpRouter packet.ICMPRouter connAwareLogger *ConnAwareLogger } @@ -661,7 +661,7 @@ func (e *EdgeTunnelServer) serveQUIC( connOptions, controlStreamHandler, connLogger.Logger(), - e.icmpProxy) + e.icmpRouter) if err != nil { if e.config.NeedPQ { handlePQTunnelError(err, e.config)