diff --git a/cmd/cloudflared/generic_service.go b/cmd/cloudflared/generic_service.go new file mode 100644 index 00000000..3eb4819e --- /dev/null +++ b/cmd/cloudflared/generic_service.go @@ -0,0 +1,13 @@ +// +build !windows,!darwin,!linux + +package main + +import ( + "os" + + cli "gopkg.in/urfave/cli.v2" +) + +func runApp(app *cli.App) { + app.Run(os.Args) +} diff --git a/cmd/cloudflared/hello.go b/cmd/cloudflared/hello.go new file mode 100644 index 00000000..9f73ade8 --- /dev/null +++ b/cmd/cloudflared/hello.go @@ -0,0 +1,204 @@ +package main + +import ( + "bytes" + "crypto/tls" + "encoding/json" + "fmt" + "html/template" + "io/ioutil" + "net" + "net/http" + "os" + "time" + + "github.com/gorilla/websocket" + "gopkg.in/urfave/cli.v2" + + "github.com/cloudflare/cloudflared/tlsconfig" +) + +type templateData struct { + ServerName string + Request *http.Request + Body string +} + +type OriginUpTime struct { + StartTime time.Time `json:"startTime"` + UpTime string `json:"uptime"` +} + +const defaultServerName = "the Argo Tunnel test server" +const indexTemplate = ` + + + + + + + Argo Tunnel Connection + + + + + + + +
+
+ + + + + + +

Congrats! You created your first tunnel!

+

+ Argo Tunnel exposes locally running applications to the internet by + running an encrypted, virtual tunnel from your laptop or server to + Cloudflare's edge network. +

+

Ready for the next step?

+ + Get started here + +
+

Request

