diff --git a/cmd/cloudflared/configuration.go b/cmd/cloudflared/configuration.go
new file mode 100644
index 00000000..d0e926e2
--- /dev/null
+++ b/cmd/cloudflared/configuration.go
@@ -0,0 +1,297 @@
+package main
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "encoding/hex"
+ "fmt"
+ "io/ioutil"
+ "math/rand"
+ "net"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/cloudflare/cloudflared/origin"
+ "github.com/cloudflare/cloudflared/tlsconfig"
+ tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
+ "github.com/cloudflare/cloudflared/validation"
+
+ "github.com/sirupsen/logrus"
+ "gopkg.in/urfave/cli.v2"
+ "gopkg.in/urfave/cli.v2/altsrc"
+
+ "github.com/mitchellh/go-homedir"
+ "github.com/pkg/errors"
+)
+
+var (
+ defaultConfigFiles = []string{"config.yml", "config.yaml"}
+
+ // Launchd doesn't set root env variables, so there is default
+ // Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible
+ defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp", "/usr/local/etc/cloudflared", "/etc/cloudflared"}
+)
+
+const defaultCredentialFile = "cert.pem"
+
+func fileExists(path string) (bool, error) {
+ f, err := os.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ // ignore missing files
+ return false, nil
+ }
+ return false, err
+ }
+ f.Close()
+ return true, nil
+}
+
+// returns the first path that contains a cert.pem file. If none of the defaultConfigDirs
+// (differs by OS for legacy reasons) contains a cert.pem file, return empty string
+func findDefaultOriginCertPath() string {
+ for _, defaultConfigDir := range defaultConfigDirs {
+ originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, defaultCredentialFile))
+ if ok, _ := fileExists(originCertPath); ok {
+ return originCertPath
+ }
+ }
+ return ""
+}
+
+// returns the first path that contains a config file. If none of the combination of
+// defaultConfigDirs (differs by OS for legacy reasons) and defaultConfigFiles
+// contains a config file, return empty string
+func findDefaultConfigPath() string {
+ for _, configDir := range defaultConfigDirs {
+ for _, configFile := range defaultConfigFiles {
+ dirPath, err := homedir.Expand(configDir)
+ if err != nil {
+ continue
+ }
+ path := filepath.Join(dirPath, configFile)
+ if ok, _ := fileExists(path); ok {
+ return path
+ }
+ }
+ }
+ return ""
+}
+
+func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) {
+ if context.String("config") != "" {
+ return altsrc.NewYamlSourceFromFile(context.String("config"))
+ }
+ return nil, nil
+}
+
+func generateRandomClientID() string {
+ r := rand.New(rand.NewSource(time.Now().UnixNano()))
+ id := make([]byte, 32)
+ r.Read(id)
+ return hex.EncodeToString(id)
+}
+
+func enoughOptionsSet(c *cli.Context) bool {
+ // For cloudflared to work, the user needs to at least provide a hostname,
+ // or runs as stand alone DNS proxy .
+ // When using sudo, use -E flag to preserve env vars
+ if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" && os.Getenv("TUNNEL_DNS") == "" {
+ if isRunningFromTerminal() {
+ logger.Errorf("No arguments were provided. You need to at least specify the hostname for this tunnel. See %s", quickStartUrl)
+ logger.Infof("If you want to run Argo Tunnel client as a stand alone DNS proxy, run with --proxy-dns option or set TUNNEL_DNS environment variable.")
+ } else {
+ logger.Errorf("You need to specify all the options in a configuration file, or use environment variables. See %s and %s", serviceUrl, argumentsUrl)
+ logger.Infof("If you want to run Argo Tunnel client as a stand alone DNS proxy, specify proxy-dns option in the configuration file, or set TUNNEL_DNS environment variable.")
+ }
+ cli.ShowAppHelp(c)
+ return false
+ }
+ return true
+}
+
+func handleDeprecatedOptions(c *cli.Context) {
+ // Fail if the user provided an old authentication method
+ if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
+ logger.Fatal("You don't need to give us your api-key anymore. Please use the new login method. Just run cloudflared login")
+ }
+}
+
+// validate url. It can be either from --url or argument
+func validateUrl(c *cli.Context) (string, error) {
+ var url = c.String("url")
+ if c.NArg() > 0 {
+ if c.IsSet("url") {
+ return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
+ }
+ url = c.Args().Get(0)
+ }
+ validUrl, err := validation.ValidateUrl(url)
+ return validUrl, err
+}
+
+func logClientOptions(c *cli.Context) {
+ flags := make(map[string]interface{})
+ for _, flag := range c.LocalFlagNames() {
+ flags[flag] = c.Generic(flag)
+ }
+ if len(flags) > 0 {
+ logger.Infof("Flags %v", flags)
+ }
+
+ envs := make(map[string]string)
+ // Find env variables for Argo Tunnel
+ for _, env := range os.Environ() {
+ // All Argo Tunnel env variables start with TUNNEL_
+ if strings.Contains(env, "TUNNEL_") {
+ vars := strings.Split(env, "=")
+ if len(vars) == 2 {
+ envs[vars[0]] = vars[1]
+ }
+ }
+ }
+ if len(envs) > 0 {
+ logger.Infof("Environmental variables %v", envs)
+ }
+}
+
+func dnsProxyStandAlone(c *cli.Context) bool {
+ return c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world"))
+}
+
+func getOriginCert(c *cli.Context) []byte {
+ if c.String("origincert") == "" {
+ logger.Warnf("Cannot determine default origin certificate path. No file %s in %v", defaultCredentialFile, defaultConfigDirs)
+ if isRunningFromTerminal() {
+ logger.Fatalf("You need to specify the origin certificate path with --origincert option, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", argumentsUrl)
+ } else {
+ logger.Fatalf("You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", serviceUrl)
+ }
+ }
+ // Check that the user has acquired a certificate using the login command
+ originCertPath, err := homedir.Expand(c.String("origincert"))
+ if err != nil {
+ logger.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
+ }
+ ok, err := fileExists(originCertPath)
+ if err != nil {
+ logger.Fatalf("Cannot check if origin cert exists at path %s", c.String("origincert"))
+ }
+ if !ok {
+ logger.Fatalf(`Cannot find a valid certificate for your origin at the path:
+
+ %s
+
+If the path above is wrong, specify the path with the -origincert option.
+If you don't have a certificate signed by Cloudflare, run the command:
+
+ %s login
+`, originCertPath, os.Args[0])
+ }
+ // Easier to send the certificate as []byte via RPC than decoding it at this point
+ originCert, err := ioutil.ReadFile(originCertPath)
+ if err != nil {
+ logger.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
+ }
+ return originCert
+}
+
+func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, logger, protoLogger *logrus.Logger) *origin.TunnelConfig {
+ hostname, err := validation.ValidateHostname(c.String("hostname"))
+ if err != nil {
+ logger.WithError(err).Fatal("Invalid hostname")
+ }
+ clientID := c.String("id")
+ if !c.IsSet("id") {
+ clientID = generateRandomClientID()
+ }
+
+ tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
+ if err != nil {
+ logger.WithError(err).Fatal("Tag parse failure")
+ }
+
+ tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
+
+ url, err := validateUrl(c)
+ if err != nil {
+ logger.WithError(err).Fatal("Error validating url")
+ }
+ logger.Infof("Proxying tunnel requests to %s", url)
+
+ originCert := getOriginCert(c)
+
+ originCertPool, err := loadCertPool(c)
+ if err != nil {
+ logger.Fatal(err)
+ }
+
+ tunnelMetrics := origin.NewTunnelMetrics()
+ httpTransport := &http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ DialContext: (&net.Dialer{
+ Timeout: c.Duration("proxy-connect-timeout"),
+ KeepAlive: c.Duration("proxy-tcp-keepalive"),
+ DualStack: !c.Bool("proxy-no-happy-eyeballs"),
+ }).DialContext,
+ MaxIdleConns: c.Int("proxy-keepalive-connections"),
+ IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
+ TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
+ ExpectContinueTimeout: 1 * time.Second,
+ TLSClientConfig: &tls.Config{RootCAs: originCertPool},
+ }
+
+ if !c.IsSet("hello-world") && c.IsSet("origin-server-name") {
+ httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
+ }
+
+ return &origin.TunnelConfig{
+ EdgeAddrs: c.StringSlice("edge"),
+ OriginUrl: url,
+ Hostname: hostname,
+ OriginCert: originCert,
+ TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
+ ClientTlsConfig: httpTransport.TLSClientConfig,
+ Retries: c.Uint("retries"),
+ HeartbeatInterval: c.Duration("heartbeat-interval"),
+ MaxHeartbeats: c.Uint64("heartbeat-count"),
+ ClientID: clientID,
+ BuildInfo: buildInfo,
+ ReportedVersion: Version,
+ LBPool: c.String("lb-pool"),
+ Tags: tags,
+ HAConnections: c.Int("ha-connections"),
+ HTTPTransport: httpTransport,
+ Metrics: tunnelMetrics,
+ MetricsUpdateFreq: c.Duration("metrics-update-freq"),
+ ProtocolLogger: protoLogger,
+ Logger: logger,
+ IsAutoupdated: c.Bool("is-autoupdated"),
+ GracePeriod: c.Duration("grace-period"),
+ RunFromTerminal: isRunningFromTerminal(),
+ }
+}
+
+func loadCertPool(c *cli.Context) (*x509.CertPool, error) {
+ const originCAPoolFlag = "origin-ca-pool"
+ originCAPoolFilename := c.String(originCAPoolFlag)
+ var originCustomCAPool []byte
+
+ if originCAPoolFilename != "" {
+ var err error
+ originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename)
+ if err != nil {
+ return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, originCAPoolFlag))
+ }
+ }
+
+ originCertPool, err := tlsconfig.LoadOriginCertPool(originCustomCAPool)
+ if err != nil {
+ return nil, errors.Wrap(err, "error loading the certificate pool")
+ }
+
+ return originCertPool, nil
+}
diff --git a/cmd/cloudflared/generic_service.go b/cmd/cloudflared/generic_service.go
index 3eb4819e..a2fbf494 100644
--- a/cmd/cloudflared/generic_service.go
+++ b/cmd/cloudflared/generic_service.go
@@ -8,6 +8,6 @@ import (
cli "gopkg.in/urfave/cli.v2"
)
-func runApp(app *cli.App) {
+func runApp(app *cli.App, shutdownC chan struct{}) {
app.Run(os.Args)
}
diff --git a/cmd/cloudflared/hello.go b/cmd/cloudflared/hello.go
index b2cd332d..55f77269 100644
--- a/cmd/cloudflared/hello.go
+++ b/cmd/cloudflared/hello.go
@@ -1,204 +1,21 @@
package main
import (
- "bytes"
- "crypto/tls"
- "encoding/json"
"fmt"
- "html/template"
- "io/ioutil"
- "net"
- "net/http"
- "os"
- "time"
- "github.com/gorilla/websocket"
"gopkg.in/urfave/cli.v2"
- "github.com/cloudflare/cloudflared/tlsconfig"
+ "github.com/cloudflare/cloudflared/hello"
)
-type templateData struct {
- ServerName string
- Request *http.Request
- Body string
-}
-type OriginUpTime struct {
- StartTime time.Time `json:"startTime"`
- UpTime string `json:"uptime"`
-}
-
-const defaultServerName = "the Argo Tunnel test server"
-const indexTemplate = `
-
-
-
-
-
-
- Argo Tunnel Connection
-
-
-
-
-
-
-
-
-
-
-
Congrats! You created your first tunnel!
-
- Argo Tunnel exposes locally running applications to the internet by
- running an encrypted, virtual tunnel from your laptop or server to
- Cloudflare's edge network.
-
-
-
-`
-
-func hello(c *cli.Context) error {
+func helloWorld(c *cli.Context) error {
address := fmt.Sprintf(":%d", c.Int("port"))
- listener, err := createListener(address)
+ listener, err := hello.CreateTLSListener(address)
if err != nil {
return err
}
defer listener.Close()
- err = startHelloWorldServer(listener, nil)
+ err = hello.StartHelloWorldServer(logger, listener, nil)
return err
}
-
-func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
- logger.Infof("Starting Hello World server at %s", listener.Addr())
- serverName := defaultServerName
- if hostname, err := os.Hostname(); err == nil {
- serverName = hostname
- }
-
- upgrader := websocket.Upgrader{
- ReadBufferSize: 1024,
- WriteBufferSize: 1024,
- }
-
- httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
- go func() {
- <-shutdownC
- httpServer.Close()
- }()
-
- http.HandleFunc("/uptime", uptimeHandler(time.Now()))
- http.HandleFunc("/ws", websocketHandler(upgrader))
- http.HandleFunc("/", rootHandler(serverName))
- err := httpServer.Serve(listener)
- return err
-}
-
-func createListener(address string) (net.Listener, error) {
- certificate, err := tlsconfig.GetHelloCertificate()
- if err != nil {
- return nil, err
- }
-
- // If the port in address is empty, a port number is automatically chosen
- listener, err := tls.Listen(
- "tcp",
- address,
- &tls.Config{Certificates: []tls.Certificate{certificate}})
-
- return listener, err
-}
-
-func uptimeHandler(startTime time.Time) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- // Note that if autoupdate is enabled, the uptime is reset when a new client
- // release is available
- resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()}
- respJson, err := json.Marshal(resp)
- if err != nil {
- w.WriteHeader(http.StatusInternalServerError)
- } else {
- w.Header().Set("Content-Type", "application/json")
- w.Write(respJson)
- }
- }
-}
-
-func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- conn, err := upgrader.Upgrade(w, r, nil)
- if err != nil {
- return
- }
- defer conn.Close()
-
- for {
- mt, message, err := conn.ReadMessage()
- if err != nil {
- break
- }
-
- if err := conn.WriteMessage(mt, message); err != nil {
- break
- }
- }
- }
-}
-
-func rootHandler(serverName string) http.HandlerFunc {
- responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
- return func(w http.ResponseWriter, r *http.Request) {
- var buffer bytes.Buffer
- var body string
- rawBody, err := ioutil.ReadAll(r.Body)
- if err == nil {
- body = string(rawBody)
- } else {
- body = ""
- }
- err = responseTemplate.Execute(&buffer, &templateData{
- ServerName: serverName,
- Request: r,
- Body: body,
- })
- if err != nil {
- w.WriteHeader(http.StatusInternalServerError)
- fmt.Fprintf(w, "error: %v", err)
- } else {
- buffer.WriteTo(w)
- }
- }
-}
diff --git a/cmd/cloudflared/linux_service.go b/cmd/cloudflared/linux_service.go
index 4e859617..19185171 100644
--- a/cmd/cloudflared/linux_service.go
+++ b/cmd/cloudflared/linux_service.go
@@ -10,7 +10,7 @@ import (
cli "gopkg.in/urfave/cli.v2"
)
-func runApp(app *cli.App) {
+func runApp(app *cli.App, shutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel system service",
@@ -183,9 +183,9 @@ func installLinuxService(c *cli.Context) error {
defaultConfigDir := filepath.Dir(c.String("config"))
defaultConfigFile := filepath.Base(c.String("config"))
- if err = copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile); err != nil {
+ if err = copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile, defaultCredentialFile); err != nil {
logger.WithError(err).Infof("Failed to copy user configuration. Before running the service, ensure that %s contains two files, %s and %s",
- serviceConfigDir, credentialFile, defaultConfigFiles[0])
+ serviceConfigDir, defaultCredentialFile, defaultConfigFiles[0])
return err
}
diff --git a/cmd/cloudflared/logger.go b/cmd/cloudflared/logger.go
new file mode 100644
index 00000000..74a0b221
--- /dev/null
+++ b/cmd/cloudflared/logger.go
@@ -0,0 +1,65 @@
+package main
+
+import (
+ "fmt"
+ "os"
+
+ "github.com/cloudflare/cloudflared/log"
+
+ "github.com/rifflock/lfshook"
+ "github.com/sirupsen/logrus"
+ "gopkg.in/urfave/cli.v2"
+
+ "github.com/mitchellh/go-homedir"
+ "github.com/pkg/errors"
+)
+
+var logger = log.CreateLogger()
+
+func configMainLogger(c *cli.Context) {
+ logLevel, err := logrus.ParseLevel(c.String("loglevel"))
+ if err != nil {
+ logger.WithError(err).Fatal("Unknown logging level specified")
+ }
+ logger.SetLevel(logLevel)
+}
+
+func configProtoLogger(c *cli.Context) *logrus.Logger {
+ protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
+ if err != nil {
+ logger.WithError(err).Fatal("Unknown protocol logging level specified")
+ }
+ protoLogger := logrus.New()
+ protoLogger.Level = protoLogLevel
+ return protoLogger
+}
+
+func initLogFile(c *cli.Context, loggers ...*logrus.Logger) error {
+ filePath, err := homedir.Expand(c.String("logfile"))
+ if err != nil {
+ return errors.Wrap(err, "Cannot resolve logfile path")
+ }
+
+ fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC
+ // do not truncate log file if the client has been autoupdated
+ if c.Bool("is-autoupdated") {
+ fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE
+ }
+ f, err := os.OpenFile(filePath, fileMode, 0664)
+ if err != nil {
+ errors.Wrap(err, fmt.Sprintf("Cannot open file %s", filePath))
+ }
+ defer f.Close()
+ pathMap := lfshook.PathMap{
+ logrus.InfoLevel: filePath,
+ logrus.ErrorLevel: filePath,
+ logrus.FatalLevel: filePath,
+ logrus.PanicLevel: filePath,
+ }
+
+ for _, l := range loggers {
+ l.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
+ }
+
+ return nil
+}
diff --git a/cmd/cloudflared/login.go b/cmd/cloudflared/login.go
index 1ba234ec..bbffa29e 100644
--- a/cmd/cloudflared/login.go
+++ b/cmd/cloudflared/login.go
@@ -35,7 +35,7 @@ func login(c *cli.Context) error {
if err != nil {
return err
}
- path := filepath.Join(configPath, credentialFile)
+ path := filepath.Join(configPath, defaultCredentialFile)
fileInfo, err := os.Stat(path)
if err == nil && fileInfo.Size() > 0 {
fmt.Fprintf(os.Stderr, `You have an existing certificate at %s which login would overwrite.
diff --git a/cmd/cloudflared/macos_service.go b/cmd/cloudflared/macos_service.go
index 37737c57..30d703c6 100644
--- a/cmd/cloudflared/macos_service.go
+++ b/cmd/cloudflared/macos_service.go
@@ -13,7 +13,7 @@ const (
launchdIdentifier = "com.cloudflare.cloudflared"
)
-func runApp(app *cli.App) {
+func runApp(app *cli.App, shutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel launch agent",
@@ -91,12 +91,12 @@ func stderrPath() string {
func installLaunchd(c *cli.Context) error {
if isRootUser() {
logger.Infof("Installing Argo Tunnel client as a system launch daemon. " +
- "Argo Tunnel client will run at boot")
+ "Argo Tunnel client will run at boot")
} else {
logger.Infof("Installing Argo Tunnel client as an user launch agent. " +
- "Note that Argo Tunnel client will only run when the user is logged in. " +
- "If you want to run Argo Tunnel client at boot, install with root permission. " +
- "For more information, visit https://developers.cloudflare.com/argo-tunnel/reference/service/")
+ "Note that Argo Tunnel client will only run when the user is logged in. " +
+ "If you want to run Argo Tunnel client at boot, install with root permission. " +
+ "For more information, visit https://developers.cloudflare.com/argo-tunnel/reference/service/")
}
etPath, err := os.Executable()
if err != nil {
@@ -120,7 +120,6 @@ func installLaunchd(c *cli.Context) error {
}
func uninstallLaunchd(c *cli.Context) error {
-
if isRootUser() {
logger.Infof("Uninstalling Argo Tunnel as a system launch daemon")
} else {
diff --git a/cmd/cloudflared/main.go b/cmd/cloudflared/main.go
index c0b6ebb1..da17fcca 100644
--- a/cmd/cloudflared/main.go
+++ b/cmd/cloudflared/main.go
@@ -1,69 +1,47 @@
package main
import (
- "crypto/tls"
- "encoding/hex"
"fmt"
- "io/ioutil"
- "math/rand"
- "net"
- "net/http"
"os"
- "os/signal"
- "path/filepath"
- "strings"
"sync"
- "syscall"
"time"
- "github.com/cloudflare/cloudflared/log"
+ "github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/origin"
- "github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns"
- tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
- "github.com/cloudflare/cloudflared/validation"
- "github.com/facebookgo/grace/gracenet"
"github.com/getsentry/raven-go"
"github.com/mitchellh/go-homedir"
- "github.com/rifflock/lfshook"
- "github.com/sirupsen/logrus"
- "golang.org/x/crypto/ssh/terminal"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
"github.com/coreos/go-systemd/daemon"
- "github.com/pkg/errors"
+ "github.com/facebookgo/grace/gracenet"
)
const (
- sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
- credentialFile = "cert.pem"
- quickStartUrl = "https://developers.cloudflare.com/argo-tunnel/quickstart/quickstart/"
- noAutoupdateMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/"
- licenseUrl = "https://developers.cloudflare.com/argo-tunnel/licence/"
+ sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
+ developerPortal = "https://developers.cloudflare.com/argo-tunnel"
+ quickStartUrl = developerPortal + "/quickstart/quickstart/"
+ serviceUrl = developerPortal + "/reference/service/"
+ argumentsUrl = developerPortal + "/reference/arguments/"
+ licenseUrl = developerPortal + "/licence/"
)
-var listeners = gracenet.Net{}
-var Version = "DEV"
-var BuildTime = "unknown"
-var logger = log.CreateLogger()
-var defaultConfigFiles = []string{"config.yml", "config.yaml"}
-
-// Launchd doesn't set root env variables, so there is default
-// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible
-var defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp", "/usr/local/etc/cloudflared", "/etc/cloudflared"}
-
-// Shutdown channel used by the app. When closed, app must terminate.
-// May be closed by the Windows service runner.
-var shutdownC chan struct{}
+var (
+ Version = "DEV"
+ BuildTime = "unknown"
+)
func main() {
metrics.RegisterBuildInfo(BuildTime, Version)
raven.SetDSN(sentryDSN)
raven.SetRelease(Version)
- shutdownC = make(chan struct{})
+
+ // Shutdown channel used by the app. When closed, app must terminate.
+ // May be closed by the Windows service runner.
+ shutdownC := make(chan struct{})
app := &cli.App{}
app.Name = "cloudflared"
@@ -119,6 +97,11 @@ func main() {
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: findDefaultOriginCertPath(),
}),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "origin-ca-pool",
+ Usage: "Path to the CA for the certificate of your origin. This option should be used only if your certificate is not signed by Cloudflare.",
+ EnvVars: []string{"TUNNEL_ORIGIN_CA_POOL"},
+ }),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "url",
Value: "https://localhost:8080",
@@ -293,10 +276,13 @@ func main() {
}),
}
app.Action = func(c *cli.Context) error {
- raven.CapturePanic(func() { startServer(c) }, nil)
+ raven.CapturePanic(func() { startServer(c, shutdownC) }, nil)
return nil
}
app.Before = func(context *cli.Context) error {
+ if context.String("config") == "" {
+ logger.Warnf("Cannot determine default configuration path. No file %v in %v", defaultConfigFiles, defaultConfigDirs)
+ }
inputSource, err := findInputSourceContext(context)
if err != nil {
logger.WithError(err).Infof("Cannot load configuration from %s", context.String("config"))
@@ -337,7 +323,7 @@ func main() {
},
{
Name: "hello",
- Action: hello,
+ Action: helloWorld,
Usage: "Run a simple \"Hello World\" server for testing Argo Tunnel.",
Flags: []cli.Flag{
&cli.IntFlag{
@@ -381,233 +367,130 @@ func main() {
ArgsUsage: " ", // can't be the empty string or we get the default output
},
}
- runApp(app)
+ runApp(app, shutdownC)
}
-func startServer(c *cli.Context) {
+func startServer(c *cli.Context, shutdownC chan struct{}) {
var wg sync.WaitGroup
+ listeners := gracenet.Net{}
errC := make(chan error)
connectedSignal := make(chan struct{})
dnsReadySignal := make(chan struct{})
+ graceShutdownSignal := make(chan struct{})
- // If the user choose to supply all options through env variables,
- // c.NumFlags() == 0 && c.NArg() == 0. For cloudflared to work, the user needs to at
- // least provide a hostname.
- if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" {
- logger.Infof("No arguments were provided. You need to at least specify the hostname for this tunnel. See %s", quickStartUrl)
- cli.ShowAppHelp(c)
+ // check whether client provides enough flags or env variables. If not, print help.
+ if ok := enoughOptionsSet(c); !ok {
return
}
- logLevel, err := logrus.ParseLevel(c.String("loglevel"))
- if err != nil {
- logger.WithError(err).Fatal("Unknown logging level specified")
- }
- logger.SetLevel(logLevel)
-
- protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
- if err != nil {
- logger.WithError(err).Fatal("Unknown protocol logging level specified")
- }
- protoLogger := logrus.New()
- protoLogger.Level = protoLogLevel
+ configMainLogger(c)
+ protoLogger := configProtoLogger(c)
if c.String("logfile") != "" {
- if err := initLogFile(c, protoLogger); err != nil {
+ if err := initLogFile(c, logger, protoLogger); err != nil {
logger.Error(err)
}
}
+ handleDeprecatedOptions(c)
+
buildInfo := origin.GetBuildInfo()
logger.Infof("Build info: %+v", *buildInfo)
logger.Infof("Version %s", Version)
logClientOptions(c)
if c.IsSet("proxy-dns") {
- port := c.Int("proxy-dns-port")
- if port <= 0 || port > 65535 {
- logger.Fatal("The 'proxy-dns-port' must be a valid port number in <1, 65535> range.")
- }
wg.Add(1)
- listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(port), c.StringSlice("proxy-dns-upstream"))
- if err != nil {
- close(dnsReadySignal)
- listener.Stop()
- logger.WithError(err).Fatal("Cannot create the DNS over HTTPS proxy server")
- }
go func() {
- err := listener.Start(dnsReadySignal)
- if err != nil {
- logger.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server")
- } else {
- <-shutdownC
- }
- listener.Stop()
- wg.Done()
+ defer wg.Done()
+ runDNSProxyServer(c, dnsReadySignal, shutdownC)
}()
} else {
close(dnsReadySignal)
}
- isRunningFromTerminal := isRunningFromTerminal()
- if isAutoupdateEnabled(c, isRunningFromTerminal) {
- // Wait for proxy-dns to come up (if used)
- <-dnsReadySignal
- if initUpdate() {
+ // Wait for proxy-dns to come up (if used)
+ <-dnsReadySignal
+
+ // update needs to be after DNS proxy is up to resolve equinox server address
+ if isAutoupdateEnabled(c) {
+ if initUpdate(&listeners) {
return
}
logger.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
- go autoupdate(c.Duration("autoupdate-freq"), shutdownC)
+ go autoupdate(c.Duration("autoupdate-freq"), &listeners, shutdownC)
}
- // Serve DNS proxy stand-alone if no hostname or tag or app is going to run
- if c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world")) {
- go writePidFile(connectedSignal, c.String("pidfile"))
- close(connectedSignal)
- runServer(c, &wg, errC, shutdownC)
- return
- }
-
- hostname, err := validation.ValidateHostname(c.String("hostname"))
- if err != nil {
- logger.WithError(err).Fatal("Invalid hostname")
- }
- clientID := c.String("id")
- if !c.IsSet("id") {
- clientID = generateRandomClientID()
- }
-
- tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
- if err != nil {
- logger.WithError(err).Fatal("Tag parse failure")
- }
-
- tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
- if c.IsSet("hello-world") {
- wg.Add(1)
- listener, err := createListener("127.0.0.1:")
- if err != nil {
- listener.Close()
- logger.WithError(err).Fatal("Cannot start Hello World Server")
- }
- go func() {
- startHelloWorldServer(listener, shutdownC)
- wg.Done()
- listener.Close()
- }()
- c.Set("url", "https://"+listener.Addr().String())
- }
-
- url, err := validateUrl(c)
- if err != nil {
- logger.WithError(err).Fatal("Error validating url")
- }
- logger.Infof("Proxying tunnel requests to %s", url)
-
- // Fail if the user provided an old authentication method
- if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
- logger.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflared login")
- }
-
- // Check that the user has acquired a certificate using the log in command
- originCertPath, err := homedir.Expand(c.String("origincert"))
- if err != nil {
- logger.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
- }
- ok, err := fileExists(originCertPath)
- if err != nil {
- logger.Fatalf("Cannot check if origin cert exists at path %s", c.String("origincert"))
- }
- if !ok {
- logger.Fatalf(`Cannot find a valid certificate for your origin at the path:
-
- %s
-
-If the path above is wrong, specify the path with the -origincert option.
-If you don't have a certificate signed by Cloudflare, run the command:
-
- %s login
-`, originCertPath, os.Args[0])
- }
- // Easier to send the certificate as []byte via RPC than decoding it at this point
- originCert, err := ioutil.ReadFile(originCertPath)
- if err != nil {
- logger.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
- }
-
- tunnelMetrics := origin.NewTunnelMetrics()
- httpTransport := &http.Transport{
- Proxy: http.ProxyFromEnvironment,
- DialContext: (&net.Dialer{
- Timeout: c.Duration("proxy-connect-timeout"),
- KeepAlive: c.Duration("proxy-tcp-keepalive"),
- DualStack: !c.Bool("proxy-no-happy-eyeballs"),
- }).DialContext,
- MaxIdleConns: c.Int("proxy-keepalive-connections"),
- IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
- TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
- ExpectContinueTimeout: 1 * time.Second,
- TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()},
- }
-
- if !c.IsSet("hello-world") && c.IsSet("origin-server-name") {
- httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
- }
-
- tunnelConfig := &origin.TunnelConfig{
- EdgeAddrs: c.StringSlice("edge"),
- OriginUrl: url,
- Hostname: hostname,
- OriginCert: originCert,
- TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
- ClientTlsConfig: httpTransport.TLSClientConfig,
- Retries: c.Uint("retries"),
- HeartbeatInterval: c.Duration("heartbeat-interval"),
- MaxHeartbeats: c.Uint64("heartbeat-count"),
- ClientID: clientID,
- BuildInfo: buildInfo,
- ReportedVersion: Version,
- LBPool: c.String("lb-pool"),
- Tags: tags,
- HAConnections: c.Int("ha-connections"),
- HTTPTransport: httpTransport,
- Metrics: tunnelMetrics,
- MetricsUpdateFreq: c.Duration("metrics-update-freq"),
- ProtocolLogger: protoLogger,
- Logger: logger,
- IsAutoupdated: c.Bool("is-autoupdated"),
- GracePeriod: c.Duration("grace-period"),
- RunFromTerminal: isRunningFromTerminal,
- }
-
- go writePidFile(connectedSignal, c.String("pidfile"))
- wg.Add(1)
- go func() {
- errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal)
- wg.Done()
- }()
-
- runServer(c, &wg, errC, shutdownC)
-}
-
-func runServer(c *cli.Context, wg *sync.WaitGroup, errC chan error, shutdownC chan struct{}) {
- wg.Add(1)
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
if err != nil {
logger.WithError(err).Fatal("Error opening metrics server listener")
}
+ defer metricsListener.Close()
+ wg.Add(1)
go func() {
+ defer wg.Done()
errC <- metrics.ServeMetrics(metricsListener, shutdownC, logger)
- wg.Done()
}()
+ // Serve DNS proxy stand-alone if no hostname or tag or app is going to run
+ if dnsProxyStandAlone(c) {
+ if c.IsSet("pidfile") {
+ go writePidFile(connectedSignal, c.String("pidfile"))
+ close(connectedSignal)
+ }
+ // no grace period, handle SIGINT/SIGTERM immediately
+ waitToShutdown(&wg, errC, shutdownC, graceShutdownSignal, 0)
+ return
+ }
+
+ if c.IsSet("hello-world") {
+ helloListener, err := hello.CreateTLSListener("127.0.0.1:")
+ if err != nil {
+ logger.WithError(err).Fatal("Cannot start Hello World Server")
+ }
+ defer helloListener.Close()
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ hello.StartHelloWorldServer(logger, helloListener, shutdownC)
+ }()
+ c.Set("url", "https://"+helloListener.Addr().String())
+ }
+
+ tunnelConfig := prepareTunnelConfig(c, buildInfo, logger, protoLogger)
+
+ if c.IsSet("pidFile") {
+ go writePidFile(connectedSignal, c.String("pidfile"))
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ errC <- origin.StartTunnelDaemon(tunnelConfig, graceShutdownSignal, connectedSignal)
+ }()
+
+ waitToShutdown(&wg, errC, shutdownC, graceShutdownSignal, c.Duration("grace-period"))
+}
+
+func waitToShutdown(wg *sync.WaitGroup,
+ errC chan error,
+ shutdownC, graceShutdownSignal chan struct{},
+ gracePeriod time.Duration,
+) {
+ var err error
+ if gracePeriod > 0 {
+ err = waitForSignalWithGraceShutdown(errC, shutdownC, graceShutdownSignal, gracePeriod)
+ } else {
+ err = waitForSignal(errC, shutdownC)
+ close(graceShutdownSignal)
+ }
+
var errCode int
- err = WaitForSignal(errC, shutdownC)
if err != nil {
logger.WithError(err).Fatal("Quitting due to error")
raven.CaptureErrorAndWait(err, nil)
errCode = 1
} else {
- logger.Info("Graceful shutdown...")
+ logger.Info("Quitting...")
}
// Wait for clean exit, discarding all errors
go func() {
@@ -618,126 +501,6 @@ func runServer(c *cli.Context, wg *sync.WaitGroup, errC chan error, shutdownC ch
os.Exit(errCode)
}
-func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
- signals := make(chan os.Signal, 10)
- signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
- defer signal.Stop(signals)
-
- select {
- case err := <-errC:
- close(shutdownC)
- return err
- case <-signals:
- close(shutdownC)
- case <-shutdownC:
- }
-
- return nil
-}
-
-func update(_ *cli.Context) error {
- if updateApplied() {
- os.Exit(64)
- }
- return nil
-}
-
-func initUpdate() bool {
- if updateApplied() {
- os.Args = append(os.Args, "--is-autoupdated=true")
- if _, err := listeners.StartProcess(); err != nil {
- logger.WithError(err).Error("Unable to restart server automatically")
- return false
- }
- return true
- }
- return false
-}
-
-func autoupdate(freq time.Duration, shutdownC chan struct{}) {
- for {
- if updateApplied() {
- os.Args = append(os.Args, "--is-autoupdated=true")
- if _, err := listeners.StartProcess(); err != nil {
- logger.WithError(err).Error("Unable to restart server automatically")
- }
- close(shutdownC)
- return
- }
- time.Sleep(freq)
- }
-}
-
-func updateApplied() bool {
- releaseInfo := checkForUpdates()
- if releaseInfo.Updated {
- logger.Infof("Updated to version %s", releaseInfo.Version)
- return true
- }
- if releaseInfo.Error != nil {
- logger.WithError(releaseInfo.Error).Error("Update check failed")
- }
- return false
-}
-
-func fileExists(path string) (bool, error) {
- f, err := os.Open(path)
- if err != nil {
- if os.IsNotExist(err) {
- // ignore missing files
- return false, nil
- }
- return false, err
- }
- f.Close()
- return true, nil
-}
-
-// returns the first path that contains a cert.pem file. If none of the defaultConfigDirs
-// (differs by OS for legacy reasons) contains a cert.pem file, return empty string
-func findDefaultOriginCertPath() string {
- for _, defaultConfigDir := range defaultConfigDirs {
- originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, credentialFile))
- if ok, _ := fileExists(originCertPath); ok {
- return originCertPath
- }
- }
- return ""
-}
-
-// returns the firt path that contains a config file. If none of the combination of
-// defaultConfigDirs (differs by OS for legacy reasons) and defaultConfigFiles
-// contains a config file, return empty string
-func findDefaultConfigPath() string {
- for _, configDir := range defaultConfigDirs {
- for _, configFile := range defaultConfigFiles {
- dirPath, err := homedir.Expand(configDir)
- if err != nil {
- continue
- }
- path := filepath.Join(dirPath, configFile)
- if ok, _ := fileExists(path); ok {
- return path
- }
- }
- }
- return ""
-}
-
-func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) {
- if context.String("config") != "" {
- return altsrc.NewYamlSourceFromFile(context.String("config"))
- }
- return nil, nil
-}
-
-func generateRandomClientID() string {
- r := rand.New(rand.NewSource(time.Now().UnixNano()))
- id := make([]byte, 32)
- r.Read(id)
- return hex.EncodeToString(id)
-}
-
func writePidFile(waitForSignal chan struct{}, pidFile string) {
<-waitForSignal
daemon.SdNotify(false, "READY=1")
@@ -752,87 +515,6 @@ func writePidFile(waitForSignal chan struct{}, pidFile string) {
fmt.Fprintf(file, "%d", os.Getpid())
}
-// validate url. It can be either from --url or argument
-func validateUrl(c *cli.Context) (string, error) {
- var url = c.String("url")
- if c.NArg() > 0 {
- if c.IsSet("url") {
- return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
- }
- url = c.Args().Get(0)
- }
- validUrl, err := validation.ValidateUrl(url)
- return validUrl, err
-}
-
-func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error {
- filePath, err := homedir.Expand(c.String("logfile"))
- if err != nil {
- return errors.Wrap(err, "Cannot resolve logfile path")
- }
-
- fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC
- // do not truncate log file if the client has been autoupdated
- if c.Bool("is-autoupdated") {
- fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE
- }
- f, err := os.OpenFile(filePath, fileMode, 0664)
- if err != nil {
- errors.Wrap(err, fmt.Sprintf("Cannot open file %s", filePath))
- }
- defer f.Close()
- pathMap := lfshook.PathMap{
- logrus.InfoLevel: filePath,
- logrus.ErrorLevel: filePath,
- logrus.FatalLevel: filePath,
- logrus.PanicLevel: filePath,
- }
-
- logger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
- protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
-
- return nil
-}
-
-func logClientOptions(c *cli.Context) {
- flags := make(map[string]interface{})
- for _, flag := range c.LocalFlagNames() {
- flags[flag] = c.Generic(flag)
- }
- if len(flags) > 0 {
- logger.Infof("Flags %v", flags)
- }
-
- envs := make(map[string]string)
- // Find env variables for Argo Tunnel
- for _, env := range os.Environ() {
- // All Argo Tunnel env variables start with TUNNEL_
- if strings.Contains(env, "TUNNEL_") {
- vars := strings.Split(env, "=")
- if len(vars) == 2 {
- envs[vars[0]] = vars[1]
- }
- }
- }
- if len(envs) > 0 {
- logger.Infof("Environmental variables %v", envs)
- }
-}
-
-func isAutoupdateEnabled(c *cli.Context, isRunningFromTerminal bool) bool {
- if isRunningFromTerminal {
- logger.Info(noAutoupdateMessage)
- return false
- }
-
- return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0
-}
-
-
-func isRunningFromTerminal() bool {
- return terminal.IsTerminal(int(os.Stdout.Fd()))
-}
-
func userHomeDir() string {
// This returns the home dir of the executing user using OS-specific method
// for discovering the home dir. It's not recommended to call this function
diff --git a/cmd/cloudflared/server.go b/cmd/cloudflared/server.go
new file mode 100644
index 00000000..30e65b5f
--- /dev/null
+++ b/cmd/cloudflared/server.go
@@ -0,0 +1,27 @@
+package main
+
+import (
+ "github.com/cloudflare/cloudflared/tunneldns"
+
+ "gopkg.in/urfave/cli.v2"
+)
+
+func runDNSProxyServer(c *cli.Context, dnsReadySignal, shutdownC chan struct{}) {
+ port := c.Int("proxy-dns-port")
+ if port <= 0 || port > 65535 {
+ logger.Fatal("The 'proxy-dns-port' must be a valid port number in <1, 65535> range.")
+ }
+ listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(port), c.StringSlice("proxy-dns-upstream"))
+ if err != nil {
+ close(dnsReadySignal)
+ listener.Stop()
+ logger.WithError(err).Fatal("Cannot create the DNS over HTTPS proxy server")
+ }
+
+ err = listener.Start(dnsReadySignal)
+ if err != nil {
+ logger.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server")
+ }
+ <-shutdownC
+ listener.Stop()
+}
diff --git a/cmd/cloudflared/service_template.go b/cmd/cloudflared/service_template.go
index afec0fd7..63e47bbd 100644
--- a/cmd/cloudflared/service_template.go
+++ b/cmd/cloudflared/service_template.go
@@ -119,7 +119,7 @@ func openFile(path string, create bool) (file *os.File, exists bool, err error)
return file, false, err
}
-func copyCertificate(srcConfigDir, destConfigDir string) error {
+func copyCertificate(srcConfigDir, destConfigDir, credentialFile string) error {
destCredentialPath := filepath.Join(destConfigDir, credentialFile)
destFile, exists, err := openFile(destCredentialPath, true)
if err != nil {
@@ -146,12 +146,12 @@ func copyCertificate(srcConfigDir, destConfigDir string) error {
return nil
}
-func copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile string) error {
+func copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile, defaultCredentialFile string) error {
if err := ensureConfigDirExists(serviceConfigDir); err != nil {
return err
}
- if err := copyCertificate(defaultConfigDir, serviceConfigDir); err != nil {
+ if err := copyCertificate(defaultConfigDir, serviceConfigDir, defaultCredentialFile); err != nil {
return err
}
diff --git a/cmd/cloudflared/signal.go b/cmd/cloudflared/signal.go
new file mode 100644
index 00000000..0953e7e9
--- /dev/null
+++ b/cmd/cloudflared/signal.go
@@ -0,0 +1,54 @@
+package main
+
+import (
+ "os"
+ "os/signal"
+ "syscall"
+ "time"
+)
+
+func waitForSignal(errC chan error, shutdownC chan struct{}) error {
+ signals := make(chan os.Signal, 10)
+ signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
+ defer signal.Stop(signals)
+
+ select {
+ case err := <-errC:
+ close(shutdownC)
+ return err
+ case <-signals:
+ close(shutdownC)
+ case <-shutdownC:
+ }
+ return nil
+}
+
+func waitForSignalWithGraceShutdown(errC chan error, shutdownC, graceShutdownSignal chan struct{}, gracePeriod time.Duration) error {
+ signals := make(chan os.Signal, 10)
+ signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
+ defer signal.Stop(signals)
+
+ select {
+ case err := <-errC:
+ close(graceShutdownSignal)
+ close(shutdownC)
+ return err
+ case <-signals:
+ close(graceShutdownSignal)
+ logger.Infof("Initiating graceful shutdown...")
+ // Unregister signal handler early, so the client can send a second SIGTERM/SIGINT
+ // to force shutdown cloudflared
+ signal.Stop(signals)
+ graceTimerTick := time.Tick(gracePeriod)
+ // send close signal via shutdownC when grace period expires or when an
+ // error is encountered.
+ select {
+ case <-graceTimerTick:
+ case <-errC:
+ }
+ close(shutdownC)
+ case <-shutdownC:
+ }
+
+ return nil
+}
diff --git a/cmd/cloudflared/signal_test.go b/cmd/cloudflared/signal_test.go
new file mode 100644
index 00000000..a56edd32
--- /dev/null
+++ b/cmd/cloudflared/signal_test.go
@@ -0,0 +1,131 @@
+package main
+
+import (
+ "fmt"
+ "syscall"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+const tick = 100 * time.Millisecond
+
+var (
+ serverErr = fmt.Errorf("server error")
+ shutdownErr = fmt.Errorf("receive shutdown")
+ graceShutdownErr = fmt.Errorf("receive grace shutdown")
+)
+
+func testChannelClosed(t *testing.T, c chan struct{}) {
+ select {
+ case <-c:
+ return
+ default:
+ t.Fatal("Channel should be readable")
+ }
+}
+
+func TestWaitForSignal(t *testing.T) {
+ // Test handling server error
+ errC := make(chan error)
+ shutdownC := make(chan struct{})
+
+ go func() {
+ errC <- serverErr
+ }()
+
+ err := waitForSignal(errC, shutdownC)
+ assert.Equal(t, serverErr, err)
+ testChannelClosed(t, shutdownC)
+
+ // Test handling SIGTERM & SIGINT
+ for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
+ errC = make(chan error)
+ shutdownC = make(chan struct{})
+
+ go func(shutdownC chan struct{}) {
+ <-shutdownC
+ errC <- shutdownErr
+ }(shutdownC)
+
+ go func(sig syscall.Signal) {
+ // sleep for a tick to prevent sending signal before calling waitForSignal
+ time.Sleep(tick)
+ syscall.Kill(syscall.Getpid(), sig)
+ }(sig)
+
+ err = waitForSignal(errC, shutdownC)
+ assert.Equal(t, nil, err)
+ assert.Equal(t, shutdownErr, <-errC)
+ testChannelClosed(t, shutdownC)
+ }
+}
+
+func TestWaitForSignalWithGraceShutdown(t *testing.T) {
+ // Test server returning error
+ errC := make(chan error)
+ shutdownC := make(chan struct{})
+ graceshutdownC := make(chan struct{})
+
+ go func() {
+ errC <- serverErr
+ }()
+
+ err := waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick)
+ assert.Equal(t, serverErr, err)
+ testChannelClosed(t, shutdownC)
+ testChannelClosed(t, graceshutdownC)
+
+ // Test handling SIGTERM & SIGINT
+ for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
+ //var wg sync.WaitGroup
+ errC := make(chan error)
+ shutdownC = make(chan struct{})
+ graceshutdownC = make(chan struct{})
+
+ go func(shutdownC, graceshutdownC chan struct{}) {
+ <-graceshutdownC
+ <-shutdownC
+ errC <- graceShutdownErr
+ }(shutdownC, graceshutdownC)
+
+ go func(sig syscall.Signal) {
+ // sleep for a tick to prevent sending signal before calling waitForSignalWithGraceShutdown
+ time.Sleep(tick)
+ syscall.Kill(syscall.Getpid(), sig)
+ }(sig)
+
+ err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick)
+ assert.Equal(t, nil, err)
+ assert.Equal(t, graceShutdownErr, <-errC)
+ testChannelClosed(t, shutdownC)
+ testChannelClosed(t, graceshutdownC)
+ }
+
+ // Test handling SIGTERM & SIGINT, server send error before end of grace period
+ for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
+ errC := make(chan error)
+ shutdownC = make(chan struct{})
+ graceshutdownC = make(chan struct{})
+
+ go func(shutdownC, graceshutdownC chan struct{}) {
+ <-graceshutdownC
+ errC <- graceShutdownErr
+ <-shutdownC
+ errC <- shutdownErr
+ }(shutdownC, graceshutdownC)
+
+ go func(sig syscall.Signal) {
+ // sleep for a tick to prevent sending signal before calling waitForSignalWithGraceShutdown
+ time.Sleep(tick)
+ syscall.Kill(syscall.Getpid(), sig)
+ }(sig)
+
+ err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick)
+ assert.Equal(t, nil, err)
+ assert.Equal(t, shutdownErr, <-errC)
+ testChannelClosed(t, shutdownC)
+ testChannelClosed(t, graceshutdownC)
+ }
+}
diff --git a/cmd/cloudflared/update.go b/cmd/cloudflared/update.go
index ccce2e03..5be63f6f 100644
--- a/cmd/cloudflared/update.go
+++ b/cmd/cloudflared/update.go
@@ -1,8 +1,20 @@
package main
-import "github.com/equinox-io/equinox"
+import (
+ "os"
+ "time"
-const appID = "app_idCzgxYerVD"
+ "golang.org/x/crypto/ssh/terminal"
+ "gopkg.in/urfave/cli.v2"
+
+ "github.com/equinox-io/equinox"
+ "github.com/facebookgo/grace/gracenet"
+)
+
+const (
+ appID = "app_idCzgxYerVD"
+ noAutoupdateMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/"
+)
var publicKey = []byte(`
-----BEGIN ECDSA PUBLIC KEY-----
@@ -39,3 +51,61 @@ func checkForUpdates() ReleaseInfo {
return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion}
}
+
+func update(_ *cli.Context) error {
+ if updateApplied() {
+ os.Exit(64)
+ }
+ return nil
+}
+
+func initUpdate(listeners *gracenet.Net) bool {
+ if updateApplied() {
+ os.Args = append(os.Args, "--is-autoupdated=true")
+ if _, err := listeners.StartProcess(); err != nil {
+ logger.WithError(err).Error("Unable to restart server automatically")
+ return false
+ }
+ return true
+ }
+ return false
+}
+
+func autoupdate(freq time.Duration, listeners *gracenet.Net, shutdownC chan struct{}) {
+ for {
+ if updateApplied() {
+ os.Args = append(os.Args, "--is-autoupdated=true")
+ if _, err := listeners.StartProcess(); err != nil {
+ logger.WithError(err).Error("Unable to restart server automatically")
+ }
+ close(shutdownC)
+ return
+ }
+ time.Sleep(freq)
+ }
+}
+
+func updateApplied() bool {
+ releaseInfo := checkForUpdates()
+ if releaseInfo.Updated {
+ logger.Infof("Updated to version %s", releaseInfo.Version)
+ return true
+ }
+ if releaseInfo.Error != nil {
+ logger.WithError(releaseInfo.Error).Error("Update check failed")
+ }
+ return false
+}
+
+func isAutoupdateEnabled(c *cli.Context) bool {
+ if isRunningFromTerminal() {
+ logger.Info(noAutoupdateMessage)
+ return false
+ }
+
+ return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0
+}
+
+func isRunningFromTerminal() bool {
+ return terminal.IsTerminal(int(os.Stdout.Fd()))
+}
diff --git a/cmd/cloudflared/windows_service.go b/cmd/cloudflared/windows_service.go
index f1442863..a30c26bc 100644
--- a/cmd/cloudflared/windows_service.go
+++ b/cmd/cloudflared/windows_service.go
@@ -21,7 +21,7 @@ const (
windowsServiceDescription = "Argo Tunnel agent"
)
-func runApp(app *cli.App) {
+func runApp(app *cli.App, shutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel Windows service",
@@ -59,7 +59,7 @@ func runApp(app *cli.App) {
elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName))
// Run executes service name by calling windowsService which is a Handler
// interface that implements Execute method
- err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog})
+ err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog, shutdownC: shutdownC})
if err != nil {
elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
return
@@ -68,8 +68,9 @@ func runApp(app *cli.App) {
}
type windowsService struct {
- app *cli.App
- elog *eventlog.Log
+ app *cli.App
+ elog *eventlog.Log
+ shutdownC chan struct{}
}
// called by the package code at the start of the service
@@ -98,7 +99,7 @@ loop:
}
}
}
- close(shutdownC)
+ close(s.shutdownC)
changes <- svc.Status{State: svc.StopPending}
return
}
diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go
index 49480cf4..8e566b5c 100644
--- a/h2mux/h2mux_test.go
+++ b/h2mux/h2mux_test.go
@@ -135,7 +135,7 @@ func TestSingleStream(t *testing.T) {
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
}
stream.WriteHeaders([]Header{
- Header{Name: "response-header", Value: "responseValue"},
+ {Name: "response-header", Value: "responseValue"},
})
buf := []byte("Hello world")
stream.Write(buf)
@@ -153,7 +153,7 @@ func TestSingleStream(t *testing.T) {
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream(
- []Header{Header{Name: "test-header", Value: "headerValue"}},
+ []Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
if err != nil {
@@ -194,6 +194,7 @@ func TestSingleStream(t *testing.T) {
func TestSingleStreamLargeResponseBody(t *testing.T) {
muxPair := NewDefaultMuxerPair()
bodySize := 1 << 24
+ streamReady := make(chan struct{})
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
if len(stream.Headers) != 1 {
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
@@ -205,25 +206,30 @@ func TestSingleStreamLargeResponseBody(t *testing.T) {
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
}
stream.WriteHeaders([]Header{
- Header{Name: "response-header", Value: "responseValue"},
+ {Name: "response-header", Value: "responseValue"},
})
payload := make([]byte, bodySize)
for i := range payload {
payload[i] = byte(i % 256)
}
+ t.Log("Writing payload...")
n, err := stream.Write(payload)
+ t.Logf("Wrote %d bytes into the stream", n)
if err != nil {
t.Fatalf("origin write error: %s", err)
}
if n != len(payload) {
t.Fatalf("origin short write: %d/%d bytes", n, len(payload))
}
+ t.Log("Payload written; signaling that the stream is ready")
+ streamReady <- struct{}{}
+
return nil
})
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream(
- []Header{Header{Name: "test-header", Value: "headerValue"}},
+ []Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
if err != nil {
@@ -239,6 +245,10 @@ func TestSingleStreamLargeResponseBody(t *testing.T) {
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
}
responseBody := make([]byte, bodySize)
+
+ <-streamReady
+ t.Log("Received stream ready signal; resuming the test")
+
n, err := io.ReadFull(stream, responseBody)
if err != nil {
t.Fatalf("error from (*MuxedStream).Read: %s", err)
@@ -261,7 +271,7 @@ func TestMultipleStreams(t *testing.T) {
}
log.Debugf("Got request for stream %s", stream.Headers[0].Value)
stream.WriteHeaders([]Header{
- Header{Name: "response-token", Value: stream.Headers[0].Value},
+ {Name: "response-token", Value: stream.Headers[0].Value},
})
log.Debugf("Wrote headers for stream %s", stream.Headers[0].Value)
stream.Write([]byte("OK"))
@@ -277,7 +287,7 @@ func TestMultipleStreams(t *testing.T) {
defer wg.Done()
tokenString := fmt.Sprintf("%d", tokenId)
stream, err := muxPair.EdgeMux.OpenStream(
- []Header{Header{Name: "client-token", Value: tokenString}},
+ []Header{{Name: "client-token", Value: tokenString}},
nil,
)
log.Debugf("Got headers for stream %d", tokenId)
@@ -328,6 +338,7 @@ func TestMultipleStreams(t *testing.T) {
func TestMultipleStreamsFlowControl(t *testing.T) {
maxStreams := 32
errorsC := make(chan error, maxStreams)
+ streamReady := make(chan struct{})
responseSizes := make([]int32, maxStreams)
for i := 0; i < maxStreams; i++ {
responseSizes[i] = rand.Int31n(int32(defaultWindowSize << 4))
@@ -344,13 +355,14 @@ func TestMultipleStreamsFlowControl(t *testing.T) {
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
}
stream.WriteHeaders([]Header{
- Header{Name: "response-header", Value: "responseValue"},
+ {Name: "response-header", Value: "responseValue"},
})
payload := make([]byte, responseSizes[(stream.streamID-2)/2])
for i := range payload {
payload[i] = byte(i % 256)
}
n, err := stream.Write(payload)
+ streamReady <- struct{}{}
if err != nil {
t.Fatalf("origin write error: %s", err)
}
@@ -367,7 +379,7 @@ func TestMultipleStreamsFlowControl(t *testing.T) {
go func(tokenId int) {
defer wg.Done()
stream, err := muxPair.EdgeMux.OpenStream(
- []Header{Header{Name: "test-header", Value: "headerValue"}},
+ []Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
if err != nil {
@@ -387,6 +399,7 @@ func TestMultipleStreamsFlowControl(t *testing.T) {
return
}
+ <-streamReady
responseBody := make([]byte, responseSizes[(stream.streamID-2)/2])
n, err := io.ReadFull(stream, responseBody)
if err != nil {
@@ -417,7 +430,7 @@ func TestGracefulShutdown(t *testing.T) {
muxPair := NewDefaultMuxerPair()
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
stream.WriteHeaders([]Header{
- Header{Name: "response-header", Value: "responseValue"},
+ {Name: "response-header", Value: "responseValue"},
})
<-sendC
log.Debugf("Writing %d bytes", len(responseBuf))
@@ -436,7 +449,7 @@ func TestGracefulShutdown(t *testing.T) {
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream(
- []Header{Header{Name: "test-header", Value: "headerValue"}},
+ []Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
// Start graceful shutdown of the edge mux - this should also close the origin mux when done
@@ -469,7 +482,7 @@ func TestUnexpectedShutdown(t *testing.T) {
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
defer close(handlerFinishC)
stream.WriteHeaders([]Header{
- Header{Name: "response-header", Value: "responseValue"},
+ {Name: "response-header", Value: "responseValue"},
})
<-sendC
n, err := stream.Read([]byte{0})
@@ -490,7 +503,7 @@ func TestUnexpectedShutdown(t *testing.T) {
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream(
- []Header{Header{Name: "test-header", Value: "headerValue"}},
+ []Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
// Close the underlying connection before telling the origin to write.
@@ -552,7 +565,7 @@ func TestOpenAfterDisconnect(t *testing.T) {
}
_, err := muxPair.EdgeMux.OpenStream(
- []Header{Header{Name: "test-header", Value: "headerValue"}},
+ []Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
if err != ErrConnectionClosed {
diff --git a/hello/hello.go b/hello/hello.go
new file mode 100644
index 00000000..7c9149cd
--- /dev/null
+++ b/hello/hello.go
@@ -0,0 +1,197 @@
+package hello
+
+import (
+ "bytes"
+ "crypto/tls"
+ "encoding/json"
+ "fmt"
+ "html/template"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "os"
+ "time"
+
+ "github.com/gorilla/websocket"
+ "github.com/sirupsen/logrus"
+
+ "github.com/cloudflare/cloudflared/tlsconfig"
+)
+
+type templateData struct {
+ ServerName string
+ Request *http.Request
+ Body string
+}
+
+type OriginUpTime struct {
+ StartTime time.Time `json:"startTime"`
+ UpTime string `json:"uptime"`
+}
+
+const defaultServerName = "the Argo Tunnel test server"
+const indexTemplate = `
+
+
+
+
+
+
+ Argo Tunnel Connection
+
+
+
+
+
+
+
+
+
+
+
Congrats! You created your first tunnel!
+
+ Argo Tunnel exposes locally running applications to the internet by
+ running an encrypted, virtual tunnel from your laptop or server to
+ Cloudflare's edge network.
+