diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index ef9a7e28..9f031641 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -5,6 +5,8 @@ import ( "fmt" "io/ioutil" mathRand "math/rand" + "net" + "net/netip" "os" "path/filepath" "strings" @@ -20,6 +22,7 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/edgediscovery/allregions" + "github.com/cloudflare/cloudflared/packet" "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/connection" @@ -384,6 +387,12 @@ func prepareTunnelConfig( NeedPQ: needPQ, PQKexIdx: pqKexIdx, } + packetConfig, err := newPacketConfig(c, log) + if err != nil { + log.Warn().Err(err).Msg("ICMP proxy feature is disabled") + } else { + tunnelConfig.PacketConfig = packetConfig + } orchestratorConfig := &orchestration.Config{ Ingress: &ingressRules, WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting), @@ -453,3 +462,147 @@ func parseConfigIPVersion(version string) (v allregions.ConfigIPVersion, err err } return } + +func newPacketConfig(c *cli.Context, logger *zerolog.Logger) (*packet.GlobalRouterConfig, error) { + ipv4Src, err := determineICMPv4Src(c.String("icmpv4-src"), logger) + if err != nil { + return nil, errors.Wrap(err, "failed to determine IPv4 source address for ICMP proxy") + } + logger.Info().Msgf("ICMP proxy will use %s as source for IPv4", ipv4Src) + + ipv6Src, zone, err := determineICMPv6Src(c.String("icmpv6-src"), logger, ipv4Src) + if err != nil { + return nil, errors.Wrap(err, "failed to determine IPv6 source address for ICMP proxy") + } + if zone != "" { + logger.Info().Msgf("ICMP proxy will use %s in zone %s as source for IPv6", ipv6Src, zone) + } else { + logger.Info().Msgf("ICMP proxy will use %s as source for IPv6", ipv6Src) + } + + icmpRouter, err := ingress.NewICMPRouter(ipv4Src, ipv6Src, zone, logger) + if err != nil { + return nil, err + } + return &packet.GlobalRouterConfig{ + ICMPRouter: icmpRouter, + IPv4Src: ipv4Src, + IPv6Src: ipv6Src, + Zone: zone, + }, nil +} + +func determineICMPv4Src(userDefinedSrc string, logger *zerolog.Logger) (netip.Addr, error) { + if userDefinedSrc != "" { + addr, err := netip.ParseAddr(userDefinedSrc) + if err != nil { + return netip.Addr{}, err + } + if addr.Is4() { + return addr, nil + } + return netip.Addr{}, fmt.Errorf("expect IPv4, but %s is IPv6", userDefinedSrc) + } + + addr, err := findLocalAddr(net.ParseIP("192.168.0.1"), 53) + if err != nil { + addr = netip.IPv4Unspecified() + logger.Debug().Err(err).Msgf("Failed to determine the IPv4 for this machine. It will use %s to send/listen for ICMPv4 echo", addr) + } + return addr, nil +} + +type interfaceIP struct { + name string + ip net.IP +} + +func determineICMPv6Src(userDefinedSrc string, logger *zerolog.Logger, ipv4Src netip.Addr) (addr netip.Addr, zone string, err error) { + if userDefinedSrc != "" { + userDefinedIP, zone, _ := strings.Cut(userDefinedSrc, "%") + addr, err := netip.ParseAddr(userDefinedIP) + if err != nil { + return netip.Addr{}, "", err + } + if addr.Is6() { + return addr, zone, nil + } + return netip.Addr{}, "", fmt.Errorf("expect IPv6, but %s is IPv4", userDefinedSrc) + } + + // Loop through all the interfaces, the preference is + // 1. The interface where ipv4Src is in + // 2. Interface with IPv6 address + // 3. Unspecified interface + + interfaces, err := net.Interfaces() + if err != nil { + return netip.IPv6Unspecified(), "", nil + } + + interfacesWithIPv6 := make([]interfaceIP, 0) + for _, interf := range interfaces { + interfaceAddrs, err := interf.Addrs() + if err != nil { + continue + } + + foundIPv4SrcInterface := false + for _, interfaceAddr := range interfaceAddrs { + if ipnet, ok := interfaceAddr.(*net.IPNet); ok { + ip := ipnet.IP + if ip.Equal(ipv4Src.AsSlice()) { + foundIPv4SrcInterface = true + } + if ip.To4() == nil { + interfacesWithIPv6 = append(interfacesWithIPv6, interfaceIP{ + name: interf.Name, + ip: ip, + }) + } + } + } + // Found the interface of ipv4Src. Loop through the addresses to see if there is an IPv6 + if foundIPv4SrcInterface { + for _, interfaceAddr := range interfaceAddrs { + if ipnet, ok := interfaceAddr.(*net.IPNet); ok { + ip := ipnet.IP + if ip.To4() == nil { + addr, err := netip.ParseAddr(ip.String()) + if err == nil { + return addr, interf.Name, nil + } + } + } + } + } + } + + for _, interf := range interfacesWithIPv6 { + addr, err := netip.ParseAddr(interf.ip.String()) + if err == nil { + return addr, interf.name, nil + } + } + logger.Debug().Err(err).Msgf("Failed to determine the IPv6 for this machine. It will use %s to send/listen for ICMPv6 echo", netip.IPv6Unspecified()) + + return netip.IPv6Unspecified(), "", nil +} + +// FindLocalAddr tries to dial UDP and returns the local address picked by the OS +func findLocalAddr(dst net.IP, port int) (netip.Addr, error) { + udpConn, err := net.DialUDP("udp", nil, &net.UDPAddr{ + IP: dst, + Port: port, + }) + if err != nil { + return netip.Addr{}, err + } + defer udpConn.Close() + localAddrPort, err := netip.ParseAddrPort(udpConn.LocalAddr().String()) + if err != nil { + return netip.Addr{}, err + } + localAddr := localAddrPort.Addr() + return localAddr, nil +} diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 71435465..4494e761 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -176,6 +176,16 @@ var ( Usage: "Base64 encoded secret to set for the tunnel. The decoded secret must be at least 32 bytes long. If not specified, a random 32-byte secret will be generated.", EnvVars: []string{"TUNNEL_CREATE_SECRET"}, } + icmpv4SrcFlag = &cli.StringFlag{ + Name: "icmpv4-src", + Usage: "Source address to send/receive ICMPv4 messages. If not provided cloudflared will dial a local address to determine the source IP or fallback to 0.0.0.0.", + EnvVars: []string{"TUNNEL_ICMPV4_SRC"}, + } + icmpv6SrcFlag = &cli.StringFlag{ + Name: "icmpv6-src", + Usage: "Source address and the interface name to send/receive ICMPv6 messages. If not provided cloudflared will dial a local address to determine the source IP or fallback to ::.", + EnvVars: []string{"TUNNEL_ICMPV6_SRC"}, + } ) func buildCreateCommand() *cli.Command { @@ -613,6 +623,8 @@ func buildRunCommand() *cli.Command { selectProtocolFlag, featuresFlag, tunnelTokenFlag, + icmpv4SrcFlag, + icmpv6SrcFlag, } flags = append(flags, configureProxyFlags(false)...) return &cli.Command{ diff --git a/component-tests/constants.py b/component-tests/constants.py index bb4b9d22..e46bac31 100644 --- a/component-tests/constants.py +++ b/component-tests/constants.py @@ -1,6 +1,7 @@ METRICS_PORT = 51000 MAX_RETRIES = 5 BACKOFF_SECS = 7 +MAX_LOG_LINES = 50 PROXY_DNS_PORT = 9053 diff --git a/component-tests/test_logging.py b/component-tests/test_logging.py index 51282ded..e9b0bfbd 100644 --- a/component-tests/test_logging.py +++ b/component-tests/test_logging.py @@ -2,6 +2,7 @@ import json import os +from constants import MAX_LOG_LINES from util import start_cloudflared, wait_tunnel_ready, send_requests # Rolling logger rotate log files after 1 MB @@ -11,14 +12,24 @@ expect_message = "Starting Hello" def assert_log_to_terminal(cloudflared): - stderr = cloudflared.stderr.read(1500) - assert expect_message.encode() in stderr, f"{stderr} doesn't contain {expect_message}" + for _ in range(0, MAX_LOG_LINES): + line = cloudflared.stderr.readline() + if not line: + break + if expect_message.encode() in line: + return + raise Exception(f"terminal log doesn't contain {expect_message}") def assert_log_in_file(file): with open(file, "r") as f: - log = f.read(2000) - assert expect_message in log, f"{log} doesn't contain {expect_message}" + for _ in range(0, MAX_LOG_LINES): + line = f.readline() + if not line: + break + if expect_message in line: + return + raise Exception(f"log file doesn't contain {expect_message}") def assert_json_log(file): diff --git a/connection/quic.go b/connection/quic.go index 86a5228c..95829a6e 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -65,7 +65,7 @@ func NewQUICConnection( connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler ControlStreamHandler, logger *zerolog.Logger, - icmpRouter packet.ICMPRouter, + packetRouterConfig *packet.GlobalRouterConfig, ) (*QUICConnection, error) { session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig) if err != nil { @@ -73,19 +73,14 @@ func NewQUICConnection( } sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) - var ( - datagramMuxer quicpogs.BaseDatagramMuxer - pr *packet.Router - ) - if icmpRouter != nil { - datagramMuxerV2 := quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan) - pr = packet.NewRouter(datagramMuxerV2, &returnPipe{muxer: datagramMuxerV2}, icmpRouter, logger) - datagramMuxer = datagramMuxerV2 - } else { - datagramMuxer = quicpogs.NewDatagramMuxer(session, logger, sessionDemuxChan) - } + datagramMuxer := quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan) sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) + var pr *packet.Router + if packetRouterConfig != nil { + pr = packet.NewRouter(packetRouterConfig, datagramMuxer, &returnPipe{muxer: datagramMuxer}, logger) + } + return &QUICConnection{ session: session, orchestrator: orchestrator, diff --git a/connection/quic_test.go b/connection/quic_test.go index 8904eeeb..c2990878 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -583,8 +583,12 @@ func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic. close(sessionDone) }() - // Send a message to the quic session on edge side, it should be deumx to this datagram session - muxedPayload := append(payload, sessionID[:]...) + // Send a message to the quic session on edge side, it should be deumx to this datagram v2 session + muxedPayload, err := quicpogs.SuffixSessionID(sessionID, payload) + require.NoError(t, err) + muxedPayload, err = quicpogs.SuffixType(muxedPayload, quicpogs.DatagramTypeUDP) + require.NoError(t, err) + err = edgeQUICSession.SendMessage(muxedPayload) require.NoError(t, err) diff --git a/ingress/icmp_darwin.go b/ingress/icmp_darwin.go index 352b5c54..2a57b415 100644 --- a/ingress/icmp_darwin.go +++ b/ingress/icmp_darwin.go @@ -113,11 +113,12 @@ func (snf echoFunnelID) String() string { return strconv.FormatUint(uint64(snf), 10) } -func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { - conn, err := newICMPConn(listenIP) +func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { + conn, err := newICMPConn(listenIP, zone) if err != nil { return nil, err } + logger.Info().Msgf("Created ICMP proxy listening on %s", conn.LocalAddr()) return &icmpProxy{ srcFunnelTracker: packet.NewFunnelTracker(), echoIDTracker: newEchoIDTracker(), diff --git a/ingress/icmp_generic.go b/ingress/icmp_generic.go index c685a2f4..e1c66e81 100644 --- a/ingress/icmp_generic.go +++ b/ingress/icmp_generic.go @@ -26,6 +26,6 @@ func (ip *icmpProxy) Serve(ctx context.Context) error { return errICMPProxyNotImplemented } -func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { +func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { return nil, errICMPProxyNotImplemented } diff --git a/ingress/icmp_linux.go b/ingress/icmp_linux.go index 5f122e49..bc7f968d 100644 --- a/ingress/icmp_linux.go +++ b/ingress/icmp_linux.go @@ -24,25 +24,27 @@ import ( type icmpProxy struct { srcFunnelTracker *packet.FunnelTracker listenIP netip.Addr + ipv6Zone string logger *zerolog.Logger idleTimeout time.Duration } -func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { - if err := testPermission(listenIP); err != nil { +func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { + if err := testPermission(listenIP, zone); err != nil { return nil, err } return &icmpProxy{ srcFunnelTracker: packet.NewFunnelTracker(), listenIP: listenIP, + ipv6Zone: zone, logger: logger, idleTimeout: idleTimeout, }, nil } -func testPermission(listenIP netip.Addr) error { +func testPermission(listenIP netip.Addr, zone string) error { // Opens a non-privileged ICMP socket. On Linux the group ID of the process needs to be in ping_group_range - conn, err := newICMPConn(listenIP) + conn, err := newICMPConn(listenIP, zone) if err != nil { // TODO: TUN-6715 check if cloudflared is in ping_group_range if the check failed. If not log instruction to // change the group ID @@ -63,10 +65,11 @@ func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FunnelUniPipe) er } newConnChan := make(chan *icmp.PacketConn, 1) newFunnelFunc := func() (packet.Funnel, error) { - conn, err := newICMPConn(ip.listenIP) + conn, err := newICMPConn(ip.listenIP, ip.ipv6Zone) if err != nil { return nil, errors.Wrap(err, "failed to open ICMP socket") } + ip.logger.Debug().Msgf("Opened ICMP socket listen on %s", conn.LocalAddr()) newConnChan <- conn localUDPAddr, ok := conn.LocalAddr().(*net.UDPAddr) if !ok { diff --git a/ingress/icmp_posix.go b/ingress/icmp_posix.go index 5c3c62b3..0badb685 100644 --- a/ingress/icmp_posix.go +++ b/ingress/icmp_posix.go @@ -16,12 +16,15 @@ import ( ) // Opens a non-privileged ICMP socket on Linux and Darwin -func newICMPConn(listenIP netip.Addr) (*icmp.PacketConn, error) { - network := "udp6" +func newICMPConn(listenIP netip.Addr, zone string) (*icmp.PacketConn, error) { if listenIP.Is4() { - network = "udp4" + return icmp.ListenPacket("udp4", listenIP.String()) } - return icmp.ListenPacket(network, listenIP.String()) + listenAddr := listenIP.String() + if zone != "" { + listenAddr = listenAddr + "%" + zone + } + return icmp.ListenPacket("udp6", listenAddr) } func netipAddr(addr net.Addr) (netip.Addr, bool) { diff --git a/ingress/icmp_posix_test.go b/ingress/icmp_posix_test.go index 397aa524..bbadd196 100644 --- a/ingress/icmp_posix_test.go +++ b/ingress/icmp_posix_test.go @@ -24,7 +24,7 @@ func TestFunnelIdleTimeout(t *testing.T) { startSeq = 8129 ) logger := zerolog.New(os.Stderr) - proxy, err := newICMPProxy(localhostIP, &logger, idleTimeout) + proxy, err := newICMPProxy(localhostIP, "", &logger, idleTimeout) require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) diff --git a/ingress/icmp_windows.go b/ingress/icmp_windows.go index 8ad313b5..bc98bcbe 100644 --- a/ingress/icmp_windows.go +++ b/ingress/icmp_windows.go @@ -224,7 +224,7 @@ type icmpProxy struct { encoderPool sync.Pool } -func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { +func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { var ( srcSocketAddr *sockAddrIn6 handle uintptr diff --git a/ingress/icmp_windows_test.go b/ingress/icmp_windows_test.go index 05f7e620..eefc654f 100644 --- a/ingress/icmp_windows_test.go +++ b/ingress/icmp_windows_test.go @@ -132,7 +132,7 @@ func TestSendEchoErrors(t *testing.T) { } func testSendEchoErrors(t *testing.T, listenIP netip.Addr) { - proxy, err := newICMPProxy(listenIP, &noopLogger, time.Second) + proxy, err := newICMPProxy(listenIP, "", &noopLogger, time.Second) require.NoError(t, err) echo := icmp.Echo{ diff --git a/ingress/origin_icmp_proxy.go b/ingress/origin_icmp_proxy.go index 4bd0c0fe..4b6826b4 100644 --- a/ingress/origin_icmp_proxy.go +++ b/ingress/origin_icmp_proxy.go @@ -33,10 +33,9 @@ type icmpRouter struct { // NewICMPRouter doesn't return an error if either ipv4 proxy or ipv6 proxy can be created. The machine might only // support one of them -func NewICMPRouter(logger *zerolog.Logger) (*icmpRouter, error) { - // TODO: TUN-6741: don't bind to all interface - ipv4Proxy, ipv4Err := newICMPProxy(netip.IPv4Unspecified(), logger, funnelIdleTimeout) - ipv6Proxy, ipv6Err := newICMPProxy(netip.IPv6Unspecified(), logger, funnelIdleTimeout) +func NewICMPRouter(ipv4Addr, ipv6Addr netip.Addr, ipv6Zone string, logger *zerolog.Logger) (*icmpRouter, error) { + ipv4Proxy, ipv4Err := newICMPProxy(ipv4Addr, "", logger, funnelIdleTimeout) + ipv6Proxy, ipv6Err := newICMPProxy(ipv6Addr, ipv6Zone, logger, funnelIdleTimeout) if ipv4Err != nil && ipv6Err != nil { return nil, fmt.Errorf("cannot create ICMPv4 proxy: %v nor ICMPv6 proxy: %v", ipv4Err, ipv6Err) } diff --git a/ingress/origin_icmp_proxy_test.go b/ingress/origin_icmp_proxy_test.go index e6f37a29..d080f46a 100644 --- a/ingress/origin_icmp_proxy_test.go +++ b/ingress/origin_icmp_proxy_test.go @@ -42,7 +42,7 @@ func testICMPRouterEcho(t *testing.T, sendIPv4 bool) { endSeq = 20 ) - router, err := NewICMPRouter(&noopLogger) + router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger) require.NoError(t, err) proxyDone := make(chan struct{}) @@ -106,7 +106,7 @@ func TestConcurrentRequestsToSameDst(t *testing.T) { endSeq = 5 ) - router, err := NewICMPRouter(&noopLogger) + router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger) require.NoError(t, err) proxyDone := make(chan struct{}) @@ -238,7 +238,7 @@ func TestICMPRouterRejectNotEcho(t *testing.T) { } func testICMPRouterRejectNotEcho(t *testing.T, srcDstIP netip.Addr, msgs []icmp.Message) { - router, err := NewICMPRouter(&noopLogger) + router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger) require.NoError(t, err) responder := echoFlowResponder{ diff --git a/packet/router.go b/packet/router.go index 83acc393..8e7e399d 100644 --- a/packet/router.go +++ b/packet/router.go @@ -2,18 +2,11 @@ package packet import ( "context" - "net" "net/netip" "github.com/rs/zerolog" ) -var ( - // Source IP in documentation range to return ICMP error messages if we can't determine the IP of this machine - icmpv4ErrFallbackSrc = netip.MustParseAddr("192.0.2.30") - icmpv6ErrFallbackSrc = netip.MustParseAddr("2001:db8::") -) - // ICMPRouter sends ICMP messages and listens for their responses type ICMPRouter interface { // Serve starts listening for responses to the requests until context is done @@ -28,32 +21,31 @@ type Upstream interface { ReceivePacket(ctx context.Context) (RawPacket, error) } +// Router routes packets between Upstream and ICMPRouter. Currently it rejects all other type of ICMP packets type Router struct { upstream Upstream returnPipe FunnelUniPipe - icmpProxy ICMPRouter + icmpRouter ICMPRouter ipv4Src netip.Addr ipv6Src netip.Addr logger *zerolog.Logger } -func NewRouter(upstream Upstream, returnPipe FunnelUniPipe, icmpProxy ICMPRouter, logger *zerolog.Logger) *Router { - ipv4Src, err := findLocalAddr(net.ParseIP("1.1.1.1"), 53) - if err != nil { - logger.Warn().Err(err).Msgf("Failed to determine the IPv4 for this machine. It will use %s as source IP for error messages such as ICMP TTL exceed", icmpv4ErrFallbackSrc) - ipv4Src = icmpv4ErrFallbackSrc - } - ipv6Src, err := findLocalAddr(net.ParseIP("2606:4700:4700::1111"), 53) - if err != nil { - logger.Warn().Err(err).Msgf("Failed to determine the IPv6 for this machine. It will use %s as source IP for error messages such as ICMP TTL exceed", icmpv6ErrFallbackSrc) - ipv6Src = icmpv6ErrFallbackSrc - } +// GlobalRouterConfig is the configuration shared by all instance of Router. +type GlobalRouterConfig struct { + ICMPRouter ICMPRouter + IPv4Src netip.Addr + IPv6Src netip.Addr + Zone string +} + +func NewRouter(globalConfig *GlobalRouterConfig, upstream Upstream, returnPipe FunnelUniPipe, logger *zerolog.Logger) *Router { return &Router{ upstream: upstream, returnPipe: returnPipe, - icmpProxy: icmpProxy, - ipv4Src: ipv4Src, - ipv6Src: ipv6Src, + icmpRouter: globalConfig.ICMPRouter, + ipv4Src: globalConfig.IPv4Src, + ipv6Src: globalConfig.IPv6Src, logger: logger, } } @@ -80,7 +72,7 @@ func (r *Router) Serve(ctx context.Context) error { } icmpPacket.TTL-- - if err := r.icmpProxy.Request(icmpPacket, r.returnPipe); err != nil { + if err := r.icmpRouter.Request(icmpPacket, r.returnPipe); err != nil { r.logger.Err(err). Str("src", icmpPacket.Src.String()). Str("dst", icmpPacket.Dst.String()). @@ -106,21 +98,3 @@ func (r *Router) sendTTLExceedMsg(pk *ICMP, rawPacket RawPacket, encoder *Encode } return r.returnPipe.SendPacket(pk.Src, encodedTTLExceed) } - -// findLocalAddr tries to dial UDP and returns the local address picked by the OS -func findLocalAddr(dst net.IP, port int) (netip.Addr, error) { - udpConn, err := net.DialUDP("udp", nil, &net.UDPAddr{ - IP: dst, - Port: port, - }) - if err != nil { - return netip.Addr{}, err - } - defer udpConn.Close() - localAddrPort, err := netip.ParseAddrPort(udpConn.LocalAddr().String()) - if err != nil { - return netip.Addr{}, err - } - localAddr := localAddrPort.Addr() - return localAddr, nil -} diff --git a/packet/router_test.go b/packet/router_test.go index 48afdc77..8b009c1c 100644 --- a/packet/router_test.go +++ b/packet/router_test.go @@ -26,7 +26,12 @@ func TestRouterReturnTTLExceed(t *testing.T) { returnPipe := &mockFunnelUniPipe{ uniPipe: make(chan RawPacket), } - router := NewRouter(upstream, returnPipe, &mockICMPRouter{}, &noopLogger) + packetConfig := &GlobalRouterConfig{ + ICMPRouter: &mockICMPRouter{}, + IPv4Src: netip.MustParseAddr("172.16.0.1"), + IPv6Src: netip.MustParseAddr("fd51:2391:523:f4ee::1"), + } + router := NewRouter(packetConfig, upstream, returnPipe, &noopLogger) ctx, cancel := context.WithCancel(context.Background()) routerStopped := make(chan struct{}) go func() { diff --git a/quic/datagram.go b/quic/datagram.go index 334754d4..23408cfe 100644 --- a/quic/datagram.go +++ b/quic/datagram.go @@ -49,7 +49,7 @@ func (dm *DatagramMuxer) SendToSession(session *packet.Session) error { packetTooBigDropped.Inc() return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(session.Payload), dm.mtu()) } - payloadWithMetadata, err := suffixSessionID(session.ID, session.Payload) + payloadWithMetadata, err := SuffixSessionID(session.ID, session.Payload) if err != nil { return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped") } @@ -112,7 +112,7 @@ func extractSessionID(b []byte) (uuid.UUID, []byte, error) { // SuffixSessionID appends the session ID at the end of the payload. Suffix is more performant than prefix because // the payload slice might already have enough capacity to append the session ID at the end -func suffixSessionID(sessionID uuid.UUID, b []byte) ([]byte, error) { +func SuffixSessionID(sessionID uuid.UUID, b []byte) ([]byte, error) { if len(b)+len(sessionID) > MaxDatagramFrameSize { return nil, fmt.Errorf("datagram size exceed %d", MaxDatagramFrameSize) } diff --git a/quic/datagram_test.go b/quic/datagram_test.go index bd55f425..69bb0b71 100644 --- a/quic/datagram_test.go +++ b/quic/datagram_test.go @@ -31,7 +31,7 @@ var ( func TestSuffixThenRemoveSessionID(t *testing.T) { msg := []byte(t.Name()) - msgWithID, err := suffixSessionID(testSessionID, msg) + msgWithID, err := SuffixSessionID(testSessionID, msg) require.NoError(t, err) require.Len(t, msgWithID, len(msg)+sessionIDLen) @@ -50,11 +50,11 @@ func TestRemoveSessionIDError(t *testing.T) { func TestSuffixSessionIDError(t *testing.T) { msg := make([]byte, MaxDatagramFrameSize-sessionIDLen) - _, err := suffixSessionID(testSessionID, msg) + _, err := SuffixSessionID(testSessionID, msg) require.NoError(t, err) msg = make([]byte, MaxDatagramFrameSize-sessionIDLen+1) - _, err = suffixSessionID(testSessionID, msg) + _, err = SuffixSessionID(testSessionID, msg) require.Error(t, err) } diff --git a/quic/datagramv2.go b/quic/datagramv2.go index 3f1c8f0e..373d8731 100644 --- a/quic/datagramv2.go +++ b/quic/datagramv2.go @@ -11,11 +11,11 @@ import ( "github.com/cloudflare/cloudflared/packet" ) -type datagramV2Type byte +type DatagramV2Type byte const ( - udp datagramV2Type = iota - ip + DatagramTypeUDP DatagramV2Type = iota + DatagramTypeIP ) const ( @@ -24,7 +24,7 @@ const ( packetChanCapacity = 128 ) -func suffixType(b []byte, datagramType datagramV2Type) ([]byte, error) { +func SuffixType(b []byte, datagramType DatagramV2Type) ([]byte, error) { if len(b)+typeIDLen > MaxDatagramFrameSize { return nil, fmt.Errorf("datagram size %d exceeds max frame size %d", len(b), MaxDatagramFrameSize) } @@ -65,11 +65,11 @@ func (dm *DatagramMuxerV2) SendToSession(session *packet.Session) error { packetTooBigDropped.Inc() return fmt.Errorf("origin UDP payload has %d bytes, which exceeds transport MTU %d", len(session.Payload), dm.mtu()) } - msgWithID, err := suffixSessionID(session.ID, session.Payload) + msgWithID, err := SuffixSessionID(session.ID, session.Payload) if err != nil { return errors.Wrap(err, "Failed to suffix session ID to datagram, it will be dropped") } - msgWithIDAndType, err := suffixType(msgWithID, udp) + msgWithIDAndType, err := SuffixType(msgWithID, DatagramTypeUDP) if err != nil { return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped") } @@ -82,7 +82,7 @@ func (dm *DatagramMuxerV2) SendToSession(session *packet.Session) error { // SendPacket suffix the datagram type to the packet. The other end of the QUIC connection can demultiplex by parsing // the payload as IP and look at the source and destination. func (dm *DatagramMuxerV2) SendPacket(pk packet.RawPacket) error { - payloadWithVersion, err := suffixType(pk.Data, ip) + payloadWithVersion, err := SuffixType(pk.Data, DatagramTypeIP) if err != nil { return errors.Wrap(err, "Failed to suffix datagram type, it will be dropped") } @@ -121,12 +121,12 @@ func (dm *DatagramMuxerV2) demux(ctx context.Context, msgWithType []byte) error if len(msgWithType) < typeIDLen { return fmt.Errorf("QUIC datagram should have at least %d byte", typeIDLen) } - msgType := datagramV2Type(msgWithType[len(msgWithType)-typeIDLen]) + msgType := DatagramV2Type(msgWithType[len(msgWithType)-typeIDLen]) msg := msgWithType[0 : len(msgWithType)-typeIDLen] switch msgType { - case udp: + case DatagramTypeUDP: return dm.handleSession(ctx, msg) - case ip: + case DatagramTypeIP: return dm.handlePacket(ctx, msg) default: return fmt.Errorf("Unexpected datagram type %d", msgType) diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index 04f5536e..9f6a2fe2 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -16,7 +16,6 @@ import ( "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/h2mux" - "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/orchestration" "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" @@ -117,13 +116,6 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato connAwareLogger: log, } - icmpRouter, err := ingress.NewICMPRouter(config.Log) - if err != nil { - log.Logger().Warn().Err(err).Msg("Failed to create icmp router, ICMP proxy feature is disabled") - } else { - edgeTunnelServer.icmpRouter = icmpRouter - } - useReconnectToken := false if config.ClassicTunnel != nil { useReconnectToken = config.ClassicTunnel.UseReconnectToken @@ -151,9 +143,9 @@ func (s *Supervisor) Run( ctx context.Context, connectedSignal *signal.Signal, ) error { - if s.edgeTunnelServer.icmpRouter != nil { + if s.config.PacketConfig != nil { go func() { - if err := s.edgeTunnelServer.icmpRouter.Serve(ctx); err != nil { + if err := s.config.PacketConfig.ICMPRouter.Serve(ctx); err != nil { if errors.Is(err, net.ErrClosed) { s.log.Logger().Info().Err(err).Msg("icmp router terminated") } else { diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 22fb55e3..5e65f3d1 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -70,6 +70,7 @@ type TunnelConfig struct { MuxerConfig *connection.MuxerConfig ProtocolSelector connection.ProtocolSelector EdgeTLSConfigs map[connection.Protocol]*tls.Config + PacketConfig *packet.GlobalRouterConfig } func (c *TunnelConfig) registrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions { @@ -200,7 +201,6 @@ type EdgeTunnelServer struct { reconnectCh chan ReconnectSignal gracefulShutdownC <-chan struct{} tracker *tunnelstate.ConnTracker - icmpRouter packet.ICMPRouter connAwareLogger *ConnAwareLogger } @@ -661,7 +661,7 @@ func (e *EdgeTunnelServer) serveQUIC( connOptions, controlStreamHandler, connLogger.Logger(), - e.icmpRouter) + e.config.PacketConfig) if err != nil { if e.config.NeedPQ { handlePQTunnelError(err, e.config)