+
+
Method: {{.Request.Method}}
+
Protocol: {{.Request.Proto}}
+
Request URL: {{.Request.URL}}
+
Transfer encoding: {{.Request.TransferEncoding}}
+
Host: {{.Request.Host}}
+
Remote address: {{.Request.RemoteAddr}}
+
Request URI: {{.Request.RequestURI}}
+{{range $key, $value := .Request.Header}} +
Header: {{$key}}, Value: {{$value}}
+{{end}} +
Body: {{.Body}}
+
+
+
+
+ + +` + +func hello(c *cli.Context) error { + address := fmt.Sprintf(":%d", c.Int("port")) + listener, err := createListener(address) + if err != nil { + return err + } + defer listener.Close() + err = startHelloWorldServer(listener, nil) + return err +} + +func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error { + Log.Infof("Starting Hello World server at %s", listener.Addr()) + serverName := defaultServerName + if hostname, err := os.Hostname(); err == nil { + serverName = hostname + } + + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil} + go func() { + <-shutdownC + httpServer.Close() + }() + + http.HandleFunc("/uptime", uptimeHandler(time.Now())) + http.HandleFunc("/ws", websocketHandler(upgrader)) + http.HandleFunc("/", rootHandler(serverName)) + err := httpServer.Serve(listener) + return err +} + +func createListener(address string) (net.Listener, error) { + certificate, err := tlsconfig.GetHelloCertificate() + if err != nil { + return nil, err + } + + // If the port in address is empty, a port number is automatically chosen + listener, err := tls.Listen( + "tcp", + address, + &tls.Config{Certificates: []tls.Certificate{certificate}}) + + return listener, err +} + +func uptimeHandler(startTime time.Time) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Note that if autoupdate is enabled, the uptime is reset when a new client + // release is available + resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()} + respJson, err := json.Marshal(resp) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } else { + w.Header().Set("Content-Type", "application/json") + w.Write(respJson) + } + } +} + +func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + for { + mt, message, err := conn.ReadMessage() + if err != nil { + break + } + + if err := conn.WriteMessage(mt, message); err != nil { + break + } + } + } +} + +func rootHandler(serverName string) http.HandlerFunc { + responseTemplate := template.Must(template.New("index").Parse(indexTemplate)) + return func(w http.ResponseWriter, r *http.Request) { + var buffer bytes.Buffer + var body string + rawBody, err := ioutil.ReadAll(r.Body) + if err == nil { + body = string(rawBody) + } else { + body = "" + } + err = responseTemplate.Execute(&buffer, &templateData{ + ServerName: serverName, + Request: r, + Body: body, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "error: %v", err) + } else { + buffer.WriteTo(w) + } + } +} diff --git a/cmd/cloudflared/hello_test.go b/cmd/cloudflared/hello_test.go new file mode 100644 index 00000000..f6e8842a --- /dev/null +++ b/cmd/cloudflared/hello_test.go @@ -0,0 +1,35 @@ +package main + +import ( + "testing" +) + +func TestCreateListenerHostAndPortSuccess(t *testing.T) { + listener, err := createListener("localhost:1234") + if err != nil { + t.Fatal(err) + } + if listener.Addr().String() == "" { + t.Fatal("Fail to find available port") + } +} + +func TestCreateListenerOnlyHostSuccess(t *testing.T) { + listener, err := createListener("localhost:") + if err != nil { + t.Fatal(err) + } + if listener.Addr().String() == "" { + t.Fatal("Fail to find available port") + } +} + +func TestCreateListenerOnlyPortSuccess(t *testing.T) { + listener, err := createListener(":8888") + if err != nil { + t.Fatal(err) + } + if listener.Addr().String() == "" { + t.Fatal("Fail to find available port") + } +} diff --git a/cmd/cloudflared/linux_service.go b/cmd/cloudflared/linux_service.go new file mode 100644 index 00000000..d0b15937 --- /dev/null +++ b/cmd/cloudflared/linux_service.go @@ -0,0 +1,292 @@ +// +build linux + +package main + +import ( + "fmt" + "os" + "path/filepath" + + cli "gopkg.in/urfave/cli.v2" +) + +func runApp(app *cli.App) { + app.Commands = append(app.Commands, &cli.Command{ + Name: "service", + Usage: "Manages the Argo Tunnel system service", + Subcommands: []*cli.Command{ + &cli.Command{ + Name: "install", + Usage: "Install Argo Tunnel as a system service", + Action: installLinuxService, + }, + &cli.Command{ + Name: "uninstall", + Usage: "Uninstall the Argo Tunnel service", + Action: uninstallLinuxService, + }, + }, + }) + app.Run(os.Args) +} + +const serviceConfigDir = "/etc/cloudflared" + +var systemdTemplates = []ServiceTemplate{ + { + Path: "/etc/systemd/system/cloudflared.service", + Content: `[Unit] +Description=Argo Tunnel +After=network.target + +[Service] +TimeoutStartSec=0 +Type=notify +ExecStart={{ .Path }} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --no-autoupdate +Restart=on-failure +RestartSec=5s + +[Install] +WantedBy=multi-user.target +`, + }, + { + Path: "/etc/systemd/system/cloudflared-update.service", + Content: `[Unit] +Description=Update Argo Tunnel +After=network.target + +[Service] +ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 64 ]; then systemctl restart cloudflared; exit 0; fi; exit $code' +`, + }, + { + Path: "/etc/systemd/system/cloudflared-update.timer", + Content: `[Unit] +Description=Update Argo Tunnel + +[Timer] +OnUnitActiveSec=1d + +[Install] +WantedBy=timers.target +`, + }, +} + +var sysvTemplate = ServiceTemplate{ + Path: "/etc/init.d/cloudflared", + FileMode: 0755, + Content: `# For RedHat and cousins: +# chkconfig: 2345 99 01 +# description: Argo Tunnel agent +# processname: {{.Path}} +### BEGIN INIT INFO +# Provides: {{.Path}} +# Required-Start: +# Required-Stop: +# Default-Start: 2 3 4 5 +# Default-Stop: 0 1 6 +# Short-Description: Argo Tunnel +# Description: Argo Tunnel agent +### END INIT INFO +cmd="{{.Path}} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --pidfile /var/run/$name.pid --autoupdate-freq 24h0m0s" +name=$(basename $(readlink -f $0)) +pid_file="/var/run/$name.pid" +stdout_log="/var/log/$name.log" +stderr_log="/var/log/$name.err" +[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name +get_pid() { + cat "$pid_file" +} +is_running() { + [ -f "$pid_file" ] && ps $(get_pid) > /dev/null 2>&1 +} +case "$1" in + start) + if is_running; then + echo "Already started" + else + echo "Starting $name" + $cmd >> "$stdout_log" 2>> "$stderr_log" & + echo $! > "$pid_file" + if ! is_running; then + echo "Unable to start, see $stdout_log and $stderr_log" + exit 1 + fi + fi + ;; + stop) + if is_running; then + echo -n "Stopping $name.." + kill $(get_pid) + for i in {1..10} + do + if ! is_running; then + break + fi + echo -n "." + sleep 1 + done + echo + if is_running; then + echo "Not stopped; may still be shutting down or shutdown may have failed" + exit 1 + else + echo "Stopped" + if [ -f "$pid_file" ]; then + rm "$pid_file" + fi + fi + else + echo "Not running" + fi + ;; + restart) + $0 stop + if is_running; then + echo "Unable to stop, will not attempt to start" + exit 1 + fi + $0 start + ;; + status) + if is_running; then + echo "Running" + else + echo "Stopped" + exit 1 + fi + ;; + *) + echo "Usage: $0 {start|stop|restart|status}" + exit 1 + ;; +esac +exit 0 +`, +} + +func isSystemd() bool { + if _, err := os.Stat("/run/systemd/system"); err == nil { + return true + } + return false +} + +func installLinuxService(c *cli.Context) error { + etPath, err := os.Executable() + if err != nil { + return fmt.Errorf("error determining executable path: %v", err) + } + templateArgs := ServiceTemplateArgs{Path: etPath} + + defaultConfigDir := filepath.Dir(c.String("config")) + defaultConfigFile := filepath.Base(c.String("config")) + if err = copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile); err != nil { + Log.WithError(err).Infof("Failed to copy user configuration. Before running the service, ensure that %s contains two files, %s and %s", + serviceConfigDir, credentialFile, defaultConfigFiles[0]) + return err + } + + switch { + case isSystemd(): + Log.Infof("Using Systemd") + return installSystemd(&templateArgs) + default: + Log.Infof("Using Sysv") + return installSysv(&templateArgs) + } +} + +func installSystemd(templateArgs *ServiceTemplateArgs) error { + for _, serviceTemplate := range systemdTemplates { + err := serviceTemplate.Generate(templateArgs) + if err != nil { + Log.WithError(err).Infof("error generating service template") + return err + } + } + if err := runCommand("systemctl", "enable", "cloudflared.service"); err != nil { + Log.WithError(err).Infof("systemctl enable cloudflared.service error") + return err + } + if err := runCommand("systemctl", "start", "cloudflared-update.timer"); err != nil { + Log.WithError(err).Infof("systemctl start cloudflared-update.timer error") + return err + } + Log.Infof("systemctl daemon-reload") + return runCommand("systemctl", "daemon-reload") +} + +func installSysv(templateArgs *ServiceTemplateArgs) error { + confPath, err := sysvTemplate.ResolvePath() + if err != nil { + Log.WithError(err).Infof("error resolving system path") + return err + } + if err := sysvTemplate.Generate(templateArgs); err != nil { + Log.WithError(err).Infof("error generating system template") + return err + } + for _, i := range [...]string{"2", "3", "4", "5"} { + if err := os.Symlink(confPath, "/etc/rc"+i+".d/S50et"); err != nil { + continue + } + } + for _, i := range [...]string{"0", "1", "6"} { + if err := os.Symlink(confPath, "/etc/rc"+i+".d/K02et"); err != nil { + continue + } + } + return nil +} + +func uninstallLinuxService(c *cli.Context) error { + switch { + case isSystemd(): + Log.Infof("Using Systemd") + return uninstallSystemd() + default: + Log.Infof("Using Sysv") + return uninstallSysv() + } +} + +func uninstallSystemd() error { + if err := runCommand("systemctl", "disable", "cloudflared.service"); err != nil { + Log.WithError(err).Infof("systemctl disable cloudflared.service error") + return err + } + if err := runCommand("systemctl", "stop", "cloudflared-update.timer"); err != nil { + Log.WithError(err).Infof("systemctl stop cloudflared-update.timer error") + return err + } + for _, serviceTemplate := range systemdTemplates { + if err := serviceTemplate.Remove(); err != nil { + Log.WithError(err).Infof("error removing service template") + return err + } + } + Log.Infof("Successfully uninstall cloudflared service") + return nil +} + +func uninstallSysv() error { + if err := sysvTemplate.Remove(); err != nil { + Log.WithError(err).Infof("error removing service template") + return err + } + for _, i := range [...]string{"2", "3", "4", "5"} { + if err := os.Remove("/etc/rc" + i + ".d/S50et"); err != nil { + continue + } + } + for _, i := range [...]string{"0", "1", "6"} { + if err := os.Remove("/etc/rc" + i + ".d/K02et"); err != nil { + continue + } + } + Log.Infof("Successfully uninstall cloudflared service") + return nil +} diff --git a/cmd/cloudflared/login.go b/cmd/cloudflared/login.go new file mode 100644 index 00000000..fa3c675e --- /dev/null +++ b/cmd/cloudflared/login.go @@ -0,0 +1,194 @@ +package main + +import ( + "crypto/rand" + "encoding/base32" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "runtime" + "syscall" + "time" + + homedir "github.com/mitchellh/go-homedir" + cli "gopkg.in/urfave/cli.v2" +) + +const baseLoginURL = "https://www.cloudflare.com/a/warp" +const baseCertStoreURL = "https://login.cloudflarewarp.com" +const clientTimeout = time.Minute * 20 + +func login(c *cli.Context) error { + configPath, err := homedir.Expand(defaultConfigDirs[0]) + if err != nil { + return err + } + ok, err := fileExists(configPath) + if !ok && err == nil { + // create config directory if doesn't already exist + err = os.Mkdir(configPath, 0700) + } + if err != nil { + return err + } + path := filepath.Join(configPath, credentialFile) + fileInfo, err := os.Stat(path) + if err == nil && fileInfo.Size() > 0 { + fmt.Fprintf(os.Stderr, `You have an existing certificate at %s which login would overwrite. +If this is intentional, please move or delete that file then run this command again. +`, path) + return nil + } + if err != nil && err.(*os.PathError).Err != syscall.ENOENT { + return err + } + + // for local debugging + baseURL := baseCertStoreURL + if c.IsSet("url") { + baseURL = c.String("url") + } + // Generate a random post URL + certURL := baseURL + generateRandomPath() + loginURL, err := url.Parse(baseLoginURL) + if err != nil { + // shouldn't happen, URL is hardcoded + return err + } + loginURL.RawQuery = "callback=" + url.QueryEscape(certURL) + + err = open(loginURL.String()) + if err != nil { + fmt.Fprintf(os.Stderr, `Please open the following URL and log in with your Cloudflare account: + +%s + +Leave cloudflared running to install the certificate automatically. +`, loginURL.String()) + } else { + fmt.Fprintf(os.Stderr, `A browser window should have opened at the following URL: + +%s + +If the browser failed to open, open it yourself and visit the URL above. + +`, loginURL.String()) + } + + if download(certURL, path) { + fmt.Fprintf(os.Stderr, `You have successfully logged in. +If you wish to copy your credentials to a server, they have been saved to: +%s +`, path) + } else { + fmt.Fprintf(os.Stderr, `Failed to write the certificate due to the following error: +%v + +Your browser will download the certificate instead. You will have to manually +copy it to the following path: + +%s + +`, err, path) + } + return nil +} + +// generateRandomPath generates a random URL to associate with the certificate. +func generateRandomPath() string { + randomBytes := make([]byte, 40) + _, err := rand.Read(randomBytes) + if err != nil { + panic(err) + } + return "/" + base32.StdEncoding.EncodeToString(randomBytes) +} + +// open opens the specified URL in the default browser of the user. +func open(url string) error { + var cmd string + var args []string + + switch runtime.GOOS { + case "windows": + cmd = "cmd" + args = []string{"/c", "start"} + case "darwin": + cmd = "open" + default: // "linux", "freebsd", "openbsd", "netbsd" + cmd = "xdg-open" + } + args = append(args, url) + return exec.Command(cmd, args...).Start() +} + +func download(certURL, filePath string) bool { + client := &http.Client{Timeout: clientTimeout} + // attempt a (long-running) certificate get + for i := 0; i < 20; i++ { + ok, err := tryDownload(client, certURL, filePath) + if ok { + putSuccess(client, certURL) + return true + } + if err != nil { + Log.WithError(err).Error("Error fetching certificate") + return false + } + } + return false +} + +func tryDownload(client *http.Client, certURL, filePath string) (ok bool, err error) { + resp, err := client.Get(certURL) + if err != nil { + return false, err + } + defer resp.Body.Close() + if resp.StatusCode == 404 { + return false, nil + } + if resp.StatusCode != 200 { + return false, fmt.Errorf("Unexpected HTTP error code %d", resp.StatusCode) + } + if resp.Header.Get("Content-Type") != "application/x-pem-file" { + return false, fmt.Errorf("Unexpected content type %s", resp.Header.Get("Content-Type")) + } + // write response + file, err := os.Create(filePath) + if err != nil { + return false, err + } + defer file.Close() + written, err := io.Copy(file, resp.Body) + switch { + case err != nil: + return false, err + case resp.ContentLength != written && resp.ContentLength != -1: + return false, fmt.Errorf("Short read (%d bytes) from server while writing certificate", written) + default: + return true, nil + } +} + +func putSuccess(client *http.Client, certURL string) { + // indicate success to the relay server + req, err := http.NewRequest("PUT", certURL+"/ok", nil) + if err != nil { + Log.WithError(err).Error("HTTP request error") + return + } + resp, err := client.Do(req) + if err != nil { + Log.WithError(err).Error("HTTP error") + return + } + resp.Body.Close() + if resp.StatusCode != 200 { + Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode) + } +} diff --git a/cmd/cloudflared/macos_service.go b/cmd/cloudflared/macos_service.go new file mode 100644 index 00000000..95be8454 --- /dev/null +++ b/cmd/cloudflared/macos_service.go @@ -0,0 +1,97 @@ +// +build darwin + +package main + +import ( + "fmt" + "os" + + "gopkg.in/urfave/cli.v2" +) + +const launchAgentIdentifier = "com.cloudflare.cloudflared" + +func runApp(app *cli.App) { + app.Commands = append(app.Commands, &cli.Command{ + Name: "service", + Usage: "Manages the Argo Tunnel launch agent", + Subcommands: []*cli.Command{ + { + Name: "install", + Usage: "Install Argo Tunnel as an user launch agent", + Action: installLaunchd, + }, + { + Name: "uninstall", + Usage: "Uninstall the Argo Tunnel launch agent", + Action: uninstallLaunchd, + }, + }, + }) + app.Run(os.Args) +} + +var launchdTemplate = ServiceTemplate{ + Path: fmt.Sprintf("~/Library/LaunchAgents/%s.plist", launchAgentIdentifier), + Content: fmt.Sprintf(` + + + + Label + %s + Program + {{ .Path }} + RunAtLoad + + StandardOutPath + /tmp/%s.out.log + StandardErrorPath + /tmp/%s.err.log + KeepAlive + + NetworkState + + + ThrottleInterval + 20 + +`, launchAgentIdentifier, launchAgentIdentifier, launchAgentIdentifier), +} + +func installLaunchd(c *cli.Context) error { + Log.Infof("Installing Argo Tunnel as an user launch agent") + etPath, err := os.Executable() + if err != nil { + Log.WithError(err).Infof("error determining executable path") + return fmt.Errorf("error determining executable path: %v", err) + } + templateArgs := ServiceTemplateArgs{Path: etPath} + err = launchdTemplate.Generate(&templateArgs) + if err != nil { + Log.WithError(err).Infof("error generating launchd template") + return err + } + plistPath, err := launchdTemplate.ResolvePath() + if err != nil { + Log.WithError(err).Infof("error resolving launchd template path") + return err + } + Log.Infof("Outputs are logged in %s and %s", fmt.Sprintf("/tmp/%s.out.log", launchAgentIdentifier), fmt.Sprintf("/tmp/%s.err.log", launchAgentIdentifier)) + return runCommand("launchctl", "load", plistPath) +} + +func uninstallLaunchd(c *cli.Context) error { + Log.Infof("Uninstalling Argo Tunnel as an user launch agent") + plistPath, err := launchdTemplate.ResolvePath() + if err != nil { + Log.WithError(err).Infof("error resolving launchd template path") + return err + } + err = runCommand("launchctl", "unload", plistPath) + if err != nil { + Log.WithError(err).Infof("error unloading") + return err + } + Log.Infof("Outputs are logged in %s and %s", fmt.Sprintf("/tmp/%s.out.log", launchAgentIdentifier), fmt.Sprintf("/tmp/%s.err.log", launchAgentIdentifier)) + return launchdTemplate.Remove() +} diff --git a/cmd/cloudflared/main.go b/cmd/cloudflared/main.go new file mode 100644 index 00000000..06bf641f --- /dev/null +++ b/cmd/cloudflared/main.go @@ -0,0 +1,794 @@ +package main + +import ( + "crypto/tls" + "encoding/hex" + "fmt" + "io/ioutil" + "math/rand" + "net" + "net/http" + "os" + "os/signal" + "path/filepath" + "runtime" + "strings" + "sync" + "syscall" + "time" + + "github.com/cloudflare/cloudflared/metrics" + "github.com/cloudflare/cloudflared/origin" + "github.com/cloudflare/cloudflared/tlsconfig" + "github.com/cloudflare/cloudflared/tunneldns" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/validation" + + "github.com/facebookgo/grace/gracenet" + "github.com/getsentry/raven-go" + "github.com/mitchellh/go-homedir" + "github.com/rifflock/lfshook" + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh/terminal" + "gopkg.in/urfave/cli.v2" + "gopkg.in/urfave/cli.v2/altsrc" + + "github.com/coreos/go-systemd/daemon" + "github.com/pkg/errors" +) + +const ( + sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878" + credentialFile = "cert.pem" + quickStartUrl = "https://developers.cloudflare.com/argo-tunnel/quickstart/quickstart/" + noAutoupdateMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/" + licenseUrl = "https://developers.cloudflare.com/argo-tunnel/licence/" +) + +var listeners = gracenet.Net{} +var Version = "DEV" +var BuildTime = "unknown" +var Log *logrus.Logger +var defaultConfigFiles = []string{"config.yml", "config.yaml"} + +// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible +var defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp"} + +// Shutdown channel used by the app. When closed, app must terminate. +// May be closed by the Windows service runner. +var shutdownC chan struct{} + +type BuildAndRuntimeInfo struct { + GoOS string `json:"go_os"` + GoVersion string `json:"go_version"` + GoArch string `json:"go_arch"` + WarpVersion string `json:"warp_version"` + WarpFlags map[string]interface{} `json:"warp_flags"` + WarpEnvs map[string]string `json:"warp_envs"` +} + +func main() { + metrics.RegisterBuildInfo(BuildTime, Version) + raven.SetDSN(sentryDSN) + raven.SetRelease(Version) + shutdownC = make(chan struct{}) + app := &cli.App{} + app.Name = "cloudflared" + app.Copyright = fmt.Sprintf(`(c) %d Cloudflare Inc. + Use is subject to the license agreement at %s`, time.Now().Year(), licenseUrl) + app.Usage = "Cloudflare reverse tunnelling proxy agent" + app.ArgsUsage = "origin-url" + app.Version = fmt.Sprintf("%s (built %s)", Version, BuildTime) + app.Description = `A reverse tunnel proxy agent that connects to Cloudflare's infrastructure. + Upon connecting, you are assigned a unique subdomain on cftunnel.com. + You need to specify a hostname on a zone you control. + A DNS record will be created to CNAME your hostname to the unique subdomain on cftunnel.com. + + Requests made to Cloudflare's servers for your hostname will be proxied + through the tunnel to your local webserver.` + app.Flags = []cli.Flag{ + &cli.StringFlag{ + Name: "config", + Usage: "Specifies a config file in YAML format.", + Value: findDefaultConfigPath(), + }, + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "autoupdate-freq", + Usage: "Autoupdate frequency. Default is 24h.", + Value: time.Hour * 24, + }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: "no-autoupdate", + Usage: "Disable periodic check for updates, restarting the server with the new version.", + Value: false, + }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: "is-autoupdated", + Usage: "Signal the new process that Argo Tunnel client has been autoupdated", + Value: false, + Hidden: true, + }), + altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ + Name: "edge", + Usage: "Address of the Cloudflare tunnel server.", + EnvVars: []string{"TUNNEL_EDGE"}, + Hidden: true, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "cacert", + Usage: "Certificate Authority authenticating the Cloudflare tunnel connection.", + EnvVars: []string{"TUNNEL_CACERT"}, + Hidden: true, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "origincert", + Usage: "Path to the certificate generated for your origin when you run cloudflared login.", + EnvVars: []string{"TUNNEL_ORIGIN_CERT"}, + Value: findDefaultOriginCertPath(), + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "url", + Value: "https://localhost:8080", + Usage: "Connect to the local webserver at `URL`.", + EnvVars: []string{"TUNNEL_URL"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "hostname", + Usage: "Set a hostname on a Cloudflare zone to route traffic through this tunnel.", + EnvVars: []string{"TUNNEL_HOSTNAME"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "origin-server-name", + Usage: "Hostname on the origin server certificate.", + EnvVars: []string{"TUNNEL_ORIGIN_SERVER_NAME"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "id", + Usage: "A unique identifier used to tie connections to this tunnel instance.", + EnvVars: []string{"TUNNEL_ID"}, + Hidden: true, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "lb-pool", + Usage: "The name of a (new/existing) load balancing pool to add this origin to.", + EnvVars: []string{"TUNNEL_LB_POOL"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "api-key", + Usage: "This parameter has been deprecated since version 2017.10.1.", + EnvVars: []string{"TUNNEL_API_KEY"}, + Hidden: true, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "api-email", + Usage: "This parameter has been deprecated since version 2017.10.1.", + EnvVars: []string{"TUNNEL_API_EMAIL"}, + Hidden: true, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "api-ca-key", + Usage: "This parameter has been deprecated since version 2017.10.1.", + EnvVars: []string{"TUNNEL_API_CA_KEY"}, + Hidden: true, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "metrics", + Value: "localhost:", + Usage: "Listen address for metrics reporting.", + EnvVars: []string{"TUNNEL_METRICS"}, + }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "metrics-update-freq", + Usage: "Frequency to update tunnel metrics", + Value: time.Second * 5, + EnvVars: []string{"TUNNEL_METRICS_UPDATE_FREQ"}, + }), + altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ + Name: "tag", + Usage: "Custom tags used to identify this tunnel, in format `KEY=VALUE`. Multiple tags may be specified", + EnvVars: []string{"TUNNEL_TAG"}, + }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "heartbeat-interval", + Usage: "Minimum idle time before sending a heartbeat.", + Value: time.Second * 5, + Hidden: true, + }), + altsrc.NewUint64Flag(&cli.Uint64Flag{ + Name: "heartbeat-count", + Usage: "Minimum number of unacked heartbeats to send before closing the connection.", + Value: 5, + Hidden: true, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "loglevel", + Value: "info", + Usage: "Application logging level {panic, fatal, error, warn, info, debug}", + EnvVars: []string{"TUNNEL_LOGLEVEL"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "proto-loglevel", + Value: "warn", + Usage: "Protocol logging level {panic, fatal, error, warn, info, debug}", + EnvVars: []string{"TUNNEL_PROTO_LOGLEVEL"}, + }), + altsrc.NewUintFlag(&cli.UintFlag{ + Name: "retries", + Value: 5, + Usage: "Maximum number of retries for connection/protocol errors.", + EnvVars: []string{"TUNNEL_RETRIES"}, + }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: "hello-world", + Value: false, + Usage: "Run Hello World Server", + EnvVars: []string{"TUNNEL_HELLO_WORLD"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "pidfile", + Usage: "Write the application's PID to this file after first successful connection.", + EnvVars: []string{"TUNNEL_PIDFILE"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "logfile", + Usage: "Save application log to this file for reporting issues.", + EnvVars: []string{"TUNNEL_LOGFILE"}, + }), + altsrc.NewIntFlag(&cli.IntFlag{ + Name: "ha-connections", + Value: 4, + Hidden: true, + }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "proxy-connect-timeout", + Usage: "HTTP proxy timeout for establishing a new connection", + Value: time.Second * 30, + }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "proxy-tls-timeout", + Usage: "HTTP proxy timeout for completing a TLS handshake", + Value: time.Second * 10, + }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "proxy-tcp-keepalive", + Usage: "HTTP proxy TCP keepalive duration", + Value: time.Second * 30, + }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: "proxy-no-happy-eyeballs", + Usage: "HTTP proxy should disable \"happy eyeballs\" for IPv4/v6 fallback", + }), + altsrc.NewIntFlag(&cli.IntFlag{ + Name: "proxy-keepalive-connections", + Usage: "HTTP proxy maximum keepalive connection pool size", + Value: 100, + }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "proxy-keepalive-timeout", + Usage: "HTTP proxy timeout for closing an idle connection", + Value: time.Second * 90, + }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: "proxy-dns", + Usage: "Run a DNS over HTTPS proxy server.", + EnvVars: []string{"TUNNEL_DNS"}, + }), + altsrc.NewUintFlag(&cli.UintFlag{ + Name: "proxy-dns-port", + Value: 53, + Usage: "Listen on given port for the DNS over HTTPS proxy server.", + EnvVars: []string{"TUNNEL_DNS_PORT"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "proxy-dns-address", + Usage: "Listen address for the DNS over HTTPS proxy server.", + Value: "localhost", + EnvVars: []string{"TUNNEL_DNS_ADDRESS"}, + }), + altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ + Name: "proxy-dns-upstream", + Usage: "Upstream endpoint URL, you can specify multiple endpoints for redundancy.", + Value: cli.NewStringSlice("https://dns.cloudflare.com/.well-known/dns-query"), + EnvVars: []string{"TUNNEL_DNS_UPSTREAM"}, + }), + } + app.Action = func(c *cli.Context) error { + raven.CapturePanic(func() { startServer(c) }, nil) + return nil + } + app.Before = func(context *cli.Context) error { + Log = logrus.New() + inputSource, err := findInputSourceContext(context) + if err != nil { + Log.WithError(err).Infof("Cannot load configuration from %s", context.String("config")) + return err + } else if inputSource != nil { + err := altsrc.ApplyInputSourceValues(context, inputSource, app.Flags) + if err != nil { + Log.WithError(err).Infof("Cannot apply configuration from %s", context.String("config")) + return err + } + Log.Infof("Applied configuration from %s", context.String("config")) + } + return nil + } + app.Commands = []*cli.Command{ + { + Name: "update", + Action: update, + Usage: "Update the agent if a new version exists", + ArgsUsage: " ", + Description: `Looks for a new version on the offical download server. + If a new version exists, updates the agent binary and quits. + Otherwise, does nothing. + + To determine if an update happened in a script, check for error code 64.`, + }, + { + Name: "login", + Action: login, + Usage: "Generate a configuration file with your login details", + ArgsUsage: " ", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "url", + Hidden: true, + }, + }, + }, + { + Name: "hello", + Action: hello, + Usage: "Run a simple \"Hello World\" server for testing Argo Tunnel.", + Flags: []cli.Flag{ + &cli.IntFlag{ + Name: "port", + Usage: "Listen on the selected port.", + Value: 8080, + }, + }, + ArgsUsage: " ", // can't be the empty string or we get the default output + }, + { + Name: "proxy-dns", + Action: tunneldns.Run, + Usage: "Run a DNS over HTTPS proxy server.", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "metrics", + Value: "localhost:", + Usage: "Listen address for metrics reporting.", + EnvVars: []string{"TUNNEL_METRICS"}, + }, + &cli.StringFlag{ + Name: "address", + Usage: "Listen address for the DNS over HTTPS proxy server.", + Value: "localhost", + EnvVars: []string{"TUNNEL_DNS_ADDRESS"}, + }, + &cli.IntFlag{ + Name: "port", + Usage: "Listen on given port for the DNS over HTTPS proxy server.", + Value: 53, + EnvVars: []string{"TUNNEL_DNS_PORT"}, + }, + &cli.StringSliceFlag{ + Name: "upstream", + Usage: "Upstream endpoint URL, you can specify multiple endpoints for redundancy.", + Value: cli.NewStringSlice("https://dns.cloudflare.com/.well-known/dns-query"), + EnvVars: []string{"TUNNEL_DNS_UPSTREAM"}, + }, + }, + ArgsUsage: " ", // can't be the empty string or we get the default output + }, + } + runApp(app) +} + +func startServer(c *cli.Context) { + var wg sync.WaitGroup + errC := make(chan error) + wg.Add(2) + + // If the user choose to supply all options through env variables, + // c.NumFlags() == 0 && c.NArg() == 0. For cloudflared to work, the user needs to at + // least provide a hostname. + if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" { + Log.Infof("No arguments were provided. You need to at least specify the hostname for this tunnel. See %s", quickStartUrl) + cli.ShowAppHelp(c) + return + } + logLevel, err := logrus.ParseLevel(c.String("loglevel")) + if err != nil { + Log.WithError(err).Fatal("Unknown logging level specified") + } + Log.SetLevel(logLevel) + + protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel")) + if err != nil { + Log.WithError(err).Fatal("Unknown protocol logging level specified") + } + protoLogger := logrus.New() + protoLogger.Level = protoLogLevel + + if c.String("logfile") != "" { + if err := initLogFile(c, protoLogger); err != nil { + Log.Error(err) + } + } + + if isAutoupdateEnabled(c) { + if initUpdate() { + return + } + Log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq")) + go autoupdate(c.Duration("autoupdate-freq"), shutdownC) + } + + hostname, err := validation.ValidateHostname(c.String("hostname")) + if err != nil { + Log.WithError(err).Fatal("Invalid hostname") + } + clientID := c.String("id") + if !c.IsSet("id") { + clientID = generateRandomClientID() + } + + tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) + if err != nil { + Log.WithError(err).Fatal("Tag parse failure") + } + + tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) + if c.IsSet("hello-world") { + wg.Add(1) + listener, err := createListener("127.0.0.1:") + if err != nil { + listener.Close() + Log.WithError(err).Fatal("Cannot start Hello World Server") + } + go func() { + startHelloWorldServer(listener, shutdownC) + wg.Done() + listener.Close() + }() + c.Set("url", "https://"+listener.Addr().String()) + } + + if c.IsSet("proxy-dns") { + wg.Add(1) + listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(c.Uint("proxy-dns-port")), c.StringSlice("proxy-dns-upstream")) + if err != nil { + listener.Stop() + Log.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server") + } + go func() { + listener.Start() + <-shutdownC + listener.Stop() + wg.Done() + }() + } + + url, err := validateUrl(c) + if err != nil { + Log.WithError(err).Fatal("Error validating url") + } + Log.Infof("Proxying tunnel requests to %s", url) + + // Fail if the user provided an old authentication method + if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") { + Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflared login") + } + + // Check that the user has acquired a certificate using the log in command + originCertPath, err := homedir.Expand(c.String("origincert")) + if err != nil { + Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert")) + } + ok, err := fileExists(originCertPath) + if err != nil { + Log.Fatalf("Cannot check if origin cert exists at path %s", c.String("origincert")) + } + if !ok { + Log.Fatalf(`Cannot find a valid certificate for your origin at the path: + + %s + +If the path above is wrong, specify the path with the -origincert option. +If you don't have a certificate signed by Cloudflare, run the command: + + %s login +`, originCertPath, os.Args[0]) + } + // Easier to send the certificate as []byte via RPC than decoding it at this point + originCert, err := ioutil.ReadFile(originCertPath) + if err != nil { + Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath) + } + + tunnelMetrics := origin.NewTunnelMetrics() + httpTransport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: c.Duration("proxy-connect-timeout"), + KeepAlive: c.Duration("proxy-tcp-keepalive"), + DualStack: !c.Bool("proxy-no-happy-eyeballs"), + }).DialContext, + MaxIdleConns: c.Int("proxy-keepalive-connections"), + IdleConnTimeout: c.Duration("proxy-keepalive-timeout"), + TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"), + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()}, + } + + if !c.IsSet("hello-world") && c.IsSet("origin-server-name") { + httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name") + } + + tunnelConfig := &origin.TunnelConfig{ + EdgeAddrs: c.StringSlice("edge"), + OriginUrl: url, + Hostname: hostname, + OriginCert: originCert, + TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")), + ClientTlsConfig: httpTransport.TLSClientConfig, + Retries: c.Uint("retries"), + HeartbeatInterval: c.Duration("heartbeat-interval"), + MaxHeartbeats: c.Uint64("heartbeat-count"), + ClientID: clientID, + ReportedVersion: Version, + LBPool: c.String("lb-pool"), + Tags: tags, + HAConnections: c.Int("ha-connections"), + HTTPTransport: httpTransport, + Metrics: tunnelMetrics, + MetricsUpdateFreq: c.Duration("metrics-update-freq"), + ProtocolLogger: protoLogger, + Logger: Log, + IsAutoupdated: c.Bool("is-autoupdated"), + } + connectedSignal := make(chan struct{}) + + go writePidFile(connectedSignal, c.String("pidfile")) + go func() { + errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal) + wg.Done() + }() + + metricsListener, err := listeners.Listen("tcp", c.String("metrics")) + if err != nil { + Log.WithError(err).Fatal("Error opening metrics server listener") + } + go func() { + errC <- metrics.ServeMetrics(metricsListener, shutdownC) + wg.Done() + }() + + var errCode int + err = WaitForSignal(errC, shutdownC) + if err != nil { + Log.WithError(err).Error("Quitting due to error") + raven.CaptureErrorAndWait(err, nil) + errCode = 1 + } else { + Log.Info("Quitting...") + } + // Wait for clean exit, discarding all errors + go func() { + for range errC { + } + }() + wg.Wait() + os.Exit(errCode) +} + +func WaitForSignal(errC chan error, shutdownC chan struct{}) error { + signals := make(chan os.Signal, 10) + signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) + defer signal.Stop(signals) + select { + case err := <-errC: + close(shutdownC) + return err + case <-signals: + close(shutdownC) + case <-shutdownC: + } + return nil +} + +func update(_ *cli.Context) error { + if updateApplied() { + os.Exit(64) + } + return nil +} + +func initUpdate() bool { + if updateApplied() { + os.Args = append(os.Args, "--is-autoupdated=true") + if _, err := listeners.StartProcess(); err != nil { + Log.WithError(err).Error("Unable to restart server automatically") + return false + } + return true + } + return false +} + +func autoupdate(freq time.Duration, shutdownC chan struct{}) { + for { + if updateApplied() { + os.Args = append(os.Args, "--is-autoupdated=true") + if _, err := listeners.StartProcess(); err != nil { + Log.WithError(err).Error("Unable to restart server automatically") + } + close(shutdownC) + return + } + time.Sleep(freq) + } +} + +func updateApplied() bool { + releaseInfo := checkForUpdates() + if releaseInfo.Updated { + Log.Infof("Updated to version %s", releaseInfo.Version) + return true + } + if releaseInfo.Error != nil { + Log.WithError(releaseInfo.Error).Error("Update check failed") + } + return false +} + +func fileExists(path string) (bool, error) { + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + // ignore missing files + return false, nil + } + return false, err + } + f.Close() + return true, nil +} + +// returns the first path that contains a cert.pem file. If none of the defaultConfigDirs +// (differs by OS for legacy reasons) contains a cert.pem file, return empty string +func findDefaultOriginCertPath() string { + for _, defaultConfigDir := range defaultConfigDirs { + originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, credentialFile)) + if ok, _ := fileExists(originCertPath); ok { + return originCertPath + } + } + return "" +} + +// returns the firt path that contains a config file. If none of the combination of +// defaultConfigDirs (differs by OS for legacy reasons) and defaultConfigFiles +// contains a config file, return empty string +func findDefaultConfigPath() string { + for _, configDir := range defaultConfigDirs { + for _, configFile := range defaultConfigFiles { + dirPath, err := homedir.Expand(configDir) + if err != nil { + return "" + } + path := filepath.Join(dirPath, configFile) + if ok, _ := fileExists(path); ok { + return path + } + } + } + return "" +} + +func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) { + if context.String("config") != "" { + return altsrc.NewYamlSourceFromFile(context.String("config")) + } + return nil, nil +} + +func generateRandomClientID() string { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + id := make([]byte, 32) + r.Read(id) + return hex.EncodeToString(id) +} + +func writePidFile(waitForSignal chan struct{}, pidFile string) { + <-waitForSignal + daemon.SdNotify(false, "READY=1") + if pidFile == "" { + return + } + file, err := os.Create(pidFile) + if err != nil { + Log.WithError(err).Errorf("Unable to write pid to %s", pidFile) + } + defer file.Close() + fmt.Fprintf(file, "%d", os.Getpid()) +} + +// validate url. It can be either from --url or argument +func validateUrl(c *cli.Context) (string, error) { + var url = c.String("url") + if c.NArg() > 0 { + if c.IsSet("url") { + return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.") + } + url = c.Args().Get(0) + } + validUrl, err := validation.ValidateUrl(url) + return validUrl, err +} + +func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error { + filePath, err := homedir.Expand(c.String("logfile")) + if err != nil { + return errors.Wrap(err, "Cannot resolve logfile path") + } + + fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC + // do not truncate log file if the client has been autoupdated + if c.Bool("is-autoupdated") { + fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE + } + f, err := os.OpenFile(filePath, fileMode, 0664) + if err != nil { + errors.Wrap(err, fmt.Sprintf("Cannot open file %s", filePath)) + } + defer f.Close() + pathMap := lfshook.PathMap{ + logrus.InfoLevel: filePath, + logrus.ErrorLevel: filePath, + logrus.FatalLevel: filePath, + logrus.PanicLevel: filePath, + } + + Log.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{})) + protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{})) + + flags := make(map[string]interface{}) + envs := make(map[string]string) + + for _, flag := range c.LocalFlagNames() { + flags[flag] = c.Generic(flag) + } + + // Find env variables for Argo Tunnel + for _, env := range os.Environ() { + // All Argo Tunnel env variables start with TUNNEL_ + if strings.Contains(env, "TUNNEL_") { + vars := strings.Split(env, "=") + if len(vars) == 2 { + envs[vars[0]] = vars[1] + } + } + } + + Log.Infof("Argo Tunnel build and runtime configuration: %+v", BuildAndRuntimeInfo{ + GoOS: runtime.GOOS, + GoVersion: runtime.Version(), + GoArch: runtime.GOARCH, + WarpVersion: Version, + WarpFlags: flags, + WarpEnvs: envs, + }) + + return nil +} + +func isAutoupdateEnabled(c *cli.Context) bool { + if terminal.IsTerminal(int(os.Stdout.Fd())) { + Log.Info(noAutoupdateMessage) + return false + } + + return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0 +} diff --git a/cmd/cloudflared/service_template.go b/cmd/cloudflared/service_template.go new file mode 100644 index 00000000..063e26cb --- /dev/null +++ b/cmd/cloudflared/service_template.go @@ -0,0 +1,192 @@ +package main + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "path/filepath" + "text/template" + + "github.com/mitchellh/go-homedir" +) + +type ServiceTemplate struct { + Path string + Content string + FileMode os.FileMode +} + +type ServiceTemplateArgs struct { + Path string +} + +func (st *ServiceTemplate) ResolvePath() (string, error) { + resolvedPath, err := homedir.Expand(st.Path) + if err != nil { + return "", fmt.Errorf("error resolving path %s: %v", st.Path, err) + } + return resolvedPath, nil +} + +func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error { + tmpl, err := template.New(st.Path).Parse(st.Content) + if err != nil { + return fmt.Errorf("error generating %s template: %v", st.Path, err) + } + resolvedPath, err := st.ResolvePath() + if err != nil { + return err + } + var buffer bytes.Buffer + err = tmpl.Execute(&buffer, args) + if err != nil { + return fmt.Errorf("error generating %s: %v", st.Path, err) + } + fileMode := os.FileMode(0644) + if st.FileMode != 0 { + fileMode = st.FileMode + } + err = ioutil.WriteFile(resolvedPath, buffer.Bytes(), fileMode) + if err != nil { + return fmt.Errorf("error writing %s: %v", resolvedPath, err) + } + return nil +} + +func (st *ServiceTemplate) Remove() error { + resolvedPath, err := st.ResolvePath() + if err != nil { + return err + } + err = os.Remove(resolvedPath) + if err != nil { + return fmt.Errorf("error deleting %s: %v", resolvedPath, err) + } + return nil +} + +func runCommand(command string, args ...string) error { + cmd := exec.Command(command, args...) + stderr, err := cmd.StderrPipe() + if err != nil { + Log.WithError(err).Infof("error getting stderr pipe") + return fmt.Errorf("error getting stderr pipe: %v", err) + } + err = cmd.Start() + if err != nil { + Log.WithError(err).Infof("error starting %s", command) + return fmt.Errorf("error starting %s: %v", command, err) + } + commandErr, _ := ioutil.ReadAll(stderr) + if len(commandErr) > 0 { + Log.Errorf("%s: %s", command, commandErr) + } + err = cmd.Wait() + if err != nil { + Log.WithError(err).Infof("%s returned error", command) + return fmt.Errorf("%s returned with error: %v", command, err) + } + return nil +} + +func ensureConfigDirExists(configDir string) error { + ok, err := fileExists(configDir) + if !ok && err == nil { + err = os.Mkdir(configDir, 0700) + } + return err +} + +// openFile opens the file at path. If create is set and the file exists, returns nil, true, nil +func openFile(path string, create bool) (file *os.File, exists bool, err error) { + expandedPath, err := homedir.Expand(path) + if err != nil { + return nil, false, err + } + if create { + fileInfo, err := os.Stat(expandedPath) + if err == nil && fileInfo.Size() > 0 { + return nil, true, nil + } + file, err = os.OpenFile(expandedPath, os.O_RDWR|os.O_CREATE, 0600) + } else { + file, err = os.Open(expandedPath) + } + return file, false, err +} + +func copyCertificate(srcConfigDir, destConfigDir string) error { + destCredentialPath := filepath.Join(destConfigDir, credentialFile) + destFile, exists, err := openFile(destCredentialPath, true) + if err != nil { + return err + } else if exists { + // credentials already exist, do nothing + return nil + } + defer destFile.Close() + + srcCredentialPath := filepath.Join(srcConfigDir, credentialFile) + srcFile, _, err := openFile(srcCredentialPath, false) + if err != nil { + return err + } + defer srcFile.Close() + + // Copy certificate + _, err = io.Copy(destFile, srcFile) + if err != nil { + return fmt.Errorf("unable to copy %s to %s: %v", srcCredentialPath, destCredentialPath, err) + } + + return nil +} + +func copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile string) error { + if err := ensureConfigDirExists(serviceConfigDir); err != nil { + return err + } + + if err := copyCertificate(defaultConfigDir, serviceConfigDir); err != nil { + return err + } + + // Copy or create config + destConfigPath := filepath.Join(serviceConfigDir, defaultConfigFile) + destFile, exists, err := openFile(destConfigPath, true) + if err != nil { + Log.WithError(err).Infof("cannot open %s", destConfigPath) + return err + } else if exists { + // config already exists, do nothing + return nil + } + defer destFile.Close() + + srcConfigPath := filepath.Join(defaultConfigDir, defaultConfigFile) + srcFile, _, err := openFile(srcConfigPath, false) + if err != nil { + fmt.Println("Your service needs a config file that at least specifies the hostname option.") + fmt.Println("Type in a hostname now, or leave it blank and create the config file later.") + fmt.Print("Hostname: ") + reader := bufio.NewReader(os.Stdin) + input, _ := reader.ReadString('\n') + if input == "" { + return err + } + fmt.Fprintf(destFile, "hostname: %s\n", input) + } else { + defer srcFile.Close() + _, err = io.Copy(destFile, srcFile) + if err != nil { + return fmt.Errorf("unable to copy %s to %s: %v", srcConfigPath, destConfigPath, err) + } + Log.Infof("Copied %s to %s", srcConfigPath, destConfigPath) + } + + return nil +} diff --git a/cmd/cloudflared/tag.go b/cmd/cloudflared/tag.go new file mode 100644 index 00000000..fde590ba --- /dev/null +++ b/cmd/cloudflared/tag.go @@ -0,0 +1,32 @@ +package main + +import ( + "fmt" + "regexp" + + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" +) + +// Restrict key names to characters allowed in an HTTP header name. +// Restrict key values to printable characters (what is recognised as data in an HTTP header value). +var tagRegexp = regexp.MustCompile("^([a-zA-Z0-9!#$%&'*+\\-.^_`|~]+)=([[:print:]]+)$") + +func NewTagFromCLI(compoundTag string) (tunnelpogs.Tag, bool) { + matches := tagRegexp.FindStringSubmatch(compoundTag) + if len(matches) == 0 { + return tunnelpogs.Tag{}, false + } + return tunnelpogs.Tag{Name: matches[1], Value: matches[2]}, true +} + +func NewTagSliceFromCLI(tags []string) ([]tunnelpogs.Tag, error) { + var tagSlice []tunnelpogs.Tag + for _, compoundTag := range tags { + if tag, ok := NewTagFromCLI(compoundTag); ok { + tagSlice = append(tagSlice, tag) + } else { + return nil, fmt.Errorf("Cannot parse tag value %s", compoundTag) + } + } + return tagSlice, nil +} diff --git a/cmd/cloudflared/tag_test.go b/cmd/cloudflared/tag_test.go new file mode 100644 index 00000000..25bf3243 --- /dev/null +++ b/cmd/cloudflared/tag_test.go @@ -0,0 +1,46 @@ +package main + +import ( + "testing" + + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + + "github.com/stretchr/testify/assert" +) + +func TestSingleTag(t *testing.T) { + testCases := []struct { + Input string + Output tunnelpogs.Tag + Fail bool + }{ + {Input: "x=y", Output: tunnelpogs.Tag{Name: "x", Value: "y"}}, + {Input: "More-Complex=Tag Values", Output: tunnelpogs.Tag{Name: "More-Complex", Value: "Tag Values"}}, + {Input: "First=Equals=Wins", Output: tunnelpogs.Tag{Name: "First", Value: "Equals=Wins"}}, + {Input: "x=", Fail: true}, + {Input: "=y", Fail: true}, + {Input: "=", Fail: true}, + {Input: "No spaces allowed=in key names", Fail: true}, + {Input: "omg\nwtf=bbq", Fail: true}, + } + for i, testCase := range testCases { + tag, ok := NewTagFromCLI(testCase.Input) + assert.Equalf(t, !testCase.Fail, ok, "mismatched success for test case %d", i) + assert.Equalf(t, testCase.Output, tag, "mismatched output for test case %d", i) + } +} + +func TestTagSlice(t *testing.T) { + tagSlice, err := NewTagSliceFromCLI([]string{"a=b", "c=d", "e=f"}) + assert.NoError(t, err) + assert.Len(t, tagSlice, 3) + assert.Equal(t, "a", tagSlice[0].Name) + assert.Equal(t, "b", tagSlice[0].Value) + assert.Equal(t, "c", tagSlice[1].Name) + assert.Equal(t, "d", tagSlice[1].Value) + assert.Equal(t, "e", tagSlice[2].Name) + assert.Equal(t, "f", tagSlice[2].Value) + + tagSlice, err = NewTagSliceFromCLI([]string{"a=b", "=", "e=f"}) + assert.Error(t, err) +} diff --git a/cmd/cloudflared/update.go b/cmd/cloudflared/update.go new file mode 100644 index 00000000..dc2f0cd0 --- /dev/null +++ b/cmd/cloudflared/update.go @@ -0,0 +1,41 @@ +package main + +import "github.com/equinox-io/equinox" + +const appID = "app_cwbQae3Tpea" + +var publicKey = []byte(` +-----BEGIN ECDSA PUBLIC KEY----- +MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE4OWZocTVZ8Do/L6ScLdkV+9A0IYMHoOf +dsCmJ/QZ6aw0w9qkkwEpne1Lmo6+0pGexZzFZOH6w5amShn+RXt7qkSid9iWlzGq +EKx0BZogHSor9Wy5VztdFaAaVbsJiCbO +-----END ECDSA PUBLIC KEY----- +`) + +type ReleaseInfo struct { + Updated bool + Version string + Error error +} + +func checkForUpdates() ReleaseInfo { + var opts equinox.Options + if err := opts.SetPublicKeyPEM(publicKey); err != nil { + return ReleaseInfo{Error: err} + } + + resp, err := equinox.Check(appID, opts) + switch { + case err == equinox.NotAvailableErr: + return ReleaseInfo{} + case err != nil: + return ReleaseInfo{Error: err} + } + + err = resp.Apply() + if err != nil { + return ReleaseInfo{Error: err} + } + + return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion} +} diff --git a/cmd/cloudflared/windows_service.go b/cmd/cloudflared/windows_service.go new file mode 100644 index 00000000..b003bb50 --- /dev/null +++ b/cmd/cloudflared/windows_service.go @@ -0,0 +1,166 @@ +// +build windows + +package main + +// Copypasta from the example files: +// https://github.com/golang/sys/blob/master/windows/svc/example + +import ( + "fmt" + "os" + + cli "gopkg.in/urfave/cli.v2" + + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/eventlog" + "golang.org/x/sys/windows/svc/mgr" +) + +const ( + windowsServiceName = "Cloudflared" + windowsServiceDescription = "Argo Tunnel agent" +) + +func runApp(app *cli.App) { + app.Commands = append(app.Commands, &cli.Command{ + Name: "service", + Usage: "Manages the Argo Tunnel Windows service", + Subcommands: []*cli.Command{ + &cli.Command{ + Name: "install", + Usage: "Install Argo Tunnel as a Windows service", + Action: installWindowsService, + }, + &cli.Command{ + Name: "uninstall", + Usage: "Uninstall the Argo Tunnel service", + Action: uninstallWindowsService, + }, + }, + }) + + isIntSess, err := svc.IsAnInteractiveSession() + if err != nil { + Log.Fatalf("failed to determine if we are running in an interactive session: %v", err) + } + + if isIntSess { + app.Run(os.Args) + return + } + + elog, err := eventlog.Open(windowsServiceName) + if err != nil { + Log.WithError(err).Infof("Cannot open event log for %s", windowsServiceName) + return + } + defer elog.Close() + + elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName)) + // Run executes service name by calling windowsService which is a Handler + // interface that implements Execute method + err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog}) + if err != nil { + elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err)) + return + } + elog.Info(1, fmt.Sprintf("%s service stopped", windowsServiceName)) +} + +type windowsService struct { + app *cli.App + elog *eventlog.Log +} + +// called by the package code at the start of the service +func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) { + const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown + changes <- svc.Status{State: svc.StartPending} + go s.app.Run(args) + + changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} +loop: + for { + select { + case c := <-r: + switch c.Cmd { + case svc.Interrogate: + s.elog.Info(1, fmt.Sprintf("control request 1 #%d", c)) + changes <- c.CurrentStatus + case svc.Stop: + s.elog.Info(1, "received stop control request") + break loop + case svc.Shutdown: + s.elog.Info(1, "received shutdown control request") + break loop + default: + s.elog.Error(1, fmt.Sprintf("unexpected control request #%d", c)) + } + } + } + close(shutdownC) + changes <- svc.Status{State: svc.StopPending} + return +} + +func installWindowsService(c *cli.Context) error { + Log.Infof("Installing Argo Tunnel Windows service") + exepath, err := os.Executable() + if err != nil { + Log.Infof("Cannot find path name that start the process") + return err + } + m, err := mgr.Connect() + if err != nil { + Log.WithError(err).Infof("Cannot establish a connection to the service control manager") + return err + } + defer m.Disconnect() + s, err := m.OpenService(windowsServiceName) + if err == nil { + s.Close() + Log.Errorf("service %s already exists", windowsServiceName) + return fmt.Errorf("service %s already exists", windowsServiceName) + } + config := mgr.Config{StartType: mgr.StartAutomatic, DisplayName: windowsServiceDescription} + s, err = m.CreateService(windowsServiceName, exepath, config) + if err != nil { + Log.Infof("Cannot install service %s", windowsServiceName) + return err + } + defer s.Close() + err = eventlog.InstallAsEventCreate(windowsServiceName, eventlog.Error|eventlog.Warning|eventlog.Info) + if err != nil { + s.Delete() + Log.WithError(err).Infof("Cannot install event logger") + return fmt.Errorf("SetupEventLogSource() failed: %s", err) + } + return nil +} + +func uninstallWindowsService(c *cli.Context) error { + Log.Infof("Uninstalling Argo Tunnel Windows Service") + m, err := mgr.Connect() + if err != nil { + Log.Infof("Cannot establish a connection to the service control manager") + return err + } + defer m.Disconnect() + s, err := m.OpenService(windowsServiceName) + if err != nil { + Log.Infof("service %s is not installed", windowsServiceName) + return fmt.Errorf("service %s is not installed", windowsServiceName) + } + defer s.Close() + err = s.Delete() + if err != nil { + Log.Errorf("Cannot delete service %s", windowsServiceName) + return err + } + err = eventlog.Remove(windowsServiceName) + if err != nil { + Log.Infof("Cannot remove event logger") + return fmt.Errorf("RemoveEventLogSource() failed: %s", err) + } + return nil +} diff --git a/h2mux/activestreammap.go b/h2mux/activestreammap.go index 9059e972..80b54d38 100644 --- a/h2mux/activestreammap.go +++ b/h2mux/activestreammap.go @@ -24,12 +24,6 @@ type activeStreamMap struct { ignoreNewStreams bool } -type FlowControlMetrics struct { - AverageReceiveWindowSize, AverageSendWindowSize float64 - MinReceiveWindowSize, MaxReceiveWindowSize uint32 - MinSendWindowSize, MaxSendWindowSize uint32 -} - func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap { m := &activeStreamMap{ streams: make(map[uint32]*MuxedStream), @@ -169,45 +163,3 @@ func (m *activeStreamMap) Abort() { } m.ignoreNewStreams = true } - -func (m *activeStreamMap) Metrics() *FlowControlMetrics { - m.Lock() - defer m.Unlock() - var averageReceiveWindowSize, averageSendWindowSize float64 - var minReceiveWindowSize, maxReceiveWindowSize, minSendWindowSize, maxSendWindowSize uint32 - i := 0 - // The first variable in the range expression for map is the key, not index. - for _, stream := range m.streams { - // iterative mean: a(t+1) = a(t) + (a(t)-x)/(t+1) - windows := stream.FlowControlWindow() - averageReceiveWindowSize += (float64(windows.receiveWindow) - averageReceiveWindowSize) / float64(i+1) - averageSendWindowSize += (float64(windows.sendWindow) - averageSendWindowSize) / float64(i+1) - if i == 0 { - maxReceiveWindowSize = windows.receiveWindow - minReceiveWindowSize = windows.receiveWindow - maxSendWindowSize = windows.sendWindow - minSendWindowSize = windows.sendWindow - } else { - if windows.receiveWindow > maxReceiveWindowSize { - maxReceiveWindowSize = windows.receiveWindow - } else if windows.receiveWindow < minReceiveWindowSize { - minReceiveWindowSize = windows.receiveWindow - } - - if windows.sendWindow > maxSendWindowSize { - maxSendWindowSize = windows.sendWindow - } else if windows.sendWindow < minSendWindowSize { - minSendWindowSize = windows.sendWindow - } - } - i++ - } - return &FlowControlMetrics{ - MinReceiveWindowSize: minReceiveWindowSize, - MaxReceiveWindowSize: maxReceiveWindowSize, - AverageReceiveWindowSize: averageReceiveWindowSize, - MinSendWindowSize: minSendWindowSize, - MaxSendWindowSize: maxSendWindowSize, - AverageSendWindowSize: averageSendWindowSize, - } -} diff --git a/h2mux/bytes_counter.go b/h2mux/bytes_counter.go new file mode 100644 index 00000000..0cd290fd --- /dev/null +++ b/h2mux/bytes_counter.go @@ -0,0 +1,18 @@ +package h2mux + +import ( + "sync/atomic" +) + +type AtomicCounter struct { + count uint64 +} + +func (c *AtomicCounter) IncrementBy(number uint64) { + atomic.AddUint64(&c.count, number) +} + +// Get returns the current value of counter and reset it to 0 +func (c *AtomicCounter) Count() uint64 { + return atomic.SwapUint64(&c.count, 0) +} diff --git a/h2mux/bytes_counter_test.go b/h2mux/bytes_counter_test.go new file mode 100644 index 00000000..da579aaf --- /dev/null +++ b/h2mux/bytes_counter_test.go @@ -0,0 +1,23 @@ +package h2mux + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCounter(t *testing.T) { + var wg sync.WaitGroup + wg.Add(dataPoints) + c := AtomicCounter{} + for i := 0; i < dataPoints; i++ { + go func() { + defer wg.Done() + c.IncrementBy(uint64(1)) + }() + } + wg.Wait() + assert.Equal(t, uint64(dataPoints), c.Count()) + assert.Equal(t, uint64(0), c.Count()) +} diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go index 928b162b..e476431b 100644 --- a/h2mux/h2mux.go +++ b/h2mux/h2mux.go @@ -59,6 +59,8 @@ type Muxer struct { muxReader *MuxReader // muxWriter is the write process. muxWriter *MuxWriter + // muxMetricsUpdater is the process to update metrics + muxMetricsUpdater *muxMetricsUpdater // newStreamChan is used to create new streams on the writer thread. // The writer will assign the next available stream ID. newStreamChan chan MuxedStreamRequest @@ -133,6 +135,11 @@ func Handshake( // set up reader/writer pair ready for serve streamErrors := NewStreamErrorMap() goAwayChan := make(chan http2.ErrCode, 1) + updateRTTChan := make(chan *roundTripMeasurement, 1) + updateReceiveWindowChan := make(chan uint32, 1) + updateSendWindowChan := make(chan uint32, 1) + updateInBoundBytesChan := make(chan uint64) + updateOutBoundBytesChan := make(chan uint64) pingTimestamp := NewPingTimestamp() connActive := NewSignal() idleDuration := config.HeartbeatInterval @@ -149,34 +156,48 @@ func Handshake( m.explicitShutdown = NewBooleanFuse() m.muxReader = &MuxReader{ - f: m.f, - handler: m.config.Handler, - streams: m.streams, - readyList: m.readyList, - streamErrors: streamErrors, - goAwayChan: goAwayChan, - abortChan: m.abortChan, - pingTimestamp: pingTimestamp, - connActive: connActive, - initialStreamWindow: defaultWindowSize, - streamWindowMax: maxWindowSize, - r: m.r, + f: m.f, + handler: m.config.Handler, + streams: m.streams, + readyList: m.readyList, + streamErrors: streamErrors, + goAwayChan: goAwayChan, + abortChan: m.abortChan, + pingTimestamp: pingTimestamp, + connActive: connActive, + initialStreamWindow: defaultWindowSize, + streamWindowMax: maxWindowSize, + r: m.r, + updateRTTChan: updateRTTChan, + updateReceiveWindowChan: updateReceiveWindowChan, + updateSendWindowChan: updateSendWindowChan, + updateInBoundBytesChan: updateInBoundBytesChan, } m.muxWriter = &MuxWriter{ - f: m.f, - streams: m.streams, - streamErrors: streamErrors, - readyStreamChan: m.readyList.ReadyChannel(), - newStreamChan: m.newStreamChan, - goAwayChan: goAwayChan, - abortChan: m.abortChan, - pingTimestamp: pingTimestamp, - idleTimer: NewIdleTimer(idleDuration, maxRetries), - connActiveChan: connActive.WaitChannel(), - maxFrameSize: defaultFrameSize, + f: m.f, + streams: m.streams, + streamErrors: streamErrors, + readyStreamChan: m.readyList.ReadyChannel(), + newStreamChan: m.newStreamChan, + goAwayChan: goAwayChan, + abortChan: m.abortChan, + pingTimestamp: pingTimestamp, + idleTimer: NewIdleTimer(idleDuration, maxRetries), + connActiveChan: connActive.WaitChannel(), + maxFrameSize: defaultFrameSize, + updateReceiveWindowChan: updateReceiveWindowChan, + updateSendWindowChan: updateSendWindowChan, + updateOutBoundBytesChan: updateOutBoundBytesChan, } m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer) - + m.muxMetricsUpdater = newMuxMetricsUpdater( + updateRTTChan, + updateReceiveWindowChan, + updateSendWindowChan, + updateInBoundBytesChan, + updateOutBoundBytesChan, + m.abortChan, + ) return m, nil } @@ -246,9 +267,13 @@ func (m *Muxer) Serve() error { m.w.Close() m.abort() }() + go func() { + errChan <- m.muxMetricsUpdater.run(logger) + }() err := <-errChan go func() { - // discard error as other handler closes + // discard errors as other handler and muxMetricsUpdater close + <-errChan <-errChan close(errChan) }() @@ -318,14 +343,8 @@ func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, erro } } -// Return the estimated round-trip time. -func (m *Muxer) RTT() RTTMeasurement { - return m.muxReader.RTT() -} - -// Return min/max/average of send/receive window for all streams on this connection -func (m *Muxer) FlowControlMetrics() *FlowControlMetrics { - return m.muxReader.FlowControlMetrics() +func (m *Muxer) Metrics() *MuxerMetrics { + return m.muxMetricsUpdater.Metrics() } func (m *Muxer) abort() { diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go index 49480cf4..3f9d1a75 100644 --- a/h2mux/h2mux_test.go +++ b/h2mux/h2mux_test.go @@ -14,6 +14,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" ) func TestMain(m *testing.M) { @@ -134,9 +135,12 @@ func TestSingleStream(t *testing.T) { if stream.Headers[0].Value != "headerValue" { t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value) } - stream.WriteHeaders([]Header{ + headers := []Header{ Header{Name: "response-header", Value: "responseValue"}, - }) + } + stream.WriteHeaders(headers) + assert.Equal(t, headers, stream.writeHeaders) + assert.False(t, stream.headersSent) buf := []byte("Hello world") stream.Write(buf) // after this receive, the edge closed the stream diff --git a/h2mux/muxedstream.go b/h2mux/muxedstream.go index f0367ce9..f9997858 100644 --- a/h2mux/muxedstream.go +++ b/h2mux/muxedstream.go @@ -19,6 +19,7 @@ type MuxedStream struct { receiveWindowCurrentMax uint32 // limit set in http2 spec. 2^31-1 receiveWindowMax uint32 + // nonzero if a WINDOW_UPDATE frame for a stream needs to be sent windowUpdate uint32 @@ -39,10 +40,6 @@ type MuxedStream struct { receivedEOF bool } -type flowControlWindow struct { - receiveWindow, sendWindow uint32 -} - func (s *MuxedStream) Read(p []byte) (n int, err error) { return s.readBuffer.Read(p) } @@ -101,17 +98,21 @@ func (s *MuxedStream) WriteHeaders(headers []Header) error { return ErrStreamHeadersSent } s.writeHeaders = headers + s.headersSent = false s.writeNotify() return nil } -func (s *MuxedStream) FlowControlWindow() *flowControlWindow { +func (s *MuxedStream) getReceiveWindow() uint32 { s.writeLock.Lock() defer s.writeLock.Unlock() - return &flowControlWindow{ - receiveWindow: s.receiveWindow, - sendWindow: s.sendWindow, - } + return s.receiveWindow +} + +func (s *MuxedStream) getSendWindow() uint32 { + s.writeLock.Lock() + defer s.writeLock.Unlock() + return s.sendWindow } // writeNotify must happen while holding writeLock. @@ -209,9 +210,7 @@ func (s *MuxedStream) getChunk() *streamChunk { } // Copies at most s.sendWindow bytes - //log.Infof("writeBuffer len %d stream %d", s.writeBuffer.Len(), s.streamID) writeLen, _ := io.CopyN(&chunk.buffer, &s.writeBuffer, int64(s.sendWindow)) - //log.Infof("writeLen %d stream %d", writeLen, s.streamID) s.sendWindow -= uint32(writeLen) s.receiveWindow += s.windowUpdate s.windowUpdate = 0 diff --git a/h2mux/muxmetrics.go b/h2mux/muxmetrics.go new file mode 100644 index 00000000..ebac9935 --- /dev/null +++ b/h2mux/muxmetrics.go @@ -0,0 +1,232 @@ +package h2mux + +import ( + "sync" + "time" + + "github.com/golang-collections/collections/queue" + + log "github.com/sirupsen/logrus" +) + +// data points used to compute average receive window and send window size +const ( + // data points used to compute average receive window and send window size + dataPoints = 100 + // updateFreq is set to 1 sec so we can get inbound & outbound byes/sec + updateFreq = time.Second +) + +type muxMetricsUpdater struct { + // rttData keeps record of rtt, rttMin, rttMax and last measured time + rttData *rttData + // receiveWindowData keeps record of receive window measurement + receiveWindowData *flowControlData + // sendWindowData keeps record of send window measurement + sendWindowData *flowControlData + // inBoundRate is incoming bytes/sec + inBoundRate *rate + // outBoundRate is outgoing bytes/sec + outBoundRate *rate + // updateRTTChan is the channel to receive new RTT measurement from muxReader + updateRTTChan <-chan *roundTripMeasurement + //updateReceiveWindowChan is the channel to receive updated receiveWindow size from muxReader and muxWriter + updateReceiveWindowChan <-chan uint32 + //updateSendWindowChan is the channel to receive updated sendWindow size from muxReader and muxWriter + updateSendWindowChan <-chan uint32 + // updateInBoundBytesChan us the channel to receive bytesRead from muxReader + updateInBoundBytesChan <-chan uint64 + // updateOutBoundBytesChan us the channel to receive bytesWrote from muxWriter + updateOutBoundBytesChan <-chan uint64 + // shutdownC is to signal the muxerMetricsUpdater to shutdown + abortChan <-chan struct{} +} + +type MuxerMetrics struct { + RTT, RTTMin, RTTMax time.Duration + ReceiveWindowAve, SendWindowAve float64 + ReceiveWindowMin, ReceiveWindowMax, SendWindowMin, SendWindowMax uint32 + InBoundRateCurr, InBoundRateMin, InBoundRateMax uint64 + OutBoundRateCurr, OutBoundRateMin, OutBoundRateMax uint64 +} + +type roundTripMeasurement struct { + receiveTime, sendTime time.Time +} + +type rttData struct { + rtt, rttMin, rttMax time.Duration + lastMeasurementTime time.Time + lock sync.RWMutex +} + +type flowControlData struct { + sum uint64 + min, max uint32 + queue *queue.Queue + lock sync.RWMutex +} + +type rate struct { + curr uint64 + min, max uint64 + lock sync.RWMutex +} + +func newMuxMetricsUpdater( + updateRTTChan <-chan *roundTripMeasurement, + updateReceiveWindowChan <-chan uint32, + updateSendWindowChan <-chan uint32, + updateInBoundBytesChan <-chan uint64, + updateOutBoundBytesChan <-chan uint64, + abortChan <-chan struct{}, +) *muxMetricsUpdater { + return &muxMetricsUpdater{ + rttData: newRTTData(), + receiveWindowData: newFlowControlData(), + sendWindowData: newFlowControlData(), + inBoundRate: newRate(), + outBoundRate: newRate(), + updateRTTChan: updateRTTChan, + updateReceiveWindowChan: updateReceiveWindowChan, + updateSendWindowChan: updateSendWindowChan, + updateInBoundBytesChan: updateInBoundBytesChan, + updateOutBoundBytesChan: updateOutBoundBytesChan, + abortChan: abortChan, + } +} + +func (updater *muxMetricsUpdater) Metrics() *MuxerMetrics { + m := &MuxerMetrics{} + m.RTT, m.RTTMin, m.RTTMax = updater.rttData.metrics() + m.ReceiveWindowAve, m.ReceiveWindowMin, m.ReceiveWindowMax = updater.receiveWindowData.metrics() + m.SendWindowAve, m.SendWindowMin, m.SendWindowMax = updater.sendWindowData.metrics() + m.InBoundRateCurr, m.InBoundRateMin, m.InBoundRateMax = updater.inBoundRate.get() + m.OutBoundRateCurr, m.OutBoundRateMin, m.OutBoundRateMax = updater.outBoundRate.get() + return m +} + +func (updater *muxMetricsUpdater) run(parentLogger *log.Entry) error { + logger := parentLogger.WithFields(log.Fields{ + "subsystem": "mux", + "dir": "metrics", + }) + defer logger.Debug("event loop finished") + for { + select { + case <-updater.abortChan: + logger.Infof("Stopping mux metrics updater") + return nil + case roundTripMeasurement := <-updater.updateRTTChan: + go updater.rttData.update(roundTripMeasurement) + logger.Debug("Update rtt") + case receiveWindow := <-updater.updateReceiveWindowChan: + go updater.receiveWindowData.update(receiveWindow) + logger.Debug("Update receive window") + case sendWindow := <-updater.updateSendWindowChan: + go updater.sendWindowData.update(sendWindow) + logger.Debug("Update send window") + case inBoundBytes := <-updater.updateInBoundBytesChan: + // inBoundBytes is bytes/sec because the update interval is 1 sec + go updater.inBoundRate.update(inBoundBytes) + logger.Debugf("Inbound bytes %d", inBoundBytes) + case outBoundBytes := <-updater.updateOutBoundBytesChan: + // outBoundBytes is bytes/sec because the update interval is 1 sec + go updater.outBoundRate.update(outBoundBytes) + logger.Debugf("Outbound bytes %d", outBoundBytes) + } + } +} + +func newRTTData() *rttData { + return &rttData{} +} + +func (r *rttData) update(measurement *roundTripMeasurement) { + r.lock.Lock() + defer r.lock.Unlock() + // discard pings before lastMeasurementTime + if r.lastMeasurementTime.After(measurement.sendTime) { + return + } + r.lastMeasurementTime = measurement.sendTime + r.rtt = measurement.receiveTime.Sub(measurement.sendTime) + if r.rttMax < r.rtt { + r.rttMax = r.rtt + } + if r.rttMin == 0 || r.rttMin > r.rtt { + r.rttMin = r.rtt + } +} + +func (r *rttData) metrics() (rtt, rttMin, rttMax time.Duration) { + r.lock.RLock() + defer r.lock.RUnlock() + return r.rtt, r.rttMin, r.rttMax +} + +func newFlowControlData() *flowControlData { + return &flowControlData{queue: queue.New()} +} + +func (f *flowControlData) update(measurement uint32) { + f.lock.Lock() + defer f.lock.Unlock() + var firstItem uint32 + // store new data into queue, remove oldest data if queue is full + f.queue.Enqueue(measurement) + if f.queue.Len() > dataPoints { + // data type should always be uint32 + firstItem = f.queue.Dequeue().(uint32) + } + // if (measurement - firstItem) < 0, uint64(measurement - firstItem) + // will overflow and become a large positive number + f.sum += uint64(measurement) + f.sum -= uint64(firstItem) + if measurement > f.max { + f.max = measurement + } + if f.min == 0 || measurement < f.min { + f.min = measurement + } +} + +// caller of ave() should acquire lock first +func (f *flowControlData) ave() float64 { + if f.queue.Len() == 0 { + return 0 + } + return float64(f.sum) / float64(f.queue.Len()) +} + +func (f *flowControlData) metrics() (ave float64, min, max uint32) { + f.lock.RLock() + defer f.lock.RUnlock() + return f.ave(), f.min, f.max +} + +func newRate() *rate { + return &rate{} +} + +func (r *rate) update(measurement uint64) { + r.lock.Lock() + defer r.lock.Unlock() + r.curr = measurement + // if measurement is 0, then there is no incoming/outgoing connection, don't update min/max + if r.curr == 0 { + return + } + if measurement > r.max { + r.max = measurement + } + if r.min == 0 || measurement < r.min { + r.min = measurement + } +} + +func (r *rate) get() (curr, min, max uint64) { + r.lock.RLock() + defer r.lock.RUnlock() + return r.curr, r.min, r.max +} diff --git a/h2mux/muxmetrics_test.go b/h2mux/muxmetrics_test.go new file mode 100644 index 00000000..7f4cddf2 --- /dev/null +++ b/h2mux/muxmetrics_test.go @@ -0,0 +1,176 @@ +package h2mux + +import ( + "sync" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func ave(sum uint64, len int) float64 { + return float64(sum) / float64(len) +} + +func TestRTTUpdate(t *testing.T) { + r := newRTTData() + start := time.Now() + // send at 0 ms, receive at 2 ms, RTT = 2ms + m := &roundTripMeasurement{receiveTime: start.Add(2 * time.Millisecond), sendTime: start} + r.update(m) + assert.Equal(t, start, r.lastMeasurementTime) + assert.Equal(t, 2*time.Millisecond, r.rtt) + assert.Equal(t, 2*time.Millisecond, r.rttMin) + assert.Equal(t, 2*time.Millisecond, r.rttMax) + + // send at 3 ms, receive at 6 ms, RTT = 3ms + m = &roundTripMeasurement{receiveTime: start.Add(6 * time.Millisecond), sendTime: start.Add(3 * time.Millisecond)} + r.update(m) + assert.Equal(t, start.Add(3*time.Millisecond), r.lastMeasurementTime) + assert.Equal(t, 3*time.Millisecond, r.rtt) + assert.Equal(t, 2*time.Millisecond, r.rttMin) + assert.Equal(t, 3*time.Millisecond, r.rttMax) + + // send at 7 ms, receive at 8 ms, RTT = 1ms + m = &roundTripMeasurement{receiveTime: start.Add(8 * time.Millisecond), sendTime: start.Add(7 * time.Millisecond)} + r.update(m) + assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime) + assert.Equal(t, 1*time.Millisecond, r.rtt) + assert.Equal(t, 1*time.Millisecond, r.rttMin) + assert.Equal(t, 3*time.Millisecond, r.rttMax) + + // send at -4 ms, receive at 0 ms, RTT = 4ms, but this ping is before last measurement + // so it will be discarded + m = &roundTripMeasurement{receiveTime: start, sendTime: start.Add(-2 * time.Millisecond)} + r.update(m) + assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime) + assert.Equal(t, 1*time.Millisecond, r.rtt) + assert.Equal(t, 1*time.Millisecond, r.rttMin) + assert.Equal(t, 3*time.Millisecond, r.rttMax) +} + +func TestFlowControlDataUpdate(t *testing.T) { + f := newFlowControlData() + assert.Equal(t, 0, f.queue.Len()) + assert.Equal(t, float64(0), f.ave()) + + var sum uint64 + min := maxWindowSize - dataPoints + max := maxWindowSize + for i := 1; i <= dataPoints; i++ { + size := maxWindowSize - uint32(i) + f.update(size) + assert.Equal(t, max - uint32(1), f.max) + assert.Equal(t, size, f.min) + + assert.Equal(t, i, f.queue.Len()) + + sum += uint64(size) + assert.Equal(t, sum, f.sum) + assert.Equal(t, ave(sum, f.queue.Len()), f.ave()) + } + + // queue is full, should start to dequeue first element + for i := 1; i <= dataPoints; i++ { + f.update(max) + assert.Equal(t, max, f.max) + assert.Equal(t, min, f.min) + + assert.Equal(t, dataPoints, f.queue.Len()) + + sum += uint64(i) + assert.Equal(t, sum, f.sum) + assert.Equal(t, ave(sum, dataPoints), f.ave()) + } +} + +func TestMuxMetricsUpdater(t *testing.T) { + updateRTTChan := make(chan *roundTripMeasurement) + updateReceiveWindowChan := make(chan uint32) + updateSendWindowChan := make(chan uint32) + updateInBoundBytesChan := make(chan uint64) + updateOutBoundBytesChan := make(chan uint64) + abortChan := make(chan struct{}) + errChan := make(chan error) + m := newMuxMetricsUpdater(updateRTTChan, + updateReceiveWindowChan, + updateSendWindowChan, + updateInBoundBytesChan, + updateOutBoundBytesChan, + abortChan, + ) + logger := log.NewEntry(log.New()) + + go func() { + errChan <- m.run(logger) + }() + + var wg sync.WaitGroup + wg.Add(2) + + // mock muxReader + readerStart := time.Now() + rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart} + updateRTTChan <- rm + go func() { + defer wg.Done() + // Becareful if dataPoints is not divisibile by 4 + readerSend := readerStart.Add(time.Millisecond) + for i := 1; i <= dataPoints/4; i++ { + readerReceive := readerSend.Add(time.Duration(i) * time.Millisecond) + rm := &roundTripMeasurement{receiveTime: readerReceive, sendTime: readerSend} + updateRTTChan <- rm + readerSend = readerReceive.Add(time.Millisecond) + + updateReceiveWindowChan <- uint32(i) + updateSendWindowChan <- uint32(i) + + updateInBoundBytesChan <- uint64(i) + } + }() + + // mock muxWriter + go func() { + defer wg.Done() + for j := dataPoints/4 + 1; j <= dataPoints/2; j++ { + updateReceiveWindowChan <- uint32(j) + updateSendWindowChan <- uint32(j) + + // should always be disgard since the send time is before readerSend + rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart.Add(-time.Duration(j*dataPoints) * time.Millisecond)} + updateRTTChan <- rm + + updateOutBoundBytesChan <- uint64(j) + } + + }() + wg.Wait() + + metrics := m.Metrics() + points := dataPoints / 2 + assert.Equal(t, time.Millisecond, metrics.RTTMin) + assert.Equal(t, time.Duration(dataPoints/4)*time.Millisecond, metrics.RTTMax) + + // sum(1..i) = i*(i+1)/2, ave(1..i) = i*(i+1)/2/i = (i+1)/2 + assert.Equal(t, float64(points+1)/float64(2), metrics.ReceiveWindowAve) + assert.Equal(t, uint32(1), metrics.ReceiveWindowMin) + assert.Equal(t, uint32(points), metrics.ReceiveWindowMax) + + assert.Equal(t, float64(points+1)/float64(2), metrics.SendWindowAve) + assert.Equal(t, uint32(1), metrics.SendWindowMin) + assert.Equal(t, uint32(points), metrics.SendWindowMax) + + assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateCurr) + assert.Equal(t, uint64(1), metrics.InBoundRateMin) + assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateMax) + + assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateCurr) + assert.Equal(t, uint64(dataPoints/4+1), metrics.OutBoundRateMin) + assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateMax) + + close(abortChan) + assert.Nil(t, <-errChan) + close(errChan) + +} diff --git a/h2mux/muxreader.go b/h2mux/muxreader.go index a0b56366..3530dab8 100644 --- a/h2mux/muxreader.go +++ b/h2mux/muxreader.go @@ -3,7 +3,6 @@ package h2mux import ( "encoding/binary" "io" - "sync" "time" log "github.com/sirupsen/logrus" @@ -34,14 +33,18 @@ type MuxReader struct { initialStreamWindow uint32 // The max value for the send window of a stream. streamWindowMax uint32 - // windowMetrics keeps track of min/max/average of send/receive windows for all streams - flowControlMetrics *FlowControlMetrics - metricsMutex sync.Mutex // r is a reference to the underlying connection used when shutting down. r io.Closer - // rttMeasurement measures RTT based on ping timestamps. - rttMeasurement RTTMeasurement - rttMutex sync.Mutex + // updateRTTChan is the channel to send new RTT measurement to muxerMetricsUpdater + updateRTTChan chan<- *roundTripMeasurement + // updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater + updateReceiveWindowChan chan<- uint32 + // updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater + updateSendWindowChan chan<- uint32 + // bytesRead is the amount of bytes read from data frame since the last time we send bytes read to metrics + bytesRead AtomicCounter + // updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater + updateInBoundBytesChan chan<- uint64 } func (r *MuxReader) Shutdown() { @@ -57,28 +60,26 @@ func (r *MuxReader) Shutdown() { }() } -func (r *MuxReader) RTT() RTTMeasurement { - r.rttMutex.Lock() - defer r.rttMutex.Unlock() - return r.rttMeasurement -} - -func (r *MuxReader) FlowControlMetrics() *FlowControlMetrics { - r.metricsMutex.Lock() - defer r.metricsMutex.Unlock() - if r.flowControlMetrics != nil { - return r.flowControlMetrics - } - // No metrics available yet - return &FlowControlMetrics{} -} - func (r *MuxReader) run(parentLogger *log.Entry) error { logger := parentLogger.WithFields(log.Fields{ "subsystem": "mux", "dir": "read", }) defer logger.Debug("event loop finished") + + // routine to periodically update bytesRead + go func() { + tickC := time.Tick(updateFreq) + for { + select { + case <-r.abortChan: + return + case <-tickC: + r.updateInBoundBytesChan <- r.bytesRead.Count() + } + } + }() + for { frame, err := r.f.ReadFrame() if err != nil { @@ -120,6 +121,8 @@ func (r *MuxReader) run(parentLogger *log.Entry) error { r.receivePingData(f) case *http2.GoAwayFrame: err = r.receiveGoAway(f) + // The receiver of a flow-controlled frame sends a WINDOW_UPDATE frame as it + // consumes data and frees up space in flow-control windows case *http2.WindowUpdateFrame: err = r.updateStreamWindow(f) default: @@ -236,10 +239,11 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E } data := frame.Data() if len(data) > 0 { - _, err = stream.readBuffer.Write(data) + n, err := stream.readBuffer.Write(data) if err != nil { return r.streamError(stream.streamID, http2.ErrCodeInternal) } + r.bytesRead.IncrementBy(uint64(n)) } if frame.Header().Flags.Has(http2.FlagDataEndStream) { if stream.receiveEOF() { @@ -253,6 +257,7 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E if !stream.consumeReceiveWindow(uint32(len(data))) { return r.streamError(stream.streamID, http2.ErrCodeFlowControl) } + r.updateReceiveWindowChan <- stream.getReceiveWindow() return nil } @@ -263,10 +268,14 @@ func (r *MuxReader) receivePingData(frame *http2.PingFrame) { r.pingTimestamp.Set(ts) return } - r.rttMutex.Lock() - r.rttMeasurement.Update(time.Unix(0, ts)) - r.rttMutex.Unlock() - r.flowControlMetrics = r.streams.Metrics() + + // Update updates the computed values with a new measurement. + // outgoingTime is the time that the probe was sent. + // We assume that time.Now() is the time we received that probe. + r.updateRTTChan <- &roundTripMeasurement{ + receiveTime: time.Now(), + sendTime: time.Unix(0, ts), + } } // Receive a GOAWAY from the peer. Gracefully shut down our connection. @@ -293,6 +302,7 @@ func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error { return nil } stream.replenishSendWindow(frame.Increment) + r.updateSendWindowChan <- stream.getSendWindow() return nil } diff --git a/h2mux/muxwriter.go b/h2mux/muxwriter.go index 1cfbf1f3..8ea4c675 100644 --- a/h2mux/muxwriter.go +++ b/h2mux/muxwriter.go @@ -40,6 +40,14 @@ type MuxWriter struct { headerEncoder *hpack.Encoder // headerBuffer is the temporary buffer used by headerEncoder. headerBuffer bytes.Buffer + // updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater + updateReceiveWindowChan chan<- uint32 + // updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater + updateSendWindowChan chan<- uint32 + // bytesWrote is the amount of bytes wrote to data frame since the last time we send bytes wrote to metrics + bytesWrote AtomicCounter + // updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater + updateOutBoundBytesChan chan<- uint64 } type MuxedStreamRequest struct { @@ -64,6 +72,20 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error { "dir": "write", }) defer logger.Debug("event loop finished") + + // routine to periodically communicate bytesWrote + go func() { + tickC := time.Tick(updateFreq) + for { + select { + case <-w.abortChan: + return + case <-tickC: + w.updateOutBoundBytesChan <- w.bytesWrote.Count() + } + } + }() + for { select { case <-w.abortChan: @@ -141,7 +163,8 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error { func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error { logger.Debug("writable") chunk := stream.getChunk() - + w.updateReceiveWindowChan <- stream.getReceiveWindow() + w.updateSendWindowChan <- stream.getSendWindow() if chunk.sendHeadersFrame() { err := w.writeHeaders(chunk.streamID, chunk.headers) if err != nil { @@ -154,7 +177,9 @@ func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) erro if chunk.sendWindowUpdateFrame() { // Send a WINDOW_UPDATE frame to update our receive window. // If the Stream ID is zero, the window update applies to the connection as a whole - // A WINDOW_UPDATE in a specific stream applies to the connection-level flow control as well. + // RFC7540 section-6.9.1 "A receiver that receives a flow-controlled frame MUST + // always account for its contribution against the connection flow-control + // window, unless the receiver treats this as a connection error" err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate) if err != nil { logger.WithError(err).Warn("error writing window update") @@ -170,6 +195,8 @@ func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) erro logger.WithError(err).Warn("error writing data") return err } + // update the amount of data wrote + w.bytesWrote.IncrementBy(uint64(len(payload))) logger.WithField("len", len(payload)).Debug("output data") if sentEOF { @@ -214,19 +241,16 @@ func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error { return err } blockSize := int(w.maxFrameSize) - continuation := false endHeaders := len(encodedHeaders) == 0 for !endHeaders && err == nil { blockFragment := encodedHeaders if len(encodedHeaders) > blockSize { blockFragment = blockFragment[:blockSize] encodedHeaders = encodedHeaders[blockSize:] - } else { - endHeaders = true - } - if continuation { + // Send CONTINUATION frame if the headers can't be fit into 1 frame err = w.f.WriteContinuation(streamID, endHeaders, blockFragment) } else { + endHeaders = true err = w.f.WriteHeaders(http2.HeadersFrameParam{ StreamID: streamID, EndHeaders: endHeaders, diff --git a/h2mux/rtt.go b/h2mux/rtt.go index 1c42ff82..350233e3 100644 --- a/h2mux/rtt.go +++ b/h2mux/rtt.go @@ -2,7 +2,6 @@ package h2mux import ( "sync/atomic" - "time" ) // PingTimestamp is an atomic interface around ping timestamping and signalling. @@ -28,26 +27,3 @@ func (pt *PingTimestamp) Get() int64 { func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} { return pt.signal.WaitChannel() } - -// RTTMeasurement encapsulates a continuous round trip time measurement. -type RTTMeasurement struct { - Current, Min, Max time.Duration - lastMeasurementTime time.Time -} - -// Update updates the computed values with a new measurement. -// outgoingTime is the time that the probe was sent. -// We assume that time.Now() is the time we received that probe. -func (r *RTTMeasurement) Update(outgoingTime time.Time) { - if !r.lastMeasurementTime.Before(outgoingTime) { - return - } - r.lastMeasurementTime = outgoingTime - r.Current = time.Since(outgoingTime) - if r.Max < r.Current { - r.Max = r.Current - } - if r.Min > r.Current { - r.Min = r.Current - } -} diff --git a/origin/metrics.go b/origin/metrics.go index 314d4060..9908c554 100644 --- a/origin/metrics.go +++ b/origin/metrics.go @@ -2,13 +2,32 @@ package origin import ( "sync" + "time" - "github.com/cloudflare/cloudflare-warp/h2mux" + "github.com/cloudflare/cloudflared/h2mux" "github.com/prometheus/client_golang/prometheus" ) -type TunnelMetrics struct { +type muxerMetrics struct { + rtt *prometheus.GaugeVec + rttMin *prometheus.GaugeVec + rttMax *prometheus.GaugeVec + receiveWindowAve *prometheus.GaugeVec + sendWindowAve *prometheus.GaugeVec + receiveWindowMin *prometheus.GaugeVec + receiveWindowMax *prometheus.GaugeVec + sendWindowMin *prometheus.GaugeVec + sendWindowMax *prometheus.GaugeVec + inBoundRateCurr *prometheus.GaugeVec + inBoundRateMin *prometheus.GaugeVec + inBoundRateMax *prometheus.GaugeVec + outBoundRateCurr *prometheus.GaugeVec + outBoundRateMin *prometheus.GaugeVec + outBoundRateMax *prometheus.GaugeVec +} + +type tunnelMetrics struct { haConnections prometheus.Gauge totalRequests prometheus.Counter requestsPerTunnel *prometheus.CounterVec @@ -20,16 +39,7 @@ type TunnelMetrics struct { maxConcurrentRequestsPerTunnel *prometheus.GaugeVec // concurrentRequests records max count of concurrent requests for each tunnel maxConcurrentRequests map[string]uint64 - rtt prometheus.Gauge - rttMin prometheus.Gauge - rttMax prometheus.Gauge timerRetries prometheus.Gauge - receiveWindowSizeAve prometheus.Gauge - sendWindowSizeAve prometheus.Gauge - receiveWindowSizeMin prometheus.Gauge - receiveWindowSizeMax prometheus.Gauge - sendWindowSizeMin prometheus.Gauge - sendWindowSizeMax prometheus.Gauge responseByCode *prometheus.CounterVec responseCodePerTunnel *prometheus.CounterVec serverLocations *prometheus.GaugeVec @@ -37,10 +47,189 @@ type TunnelMetrics struct { locationLock sync.Mutex // oldServerLocations stores the last server the tunnel was connected to oldServerLocations map[string]string + + muxerMetrics *muxerMetrics +} + +func newMuxerMetrics() *muxerMetrics { + rtt := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "rtt", + Help: "Round-trip time in millisecond", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(rtt) + + rttMin := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "rtt_min", + Help: "Shortest round-trip time in millisecond", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(rttMin) + + rttMax := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "rtt_max", + Help: "Longest round-trip time in millisecond", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(rttMax) + + receiveWindowAve := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "receive_window_ave", + Help: "Average receive window size in bytes", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(receiveWindowAve) + + sendWindowAve := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "send_window_ave", + Help: "Average send window size in bytes", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(sendWindowAve) + + receiveWindowMin := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "receive_window_min", + Help: "Smallest receive window size in bytes", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(receiveWindowMin) + + receiveWindowMax := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "receive_window_max", + Help: "Largest receive window size in bytes", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(receiveWindowMax) + + sendWindowMin := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "send_window_min", + Help: "Smallest send window size in bytes", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(sendWindowMin) + + sendWindowMax := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "send_window_max", + Help: "Largest send window size in bytes", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(sendWindowMax) + + inBoundRateCurr := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "inbound_bytes_per_sec_curr", + Help: "Current inbounding bytes per second, 0 if there is no incoming connection", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(inBoundRateCurr) + + inBoundRateMin := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "inbound_bytes_per_sec_min", + Help: "Minimum non-zero inbounding bytes per second", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(inBoundRateMin) + + inBoundRateMax := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "inbound_bytes_per_sec_max", + Help: "Maximum inbounding bytes per second", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(inBoundRateMax) + + outBoundRateCurr := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "outbound_bytes_per_sec_curr", + Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(outBoundRateCurr) + + outBoundRateMin := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "outbound_bytes_per_sec_min", + Help: "Minimum non-zero outbounding bytes per second", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(outBoundRateMin) + + outBoundRateMax := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "outbound_bytes_per_sec_max", + Help: "Maximum outbounding bytes per second", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(outBoundRateMax) + + return &muxerMetrics{ + rtt: rtt, + rttMin: rttMin, + rttMax: rttMax, + receiveWindowAve: receiveWindowAve, + sendWindowAve: sendWindowAve, + receiveWindowMin: receiveWindowMin, + receiveWindowMax: receiveWindowMax, + sendWindowMin: sendWindowMin, + sendWindowMax: sendWindowMax, + inBoundRateCurr: inBoundRateCurr, + inBoundRateMin: inBoundRateMin, + inBoundRateMax: inBoundRateMax, + outBoundRateCurr: outBoundRateCurr, + outBoundRateMin: outBoundRateMin, + outBoundRateMax: outBoundRateMax, + } +} + +func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) { + m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT)) + m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin)) + m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax)) + m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve) + m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve) + m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin)) + m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax)) + m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin)) + m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax)) + m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr)) + m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin)) + m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax)) + m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr)) + m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin)) + m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax)) +} + +func convertRTTMilliSec(t time.Duration) float64 { + return float64(t / time.Millisecond) } // Metrics that can be collected without asking the edge -func NewTunnelMetrics() *TunnelMetrics { +func NewTunnelMetrics() *tunnelMetrics { haConnections := prometheus.NewGauge( prometheus.GaugeOpts{ Name: "ha_connections", @@ -82,27 +271,6 @@ func NewTunnelMetrics() *TunnelMetrics { ) prometheus.MustRegister(maxConcurrentRequestsPerTunnel) - rtt := prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "rtt", - Help: "Round-trip time", - }) - prometheus.MustRegister(rtt) - - rttMin := prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "rtt_min", - Help: "Shortest round-trip time", - }) - prometheus.MustRegister(rttMin) - - rttMax := prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "rtt_max", - Help: "Longest round-trip time", - }) - prometheus.MustRegister(rttMax) - timerRetries := prometheus.NewGauge( prometheus.GaugeOpts{ Name: "timer_retries", @@ -110,48 +278,6 @@ func NewTunnelMetrics() *TunnelMetrics { }) prometheus.MustRegister(timerRetries) - receiveWindowSizeAve := prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "receive_window_ave", - Help: "Average receive window size", - }) - prometheus.MustRegister(receiveWindowSizeAve) - - sendWindowSizeAve := prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "send_window_ave", - Help: "Average send window size", - }) - prometheus.MustRegister(sendWindowSizeAve) - - receiveWindowSizeMin := prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "receive_window_min", - Help: "Smallest receive window size", - }) - prometheus.MustRegister(receiveWindowSizeMin) - - receiveWindowSizeMax := prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "receive_window_max", - Help: "Largest receive window size", - }) - prometheus.MustRegister(receiveWindowSizeMax) - - sendWindowSizeMin := prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "send_window_min", - Help: "Smallest send window size", - }) - prometheus.MustRegister(sendWindowSizeMin) - - sendWindowSizeMax := prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "send_window_max", - Help: "Largest send window size", - }) - prometheus.MustRegister(sendWindowSizeMax) - responseByCode := prometheus.NewCounterVec( prometheus.CounterOpts{ Name: "response_by_code", @@ -179,7 +305,7 @@ func NewTunnelMetrics() *TunnelMetrics { ) prometheus.MustRegister(serverLocations) - return &TunnelMetrics{ + return &tunnelMetrics{ haConnections: haConnections, totalRequests: totalRequests, requestsPerTunnel: requestsPerTunnel, @@ -187,41 +313,28 @@ func NewTunnelMetrics() *TunnelMetrics { concurrentRequests: make(map[string]uint64), maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel, maxConcurrentRequests: make(map[string]uint64), - rtt: rtt, - rttMin: rttMin, - rttMax: rttMax, - timerRetries: timerRetries, - receiveWindowSizeAve: receiveWindowSizeAve, - sendWindowSizeAve: sendWindowSizeAve, - receiveWindowSizeMin: receiveWindowSizeMin, - receiveWindowSizeMax: receiveWindowSizeMax, - sendWindowSizeMin: sendWindowSizeMin, - sendWindowSizeMax: sendWindowSizeMax, - responseByCode: responseByCode, - responseCodePerTunnel: responseCodePerTunnel, - serverLocations: serverLocations, - oldServerLocations: make(map[string]string), + timerRetries: timerRetries, + responseByCode: responseByCode, + responseCodePerTunnel: responseCodePerTunnel, + serverLocations: serverLocations, + oldServerLocations: make(map[string]string), + muxerMetrics: newMuxerMetrics(), } } -func (t *TunnelMetrics) incrementHaConnections() { +func (t *tunnelMetrics) incrementHaConnections() { t.haConnections.Inc() } -func (t *TunnelMetrics) decrementHaConnections() { +func (t *tunnelMetrics) decrementHaConnections() { t.haConnections.Dec() } -func (t *TunnelMetrics) updateTunnelFlowControlMetrics(metrics *h2mux.FlowControlMetrics) { - t.receiveWindowSizeAve.Set(float64(metrics.AverageReceiveWindowSize)) - t.sendWindowSizeAve.Set(float64(metrics.AverageSendWindowSize)) - t.receiveWindowSizeMin.Set(float64(metrics.MinReceiveWindowSize)) - t.receiveWindowSizeMax.Set(float64(metrics.MaxReceiveWindowSize)) - t.sendWindowSizeMin.Set(float64(metrics.MinSendWindowSize)) - t.sendWindowSizeMax.Set(float64(metrics.MaxSendWindowSize)) +func (t *tunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) { + t.muxerMetrics.update(connectionID, metrics) } -func (t *TunnelMetrics) incrementRequests(connectionID string) { +func (t *tunnelMetrics) incrementRequests(connectionID string) { t.concurrentRequestsLock.Lock() var concurrentRequests uint64 var ok bool @@ -243,7 +356,7 @@ func (t *TunnelMetrics) incrementRequests(connectionID string) { t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc() } -func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) { +func (t *tunnelMetrics) decrementConcurrentRequests(connectionID string) { t.concurrentRequestsLock.Lock() if _, ok := t.concurrentRequests[connectionID]; ok { t.concurrentRequests[connectionID] -= 1 @@ -255,13 +368,13 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) { t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec() } -func (t *TunnelMetrics) incrementResponses(connectionID, code string) { +func (t *tunnelMetrics) incrementResponses(connectionID, code string) { t.responseByCode.WithLabelValues(code).Inc() t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc() } -func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) { +func (t *tunnelMetrics) registerServerLocation(connectionID, loc string) { t.locationLock.Lock() defer t.locationLock.Unlock() if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc { diff --git a/origin/tunnel.go b/origin/tunnel.go index dfeb3ee8..e75988b4 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -15,11 +15,11 @@ import ( "golang.org/x/net/context" - "github.com/cloudflare/cloudflare-warp/h2mux" - "github.com/cloudflare/cloudflare-warp/tunnelrpc" - tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" - "github.com/cloudflare/cloudflare-warp/validation" - "github.com/cloudflare/cloudflare-warp/websocket" + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/tunnelrpc" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/cloudflare/cloudflared/validation" + "github.com/cloudflare/cloudflared/websocket" raven "github.com/getsentry/raven-go" "github.com/pkg/errors" @@ -53,7 +53,7 @@ type TunnelConfig struct { Tags []tunnelpogs.Tag HAConnections int HTTPTransport http.RoundTripper - Metrics *TunnelMetrics + Metrics *tunnelMetrics MetricsUpdateFreq time.Duration ProtocolLogger *logrus.Logger Logger *logrus.Logger @@ -185,6 +185,7 @@ func ServeTunnel( serveCtx, serveCancel := context.WithCancel(ctx) registerErrC := make(chan error, 1) go func() { + defer wg.Done() err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP) if err == nil { connectedFuse.Fuse(true) @@ -193,18 +194,18 @@ func ServeTunnel( serveCancel() } registerErrC <- err - wg.Done() }() updateMetricsTickC := time.Tick(config.MetricsUpdateFreq) go func() { defer wg.Done() + connectionTag := uint8ToString(connectionID) for { select { case <-serveCtx.Done(): handler.muxer.Shutdown() return case <-updateMetricsTickC: - handler.UpdateMetrics() + handler.UpdateMetrics(connectionTag) } } }() @@ -303,7 +304,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi func LogServerInfo(logger *logrus.Entry, promise tunnelrpc.ServerInfo_Promise, connectionID uint8, - metrics *TunnelMetrics, + metrics *tunnelMetrics, ) { serverInfoMessage, err := promise.Struct() if err != nil { @@ -356,13 +357,17 @@ func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) { return } +func FindCfRayHeader(h1 *http.Request) string { + return h1.Header.Get("Cf-Ray") +} + type TunnelHandler struct { originUrl string muxer *h2mux.Muxer httpClient http.RoundTripper tlsConfig *tls.Config tags []tunnelpogs.Tag - metrics *TunnelMetrics + metrics *tunnelMetrics // connectionID is only used by metrics, and prometheus requires labels to be string connectionID string } @@ -435,7 +440,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { Log.WithError(err).Error("invalid request received") } h.AppendTagHeaders(req) - + cfRay := FindCfRayHeader(req) + h.logRequest(req, cfRay) if websocket.IsWebSocketUpgrade(req) { conn, response, err := websocket.ClientConnect(req, h.tlsConfig) if err != nil { @@ -444,6 +450,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { stream.WriteHeaders(H1ResponseToH2Response(response)) defer conn.Close() websocket.Stream(conn.UnderlyingConn(), stream) + h.metrics.incrementResponses(h.connectionID, "200") + h.logResponse(response, cfRay) } } else { response, err := h.httpClient.RoundTrip(req) @@ -454,6 +462,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { stream.WriteHeaders(H1ResponseToH2Response(response)) io.Copy(stream, response.Body) h.metrics.incrementResponses(h.connectionID, "200") + h.logResponse(response, cfRay) } } h.metrics.decrementConcurrentRequests(h.connectionID) @@ -467,9 +476,27 @@ func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) { h.metrics.incrementResponses(h.connectionID, "502") } -func (h *TunnelHandler) UpdateMetrics() { - flowCtlMetrics := h.muxer.FlowControlMetrics() - h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics) +func (h *TunnelHandler) logRequest(req *http.Request, cfRay string) { + if cfRay != "" { + Log.WithField("CF-RAY", cfRay).Infof("%s %s %s", req.Method, req.URL, req.Proto) + } else { + Log.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto) + } + Log.Debugf("Request Headers %+v", req.Header) +} + +func (h *TunnelHandler) logResponse(r *http.Response, cfRay string) { + if cfRay != "" { + Log.WithField("CF-RAY", cfRay).Infof("%s", r.Status) + } else { + Log.Infof("%s", r.Status) + } + Log.Debugf("Response Headers %+v", r.Header) +} + + +func (h *TunnelHandler) UpdateMetrics(connectionID string) { + h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics()) } func uint8ToString(input uint8) string { diff --git a/tlsconfig/hello_ca.go b/tlsconfig/hello_ca.go index bb49093c..7ca15868 100644 --- a/tlsconfig/hello_ca.go +++ b/tlsconfig/hello_ca.go @@ -11,28 +11,28 @@ const ( BgUrgQQAIg== -----END EC PARAMETERS----- -----BEGIN EC PRIVATE KEY----- -MIGkAgEBBDAdyQBXfxTDCQSOT0HugmH9pVBtIw8t5dYvm6HxGlNq6P57v5GeN02Z -dH9FRl7+VSWgBwYFK4EEACKhZANiAATqpFzTxxV7D+/oqhKCTR6BEM9elTfKaRQE -FsLufcmaTMw/9tTwgpHKao/QsLKDTNbQhbSQLkcmpCQKlSGhl+pCrqNt/oYUAhav -UIwpwGiLCqGH/R2AqWLKRPOa/Rufs/U= +MIGkAgEBBDBGGfwhIJdiUiJUVIItqJjEIMmlXxsMa8TQeer47+g+cIZ466rgg8EK ++Mdn6BY48GCgBwYFK4EEACKhZANiAASW//A9iDbPKg3OLkn7yJqLer32g9I5lBKR +tPc/zBubQLLz9lAaYI6AOQiJXhGr5JkKmQfi1sYHK5rJITPFy4W8Et4hHLdazDZH +WnEd+TStQABFUjrhtqXPWmGKcly0pOE= -----END EC PRIVATE KEY-----` helloCRT = ` -----BEGIN CERTIFICATE----- -MIICkDCCAhigAwIBAgIJAPtKfUjc2lwGMAkGByqGSM49BAEwgYoxCzAJBgNVBAYT -AlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQKDBBD -bG91ZGZsYXJlLCBJbmMuMT8wPQYDVQQDDDZDbG91ZGZsYXJlIEFyZ28gVHVubmVs -IFNhbXBsZSBIZWxsbyBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMjE1MjAxNjU5 -WhcNMjgwMjEzMjAxNjU5WjCBijELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVRleGFz -MQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENsb3VkZmxhcmUsIEluYy4xPzA9 -BgNVBAMMNkNsb3VkZmxhcmUgQXJnbyBUdW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZl -ciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49AgEGBSuBBAAiA2IABOqkXNPHFXsP7+iq -EoJNHoEQz16VN8ppFAQWwu59yZpMzD/21PCCkcpqj9CwsoNM1tCFtJAuRyakJAqV -IaGX6kKuo23+hhQCFq9QjCnAaIsKoYf9HYCpYspE85r9G5+z9aNJMEcwRQYDVR0R -BD4wPIIJbG9jYWxob3N0ggp3YXJwLWhlbGxvggt3YXJwMi1oZWxsb4cEfwAAAYcQ -AAAAAAAAAAAAAAAAAAAAATAJBgcqhkjOPQQBA2cAMGQCMHyVPufXZ6vQo6XRWRa0 -dAwtfgesOdZVP2Wt+t5v8jOIQQh1IQXYk5GtyoZGSObjhQIwd1fRgAyKXaZt+1DV -ZtHTdf8pMvESfJsSd8AB1eQ6q+pAiRUYyaxcE1Mlo2YY5o+g +MIICiDCCAg6gAwIBAgIJAJ/FfkBTtbuIMAkGByqGSM49BAEwfzELMAkGA1UEBhMC +VVMxDjAMBgNVBAgMBVRleGFzMQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENs +b3VkZmxhcmUsIEluYy4xNDAyBgNVBAMMK0FyZ28gVHVubmVsIFNhbXBsZSBIZWxs +byBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMzE5MjMwNTMyWhcNMjgwMzE2MjMw +NTMyWjB/MQswCQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1 +c3RpbjEZMBcGA1UECgwQQ2xvdWRmbGFyZSwgSW5jLjE0MDIGA1UEAwwrQXJnbyBU +dW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZlciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49 +AgEGBSuBBAAiA2IABJb/8D2INs8qDc4uSfvImot6vfaD0jmUEpG09z/MG5tAsvP2 +UBpgjoA5CIleEavkmQqZB+LWxgcrmskhM8XLhbwS3iEct1rMNkdacR35NK1AAEVS +OuG2pc9aYYpyXLSk4aNXMFUwUwYDVR0RBEwwSoIJbG9jYWxob3N0ghFjbG91ZGZs +YXJlZC1oZWxsb4ISY2xvdWRmbGFyZWQyLWhlbGxvhwR/AAABhxAAAAAAAAAAAAAA +AAAAAAABMAkGByqGSM49BAEDaQAwZgIxAPxkdghH6y8xLMnY9Bom3Llf4NYM6yB9 +PD1YsaNUJTsxjTk3YY1Jsp+yzK0yUKtTZwIxAPcdvqCF2/iR9H288pCT1TgtO0a9 +cJL9RY1lq7DIGN37v1ZXReWaD+3hNokY8NriVg== -----END CERTIFICATE-----` ) diff --git a/tunneldns/https_proxy.go b/tunneldns/https_proxy.go new file mode 100644 index 00000000..3f6140b8 --- /dev/null +++ b/tunneldns/https_proxy.go @@ -0,0 +1,38 @@ +package tunneldns + +import ( + "github.com/coredns/coredns/plugin" + "github.com/miekg/dns" + "github.com/pkg/errors" + "golang.org/x/net/context" +) + +// Upstream is a simplified interface for proxy destination +type Upstream interface { + Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) +} + +// ProxyPlugin is a simplified DNS proxy using a generic upstream interface +type ProxyPlugin struct { + Upstreams []Upstream + Next plugin.Handler +} + +// ServeDNS implements interface for CoreDNS plugin +func (p ProxyPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + var reply *dns.Msg + var backendErr error + + for _, upstream := range p.Upstreams { + reply, backendErr = upstream.Exchange(ctx, r) + if backendErr == nil { + w.WriteMsg(reply) + return 0, nil + } + } + + return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams") +} + +// Name implements interface for CoreDNS plugin +func (p ProxyPlugin) Name() string { return "proxy" } diff --git a/tunneldns/https_upstream.go b/tunneldns/https_upstream.go new file mode 100644 index 00000000..eee448c4 --- /dev/null +++ b/tunneldns/https_upstream.go @@ -0,0 +1,97 @@ +package tunneldns + +import ( + "bytes" + "crypto/tls" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "time" + + "github.com/miekg/dns" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "golang.org/x/net/context" +) + +const ( + defaultTimeout = 5 * time.Second +) + +// UpstreamHTTPS is the upstream implementation for DNS over HTTPS service +type UpstreamHTTPS struct { + client *http.Client + endpoint *url.URL +} + +// NewUpstreamHTTPS creates a new DNS over HTTPS upstream from hostname +func NewUpstreamHTTPS(endpoint string) (Upstream, error) { + u, err := url.Parse(endpoint) + if err != nil { + return nil, err + } + + // Update TLS and HTTP client configuration + tls := &tls.Config{ServerName: u.Hostname()} + client := &http.Client{ + Timeout: time.Second * defaultTimeout, + Transport: &http.Transport{TLSClientConfig: tls}, + } + + return &UpstreamHTTPS{client: client, endpoint: u}, nil +} + +// Exchange provides an implementation for the Upstream interface +func (u *UpstreamHTTPS) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { + queryBuf, err := query.Pack() + if err != nil { + return nil, errors.Wrap(err, "failed to pack DNS query") + } + + // No content negotiation for now, use DNS wire format + buf, backendErr := u.exchangeWireformat(queryBuf) + if backendErr == nil { + response := &dns.Msg{} + if err := response.Unpack(buf); err != nil { + return nil, errors.Wrap(err, "failed to unpack DNS response from body") + } + + response.Id = query.Id + return response, nil + } + + log.WithError(backendErr).Errorf("failed to connect to an HTTPS backend %q", u.endpoint) + return nil, backendErr +} + +// Perform message exchange with the default UDP wireformat defined in current draft +// https://datatracker.ietf.org/doc/draft-ietf-doh-dns-over-https +func (u *UpstreamHTTPS) exchangeWireformat(msg []byte) ([]byte, error) { + req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg)) + if err != nil { + return nil, errors.Wrap(err, "failed to create an HTTPS request") + } + + req.Header.Add("Content-Type", "application/dns-udpwireformat") + req.Host = u.endpoint.Hostname() + + resp, err := u.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "failed to perform an HTTPS request") + } + + // Check response status code + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("returned status code %d", resp.StatusCode) + } + + // Read wireformat response from the body + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read the response body") + } + + return buf, nil +} diff --git a/tunneldns/metrics.go b/tunneldns/metrics.go new file mode 100644 index 00000000..5f688186 --- /dev/null +++ b/tunneldns/metrics.go @@ -0,0 +1,45 @@ +package tunneldns + +import ( + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/metrics/vars" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/rcode" + "github.com/coredns/coredns/request" + "github.com/miekg/dns" + "github.com/prometheus/client_golang/prometheus" + "golang.org/x/net/context" +) + +// MetricsPlugin is an adapter for CoreDNS and built-in metrics +type MetricsPlugin struct { + Next plugin.Handler +} + +// NewMetricsPlugin creates a plugin with configured metrics +func NewMetricsPlugin(next plugin.Handler) *MetricsPlugin { + prometheus.MustRegister(vars.RequestCount) + prometheus.MustRegister(vars.RequestDuration) + prometheus.MustRegister(vars.RequestSize) + prometheus.MustRegister(vars.RequestDo) + prometheus.MustRegister(vars.RequestType) + prometheus.MustRegister(vars.ResponseSize) + prometheus.MustRegister(vars.ResponseRcode) + return &MetricsPlugin{Next: next} +} + +// ServeDNS implements the CoreDNS plugin interface +func (p MetricsPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + + rw := dnstest.NewRecorder(w) + status, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, rw, r) + + // Update built-in metrics + vars.Report(state, ".", rcode.ToString(rw.Rcode), rw.Len, rw.Start) + + return status, err +} + +// Name implements the CoreDNS plugin interface +func (p MetricsPlugin) Name() string { return "metrics" } diff --git a/tunneldns/tunnel.go b/tunneldns/tunnel.go new file mode 100644 index 00000000..f4b5eed7 --- /dev/null +++ b/tunneldns/tunnel.go @@ -0,0 +1,144 @@ +package tunneldns + +import ( + "fmt" + "net" + "os" + "os/signal" + "strconv" + "sync" + "syscall" + + "gopkg.in/urfave/cli.v2" + + "github.com/cloudflare/cloudflared/metrics" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/cache" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +// Listener is an adapter between CoreDNS server and Warp runnable +type Listener struct { + server *dnsserver.Server + wg sync.WaitGroup +} + +// Run implements a foreground runner +func Run(c *cli.Context) error { + metricsListener, err := net.Listen("tcp", c.String("metrics")) + if err != nil { + log.WithError(err).Fatal("Failed to open the metrics listener") + } + + go metrics.ServeMetrics(metricsListener, nil) + + listener, err := CreateListener(c.String("address"), uint16(c.Uint("port")), c.StringSlice("upstream")) + if err != nil { + log.WithError(err).Errorf("Failed to create the listeners") + return err + } + + // Try to start the server + err = listener.Start() + if err != nil { + log.WithError(err).Errorf("Failed to start the listeners") + return listener.Stop() + } + + // Wait for signal + signals := make(chan os.Signal, 10) + signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) + defer signal.Stop(signals) + <-signals + + // Shut down server + err = listener.Stop() + if err != nil { + log.WithError(err).Errorf("failed to stop") + } + return err +} + +// Create a CoreDNS server plugin from configuration +func createConfig(address string, port uint16, p plugin.Handler) *dnsserver.Config { + c := &dnsserver.Config{ + Zone: ".", + Transport: "dns", + ListenHosts: []string{address}, + Port: strconv.FormatUint(uint64(port), 10), + } + + c.AddPlugin(func(next plugin.Handler) plugin.Handler { return p }) + return c +} + +// Start blocks for serving requests +func (l *Listener) Start() error { + log.WithField("addr", l.server.Address()).Infof("Starting DNS over HTTPS proxy server") + + // Start UDP listener + if udp, err := l.server.ListenPacket(); err == nil { + l.wg.Add(1) + go func() { + l.server.ServePacket(udp) + l.wg.Done() + }() + } else { + return errors.Wrap(err, "failed to create a UDP listener") + } + + // Start TCP listener + tcp, err := l.server.Listen() + if err == nil { + l.wg.Add(1) + go func() { + l.server.Serve(tcp) + l.wg.Done() + }() + } + + return errors.Wrap(err, "failed to create a TCP listener") +} + +// Stop signals server shutdown and blocks until completed +func (l *Listener) Stop() error { + if err := l.server.Stop(); err != nil { + return err + } + + l.wg.Wait() + return nil +} + +// CreateListener configures the server and bound sockets +func CreateListener(address string, port uint16, upstreams []string) (*Listener, error) { + // Build the list of upstreams + upstreamList := make([]Upstream, 0) + for _, url := range upstreams { + log.WithField("url", url).Infof("Adding DNS upstream") + upstream, err := NewUpstreamHTTPS(url) + if err != nil { + return nil, errors.Wrap(err, "failed to create HTTPS upstream") + } + upstreamList = append(upstreamList, upstream) + } + + // Create a local cache with HTTPS proxy plugin + chain := cache.New() + chain.Next = ProxyPlugin{ + Upstreams: upstreamList, + } + + // Format an endpoint + endpoint := fmt.Sprintf("dns://%s:%d", address, port) + + // Create the actual middleware server + server, err := dnsserver.NewServer(endpoint, []*dnsserver.Config{createConfig(address, port, NewMetricsPlugin(chain))}) + if err != nil { + return nil, err + } + + return &Listener{server: server}, nil +} diff --git a/tunnelrpc/pogs/tunnelrpc.go b/tunnelrpc/pogs/tunnelrpc.go index 4f5280e4..f70545b3 100644 --- a/tunnelrpc/pogs/tunnelrpc.go +++ b/tunnelrpc/pogs/tunnelrpc.go @@ -1,7 +1,7 @@ package pogs import ( - "github.com/cloudflare/cloudflare-warp/tunnelrpc" + "github.com/cloudflare/cloudflared/tunnelrpc" "golang.org/x/net/context" "zombiezen.com/go/capnproto2" "zombiezen.com/go/capnproto2/pogs" diff --git a/tunnelrpc/tunnelrpc.capnp b/tunnelrpc/tunnelrpc.capnp index 636f317a..3886f534 100644 --- a/tunnelrpc/tunnelrpc.capnp +++ b/tunnelrpc/tunnelrpc.capnp @@ -1,7 +1,7 @@ using Go = import "go.capnp"; @0xdb8274f9144abc7e; $Go.package("tunnelrpc"); -$Go.import("github.com/cloudflare/cloudflare-warp/tunnelrpc"); +$Go.import("github.com/cloudflare/cloudflared/tunnelrpc"); struct Authentication { key @0 :Text; @@ -35,7 +35,7 @@ struct RegistrationOptions { connectionId @6 :UInt8; # origin LAN IP originLocalIp @7 :Text; - # whether Warp client has been autoupdated + # whether Argo Tunnel client has been autoupdated isAutoupdated @8 :Bool; } diff --git a/validation/validation.go b/validation/validation.go index eb4eb90b..746bf9c0 100644 --- a/validation/validation.go +++ b/validation/validation.go @@ -119,7 +119,7 @@ func validateScheme(scheme string) error { return nil } } - return fmt.Errorf("Currently Cloudflare-Warp does not support %s protocol.", scheme) + return fmt.Errorf("Currently Argo Tunnel does not support %s protocol.", scheme) } func validateIP(scheme, host, port string) (string, error) { diff --git a/validation/validation_test.go b/validation/validation_test.go index 0be8066b..49a2cf10 100644 --- a/validation/validation_test.go +++ b/validation/validation_test.go @@ -126,7 +126,7 @@ func TestValidateUrl(t *testing.T) { assert.Equal(t, "https://hello.example.com", validUrl) validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt") - assert.Equal(t, "Currently Cloudflare-Warp does not support ftp protocol.", err.Error()) + assert.Equal(t, "Currently Argo Tunnel does not support ftp protocol.", err.Error()) assert.Empty(t, validUrl) validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")