diff --git a/CHANGES.md b/CHANGES.md index b3574850..c3b34105 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,7 @@ +## 2025.7.1 +### Notices +- `cloudflared` will no longer officially support Debian and Ubuntu distros that reached end-of-life: `buster`, `bullseye`, `impish`, `trusty`. + ## 2025.1.1 ### New Features - This release introduces the use of new Post Quantum curves and the ability to use Post Quantum curves when running tunnels with the QUIC protocol this applies to non-FIPS and FIPS builds. diff --git a/Dockerfile b/Dockerfile index d055fed1..fd1676e2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,8 +27,11 @@ LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared # copy our compiled binary COPY --from=builder --chown=nonroot /go/src/github.com/cloudflare/cloudflared/cloudflared /usr/local/bin/ -# run as non-privileged user -USER nonroot +# run as nonroot user +# We need to use numeric user id's because Kubernetes doesn't support strings: +# https://github.com/kubernetes/kubernetes/blob/v1.33.2/pkg/kubelet/kuberuntime/security_context_others.go#L49 +# The `nonroot` user maps to `65532`, from: https://github.com/GoogleContainerTools/distroless/blob/main/common/variables.bzl#L18 +USER 65532:65532 # command / entrypoint of container ENTRYPOINT ["cloudflared", "--no-autoupdate"] diff --git a/Dockerfile.amd64 b/Dockerfile.amd64 index 4afb8827..b00ed3cb 100644 --- a/Dockerfile.amd64 +++ b/Dockerfile.amd64 @@ -22,8 +22,11 @@ LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared # copy our compiled binary COPY --from=builder --chown=nonroot /go/src/github.com/cloudflare/cloudflared/cloudflared /usr/local/bin/ -# run as non-privileged user -USER nonroot +# run as nonroot user +# We need to use numeric user id's because Kubernetes doesn't support strings: +# https://github.com/kubernetes/kubernetes/blob/v1.33.2/pkg/kubelet/kuberuntime/security_context_others.go#L49 +# The `nonroot` user maps to `65532`, from: https://github.com/GoogleContainerTools/distroless/blob/main/common/variables.bzl#L18 +USER 65532:65532 # command / entrypoint of container ENTRYPOINT ["cloudflared", "--no-autoupdate"] diff --git a/Dockerfile.arm64 b/Dockerfile.arm64 index 6e28377b..3bf0ebbf 100644 --- a/Dockerfile.arm64 +++ b/Dockerfile.arm64 @@ -22,8 +22,11 @@ LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared # copy our compiled binary COPY --from=builder --chown=nonroot /go/src/github.com/cloudflare/cloudflared/cloudflared /usr/local/bin/ -# run as non-privileged user -USER nonroot +# run as nonroot user +# We need to use numeric user id's because Kubernetes doesn't support strings: +# https://github.com/kubernetes/kubernetes/blob/v1.33.2/pkg/kubelet/kuberuntime/security_context_others.go#L49 +# The `nonroot` user maps to `65532`, from: https://github.com/GoogleContainerTools/distroless/blob/main/common/variables.bzl#L18 +USER 65532:65532 # command / entrypoint of container ENTRYPOINT ["cloudflared", "--no-autoupdate"] diff --git a/RELEASE_NOTES b/RELEASE_NOTES index 781b94d5..6454069d 100644 --- a/RELEASE_NOTES +++ b/RELEASE_NOTES @@ -1,3 +1,13 @@ +2025.7.0 +- 2025-07-03 TUN-9540: Use numeric user id for Dockerfiles +- 2025-07-01 TUN-9161: Remove P256Kyber768Draft00PQKex curve from nonFips curve preferences +- 2025-07-01 TUN-9531: Bump go-boring from 1.24.2 to 1.24.4 +- 2025-07-01 TUN-9511: Add metrics for virtual DNS origin +- 2025-06-30 TUN-9470: Add OriginDialerService to include TCP +- 2025-06-30 TUN-9473: Add --dns-resolver-addrs flag +- 2025-06-27 TUN-9472: Add virtual DNS service +- 2025-06-23 TUN-9469: Centralize UDP origin proxy dialing as ingress service + 2025.6.1 - 2025-06-16 TUN-9467: add vulncheck to cloudflared - 2025-06-16 TUN-9495: Remove references to cloudflare-go diff --git a/cfsetup.yaml b/cfsetup.yaml index cfd7885c..dd01f650 100644 --- a/cfsetup.yaml +++ b/cfsetup.yaml @@ -1,9 +1,9 @@ -pinned_go: &pinned_go go-boring=1.24.2-1 +pinned_go: &pinned_go go-boring=1.24.4-1 build_dir: &build_dir /cfsetup_build default-flavor: bookworm -bullseye: &bullseye +bookworm: &bookworm build-linux: build_dir: *build_dir builddeps: &build_deps @@ -253,5 +253,4 @@ bullseye: &bullseye - pip install pynacl==1.4.0 pygithub==1.55 boto3==1.22.9 python-gnupg==0.4.9 - make r2-linux-release -bookworm: *bullseye -trixie: *bullseye +trixie: *bookworm diff --git a/cmd/cloudflared/flags/flags.go b/cmd/cloudflared/flags/flags.go index a7bf1b7e..975ee401 100644 --- a/cmd/cloudflared/flags/flags.go +++ b/cmd/cloudflared/flags/flags.go @@ -157,4 +157,7 @@ const ( // ApiURL is the command line flag used to define the base URL of the API ApiURL = "api-url" + + // Virtual DNS resolver service resolver addresses to use instead of dynamically fetching them from the OS. + VirtualDNSServiceResolverAddresses = "dns-resolver-addrs" ) diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 7961c813..63f78426 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -11,6 +11,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" "github.com/rs/zerolog" "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" @@ -25,6 +26,7 @@ import ( "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/features" "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/ingress/origins" "github.com/cloudflare/cloudflared/orchestration" "github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/tlsconfig" @@ -219,6 +221,27 @@ func prepareTunnelConfig( resolvedRegion = endpoint } + warpRoutingConfig := ingress.NewWarpRoutingConfig(&cfg.WarpRouting) + + // Setup origin dialer service and virtual services + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: ingress.NewDialer(warpRoutingConfig), + TCPWriteTimeout: c.Duration(flags.WriteStreamTimeout), + }, log) + + // Setup DNS Resolver Service + originMetrics := origins.NewMetrics(prometheus.DefaultRegisterer) + dnsResolverAddrs := c.StringSlice(flags.VirtualDNSServiceResolverAddresses) + dnsService := origins.NewDNSResolverService(origins.NewDNSDialer(), log, originMetrics) + if len(dnsResolverAddrs) > 0 { + addrs, err := parseResolverAddrPorts(dnsResolverAddrs) + if err != nil { + return nil, nil, fmt.Errorf("invalid %s provided: %w", flags.VirtualDNSServiceResolverAddresses, err) + } + dnsService = origins.NewStaticDNSResolverService(addrs, origins.NewDNSDialer(), log, originMetrics) + } + originDialerService.AddReservedService(dnsService, []netip.AddrPort{origins.VirtualDNSServiceAddr}) + tunnelConfig := &supervisor.TunnelConfig{ ClientConfig: clientConfig, GracePeriod: gracePeriod, @@ -246,6 +269,8 @@ func prepareTunnelConfig( DisableQUICPathMTUDiscovery: c.Bool(flags.QuicDisablePathMTUDiscovery), QUICConnectionLevelFlowControlLimit: c.Uint64(flags.QuicConnLevelFlowControlLimit), QUICStreamLevelFlowControlLimit: c.Uint64(flags.QuicStreamLevelFlowControlLimit), + OriginDNSService: dnsService, + OriginDialerService: originDialerService, } icmpRouter, err := newICMPRouter(c, log) if err != nil { @@ -254,10 +279,10 @@ func prepareTunnelConfig( tunnelConfig.ICMPRouterServer = icmpRouter } orchestratorConfig := &orchestration.Config{ - Ingress: &ingressRules, - WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting), - ConfigurationFlags: parseConfigFlags(c), - WriteTimeout: tunnelConfig.WriteStreamTimeout, + Ingress: &ingressRules, + WarpRouting: warpRoutingConfig, + OriginDialerService: originDialerService, + ConfigurationFlags: parseConfigFlags(c), } return tunnelConfig, orchestratorConfig, nil } @@ -494,3 +519,19 @@ func findLocalAddr(dst net.IP, port int) (netip.Addr, error) { localAddr := localAddrPort.Addr() return localAddr, nil } + +func parseResolverAddrPorts(input []string) ([]netip.AddrPort, error) { + // We don't allow more than 10 resolvers to be provided statically for the resolver service. + if len(input) > 10 { + return nil, errors.New("too many addresses provided, max: 10") + } + addrs := make([]netip.AddrPort, 0, len(input)) + for _, val := range input { + addr, err := netip.ParseAddrPort(val) + if err != nil { + return nil, err + } + addrs = append(addrs, addr) + } + return addrs, nil +} diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 4be655a0..f89e05c1 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -241,6 +241,11 @@ var ( Usage: "Overrides the remote configuration for max active private network flows (TCP/UDP) that this cloudflared instance supports", EnvVars: []string{"TUNNEL_MAX_ACTIVE_FLOWS"}, } + dnsResolverAddrsFlag = &cli.StringSliceFlag{ + Name: flags.VirtualDNSServiceResolverAddresses, + Usage: "Overrides the dynamic DNS resolver resolution to use these address:port's instead.", + EnvVars: []string{"TUNNEL_DNS_RESOLVER_ADDRS"}, + } ) func buildCreateCommand() *cli.Command { @@ -718,6 +723,7 @@ func buildRunCommand() *cli.Command { icmpv4SrcFlag, icmpv6SrcFlag, maxActiveFlowsFlag, + dnsResolverAddrsFlag, } flags = append(flags, configureProxyFlags(false)...) return &cli.Command{ diff --git a/connection/quic_connection_test.go b/connection/quic_connection_test.go index 251f8630..8765fd29 100644 --- a/connection/quic_connection_test.go +++ b/connection/quic_connection_test.go @@ -30,6 +30,7 @@ import ( "golang.org/x/net/nettest" "github.com/cloudflare/cloudflared/client" + "github.com/cloudflare/cloudflared/config" cfdflow "github.com/cloudflare/cloudflared/flow" "github.com/cloudflare/cloudflared/datagramsession" @@ -823,6 +824,15 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) sessionManager := datagramsession.NewManager(&log, datagramMuxer.SendToSession, sessionDemuxChan) var connIndex uint8 = 0 packetRouter := ingress.NewPacketRouter(nil, datagramMuxer, connIndex, &log) + testDefaultDialer := ingress.NewDialer(ingress.WarpRoutingConfig{ + ConnectTimeout: config.CustomDuration{Duration: 1 * time.Second}, + TCPKeepAlive: config.CustomDuration{Duration: 15 * time.Second}, + MaxActiveFlows: 0, + }) + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &log) datagramConn := &datagramV2Connection{ conn, @@ -830,6 +840,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) sessionManager, cfdflow.NewLimiter(0), datagramMuxer, + originDialer, packetRouter, 15 * time.Second, 0 * time.Second, diff --git a/connection/quic_datagram_v2.go b/connection/quic_datagram_v2.go index 01e13466..aebead70 100644 --- a/connection/quic_datagram_v2.go +++ b/connection/quic_datagram_v2.go @@ -4,9 +4,11 @@ import ( "context" "fmt" "net" + "net/netip" "time" "github.com/google/uuid" + "github.com/pkg/errors" pkgerrors "github.com/pkg/errors" "github.com/quic-go/quic-go" "github.com/rs/zerolog" @@ -32,6 +34,10 @@ const ( demuxChanCapacity = 16 ) +var ( + errInvalidDestinationIP = errors.New("unable to parse destination IP") +) + // DatagramSessionHandler is a service that can serve datagrams for a connection and handle sessions from incoming // connection streams. type DatagramSessionHandler interface { @@ -51,7 +57,10 @@ type datagramV2Connection struct { // datagramMuxer mux/demux datagrams from quic connection datagramMuxer *cfdquic.DatagramMuxerV2 - packetRouter *ingress.PacketRouter + // originDialer is the origin dialer for UDP requests + originDialer ingress.OriginUDPDialer + // packetRouter acts as the origin router for ICMP requests + packetRouter *ingress.PacketRouter rpcTimeout time.Duration streamWriteTimeout time.Duration @@ -61,6 +70,7 @@ type datagramV2Connection struct { func NewDatagramV2Connection(ctx context.Context, conn quic.Connection, + originDialer ingress.OriginUDPDialer, icmpRouter ingress.ICMPRouter, index uint8, rpcTimeout time.Duration, @@ -79,6 +89,7 @@ func NewDatagramV2Connection(ctx context.Context, sessionManager: sessionManager, flowLimiter: flowLimiter, datagramMuxer: datagramMuxer, + originDialer: originDialer, packetRouter: packetRouter, rpcTimeout: rpcTimeout, streamWriteTimeout: streamWriteTimeout, @@ -128,12 +139,29 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID tracing.EndWithErrorStatus(registerSpan, err) return nil, err } + // We need to force the net.IP to IPv4 (if it's an IPv4 address) otherwise the net.IP conversion from capnp + // will be a IPv4-mapped-IPv6 address. + // In the case that the address is IPv6 we leave it untouched and parse it as normal. + ip := dstIP.To4() + if ip == nil { + ip = dstIP + } + // Parse the dstIP and dstPort into a netip.AddrPort + // This should never fail because the IP was already parsed as a valid net.IP + destAddr, ok := netip.AddrFromSlice(ip) + if !ok { + log.Err(errInvalidDestinationIP).Msgf("Failed to parse destination proxy IP: %s", ip) + tracing.EndWithErrorStatus(registerSpan, errInvalidDestinationIP) + q.flowLimiter.Release() + return nil, errInvalidDestinationIP + } + dstAddrPort := netip.AddrPortFrom(destAddr, dstPort) // Each session is a series of datagram from an eyeball to a dstIP:dstPort. // (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket. - originProxy, err := ingress.DialUDP(dstIP, dstPort) + originProxy, err := q.originDialer.DialUDP(dstAddrPort) if err != nil { - log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort) + log.Err(err).Msgf("Failed to create udp proxy to %s", dstAddrPort) tracing.EndWithErrorStatus(registerSpan, err) q.flowLimiter.Release() return nil, err diff --git a/connection/quic_datagram_v2_test.go b/connection/quic_datagram_v2_test.go index af58ffdb..e4edac46 100644 --- a/connection/quic_datagram_v2_test.go +++ b/connection/quic_datagram_v2_test.go @@ -84,6 +84,7 @@ func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) { t.Context(), conn, nil, + nil, 0, 0*time.Second, 0*time.Second, diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index f7e08004..139877ad 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -19,7 +19,7 @@ import ( type OriginConnection interface { // Stream should generally be implemented as a bidirectional io.Copy. Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) - Close() + Close() error } type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger) @@ -48,16 +48,7 @@ func (tc *tcpConnection) Write(b []byte) (int, error) { } } - nBytes, err := tc.Conn.Write(b) - if err != nil { - tc.logger.Err(err).Msg("Error writing to the TCP connection") - } - - return nBytes, err -} - -func (tc *tcpConnection) Close() { - tc.Conn.Close() + return tc.Conn.Write(b) } // tcpOverWSConnection is an OriginConnection that streams to TCP over WS. @@ -75,8 +66,8 @@ func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWri wsConn.Close() } -func (wc *tcpOverWSConnection) Close() { - wc.conn.Close() +func (wc *tcpOverWSConnection) Close() error { + return wc.conn.Close() } // socksProxyOverWSConnection is an OriginConnection that streams SOCKS connections over WS. @@ -95,5 +86,6 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io. wsConn.Close() } -func (sp *socksProxyOverWSConnection) Close() { +func (sp *socksProxyOverWSConnection) Close() error { + return nil } diff --git a/ingress/origin_dialer.go b/ingress/origin_dialer.go new file mode 100644 index 00000000..36ade327 --- /dev/null +++ b/ingress/origin_dialer.go @@ -0,0 +1,146 @@ +package ingress + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + "time" + + "github.com/rs/zerolog" +) + +// OriginTCPDialer provides a TCP dial operation to a requested address. +type OriginTCPDialer interface { + DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error) +} + +// OriginUDPDialer provides a UDP dial operation to a requested address. +type OriginUDPDialer interface { + DialUDP(addr netip.AddrPort) (net.Conn, error) +} + +// OriginDialer provides both TCP and UDP dial operations to an address. +type OriginDialer interface { + OriginTCPDialer + OriginUDPDialer +} + +type OriginConfig struct { + // The default Dialer used if no reserved services are found for an origin request. + DefaultDialer OriginDialer + // Timeout on write operations for TCP connections to the origin. + TCPWriteTimeout time.Duration +} + +// OriginDialerService provides a proxy TCP and UDP dialer to origin services while allowing reserved +// services to be provided. These reserved services are assigned to specific [netip.AddrPort]s +// and provide their own [OriginDialer]'s to handle origin dialing per protocol. +type OriginDialerService struct { + // Reserved TCP services for reserved AddrPort values + reservedTCPServices map[netip.AddrPort]OriginTCPDialer + // Reserved UDP services for reserved AddrPort values + reservedUDPServices map[netip.AddrPort]OriginUDPDialer + // The default Dialer used if no reserved services are found for an origin request + defaultDialer OriginDialer + defaultDialerM sync.RWMutex + // Write timeout for TCP connections + writeTimeout time.Duration + + logger *zerolog.Logger +} + +func NewOriginDialer(config OriginConfig, logger *zerolog.Logger) *OriginDialerService { + return &OriginDialerService{ + reservedTCPServices: map[netip.AddrPort]OriginTCPDialer{}, + reservedUDPServices: map[netip.AddrPort]OriginUDPDialer{}, + defaultDialer: config.DefaultDialer, + writeTimeout: config.TCPWriteTimeout, + logger: logger, + } +} + +// AddReservedService adds a reserved virtual service to dial to. +// Not locked and expected to be initialized before calling first dial and not afterwards. +func (d *OriginDialerService) AddReservedService(service OriginDialer, addrs []netip.AddrPort) { + for _, addr := range addrs { + d.reservedTCPServices[addr] = service + d.reservedUDPServices[addr] = service + } +} + +// UpdateDefaultDialer updates the default dialer. +func (d *OriginDialerService) UpdateDefaultDialer(dialer *Dialer) { + d.defaultDialerM.Lock() + defer d.defaultDialerM.Unlock() + d.defaultDialer = dialer +} + +// DialTCP will perform a dial TCP to the requested addr. +func (d *OriginDialerService) DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error) { + conn, err := d.dialTCP(ctx, addr) + if err != nil { + return nil, err + } + // Assign the write timeout for the TCP operations + return &tcpConnection{ + Conn: conn, + writeTimeout: d.writeTimeout, + logger: d.logger, + }, nil +} + +func (d *OriginDialerService) dialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error) { + // Check to see if any reserved services are available for this addr and call their dialer instead. + if dialer, ok := d.reservedTCPServices[addr]; ok { + return dialer.DialTCP(ctx, addr) + } + d.defaultDialerM.RLock() + dialer := d.defaultDialer + d.defaultDialerM.RUnlock() + return dialer.DialTCP(ctx, addr) +} + +// DialUDP will perform a dial UDP to the requested addr. +func (d *OriginDialerService) DialUDP(addr netip.AddrPort) (net.Conn, error) { + // Check to see if any reserved services are available for this addr and call their dialer instead. + if dialer, ok := d.reservedUDPServices[addr]; ok { + return dialer.DialUDP(addr) + } + d.defaultDialerM.RLock() + dialer := d.defaultDialer + d.defaultDialerM.RUnlock() + return dialer.DialUDP(addr) +} + +type Dialer struct { + Dialer net.Dialer +} + +func NewDialer(config WarpRoutingConfig) *Dialer { + return &Dialer{ + Dialer: net.Dialer{ + Timeout: config.ConnectTimeout.Duration, + KeepAlive: config.TCPKeepAlive.Duration, + }, + } +} + +func (d *Dialer) DialTCP(ctx context.Context, dest netip.AddrPort) (net.Conn, error) { + conn, err := d.Dialer.DialContext(ctx, "tcp", dest.String()) + if err != nil { + return nil, fmt.Errorf("unable to dial tcp to origin %s: %w", dest, err) + } + + return conn, nil +} + +func (d *Dialer) DialUDP(dest netip.AddrPort) (net.Conn, error) { + conn, err := d.Dialer.Dial("udp", dest.String()) + if err != nil { + return nil, fmt.Errorf("unable to dial udp to origin %s: %w", dest, err) + } + + return conn, nil +} diff --git a/ingress/origin_udp_proxy.go b/ingress/origin_udp_proxy.go deleted file mode 100644 index 012c05c0..00000000 --- a/ingress/origin_udp_proxy.go +++ /dev/null @@ -1,46 +0,0 @@ -package ingress - -import ( - "fmt" - "io" - "net" - "net/netip" -) - -type UDPProxy interface { - io.ReadWriteCloser - LocalAddr() net.Addr -} - -type udpProxy struct { - *net.UDPConn -} - -func DialUDP(dstIP net.IP, dstPort uint16) (UDPProxy, error) { - dstAddr := &net.UDPAddr{ - IP: dstIP, - Port: int(dstPort), - } - - // We use nil as local addr to force runtime to find the best suitable local address IP given the destination - // address as context. - udpConn, err := net.DialUDP("udp", nil, dstAddr) - if err != nil { - return nil, fmt.Errorf("unable to create UDP proxy to origin (%v:%v): %w", dstIP, dstPort, err) - } - - return &udpProxy{udpConn}, nil -} - -func DialUDPAddrPort(dest netip.AddrPort) (*net.UDPConn, error) { - addr := net.UDPAddrFromAddrPort(dest) - - // We use nil as local addr to force runtime to find the best suitable local address IP given the destination - // address as context. - udpConn, err := net.DialUDP("udp", nil, addr) - if err != nil { - return nil, fmt.Errorf("unable to dial udp to origin %s: %w", dest, err) - } - - return udpConn, nil -} diff --git a/ingress/origins/dns.go b/ingress/origins/dns.go new file mode 100644 index 00000000..c09c581d --- /dev/null +++ b/ingress/origins/dns.go @@ -0,0 +1,219 @@ +package origins + +import ( + "context" + "crypto/rand" + "math/big" + "net" + "net/netip" + "slices" + "sync" + "time" + + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/ingress" +) + +const ( + // We need a DNS record: + // 1. That will be around for as long as cloudflared is + // 2. That Cloudflare controls: to allow us to make changes if needed + // 3. That is an external record to a typical customer's network: enforcing that the DNS request go to the + // local DNS resolver over any local /etc/host configurations setup. + // 4. That cloudflared would normally query: ensuring that users with a positive security model for DNS queries + // don't need to adjust anything. + // + // This hostname is one that used during the edge discovery process and as such satisfies the above constraints. + defaultLookupHost = "region1.v2.argotunnel.com" + defaultResolverPort uint16 = 53 + + // We want the refresh time to be short to accommodate DNS resolver changes locally, but not too frequent as to + // shuffle the resolver if multiple are configured. + refreshFreq = 5 * time.Minute + refreshTimeout = 5 * time.Second +) + +var ( + // Virtual DNS service address + VirtualDNSServiceAddr = netip.AddrPortFrom(netip.MustParseAddr("2606:4700:0cf1:2000:0000:0000:0000:0001"), 53) + + defaultResolverAddr = netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), defaultResolverPort) +) + +type netDial func(network string, address string) (net.Conn, error) + +// DNSResolverService will make DNS requests to the local DNS resolver via the Dial method. +type DNSResolverService struct { + addresses []netip.AddrPort + addressesM sync.RWMutex + static bool + dialer ingress.OriginDialer + resolver peekResolver + logger *zerolog.Logger + metrics Metrics +} + +func NewDNSResolverService(dialer ingress.OriginDialer, logger *zerolog.Logger, metrics Metrics) *DNSResolverService { + return &DNSResolverService{ + addresses: []netip.AddrPort{defaultResolverAddr}, + dialer: dialer, + resolver: &resolver{dialFunc: net.Dial}, + logger: logger, + metrics: metrics, + } +} + +func NewStaticDNSResolverService(resolverAddrs []netip.AddrPort, dialer ingress.OriginDialer, logger *zerolog.Logger, metrics Metrics) *DNSResolverService { + s := NewDNSResolverService(dialer, logger, metrics) + s.addresses = resolverAddrs + s.static = true + return s +} + +func (s *DNSResolverService) DialTCP(ctx context.Context, _ netip.AddrPort) (net.Conn, error) { + s.metrics.IncrementDNSTCPRequests() + dest := s.getAddress() + // The dialer ignores the provided address because the request will instead go to the local DNS resolver. + return s.dialer.DialTCP(ctx, dest) +} + +func (s *DNSResolverService) DialUDP(_ netip.AddrPort) (net.Conn, error) { + s.metrics.IncrementDNSUDPRequests() + dest := s.getAddress() + // The dialer ignores the provided address because the request will instead go to the local DNS resolver. + return s.dialer.DialUDP(dest) +} + +// StartRefreshLoop is a routine that is expected to run in the background to update the DNS local resolver if +// adjusted while the cloudflared process is running. +// Does not run when the resolver was provided with external resolver addresses via CLI. +func (s *DNSResolverService) StartRefreshLoop(ctx context.Context) { + if s.static { + s.logger.Debug().Msgf("Canceled DNS local resolver refresh loop because static resolver addresses were provided: %s", s.addresses) + return + } + // Call update first to load an address before handling traffic + err := s.update(ctx) + if err != nil { + s.logger.Err(err).Msg("Failed to initialize DNS local resolver") + } + for { + select { + case <-ctx.Done(): + return + case <-time.Tick(refreshFreq): + err := s.update(ctx) + if err != nil { + s.logger.Err(err).Msg("Failed to refresh DNS local resolver") + } + } + } +} + +func (s *DNSResolverService) update(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, refreshTimeout) + defer cancel() + // Make a standard DNS request to a well-known DNS record that will last a long time + _, err := s.resolver.lookupNetIP(ctx, defaultLookupHost) + if err != nil { + return err + } + + // Validate the address before updating internal reference + _, address := s.resolver.addr() + peekAddrPort, err := netip.ParseAddrPort(address) + if err == nil { + s.setAddress(peekAddrPort) + return nil + } + // It's possible that the address didn't have an attached port, attempt to parse just the address and use + // the default port 53 + peekAddr, err := netip.ParseAddr(address) + if err != nil { + return err + } + s.setAddress(netip.AddrPortFrom(peekAddr, defaultResolverPort)) + return nil +} + +// returns the address from the peekResolver or from the static addresses if provided. +// If multiple addresses are provided in the static addresses pick one randomly. +func (s *DNSResolverService) getAddress() netip.AddrPort { + s.addressesM.RLock() + defer s.addressesM.RUnlock() + l := len(s.addresses) + if l <= 0 { + return defaultResolverAddr + } + if l == 1 { + return s.addresses[0] + } + // Only initialize the random selection if there is more than one element in the list. + var i int64 = 0 + r, err := rand.Int(rand.Reader, big.NewInt(int64(l))) + // We ignore errors from crypto rand and use index 0; this should be extremely unlikely and the + // list index doesn't need to be cryptographically secure, but linters insist. + if err == nil { + i = r.Int64() + } + return s.addresses[i] +} + +// lock and update the address used for the local DNS resolver +func (s *DNSResolverService) setAddress(addr netip.AddrPort) { + s.addressesM.Lock() + defer s.addressesM.Unlock() + if !slices.Contains(s.addresses, addr) { + s.logger.Debug().Msgf("Updating DNS local resolver: %s", addr) + } + // We only store one address when reading the peekResolver, so we just replace the whole list. + s.addresses = []netip.AddrPort{addr} +} + +type peekResolver interface { + addr() (network string, address string) + lookupNetIP(ctx context.Context, host string) ([]netip.Addr, error) +} + +// resolver is a shim that inspects the go runtime's DNS resolution process to capture the DNS resolver +// address used to complete a DNS request. +type resolver struct { + network string + address string + dialFunc netDial +} + +func (r *resolver) addr() (network string, address string) { + return r.network, r.address +} + +func (r *resolver) lookupNetIP(ctx context.Context, host string) ([]netip.Addr, error) { + resolver := &net.Resolver{ + PreferGo: true, + // Use the peekDial to inspect the results of the DNS resolver used during the LookupIPAddr call. + Dial: r.peekDial, + } + return resolver.LookupNetIP(ctx, "ip", host) +} + +func (r *resolver) peekDial(ctx context.Context, network, address string) (net.Conn, error) { + r.network = network + r.address = address + return r.dialFunc(network, address) +} + +// NewDNSDialer creates a custom dialer for the DNS resolver service to utilize. +func NewDNSDialer() *ingress.Dialer { + return &ingress.Dialer{ + Dialer: net.Dialer{ + // We want short timeouts for the DNS requests + Timeout: 5 * time.Second, + // We do not want keep alive since the edge will not reuse TCP connections per request + KeepAlive: -1, + KeepAliveConfig: net.KeepAliveConfig{ + Enable: false, + }, + }, + } +} diff --git a/ingress/origins/dns_test.go b/ingress/origins/dns_test.go new file mode 100644 index 00000000..c44db1ea --- /dev/null +++ b/ingress/origins/dns_test.go @@ -0,0 +1,195 @@ +package origins + +import ( + "context" + "errors" + "net" + "net/netip" + "slices" + "testing" + "time" + + "github.com/rs/zerolog" +) + +func TestDNSResolver_DefaultResolver(t *testing.T) { + log := zerolog.Nop() + service := NewDNSResolverService(NewDNSDialer(), &log, &noopMetrics{}) + mockResolver := &mockPeekResolver{ + address: "127.0.0.2:53", + } + service.resolver = mockResolver + validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses) +} + +func TestStaticDNSResolver_DefaultResolver(t *testing.T) { + log := zerolog.Nop() + addresses := []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:53"), netip.MustParseAddrPort("1.0.0.1:53")} + service := NewStaticDNSResolverService(addresses, NewDNSDialer(), &log, &noopMetrics{}) + mockResolver := &mockPeekResolver{ + address: "127.0.0.2:53", + } + service.resolver = mockResolver + validateAddrs(t, addresses, service.addresses) +} + +func TestDNSResolver_UpdateResolverAddress(t *testing.T) { + log := zerolog.Nop() + service := NewDNSResolverService(NewDNSDialer(), &log, &noopMetrics{}) + + mockResolver := &mockPeekResolver{} + service.resolver = mockResolver + + tests := []struct { + addr string + expected netip.AddrPort + }{ + {"127.0.0.2:53", netip.MustParseAddrPort("127.0.0.2:53")}, + // missing port should be added (even though this is unlikely to happen) + {"127.0.0.3", netip.MustParseAddrPort("127.0.0.3:53")}, + } + + for _, test := range tests { + mockResolver.address = test.addr + // Update the resolver address + err := service.update(t.Context()) + if err != nil { + t.Error(err) + } + // Validate expected + validateAddrs(t, []netip.AddrPort{test.expected}, service.addresses) + } +} + +func TestStaticDNSResolver_RefreshLoopExits(t *testing.T) { + log := zerolog.Nop() + addresses := []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:53"), netip.MustParseAddrPort("1.0.0.1:53")} + service := NewStaticDNSResolverService(addresses, NewDNSDialer(), &log, &noopMetrics{}) + + mockResolver := &mockPeekResolver{ + address: "127.0.0.2:53", + } + service.resolver = mockResolver + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + go service.StartRefreshLoop(ctx) + + // Wait for the refresh loop to end _and_ not update the addresses + time.Sleep(10 * time.Millisecond) + + // Validate expected + validateAddrs(t, addresses, service.addresses) +} + +func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) { + log := zerolog.Nop() + service := NewDNSResolverService(NewDNSDialer(), &log, &noopMetrics{}) + mockResolver := &mockPeekResolver{} + service.resolver = mockResolver + + invalidAddresses := []string{ + "999.999.999.999", + "localhost", + "255.255.255", + } + + for _, addr := range invalidAddresses { + mockResolver.address = addr + // Update the resolver address should not update for these invalid addresses + err := service.update(t.Context()) + if err == nil { + t.Error("service update should throw an error") + } + // Validate expected + validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses) + } +} + +func TestDNSResolver_UpdateResolverErrorIgnored(t *testing.T) { + log := zerolog.Nop() + service := NewDNSResolverService(NewDNSDialer(), &log, &noopMetrics{}) + resolverErr := errors.New("test resolver error") + mockResolver := &mockPeekResolver{err: resolverErr} + service.resolver = mockResolver + + // Update the resolver address should not update when the resolver cannot complete the lookup + err := service.update(t.Context()) + if err != resolverErr { + t.Error("service update should throw an error") + } + // Validate expected + validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses) +} + +func TestDNSResolver_DialUDPUsesResolvedAddress(t *testing.T) { + log := zerolog.Nop() + mockDialer := &mockDialer{expected: defaultResolverAddr} + service := NewDNSResolverService(mockDialer, &log, &noopMetrics{}) + mockResolver := &mockPeekResolver{} + service.resolver = mockResolver + + // Attempt a dial to 127.0.0.2:53 which should be ignored and instead resolve to 127.0.0.1:53 + _, err := service.DialUDP(netip.MustParseAddrPort("127.0.0.2:53")) + if err != nil { + t.Error(err) + } +} + +func TestDNSResolver_DialTCPUsesResolvedAddress(t *testing.T) { + log := zerolog.Nop() + mockDialer := &mockDialer{expected: defaultResolverAddr} + service := NewDNSResolverService(mockDialer, &log, &noopMetrics{}) + mockResolver := &mockPeekResolver{} + service.resolver = mockResolver + + // Attempt a dial to 127.0.0.2:53 which should be ignored and instead resolve to 127.0.0.1:53 + _, err := service.DialTCP(t.Context(), netip.MustParseAddrPort("127.0.0.2:53")) + if err != nil { + t.Error(err) + } +} + +type mockPeekResolver struct { + err error + address string +} + +func (r *mockPeekResolver) addr() (network, address string) { + return "udp", r.address +} + +func (r *mockPeekResolver) lookupNetIP(ctx context.Context, host string) ([]netip.Addr, error) { + // We can return an empty result as it doesn't matter as long as the lookup doesn't fail + return []netip.Addr{}, r.err +} + +type mockDialer struct { + expected netip.AddrPort +} + +func (d *mockDialer) DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error) { + if d.expected != addr { + return nil, errors.New("unexpected address dialed") + } + return nil, nil +} + +func (d *mockDialer) DialUDP(addr netip.AddrPort) (net.Conn, error) { + if d.expected != addr { + return nil, errors.New("unexpected address dialed") + } + return nil, nil +} + +func validateAddrs(t *testing.T, expected []netip.AddrPort, actual []netip.AddrPort) { + if len(actual) != len(expected) { + t.Errorf("addresses should only contain one element: %s", actual) + } + for _, e := range expected { + if !slices.Contains(actual, e) { + t.Errorf("missing address: %s in %s", e, actual) + } + } +} diff --git a/ingress/origins/metrics.go b/ingress/origins/metrics.go new file mode 100644 index 00000000..8586ec29 --- /dev/null +++ b/ingress/origins/metrics.go @@ -0,0 +1,40 @@ +package origins + +import ( + "github.com/prometheus/client_golang/prometheus" +) + +const ( + namespace = "cloudflared" + subsystem = "virtual_origins" +) + +type Metrics interface { + IncrementDNSUDPRequests() + IncrementDNSTCPRequests() +} + +type metrics struct { + dnsResolverRequests *prometheus.CounterVec +} + +func (m *metrics) IncrementDNSUDPRequests() { + m.dnsResolverRequests.WithLabelValues("udp").Inc() +} + +func (m *metrics) IncrementDNSTCPRequests() { + m.dnsResolverRequests.WithLabelValues("tcp").Inc() +} + +func NewMetrics(registerer prometheus.Registerer) Metrics { + m := &metrics{ + dnsResolverRequests: prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "dns_requests_total", + Help: "Total count of DNS requests that have been proxied to the virtual DNS resolver origin", + }, []string{"protocol"}), + } + registerer.MustRegister(m.dnsResolverRequests) + return m +} diff --git a/ingress/origins/metrics_test.go b/ingress/origins/metrics_test.go new file mode 100644 index 00000000..311b1fa0 --- /dev/null +++ b/ingress/origins/metrics_test.go @@ -0,0 +1,6 @@ +package origins + +type noopMetrics struct{} + +func (noopMetrics) IncrementDNSUDPRequests() {} +func (noopMetrics) IncrementDNSTCPRequests() {} diff --git a/orchestration/config.go b/orchestration/config.go index 04c7a0ab..b87b69a6 100644 --- a/orchestration/config.go +++ b/orchestration/config.go @@ -2,7 +2,6 @@ package orchestration import ( "encoding/json" - "time" "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/ingress" @@ -20,9 +19,9 @@ type newLocalConfig struct { // Config is the original config as read and parsed by cloudflared. type Config struct { - Ingress *ingress.Ingress - WarpRouting ingress.WarpRoutingConfig - WriteTimeout time.Duration + Ingress *ingress.Ingress + WarpRouting ingress.WarpRoutingConfig + OriginDialerService *ingress.OriginDialerService // Extra settings used to configure this instance but that are not eligible for remotely management // ie. (--protocol, --loglevel, ...) diff --git a/orchestration/orchestrator.go b/orchestration/orchestrator.go index abfd1f9b..9840bd36 100644 --- a/orchestration/orchestrator.go +++ b/orchestration/orchestrator.go @@ -38,7 +38,9 @@ type Orchestrator struct { tags []pogs.Tag // flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit. flowLimiter cfdflow.Limiter - log *zerolog.Logger + // Origin dialer service to manage egress socket dialing. + originDialerService *ingress.OriginDialerService + log *zerolog.Logger // orchestrator must not handle any more updates after shutdownC is closed shutdownC <-chan struct{} @@ -50,18 +52,20 @@ func NewOrchestrator(ctx context.Context, config *Config, tags []pogs.Tag, internalRules []ingress.Rule, - log *zerolog.Logger) (*Orchestrator, error) { + log *zerolog.Logger, +) (*Orchestrator, error) { o := &Orchestrator{ // Lowest possible version, any remote configuration will have version higher than this // Starting at -1 allows a configuration migration (local to remote) to override the current configuration as it // will start at version 0. - currentVersion: -1, - internalRules: internalRules, - config: config, - tags: tags, - flowLimiter: cfdflow.NewLimiter(config.WarpRouting.MaxActiveFlows), - log: log, - shutdownC: ctx.Done(), + currentVersion: -1, + internalRules: internalRules, + config: config, + tags: tags, + flowLimiter: cfdflow.NewLimiter(config.WarpRouting.MaxActiveFlows), + originDialerService: config.OriginDialerService, + log: log, + shutdownC: ctx.Done(), } if err := o.updateIngress(*config.Ingress, config.WarpRouting); err != nil { return nil, err @@ -175,7 +179,15 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i // Update the flow limit since the configuration might have changed o.flowLimiter.SetLimit(warpRouting.MaxActiveFlows) - proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.flowLimiter, o.config.WriteTimeout, o.log) + // Update the origin dialer service with the new dialer settings + // We need to update the dialer here instead of creating a new instance of OriginDialerService because it has + // its own references and go routines. Specifically, the UDP dialer is a reference to this same service all the + // way into the datagram manager. Reconstructing the datagram manager is not something we currently provide during + // runtime in response to a configuration push except when starting a tunnel connection. + o.originDialerService.UpdateDefaultDialer(ingress.NewDialer(warpRouting)) + + // Create and replace the origin proxy with a new instance + proxy := proxy.NewOriginProxy(ingressRules, o.originDialerService, o.tags, o.flowLimiter, o.log) o.proxy.Store(proxy) o.config.Ingress = &ingressRules o.config.WarpRouting = warpRouting diff --git a/orchestration/orchestrator_test.go b/orchestration/orchestrator_test.go index aeed4860..7a14b2d4 100644 --- a/orchestration/orchestrator_test.go +++ b/orchestration/orchestrator_test.go @@ -41,6 +41,11 @@ var ( Value: "test", }, } + testDefaultDialer = ingress.NewDialer(ingress.WarpRoutingConfig{ + ConnectTimeout: config.CustomDuration{Duration: 1 * time.Second}, + TCPKeepAlive: config.CustomDuration{Duration: 15 * time.Second}, + MaxActiveFlows: 0, + }) ) // TestUpdateConfiguration tests that @@ -50,8 +55,13 @@ var ( // - configurations can be deserialized // - receiving an old version is noop func TestUpdateConfiguration(t *testing.T) { + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) initConfig := &Config{ - Ingress: &ingress.Ingress{}, + Ingress: &ingress.Ingress{}, + OriginDialerService: originDialer, } orchestrator, err := NewOrchestrator(t.Context(), initConfig, testTags, []ingress.Rule{ingress.NewManagementRule(management.New("management.argotunnel.com", false, "1.1.1.1:80", uuid.Nil, "", &testLogger, nil))}, &testLogger) require.NoError(t, err) @@ -179,8 +189,13 @@ func TestUpdateConfiguration(t *testing.T) { // Validates that a new version 0 will be applied if the configuration is loaded locally. // This will happen when a locally managed tunnel is migrated to remote configuration and receives its first configuration. func TestUpdateConfiguration_FromMigration(t *testing.T) { + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) initConfig := &Config{ - Ingress: &ingress.Ingress{}, + Ingress: &ingress.Ingress{}, + OriginDialerService: originDialer, } orchestrator, err := NewOrchestrator(t.Context(), initConfig, testTags, []ingress.Rule{}, &testLogger) require.NoError(t, err) @@ -205,8 +220,13 @@ func TestUpdateConfiguration_FromMigration(t *testing.T) { // Validates that the default ingress rule will be set if there is no rule provided from the remote. func TestUpdateConfiguration_WithoutIngressRule(t *testing.T) { + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) initConfig := &Config{ - Ingress: &ingress.Ingress{}, + Ingress: &ingress.Ingress{}, + OriginDialerService: originDialer, } orchestrator, err := NewOrchestrator(t.Context(), initConfig, testTags, []ingress.Rule{}, &testLogger) require.NoError(t, err) @@ -244,6 +264,11 @@ func TestConcurrentUpdateAndRead(t *testing.T) { require.NoError(t, err) defer tcpOrigin.Close() + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) + var ( configJSONV1 = []byte(fmt.Sprintf(` { @@ -296,7 +321,8 @@ func TestConcurrentUpdateAndRead(t *testing.T) { appliedV2 = make(chan struct{}) initConfig = &Config{ - Ingress: &ingress.Ingress{}, + Ingress: &ingress.Ingress{}, + OriginDialerService: originDialer, } ) @@ -313,7 +339,7 @@ func TestConcurrentUpdateAndRead(t *testing.T) { go func() { serveTCPOrigin(t, tcpOrigin, &wg) }() - for i := 0; i < concurrentRequests; i++ { + for i := range concurrentRequests { originProxy, err := orchestrator.GetOriginProxy() require.NoError(t, err) wg.Add(1) @@ -323,48 +349,37 @@ func TestConcurrentUpdateAndRead(t *testing.T) { assert.NoError(t, err, "proxyHTTP %d failed %v", i, err) defer resp.Body.Close() - var warpRoutingDisabled bool // The response can be from initOrigin, http_status:204 or http_status:418 switch resp.StatusCode { - // v1 proxy, warp enabled + // v1 proxy case 200: body, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, t.Name(), string(body)) - warpRoutingDisabled = false - // v2 proxy, warp disabled + // v2 proxy case 204: assert.Greater(t, i, concurrentRequests/4) - warpRoutingDisabled = true - // v3 proxy, warp enabled + // v3 proxy case 418: assert.Greater(t, i, concurrentRequests/2) - warpRoutingDisabled = false } // Once we have originProxy, it won't be changed by configuration updates. // We can infer the version by the ProxyHTTP response code pr, pw := io.Pipe() - w := newRespReadWriteFlusher() // Write TCP message and make sure it's echo back. This has to be done in a go routune since ProxyTCP doesn't // return until the stream is closed. - if !warpRoutingDisabled { - wg.Add(1) - go func() { - defer wg.Done() - defer pw.Close() - tcpEyeball(t, pw, tcpBody, w) - }() - } + wg.Add(1) + go func() { + defer wg.Done() + defer pw.Close() + tcpEyeball(t, pw, tcpBody, w) + }() err = proxyTCP(ctx, originProxy, tcpOrigin.Addr().String(), w, pr) - if warpRoutingDisabled { - assert.Error(t, err, "expect proxyTCP %d to return error", i) - } else { - assert.NoError(t, err, "proxyTCP %d failed %v", i, err) - } + assert.NoError(t, err, "proxyTCP %d failed %v", i, err) }(i, originProxy) if i == concurrentRequests/4 { @@ -406,39 +421,47 @@ func TestOverrideWarpRoutingConfigWithLocalValues(t *testing.T) { require.EqualValues(t, expectedValue, warpRouting["maxActiveFlows"]) } - remoteValue := uint64(100) - remoteIngress := ingress.Ingress{} + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) + + // All the possible values set for MaxActiveFlows from the various points that can provide it: + // 1. Initialized value + // 2. Local CLI flag config + // 3. Remote configuration value + initValue := uint64(0) + localValue := uint64(100) + remoteValue := uint64(500) + + initConfig := &Config{ + Ingress: &ingress.Ingress{}, + WarpRouting: ingress.WarpRoutingConfig{ + MaxActiveFlows: initValue, + }, + OriginDialerService: originDialer, + ConfigurationFlags: map[string]string{ + flags.MaxActiveFlows: fmt.Sprintf("%d", localValue), + }, + } + + // We expect the local configuration flag to be the starting value + orchestrator, err := NewOrchestrator(ctx, initConfig, testTags, []ingress.Rule{}, &testLogger) + require.NoError(t, err) + + assertMaxActiveFlows(orchestrator, localValue) + + // Assigning the MaxActiveFlows in the remote config should be ignored over the local config remoteWarpConfig := ingress.WarpRoutingConfig{ MaxActiveFlows: remoteValue, } - remoteConfig := &Config{ - Ingress: &remoteIngress, - WarpRouting: remoteWarpConfig, - ConfigurationFlags: map[string]string{}, - } - orchestrator, err := NewOrchestrator(ctx, remoteConfig, testTags, []ingress.Rule{}, &testLogger) - require.NoError(t, err) - assertMaxActiveFlows(orchestrator, remoteValue) - - // Add a local override for the maxActiveFlows - localValue := uint64(500) - remoteConfig.ConfigurationFlags[flags.MaxActiveFlows] = fmt.Sprintf("%d", localValue) // Force a configuration refresh - err = orchestrator.updateIngress(remoteIngress, remoteWarpConfig) + err = orchestrator.updateIngress(ingress.Ingress{}, remoteWarpConfig) require.NoError(t, err) // Check the value being used is the local one assertMaxActiveFlows(orchestrator, localValue) - - // Remove local override for the maxActiveFlows - delete(remoteConfig.ConfigurationFlags, flags.MaxActiveFlows) - // Force a configuration refresh - err = orchestrator.updateIngress(remoteIngress, remoteWarpConfig) - require.NoError(t, err) - - // Check the value being used is now the remote again - assertMaxActiveFlows(orchestrator, remoteValue) } func proxyHTTP(originProxy connection.OriginProxy, hostname string) (*http.Response, error) { @@ -546,6 +569,10 @@ func updateWithValidation(t *testing.T, orchestrator *Orchestrator, version int3 // TestClosePreviousProxies makes sure proxies started in the previous configuration version are shutdown func TestClosePreviousProxies(t *testing.T) { + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) var ( hostname = "hello.tunnel1.org" configWithHelloWorld = []byte(fmt.Sprintf(` @@ -576,7 +603,8 @@ func TestClosePreviousProxies(t *testing.T) { } `) initConfig = &Config{ - Ingress: &ingress.Ingress{}, + Ingress: &ingress.Ingress{}, + OriginDialerService: originDialer, } ) @@ -638,8 +666,13 @@ func TestPersistentConnection(t *testing.T) { hostname = "http://ws.tunnel.org" ) msg := t.Name() + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &testLogger) initConfig := &Config{ - Ingress: &ingress.Ingress{}, + Ingress: &ingress.Ingress{}, + OriginDialerService: originDialer, } orchestrator, err := NewOrchestrator(t.Context(), initConfig, testTags, []ingress.Rule{}, &testLogger) require.NoError(t, err) @@ -752,8 +785,9 @@ func TestSerializeLocalConfig(t *testing.T) { ConfigurationFlags: map[string]string{"a": "b"}, } - result, _ := json.Marshal(c) - fmt.Println(string(result)) + result, err := json.Marshal(c) + require.NoError(t, err) + require.JSONEq(t, `{"__configuration_flags":{"a":"b"},"ingress":[],"warp-routing":{"connectTimeout":0,"tcpKeepAlive":0}}`, string(result)) } func wsEcho(w http.ResponseWriter, r *http.Request) { diff --git a/proxy/proxy.go b/proxy/proxy.go index c16555ea..6b404189 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -5,11 +5,11 @@ import ( "fmt" "io" "net/http" + "net/netip" "strconv" "time" "github.com/pkg/errors" - pkgerrors "github.com/pkg/errors" "github.com/rs/zerolog" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -35,7 +35,7 @@ const ( // Proxy represents a means to Proxy between cloudflared and the origin services. type Proxy struct { ingressRules ingress.Ingress - warpRouting *ingress.WarpRoutingService + originDialer ingress.OriginTCPDialer tags []pogs.Tag flowLimiter cfdflow.Limiter log *zerolog.Logger @@ -44,21 +44,19 @@ type Proxy struct { // NewOriginProxy returns a new instance of the Proxy struct. func NewOriginProxy( ingressRules ingress.Ingress, - warpRouting ingress.WarpRoutingConfig, + originDialer ingress.OriginDialer, tags []pogs.Tag, flowLimiter cfdflow.Limiter, - writeTimeout time.Duration, log *zerolog.Logger, ) *Proxy { proxy := &Proxy{ ingressRules: ingressRules, + originDialer: originDialer, tags: tags, flowLimiter: flowLimiter, log: log, } - proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting, writeTimeout) - return proxy } @@ -146,24 +144,18 @@ func (p *Proxy) ProxyHTTP( // ProxyTCP proxies to a TCP connection between the origin service and cloudflared. func (p *Proxy) ProxyTCP( ctx context.Context, - rwa connection.ReadWriteAcker, + conn connection.ReadWriteAcker, req *connection.TCPRequest, ) error { incrementTCPRequests() defer decrementTCPConcurrentRequests() - if p.warpRouting == nil { - err := errors.New(`cloudflared received a request from WARP client, but your configuration has disabled ingress from WARP clients. To enable this, set "warp-routing:\n\t enabled: true" in your config.yaml`) - p.log.Error().Msg(err.Error()) - return err - } - logger := newTCPLogger(p.log, req) // Try to start a new flow if err := p.flowLimiter.Acquire(management.TCP.String()); err != nil { logger.Warn().Msg("Too many concurrent flows being handled, rejecting tcp proxy") - return pkgerrors.Wrap(err, "failed to start tcp flow due to rate limiting") + return errors.Wrap(err, "failed to start tcp flow due to rate limiting") } defer p.flowLimiter.Release() @@ -173,7 +165,14 @@ func (p *Proxy) ProxyTCP( tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, &logger) logger.Debug().Msg("tcp proxy stream started") - if err := p.proxyStream(tracedCtx, rwa, req.Dest, p.warpRouting.Proxy, &logger); err != nil { + // Parse the destination into a netip.AddrPort + dest, err := netip.ParseAddrPort(req.Dest) + if err != nil { + logRequestError(&logger, err) + return err + } + + if err := p.proxyTCPStream(tracedCtx, conn, dest, p.originDialer, &logger); err != nil { logRequestError(&logger, err) return err } @@ -291,14 +290,14 @@ func (p *Proxy) proxyStream( tr *tracing.TracedContext, rwa connection.ReadWriteAcker, dest string, - connectionProxy ingress.StreamBasedOriginProxy, + originDialer ingress.StreamBasedOriginProxy, logger *zerolog.Logger, ) error { ctx := tr.Context _, connectSpan := tr.Tracer().Start(ctx, "stream-connect") start := time.Now() - originConn, err := connectionProxy.EstablishConnection(ctx, dest, logger) + originConn, err := originDialer.EstablishConnection(ctx, dest, logger) if err != nil { connectStreamErrors.Inc() tracing.EndWithErrorStatus(connectSpan, err) @@ -322,6 +321,45 @@ func (p *Proxy) proxyStream( return nil } +// proxyTCPStream proxies private network type TCP connections as a stream towards an available origin. +// +// This is different than proxyStream because it's not leveraged ingress rule services and uses the +// originDialer from OriginDialerService. +func (p *Proxy) proxyTCPStream( + tr *tracing.TracedContext, + tunnelConn connection.ReadWriteAcker, + dest netip.AddrPort, + originDialer ingress.OriginTCPDialer, + logger *zerolog.Logger, +) error { + ctx := tr.Context + _, connectSpan := tr.Tracer().Start(ctx, "stream-connect") + + start := time.Now() + originConn, err := originDialer.DialTCP(ctx, dest) + if err != nil { + connectStreamErrors.Inc() + tracing.EndWithErrorStatus(connectSpan, err) + return err + } + connectSpan.End() + defer originConn.Close() + logger.Debug().Msg("origin connection established") + + encodedSpans := tr.GetSpans() + + if err := tunnelConn.AckConnection(encodedSpans); err != nil { + connectStreamErrors.Inc() + return err + } + + connectLatency.Observe(float64(time.Since(start).Milliseconds())) + logger.Debug().Msg("proxy stream acknowledged") + + stream.Pipe(tunnelConn, originConn, logger) + return nil +} + func (p *Proxy) proxyLocalRequest(proxy ingress.HTTPLocalProxy, w connection.ResponseWriter, req *http.Request, isWebsocket bool) { if isWebsocket { // These headers are added since they are stripped off during an eyeball request to origintunneld, but they diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 27a7283e..f97978db 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -34,17 +34,17 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/ingress" - "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) var ( - testTags = []pogs.Tag{{Name: "Name", Value: "value"}} - noWarpRouting = ingress.WarpRoutingConfig{} - testWarpRouting = ingress.WarpRoutingConfig{ - ConnectTimeout: config.CustomDuration{Duration: time.Second}, - } + testTags = []pogs.Tag{{Name: "Name", Value: "value"}} + testDefaultDialer = ingress.NewDialer(ingress.WarpRoutingConfig{ + ConnectTimeout: config.CustomDuration{Duration: 1 * time.Second}, + TCPKeepAlive: config.CustomDuration{Duration: 15 * time.Second}, + MaxActiveFlows: 0, + }) ) type mockHTTPRespWriter struct { @@ -163,7 +163,12 @@ func TestProxySingleOrigin(t *testing.T) { require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log) + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &log) + + proxy := NewOriginProxy(ingressRule, originDialer, testTags, cfdflow.NewLimiter(0), &log) t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxySSE", testProxySSE(proxy)) @@ -358,7 +363,7 @@ type MultipleIngressTest struct { } func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.UnvalidatedIngressRule, tests []MultipleIngressTest) { - ingress, err := ingress.ParseIngress(&config.Configuration{ + ingressRule, err := ingress.ParseIngress(&config.Configuration{ TunnelID: t.Name(), Ingress: unvalidatedIngress, }) @@ -367,9 +372,14 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat log := zerolog.Nop() ctx, cancel := context.WithCancel(t.Context()) - require.NoError(t, ingress.StartOrigins(&log, ctx.Done())) + require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done())) - proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log) + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &log) + + proxy := NewOriginProxy(ingressRule, originDialer, testTags, cfdflow.NewLimiter(0), &log) for _, test := range tests { responseWriter := newMockHTTPRespWriter() @@ -417,7 +427,12 @@ func TestProxyError(t *testing.T) { log := zerolog.Nop() - proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log) + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 1 * time.Second, + }, &log) + + proxy := NewOriginProxy(ing, originDialer, testTags, cfdflow.NewLimiter(0), &log) responseWriter := newMockHTTPRespWriter() req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) @@ -468,7 +483,7 @@ func (r *replayer) Bytes() []byte { // WS - TCP: When a tcp based ingress is configured on the origin and the // eyeball sends tcp packets wrapped in websockets. (E.g: cloudflared access). func TestConnections(t *testing.T) { - logger := logger.Create(nil) + log := zerolog.Nop() replayer := &replayer{rw: bytes.NewBuffer([]byte{})} type args struct { ingressServiceScheme string @@ -476,9 +491,6 @@ func TestConnections(t *testing.T) { eyeballResponseWriter connection.ResponseWriter eyeballRequestBody io.ReadCloser - // Can be set to nil to show warp routing is not enabled. - warpRoutingService *ingress.WarpRoutingService - // eyeball connection type. connectionType connection.Type @@ -489,6 +501,11 @@ func TestConnections(t *testing.T) { flowLimiterResponse error } + originDialer := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) + type want struct { message []byte headers http.Header @@ -531,7 +548,6 @@ func TestConnections(t *testing.T) { originService: runEchoTCPService, eyeballResponseWriter: newTCPRespWriter(replayer), eyeballRequestBody: newTCPRequestBody([]byte("test2")), - warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)), connectionType: connection.TypeTCP, requestHeaders: map[string][]string{ "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, @@ -549,7 +565,6 @@ func TestConnections(t *testing.T) { originService: runEchoWSService, // eyeballResponseWriter gets set after roundtrip dial. eyeballRequestBody: newPipedWSRequestBody([]byte("test3")), - warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)), requestHeaders: map[string][]string{ "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, }, @@ -602,23 +617,6 @@ func TestConnections(t *testing.T) { headers: map[string][]string{}, }, }, - { - name: "tcp-tcp proxy without warpRoutingService enabled", - args: args{ - ingressServiceScheme: "tcp://", - originService: runEchoTCPService, - eyeballResponseWriter: newTCPRespWriter(replayer), - eyeballRequestBody: newTCPRequestBody([]byte("test2")), - connectionType: connection.TypeTCP, - requestHeaders: map[string][]string{ - "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, - }, - }, - want: want{ - message: []byte{}, - err: true, - }, - }, { name: "ws-ws proxy when origin is different", args: args{ @@ -671,7 +669,6 @@ func TestConnections(t *testing.T) { originService: runEchoTCPService, eyeballResponseWriter: newTCPRespWriter(replayer), eyeballRequestBody: newTCPRequestBody([]byte("rate-limited")), - warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)), connectionType: connection.TypeTCP, requestHeaders: map[string][]string{ "Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, @@ -694,7 +691,7 @@ func TestConnections(t *testing.T) { test.args.originService(t, ln) ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String()) - _ = ingressRule.StartOrigins(logger, ctx.Done()) + _ = ingressRule.StartOrigins(&log, ctx.Done()) // Mock flow limiter ctrl := gomock.NewController(t) @@ -703,8 +700,7 @@ func TestConnections(t *testing.T) { flowLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.flowLimiterResponse) flowLimiter.EXPECT().Release().AnyTimes() - proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, flowLimiter, time.Duration(0), logger) - proxy.warpRouting = test.args.warpRoutingService + proxy := NewOriginProxy(ingressRule, originDialer, testTags, flowLimiter, &log) dest := ln.Addr().String() req, err := http.NewRequest( diff --git a/quic/v3/manager.go b/quic/v3/manager.go index c32cc563..d2456ff4 100644 --- a/quic/v3/manager.go +++ b/quic/v3/manager.go @@ -2,12 +2,11 @@ package v3 import ( "errors" - "net" - "net/netip" "sync" "github.com/rs/zerolog" + "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/management" cfdflow "github.com/cloudflare/cloudflared/flow" @@ -38,18 +37,16 @@ type SessionManager interface { UnregisterSession(requestID RequestID) } -type DialUDP func(dest netip.AddrPort) (*net.UDPConn, error) - type sessionManager struct { sessions map[RequestID]Session mutex sync.RWMutex - originDialer DialUDP + originDialer ingress.OriginUDPDialer limiter cfdflow.Limiter metrics Metrics log *zerolog.Logger } -func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP, limiter cfdflow.Limiter) SessionManager { +func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer ingress.OriginUDPDialer, limiter cfdflow.Limiter) SessionManager { return &sessionManager{ sessions: make(map[RequestID]Session), originDialer: originDialer, @@ -76,7 +73,7 @@ func (s *sessionManager) RegisterSession(request *UDPSessionRegistrationDatagram } // Attempt to bind the UDP socket for the new session - origin, err := s.originDialer(request.Dest) + origin, err := s.originDialer.DialUDP(request.Dest) if err != nil { return nil, err } diff --git a/quic/v3/manager_test.go b/quic/v3/manager_test.go index 759b08c6..80daf685 100644 --- a/quic/v3/manager_test.go +++ b/quic/v3/manager_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/mocks" cfdflow "github.com/cloudflare/cloudflared/flow" @@ -18,9 +19,21 @@ import ( v3 "github.com/cloudflare/cloudflared/quic/v3" ) +var ( + testDefaultDialer = ingress.NewDialer(ingress.WarpRoutingConfig{ + ConnectTimeout: config.CustomDuration{Duration: 1 * time.Second}, + TCPKeepAlive: config.CustomDuration{Duration: 15 * time.Second}, + MaxActiveFlows: 0, + }) +) + func TestRegisterSession(t *testing.T) { log := zerolog.Nop() - manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)) + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) + manager := v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)) request := v3.UDPSessionRegistrationDatagram{ RequestID: testRequestID, @@ -76,7 +89,11 @@ func TestRegisterSession(t *testing.T) { func TestGetSession_Empty(t *testing.T) { log := zerolog.Nop() - manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)) + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) + manager := v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)) _, err := manager.GetSession(testRequestID) if !errors.Is(err, v3.ErrSessionNotFound) { @@ -86,6 +103,10 @@ func TestGetSession_Empty(t *testing.T) { func TestRegisterSessionRateLimit(t *testing.T) { log := zerolog.Nop() + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) ctrl := gomock.NewController(t) flowLimiterMock := mocks.NewMockLimiter(ctrl) @@ -93,7 +114,7 @@ func TestRegisterSessionRateLimit(t *testing.T) { flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows) flowLimiterMock.EXPECT().Release().Times(0) - manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, flowLimiterMock) + manager := v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, flowLimiterMock) request := v3.UDPSessionRegistrationDatagram{ RequestID: testRequestID, diff --git a/quic/v3/muxer_test.go b/quic/v3/muxer_test.go index 555489f5..729abd3c 100644 --- a/quic/v3/muxer_test.go +++ b/quic/v3/muxer_test.go @@ -88,7 +88,11 @@ func (m *mockEyeball) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.RawP func TestDatagramConn_New(t *testing.T) { log := zerolog.Nop() - conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) + conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) if conn == nil { t.Fatal("expected valid connection") } @@ -96,8 +100,12 @@ func TestDatagramConn_New(t *testing.T) { func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) { log := zerolog.Nop() + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) quic := newMockQuicConn() - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) payload := []byte{0xef, 0xef} err := conn.SendUDPSessionDatagram(payload) @@ -111,8 +119,12 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) { func TestDatagramConn_SendUDPSessionResponse(t *testing.T) { log := zerolog.Nop() + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) quic := newMockQuicConn() - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable) require.NoError(t, err) @@ -133,8 +145,12 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) { func TestDatagramConnServe_ApplicationClosed(t *testing.T) { log := zerolog.Nop() + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) quic := newMockQuicConn() - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) defer cancel() @@ -146,11 +162,15 @@ func TestDatagramConnServe_ApplicationClosed(t *testing.T) { func TestDatagramConnServe_ConnectionClosed(t *testing.T) { log := zerolog.Nop() + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) quic := newMockQuicConn() ctx, cancel := context.WithTimeout(t.Context(), 1*time.Second) defer cancel() quic.ctx = ctx - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) err := conn.Serve(t.Context()) if !errors.Is(err, context.DeadlineExceeded) { @@ -160,8 +180,12 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) { func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) { log := zerolog.Nop() + originDialerService := ingress.NewOriginDialer(ingress.OriginConfig{ + DefaultDialer: testDefaultDialer, + TCPWriteTimeout: 0, + }, &log) quic := &mockQuicConnReadError{err: net.ErrClosed} - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, originDialerService, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log) err := conn.Serve(t.Context()) if !errors.Is(err, net.ErrClosed) { diff --git a/release_pkgs.py b/release_pkgs.py index 92d38195..aa5d5dd1 100644 --- a/release_pkgs.py +++ b/release_pkgs.py @@ -346,8 +346,7 @@ def parse_args(): ) parser.add_argument( - "--deb-based-releases", default=["any", "bookworm", "bullseye", "buster", "noble", "jammy", "impish", "focal", "bionic", - "xenial", "trusty"], + "--deb-based-releases", default=["any", "bookworm", "noble", "jammy", "focal", "bionic", "xenial"], help="list of debian based releases that need to be packaged for" ) diff --git a/supervisor/pqtunnels.go b/supervisor/pqtunnels.go index 2eaad9e8..30eb2e87 100644 --- a/supervisor/pqtunnels.go +++ b/supervisor/pqtunnels.go @@ -17,8 +17,8 @@ const ( ) var ( - nonFipsPostQuantumStrictPKex []tls.CurveID = []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex} - nonFipsPostQuantumPreferPKex []tls.CurveID = []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex} + nonFipsPostQuantumStrictPKex []tls.CurveID = []tls.CurveID{X25519MLKEM768PQKex} + nonFipsPostQuantumPreferPKex []tls.CurveID = []tls.CurveID{X25519MLKEM768PQKex} fipsPostQuantumStrictPKex []tls.CurveID = []tls.CurveID{P256Kyber768Draft00PQKex} fipsPostQuantumPreferPKex []tls.CurveID = []tls.CurveID{P256Kyber768Draft00PQKex, tls.CurveP256} ) diff --git a/supervisor/pqtunnels_test.go b/supervisor/pqtunnels_test.go index 383200db..3be54460 100644 --- a/supervisor/pqtunnels_test.go +++ b/supervisor/pqtunnels_test.go @@ -2,12 +2,16 @@ package supervisor import ( "crypto/tls" + "net/http" + "net/http/httptest" + "slices" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/cloudflare/cloudflared/features" + "github.com/cloudflare/cloudflared/fips" ) func TestCurvePreferences(t *testing.T) { @@ -48,7 +52,7 @@ func TestCurvePreferences(t *testing.T) { pqMode: features.PostQuantumPrefer, fipsEnabled: false, currentCurves: []tls.CurveID{tls.CurveP256}, - expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex, tls.CurveP256}, + expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, tls.CurveP256}, }, { name: "Non FIPS with Prefer PQ - no duplicates", @@ -62,14 +66,14 @@ func TestCurvePreferences(t *testing.T) { pqMode: features.PostQuantumPrefer, fipsEnabled: false, currentCurves: []tls.CurveID{tls.CurveP256, X25519Kyber768Draft00PQKex}, - expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex, tls.CurveP256}, + expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, tls.CurveP256, X25519Kyber768Draft00PQKex}, }, { name: "Non FIPS with Strict PQ", pqMode: features.PostQuantumStrict, fipsEnabled: false, currentCurves: []tls.CurveID{tls.CurveP256}, - expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex}, + expectedCurves: []tls.CurveID{X25519MLKEM768PQKex}, }, } @@ -82,3 +86,34 @@ func TestCurvePreferences(t *testing.T) { }) } } + +func runClientServerHandshake(t *testing.T, curves []tls.CurveID) []tls.CurveID { + var advertisedCurves []tls.CurveID + ts := httptest.NewUnstartedServer(nil) + ts.TLS = &tls.Config{ // nolint: gosec + GetConfigForClient: func(chi *tls.ClientHelloInfo) (*tls.Config, error) { + advertisedCurves = slices.Clone(chi.SupportedCurves) + return nil, nil + }, + } + ts.StartTLS() + defer ts.Close() + clientTlsConfig := ts.Client().Transport.(*http.Transport).TLSClientConfig + clientTlsConfig.CurvePreferences = curves + resp, err := ts.Client().Head(ts.URL) + if err != nil { + t.Error(err) + return nil + } + defer resp.Body.Close() + return advertisedCurves +} + +func TestSupportedCurvesNegotiation(t *testing.T) { + for _, tcase := range []features.PostQuantumMode{features.PostQuantumPrefer} { + curves, err := curvePreference(tcase, fips.IsFipsEnabled(), make([]tls.CurveID, 0)) + require.NoError(t, err) + advertisedCurves := runClientServerHandshake(t, curves) + assert.Equal(t, curves, advertisedCurves) + } +} diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index df8bbd46..cb25d68a 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -13,7 +13,6 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" - "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/orchestration" v3 "github.com/cloudflare/cloudflared/quic/v3" "github.com/cloudflare/cloudflared/retry" @@ -78,7 +77,8 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato edgeBindAddr := config.EdgeBindAddr datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer) - sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetFlowLimiter()) + + sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, config.OriginDialerService, orchestrator.GetFlowLimiter()) edgeTunnelServer := EdgeTunnelServer{ config: config, @@ -125,6 +125,9 @@ func (s *Supervisor) Run( }() } + // Setup DNS Resolver refresh + go s.config.OriginDNSService.StartRefreshLoop(ctx) + if err := s.initialize(ctx, connectedSignal); err != nil { if err == errEarlyShutdown { return nil diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index c708c944..b73eecb9 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -24,6 +24,7 @@ import ( "github.com/cloudflare/cloudflared/features" "github.com/cloudflare/cloudflared/fips" "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/ingress/origins" "github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/orchestration" quicpogs "github.com/cloudflare/cloudflared/quic" @@ -60,10 +61,12 @@ type TunnelConfig struct { NeedPQ bool - NamedTunnel *connection.TunnelProperties - ProtocolSelector connection.ProtocolSelector - EdgeTLSConfigs map[connection.Protocol]*tls.Config - ICMPRouterServer ingress.ICMPRouterServer + NamedTunnel *connection.TunnelProperties + ProtocolSelector connection.ProtocolSelector + EdgeTLSConfigs map[connection.Protocol]*tls.Config + ICMPRouterServer ingress.ICMPRouterServer + OriginDNSService *origins.DNSResolverService + OriginDialerService *ingress.OriginDialerService RPCTimeout time.Duration WriteStreamTimeout time.Duration @@ -613,6 +616,7 @@ func (e *EdgeTunnelServer) serveQUIC( datagramSessionManager = connection.NewDatagramV2Connection( ctx, conn, + e.config.OriginDialerService, e.config.ICMPRouterServer, connIndex, e.config.RPCTimeout,