diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 24ec3817..70527c47 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -215,7 +215,6 @@ var ( "overwrite-dns", "help", } - runQuickTunnel = RunQuickTunnel ) func Flags() []cli.Flag { @@ -287,7 +286,14 @@ See https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/in } } +// This is so that we can mock QuickTunnelRunner for TunnelCommand test cases +type QuickTunnelRunner func(*subcommandContext) error + func TunnelCommand(c *cli.Context) error { + return tunnelCommandImpl(c, RunQuickTunnel) +} + +func tunnelCommandImpl(c *cli.Context, quickTunnelRunner QuickTunnelRunner) error { sc, err := newSubcommandContext(c) if err != nil { return err @@ -316,7 +322,7 @@ func TunnelCommand(c *cli.Context) error { // We don't support running proxy-dns and a quick tunnel at the same time as the same process shouldRunQuickTunnel := c.IsSet("url") || c.IsSet("unix-socket") || c.IsSet(ingress.HelloWorldFlag) if !c.IsSet("proxy-dns") && c.String("quick-service") != "" && shouldRunQuickTunnel { - return runQuickTunnel(sc) + return quickTunnelRunner(sc) } // If user provides a config, check to see if they meant to use `tunnel run` instead diff --git a/cmd/cloudflared/tunnel/cmd_test.go b/cmd/cloudflared/tunnel/cmd_test.go index faf1de00..06c552a8 100644 --- a/cmd/cloudflared/tunnel/cmd_test.go +++ b/cmd/cloudflared/tunnel/cmd_test.go @@ -55,10 +55,8 @@ func TestShouldRunQuickTunnel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Mock RunQuickTunnel Function - originalRunQuickTunnel := runQuickTunnel - defer func() { runQuickTunnel = originalRunQuickTunnel }() mockCalled := false - runQuickTunnel = func(sc *subcommandContext) error { + mockQuickTunnelRunner := func(sc *subcommandContext) error { mockCalled = true return nil } @@ -69,7 +67,7 @@ func TestShouldRunQuickTunnel(t *testing.T) { context := cli.NewContext(app, set, nil) // Call TunnelCommand - err := TunnelCommand(context) + err := tunnelCommandImpl(context, mockQuickTunnelRunner) // Validate if tt.expectError {