diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 1ed2289a..5afd801f 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -384,7 +384,11 @@ func StartServer( observer.RegisterSink(app) } - return waitToShutdown(&wg, cancel, errC, graceShutdownC, c.Duration("grace-period"), log) + gracePeriod, err := gracePeriod(c) + if err != nil { + return err + } + return waitToShutdown(&wg, cancel, errC, graceShutdownC, gracePeriod, log) } func waitToShutdown(wg *sync.WaitGroup, diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index dc264c31..901d192c 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/google/uuid" homedir "github.com/mitchellh/go-homedir" @@ -260,9 +261,13 @@ func prepareTunnelConfig( } originProxy := origin.NewOriginProxy(ingressRules, warpRoutingService, tags, log) + gracePeriod, err := gracePeriod(c) + if err != nil { + return nil, ingress.Ingress{}, err + } connectionConfig := &connection.Config{ OriginProxy: originProxy, - GracePeriod: c.Duration("grace-period"), + GracePeriod: gracePeriod, ReplaceExisting: c.Bool("force"), } muxerConfig := &connection.MuxerConfig{ @@ -300,6 +305,14 @@ func prepareTunnelConfig( }, ingressRules, nil } +func gracePeriod(c *cli.Context) (time.Duration, error) { + period := c.Duration("grace-period") + if period > connection.MaxGracePeriod { + return time.Duration(0), fmt.Errorf("grace-period must be equal or less than %v", connection.MaxGracePeriod) + } + return period, nil +} + func isWarpRoutingEnabled(warpConfig config.WarpRoutingConfig, isNamedTunnel bool) bool { return warpConfig.Enabled && isNamedTunnel } diff --git a/connection/connection.go b/connection/connection.go index dbe5ef1e..f061672c 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -18,6 +18,7 @@ import ( const ( lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" LogFieldConnIndex = "connIndex" + MaxGracePeriod = time.Minute * 3 ) var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))