From 9135a4837c16a8f1f0dc4073d339db0a03d6579f Mon Sep 17 00:00:00 2001 From: cloudflare-warp-bot Date: Thu, 3 May 2018 22:32:30 +0000 Subject: [PATCH] Release Argo Tunnel Client 2018.5.0 --- cmd/cloudflared/configuration.go | 297 ++++++++++++++++ cmd/cloudflared/generic_service.go | 2 +- cmd/cloudflared/hello.go | 191 +--------- cmd/cloudflared/linux_service.go | 6 +- cmd/cloudflared/logger.go | 65 ++++ cmd/cloudflared/login.go | 2 +- cmd/cloudflared/macos_service.go | 11 +- cmd/cloudflared/main.go | 524 ++++++---------------------- cmd/cloudflared/server.go | 27 ++ cmd/cloudflared/service_template.go | 6 +- cmd/cloudflared/signal.go | 54 +++ cmd/cloudflared/signal_test.go | 131 +++++++ cmd/cloudflared/update.go | 74 +++- cmd/cloudflared/windows_service.go | 11 +- h2mux/h2mux_test.go | 39 ++- hello/hello.go | 197 +++++++++++ hello/hello_test.go | 38 ++ origin/tunnel.go | 9 +- tlsconfig/cloudflare_ca.go | 4 +- tlsconfig/tlsconfig.go | 48 ++- tlsconfig/tlsconfig_test.go | 211 +++++++++++ tunneldns/https_upstream.go | 2 +- websocket/websocket.go | 31 +- websocket/websocket_test.go | 100 ++++++ 24 files changed, 1425 insertions(+), 655 deletions(-) create mode 100644 cmd/cloudflared/configuration.go create mode 100644 cmd/cloudflared/logger.go create mode 100644 cmd/cloudflared/server.go create mode 100644 cmd/cloudflared/signal.go create mode 100644 cmd/cloudflared/signal_test.go create mode 100644 hello/hello.go create mode 100644 hello/hello_test.go create mode 100644 tlsconfig/tlsconfig_test.go create mode 100644 websocket/websocket_test.go diff --git a/cmd/cloudflared/configuration.go b/cmd/cloudflared/configuration.go new file mode 100644 index 00000000..d0e926e2 --- /dev/null +++ b/cmd/cloudflared/configuration.go @@ -0,0 +1,297 @@ +package main + +import ( + "crypto/tls" + "crypto/x509" + "encoding/hex" + "fmt" + "io/ioutil" + "math/rand" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/cloudflare/cloudflared/origin" + "github.com/cloudflare/cloudflared/tlsconfig" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/validation" + + "github.com/sirupsen/logrus" + "gopkg.in/urfave/cli.v2" + "gopkg.in/urfave/cli.v2/altsrc" + + "github.com/mitchellh/go-homedir" + "github.com/pkg/errors" +) + +var ( + defaultConfigFiles = []string{"config.yml", "config.yaml"} + + // Launchd doesn't set root env variables, so there is default + // Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible + defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp", "/usr/local/etc/cloudflared", "/etc/cloudflared"} +) + +const defaultCredentialFile = "cert.pem" + +func fileExists(path string) (bool, error) { + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + // ignore missing files + return false, nil + } + return false, err + } + f.Close() + return true, nil +} + +// returns the first path that contains a cert.pem file. If none of the defaultConfigDirs +// (differs by OS for legacy reasons) contains a cert.pem file, return empty string +func findDefaultOriginCertPath() string { + for _, defaultConfigDir := range defaultConfigDirs { + originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, defaultCredentialFile)) + if ok, _ := fileExists(originCertPath); ok { + return originCertPath + } + } + return "" +} + +// returns the first path that contains a config file. If none of the combination of +// defaultConfigDirs (differs by OS for legacy reasons) and defaultConfigFiles +// contains a config file, return empty string +func findDefaultConfigPath() string { + for _, configDir := range defaultConfigDirs { + for _, configFile := range defaultConfigFiles { + dirPath, err := homedir.Expand(configDir) + if err != nil { + continue + } + path := filepath.Join(dirPath, configFile) + if ok, _ := fileExists(path); ok { + return path + } + } + } + return "" +} + +func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) { + if context.String("config") != "" { + return altsrc.NewYamlSourceFromFile(context.String("config")) + } + return nil, nil +} + +func generateRandomClientID() string { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + id := make([]byte, 32) + r.Read(id) + return hex.EncodeToString(id) +} + +func enoughOptionsSet(c *cli.Context) bool { + // For cloudflared to work, the user needs to at least provide a hostname, + // or runs as stand alone DNS proxy . + // When using sudo, use -E flag to preserve env vars + if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" && os.Getenv("TUNNEL_DNS") == "" { + if isRunningFromTerminal() { + logger.Errorf("No arguments were provided. You need to at least specify the hostname for this tunnel. See %s", quickStartUrl) + logger.Infof("If you want to run Argo Tunnel client as a stand alone DNS proxy, run with --proxy-dns option or set TUNNEL_DNS environment variable.") + } else { + logger.Errorf("You need to specify all the options in a configuration file, or use environment variables. See %s and %s", serviceUrl, argumentsUrl) + logger.Infof("If you want to run Argo Tunnel client as a stand alone DNS proxy, specify proxy-dns option in the configuration file, or set TUNNEL_DNS environment variable.") + } + cli.ShowAppHelp(c) + return false + } + return true +} + +func handleDeprecatedOptions(c *cli.Context) { + // Fail if the user provided an old authentication method + if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") { + logger.Fatal("You don't need to give us your api-key anymore. Please use the new login method. Just run cloudflared login") + } +} + +// validate url. It can be either from --url or argument +func validateUrl(c *cli.Context) (string, error) { + var url = c.String("url") + if c.NArg() > 0 { + if c.IsSet("url") { + return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.") + } + url = c.Args().Get(0) + } + validUrl, err := validation.ValidateUrl(url) + return validUrl, err +} + +func logClientOptions(c *cli.Context) { + flags := make(map[string]interface{}) + for _, flag := range c.LocalFlagNames() { + flags[flag] = c.Generic(flag) + } + if len(flags) > 0 { + logger.Infof("Flags %v", flags) + } + + envs := make(map[string]string) + // Find env variables for Argo Tunnel + for _, env := range os.Environ() { + // All Argo Tunnel env variables start with TUNNEL_ + if strings.Contains(env, "TUNNEL_") { + vars := strings.Split(env, "=") + if len(vars) == 2 { + envs[vars[0]] = vars[1] + } + } + } + if len(envs) > 0 { + logger.Infof("Environmental variables %v", envs) + } +} + +func dnsProxyStandAlone(c *cli.Context) bool { + return c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world")) +} + +func getOriginCert(c *cli.Context) []byte { + if c.String("origincert") == "" { + logger.Warnf("Cannot determine default origin certificate path. No file %s in %v", defaultCredentialFile, defaultConfigDirs) + if isRunningFromTerminal() { + logger.Fatalf("You need to specify the origin certificate path with --origincert option, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", argumentsUrl) + } else { + logger.Fatalf("You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", serviceUrl) + } + } + // Check that the user has acquired a certificate using the login command + originCertPath, err := homedir.Expand(c.String("origincert")) + if err != nil { + logger.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert")) + } + ok, err := fileExists(originCertPath) + if err != nil { + logger.Fatalf("Cannot check if origin cert exists at path %s", c.String("origincert")) + } + if !ok { + logger.Fatalf(`Cannot find a valid certificate for your origin at the path: + + %s + +If the path above is wrong, specify the path with the -origincert option. +If you don't have a certificate signed by Cloudflare, run the command: + + %s login +`, originCertPath, os.Args[0]) + } + // Easier to send the certificate as []byte via RPC than decoding it at this point + originCert, err := ioutil.ReadFile(originCertPath) + if err != nil { + logger.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath) + } + return originCert +} + +func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, logger, protoLogger *logrus.Logger) *origin.TunnelConfig { + hostname, err := validation.ValidateHostname(c.String("hostname")) + if err != nil { + logger.WithError(err).Fatal("Invalid hostname") + } + clientID := c.String("id") + if !c.IsSet("id") { + clientID = generateRandomClientID() + } + + tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) + if err != nil { + logger.WithError(err).Fatal("Tag parse failure") + } + + tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) + + url, err := validateUrl(c) + if err != nil { + logger.WithError(err).Fatal("Error validating url") + } + logger.Infof("Proxying tunnel requests to %s", url) + + originCert := getOriginCert(c) + + originCertPool, err := loadCertPool(c) + if err != nil { + logger.Fatal(err) + } + + tunnelMetrics := origin.NewTunnelMetrics() + httpTransport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: c.Duration("proxy-connect-timeout"), + KeepAlive: c.Duration("proxy-tcp-keepalive"), + DualStack: !c.Bool("proxy-no-happy-eyeballs"), + }).DialContext, + MaxIdleConns: c.Int("proxy-keepalive-connections"), + IdleConnTimeout: c.Duration("proxy-keepalive-timeout"), + TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"), + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{RootCAs: originCertPool}, + } + + if !c.IsSet("hello-world") && c.IsSet("origin-server-name") { + httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name") + } + + return &origin.TunnelConfig{ + EdgeAddrs: c.StringSlice("edge"), + OriginUrl: url, + Hostname: hostname, + OriginCert: originCert, + TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")), + ClientTlsConfig: httpTransport.TLSClientConfig, + Retries: c.Uint("retries"), + HeartbeatInterval: c.Duration("heartbeat-interval"), + MaxHeartbeats: c.Uint64("heartbeat-count"), + ClientID: clientID, + BuildInfo: buildInfo, + ReportedVersion: Version, + LBPool: c.String("lb-pool"), + Tags: tags, + HAConnections: c.Int("ha-connections"), + HTTPTransport: httpTransport, + Metrics: tunnelMetrics, + MetricsUpdateFreq: c.Duration("metrics-update-freq"), + ProtocolLogger: protoLogger, + Logger: logger, + IsAutoupdated: c.Bool("is-autoupdated"), + GracePeriod: c.Duration("grace-period"), + RunFromTerminal: isRunningFromTerminal(), + } +} + +func loadCertPool(c *cli.Context) (*x509.CertPool, error) { + const originCAPoolFlag = "origin-ca-pool" + originCAPoolFilename := c.String(originCAPoolFlag) + var originCustomCAPool []byte + + if originCAPoolFilename != "" { + var err error + originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, originCAPoolFlag)) + } + } + + originCertPool, err := tlsconfig.LoadOriginCertPool(originCustomCAPool) + if err != nil { + return nil, errors.Wrap(err, "error loading the certificate pool") + } + + return originCertPool, nil +} diff --git a/cmd/cloudflared/generic_service.go b/cmd/cloudflared/generic_service.go index 3eb4819e..a2fbf494 100644 --- a/cmd/cloudflared/generic_service.go +++ b/cmd/cloudflared/generic_service.go @@ -8,6 +8,6 @@ import ( cli "gopkg.in/urfave/cli.v2" ) -func runApp(app *cli.App) { +func runApp(app *cli.App, shutdownC chan struct{}) { app.Run(os.Args) } diff --git a/cmd/cloudflared/hello.go b/cmd/cloudflared/hello.go index b2cd332d..55f77269 100644 --- a/cmd/cloudflared/hello.go +++ b/cmd/cloudflared/hello.go @@ -1,204 +1,21 @@ package main import ( - "bytes" - "crypto/tls" - "encoding/json" "fmt" - "html/template" - "io/ioutil" - "net" - "net/http" - "os" - "time" - "github.com/gorilla/websocket" "gopkg.in/urfave/cli.v2" - "github.com/cloudflare/cloudflared/tlsconfig" + "github.com/cloudflare/cloudflared/hello" ) -type templateData struct { - ServerName string - Request *http.Request - Body string -} -type OriginUpTime struct { - StartTime time.Time `json:"startTime"` - UpTime string `json:"uptime"` -} - -const defaultServerName = "the Argo Tunnel test server" -const indexTemplate = ` - - - - - - - Argo Tunnel Connection - - - - - - - -
-
- - - - - - -

Congrats! You created your first tunnel!

-

- Argo Tunnel exposes locally running applications to the internet by - running an encrypted, virtual tunnel from your laptop or server to - Cloudflare's edge network. -

-

Ready for the next step?

- - Get started here - -
-

Request

-
-
Method: {{.Request.Method}}
-
Protocol: {{.Request.Proto}}
-
Request URL: {{.Request.URL}}
-
Transfer encoding: {{.Request.TransferEncoding}}
-
Host: {{.Request.Host}}
-
Remote address: {{.Request.RemoteAddr}}
-
Request URI: {{.Request.RequestURI}}
-{{range $key, $value := .Request.Header}} -
Header: {{$key}}, Value: {{$value}}
-{{end}} -
Body: {{.Body}}
-
-
-
-
- - -` - -func hello(c *cli.Context) error { +func helloWorld(c *cli.Context) error { address := fmt.Sprintf(":%d", c.Int("port")) - listener, err := createListener(address) + listener, err := hello.CreateTLSListener(address) if err != nil { return err } defer listener.Close() - err = startHelloWorldServer(listener, nil) + err = hello.StartHelloWorldServer(logger, listener, nil) return err } - -func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error { - logger.Infof("Starting Hello World server at %s", listener.Addr()) - serverName := defaultServerName - if hostname, err := os.Hostname(); err == nil { - serverName = hostname - } - - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - } - - httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil} - go func() { - <-shutdownC - httpServer.Close() - }() - - http.HandleFunc("/uptime", uptimeHandler(time.Now())) - http.HandleFunc("/ws", websocketHandler(upgrader)) - http.HandleFunc("/", rootHandler(serverName)) - err := httpServer.Serve(listener) - return err -} - -func createListener(address string) (net.Listener, error) { - certificate, err := tlsconfig.GetHelloCertificate() - if err != nil { - return nil, err - } - - // If the port in address is empty, a port number is automatically chosen - listener, err := tls.Listen( - "tcp", - address, - &tls.Config{Certificates: []tls.Certificate{certificate}}) - - return listener, err -} - -func uptimeHandler(startTime time.Time) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // Note that if autoupdate is enabled, the uptime is reset when a new client - // release is available - resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()} - respJson, err := json.Marshal(resp) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - } else { - w.Header().Set("Content-Type", "application/json") - w.Write(respJson) - } - } -} - -func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - - for { - mt, message, err := conn.ReadMessage() - if err != nil { - break - } - - if err := conn.WriteMessage(mt, message); err != nil { - break - } - } - } -} - -func rootHandler(serverName string) http.HandlerFunc { - responseTemplate := template.Must(template.New("index").Parse(indexTemplate)) - return func(w http.ResponseWriter, r *http.Request) { - var buffer bytes.Buffer - var body string - rawBody, err := ioutil.ReadAll(r.Body) - if err == nil { - body = string(rawBody) - } else { - body = "" - } - err = responseTemplate.Execute(&buffer, &templateData{ - ServerName: serverName, - Request: r, - Body: body, - }) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintf(w, "error: %v", err) - } else { - buffer.WriteTo(w) - } - } -} diff --git a/cmd/cloudflared/linux_service.go b/cmd/cloudflared/linux_service.go index 4e859617..19185171 100644 --- a/cmd/cloudflared/linux_service.go +++ b/cmd/cloudflared/linux_service.go @@ -10,7 +10,7 @@ import ( cli "gopkg.in/urfave/cli.v2" ) -func runApp(app *cli.App) { +func runApp(app *cli.App, shutdownC chan struct{}) { app.Commands = append(app.Commands, &cli.Command{ Name: "service", Usage: "Manages the Argo Tunnel system service", @@ -183,9 +183,9 @@ func installLinuxService(c *cli.Context) error { defaultConfigDir := filepath.Dir(c.String("config")) defaultConfigFile := filepath.Base(c.String("config")) - if err = copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile); err != nil { + if err = copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile, defaultCredentialFile); err != nil { logger.WithError(err).Infof("Failed to copy user configuration. Before running the service, ensure that %s contains two files, %s and %s", - serviceConfigDir, credentialFile, defaultConfigFiles[0]) + serviceConfigDir, defaultCredentialFile, defaultConfigFiles[0]) return err } diff --git a/cmd/cloudflared/logger.go b/cmd/cloudflared/logger.go new file mode 100644 index 00000000..74a0b221 --- /dev/null +++ b/cmd/cloudflared/logger.go @@ -0,0 +1,65 @@ +package main + +import ( + "fmt" + "os" + + "github.com/cloudflare/cloudflared/log" + + "github.com/rifflock/lfshook" + "github.com/sirupsen/logrus" + "gopkg.in/urfave/cli.v2" + + "github.com/mitchellh/go-homedir" + "github.com/pkg/errors" +) + +var logger = log.CreateLogger() + +func configMainLogger(c *cli.Context) { + logLevel, err := logrus.ParseLevel(c.String("loglevel")) + if err != nil { + logger.WithError(err).Fatal("Unknown logging level specified") + } + logger.SetLevel(logLevel) +} + +func configProtoLogger(c *cli.Context) *logrus.Logger { + protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel")) + if err != nil { + logger.WithError(err).Fatal("Unknown protocol logging level specified") + } + protoLogger := logrus.New() + protoLogger.Level = protoLogLevel + return protoLogger +} + +func initLogFile(c *cli.Context, loggers ...*logrus.Logger) error { + filePath, err := homedir.Expand(c.String("logfile")) + if err != nil { + return errors.Wrap(err, "Cannot resolve logfile path") + } + + fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC + // do not truncate log file if the client has been autoupdated + if c.Bool("is-autoupdated") { + fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE + } + f, err := os.OpenFile(filePath, fileMode, 0664) + if err != nil { + errors.Wrap(err, fmt.Sprintf("Cannot open file %s", filePath)) + } + defer f.Close() + pathMap := lfshook.PathMap{ + logrus.InfoLevel: filePath, + logrus.ErrorLevel: filePath, + logrus.FatalLevel: filePath, + logrus.PanicLevel: filePath, + } + + for _, l := range loggers { + l.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{})) + } + + return nil +} diff --git a/cmd/cloudflared/login.go b/cmd/cloudflared/login.go index 1ba234ec..bbffa29e 100644 --- a/cmd/cloudflared/login.go +++ b/cmd/cloudflared/login.go @@ -35,7 +35,7 @@ func login(c *cli.Context) error { if err != nil { return err } - path := filepath.Join(configPath, credentialFile) + path := filepath.Join(configPath, defaultCredentialFile) fileInfo, err := os.Stat(path) if err == nil && fileInfo.Size() > 0 { fmt.Fprintf(os.Stderr, `You have an existing certificate at %s which login would overwrite. diff --git a/cmd/cloudflared/macos_service.go b/cmd/cloudflared/macos_service.go index 37737c57..30d703c6 100644 --- a/cmd/cloudflared/macos_service.go +++ b/cmd/cloudflared/macos_service.go @@ -13,7 +13,7 @@ const ( launchdIdentifier = "com.cloudflare.cloudflared" ) -func runApp(app *cli.App) { +func runApp(app *cli.App, shutdownC chan struct{}) { app.Commands = append(app.Commands, &cli.Command{ Name: "service", Usage: "Manages the Argo Tunnel launch agent", @@ -91,12 +91,12 @@ func stderrPath() string { func installLaunchd(c *cli.Context) error { if isRootUser() { logger.Infof("Installing Argo Tunnel client as a system launch daemon. " + - "Argo Tunnel client will run at boot") + "Argo Tunnel client will run at boot") } else { logger.Infof("Installing Argo Tunnel client as an user launch agent. " + - "Note that Argo Tunnel client will only run when the user is logged in. " + - "If you want to run Argo Tunnel client at boot, install with root permission. " + - "For more information, visit https://developers.cloudflare.com/argo-tunnel/reference/service/") + "Note that Argo Tunnel client will only run when the user is logged in. " + + "If you want to run Argo Tunnel client at boot, install with root permission. " + + "For more information, visit https://developers.cloudflare.com/argo-tunnel/reference/service/") } etPath, err := os.Executable() if err != nil { @@ -120,7 +120,6 @@ func installLaunchd(c *cli.Context) error { } func uninstallLaunchd(c *cli.Context) error { - if isRootUser() { logger.Infof("Uninstalling Argo Tunnel as a system launch daemon") } else { diff --git a/cmd/cloudflared/main.go b/cmd/cloudflared/main.go index c0b6ebb1..da17fcca 100644 --- a/cmd/cloudflared/main.go +++ b/cmd/cloudflared/main.go @@ -1,69 +1,47 @@ package main import ( - "crypto/tls" - "encoding/hex" "fmt" - "io/ioutil" - "math/rand" - "net" - "net/http" "os" - "os/signal" - "path/filepath" - "strings" "sync" - "syscall" "time" - "github.com/cloudflare/cloudflared/log" + "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/metrics" "github.com/cloudflare/cloudflared/origin" - "github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tunneldns" - tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" - "github.com/cloudflare/cloudflared/validation" - "github.com/facebookgo/grace/gracenet" "github.com/getsentry/raven-go" "github.com/mitchellh/go-homedir" - "github.com/rifflock/lfshook" - "github.com/sirupsen/logrus" - "golang.org/x/crypto/ssh/terminal" "gopkg.in/urfave/cli.v2" "gopkg.in/urfave/cli.v2/altsrc" "github.com/coreos/go-systemd/daemon" - "github.com/pkg/errors" + "github.com/facebookgo/grace/gracenet" ) const ( - sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878" - credentialFile = "cert.pem" - quickStartUrl = "https://developers.cloudflare.com/argo-tunnel/quickstart/quickstart/" - noAutoupdateMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/" - licenseUrl = "https://developers.cloudflare.com/argo-tunnel/licence/" + sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878" + developerPortal = "https://developers.cloudflare.com/argo-tunnel" + quickStartUrl = developerPortal + "/quickstart/quickstart/" + serviceUrl = developerPortal + "/reference/service/" + argumentsUrl = developerPortal + "/reference/arguments/" + licenseUrl = developerPortal + "/licence/" ) -var listeners = gracenet.Net{} -var Version = "DEV" -var BuildTime = "unknown" -var logger = log.CreateLogger() -var defaultConfigFiles = []string{"config.yml", "config.yaml"} - -// Launchd doesn't set root env variables, so there is default -// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible -var defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp", "/usr/local/etc/cloudflared", "/etc/cloudflared"} - -// Shutdown channel used by the app. When closed, app must terminate. -// May be closed by the Windows service runner. -var shutdownC chan struct{} +var ( + Version = "DEV" + BuildTime = "unknown" +) func main() { metrics.RegisterBuildInfo(BuildTime, Version) raven.SetDSN(sentryDSN) raven.SetRelease(Version) - shutdownC = make(chan struct{}) + + // Shutdown channel used by the app. When closed, app must terminate. + // May be closed by the Windows service runner. + shutdownC := make(chan struct{}) app := &cli.App{} app.Name = "cloudflared" @@ -119,6 +97,11 @@ func main() { EnvVars: []string{"TUNNEL_ORIGIN_CERT"}, Value: findDefaultOriginCertPath(), }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "origin-ca-pool", + Usage: "Path to the CA for the certificate of your origin. This option should be used only if your certificate is not signed by Cloudflare.", + EnvVars: []string{"TUNNEL_ORIGIN_CA_POOL"}, + }), altsrc.NewStringFlag(&cli.StringFlag{ Name: "url", Value: "https://localhost:8080", @@ -293,10 +276,13 @@ func main() { }), } app.Action = func(c *cli.Context) error { - raven.CapturePanic(func() { startServer(c) }, nil) + raven.CapturePanic(func() { startServer(c, shutdownC) }, nil) return nil } app.Before = func(context *cli.Context) error { + if context.String("config") == "" { + logger.Warnf("Cannot determine default configuration path. No file %v in %v", defaultConfigFiles, defaultConfigDirs) + } inputSource, err := findInputSourceContext(context) if err != nil { logger.WithError(err).Infof("Cannot load configuration from %s", context.String("config")) @@ -337,7 +323,7 @@ func main() { }, { Name: "hello", - Action: hello, + Action: helloWorld, Usage: "Run a simple \"Hello World\" server for testing Argo Tunnel.", Flags: []cli.Flag{ &cli.IntFlag{ @@ -381,233 +367,130 @@ func main() { ArgsUsage: " ", // can't be the empty string or we get the default output }, } - runApp(app) + runApp(app, shutdownC) } -func startServer(c *cli.Context) { +func startServer(c *cli.Context, shutdownC chan struct{}) { var wg sync.WaitGroup + listeners := gracenet.Net{} errC := make(chan error) connectedSignal := make(chan struct{}) dnsReadySignal := make(chan struct{}) + graceShutdownSignal := make(chan struct{}) - // If the user choose to supply all options through env variables, - // c.NumFlags() == 0 && c.NArg() == 0. For cloudflared to work, the user needs to at - // least provide a hostname. - if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" { - logger.Infof("No arguments were provided. You need to at least specify the hostname for this tunnel. See %s", quickStartUrl) - cli.ShowAppHelp(c) + // check whether client provides enough flags or env variables. If not, print help. + if ok := enoughOptionsSet(c); !ok { return } - logLevel, err := logrus.ParseLevel(c.String("loglevel")) - if err != nil { - logger.WithError(err).Fatal("Unknown logging level specified") - } - logger.SetLevel(logLevel) - - protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel")) - if err != nil { - logger.WithError(err).Fatal("Unknown protocol logging level specified") - } - protoLogger := logrus.New() - protoLogger.Level = protoLogLevel + configMainLogger(c) + protoLogger := configProtoLogger(c) if c.String("logfile") != "" { - if err := initLogFile(c, protoLogger); err != nil { + if err := initLogFile(c, logger, protoLogger); err != nil { logger.Error(err) } } + handleDeprecatedOptions(c) + buildInfo := origin.GetBuildInfo() logger.Infof("Build info: %+v", *buildInfo) logger.Infof("Version %s", Version) logClientOptions(c) if c.IsSet("proxy-dns") { - port := c.Int("proxy-dns-port") - if port <= 0 || port > 65535 { - logger.Fatal("The 'proxy-dns-port' must be a valid port number in <1, 65535> range.") - } wg.Add(1) - listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(port), c.StringSlice("proxy-dns-upstream")) - if err != nil { - close(dnsReadySignal) - listener.Stop() - logger.WithError(err).Fatal("Cannot create the DNS over HTTPS proxy server") - } go func() { - err := listener.Start(dnsReadySignal) - if err != nil { - logger.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server") - } else { - <-shutdownC - } - listener.Stop() - wg.Done() + defer wg.Done() + runDNSProxyServer(c, dnsReadySignal, shutdownC) }() } else { close(dnsReadySignal) } - isRunningFromTerminal := isRunningFromTerminal() - if isAutoupdateEnabled(c, isRunningFromTerminal) { - // Wait for proxy-dns to come up (if used) - <-dnsReadySignal - if initUpdate() { + // Wait for proxy-dns to come up (if used) + <-dnsReadySignal + + // update needs to be after DNS proxy is up to resolve equinox server address + if isAutoupdateEnabled(c) { + if initUpdate(&listeners) { return } logger.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq")) - go autoupdate(c.Duration("autoupdate-freq"), shutdownC) + go autoupdate(c.Duration("autoupdate-freq"), &listeners, shutdownC) } - // Serve DNS proxy stand-alone if no hostname or tag or app is going to run - if c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world")) { - go writePidFile(connectedSignal, c.String("pidfile")) - close(connectedSignal) - runServer(c, &wg, errC, shutdownC) - return - } - - hostname, err := validation.ValidateHostname(c.String("hostname")) - if err != nil { - logger.WithError(err).Fatal("Invalid hostname") - } - clientID := c.String("id") - if !c.IsSet("id") { - clientID = generateRandomClientID() - } - - tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) - if err != nil { - logger.WithError(err).Fatal("Tag parse failure") - } - - tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) - if c.IsSet("hello-world") { - wg.Add(1) - listener, err := createListener("127.0.0.1:") - if err != nil { - listener.Close() - logger.WithError(err).Fatal("Cannot start Hello World Server") - } - go func() { - startHelloWorldServer(listener, shutdownC) - wg.Done() - listener.Close() - }() - c.Set("url", "https://"+listener.Addr().String()) - } - - url, err := validateUrl(c) - if err != nil { - logger.WithError(err).Fatal("Error validating url") - } - logger.Infof("Proxying tunnel requests to %s", url) - - // Fail if the user provided an old authentication method - if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") { - logger.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflared login") - } - - // Check that the user has acquired a certificate using the log in command - originCertPath, err := homedir.Expand(c.String("origincert")) - if err != nil { - logger.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert")) - } - ok, err := fileExists(originCertPath) - if err != nil { - logger.Fatalf("Cannot check if origin cert exists at path %s", c.String("origincert")) - } - if !ok { - logger.Fatalf(`Cannot find a valid certificate for your origin at the path: - - %s - -If the path above is wrong, specify the path with the -origincert option. -If you don't have a certificate signed by Cloudflare, run the command: - - %s login -`, originCertPath, os.Args[0]) - } - // Easier to send the certificate as []byte via RPC than decoding it at this point - originCert, err := ioutil.ReadFile(originCertPath) - if err != nil { - logger.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath) - } - - tunnelMetrics := origin.NewTunnelMetrics() - httpTransport := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: c.Duration("proxy-connect-timeout"), - KeepAlive: c.Duration("proxy-tcp-keepalive"), - DualStack: !c.Bool("proxy-no-happy-eyeballs"), - }).DialContext, - MaxIdleConns: c.Int("proxy-keepalive-connections"), - IdleConnTimeout: c.Duration("proxy-keepalive-timeout"), - TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"), - ExpectContinueTimeout: 1 * time.Second, - TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()}, - } - - if !c.IsSet("hello-world") && c.IsSet("origin-server-name") { - httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name") - } - - tunnelConfig := &origin.TunnelConfig{ - EdgeAddrs: c.StringSlice("edge"), - OriginUrl: url, - Hostname: hostname, - OriginCert: originCert, - TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")), - ClientTlsConfig: httpTransport.TLSClientConfig, - Retries: c.Uint("retries"), - HeartbeatInterval: c.Duration("heartbeat-interval"), - MaxHeartbeats: c.Uint64("heartbeat-count"), - ClientID: clientID, - BuildInfo: buildInfo, - ReportedVersion: Version, - LBPool: c.String("lb-pool"), - Tags: tags, - HAConnections: c.Int("ha-connections"), - HTTPTransport: httpTransport, - Metrics: tunnelMetrics, - MetricsUpdateFreq: c.Duration("metrics-update-freq"), - ProtocolLogger: protoLogger, - Logger: logger, - IsAutoupdated: c.Bool("is-autoupdated"), - GracePeriod: c.Duration("grace-period"), - RunFromTerminal: isRunningFromTerminal, - } - - go writePidFile(connectedSignal, c.String("pidfile")) - wg.Add(1) - go func() { - errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal) - wg.Done() - }() - - runServer(c, &wg, errC, shutdownC) -} - -func runServer(c *cli.Context, wg *sync.WaitGroup, errC chan error, shutdownC chan struct{}) { - wg.Add(1) metricsListener, err := listeners.Listen("tcp", c.String("metrics")) if err != nil { logger.WithError(err).Fatal("Error opening metrics server listener") } + defer metricsListener.Close() + wg.Add(1) go func() { + defer wg.Done() errC <- metrics.ServeMetrics(metricsListener, shutdownC, logger) - wg.Done() }() + // Serve DNS proxy stand-alone if no hostname or tag or app is going to run + if dnsProxyStandAlone(c) { + if c.IsSet("pidfile") { + go writePidFile(connectedSignal, c.String("pidfile")) + close(connectedSignal) + } + // no grace period, handle SIGINT/SIGTERM immediately + waitToShutdown(&wg, errC, shutdownC, graceShutdownSignal, 0) + return + } + + if c.IsSet("hello-world") { + helloListener, err := hello.CreateTLSListener("127.0.0.1:") + if err != nil { + logger.WithError(err).Fatal("Cannot start Hello World Server") + } + defer helloListener.Close() + wg.Add(1) + go func() { + defer wg.Done() + hello.StartHelloWorldServer(logger, helloListener, shutdownC) + }() + c.Set("url", "https://"+helloListener.Addr().String()) + } + + tunnelConfig := prepareTunnelConfig(c, buildInfo, logger, protoLogger) + + if c.IsSet("pidFile") { + go writePidFile(connectedSignal, c.String("pidfile")) + } + + wg.Add(1) + go func() { + defer wg.Done() + errC <- origin.StartTunnelDaemon(tunnelConfig, graceShutdownSignal, connectedSignal) + }() + + waitToShutdown(&wg, errC, shutdownC, graceShutdownSignal, c.Duration("grace-period")) +} + +func waitToShutdown(wg *sync.WaitGroup, + errC chan error, + shutdownC, graceShutdownSignal chan struct{}, + gracePeriod time.Duration, +) { + var err error + if gracePeriod > 0 { + err = waitForSignalWithGraceShutdown(errC, shutdownC, graceShutdownSignal, gracePeriod) + } else { + err = waitForSignal(errC, shutdownC) + close(graceShutdownSignal) + } + var errCode int - err = WaitForSignal(errC, shutdownC) if err != nil { logger.WithError(err).Fatal("Quitting due to error") raven.CaptureErrorAndWait(err, nil) errCode = 1 } else { - logger.Info("Graceful shutdown...") + logger.Info("Quitting...") } // Wait for clean exit, discarding all errors go func() { @@ -618,126 +501,6 @@ func runServer(c *cli.Context, wg *sync.WaitGroup, errC chan error, shutdownC ch os.Exit(errCode) } -func WaitForSignal(errC chan error, shutdownC chan struct{}) error { - signals := make(chan os.Signal, 10) - signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) - defer signal.Stop(signals) - - select { - case err := <-errC: - close(shutdownC) - return err - case <-signals: - close(shutdownC) - case <-shutdownC: - } - - return nil -} - -func update(_ *cli.Context) error { - if updateApplied() { - os.Exit(64) - } - return nil -} - -func initUpdate() bool { - if updateApplied() { - os.Args = append(os.Args, "--is-autoupdated=true") - if _, err := listeners.StartProcess(); err != nil { - logger.WithError(err).Error("Unable to restart server automatically") - return false - } - return true - } - return false -} - -func autoupdate(freq time.Duration, shutdownC chan struct{}) { - for { - if updateApplied() { - os.Args = append(os.Args, "--is-autoupdated=true") - if _, err := listeners.StartProcess(); err != nil { - logger.WithError(err).Error("Unable to restart server automatically") - } - close(shutdownC) - return - } - time.Sleep(freq) - } -} - -func updateApplied() bool { - releaseInfo := checkForUpdates() - if releaseInfo.Updated { - logger.Infof("Updated to version %s", releaseInfo.Version) - return true - } - if releaseInfo.Error != nil { - logger.WithError(releaseInfo.Error).Error("Update check failed") - } - return false -} - -func fileExists(path string) (bool, error) { - f, err := os.Open(path) - if err != nil { - if os.IsNotExist(err) { - // ignore missing files - return false, nil - } - return false, err - } - f.Close() - return true, nil -} - -// returns the first path that contains a cert.pem file. If none of the defaultConfigDirs -// (differs by OS for legacy reasons) contains a cert.pem file, return empty string -func findDefaultOriginCertPath() string { - for _, defaultConfigDir := range defaultConfigDirs { - originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, credentialFile)) - if ok, _ := fileExists(originCertPath); ok { - return originCertPath - } - } - return "" -} - -// returns the firt path that contains a config file. If none of the combination of -// defaultConfigDirs (differs by OS for legacy reasons) and defaultConfigFiles -// contains a config file, return empty string -func findDefaultConfigPath() string { - for _, configDir := range defaultConfigDirs { - for _, configFile := range defaultConfigFiles { - dirPath, err := homedir.Expand(configDir) - if err != nil { - continue - } - path := filepath.Join(dirPath, configFile) - if ok, _ := fileExists(path); ok { - return path - } - } - } - return "" -} - -func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) { - if context.String("config") != "" { - return altsrc.NewYamlSourceFromFile(context.String("config")) - } - return nil, nil -} - -func generateRandomClientID() string { - r := rand.New(rand.NewSource(time.Now().UnixNano())) - id := make([]byte, 32) - r.Read(id) - return hex.EncodeToString(id) -} - func writePidFile(waitForSignal chan struct{}, pidFile string) { <-waitForSignal daemon.SdNotify(false, "READY=1") @@ -752,87 +515,6 @@ func writePidFile(waitForSignal chan struct{}, pidFile string) { fmt.Fprintf(file, "%d", os.Getpid()) } -// validate url. It can be either from --url or argument -func validateUrl(c *cli.Context) (string, error) { - var url = c.String("url") - if c.NArg() > 0 { - if c.IsSet("url") { - return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.") - } - url = c.Args().Get(0) - } - validUrl, err := validation.ValidateUrl(url) - return validUrl, err -} - -func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error { - filePath, err := homedir.Expand(c.String("logfile")) - if err != nil { - return errors.Wrap(err, "Cannot resolve logfile path") - } - - fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC - // do not truncate log file if the client has been autoupdated - if c.Bool("is-autoupdated") { - fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE - } - f, err := os.OpenFile(filePath, fileMode, 0664) - if err != nil { - errors.Wrap(err, fmt.Sprintf("Cannot open file %s", filePath)) - } - defer f.Close() - pathMap := lfshook.PathMap{ - logrus.InfoLevel: filePath, - logrus.ErrorLevel: filePath, - logrus.FatalLevel: filePath, - logrus.PanicLevel: filePath, - } - - logger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{})) - protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{})) - - return nil -} - -func logClientOptions(c *cli.Context) { - flags := make(map[string]interface{}) - for _, flag := range c.LocalFlagNames() { - flags[flag] = c.Generic(flag) - } - if len(flags) > 0 { - logger.Infof("Flags %v", flags) - } - - envs := make(map[string]string) - // Find env variables for Argo Tunnel - for _, env := range os.Environ() { - // All Argo Tunnel env variables start with TUNNEL_ - if strings.Contains(env, "TUNNEL_") { - vars := strings.Split(env, "=") - if len(vars) == 2 { - envs[vars[0]] = vars[1] - } - } - } - if len(envs) > 0 { - logger.Infof("Environmental variables %v", envs) - } -} - -func isAutoupdateEnabled(c *cli.Context, isRunningFromTerminal bool) bool { - if isRunningFromTerminal { - logger.Info(noAutoupdateMessage) - return false - } - - return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0 -} - - -func isRunningFromTerminal() bool { - return terminal.IsTerminal(int(os.Stdout.Fd())) -} - func userHomeDir() string { // This returns the home dir of the executing user using OS-specific method // for discovering the home dir. It's not recommended to call this function diff --git a/cmd/cloudflared/server.go b/cmd/cloudflared/server.go new file mode 100644 index 00000000..30e65b5f --- /dev/null +++ b/cmd/cloudflared/server.go @@ -0,0 +1,27 @@ +package main + +import ( + "github.com/cloudflare/cloudflared/tunneldns" + + "gopkg.in/urfave/cli.v2" +) + +func runDNSProxyServer(c *cli.Context, dnsReadySignal, shutdownC chan struct{}) { + port := c.Int("proxy-dns-port") + if port <= 0 || port > 65535 { + logger.Fatal("The 'proxy-dns-port' must be a valid port number in <1, 65535> range.") + } + listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(port), c.StringSlice("proxy-dns-upstream")) + if err != nil { + close(dnsReadySignal) + listener.Stop() + logger.WithError(err).Fatal("Cannot create the DNS over HTTPS proxy server") + } + + err = listener.Start(dnsReadySignal) + if err != nil { + logger.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server") + } + <-shutdownC + listener.Stop() +} diff --git a/cmd/cloudflared/service_template.go b/cmd/cloudflared/service_template.go index afec0fd7..63e47bbd 100644 --- a/cmd/cloudflared/service_template.go +++ b/cmd/cloudflared/service_template.go @@ -119,7 +119,7 @@ func openFile(path string, create bool) (file *os.File, exists bool, err error) return file, false, err } -func copyCertificate(srcConfigDir, destConfigDir string) error { +func copyCertificate(srcConfigDir, destConfigDir, credentialFile string) error { destCredentialPath := filepath.Join(destConfigDir, credentialFile) destFile, exists, err := openFile(destCredentialPath, true) if err != nil { @@ -146,12 +146,12 @@ func copyCertificate(srcConfigDir, destConfigDir string) error { return nil } -func copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile string) error { +func copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile, defaultCredentialFile string) error { if err := ensureConfigDirExists(serviceConfigDir); err != nil { return err } - if err := copyCertificate(defaultConfigDir, serviceConfigDir); err != nil { + if err := copyCertificate(defaultConfigDir, serviceConfigDir, defaultCredentialFile); err != nil { return err } diff --git a/cmd/cloudflared/signal.go b/cmd/cloudflared/signal.go new file mode 100644 index 00000000..0953e7e9 --- /dev/null +++ b/cmd/cloudflared/signal.go @@ -0,0 +1,54 @@ +package main + +import ( + "os" + "os/signal" + "syscall" + "time" +) + +func waitForSignal(errC chan error, shutdownC chan struct{}) error { + signals := make(chan os.Signal, 10) + signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) + defer signal.Stop(signals) + + select { + case err := <-errC: + close(shutdownC) + return err + case <-signals: + close(shutdownC) + case <-shutdownC: + } + return nil +} + +func waitForSignalWithGraceShutdown(errC chan error, shutdownC, graceShutdownSignal chan struct{}, gracePeriod time.Duration) error { + signals := make(chan os.Signal, 10) + signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) + defer signal.Stop(signals) + + select { + case err := <-errC: + close(graceShutdownSignal) + close(shutdownC) + return err + case <-signals: + close(graceShutdownSignal) + logger.Infof("Initiating graceful shutdown...") + // Unregister signal handler early, so the client can send a second SIGTERM/SIGINT + // to force shutdown cloudflared + signal.Stop(signals) + graceTimerTick := time.Tick(gracePeriod) + // send close signal via shutdownC when grace period expires or when an + // error is encountered. + select { + case <-graceTimerTick: + case <-errC: + } + close(shutdownC) + case <-shutdownC: + } + + return nil +} diff --git a/cmd/cloudflared/signal_test.go b/cmd/cloudflared/signal_test.go new file mode 100644 index 00000000..a56edd32 --- /dev/null +++ b/cmd/cloudflared/signal_test.go @@ -0,0 +1,131 @@ +package main + +import ( + "fmt" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +const tick = 100 * time.Millisecond + +var ( + serverErr = fmt.Errorf("server error") + shutdownErr = fmt.Errorf("receive shutdown") + graceShutdownErr = fmt.Errorf("receive grace shutdown") +) + +func testChannelClosed(t *testing.T, c chan struct{}) { + select { + case <-c: + return + default: + t.Fatal("Channel should be readable") + } +} + +func TestWaitForSignal(t *testing.T) { + // Test handling server error + errC := make(chan error) + shutdownC := make(chan struct{}) + + go func() { + errC <- serverErr + }() + + err := waitForSignal(errC, shutdownC) + assert.Equal(t, serverErr, err) + testChannelClosed(t, shutdownC) + + // Test handling SIGTERM & SIGINT + for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { + errC = make(chan error) + shutdownC = make(chan struct{}) + + go func(shutdownC chan struct{}) { + <-shutdownC + errC <- shutdownErr + }(shutdownC) + + go func(sig syscall.Signal) { + // sleep for a tick to prevent sending signal before calling waitForSignal + time.Sleep(tick) + syscall.Kill(syscall.Getpid(), sig) + }(sig) + + err = waitForSignal(errC, shutdownC) + assert.Equal(t, nil, err) + assert.Equal(t, shutdownErr, <-errC) + testChannelClosed(t, shutdownC) + } +} + +func TestWaitForSignalWithGraceShutdown(t *testing.T) { + // Test server returning error + errC := make(chan error) + shutdownC := make(chan struct{}) + graceshutdownC := make(chan struct{}) + + go func() { + errC <- serverErr + }() + + err := waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick) + assert.Equal(t, serverErr, err) + testChannelClosed(t, shutdownC) + testChannelClosed(t, graceshutdownC) + + // Test handling SIGTERM & SIGINT + for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { + //var wg sync.WaitGroup + errC := make(chan error) + shutdownC = make(chan struct{}) + graceshutdownC = make(chan struct{}) + + go func(shutdownC, graceshutdownC chan struct{}) { + <-graceshutdownC + <-shutdownC + errC <- graceShutdownErr + }(shutdownC, graceshutdownC) + + go func(sig syscall.Signal) { + // sleep for a tick to prevent sending signal before calling waitForSignalWithGraceShutdown + time.Sleep(tick) + syscall.Kill(syscall.Getpid(), sig) + }(sig) + + err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick) + assert.Equal(t, nil, err) + assert.Equal(t, graceShutdownErr, <-errC) + testChannelClosed(t, shutdownC) + testChannelClosed(t, graceshutdownC) + } + + // Test handling SIGTERM & SIGINT, server send error before end of grace period + for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} { + errC := make(chan error) + shutdownC = make(chan struct{}) + graceshutdownC = make(chan struct{}) + + go func(shutdownC, graceshutdownC chan struct{}) { + <-graceshutdownC + errC <- graceShutdownErr + <-shutdownC + errC <- shutdownErr + }(shutdownC, graceshutdownC) + + go func(sig syscall.Signal) { + // sleep for a tick to prevent sending signal before calling waitForSignalWithGraceShutdown + time.Sleep(tick) + syscall.Kill(syscall.Getpid(), sig) + }(sig) + + err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick) + assert.Equal(t, nil, err) + assert.Equal(t, shutdownErr, <-errC) + testChannelClosed(t, shutdownC) + testChannelClosed(t, graceshutdownC) + } +} diff --git a/cmd/cloudflared/update.go b/cmd/cloudflared/update.go index ccce2e03..5be63f6f 100644 --- a/cmd/cloudflared/update.go +++ b/cmd/cloudflared/update.go @@ -1,8 +1,20 @@ package main -import "github.com/equinox-io/equinox" +import ( + "os" + "time" -const appID = "app_idCzgxYerVD" + "golang.org/x/crypto/ssh/terminal" + "gopkg.in/urfave/cli.v2" + + "github.com/equinox-io/equinox" + "github.com/facebookgo/grace/gracenet" +) + +const ( + appID = "app_idCzgxYerVD" + noAutoupdateMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/" +) var publicKey = []byte(` -----BEGIN ECDSA PUBLIC KEY----- @@ -39,3 +51,61 @@ func checkForUpdates() ReleaseInfo { return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion} } + +func update(_ *cli.Context) error { + if updateApplied() { + os.Exit(64) + } + return nil +} + +func initUpdate(listeners *gracenet.Net) bool { + if updateApplied() { + os.Args = append(os.Args, "--is-autoupdated=true") + if _, err := listeners.StartProcess(); err != nil { + logger.WithError(err).Error("Unable to restart server automatically") + return false + } + return true + } + return false +} + +func autoupdate(freq time.Duration, listeners *gracenet.Net, shutdownC chan struct{}) { + for { + if updateApplied() { + os.Args = append(os.Args, "--is-autoupdated=true") + if _, err := listeners.StartProcess(); err != nil { + logger.WithError(err).Error("Unable to restart server automatically") + } + close(shutdownC) + return + } + time.Sleep(freq) + } +} + +func updateApplied() bool { + releaseInfo := checkForUpdates() + if releaseInfo.Updated { + logger.Infof("Updated to version %s", releaseInfo.Version) + return true + } + if releaseInfo.Error != nil { + logger.WithError(releaseInfo.Error).Error("Update check failed") + } + return false +} + +func isAutoupdateEnabled(c *cli.Context) bool { + if isRunningFromTerminal() { + logger.Info(noAutoupdateMessage) + return false + } + + return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0 +} + +func isRunningFromTerminal() bool { + return terminal.IsTerminal(int(os.Stdout.Fd())) +} diff --git a/cmd/cloudflared/windows_service.go b/cmd/cloudflared/windows_service.go index f1442863..a30c26bc 100644 --- a/cmd/cloudflared/windows_service.go +++ b/cmd/cloudflared/windows_service.go @@ -21,7 +21,7 @@ const ( windowsServiceDescription = "Argo Tunnel agent" ) -func runApp(app *cli.App) { +func runApp(app *cli.App, shutdownC chan struct{}) { app.Commands = append(app.Commands, &cli.Command{ Name: "service", Usage: "Manages the Argo Tunnel Windows service", @@ -59,7 +59,7 @@ func runApp(app *cli.App) { elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName)) // Run executes service name by calling windowsService which is a Handler // interface that implements Execute method - err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog}) + err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog, shutdownC: shutdownC}) if err != nil { elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err)) return @@ -68,8 +68,9 @@ func runApp(app *cli.App) { } type windowsService struct { - app *cli.App - elog *eventlog.Log + app *cli.App + elog *eventlog.Log + shutdownC chan struct{} } // called by the package code at the start of the service @@ -98,7 +99,7 @@ loop: } } } - close(shutdownC) + close(s.shutdownC) changes <- svc.Status{State: svc.StopPending} return } diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go index 49480cf4..8e566b5c 100644 --- a/h2mux/h2mux_test.go +++ b/h2mux/h2mux_test.go @@ -135,7 +135,7 @@ func TestSingleStream(t *testing.T) { t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value) } stream.WriteHeaders([]Header{ - Header{Name: "response-header", Value: "responseValue"}, + {Name: "response-header", Value: "responseValue"}, }) buf := []byte("Hello world") stream.Write(buf) @@ -153,7 +153,7 @@ func TestSingleStream(t *testing.T) { muxPair.HandshakeAndServe(t) stream, err := muxPair.EdgeMux.OpenStream( - []Header{Header{Name: "test-header", Value: "headerValue"}}, + []Header{{Name: "test-header", Value: "headerValue"}}, nil, ) if err != nil { @@ -194,6 +194,7 @@ func TestSingleStream(t *testing.T) { func TestSingleStreamLargeResponseBody(t *testing.T) { muxPair := NewDefaultMuxerPair() bodySize := 1 << 24 + streamReady := make(chan struct{}) muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { if len(stream.Headers) != 1 { t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) @@ -205,25 +206,30 @@ func TestSingleStreamLargeResponseBody(t *testing.T) { t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value) } stream.WriteHeaders([]Header{ - Header{Name: "response-header", Value: "responseValue"}, + {Name: "response-header", Value: "responseValue"}, }) payload := make([]byte, bodySize) for i := range payload { payload[i] = byte(i % 256) } + t.Log("Writing payload...") n, err := stream.Write(payload) + t.Logf("Wrote %d bytes into the stream", n) if err != nil { t.Fatalf("origin write error: %s", err) } if n != len(payload) { t.Fatalf("origin short write: %d/%d bytes", n, len(payload)) } + t.Log("Payload written; signaling that the stream is ready") + streamReady <- struct{}{} + return nil }) muxPair.HandshakeAndServe(t) stream, err := muxPair.EdgeMux.OpenStream( - []Header{Header{Name: "test-header", Value: "headerValue"}}, + []Header{{Name: "test-header", Value: "headerValue"}}, nil, ) if err != nil { @@ -239,6 +245,10 @@ func TestSingleStreamLargeResponseBody(t *testing.T) { t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value) } responseBody := make([]byte, bodySize) + + <-streamReady + t.Log("Received stream ready signal; resuming the test") + n, err := io.ReadFull(stream, responseBody) if err != nil { t.Fatalf("error from (*MuxedStream).Read: %s", err) @@ -261,7 +271,7 @@ func TestMultipleStreams(t *testing.T) { } log.Debugf("Got request for stream %s", stream.Headers[0].Value) stream.WriteHeaders([]Header{ - Header{Name: "response-token", Value: stream.Headers[0].Value}, + {Name: "response-token", Value: stream.Headers[0].Value}, }) log.Debugf("Wrote headers for stream %s", stream.Headers[0].Value) stream.Write([]byte("OK")) @@ -277,7 +287,7 @@ func TestMultipleStreams(t *testing.T) { defer wg.Done() tokenString := fmt.Sprintf("%d", tokenId) stream, err := muxPair.EdgeMux.OpenStream( - []Header{Header{Name: "client-token", Value: tokenString}}, + []Header{{Name: "client-token", Value: tokenString}}, nil, ) log.Debugf("Got headers for stream %d", tokenId) @@ -328,6 +338,7 @@ func TestMultipleStreams(t *testing.T) { func TestMultipleStreamsFlowControl(t *testing.T) { maxStreams := 32 errorsC := make(chan error, maxStreams) + streamReady := make(chan struct{}) responseSizes := make([]int32, maxStreams) for i := 0; i < maxStreams; i++ { responseSizes[i] = rand.Int31n(int32(defaultWindowSize << 4)) @@ -344,13 +355,14 @@ func TestMultipleStreamsFlowControl(t *testing.T) { t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value) } stream.WriteHeaders([]Header{ - Header{Name: "response-header", Value: "responseValue"}, + {Name: "response-header", Value: "responseValue"}, }) payload := make([]byte, responseSizes[(stream.streamID-2)/2]) for i := range payload { payload[i] = byte(i % 256) } n, err := stream.Write(payload) + streamReady <- struct{}{} if err != nil { t.Fatalf("origin write error: %s", err) } @@ -367,7 +379,7 @@ func TestMultipleStreamsFlowControl(t *testing.T) { go func(tokenId int) { defer wg.Done() stream, err := muxPair.EdgeMux.OpenStream( - []Header{Header{Name: "test-header", Value: "headerValue"}}, + []Header{{Name: "test-header", Value: "headerValue"}}, nil, ) if err != nil { @@ -387,6 +399,7 @@ func TestMultipleStreamsFlowControl(t *testing.T) { return } + <-streamReady responseBody := make([]byte, responseSizes[(stream.streamID-2)/2]) n, err := io.ReadFull(stream, responseBody) if err != nil { @@ -417,7 +430,7 @@ func TestGracefulShutdown(t *testing.T) { muxPair := NewDefaultMuxerPair() muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { stream.WriteHeaders([]Header{ - Header{Name: "response-header", Value: "responseValue"}, + {Name: "response-header", Value: "responseValue"}, }) <-sendC log.Debugf("Writing %d bytes", len(responseBuf)) @@ -436,7 +449,7 @@ func TestGracefulShutdown(t *testing.T) { muxPair.HandshakeAndServe(t) stream, err := muxPair.EdgeMux.OpenStream( - []Header{Header{Name: "test-header", Value: "headerValue"}}, + []Header{{Name: "test-header", Value: "headerValue"}}, nil, ) // Start graceful shutdown of the edge mux - this should also close the origin mux when done @@ -469,7 +482,7 @@ func TestUnexpectedShutdown(t *testing.T) { muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { defer close(handlerFinishC) stream.WriteHeaders([]Header{ - Header{Name: "response-header", Value: "responseValue"}, + {Name: "response-header", Value: "responseValue"}, }) <-sendC n, err := stream.Read([]byte{0}) @@ -490,7 +503,7 @@ func TestUnexpectedShutdown(t *testing.T) { muxPair.HandshakeAndServe(t) stream, err := muxPair.EdgeMux.OpenStream( - []Header{Header{Name: "test-header", Value: "headerValue"}}, + []Header{{Name: "test-header", Value: "headerValue"}}, nil, ) // Close the underlying connection before telling the origin to write. @@ -552,7 +565,7 @@ func TestOpenAfterDisconnect(t *testing.T) { } _, err := muxPair.EdgeMux.OpenStream( - []Header{Header{Name: "test-header", Value: "headerValue"}}, + []Header{{Name: "test-header", Value: "headerValue"}}, nil, ) if err != ErrConnectionClosed { diff --git a/hello/hello.go b/hello/hello.go new file mode 100644 index 00000000..7c9149cd --- /dev/null +++ b/hello/hello.go @@ -0,0 +1,197 @@ +package hello + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "html/template" + "io/ioutil" + "net" + "net/http" + "os" + "time" + + "github.com/gorilla/websocket" + "github.com/sirupsen/logrus" + + "github.com/cloudflare/cloudflared/tlsconfig" +) + +type templateData struct { + ServerName string + Request *http.Request + Body string +} + +type OriginUpTime struct { + StartTime time.Time `json:"startTime"` + UpTime string `json:"uptime"` +} + +const defaultServerName = "the Argo Tunnel test server" +const indexTemplate = ` + + + + + + + Argo Tunnel Connection + + + + + + + +
+
+ + + + + + +

Congrats! You created your first tunnel!

+

+ Argo Tunnel exposes locally running applications to the internet by + running an encrypted, virtual tunnel from your laptop or server to + Cloudflare's edge network. +

+

Ready for the next step?

+ + Get started here + +
+

Request

+
+
Method: {{.Request.Method}}
+
Protocol: {{.Request.Proto}}
+
Request URL: {{.Request.URL}}
+
Transfer encoding: {{.Request.TransferEncoding}}
+
Host: {{.Request.Host}}
+
Remote address: {{.Request.RemoteAddr}}
+
Request URI: {{.Request.RequestURI}}
+{{range $key, $value := .Request.Header}} +
Header: {{$key}}, Value: {{$value}}
+{{end}} +
Body: {{.Body}}
+
+
+
+
+ + +` + + +func StartHelloWorldServer(logger *logrus.Logger, listener net.Listener, shutdownC <-chan struct{}) error { + logger.Infof("Starting Hello World server at %s", listener.Addr()) + serverName := defaultServerName + if hostname, err := os.Hostname(); err == nil { + serverName = hostname + } + + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil} + go func() { + <-shutdownC + httpServer.Close() + }() + + http.HandleFunc("/uptime", uptimeHandler(time.Now())) + http.HandleFunc("/ws", websocketHandler(logger, upgrader)) + http.HandleFunc("/", rootHandler(serverName)) + err := httpServer.Serve(listener) + return err +} + +func CreateTLSListener(address string) (net.Listener, error) { + certificate, err := tlsconfig.GetHelloCertificate() + if err != nil { + return nil, err + } + + // If the port in address is empty, a port number is automatically chosen + listener, err := tls.Listen( + "tcp", + address, + &tls.Config{Certificates: []tls.Certificate{certificate}}) + + return listener, err +} + + +func uptimeHandler(startTime time.Time) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Note that if autoupdate is enabled, the uptime is reset when a new client + // release is available + resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()} + respJson, err := json.Marshal(resp) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } else { + w.Header().Set("Content-Type", "application/json") + w.Write(respJson) + } + } +} + +// This handler will echo message +func websocketHandler(logger *logrus.Logger, upgrader websocket.Upgrader) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + for { + mt, message, err := conn.ReadMessage() + if err != nil { + logger.WithError(err).Error("websocket read message error") + break + } + + if err := conn.WriteMessage(mt, message); err != nil { + logger.WithError(err).Error("websocket write message error") + break + } + } + } +} + +func rootHandler(serverName string) http.HandlerFunc { + responseTemplate := template.Must(template.New("index").Parse(indexTemplate)) + return func(w http.ResponseWriter, r *http.Request) { + var buffer bytes.Buffer + var body string + rawBody, err := ioutil.ReadAll(r.Body) + if err == nil { + body = string(rawBody) + } else { + body = "" + } + err = responseTemplate.Execute(&buffer, &templateData{ + ServerName: serverName, + Request: r, + Body: body, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "error: %v", err) + } else { + buffer.WriteTo(w) + } + } +} diff --git a/hello/hello_test.go b/hello/hello_test.go new file mode 100644 index 00000000..3a049c8a --- /dev/null +++ b/hello/hello_test.go @@ -0,0 +1,38 @@ +package hello + +import ( + "testing" +) + +func TestCreateTLSListenerHostAndPortSuccess(t *testing.T) { + listener, err := CreateTLSListener("localhost:1234") + defer listener.Close() + if err != nil { + t.Fatal(err) + } + if listener.Addr().String() == "" { + t.Fatal("Fail to find available port") + } +} + +func TestCreateTLSListenerOnlyHostSuccess(t *testing.T) { + listener, err := CreateTLSListener("localhost:") + defer listener.Close() + if err != nil { + t.Fatal(err) + } + if listener.Addr().String() == "" { + t.Fatal("Fail to find available port") + } +} + +func TestCreateTLSListenerOnlyPortSuccess(t *testing.T) { + listener, err := CreateTLSListener(":8888") + defer listener.Close() + if err != nil { + t.Fatal(err) + } + if listener.Addr().String() == "" { + t.Fatal("Fail to find available port") + } +} diff --git a/origin/tunnel.go b/origin/tunnel.go index 8257d3d1..a7d0c579 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -123,7 +123,12 @@ func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connecte } } -func ServeTunnelLoop(ctx context.Context, config *TunnelConfig, addr *net.TCPAddr, connectionID uint8, connectedSignal chan struct{}) error { +func ServeTunnelLoop(ctx context.Context, + config *TunnelConfig, + addr *net.TCPAddr, + connectionID uint8, + connectedSignal chan struct{}, +) error { config.Metrics.incrementHaConnections() defer config.Metrics.decrementHaConnections() backoff := BackoffHandler{MaxRetries: config.Retries} @@ -482,6 +487,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { } else { stream.WriteHeaders(H1ResponseToH2Response(response)) defer conn.Close() + // Copy to/from stream to the undelying connection. Use the underlying + // connection because cloudflared doesn't operate on the message themselves websocket.Stream(conn.UnderlyingConn(), stream) h.metrics.incrementResponses(h.connectionID, "200") h.logResponse(response, cfRay) diff --git a/tlsconfig/cloudflare_ca.go b/tlsconfig/cloudflare_ca.go index f1444a7d..202be9eb 100644 --- a/tlsconfig/cloudflare_ca.go +++ b/tlsconfig/cloudflare_ca.go @@ -5,7 +5,7 @@ import ( ) // TODO: remove the Origin CA root certs when migrated to Authenticated Origin Pull certs -const cloudflareRootCA = ` +var cloudflareRootCA = []byte(` Issuer: C=US, ST=California, L=San Francisco, O=CloudFlare, Inc., OU=CloudFlare Origin SSL ECC Certificate Authority -----BEGIN CERTIFICATE----- MIICiDCCAi6gAwIBAgIUXZP3MWb8MKwBE1Qbawsp1sfA/Y4wCgYIKoZIzj0EAwIw @@ -83,7 +83,7 @@ Bz+1CD4D/bWrs3cC9+kk/jFmrrAymZlkFX8tDb5aXASSLJjUjcptci9SKqtI2h0J wUGkD7+bQAr+7vr8/R+CBmNMe7csE8NeEX6lVMF7Dh0a1YKQa6hUN18bBuYgTMuT QzMmZpRpIBB321ZBlcnlxiTJvWxvbCPHKHj20VwwAz7LONF59s84ZsOqfoBv8gKM s0s5dsq5zpLeaw== ------END CERTIFICATE-----` +-----END CERTIFICATE-----`) func GetCloudflareRootCA() *x509.CertPool { ca := x509.NewCertPool() diff --git a/tlsconfig/tlsconfig.go b/tlsconfig/tlsconfig.go index ba76cf97..2ecda972 100644 --- a/tlsconfig/tlsconfig.go +++ b/tlsconfig/tlsconfig.go @@ -9,6 +9,7 @@ import ( "net" "github.com/cloudflare/cloudflared/log" + "github.com/pkg/errors" "gopkg.in/urfave/cli.v2" ) @@ -64,21 +65,27 @@ func LoadCert(certPath string) *x509.CertPool { return ca } -func LoadOriginCertsPool() *x509.CertPool { +func LoadGlobalCertPool() (*x509.CertPool, error) { + success := false + // First, obtain the system certificate pool certPool, systemCertPoolErr := x509.SystemCertPool() if systemCertPoolErr != nil { logger.Warnf("error obtaining the system certificates: %s", systemCertPoolErr) certPool = x509.NewCertPool() + } else { + success = true } // Next, append the Cloudflare CA pool into the system pool - if !certPool.AppendCertsFromPEM([]byte(cloudflareRootCA)) { - logger.Warn("could not append the CF certificate to the system certificate pool") + if !certPool.AppendCertsFromPEM(cloudflareRootCA) { + logger.Warn("could not append the CF certificate to the cloudflared certificate pool") + } else { + success = true + } - if systemCertPoolErr != nil { // Obtaining both certificates failed; this is a fatal error - logger.WithError(systemCertPoolErr).Fatalf("Error loading the certificate pool") - } + if success != true { // Obtaining any of the CAs has failed; this is a fatal error + return nil, errors.New("error loading any of the CAs into the global certificate pool") } // Finally, add the Hello certificate into the pool (since it's self-signed) @@ -89,7 +96,34 @@ func LoadOriginCertsPool() *x509.CertPool { certPool.AddCert(helloCertificate) - return certPool + return certPool, nil +} + +func LoadOriginCertPool(originCAPoolPEM []byte) (*x509.CertPool, error) { + success := false + + // Get the global pool + certPool, globalPoolErr := LoadGlobalCertPool() + if globalPoolErr != nil { + certPool = x509.NewCertPool() + } else { + success = true + } + + // Then, add any custom origin CA pool the user may have passed + if originCAPoolPEM != nil { + if !certPool.AppendCertsFromPEM(originCAPoolPEM) { + logger.Warn("could not append the provided origin CA to the cloudflared certificate pool") + } else { + success = true + } + } + + if success != true { + return nil, errors.New("error loading any of the CAs into the origin certificate pool") + } + + return certPool, nil } func CreateTunnelConfig(c *cli.Context, addrs []string) *tls.Config { diff --git a/tlsconfig/tlsconfig_test.go b/tlsconfig/tlsconfig_test.go new file mode 100644 index 00000000..48e9028f --- /dev/null +++ b/tlsconfig/tlsconfig_test.go @@ -0,0 +1,211 @@ +package tlsconfig + +import ( + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Generated using `openssl req -newkey rsa:512 -nodes -x509 -days 3650` +var samplePEM = []byte(` +-----BEGIN CERTIFICATE----- +MIIB4DCCAYoCCQCb/H0EUrdXEjANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv +dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE +AwwIVGVzdCBPbmUwHhcNMTgwNDI2MTYxMDUxWhcNMjgwNDIzMTYxMDUxWjB3MQsw +CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG +A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn +eTERMA8GA1UEAwwIVGVzdCBPbmUwXDANBgkqhkiG9w0BAQEFAANLADBIAkEAwVQD +K0SJ25UFLznm2pU3zhzMEvpDEofHVNnCjk4mlDrtVop7PkKZ8pDEmuQANltUrxC8 +yHBE2wXMv+GlH+bDtwIDAQABMA0GCSqGSIb3DQEBCwUAA0EAjVYQzozIFPkt/HRY +uUoZ8zEHIDICb0syFf5VAjm9AgTwIPzUmD+c5vl6LWDnxq7L45nLCzhhQ6YmiwDz +X7Wcyg== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIB4DCCAYoCCQDZfCdAJ+mwzDANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV +UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv +dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE +AwwIVGVzdCBUd28wHhcNMTgwNDI2MTYxMTIwWhcNMjgwNDIzMTYxMTIwWjB3MQsw +CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG +A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn +eTERMA8GA1UEAwwIVGVzdCBUd28wXDANBgkqhkiG9w0BAQEFAANLADBIAkEAoHKp +ROVK3zCSsH7ocYeyRAML4V7SFAbZcb4WIwDnE08oMBVRkQVcW5tqEkvG3RiClfzV +wZIJ3CfqKIeSNSDU9wIDAQABMA0GCSqGSIb3DQEBCwUAA0EAJw2gUbnPiq4C2p5b +iWzlA9Q7aKo+VQ4H7IZS7tTccr59nVjvH/TG3eWujpnocr4TOqW9M3CK1DF9mUGP +3pQ3Jg== +-----END CERTIFICATE----- +`) + +var systemCertPoolSubjects []*pkix.Name + +type certificateFixture struct { + ou string + cn string +} + +func TestMain(m *testing.M) { + systemCertPool, err := x509.SystemCertPool() + if isUnrecoverableError(err) { + os.Exit(1) + } + + if systemCertPool == nil { + // On Windows, let's just assume the system cert pool was empty + systemCertPool = x509.NewCertPool() + } + + systemCertPoolSubjects, err = getCertPoolSubjects(systemCertPool) + if err != nil { + os.Exit(1) + } + + os.Exit(m.Run()) +} + +func TestLoadOriginCertPoolJustSystemPool(t *testing.T) { + certPoolSubjects := loadCertPoolSubjects(t, nil) + extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects) + + // Remove extra subjects from the cert pool + var filteredSystemCertPoolSubjects []*pkix.Name + + t.Log(extraSubjects) + +OUTER: + for _, subject := range certPoolSubjects { + for _, extraSubject := range extraSubjects { + if subject == extraSubject { + t.Log(extraSubject) + continue OUTER + } + } + + filteredSystemCertPoolSubjects = append(filteredSystemCertPoolSubjects, subject) + } + + assert.Equal(t, len(filteredSystemCertPoolSubjects), len(systemCertPoolSubjects)) + + difference := subjectSubtract(systemCertPoolSubjects, filteredSystemCertPoolSubjects) + assert.Equal(t, 0, len(difference)) +} + +func TestLoadOriginCertPoolCFCertificates(t *testing.T) { + certPoolSubjects := loadCertPoolSubjects(t, nil) + + extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects) + + expected := []*certificateFixture{ + {ou: "CloudFlare Origin SSL ECC Certificate Authority"}, + {ou: "CloudFlare Origin SSL Certificate Authority"}, + {cn: "origin-pull.cloudflare.net"}, + {cn: "Argo Tunnel Sample Hello Server Certificate"}, + } + + assertFixturesMatchSubjects(t, expected, extraSubjects) +} + +func TestLoadOriginCertPoolWithExtraPEMs(t *testing.T) { + certPoolWithoutPEMSubjects := loadCertPoolSubjects(t, nil) + certPoolWithPEMSubjects := loadCertPoolSubjects(t, samplePEM) + + difference := subjectSubtract(certPoolWithoutPEMSubjects, certPoolWithPEMSubjects) + + assert.Equal(t, 2, len(difference)) + + expected := []*certificateFixture{ + {cn: "Test One"}, + {cn: "Test Two"}, + } + + assertFixturesMatchSubjects(t, expected, difference) +} + +func loadCertPoolSubjects(t *testing.T, originCAPoolPEM []byte) []*pkix.Name { + certPool, err := LoadOriginCertPool(originCAPoolPEM) + if isUnrecoverableError(err) { + t.Fatal(err) + } + assert.NotEmpty(t, certPool.Subjects()) + certPoolSubjects, err := getCertPoolSubjects(certPool) + if err != nil { + t.Fatal(err) + } + + return certPoolSubjects +} + +func assertFixturesMatchSubjects(t *testing.T, fixtures []*certificateFixture, subjects []*pkix.Name) { + assert.Equal(t, len(fixtures), len(subjects)) + + for _, fixture := range fixtures { + found := false + for _, subject := range subjects { + found = found || fixtureMatchesSubjectPredicate(fixture, subject) + } + + if !found { + t.Fail() + } + } +} + +func fixtureMatchesSubjectPredicate(fixture *certificateFixture, subject *pkix.Name) bool { + cnMatch := true + if fixture.cn != "" { + cnMatch = fixture.cn == subject.CommonName + } + + ouMatch := true + if fixture.ou != "" { + ouMatch = len(subject.OrganizationalUnit) > 0 && fixture.ou == subject.OrganizationalUnit[0] + } + + return cnMatch && ouMatch +} + +func subjectSubtract(left []*pkix.Name, right []*pkix.Name) []*pkix.Name { + var difference []*pkix.Name + + var found bool + for _, r := range right { + found = false + for _, l := range left { + if (*l).String() == (*r).String() { + found = true + } + } + + if !found { + difference = append(difference, r) + } + } + + return difference +} + +func getCertPoolSubjects(certPool *x509.CertPool) ([]*pkix.Name, error) { + var subjects []*pkix.Name + + for _, subject := range certPool.Subjects() { + var sequence pkix.RDNSequence + _, err := asn1.Unmarshal(subject, &sequence) + if err != nil { + return nil, err + } + + name := pkix.Name{} + name.FillFromRDNSequence(&sequence) + + subjects = append(subjects, &name) + } + + return subjects, nil +} + +func isUnrecoverableError(err error) bool { + return err != nil && err.Error() != "crypto/x509: system root pool is not available on Windows" +} diff --git a/tunneldns/https_upstream.go b/tunneldns/https_upstream.go index 1c681dfc..03d7520c 100644 --- a/tunneldns/https_upstream.go +++ b/tunneldns/https_upstream.go @@ -43,7 +43,7 @@ func NewUpstreamHTTPS(endpoint string) (Upstream, error) { http2.ConfigureTransport(transport) client := &http.Client{ - Timeout: time.Second * defaultTimeout, + Timeout: defaultTimeout, Transport: transport, } diff --git a/websocket/websocket.go b/websocket/websocket.go index c70b2a7d..371269b2 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -13,16 +13,28 @@ import ( "github.com/gorilla/websocket" ) +var stripWebsocketHeaders = []string { + "Upgrade", + "Connection", + "Sec-Websocket-Key", + "Sec-Websocket-Version", + "Sec-Websocket-Extensions", +} + // IsWebSocketUpgrade checks to see if the request is a WebSocket connection. func IsWebSocketUpgrade(req *http.Request) bool { return websocket.IsWebSocketUpgrade(req) } -// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing. +// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing +// the connection. The response body may not contain the entire response and does +// not need to be closed by the application. func ClientConnect(req *http.Request, tlsClientConfig *tls.Config) (*websocket.Conn, *http.Response, error) { req.URL.Scheme = changeRequestScheme(req) + wsHeaders := websocketHeaders(req) + d := &websocket.Dialer{TLSClientConfig: tlsClientConfig} - conn, response, err := d.Dial(req.URL.String(), nil) + conn, response, err := d.Dial(req.URL.String(), wsHeaders) if err != nil { return nil, nil, err } @@ -62,6 +74,21 @@ func Stream(conn, backendConn io.ReadWriter) { <-proxyDone } +// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key, +// Sec-WebSocket-Version and Sec-Websocket-Extensions headers. +// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194. +func websocketHeaders(req *http.Request) http.Header { + wsHeaders := make(http.Header) + for key, val := range req.Header { + wsHeaders[key] = val + } + // Assume the header keys are in canonical format. + for _, header := range stripWebsocketHeaders { + wsHeaders.Del(header) + } + return wsHeaders +} + // sha1Base64 sha1 and then base64 encodes str. func sha1Base64(str string) string { hasher := sha1.New() diff --git a/websocket/websocket_test.go b/websocket/websocket_test.go new file mode 100644 index 00000000..45eacd95 --- /dev/null +++ b/websocket/websocket_test.go @@ -0,0 +1,100 @@ +package websocket + +import ( + "crypto/tls" + "io" + "math/rand" + "net/http" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "golang.org/x/net/websocket" + + "github.com/cloudflare/cloudflared/hello" + "github.com/cloudflare/cloudflared/tlsconfig" +) + +const ( + // example in Sec-Websocket-Key in rfc6455 + testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ==" + // example Sec-Websocket-Accept in rfc6455 + testSecWebsocketAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" +) + +func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request { + req, err := http.NewRequest("GET", url, stream) + if err != nil { + t.Fatalf("testRequestHeader error") + } + + req.Header.Add("Connection", "Upgrade") + req.Header.Add("Upgrade", "WebSocket") + req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey) + req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol") + req.Header.Add("Sec-Websocket-Version", "13") + req.Header.Add("User-Agent", "curl/7.59.0") + + return req +} + +func websocketClientTLSConfig(t *testing.T) *tls.Config { + certPool, err := tlsconfig.LoadOriginCertPool(nil) + assert.NoError(t, err) + assert.NotNil(t, certPool) + return &tls.Config{RootCAs: certPool} +} + +func TestWebsocketHeaders(t *testing.T) { + req := testRequest(t, "http://example.com", nil) + wsHeaders := websocketHeaders(req) + for _, header := range stripWebsocketHeaders { + assert.Empty(t, wsHeaders[header]) + } + assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent")) +} + +func TestGenerateAcceptKey(t *testing.T) { + req := testRequest(t, "http://example.com", nil) + assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(req)) +} + +func TestServe(t *testing.T) { + logger := logrus.New() + shutdownC := make(chan struct{}) + errC := make(chan error) + listener, err := hello.CreateTLSListener("localhost:1111") + assert.NoError(t, err) + defer listener.Close() + + go func() { + errC <- hello.StartHelloWorldServer(logger, listener, shutdownC) + }() + + req := testRequest(t, "https://localhost:1111/ws", nil) + + tlsConfig := websocketClientTLSConfig(t) + assert.NotNil(t, tlsConfig) + conn, resp, err := ClientConnect(req, tlsConfig) + assert.NoError(t, err) + assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept")) + + for i := 0; i < 1000; i++ { + messageSize := rand.Int() % 2048 + 1 + clientMessage := make([]byte, messageSize) + // rand.Read always returns len(clientMessage) and a nil error + rand.Read(clientMessage) + err = conn.WriteMessage(websocket.BinaryFrame, clientMessage) + assert.NoError(t, err) + + messageType, message, err := conn.ReadMessage() + assert.NoError(t, err) + assert.Equal(t, websocket.BinaryFrame, messageType) + assert.Equal(t, clientMessage, message) + } + + conn.Close() + close(shutdownC) + <-errC +}