diff --git a/cmd/cloudflared/app_resolver_service.go b/cmd/cloudflared/app_resolver_service.go index 4b922d20..86155556 100644 --- a/cmd/cloudflared/app_resolver_service.go +++ b/cmd/cloudflared/app_resolver_service.go @@ -11,8 +11,9 @@ const ( // ResolverServiceType is used to identify what kind of overwatch service this is ResolverServiceType = "resolver" - LogFieldResolverAddress = "resolverAddress" - LogFieldResolverPort = "resolverPort" + LogFieldResolverAddress = "resolverAddress" + LogFieldResolverPort = "resolverPort" + LogFieldResolverMaxUpstreamConns = "resolverMaxUpstreamConns" ) // ResolverService is used to wrap the tunneldns package's DNS over HTTP @@ -57,7 +58,7 @@ func (s *ResolverService) Shutdown() { func (s *ResolverService) Run() error { // create a listener l, err := tunneldns.CreateListener(s.resolver.AddressOrDefault(), s.resolver.PortOrDefault(), - s.resolver.UpstreamsOrDefault(), s.resolver.BootstrapsOrDefault(), s.log) + s.resolver.UpstreamsOrDefault(), s.resolver.BootstrapsOrDefault(), s.resolver.MaxUpstreamConnectionsOrDefault(), s.log) if err != nil { return err } @@ -74,6 +75,7 @@ func (s *ResolverService) Run() error { resolverLog := s.log.With(). Str(LogFieldResolverAddress, s.resolver.AddressOrDefault()). Uint16(LogFieldResolverPort, s.resolver.PortOrDefault()). + Int(LogFieldResolverMaxUpstreamConns, s.resolver.MaxUpstreamConnectionsOrDefault()). Logger() resolverLog.Info().Msg("Starting resolver") diff --git a/cmd/cloudflared/config/model.go b/cmd/cloudflared/config/model.go index e1c40233..f90fd1ec 100644 --- a/cmd/cloudflared/config/model.go +++ b/cmd/cloudflared/config/model.go @@ -5,6 +5,8 @@ import ( "fmt" "io" "strings" + + "github.com/cloudflare/cloudflared/tunneldns" ) // Forwarder represents a client side listener to forward traffic to the edge @@ -25,11 +27,12 @@ type Tunnel struct { // DNSResolver represents a client side DNS resolver type DNSResolver struct { - Enabled bool `json:"enabled"` - Address string `json:"address,omitempty"` - Port uint16 `json:"port,omitempty"` - Upstreams []string `json:"upstreams,omitempty"` - Bootstraps []string `json:"bootstraps,omitempty"` + Enabled bool `json:"enabled"` + Address string `json:"address,omitempty"` + Port uint16 `json:"port,omitempty"` + Upstreams []string `json:"upstreams,omitempty"` + Bootstraps []string `json:"bootstraps,omitempty"` + MaxUpstreamConnections int `json:"max_upstream_connections,omitempty"` } // Root is the base options to configure the service @@ -59,6 +62,7 @@ func (r *DNSResolver) Hash() string { io.WriteString(h, strings.Join(r.Bootstraps, ",")) io.WriteString(h, strings.Join(r.Upstreams, ",")) io.WriteString(h, fmt.Sprintf("%d", r.Port)) + io.WriteString(h, fmt.Sprintf("%d", r.MaxUpstreamConnections)) io.WriteString(h, fmt.Sprintf("%v", r.Enabled)) return fmt.Sprintf("%x", h.Sum(nil)) } @@ -99,3 +103,11 @@ func (r *DNSResolver) BootstrapsOrDefault() []string { } return []string{"https://162.159.36.1/dns-query", "https://162.159.46.1/dns-query", "https://[2606:4700:4700::1111]/dns-query", "https://[2606:4700:4700::1001]/dns-query"} } + +// MaxUpstreamConnectionsOrDefault return the max upstream connections or returns the default if negative +func (r *DNSResolver) MaxUpstreamConnectionsOrDefault() int { + if r.MaxUpstreamConnections >= 0 { + return r.MaxUpstreamConnections + } + return tunneldns.MaxUpstreamConnsDefault +} diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index a600cb70..a8be21d7 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -1011,6 +1011,13 @@ func configureProxyDNSFlags(shouldHide bool) []cli.Flag { EnvVars: []string{"TUNNEL_DNS_UPSTREAM"}, Hidden: shouldHide, }), + altsrc.NewIntFlag(&cli.IntFlag{ + Name: "proxy-dns-max-upstream-conns", + Usage: "Maximum concurrent connections to upstream. Setting to 0 means unlimited.", + Value: tunneldns.MaxUpstreamConnsDefault, + Hidden: shouldHide, + EnvVars: []string{"TUNNEL_DNS_MAX_UPSTREAM_CONNS"}, + }), altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ Name: "proxy-dns-bootstrap", Usage: "bootstrap endpoint URL, you can specify multiple endpoints for redundancy.", diff --git a/cmd/cloudflared/tunnel/server.go b/cmd/cloudflared/tunnel/server.go index 7f6f78ea..63ffad49 100644 --- a/cmd/cloudflared/tunnel/server.go +++ b/cmd/cloudflared/tunnel/server.go @@ -1,6 +1,8 @@ package tunnel import ( + "fmt" + "github.com/cloudflare/cloudflared/tunneldns" "github.com/pkg/errors" @@ -13,7 +15,11 @@ func runDNSProxyServer(c *cli.Context, dnsReadySignal chan struct{}, shutdownC < if port <= 0 || port > 65535 { return errors.New("The 'proxy-dns-port' must be a valid port number in <1, 65535> range.") } - listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(port), c.StringSlice("proxy-dns-upstream"), c.StringSlice("proxy-dns-bootstrap"), log) + maxUpstreamConnections := c.Int("proxy-dns-max-upstream-conns") + if maxUpstreamConnections < 0 { + return fmt.Errorf("'%s' must be 0 or higher", "proxy-dns-max-upstream-conns") + } + listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(port), c.StringSlice("proxy-dns-upstream"), c.StringSlice("proxy-dns-bootstrap"), maxUpstreamConnections, log) if err != nil { close(dnsReadySignal) listener.Stop() diff --git a/tunneldns/https_upstream.go b/tunneldns/https_upstream.go index 6ea66725..c5e1c19a 100644 --- a/tunneldns/https_upstream.go +++ b/tunneldns/https_upstream.go @@ -30,12 +30,12 @@ type UpstreamHTTPS struct { } // NewUpstreamHTTPS creates a new DNS over HTTPS upstream from endpoint -func NewUpstreamHTTPS(endpoint string, bootstraps []string, log *zerolog.Logger) (Upstream, error) { +func NewUpstreamHTTPS(endpoint string, bootstraps []string, maxConnections int, log *zerolog.Logger) (Upstream, error) { u, err := url.Parse(endpoint) if err != nil { return nil, err } - return &UpstreamHTTPS{client: configureClient(u.Hostname()), endpoint: u, bootstraps: bootstraps, log: log}, nil + return &UpstreamHTTPS{client: configureClient(u.Hostname(), maxConnections), endpoint: u, bootstraps: bootstraps, log: log}, nil } // Exchange provides an implementation for the Upstream interface @@ -122,17 +122,18 @@ func configureBootstrap(bootstrap string) (*url.URL, *http.Client, error) { return nil, nil, fmt.Errorf("bootstrap address of %s must be an IP address", b.Hostname()) } - return b, configureClient(b.Hostname()), nil + return b, configureClient(b.Hostname(), MaxUpstreamConnsDefault), nil } // configureClient will configure a HTTPS client for upstream DoH requests -func configureClient(hostname string) *http.Client { +func configureClient(hostname string, maxUpstreamConnections int) *http.Client { // Update TLS and HTTP client configuration tlsConfig := &tls.Config{ServerName: hostname} transport := &http.Transport{ TLSClientConfig: tlsConfig, DisableCompression: true, MaxIdleConns: 1, + MaxConnsPerHost: maxUpstreamConnections, Proxy: http.ProxyFromEnvironment, } _ = http2.ConfigureTransport(transport) diff --git a/tunneldns/tunnel.go b/tunneldns/tunnel.go index 2c006ac8..485277aa 100644 --- a/tunneldns/tunnel.go +++ b/tunneldns/tunnel.go @@ -21,8 +21,9 @@ import ( ) const ( - LogFieldAddress = "address" - LogFieldURL = "url" + LogFieldAddress = "address" + LogFieldURL = "url" + MaxUpstreamConnsDefault = 5 ) // Listener is an adapter between CoreDNS server and Warp runnable @@ -69,6 +70,12 @@ func Command(hidden bool) *cli.Command { Value: cli.NewStringSlice("https://162.159.36.1/dns-query", "https://162.159.46.1/dns-query", "https://[2606:4700:4700::1111]/dns-query", "https://[2606:4700:4700::1001]/dns-query"), EnvVars: []string{"TUNNEL_DNS_BOOTSTRAP"}, }, + &cli.IntFlag{ + Name: "max-upstream-conns", + Usage: "Maximum concurrent connections to upstream. Setting to 0 means unlimited.", + Value: MaxUpstreamConnsDefault, + EnvVars: []string{"TUNNEL_DNS_MAX_UPSTREAM_CONNS"}, + }, }, ArgsUsage: " ", // can't be the empty string or we get the default output Hidden: hidden, @@ -92,8 +99,10 @@ func Run(c *cli.Context) error { uint16(c.Int("port")), c.StringSlice("upstream"), c.StringSlice("bootstrap"), + c.Int("max-upstream-conns"), log, ) + if err != nil { log.Err(err).Msg("Failed to create the listeners") return err @@ -175,12 +184,12 @@ func (l *Listener) Stop() error { } // CreateListener configures the server and bound sockets -func CreateListener(address string, port uint16, upstreams []string, bootstraps []string, log *zerolog.Logger) (*Listener, error) { +func CreateListener(address string, port uint16, upstreams []string, bootstraps []string, maxUpstreamConnections int, log *zerolog.Logger) (*Listener, error) { // Build the list of upstreams upstreamList := make([]Upstream, 0) for _, url := range upstreams { log.Info().Str(LogFieldURL, url).Msg("Adding DNS upstream") - upstream, err := NewUpstreamHTTPS(url, bootstraps, log) + upstream, err := NewUpstreamHTTPS(url, bootstraps, maxUpstreamConnections, log) if err != nil { return nil, errors.Wrap(err, "failed to create HTTPS upstream") }