diff --git a/cmd/cloudflared/app_resolver_service.go b/cmd/cloudflared/app_resolver_service.go index 4b922d20..86155556 100644 --- a/cmd/cloudflared/app_resolver_service.go +++ b/cmd/cloudflared/app_resolver_service.go @@ -11,8 +11,9 @@ const ( // ResolverServiceType is used to identify what kind of overwatch service this is ResolverServiceType = "resolver" - LogFieldResolverAddress = "resolverAddress" - LogFieldResolverPort = "resolverPort" + LogFieldResolverAddress = "resolverAddress" + LogFieldResolverPort = "resolverPort" + LogFieldResolverMaxUpstreamConns = "resolverMaxUpstreamConns" ) // ResolverService is used to wrap the tunneldns package's DNS over HTTP @@ -57,7 +58,7 @@ func (s *ResolverService) Shutdown() { func (s *ResolverService) Run() error { // create a listener l, err := tunneldns.CreateListener(s.resolver.AddressOrDefault(), s.resolver.PortOrDefault(), - s.resolver.UpstreamsOrDefault(), s.resolver.BootstrapsOrDefault(), s.log) + s.resolver.UpstreamsOrDefault(), s.resolver.BootstrapsOrDefault(), s.resolver.MaxUpstreamConnectionsOrDefault(), s.log) if err != nil { return err } @@ -74,6 +75,7 @@ func (s *ResolverService) Run() error { resolverLog := s.log.With(). Str(LogFieldResolverAddress, s.resolver.AddressOrDefault()). Uint16(LogFieldResolverPort, s.resolver.PortOrDefault()). + Int(LogFieldResolverMaxUpstreamConns, s.resolver.MaxUpstreamConnectionsOrDefault()). Logger() resolverLog.Info().Msg("Starting resolver") diff --git a/cmd/cloudflared/config/model.go b/cmd/cloudflared/config/model.go index e1c40233..ae7c9f89 100644 --- a/cmd/cloudflared/config/model.go +++ b/cmd/cloudflared/config/model.go @@ -30,6 +30,7 @@ type DNSResolver struct { Port uint16 `json:"port,omitempty"` Upstreams []string `json:"upstreams,omitempty"` Bootstraps []string `json:"bootstraps,omitempty"` + MaxUpstreamConnections int `json:"max_upstream_connections,omitempty"` } // Root is the base options to configure the service @@ -59,6 +60,7 @@ func (r *DNSResolver) Hash() string { io.WriteString(h, strings.Join(r.Bootstraps, ",")) io.WriteString(h, strings.Join(r.Upstreams, ",")) io.WriteString(h, fmt.Sprintf("%d", r.Port)) + io.WriteString(h, fmt.Sprintf("%d", r.MaxUpstreamConnections)) io.WriteString(h, fmt.Sprintf("%v", r.Enabled)) return fmt.Sprintf("%x", h.Sum(nil)) } @@ -99,3 +101,11 @@ func (r *DNSResolver) BootstrapsOrDefault() []string { } return []string{"https://162.159.36.1/dns-query", "https://162.159.46.1/dns-query", "https://[2606:4700:4700::1111]/dns-query", "https://[2606:4700:4700::1001]/dns-query"} } + +// MaxUpstreamConnectionsOrDefault return the max upstream connections or returns the default if 0 +func (r *DNSResolver) MaxUpstreamConnectionsOrDefault() int { + if r.MaxUpstreamConnections >= 0 { + return r.MaxUpstreamConnections + } + return 0 +} diff --git a/cmd/cloudflared/tunnel/server.go b/cmd/cloudflared/tunnel/server.go index b1f36222..c5b38e50 100644 --- a/cmd/cloudflared/tunnel/server.go +++ b/cmd/cloudflared/tunnel/server.go @@ -13,7 +13,11 @@ func runDNSProxyServer(c *cli.Context, dnsReadySignal, shutdownC chan struct{}, if port <= 0 || port > 65535 { return errors.New("The 'proxy-dns-port' must be a valid port number in <1, 65535> range.") } - listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(port), c.StringSlice("proxy-dns-upstream"), c.StringSlice("proxy-dns-bootstrap"), log) + maxUpstreamConnections := c.Int("proxy-dns-max-upstream-conns") + if maxUpstreamConnections < 0 { + return errors.New("'proxy-dns-max-upstream-conns' must be 0 or higher") + } + listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(port), c.StringSlice("proxy-dns-upstream"), c.StringSlice("proxy-dns-bootstrap"), maxUpstreamConnections, log) if err != nil { close(dnsReadySignal) listener.Stop() diff --git a/tunneldns/https_upstream.go b/tunneldns/https_upstream.go index 6ea66725..979f988a 100644 --- a/tunneldns/https_upstream.go +++ b/tunneldns/https_upstream.go @@ -23,19 +23,20 @@ const ( // UpstreamHTTPS is the upstream implementation for DNS over HTTPS service type UpstreamHTTPS struct { - client *http.Client - endpoint *url.URL - bootstraps []string - log *zerolog.Logger + client *http.Client + endpoint *url.URL + bootstraps []string + maxConnections int + log *zerolog.Logger } // NewUpstreamHTTPS creates a new DNS over HTTPS upstream from endpoint -func NewUpstreamHTTPS(endpoint string, bootstraps []string, log *zerolog.Logger) (Upstream, error) { +func NewUpstreamHTTPS(endpoint string, bootstraps []string, maxConnections int, log *zerolog.Logger) (Upstream, error) { u, err := url.Parse(endpoint) if err != nil { return nil, err } - return &UpstreamHTTPS{client: configureClient(u.Hostname()), endpoint: u, bootstraps: bootstraps, log: log}, nil + return &UpstreamHTTPS{client: configureClient(u.Hostname(), maxConnections), endpoint: u, bootstraps: bootstraps, log: log}, nil } // Exchange provides an implementation for the Upstream interface @@ -122,17 +123,18 @@ func configureBootstrap(bootstrap string) (*url.URL, *http.Client, error) { return nil, nil, fmt.Errorf("bootstrap address of %s must be an IP address", b.Hostname()) } - return b, configureClient(b.Hostname()), nil + return b, configureClient(b.Hostname(), 0), nil } // configureClient will configure a HTTPS client for upstream DoH requests -func configureClient(hostname string) *http.Client { +func configureClient(hostname string, maxUpstreamConnections int) *http.Client { // Update TLS and HTTP client configuration tlsConfig := &tls.Config{ServerName: hostname} transport := &http.Transport{ TLSClientConfig: tlsConfig, DisableCompression: true, MaxIdleConns: 1, + MaxConnsPerHost: maxUpstreamConnections, Proxy: http.ProxyFromEnvironment, } _ = http2.ConfigureTransport(transport) diff --git a/tunneldns/tunnel.go b/tunneldns/tunnel.go index 315ad1fc..78f66b85 100644 --- a/tunneldns/tunnel.go +++ b/tunneldns/tunnel.go @@ -68,6 +68,12 @@ func Command(hidden bool) *cli.Command { Value: cli.NewStringSlice("https://162.159.36.1/dns-query", "https://162.159.46.1/dns-query", "https://[2606:4700:4700::1111]/dns-query", "https://[2606:4700:4700::1001]/dns-query"), EnvVars: []string{"TUNNEL_DNS_BOOTSTRAP"}, }, + &cli.IntFlag{ + Name: "max-upstream-conns", + Usage: "Maximum concurrent connections to upstream, unlimited by default", + Value: 0, + EnvVars: []string{"TUNNEL_DNS_MAX_UPSTREAM_CONNS"}, + }, }, ArgsUsage: " ", // can't be the empty string or we get the default output Hidden: hidden, @@ -90,8 +96,10 @@ func Run(c *cli.Context) error { uint16(c.Uint("port")), c.StringSlice("upstream"), c.StringSlice("bootstrap"), + c.Int("max-upstream-conns"), log, ) + if err != nil { log.Err(err).Msg("Failed to create the listeners") return err @@ -173,12 +181,12 @@ func (l *Listener) Stop() error { } // CreateListener configures the server and bound sockets -func CreateListener(address string, port uint16, upstreams []string, bootstraps []string, log *zerolog.Logger) (*Listener, error) { +func CreateListener(address string, port uint16, upstreams []string, bootstraps []string, maxUpstreamConnections int, log *zerolog.Logger) (*Listener, error) { // Build the list of upstreams upstreamList := make([]Upstream, 0) for _, url := range upstreams { log.Info().Str(LogFieldURL, url).Msg("Adding DNS upstream") - upstream, err := NewUpstreamHTTPS(url, bootstraps, log) + upstream, err := NewUpstreamHTTPS(url, bootstraps, maxUpstreamConnections, log) if err != nil { return nil, errors.Wrap(err, "failed to create HTTPS upstream") }