diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index db1cb5af..9b95c11f 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -215,6 +215,7 @@ var ( "overwrite-dns", "help", } + runQuickTunnel = RunQuickTunnel ) func Flags() []cli.Flag { @@ -313,9 +314,9 @@ func TunnelCommand(c *cli.Context) error { // Run a quick tunnel // A unauthenticated named tunnel hosted on ..com // We don't support running proxy-dns and a quick tunnel at the same time as the same process - shouldRunQuickTunnel := c.IsSet("url") || c.IsSet(ingress.HelloWorldFlag) + shouldRunQuickTunnel := c.IsSet("url") || c.IsSet("unix-socket") || c.IsSet(ingress.HelloWorldFlag) if !c.IsSet("proxy-dns") && c.String("quick-service") != "" && shouldRunQuickTunnel { - return RunQuickTunnel(sc) + return runQuickTunnel(sc) } // If user provides a config, check to see if they meant to use `tunnel run` instead diff --git a/cmd/cloudflared/tunnel/cmd_test.go b/cmd/cloudflared/tunnel/cmd_test.go index b29b3966..506ad01d 100644 --- a/cmd/cloudflared/tunnel/cmd_test.go +++ b/cmd/cloudflared/tunnel/cmd_test.go @@ -1,9 +1,12 @@ package tunnel import ( + "flag" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v2" ) func TestHostnameFromURI(t *testing.T) { @@ -15,3 +18,83 @@ func TestHostnameFromURI(t *testing.T) { assert.Equal(t, "", hostnameFromURI("trash")) assert.Equal(t, "", hostnameFromURI("https://awesomesauce.com")) } + +func TestShouldRunQuickTunnel(t *testing.T) { + tests := []struct { + name string + flags map[string]string + expectQuickTunnel bool + expectError bool + }{ + { + name: "Quick tunnel with URL set", + flags: map[string]string{"url": "http://127.0.0.1:8080", "quick-service": "https://fakeapi.trycloudflare.com"}, + expectQuickTunnel: true, + expectError: false, + }, + { + name: "Quick tunnel with unix-socket set", + flags: map[string]string{"unix-socket": "/tmp/socket", "quick-service": "https://fakeapi.trycloudflare.com"}, + expectQuickTunnel: true, + expectError: false, + }, + { + name: "Quick tunnel with hello-world flag", + flags: map[string]string{"hello-world": "true", "quick-service": "https://fakeapi.trycloudflare.com"}, + expectQuickTunnel: true, + expectError: false, + }, + { + name: "Quick tunnel with proxy-dns (invalid combo)", + flags: map[string]string{"url": "http://127.0.0.1:9090", "proxy-dns": "true", "quick-service": "https://fakeapi.trycloudflare.com"}, + expectQuickTunnel: false, + expectError: true, + }, + { + name: "No quick-service set", + flags: map[string]string{"url": "http://127.0.0.1:9090"}, + expectQuickTunnel: false, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock RunQuickTunnel Function + originalRunQuickTunnel := runQuickTunnel + defer func() { runQuickTunnel = originalRunQuickTunnel }() + mockCalled := false + runQuickTunnel = func(sc *subcommandContext) error { + mockCalled = true + return nil + } + + // Mock App Context + app := &cli.App{} + set := flagSetFromMap(tt.flags) + context := cli.NewContext(app, set, nil) + + // Call TunnelCommand + err := TunnelCommand(context) + + // Validate + if tt.expectError { + require.Error(t, err) + } else if tt.expectQuickTunnel { + assert.True(t, mockCalled) + require.NoError(t, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func flagSetFromMap(flags map[string]string) *flag.FlagSet { + set := flag.NewFlagSet("test", 0) + for key, value := range flags { + set.String(key, "", "") + set.Set(key, value) + } + return set +}