TUN-5681: Add support for running tunnel using Token

This commit is contained in:
João Oliveirinha 2022-02-21 11:49:13 +00:00
parent 22cd8ceb8c
commit b6d7076400
6 changed files with 101 additions and 16 deletions

View File

@ -5,7 +5,7 @@ import (
) )
type TunnelClient interface { type TunnelClient interface {
CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) CreateTunnel(name string, tunnelSecret []byte) (*TunnelWithToken, error)
GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
DeleteTunnel(tunnelID uuid.UUID) error DeleteTunnel(tunnelID uuid.UUID) error
ListTunnels(filter *TunnelFilter) ([]*Tunnel, error) ListTunnels(filter *TunnelFilter) ([]*Tunnel, error)

View File

@ -23,6 +23,11 @@ type Tunnel struct {
Connections []Connection `json:"connections"` Connections []Connection `json:"connections"`
} }
type TunnelWithToken struct {
Tunnel
Token string `json:"token"`
}
type Connection struct { type Connection struct {
ColoName string `json:"colo_name"` ColoName string `json:"colo_name"`
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
@ -63,7 +68,7 @@ func (cp CleanupParams) encode() string {
return cp.queryParams.Encode() return cp.queryParams.Encode()
} }
func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) { func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*TunnelWithToken, error) {
if name == "" { if name == "" {
return nil, errors.New("tunnel name required") return nil, errors.New("tunnel name required")
} }
@ -83,7 +88,11 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er
switch resp.StatusCode { switch resp.StatusCode {
case http.StatusOK: case http.StatusOK:
return unmarshalTunnel(resp.Body) var tunnel TunnelWithToken
if serdeErr := parseResponse(resp.Body, &tunnel); err != nil {
return nil, serdeErr
}
return &tunnel, nil
case http.StatusConflict: case http.StatusConflict:
return nil, ErrTunnelNameConflict return nil, ErrTunnelNameConflict
} }

View File

@ -220,7 +220,9 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec
} }
fmt.Println(" Keep this file secret. To revoke these credentials, delete the tunnel.") fmt.Println(" Keep this file secret. To revoke these credentials, delete the tunnel.")
fmt.Printf("\nCreated tunnel %s with id %s\n", tunnel.Name, tunnel.ID) fmt.Printf("\nCreated tunnel %s with id %s\n", tunnel.Name, tunnel.ID)
return tunnel, nil fmt.Printf("\nTunnel Token: %s\n", tunnel.Token)
return &tunnel.Tunnel, nil
} }
func (sc *subcommandContext) list(filter *cfapi.TunnelFilter) ([]*cfapi.Tunnel, error) { func (sc *subcommandContext) list(filter *cfapi.TunnelFilter) ([]*cfapi.Tunnel, error) {
@ -300,6 +302,12 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
return err return err
} }
return sc.runWithCredentials(credentials)
}
func (sc *subcommandContext) runWithCredentials(credentials connection.Credentials) error {
sc.log.Info().Str(LogFieldTunnelID, credentials.TunnelID.String()).Msg("Starting tunnel")
return StartServer( return StartServer(
sc.c, sc.c,
buildInfo, buildInfo,

View File

@ -2,6 +2,7 @@ package tunnel
import ( import (
"crypto/rand" "crypto/rand"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -34,6 +35,7 @@ const (
CredFileFlagAlias = "cred-file" CredFileFlagAlias = "cred-file"
CredFileFlag = "credentials-file" CredFileFlag = "credentials-file"
CredContentsFlag = "credentials-contents" CredContentsFlag = "credentials-contents"
TunnelTokenFlag = "token"
overwriteDNSFlagName = "overwrite-dns" overwriteDNSFlagName = "overwrite-dns"
LogFieldTunnelID = "tunnelID" LogFieldTunnelID = "tunnelID"
@ -118,6 +120,11 @@ var (
Usage: "Contents of the tunnel credentials JSON file to use. When provided along with credentials-file, this will take precedence.", Usage: "Contents of the tunnel credentials JSON file to use. When provided along with credentials-file, this will take precedence.",
EnvVars: []string{"TUNNEL_CRED_CONTENTS"}, EnvVars: []string{"TUNNEL_CRED_CONTENTS"},
}) })
tunnelTokenFlag = altsrc.NewStringFlag(&cli.StringFlag{
Name: TunnelTokenFlag,
Usage: "The Tunnel token. When provided along with credentials, this will take precedence.",
EnvVars: []string{"TUNNEL_TOKEN"},
})
forceDeleteFlag = &cli.BoolFlag{ forceDeleteFlag = &cli.BoolFlag{
Name: "force", Name: "force",
Aliases: []string{"f"}, Aliases: []string{"f"},
@ -597,6 +604,7 @@ func buildRunCommand() *cli.Command {
credentialsContentsFlag, credentialsContentsFlag,
selectProtocolFlag, selectProtocolFlag,
featuresFlag, featuresFlag,
tunnelTokenFlag,
} }
flags = append(flags, configureProxyFlags(false)...) flags = append(flags, configureProxyFlags(false)...)
return &cli.Command{ return &cli.Command{
@ -627,14 +635,6 @@ func runCommand(c *cli.Context) error {
if c.NArg() > 1 { if c.NArg() > 1 {
return cliutil.UsageError(`"cloudflared tunnel run" accepts only one argument, the ID or name of the tunnel to run.`) return cliutil.UsageError(`"cloudflared tunnel run" accepts only one argument, the ID or name of the tunnel to run.`)
} }
tunnelRef := c.Args().First()
if tunnelRef == "" {
// see if tunnel id was in the config file
tunnelRef = config.GetConfiguration().TunnelID
if tunnelRef == "" {
return cliutil.UsageError(`"cloudflared tunnel run" requires the ID or name of the tunnel to run as the last command line argument or in the configuration file.`)
}
}
if c.String("hostname") != "" { if c.String("hostname") != "" {
sc.log.Warn().Msg("The property `hostname` in your configuration is ignored because you configured a Named Tunnel " + sc.log.Warn().Msg("The property `hostname` in your configuration is ignored because you configured a Named Tunnel " +
@ -642,7 +642,38 @@ func runCommand(c *cli.Context) error {
"your origin will not be reachable. You should remove the `hostname` property to avoid this warning.") "your origin will not be reachable. You should remove the `hostname` property to avoid this warning.")
} }
return runNamedTunnel(sc, tunnelRef) // Check if token is provided and if not use default tunnelID flag method
if tokenStr := c.String(TunnelTokenFlag); tokenStr != "" {
if token, err := parseToken(tokenStr); err == nil {
return sc.runWithCredentials(token.Credentials())
}
return cliutil.UsageError("Provided Tunnel token is not valid.")
} else {
tunnelRef := c.Args().First()
if tunnelRef == "" {
// see if tunnel id was in the config file
tunnelRef = config.GetConfiguration().TunnelID
if tunnelRef == "" {
return cliutil.UsageError(`"cloudflared tunnel run" requires the ID or name of the tunnel to run as the last command line argument or in the configuration file.`)
}
}
return runNamedTunnel(sc, tunnelRef)
}
}
func parseToken(tokenStr string) (*connection.TunnelToken, error) {
content, err := base64.StdEncoding.DecodeString(tokenStr)
if err != nil {
return nil, err
}
var token connection.TunnelToken
if err := json.Unmarshal(content, &token); err != nil {
return nil, err
}
return &token, nil
} }
func runNamedTunnel(sc *subcommandContext, tunnelRef string) error { func runNamedTunnel(sc *subcommandContext, tunnelRef string) error {
@ -650,9 +681,6 @@ func runNamedTunnel(sc *subcommandContext, tunnelRef string) error {
if err != nil { if err != nil {
return errors.Wrap(err, "error parsing tunnel ID") return errors.Wrap(err, "error parsing tunnel ID")
} }
sc.log.Info().Str(LogFieldTunnelID, tunnelID.String()).Msg("Starting tunnel")
return sc.run(tunnelID) return sc.run(tunnelID)
} }

View File

@ -1,14 +1,18 @@
package tunnel package tunnel
import ( import (
"encoding/base64"
"encoding/json"
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/google/uuid" "github.com/google/uuid"
homedir "github.com/mitchellh/go-homedir" homedir "github.com/mitchellh/go-homedir"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/connection"
) )
func Test_fmtConnections(t *testing.T) { func Test_fmtConnections(t *testing.T) {
@ -177,3 +181,24 @@ func Test_validateHostname(t *testing.T) {
}) })
} }
} }
func Test_TunnelToken(t *testing.T) {
token, err := parseToken("aabc")
require.Error(t, err)
require.Nil(t, token)
expectedToken := &connection.TunnelToken{
AccountTag: "abc",
TunnelSecret: []byte("secret"),
TunnelID: uuid.New(),
}
tokenJsonStr, err := json.Marshal(expectedToken)
require.NoError(t, err)
token64 := base64.StdEncoding.EncodeToString(tokenJsonStr)
token, err = parseToken(token64)
require.NoError(t, err)
require.Equal(t, token, expectedToken)
}

View File

@ -50,6 +50,21 @@ func (c *Credentials) Auth() pogs.TunnelAuth {
} }
} }
// TunnelToken are Credentials but encoded with custom fields namings.
type TunnelToken struct {
AccountTag string `json:"a"`
TunnelSecret []byte `json:"s"`
TunnelID uuid.UUID `json:"t"`
}
func (t TunnelToken) Credentials() Credentials {
return Credentials{
AccountTag: t.AccountTag,
TunnelSecret: t.TunnelSecret,
TunnelID: t.TunnelID,
}
}
type ClassicTunnelProperties struct { type ClassicTunnelProperties struct {
Hostname string Hostname string
OriginCert []byte OriginCert []byte