TUN-5681: Add support for running tunnel using Token
This commit is contained in:
		
							parent
							
								
									22cd8ceb8c
								
							
						
					
					
						commit
						b6d7076400
					
				| 
						 | 
				
			
			@ -5,7 +5,7 @@ import (
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
type TunnelClient interface {
 | 
			
		||||
	CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
 | 
			
		||||
	CreateTunnel(name string, tunnelSecret []byte) (*TunnelWithToken, error)
 | 
			
		||||
	GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
 | 
			
		||||
	DeleteTunnel(tunnelID uuid.UUID) error
 | 
			
		||||
	ListTunnels(filter *TunnelFilter) ([]*Tunnel, error)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -23,6 +23,11 @@ type Tunnel struct {
 | 
			
		|||
	Connections []Connection `json:"connections"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TunnelWithToken struct {
 | 
			
		||||
	Tunnel
 | 
			
		||||
	Token string `json:"token"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Connection struct {
 | 
			
		||||
	ColoName           string    `json:"colo_name"`
 | 
			
		||||
	ID                 uuid.UUID `json:"id"`
 | 
			
		||||
| 
						 | 
				
			
			@ -63,7 +68,7 @@ func (cp CleanupParams) encode() string {
 | 
			
		|||
	return cp.queryParams.Encode()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) {
 | 
			
		||||
func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*TunnelWithToken, error) {
 | 
			
		||||
	if name == "" {
 | 
			
		||||
		return nil, errors.New("tunnel name required")
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			@ -83,7 +88,11 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, er
 | 
			
		|||
 | 
			
		||||
	switch resp.StatusCode {
 | 
			
		||||
	case http.StatusOK:
 | 
			
		||||
		return unmarshalTunnel(resp.Body)
 | 
			
		||||
		var tunnel TunnelWithToken
 | 
			
		||||
		if serdeErr := parseResponse(resp.Body, &tunnel); err != nil {
 | 
			
		||||
			return nil, serdeErr
 | 
			
		||||
		}
 | 
			
		||||
		return &tunnel, nil
 | 
			
		||||
	case http.StatusConflict:
 | 
			
		||||
		return nil, ErrTunnelNameConflict
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -220,7 +220,9 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec
 | 
			
		|||
	}
 | 
			
		||||
	fmt.Println(" Keep this file secret. To revoke these credentials, delete the tunnel.")
 | 
			
		||||
	fmt.Printf("\nCreated tunnel %s with id %s\n", tunnel.Name, tunnel.ID)
 | 
			
		||||
	return tunnel, nil
 | 
			
		||||
	fmt.Printf("\nTunnel Token: %s\n", tunnel.Token)
 | 
			
		||||
 | 
			
		||||
	return &tunnel.Tunnel, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (sc *subcommandContext) list(filter *cfapi.TunnelFilter) ([]*cfapi.Tunnel, error) {
 | 
			
		||||
| 
						 | 
				
			
			@ -300,6 +302,12 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
 | 
			
		|||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return sc.runWithCredentials(credentials)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (sc *subcommandContext) runWithCredentials(credentials connection.Credentials) error {
 | 
			
		||||
	sc.log.Info().Str(LogFieldTunnelID, credentials.TunnelID.String()).Msg("Starting tunnel")
 | 
			
		||||
 | 
			
		||||
	return StartServer(
 | 
			
		||||
		sc.c,
 | 
			
		||||
		buildInfo,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,6 +2,7 @@ package tunnel
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
| 
						 | 
				
			
			@ -34,6 +35,7 @@ const (
 | 
			
		|||
	CredFileFlagAlias    = "cred-file"
 | 
			
		||||
	CredFileFlag         = "credentials-file"
 | 
			
		||||
	CredContentsFlag     = "credentials-contents"
 | 
			
		||||
	TunnelTokenFlag      = "token"
 | 
			
		||||
	overwriteDNSFlagName = "overwrite-dns"
 | 
			
		||||
 | 
			
		||||
	LogFieldTunnelID = "tunnelID"
 | 
			
		||||
| 
						 | 
				
			
			@ -118,6 +120,11 @@ var (
 | 
			
		|||
		Usage:   "Contents of the tunnel credentials JSON file to use. When provided along with credentials-file, this will take precedence.",
 | 
			
		||||
		EnvVars: []string{"TUNNEL_CRED_CONTENTS"},
 | 
			
		||||
	})
 | 
			
		||||
	tunnelTokenFlag = altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
		Name:    TunnelTokenFlag,
 | 
			
		||||
		Usage:   "The Tunnel token. When provided along with credentials, this will take precedence.",
 | 
			
		||||
		EnvVars: []string{"TUNNEL_TOKEN"},
 | 
			
		||||
	})
 | 
			
		||||
	forceDeleteFlag = &cli.BoolFlag{
 | 
			
		||||
		Name:    "force",
 | 
			
		||||
		Aliases: []string{"f"},
 | 
			
		||||
| 
						 | 
				
			
			@ -597,6 +604,7 @@ func buildRunCommand() *cli.Command {
 | 
			
		|||
		credentialsContentsFlag,
 | 
			
		||||
		selectProtocolFlag,
 | 
			
		||||
		featuresFlag,
 | 
			
		||||
		tunnelTokenFlag,
 | 
			
		||||
	}
 | 
			
		||||
	flags = append(flags, configureProxyFlags(false)...)
 | 
			
		||||
	return &cli.Command{
 | 
			
		||||
| 
						 | 
				
			
			@ -627,6 +635,21 @@ func runCommand(c *cli.Context) error {
 | 
			
		|||
	if c.NArg() > 1 {
 | 
			
		||||
		return cliutil.UsageError(`"cloudflared tunnel run" accepts only one argument, the ID or name of the tunnel to run.`)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if c.String("hostname") != "" {
 | 
			
		||||
		sc.log.Warn().Msg("The property `hostname` in your configuration is ignored because you configured a Named Tunnel " +
 | 
			
		||||
			"in the property `tunnel` to run. Make sure to provision the routing (e.g. via `cloudflared tunnel route dns/lb`) or else " +
 | 
			
		||||
			"your origin will not be reachable. You should remove the `hostname` property to avoid this warning.")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check if token is provided and if not use default tunnelID flag method
 | 
			
		||||
	if tokenStr := c.String(TunnelTokenFlag); tokenStr != "" {
 | 
			
		||||
		if token, err := parseToken(tokenStr); err == nil {
 | 
			
		||||
			return sc.runWithCredentials(token.Credentials())
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return cliutil.UsageError("Provided Tunnel token is not valid.")
 | 
			
		||||
	} else {
 | 
			
		||||
		tunnelRef := c.Args().First()
 | 
			
		||||
		if tunnelRef == "" {
 | 
			
		||||
			// see if tunnel id was in the config file
 | 
			
		||||
| 
						 | 
				
			
			@ -636,13 +659,21 @@ func runCommand(c *cli.Context) error {
 | 
			
		|||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	if c.String("hostname") != "" {
 | 
			
		||||
		sc.log.Warn().Msg("The property `hostname` in your configuration is ignored because you configured a Named Tunnel " +
 | 
			
		||||
			"in the property `tunnel` to run. Make sure to provision the routing (e.g. via `cloudflared tunnel route dns/lb`) or else " +
 | 
			
		||||
			"your origin will not be reachable. You should remove the `hostname` property to avoid this warning.")
 | 
			
		||||
		return runNamedTunnel(sc, tunnelRef)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
	return runNamedTunnel(sc, tunnelRef)
 | 
			
		||||
func parseToken(tokenStr string) (*connection.TunnelToken, error) {
 | 
			
		||||
	content, err := base64.StdEncoding.DecodeString(tokenStr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var token connection.TunnelToken
 | 
			
		||||
	if err := json.Unmarshal(content, &token); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return &token, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func runNamedTunnel(sc *subcommandContext, tunnelRef string) error {
 | 
			
		||||
| 
						 | 
				
			
			@ -650,9 +681,6 @@ func runNamedTunnel(sc *subcommandContext, tunnelRef string) error {
 | 
			
		|||
	if err != nil {
 | 
			
		||||
		return errors.Wrap(err, "error parsing tunnel ID")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sc.log.Info().Str(LogFieldTunnelID, tunnelID.String()).Msg("Starting tunnel")
 | 
			
		||||
 | 
			
		||||
	return sc.run(tunnelID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,14 +1,18 @@
 | 
			
		|||
package tunnel
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/google/uuid"
 | 
			
		||||
	homedir "github.com/mitchellh/go-homedir"
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/cfapi"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/connection"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Test_fmtConnections(t *testing.T) {
 | 
			
		||||
| 
						 | 
				
			
			@ -177,3 +181,24 @@ func Test_validateHostname(t *testing.T) {
 | 
			
		|||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_TunnelToken(t *testing.T) {
 | 
			
		||||
	token, err := parseToken("aabc")
 | 
			
		||||
	require.Error(t, err)
 | 
			
		||||
	require.Nil(t, token)
 | 
			
		||||
 | 
			
		||||
	expectedToken := &connection.TunnelToken{
 | 
			
		||||
		AccountTag:   "abc",
 | 
			
		||||
		TunnelSecret: []byte("secret"),
 | 
			
		||||
		TunnelID:     uuid.New(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tokenJsonStr, err := json.Marshal(expectedToken)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	token64 := base64.StdEncoding.EncodeToString(tokenJsonStr)
 | 
			
		||||
 | 
			
		||||
	token, err = parseToken(token64)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	require.Equal(t, token, expectedToken)
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,6 +50,21 @@ func (c *Credentials) Auth() pogs.TunnelAuth {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TunnelToken are Credentials but encoded with custom fields namings.
 | 
			
		||||
type TunnelToken struct {
 | 
			
		||||
	AccountTag   string    `json:"a"`
 | 
			
		||||
	TunnelSecret []byte    `json:"s"`
 | 
			
		||||
	TunnelID     uuid.UUID `json:"t"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t TunnelToken) Credentials() Credentials {
 | 
			
		||||
	return Credentials{
 | 
			
		||||
		AccountTag:   t.AccountTag,
 | 
			
		||||
		TunnelSecret: t.TunnelSecret,
 | 
			
		||||
		TunnelID:     t.TunnelID,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ClassicTunnelProperties struct {
 | 
			
		||||
	Hostname   string
 | 
			
		||||
	OriginCert []byte
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue