diff --git a/cmd/cloudflared/generic_service.go b/cmd/cloudflared/generic_service.go
new file mode 100644
index 00000000..3eb4819e
--- /dev/null
+++ b/cmd/cloudflared/generic_service.go
@@ -0,0 +1,13 @@
+// +build !windows,!darwin,!linux
+
+package main
+
+import (
+ "os"
+
+ cli "gopkg.in/urfave/cli.v2"
+)
+
+func runApp(app *cli.App) {
+ app.Run(os.Args)
+}
diff --git a/cmd/cloudflared/hello.go b/cmd/cloudflared/hello.go
new file mode 100644
index 00000000..9f73ade8
--- /dev/null
+++ b/cmd/cloudflared/hello.go
@@ -0,0 +1,204 @@
+package main
+
+import (
+ "bytes"
+ "crypto/tls"
+ "encoding/json"
+ "fmt"
+ "html/template"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "os"
+ "time"
+
+ "github.com/gorilla/websocket"
+ "gopkg.in/urfave/cli.v2"
+
+ "github.com/cloudflare/cloudflared/tlsconfig"
+)
+
+type templateData struct {
+ ServerName string
+ Request *http.Request
+ Body string
+}
+
+type OriginUpTime struct {
+ StartTime time.Time `json:"startTime"`
+ UpTime string `json:"uptime"`
+}
+
+const defaultServerName = "the Argo Tunnel test server"
+const indexTemplate = `
+
+
+
+
+
+
+ Argo Tunnel Connection
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Congrats! You created your first tunnel!
+
+ Argo Tunnel exposes locally running applications to the internet by
+ running an encrypted, virtual tunnel from your laptop or server to
+ Cloudflare's edge network.
+
+
Ready for the next step?
+
+ Get started here
+
+
+ Request
+
+ Method: {{.Request.Method}}
+ Protocol: {{.Request.Proto}}
+ Request URL: {{.Request.URL}}
+ Transfer encoding: {{.Request.TransferEncoding}}
+ Host: {{.Request.Host}}
+ Remote address: {{.Request.RemoteAddr}}
+ Request URI: {{.Request.RequestURI}}
+{{range $key, $value := .Request.Header}}
+ Header: {{$key}}, Value: {{$value}}
+{{end}}
+ Body: {{.Body}}
+
+
+
+
+
+
+`
+
+func hello(c *cli.Context) error {
+ address := fmt.Sprintf(":%d", c.Int("port"))
+ listener, err := createListener(address)
+ if err != nil {
+ return err
+ }
+ defer listener.Close()
+ err = startHelloWorldServer(listener, nil)
+ return err
+}
+
+func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
+ Log.Infof("Starting Hello World server at %s", listener.Addr())
+ serverName := defaultServerName
+ if hostname, err := os.Hostname(); err == nil {
+ serverName = hostname
+ }
+
+ upgrader := websocket.Upgrader{
+ ReadBufferSize: 1024,
+ WriteBufferSize: 1024,
+ }
+
+ httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
+ go func() {
+ <-shutdownC
+ httpServer.Close()
+ }()
+
+ http.HandleFunc("/uptime", uptimeHandler(time.Now()))
+ http.HandleFunc("/ws", websocketHandler(upgrader))
+ http.HandleFunc("/", rootHandler(serverName))
+ err := httpServer.Serve(listener)
+ return err
+}
+
+func createListener(address string) (net.Listener, error) {
+ certificate, err := tlsconfig.GetHelloCertificate()
+ if err != nil {
+ return nil, err
+ }
+
+ // If the port in address is empty, a port number is automatically chosen
+ listener, err := tls.Listen(
+ "tcp",
+ address,
+ &tls.Config{Certificates: []tls.Certificate{certificate}})
+
+ return listener, err
+}
+
+func uptimeHandler(startTime time.Time) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ // Note that if autoupdate is enabled, the uptime is reset when a new client
+ // release is available
+ resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()}
+ respJson, err := json.Marshal(resp)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ } else {
+ w.Header().Set("Content-Type", "application/json")
+ w.Write(respJson)
+ }
+ }
+}
+
+func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ return
+ }
+ defer conn.Close()
+
+ for {
+ mt, message, err := conn.ReadMessage()
+ if err != nil {
+ break
+ }
+
+ if err := conn.WriteMessage(mt, message); err != nil {
+ break
+ }
+ }
+ }
+}
+
+func rootHandler(serverName string) http.HandlerFunc {
+ responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
+ return func(w http.ResponseWriter, r *http.Request) {
+ var buffer bytes.Buffer
+ var body string
+ rawBody, err := ioutil.ReadAll(r.Body)
+ if err == nil {
+ body = string(rawBody)
+ } else {
+ body = ""
+ }
+ err = responseTemplate.Execute(&buffer, &templateData{
+ ServerName: serverName,
+ Request: r,
+ Body: body,
+ })
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ fmt.Fprintf(w, "error: %v", err)
+ } else {
+ buffer.WriteTo(w)
+ }
+ }
+}
diff --git a/cmd/cloudflared/hello_test.go b/cmd/cloudflared/hello_test.go
new file mode 100644
index 00000000..f6e8842a
--- /dev/null
+++ b/cmd/cloudflared/hello_test.go
@@ -0,0 +1,35 @@
+package main
+
+import (
+ "testing"
+)
+
+func TestCreateListenerHostAndPortSuccess(t *testing.T) {
+ listener, err := createListener("localhost:1234")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if listener.Addr().String() == "" {
+ t.Fatal("Fail to find available port")
+ }
+}
+
+func TestCreateListenerOnlyHostSuccess(t *testing.T) {
+ listener, err := createListener("localhost:")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if listener.Addr().String() == "" {
+ t.Fatal("Fail to find available port")
+ }
+}
+
+func TestCreateListenerOnlyPortSuccess(t *testing.T) {
+ listener, err := createListener(":8888")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if listener.Addr().String() == "" {
+ t.Fatal("Fail to find available port")
+ }
+}
diff --git a/cmd/cloudflared/linux_service.go b/cmd/cloudflared/linux_service.go
new file mode 100644
index 00000000..d0b15937
--- /dev/null
+++ b/cmd/cloudflared/linux_service.go
@@ -0,0 +1,292 @@
+// +build linux
+
+package main
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+
+ cli "gopkg.in/urfave/cli.v2"
+)
+
+func runApp(app *cli.App) {
+ app.Commands = append(app.Commands, &cli.Command{
+ Name: "service",
+ Usage: "Manages the Argo Tunnel system service",
+ Subcommands: []*cli.Command{
+ &cli.Command{
+ Name: "install",
+ Usage: "Install Argo Tunnel as a system service",
+ Action: installLinuxService,
+ },
+ &cli.Command{
+ Name: "uninstall",
+ Usage: "Uninstall the Argo Tunnel service",
+ Action: uninstallLinuxService,
+ },
+ },
+ })
+ app.Run(os.Args)
+}
+
+const serviceConfigDir = "/etc/cloudflared"
+
+var systemdTemplates = []ServiceTemplate{
+ {
+ Path: "/etc/systemd/system/cloudflared.service",
+ Content: `[Unit]
+Description=Argo Tunnel
+After=network.target
+
+[Service]
+TimeoutStartSec=0
+Type=notify
+ExecStart={{ .Path }} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --no-autoupdate
+Restart=on-failure
+RestartSec=5s
+
+[Install]
+WantedBy=multi-user.target
+`,
+ },
+ {
+ Path: "/etc/systemd/system/cloudflared-update.service",
+ Content: `[Unit]
+Description=Update Argo Tunnel
+After=network.target
+
+[Service]
+ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 64 ]; then systemctl restart cloudflared; exit 0; fi; exit $code'
+`,
+ },
+ {
+ Path: "/etc/systemd/system/cloudflared-update.timer",
+ Content: `[Unit]
+Description=Update Argo Tunnel
+
+[Timer]
+OnUnitActiveSec=1d
+
+[Install]
+WantedBy=timers.target
+`,
+ },
+}
+
+var sysvTemplate = ServiceTemplate{
+ Path: "/etc/init.d/cloudflared",
+ FileMode: 0755,
+ Content: `# For RedHat and cousins:
+# chkconfig: 2345 99 01
+# description: Argo Tunnel agent
+# processname: {{.Path}}
+### BEGIN INIT INFO
+# Provides: {{.Path}}
+# Required-Start:
+# Required-Stop:
+# Default-Start: 2 3 4 5
+# Default-Stop: 0 1 6
+# Short-Description: Argo Tunnel
+# Description: Argo Tunnel agent
+### END INIT INFO
+cmd="{{.Path}} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --pidfile /var/run/$name.pid --autoupdate-freq 24h0m0s"
+name=$(basename $(readlink -f $0))
+pid_file="/var/run/$name.pid"
+stdout_log="/var/log/$name.log"
+stderr_log="/var/log/$name.err"
+[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name
+get_pid() {
+ cat "$pid_file"
+}
+is_running() {
+ [ -f "$pid_file" ] && ps $(get_pid) > /dev/null 2>&1
+}
+case "$1" in
+ start)
+ if is_running; then
+ echo "Already started"
+ else
+ echo "Starting $name"
+ $cmd >> "$stdout_log" 2>> "$stderr_log" &
+ echo $! > "$pid_file"
+ if ! is_running; then
+ echo "Unable to start, see $stdout_log and $stderr_log"
+ exit 1
+ fi
+ fi
+ ;;
+ stop)
+ if is_running; then
+ echo -n "Stopping $name.."
+ kill $(get_pid)
+ for i in {1..10}
+ do
+ if ! is_running; then
+ break
+ fi
+ echo -n "."
+ sleep 1
+ done
+ echo
+ if is_running; then
+ echo "Not stopped; may still be shutting down or shutdown may have failed"
+ exit 1
+ else
+ echo "Stopped"
+ if [ -f "$pid_file" ]; then
+ rm "$pid_file"
+ fi
+ fi
+ else
+ echo "Not running"
+ fi
+ ;;
+ restart)
+ $0 stop
+ if is_running; then
+ echo "Unable to stop, will not attempt to start"
+ exit 1
+ fi
+ $0 start
+ ;;
+ status)
+ if is_running; then
+ echo "Running"
+ else
+ echo "Stopped"
+ exit 1
+ fi
+ ;;
+ *)
+ echo "Usage: $0 {start|stop|restart|status}"
+ exit 1
+ ;;
+esac
+exit 0
+`,
+}
+
+func isSystemd() bool {
+ if _, err := os.Stat("/run/systemd/system"); err == nil {
+ return true
+ }
+ return false
+}
+
+func installLinuxService(c *cli.Context) error {
+ etPath, err := os.Executable()
+ if err != nil {
+ return fmt.Errorf("error determining executable path: %v", err)
+ }
+ templateArgs := ServiceTemplateArgs{Path: etPath}
+
+ defaultConfigDir := filepath.Dir(c.String("config"))
+ defaultConfigFile := filepath.Base(c.String("config"))
+ if err = copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile); err != nil {
+ Log.WithError(err).Infof("Failed to copy user configuration. Before running the service, ensure that %s contains two files, %s and %s",
+ serviceConfigDir, credentialFile, defaultConfigFiles[0])
+ return err
+ }
+
+ switch {
+ case isSystemd():
+ Log.Infof("Using Systemd")
+ return installSystemd(&templateArgs)
+ default:
+ Log.Infof("Using Sysv")
+ return installSysv(&templateArgs)
+ }
+}
+
+func installSystemd(templateArgs *ServiceTemplateArgs) error {
+ for _, serviceTemplate := range systemdTemplates {
+ err := serviceTemplate.Generate(templateArgs)
+ if err != nil {
+ Log.WithError(err).Infof("error generating service template")
+ return err
+ }
+ }
+ if err := runCommand("systemctl", "enable", "cloudflared.service"); err != nil {
+ Log.WithError(err).Infof("systemctl enable cloudflared.service error")
+ return err
+ }
+ if err := runCommand("systemctl", "start", "cloudflared-update.timer"); err != nil {
+ Log.WithError(err).Infof("systemctl start cloudflared-update.timer error")
+ return err
+ }
+ Log.Infof("systemctl daemon-reload")
+ return runCommand("systemctl", "daemon-reload")
+}
+
+func installSysv(templateArgs *ServiceTemplateArgs) error {
+ confPath, err := sysvTemplate.ResolvePath()
+ if err != nil {
+ Log.WithError(err).Infof("error resolving system path")
+ return err
+ }
+ if err := sysvTemplate.Generate(templateArgs); err != nil {
+ Log.WithError(err).Infof("error generating system template")
+ return err
+ }
+ for _, i := range [...]string{"2", "3", "4", "5"} {
+ if err := os.Symlink(confPath, "/etc/rc"+i+".d/S50et"); err != nil {
+ continue
+ }
+ }
+ for _, i := range [...]string{"0", "1", "6"} {
+ if err := os.Symlink(confPath, "/etc/rc"+i+".d/K02et"); err != nil {
+ continue
+ }
+ }
+ return nil
+}
+
+func uninstallLinuxService(c *cli.Context) error {
+ switch {
+ case isSystemd():
+ Log.Infof("Using Systemd")
+ return uninstallSystemd()
+ default:
+ Log.Infof("Using Sysv")
+ return uninstallSysv()
+ }
+}
+
+func uninstallSystemd() error {
+ if err := runCommand("systemctl", "disable", "cloudflared.service"); err != nil {
+ Log.WithError(err).Infof("systemctl disable cloudflared.service error")
+ return err
+ }
+ if err := runCommand("systemctl", "stop", "cloudflared-update.timer"); err != nil {
+ Log.WithError(err).Infof("systemctl stop cloudflared-update.timer error")
+ return err
+ }
+ for _, serviceTemplate := range systemdTemplates {
+ if err := serviceTemplate.Remove(); err != nil {
+ Log.WithError(err).Infof("error removing service template")
+ return err
+ }
+ }
+ Log.Infof("Successfully uninstall cloudflared service")
+ return nil
+}
+
+func uninstallSysv() error {
+ if err := sysvTemplate.Remove(); err != nil {
+ Log.WithError(err).Infof("error removing service template")
+ return err
+ }
+ for _, i := range [...]string{"2", "3", "4", "5"} {
+ if err := os.Remove("/etc/rc" + i + ".d/S50et"); err != nil {
+ continue
+ }
+ }
+ for _, i := range [...]string{"0", "1", "6"} {
+ if err := os.Remove("/etc/rc" + i + ".d/K02et"); err != nil {
+ continue
+ }
+ }
+ Log.Infof("Successfully uninstall cloudflared service")
+ return nil
+}
diff --git a/cmd/cloudflared/login.go b/cmd/cloudflared/login.go
new file mode 100644
index 00000000..fa3c675e
--- /dev/null
+++ b/cmd/cloudflared/login.go
@@ -0,0 +1,194 @@
+package main
+
+import (
+ "crypto/rand"
+ "encoding/base32"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "syscall"
+ "time"
+
+ homedir "github.com/mitchellh/go-homedir"
+ cli "gopkg.in/urfave/cli.v2"
+)
+
+const baseLoginURL = "https://www.cloudflare.com/a/warp"
+const baseCertStoreURL = "https://login.cloudflarewarp.com"
+const clientTimeout = time.Minute * 20
+
+func login(c *cli.Context) error {
+ configPath, err := homedir.Expand(defaultConfigDirs[0])
+ if err != nil {
+ return err
+ }
+ ok, err := fileExists(configPath)
+ if !ok && err == nil {
+ // create config directory if doesn't already exist
+ err = os.Mkdir(configPath, 0700)
+ }
+ if err != nil {
+ return err
+ }
+ path := filepath.Join(configPath, credentialFile)
+ fileInfo, err := os.Stat(path)
+ if err == nil && fileInfo.Size() > 0 {
+ fmt.Fprintf(os.Stderr, `You have an existing certificate at %s which login would overwrite.
+If this is intentional, please move or delete that file then run this command again.
+`, path)
+ return nil
+ }
+ if err != nil && err.(*os.PathError).Err != syscall.ENOENT {
+ return err
+ }
+
+ // for local debugging
+ baseURL := baseCertStoreURL
+ if c.IsSet("url") {
+ baseURL = c.String("url")
+ }
+ // Generate a random post URL
+ certURL := baseURL + generateRandomPath()
+ loginURL, err := url.Parse(baseLoginURL)
+ if err != nil {
+ // shouldn't happen, URL is hardcoded
+ return err
+ }
+ loginURL.RawQuery = "callback=" + url.QueryEscape(certURL)
+
+ err = open(loginURL.String())
+ if err != nil {
+ fmt.Fprintf(os.Stderr, `Please open the following URL and log in with your Cloudflare account:
+
+%s
+
+Leave cloudflared running to install the certificate automatically.
+`, loginURL.String())
+ } else {
+ fmt.Fprintf(os.Stderr, `A browser window should have opened at the following URL:
+
+%s
+
+If the browser failed to open, open it yourself and visit the URL above.
+
+`, loginURL.String())
+ }
+
+ if download(certURL, path) {
+ fmt.Fprintf(os.Stderr, `You have successfully logged in.
+If you wish to copy your credentials to a server, they have been saved to:
+%s
+`, path)
+ } else {
+ fmt.Fprintf(os.Stderr, `Failed to write the certificate due to the following error:
+%v
+
+Your browser will download the certificate instead. You will have to manually
+copy it to the following path:
+
+%s
+
+`, err, path)
+ }
+ return nil
+}
+
+// generateRandomPath generates a random URL to associate with the certificate.
+func generateRandomPath() string {
+ randomBytes := make([]byte, 40)
+ _, err := rand.Read(randomBytes)
+ if err != nil {
+ panic(err)
+ }
+ return "/" + base32.StdEncoding.EncodeToString(randomBytes)
+}
+
+// open opens the specified URL in the default browser of the user.
+func open(url string) error {
+ var cmd string
+ var args []string
+
+ switch runtime.GOOS {
+ case "windows":
+ cmd = "cmd"
+ args = []string{"/c", "start"}
+ case "darwin":
+ cmd = "open"
+ default: // "linux", "freebsd", "openbsd", "netbsd"
+ cmd = "xdg-open"
+ }
+ args = append(args, url)
+ return exec.Command(cmd, args...).Start()
+}
+
+func download(certURL, filePath string) bool {
+ client := &http.Client{Timeout: clientTimeout}
+ // attempt a (long-running) certificate get
+ for i := 0; i < 20; i++ {
+ ok, err := tryDownload(client, certURL, filePath)
+ if ok {
+ putSuccess(client, certURL)
+ return true
+ }
+ if err != nil {
+ Log.WithError(err).Error("Error fetching certificate")
+ return false
+ }
+ }
+ return false
+}
+
+func tryDownload(client *http.Client, certURL, filePath string) (ok bool, err error) {
+ resp, err := client.Get(certURL)
+ if err != nil {
+ return false, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode == 404 {
+ return false, nil
+ }
+ if resp.StatusCode != 200 {
+ return false, fmt.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
+ }
+ if resp.Header.Get("Content-Type") != "application/x-pem-file" {
+ return false, fmt.Errorf("Unexpected content type %s", resp.Header.Get("Content-Type"))
+ }
+ // write response
+ file, err := os.Create(filePath)
+ if err != nil {
+ return false, err
+ }
+ defer file.Close()
+ written, err := io.Copy(file, resp.Body)
+ switch {
+ case err != nil:
+ return false, err
+ case resp.ContentLength != written && resp.ContentLength != -1:
+ return false, fmt.Errorf("Short read (%d bytes) from server while writing certificate", written)
+ default:
+ return true, nil
+ }
+}
+
+func putSuccess(client *http.Client, certURL string) {
+ // indicate success to the relay server
+ req, err := http.NewRequest("PUT", certURL+"/ok", nil)
+ if err != nil {
+ Log.WithError(err).Error("HTTP request error")
+ return
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ Log.WithError(err).Error("HTTP error")
+ return
+ }
+ resp.Body.Close()
+ if resp.StatusCode != 200 {
+ Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
+ }
+}
diff --git a/cmd/cloudflared/macos_service.go b/cmd/cloudflared/macos_service.go
new file mode 100644
index 00000000..95be8454
--- /dev/null
+++ b/cmd/cloudflared/macos_service.go
@@ -0,0 +1,97 @@
+// +build darwin
+
+package main
+
+import (
+ "fmt"
+ "os"
+
+ "gopkg.in/urfave/cli.v2"
+)
+
+const launchAgentIdentifier = "com.cloudflare.cloudflared"
+
+func runApp(app *cli.App) {
+ app.Commands = append(app.Commands, &cli.Command{
+ Name: "service",
+ Usage: "Manages the Argo Tunnel launch agent",
+ Subcommands: []*cli.Command{
+ {
+ Name: "install",
+ Usage: "Install Argo Tunnel as an user launch agent",
+ Action: installLaunchd,
+ },
+ {
+ Name: "uninstall",
+ Usage: "Uninstall the Argo Tunnel launch agent",
+ Action: uninstallLaunchd,
+ },
+ },
+ })
+ app.Run(os.Args)
+}
+
+var launchdTemplate = ServiceTemplate{
+ Path: fmt.Sprintf("~/Library/LaunchAgents/%s.plist", launchAgentIdentifier),
+ Content: fmt.Sprintf(`
+
+
+
+ Label
+ %s
+ Program
+ {{ .Path }}
+ RunAtLoad
+
+ StandardOutPath
+ /tmp/%s.out.log
+ StandardErrorPath
+ /tmp/%s.err.log
+ KeepAlive
+
+ NetworkState
+
+
+ ThrottleInterval
+ 20
+
+ `, launchAgentIdentifier, launchAgentIdentifier, launchAgentIdentifier),
+}
+
+func installLaunchd(c *cli.Context) error {
+ Log.Infof("Installing Argo Tunnel as an user launch agent")
+ etPath, err := os.Executable()
+ if err != nil {
+ Log.WithError(err).Infof("error determining executable path")
+ return fmt.Errorf("error determining executable path: %v", err)
+ }
+ templateArgs := ServiceTemplateArgs{Path: etPath}
+ err = launchdTemplate.Generate(&templateArgs)
+ if err != nil {
+ Log.WithError(err).Infof("error generating launchd template")
+ return err
+ }
+ plistPath, err := launchdTemplate.ResolvePath()
+ if err != nil {
+ Log.WithError(err).Infof("error resolving launchd template path")
+ return err
+ }
+ Log.Infof("Outputs are logged in %s and %s", fmt.Sprintf("/tmp/%s.out.log", launchAgentIdentifier), fmt.Sprintf("/tmp/%s.err.log", launchAgentIdentifier))
+ return runCommand("launchctl", "load", plistPath)
+}
+
+func uninstallLaunchd(c *cli.Context) error {
+ Log.Infof("Uninstalling Argo Tunnel as an user launch agent")
+ plistPath, err := launchdTemplate.ResolvePath()
+ if err != nil {
+ Log.WithError(err).Infof("error resolving launchd template path")
+ return err
+ }
+ err = runCommand("launchctl", "unload", plistPath)
+ if err != nil {
+ Log.WithError(err).Infof("error unloading")
+ return err
+ }
+ Log.Infof("Outputs are logged in %s and %s", fmt.Sprintf("/tmp/%s.out.log", launchAgentIdentifier), fmt.Sprintf("/tmp/%s.err.log", launchAgentIdentifier))
+ return launchdTemplate.Remove()
+}
diff --git a/cmd/cloudflared/main.go b/cmd/cloudflared/main.go
new file mode 100644
index 00000000..06bf641f
--- /dev/null
+++ b/cmd/cloudflared/main.go
@@ -0,0 +1,794 @@
+package main
+
+import (
+ "crypto/tls"
+ "encoding/hex"
+ "fmt"
+ "io/ioutil"
+ "math/rand"
+ "net"
+ "net/http"
+ "os"
+ "os/signal"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/cloudflare/cloudflared/metrics"
+ "github.com/cloudflare/cloudflared/origin"
+ "github.com/cloudflare/cloudflared/tlsconfig"
+ "github.com/cloudflare/cloudflared/tunneldns"
+ tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
+ "github.com/cloudflare/cloudflared/validation"
+
+ "github.com/facebookgo/grace/gracenet"
+ "github.com/getsentry/raven-go"
+ "github.com/mitchellh/go-homedir"
+ "github.com/rifflock/lfshook"
+ "github.com/sirupsen/logrus"
+ "golang.org/x/crypto/ssh/terminal"
+ "gopkg.in/urfave/cli.v2"
+ "gopkg.in/urfave/cli.v2/altsrc"
+
+ "github.com/coreos/go-systemd/daemon"
+ "github.com/pkg/errors"
+)
+
+const (
+ sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
+ credentialFile = "cert.pem"
+ quickStartUrl = "https://developers.cloudflare.com/argo-tunnel/quickstart/quickstart/"
+ noAutoupdateMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/"
+ licenseUrl = "https://developers.cloudflare.com/argo-tunnel/licence/"
+)
+
+var listeners = gracenet.Net{}
+var Version = "DEV"
+var BuildTime = "unknown"
+var Log *logrus.Logger
+var defaultConfigFiles = []string{"config.yml", "config.yaml"}
+
+// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible
+var defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp"}
+
+// Shutdown channel used by the app. When closed, app must terminate.
+// May be closed by the Windows service runner.
+var shutdownC chan struct{}
+
+type BuildAndRuntimeInfo struct {
+ GoOS string `json:"go_os"`
+ GoVersion string `json:"go_version"`
+ GoArch string `json:"go_arch"`
+ WarpVersion string `json:"warp_version"`
+ WarpFlags map[string]interface{} `json:"warp_flags"`
+ WarpEnvs map[string]string `json:"warp_envs"`
+}
+
+func main() {
+ metrics.RegisterBuildInfo(BuildTime, Version)
+ raven.SetDSN(sentryDSN)
+ raven.SetRelease(Version)
+ shutdownC = make(chan struct{})
+ app := &cli.App{}
+ app.Name = "cloudflared"
+ app.Copyright = fmt.Sprintf(`(c) %d Cloudflare Inc.
+ Use is subject to the license agreement at %s`, time.Now().Year(), licenseUrl)
+ app.Usage = "Cloudflare reverse tunnelling proxy agent"
+ app.ArgsUsage = "origin-url"
+ app.Version = fmt.Sprintf("%s (built %s)", Version, BuildTime)
+ app.Description = `A reverse tunnel proxy agent that connects to Cloudflare's infrastructure.
+ Upon connecting, you are assigned a unique subdomain on cftunnel.com.
+ You need to specify a hostname on a zone you control.
+ A DNS record will be created to CNAME your hostname to the unique subdomain on cftunnel.com.
+
+ Requests made to Cloudflare's servers for your hostname will be proxied
+ through the tunnel to your local webserver.`
+ app.Flags = []cli.Flag{
+ &cli.StringFlag{
+ Name: "config",
+ Usage: "Specifies a config file in YAML format.",
+ Value: findDefaultConfigPath(),
+ },
+ altsrc.NewDurationFlag(&cli.DurationFlag{
+ Name: "autoupdate-freq",
+ Usage: "Autoupdate frequency. Default is 24h.",
+ Value: time.Hour * 24,
+ }),
+ altsrc.NewBoolFlag(&cli.BoolFlag{
+ Name: "no-autoupdate",
+ Usage: "Disable periodic check for updates, restarting the server with the new version.",
+ Value: false,
+ }),
+ altsrc.NewBoolFlag(&cli.BoolFlag{
+ Name: "is-autoupdated",
+ Usage: "Signal the new process that Argo Tunnel client has been autoupdated",
+ Value: false,
+ Hidden: true,
+ }),
+ altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
+ Name: "edge",
+ Usage: "Address of the Cloudflare tunnel server.",
+ EnvVars: []string{"TUNNEL_EDGE"},
+ Hidden: true,
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "cacert",
+ Usage: "Certificate Authority authenticating the Cloudflare tunnel connection.",
+ EnvVars: []string{"TUNNEL_CACERT"},
+ Hidden: true,
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "origincert",
+ Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
+ EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
+ Value: findDefaultOriginCertPath(),
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "url",
+ Value: "https://localhost:8080",
+ Usage: "Connect to the local webserver at `URL`.",
+ EnvVars: []string{"TUNNEL_URL"},
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "hostname",
+ Usage: "Set a hostname on a Cloudflare zone to route traffic through this tunnel.",
+ EnvVars: []string{"TUNNEL_HOSTNAME"},
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "origin-server-name",
+ Usage: "Hostname on the origin server certificate.",
+ EnvVars: []string{"TUNNEL_ORIGIN_SERVER_NAME"},
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "id",
+ Usage: "A unique identifier used to tie connections to this tunnel instance.",
+ EnvVars: []string{"TUNNEL_ID"},
+ Hidden: true,
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "lb-pool",
+ Usage: "The name of a (new/existing) load balancing pool to add this origin to.",
+ EnvVars: []string{"TUNNEL_LB_POOL"},
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "api-key",
+ Usage: "This parameter has been deprecated since version 2017.10.1.",
+ EnvVars: []string{"TUNNEL_API_KEY"},
+ Hidden: true,
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "api-email",
+ Usage: "This parameter has been deprecated since version 2017.10.1.",
+ EnvVars: []string{"TUNNEL_API_EMAIL"},
+ Hidden: true,
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "api-ca-key",
+ Usage: "This parameter has been deprecated since version 2017.10.1.",
+ EnvVars: []string{"TUNNEL_API_CA_KEY"},
+ Hidden: true,
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "metrics",
+ Value: "localhost:",
+ Usage: "Listen address for metrics reporting.",
+ EnvVars: []string{"TUNNEL_METRICS"},
+ }),
+ altsrc.NewDurationFlag(&cli.DurationFlag{
+ Name: "metrics-update-freq",
+ Usage: "Frequency to update tunnel metrics",
+ Value: time.Second * 5,
+ EnvVars: []string{"TUNNEL_METRICS_UPDATE_FREQ"},
+ }),
+ altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
+ Name: "tag",
+ Usage: "Custom tags used to identify this tunnel, in format `KEY=VALUE`. Multiple tags may be specified",
+ EnvVars: []string{"TUNNEL_TAG"},
+ }),
+ altsrc.NewDurationFlag(&cli.DurationFlag{
+ Name: "heartbeat-interval",
+ Usage: "Minimum idle time before sending a heartbeat.",
+ Value: time.Second * 5,
+ Hidden: true,
+ }),
+ altsrc.NewUint64Flag(&cli.Uint64Flag{
+ Name: "heartbeat-count",
+ Usage: "Minimum number of unacked heartbeats to send before closing the connection.",
+ Value: 5,
+ Hidden: true,
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "loglevel",
+ Value: "info",
+ Usage: "Application logging level {panic, fatal, error, warn, info, debug}",
+ EnvVars: []string{"TUNNEL_LOGLEVEL"},
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "proto-loglevel",
+ Value: "warn",
+ Usage: "Protocol logging level {panic, fatal, error, warn, info, debug}",
+ EnvVars: []string{"TUNNEL_PROTO_LOGLEVEL"},
+ }),
+ altsrc.NewUintFlag(&cli.UintFlag{
+ Name: "retries",
+ Value: 5,
+ Usage: "Maximum number of retries for connection/protocol errors.",
+ EnvVars: []string{"TUNNEL_RETRIES"},
+ }),
+ altsrc.NewBoolFlag(&cli.BoolFlag{
+ Name: "hello-world",
+ Value: false,
+ Usage: "Run Hello World Server",
+ EnvVars: []string{"TUNNEL_HELLO_WORLD"},
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "pidfile",
+ Usage: "Write the application's PID to this file after first successful connection.",
+ EnvVars: []string{"TUNNEL_PIDFILE"},
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "logfile",
+ Usage: "Save application log to this file for reporting issues.",
+ EnvVars: []string{"TUNNEL_LOGFILE"},
+ }),
+ altsrc.NewIntFlag(&cli.IntFlag{
+ Name: "ha-connections",
+ Value: 4,
+ Hidden: true,
+ }),
+ altsrc.NewDurationFlag(&cli.DurationFlag{
+ Name: "proxy-connect-timeout",
+ Usage: "HTTP proxy timeout for establishing a new connection",
+ Value: time.Second * 30,
+ }),
+ altsrc.NewDurationFlag(&cli.DurationFlag{
+ Name: "proxy-tls-timeout",
+ Usage: "HTTP proxy timeout for completing a TLS handshake",
+ Value: time.Second * 10,
+ }),
+ altsrc.NewDurationFlag(&cli.DurationFlag{
+ Name: "proxy-tcp-keepalive",
+ Usage: "HTTP proxy TCP keepalive duration",
+ Value: time.Second * 30,
+ }),
+ altsrc.NewBoolFlag(&cli.BoolFlag{
+ Name: "proxy-no-happy-eyeballs",
+ Usage: "HTTP proxy should disable \"happy eyeballs\" for IPv4/v6 fallback",
+ }),
+ altsrc.NewIntFlag(&cli.IntFlag{
+ Name: "proxy-keepalive-connections",
+ Usage: "HTTP proxy maximum keepalive connection pool size",
+ Value: 100,
+ }),
+ altsrc.NewDurationFlag(&cli.DurationFlag{
+ Name: "proxy-keepalive-timeout",
+ Usage: "HTTP proxy timeout for closing an idle connection",
+ Value: time.Second * 90,
+ }),
+ altsrc.NewBoolFlag(&cli.BoolFlag{
+ Name: "proxy-dns",
+ Usage: "Run a DNS over HTTPS proxy server.",
+ EnvVars: []string{"TUNNEL_DNS"},
+ }),
+ altsrc.NewUintFlag(&cli.UintFlag{
+ Name: "proxy-dns-port",
+ Value: 53,
+ Usage: "Listen on given port for the DNS over HTTPS proxy server.",
+ EnvVars: []string{"TUNNEL_DNS_PORT"},
+ }),
+ altsrc.NewStringFlag(&cli.StringFlag{
+ Name: "proxy-dns-address",
+ Usage: "Listen address for the DNS over HTTPS proxy server.",
+ Value: "localhost",
+ EnvVars: []string{"TUNNEL_DNS_ADDRESS"},
+ }),
+ altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
+ Name: "proxy-dns-upstream",
+ Usage: "Upstream endpoint URL, you can specify multiple endpoints for redundancy.",
+ Value: cli.NewStringSlice("https://dns.cloudflare.com/.well-known/dns-query"),
+ EnvVars: []string{"TUNNEL_DNS_UPSTREAM"},
+ }),
+ }
+ app.Action = func(c *cli.Context) error {
+ raven.CapturePanic(func() { startServer(c) }, nil)
+ return nil
+ }
+ app.Before = func(context *cli.Context) error {
+ Log = logrus.New()
+ inputSource, err := findInputSourceContext(context)
+ if err != nil {
+ Log.WithError(err).Infof("Cannot load configuration from %s", context.String("config"))
+ return err
+ } else if inputSource != nil {
+ err := altsrc.ApplyInputSourceValues(context, inputSource, app.Flags)
+ if err != nil {
+ Log.WithError(err).Infof("Cannot apply configuration from %s", context.String("config"))
+ return err
+ }
+ Log.Infof("Applied configuration from %s", context.String("config"))
+ }
+ return nil
+ }
+ app.Commands = []*cli.Command{
+ {
+ Name: "update",
+ Action: update,
+ Usage: "Update the agent if a new version exists",
+ ArgsUsage: " ",
+ Description: `Looks for a new version on the offical download server.
+ If a new version exists, updates the agent binary and quits.
+ Otherwise, does nothing.
+
+ To determine if an update happened in a script, check for error code 64.`,
+ },
+ {
+ Name: "login",
+ Action: login,
+ Usage: "Generate a configuration file with your login details",
+ ArgsUsage: " ",
+ Flags: []cli.Flag{
+ &cli.StringFlag{
+ Name: "url",
+ Hidden: true,
+ },
+ },
+ },
+ {
+ Name: "hello",
+ Action: hello,
+ Usage: "Run a simple \"Hello World\" server for testing Argo Tunnel.",
+ Flags: []cli.Flag{
+ &cli.IntFlag{
+ Name: "port",
+ Usage: "Listen on the selected port.",
+ Value: 8080,
+ },
+ },
+ ArgsUsage: " ", // can't be the empty string or we get the default output
+ },
+ {
+ Name: "proxy-dns",
+ Action: tunneldns.Run,
+ Usage: "Run a DNS over HTTPS proxy server.",
+ Flags: []cli.Flag{
+ &cli.StringFlag{
+ Name: "metrics",
+ Value: "localhost:",
+ Usage: "Listen address for metrics reporting.",
+ EnvVars: []string{"TUNNEL_METRICS"},
+ },
+ &cli.StringFlag{
+ Name: "address",
+ Usage: "Listen address for the DNS over HTTPS proxy server.",
+ Value: "localhost",
+ EnvVars: []string{"TUNNEL_DNS_ADDRESS"},
+ },
+ &cli.IntFlag{
+ Name: "port",
+ Usage: "Listen on given port for the DNS over HTTPS proxy server.",
+ Value: 53,
+ EnvVars: []string{"TUNNEL_DNS_PORT"},
+ },
+ &cli.StringSliceFlag{
+ Name: "upstream",
+ Usage: "Upstream endpoint URL, you can specify multiple endpoints for redundancy.",
+ Value: cli.NewStringSlice("https://dns.cloudflare.com/.well-known/dns-query"),
+ EnvVars: []string{"TUNNEL_DNS_UPSTREAM"},
+ },
+ },
+ ArgsUsage: " ", // can't be the empty string or we get the default output
+ },
+ }
+ runApp(app)
+}
+
+func startServer(c *cli.Context) {
+ var wg sync.WaitGroup
+ errC := make(chan error)
+ wg.Add(2)
+
+ // If the user choose to supply all options through env variables,
+ // c.NumFlags() == 0 && c.NArg() == 0. For cloudflared to work, the user needs to at
+ // least provide a hostname.
+ if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" {
+ Log.Infof("No arguments were provided. You need to at least specify the hostname for this tunnel. See %s", quickStartUrl)
+ cli.ShowAppHelp(c)
+ return
+ }
+ logLevel, err := logrus.ParseLevel(c.String("loglevel"))
+ if err != nil {
+ Log.WithError(err).Fatal("Unknown logging level specified")
+ }
+ Log.SetLevel(logLevel)
+
+ protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
+ if err != nil {
+ Log.WithError(err).Fatal("Unknown protocol logging level specified")
+ }
+ protoLogger := logrus.New()
+ protoLogger.Level = protoLogLevel
+
+ if c.String("logfile") != "" {
+ if err := initLogFile(c, protoLogger); err != nil {
+ Log.Error(err)
+ }
+ }
+
+ if isAutoupdateEnabled(c) {
+ if initUpdate() {
+ return
+ }
+ Log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
+ go autoupdate(c.Duration("autoupdate-freq"), shutdownC)
+ }
+
+ hostname, err := validation.ValidateHostname(c.String("hostname"))
+ if err != nil {
+ Log.WithError(err).Fatal("Invalid hostname")
+ }
+ clientID := c.String("id")
+ if !c.IsSet("id") {
+ clientID = generateRandomClientID()
+ }
+
+ tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
+ if err != nil {
+ Log.WithError(err).Fatal("Tag parse failure")
+ }
+
+ tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
+ if c.IsSet("hello-world") {
+ wg.Add(1)
+ listener, err := createListener("127.0.0.1:")
+ if err != nil {
+ listener.Close()
+ Log.WithError(err).Fatal("Cannot start Hello World Server")
+ }
+ go func() {
+ startHelloWorldServer(listener, shutdownC)
+ wg.Done()
+ listener.Close()
+ }()
+ c.Set("url", "https://"+listener.Addr().String())
+ }
+
+ if c.IsSet("proxy-dns") {
+ wg.Add(1)
+ listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(c.Uint("proxy-dns-port")), c.StringSlice("proxy-dns-upstream"))
+ if err != nil {
+ listener.Stop()
+ Log.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server")
+ }
+ go func() {
+ listener.Start()
+ <-shutdownC
+ listener.Stop()
+ wg.Done()
+ }()
+ }
+
+ url, err := validateUrl(c)
+ if err != nil {
+ Log.WithError(err).Fatal("Error validating url")
+ }
+ Log.Infof("Proxying tunnel requests to %s", url)
+
+ // Fail if the user provided an old authentication method
+ if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
+ Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflared login")
+ }
+
+ // Check that the user has acquired a certificate using the log in command
+ originCertPath, err := homedir.Expand(c.String("origincert"))
+ if err != nil {
+ Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
+ }
+ ok, err := fileExists(originCertPath)
+ if err != nil {
+ Log.Fatalf("Cannot check if origin cert exists at path %s", c.String("origincert"))
+ }
+ if !ok {
+ Log.Fatalf(`Cannot find a valid certificate for your origin at the path:
+
+ %s
+
+If the path above is wrong, specify the path with the -origincert option.
+If you don't have a certificate signed by Cloudflare, run the command:
+
+ %s login
+`, originCertPath, os.Args[0])
+ }
+ // Easier to send the certificate as []byte via RPC than decoding it at this point
+ originCert, err := ioutil.ReadFile(originCertPath)
+ if err != nil {
+ Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
+ }
+
+ tunnelMetrics := origin.NewTunnelMetrics()
+ httpTransport := &http.Transport{
+ Proxy: http.ProxyFromEnvironment,
+ DialContext: (&net.Dialer{
+ Timeout: c.Duration("proxy-connect-timeout"),
+ KeepAlive: c.Duration("proxy-tcp-keepalive"),
+ DualStack: !c.Bool("proxy-no-happy-eyeballs"),
+ }).DialContext,
+ MaxIdleConns: c.Int("proxy-keepalive-connections"),
+ IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
+ TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
+ ExpectContinueTimeout: 1 * time.Second,
+ TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()},
+ }
+
+ if !c.IsSet("hello-world") && c.IsSet("origin-server-name") {
+ httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
+ }
+
+ tunnelConfig := &origin.TunnelConfig{
+ EdgeAddrs: c.StringSlice("edge"),
+ OriginUrl: url,
+ Hostname: hostname,
+ OriginCert: originCert,
+ TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
+ ClientTlsConfig: httpTransport.TLSClientConfig,
+ Retries: c.Uint("retries"),
+ HeartbeatInterval: c.Duration("heartbeat-interval"),
+ MaxHeartbeats: c.Uint64("heartbeat-count"),
+ ClientID: clientID,
+ ReportedVersion: Version,
+ LBPool: c.String("lb-pool"),
+ Tags: tags,
+ HAConnections: c.Int("ha-connections"),
+ HTTPTransport: httpTransport,
+ Metrics: tunnelMetrics,
+ MetricsUpdateFreq: c.Duration("metrics-update-freq"),
+ ProtocolLogger: protoLogger,
+ Logger: Log,
+ IsAutoupdated: c.Bool("is-autoupdated"),
+ }
+ connectedSignal := make(chan struct{})
+
+ go writePidFile(connectedSignal, c.String("pidfile"))
+ go func() {
+ errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal)
+ wg.Done()
+ }()
+
+ metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
+ if err != nil {
+ Log.WithError(err).Fatal("Error opening metrics server listener")
+ }
+ go func() {
+ errC <- metrics.ServeMetrics(metricsListener, shutdownC)
+ wg.Done()
+ }()
+
+ var errCode int
+ err = WaitForSignal(errC, shutdownC)
+ if err != nil {
+ Log.WithError(err).Error("Quitting due to error")
+ raven.CaptureErrorAndWait(err, nil)
+ errCode = 1
+ } else {
+ Log.Info("Quitting...")
+ }
+ // Wait for clean exit, discarding all errors
+ go func() {
+ for range errC {
+ }
+ }()
+ wg.Wait()
+ os.Exit(errCode)
+}
+
+func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
+ signals := make(chan os.Signal, 10)
+ signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
+ defer signal.Stop(signals)
+ select {
+ case err := <-errC:
+ close(shutdownC)
+ return err
+ case <-signals:
+ close(shutdownC)
+ case <-shutdownC:
+ }
+ return nil
+}
+
+func update(_ *cli.Context) error {
+ if updateApplied() {
+ os.Exit(64)
+ }
+ return nil
+}
+
+func initUpdate() bool {
+ if updateApplied() {
+ os.Args = append(os.Args, "--is-autoupdated=true")
+ if _, err := listeners.StartProcess(); err != nil {
+ Log.WithError(err).Error("Unable to restart server automatically")
+ return false
+ }
+ return true
+ }
+ return false
+}
+
+func autoupdate(freq time.Duration, shutdownC chan struct{}) {
+ for {
+ if updateApplied() {
+ os.Args = append(os.Args, "--is-autoupdated=true")
+ if _, err := listeners.StartProcess(); err != nil {
+ Log.WithError(err).Error("Unable to restart server automatically")
+ }
+ close(shutdownC)
+ return
+ }
+ time.Sleep(freq)
+ }
+}
+
+func updateApplied() bool {
+ releaseInfo := checkForUpdates()
+ if releaseInfo.Updated {
+ Log.Infof("Updated to version %s", releaseInfo.Version)
+ return true
+ }
+ if releaseInfo.Error != nil {
+ Log.WithError(releaseInfo.Error).Error("Update check failed")
+ }
+ return false
+}
+
+func fileExists(path string) (bool, error) {
+ f, err := os.Open(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ // ignore missing files
+ return false, nil
+ }
+ return false, err
+ }
+ f.Close()
+ return true, nil
+}
+
+// returns the first path that contains a cert.pem file. If none of the defaultConfigDirs
+// (differs by OS for legacy reasons) contains a cert.pem file, return empty string
+func findDefaultOriginCertPath() string {
+ for _, defaultConfigDir := range defaultConfigDirs {
+ originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, credentialFile))
+ if ok, _ := fileExists(originCertPath); ok {
+ return originCertPath
+ }
+ }
+ return ""
+}
+
+// returns the firt path that contains a config file. If none of the combination of
+// defaultConfigDirs (differs by OS for legacy reasons) and defaultConfigFiles
+// contains a config file, return empty string
+func findDefaultConfigPath() string {
+ for _, configDir := range defaultConfigDirs {
+ for _, configFile := range defaultConfigFiles {
+ dirPath, err := homedir.Expand(configDir)
+ if err != nil {
+ return ""
+ }
+ path := filepath.Join(dirPath, configFile)
+ if ok, _ := fileExists(path); ok {
+ return path
+ }
+ }
+ }
+ return ""
+}
+
+func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) {
+ if context.String("config") != "" {
+ return altsrc.NewYamlSourceFromFile(context.String("config"))
+ }
+ return nil, nil
+}
+
+func generateRandomClientID() string {
+ r := rand.New(rand.NewSource(time.Now().UnixNano()))
+ id := make([]byte, 32)
+ r.Read(id)
+ return hex.EncodeToString(id)
+}
+
+func writePidFile(waitForSignal chan struct{}, pidFile string) {
+ <-waitForSignal
+ daemon.SdNotify(false, "READY=1")
+ if pidFile == "" {
+ return
+ }
+ file, err := os.Create(pidFile)
+ if err != nil {
+ Log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
+ }
+ defer file.Close()
+ fmt.Fprintf(file, "%d", os.Getpid())
+}
+
+// validate url. It can be either from --url or argument
+func validateUrl(c *cli.Context) (string, error) {
+ var url = c.String("url")
+ if c.NArg() > 0 {
+ if c.IsSet("url") {
+ return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
+ }
+ url = c.Args().Get(0)
+ }
+ validUrl, err := validation.ValidateUrl(url)
+ return validUrl, err
+}
+
+func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error {
+ filePath, err := homedir.Expand(c.String("logfile"))
+ if err != nil {
+ return errors.Wrap(err, "Cannot resolve logfile path")
+ }
+
+ fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC
+ // do not truncate log file if the client has been autoupdated
+ if c.Bool("is-autoupdated") {
+ fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE
+ }
+ f, err := os.OpenFile(filePath, fileMode, 0664)
+ if err != nil {
+ errors.Wrap(err, fmt.Sprintf("Cannot open file %s", filePath))
+ }
+ defer f.Close()
+ pathMap := lfshook.PathMap{
+ logrus.InfoLevel: filePath,
+ logrus.ErrorLevel: filePath,
+ logrus.FatalLevel: filePath,
+ logrus.PanicLevel: filePath,
+ }
+
+ Log.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
+ protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
+
+ flags := make(map[string]interface{})
+ envs := make(map[string]string)
+
+ for _, flag := range c.LocalFlagNames() {
+ flags[flag] = c.Generic(flag)
+ }
+
+ // Find env variables for Argo Tunnel
+ for _, env := range os.Environ() {
+ // All Argo Tunnel env variables start with TUNNEL_
+ if strings.Contains(env, "TUNNEL_") {
+ vars := strings.Split(env, "=")
+ if len(vars) == 2 {
+ envs[vars[0]] = vars[1]
+ }
+ }
+ }
+
+ Log.Infof("Argo Tunnel build and runtime configuration: %+v", BuildAndRuntimeInfo{
+ GoOS: runtime.GOOS,
+ GoVersion: runtime.Version(),
+ GoArch: runtime.GOARCH,
+ WarpVersion: Version,
+ WarpFlags: flags,
+ WarpEnvs: envs,
+ })
+
+ return nil
+}
+
+func isAutoupdateEnabled(c *cli.Context) bool {
+ if terminal.IsTerminal(int(os.Stdout.Fd())) {
+ Log.Info(noAutoupdateMessage)
+ return false
+ }
+
+ return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0
+}
diff --git a/cmd/cloudflared/service_template.go b/cmd/cloudflared/service_template.go
new file mode 100644
index 00000000..063e26cb
--- /dev/null
+++ b/cmd/cloudflared/service_template.go
@@ -0,0 +1,192 @@
+package main
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "text/template"
+
+ "github.com/mitchellh/go-homedir"
+)
+
+type ServiceTemplate struct {
+ Path string
+ Content string
+ FileMode os.FileMode
+}
+
+type ServiceTemplateArgs struct {
+ Path string
+}
+
+func (st *ServiceTemplate) ResolvePath() (string, error) {
+ resolvedPath, err := homedir.Expand(st.Path)
+ if err != nil {
+ return "", fmt.Errorf("error resolving path %s: %v", st.Path, err)
+ }
+ return resolvedPath, nil
+}
+
+func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
+ tmpl, err := template.New(st.Path).Parse(st.Content)
+ if err != nil {
+ return fmt.Errorf("error generating %s template: %v", st.Path, err)
+ }
+ resolvedPath, err := st.ResolvePath()
+ if err != nil {
+ return err
+ }
+ var buffer bytes.Buffer
+ err = tmpl.Execute(&buffer, args)
+ if err != nil {
+ return fmt.Errorf("error generating %s: %v", st.Path, err)
+ }
+ fileMode := os.FileMode(0644)
+ if st.FileMode != 0 {
+ fileMode = st.FileMode
+ }
+ err = ioutil.WriteFile(resolvedPath, buffer.Bytes(), fileMode)
+ if err != nil {
+ return fmt.Errorf("error writing %s: %v", resolvedPath, err)
+ }
+ return nil
+}
+
+func (st *ServiceTemplate) Remove() error {
+ resolvedPath, err := st.ResolvePath()
+ if err != nil {
+ return err
+ }
+ err = os.Remove(resolvedPath)
+ if err != nil {
+ return fmt.Errorf("error deleting %s: %v", resolvedPath, err)
+ }
+ return nil
+}
+
+func runCommand(command string, args ...string) error {
+ cmd := exec.Command(command, args...)
+ stderr, err := cmd.StderrPipe()
+ if err != nil {
+ Log.WithError(err).Infof("error getting stderr pipe")
+ return fmt.Errorf("error getting stderr pipe: %v", err)
+ }
+ err = cmd.Start()
+ if err != nil {
+ Log.WithError(err).Infof("error starting %s", command)
+ return fmt.Errorf("error starting %s: %v", command, err)
+ }
+ commandErr, _ := ioutil.ReadAll(stderr)
+ if len(commandErr) > 0 {
+ Log.Errorf("%s: %s", command, commandErr)
+ }
+ err = cmd.Wait()
+ if err != nil {
+ Log.WithError(err).Infof("%s returned error", command)
+ return fmt.Errorf("%s returned with error: %v", command, err)
+ }
+ return nil
+}
+
+func ensureConfigDirExists(configDir string) error {
+ ok, err := fileExists(configDir)
+ if !ok && err == nil {
+ err = os.Mkdir(configDir, 0700)
+ }
+ return err
+}
+
+// openFile opens the file at path. If create is set and the file exists, returns nil, true, nil
+func openFile(path string, create bool) (file *os.File, exists bool, err error) {
+ expandedPath, err := homedir.Expand(path)
+ if err != nil {
+ return nil, false, err
+ }
+ if create {
+ fileInfo, err := os.Stat(expandedPath)
+ if err == nil && fileInfo.Size() > 0 {
+ return nil, true, nil
+ }
+ file, err = os.OpenFile(expandedPath, os.O_RDWR|os.O_CREATE, 0600)
+ } else {
+ file, err = os.Open(expandedPath)
+ }
+ return file, false, err
+}
+
+func copyCertificate(srcConfigDir, destConfigDir string) error {
+ destCredentialPath := filepath.Join(destConfigDir, credentialFile)
+ destFile, exists, err := openFile(destCredentialPath, true)
+ if err != nil {
+ return err
+ } else if exists {
+ // credentials already exist, do nothing
+ return nil
+ }
+ defer destFile.Close()
+
+ srcCredentialPath := filepath.Join(srcConfigDir, credentialFile)
+ srcFile, _, err := openFile(srcCredentialPath, false)
+ if err != nil {
+ return err
+ }
+ defer srcFile.Close()
+
+ // Copy certificate
+ _, err = io.Copy(destFile, srcFile)
+ if err != nil {
+ return fmt.Errorf("unable to copy %s to %s: %v", srcCredentialPath, destCredentialPath, err)
+ }
+
+ return nil
+}
+
+func copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile string) error {
+ if err := ensureConfigDirExists(serviceConfigDir); err != nil {
+ return err
+ }
+
+ if err := copyCertificate(defaultConfigDir, serviceConfigDir); err != nil {
+ return err
+ }
+
+ // Copy or create config
+ destConfigPath := filepath.Join(serviceConfigDir, defaultConfigFile)
+ destFile, exists, err := openFile(destConfigPath, true)
+ if err != nil {
+ Log.WithError(err).Infof("cannot open %s", destConfigPath)
+ return err
+ } else if exists {
+ // config already exists, do nothing
+ return nil
+ }
+ defer destFile.Close()
+
+ srcConfigPath := filepath.Join(defaultConfigDir, defaultConfigFile)
+ srcFile, _, err := openFile(srcConfigPath, false)
+ if err != nil {
+ fmt.Println("Your service needs a config file that at least specifies the hostname option.")
+ fmt.Println("Type in a hostname now, or leave it blank and create the config file later.")
+ fmt.Print("Hostname: ")
+ reader := bufio.NewReader(os.Stdin)
+ input, _ := reader.ReadString('\n')
+ if input == "" {
+ return err
+ }
+ fmt.Fprintf(destFile, "hostname: %s\n", input)
+ } else {
+ defer srcFile.Close()
+ _, err = io.Copy(destFile, srcFile)
+ if err != nil {
+ return fmt.Errorf("unable to copy %s to %s: %v", srcConfigPath, destConfigPath, err)
+ }
+ Log.Infof("Copied %s to %s", srcConfigPath, destConfigPath)
+ }
+
+ return nil
+}
diff --git a/cmd/cloudflared/tag.go b/cmd/cloudflared/tag.go
new file mode 100644
index 00000000..fde590ba
--- /dev/null
+++ b/cmd/cloudflared/tag.go
@@ -0,0 +1,32 @@
+package main
+
+import (
+ "fmt"
+ "regexp"
+
+ tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
+)
+
+// Restrict key names to characters allowed in an HTTP header name.
+// Restrict key values to printable characters (what is recognised as data in an HTTP header value).
+var tagRegexp = regexp.MustCompile("^([a-zA-Z0-9!#$%&'*+\\-.^_`|~]+)=([[:print:]]+)$")
+
+func NewTagFromCLI(compoundTag string) (tunnelpogs.Tag, bool) {
+ matches := tagRegexp.FindStringSubmatch(compoundTag)
+ if len(matches) == 0 {
+ return tunnelpogs.Tag{}, false
+ }
+ return tunnelpogs.Tag{Name: matches[1], Value: matches[2]}, true
+}
+
+func NewTagSliceFromCLI(tags []string) ([]tunnelpogs.Tag, error) {
+ var tagSlice []tunnelpogs.Tag
+ for _, compoundTag := range tags {
+ if tag, ok := NewTagFromCLI(compoundTag); ok {
+ tagSlice = append(tagSlice, tag)
+ } else {
+ return nil, fmt.Errorf("Cannot parse tag value %s", compoundTag)
+ }
+ }
+ return tagSlice, nil
+}
diff --git a/cmd/cloudflared/tag_test.go b/cmd/cloudflared/tag_test.go
new file mode 100644
index 00000000..25bf3243
--- /dev/null
+++ b/cmd/cloudflared/tag_test.go
@@ -0,0 +1,46 @@
+package main
+
+import (
+ "testing"
+
+ tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestSingleTag(t *testing.T) {
+ testCases := []struct {
+ Input string
+ Output tunnelpogs.Tag
+ Fail bool
+ }{
+ {Input: "x=y", Output: tunnelpogs.Tag{Name: "x", Value: "y"}},
+ {Input: "More-Complex=Tag Values", Output: tunnelpogs.Tag{Name: "More-Complex", Value: "Tag Values"}},
+ {Input: "First=Equals=Wins", Output: tunnelpogs.Tag{Name: "First", Value: "Equals=Wins"}},
+ {Input: "x=", Fail: true},
+ {Input: "=y", Fail: true},
+ {Input: "=", Fail: true},
+ {Input: "No spaces allowed=in key names", Fail: true},
+ {Input: "omg\nwtf=bbq", Fail: true},
+ }
+ for i, testCase := range testCases {
+ tag, ok := NewTagFromCLI(testCase.Input)
+ assert.Equalf(t, !testCase.Fail, ok, "mismatched success for test case %d", i)
+ assert.Equalf(t, testCase.Output, tag, "mismatched output for test case %d", i)
+ }
+}
+
+func TestTagSlice(t *testing.T) {
+ tagSlice, err := NewTagSliceFromCLI([]string{"a=b", "c=d", "e=f"})
+ assert.NoError(t, err)
+ assert.Len(t, tagSlice, 3)
+ assert.Equal(t, "a", tagSlice[0].Name)
+ assert.Equal(t, "b", tagSlice[0].Value)
+ assert.Equal(t, "c", tagSlice[1].Name)
+ assert.Equal(t, "d", tagSlice[1].Value)
+ assert.Equal(t, "e", tagSlice[2].Name)
+ assert.Equal(t, "f", tagSlice[2].Value)
+
+ tagSlice, err = NewTagSliceFromCLI([]string{"a=b", "=", "e=f"})
+ assert.Error(t, err)
+}
diff --git a/cmd/cloudflared/update.go b/cmd/cloudflared/update.go
new file mode 100644
index 00000000..dc2f0cd0
--- /dev/null
+++ b/cmd/cloudflared/update.go
@@ -0,0 +1,41 @@
+package main
+
+import "github.com/equinox-io/equinox"
+
+const appID = "app_cwbQae3Tpea"
+
+var publicKey = []byte(`
+-----BEGIN ECDSA PUBLIC KEY-----
+MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE4OWZocTVZ8Do/L6ScLdkV+9A0IYMHoOf
+dsCmJ/QZ6aw0w9qkkwEpne1Lmo6+0pGexZzFZOH6w5amShn+RXt7qkSid9iWlzGq
+EKx0BZogHSor9Wy5VztdFaAaVbsJiCbO
+-----END ECDSA PUBLIC KEY-----
+`)
+
+type ReleaseInfo struct {
+ Updated bool
+ Version string
+ Error error
+}
+
+func checkForUpdates() ReleaseInfo {
+ var opts equinox.Options
+ if err := opts.SetPublicKeyPEM(publicKey); err != nil {
+ return ReleaseInfo{Error: err}
+ }
+
+ resp, err := equinox.Check(appID, opts)
+ switch {
+ case err == equinox.NotAvailableErr:
+ return ReleaseInfo{}
+ case err != nil:
+ return ReleaseInfo{Error: err}
+ }
+
+ err = resp.Apply()
+ if err != nil {
+ return ReleaseInfo{Error: err}
+ }
+
+ return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion}
+}
diff --git a/cmd/cloudflared/windows_service.go b/cmd/cloudflared/windows_service.go
new file mode 100644
index 00000000..b003bb50
--- /dev/null
+++ b/cmd/cloudflared/windows_service.go
@@ -0,0 +1,166 @@
+// +build windows
+
+package main
+
+// Copypasta from the example files:
+// https://github.com/golang/sys/blob/master/windows/svc/example
+
+import (
+ "fmt"
+ "os"
+
+ cli "gopkg.in/urfave/cli.v2"
+
+ "golang.org/x/sys/windows/svc"
+ "golang.org/x/sys/windows/svc/eventlog"
+ "golang.org/x/sys/windows/svc/mgr"
+)
+
+const (
+ windowsServiceName = "Cloudflared"
+ windowsServiceDescription = "Argo Tunnel agent"
+)
+
+func runApp(app *cli.App) {
+ app.Commands = append(app.Commands, &cli.Command{
+ Name: "service",
+ Usage: "Manages the Argo Tunnel Windows service",
+ Subcommands: []*cli.Command{
+ &cli.Command{
+ Name: "install",
+ Usage: "Install Argo Tunnel as a Windows service",
+ Action: installWindowsService,
+ },
+ &cli.Command{
+ Name: "uninstall",
+ Usage: "Uninstall the Argo Tunnel service",
+ Action: uninstallWindowsService,
+ },
+ },
+ })
+
+ isIntSess, err := svc.IsAnInteractiveSession()
+ if err != nil {
+ Log.Fatalf("failed to determine if we are running in an interactive session: %v", err)
+ }
+
+ if isIntSess {
+ app.Run(os.Args)
+ return
+ }
+
+ elog, err := eventlog.Open(windowsServiceName)
+ if err != nil {
+ Log.WithError(err).Infof("Cannot open event log for %s", windowsServiceName)
+ return
+ }
+ defer elog.Close()
+
+ elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName))
+ // Run executes service name by calling windowsService which is a Handler
+ // interface that implements Execute method
+ err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog})
+ if err != nil {
+ elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
+ return
+ }
+ elog.Info(1, fmt.Sprintf("%s service stopped", windowsServiceName))
+}
+
+type windowsService struct {
+ app *cli.App
+ elog *eventlog.Log
+}
+
+// called by the package code at the start of the service
+func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) {
+ const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
+ changes <- svc.Status{State: svc.StartPending}
+ go s.app.Run(args)
+
+ changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
+loop:
+ for {
+ select {
+ case c := <-r:
+ switch c.Cmd {
+ case svc.Interrogate:
+ s.elog.Info(1, fmt.Sprintf("control request 1 #%d", c))
+ changes <- c.CurrentStatus
+ case svc.Stop:
+ s.elog.Info(1, "received stop control request")
+ break loop
+ case svc.Shutdown:
+ s.elog.Info(1, "received shutdown control request")
+ break loop
+ default:
+ s.elog.Error(1, fmt.Sprintf("unexpected control request #%d", c))
+ }
+ }
+ }
+ close(shutdownC)
+ changes <- svc.Status{State: svc.StopPending}
+ return
+}
+
+func installWindowsService(c *cli.Context) error {
+ Log.Infof("Installing Argo Tunnel Windows service")
+ exepath, err := os.Executable()
+ if err != nil {
+ Log.Infof("Cannot find path name that start the process")
+ return err
+ }
+ m, err := mgr.Connect()
+ if err != nil {
+ Log.WithError(err).Infof("Cannot establish a connection to the service control manager")
+ return err
+ }
+ defer m.Disconnect()
+ s, err := m.OpenService(windowsServiceName)
+ if err == nil {
+ s.Close()
+ Log.Errorf("service %s already exists", windowsServiceName)
+ return fmt.Errorf("service %s already exists", windowsServiceName)
+ }
+ config := mgr.Config{StartType: mgr.StartAutomatic, DisplayName: windowsServiceDescription}
+ s, err = m.CreateService(windowsServiceName, exepath, config)
+ if err != nil {
+ Log.Infof("Cannot install service %s", windowsServiceName)
+ return err
+ }
+ defer s.Close()
+ err = eventlog.InstallAsEventCreate(windowsServiceName, eventlog.Error|eventlog.Warning|eventlog.Info)
+ if err != nil {
+ s.Delete()
+ Log.WithError(err).Infof("Cannot install event logger")
+ return fmt.Errorf("SetupEventLogSource() failed: %s", err)
+ }
+ return nil
+}
+
+func uninstallWindowsService(c *cli.Context) error {
+ Log.Infof("Uninstalling Argo Tunnel Windows Service")
+ m, err := mgr.Connect()
+ if err != nil {
+ Log.Infof("Cannot establish a connection to the service control manager")
+ return err
+ }
+ defer m.Disconnect()
+ s, err := m.OpenService(windowsServiceName)
+ if err != nil {
+ Log.Infof("service %s is not installed", windowsServiceName)
+ return fmt.Errorf("service %s is not installed", windowsServiceName)
+ }
+ defer s.Close()
+ err = s.Delete()
+ if err != nil {
+ Log.Errorf("Cannot delete service %s", windowsServiceName)
+ return err
+ }
+ err = eventlog.Remove(windowsServiceName)
+ if err != nil {
+ Log.Infof("Cannot remove event logger")
+ return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
+ }
+ return nil
+}
diff --git a/h2mux/activestreammap.go b/h2mux/activestreammap.go
index 9059e972..80b54d38 100644
--- a/h2mux/activestreammap.go
+++ b/h2mux/activestreammap.go
@@ -24,12 +24,6 @@ type activeStreamMap struct {
ignoreNewStreams bool
}
-type FlowControlMetrics struct {
- AverageReceiveWindowSize, AverageSendWindowSize float64
- MinReceiveWindowSize, MaxReceiveWindowSize uint32
- MinSendWindowSize, MaxSendWindowSize uint32
-}
-
func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap {
m := &activeStreamMap{
streams: make(map[uint32]*MuxedStream),
@@ -169,45 +163,3 @@ func (m *activeStreamMap) Abort() {
}
m.ignoreNewStreams = true
}
-
-func (m *activeStreamMap) Metrics() *FlowControlMetrics {
- m.Lock()
- defer m.Unlock()
- var averageReceiveWindowSize, averageSendWindowSize float64
- var minReceiveWindowSize, maxReceiveWindowSize, minSendWindowSize, maxSendWindowSize uint32
- i := 0
- // The first variable in the range expression for map is the key, not index.
- for _, stream := range m.streams {
- // iterative mean: a(t+1) = a(t) + (a(t)-x)/(t+1)
- windows := stream.FlowControlWindow()
- averageReceiveWindowSize += (float64(windows.receiveWindow) - averageReceiveWindowSize) / float64(i+1)
- averageSendWindowSize += (float64(windows.sendWindow) - averageSendWindowSize) / float64(i+1)
- if i == 0 {
- maxReceiveWindowSize = windows.receiveWindow
- minReceiveWindowSize = windows.receiveWindow
- maxSendWindowSize = windows.sendWindow
- minSendWindowSize = windows.sendWindow
- } else {
- if windows.receiveWindow > maxReceiveWindowSize {
- maxReceiveWindowSize = windows.receiveWindow
- } else if windows.receiveWindow < minReceiveWindowSize {
- minReceiveWindowSize = windows.receiveWindow
- }
-
- if windows.sendWindow > maxSendWindowSize {
- maxSendWindowSize = windows.sendWindow
- } else if windows.sendWindow < minSendWindowSize {
- minSendWindowSize = windows.sendWindow
- }
- }
- i++
- }
- return &FlowControlMetrics{
- MinReceiveWindowSize: minReceiveWindowSize,
- MaxReceiveWindowSize: maxReceiveWindowSize,
- AverageReceiveWindowSize: averageReceiveWindowSize,
- MinSendWindowSize: minSendWindowSize,
- MaxSendWindowSize: maxSendWindowSize,
- AverageSendWindowSize: averageSendWindowSize,
- }
-}
diff --git a/h2mux/bytes_counter.go b/h2mux/bytes_counter.go
new file mode 100644
index 00000000..0cd290fd
--- /dev/null
+++ b/h2mux/bytes_counter.go
@@ -0,0 +1,18 @@
+package h2mux
+
+import (
+ "sync/atomic"
+)
+
+type AtomicCounter struct {
+ count uint64
+}
+
+func (c *AtomicCounter) IncrementBy(number uint64) {
+ atomic.AddUint64(&c.count, number)
+}
+
+// Get returns the current value of counter and reset it to 0
+func (c *AtomicCounter) Count() uint64 {
+ return atomic.SwapUint64(&c.count, 0)
+}
diff --git a/h2mux/bytes_counter_test.go b/h2mux/bytes_counter_test.go
new file mode 100644
index 00000000..da579aaf
--- /dev/null
+++ b/h2mux/bytes_counter_test.go
@@ -0,0 +1,23 @@
+package h2mux
+
+import (
+ "sync"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestCounter(t *testing.T) {
+ var wg sync.WaitGroup
+ wg.Add(dataPoints)
+ c := AtomicCounter{}
+ for i := 0; i < dataPoints; i++ {
+ go func() {
+ defer wg.Done()
+ c.IncrementBy(uint64(1))
+ }()
+ }
+ wg.Wait()
+ assert.Equal(t, uint64(dataPoints), c.Count())
+ assert.Equal(t, uint64(0), c.Count())
+}
diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go
index 928b162b..e476431b 100644
--- a/h2mux/h2mux.go
+++ b/h2mux/h2mux.go
@@ -59,6 +59,8 @@ type Muxer struct {
muxReader *MuxReader
// muxWriter is the write process.
muxWriter *MuxWriter
+ // muxMetricsUpdater is the process to update metrics
+ muxMetricsUpdater *muxMetricsUpdater
// newStreamChan is used to create new streams on the writer thread.
// The writer will assign the next available stream ID.
newStreamChan chan MuxedStreamRequest
@@ -133,6 +135,11 @@ func Handshake(
// set up reader/writer pair ready for serve
streamErrors := NewStreamErrorMap()
goAwayChan := make(chan http2.ErrCode, 1)
+ updateRTTChan := make(chan *roundTripMeasurement, 1)
+ updateReceiveWindowChan := make(chan uint32, 1)
+ updateSendWindowChan := make(chan uint32, 1)
+ updateInBoundBytesChan := make(chan uint64)
+ updateOutBoundBytesChan := make(chan uint64)
pingTimestamp := NewPingTimestamp()
connActive := NewSignal()
idleDuration := config.HeartbeatInterval
@@ -149,34 +156,48 @@ func Handshake(
m.explicitShutdown = NewBooleanFuse()
m.muxReader = &MuxReader{
- f: m.f,
- handler: m.config.Handler,
- streams: m.streams,
- readyList: m.readyList,
- streamErrors: streamErrors,
- goAwayChan: goAwayChan,
- abortChan: m.abortChan,
- pingTimestamp: pingTimestamp,
- connActive: connActive,
- initialStreamWindow: defaultWindowSize,
- streamWindowMax: maxWindowSize,
- r: m.r,
+ f: m.f,
+ handler: m.config.Handler,
+ streams: m.streams,
+ readyList: m.readyList,
+ streamErrors: streamErrors,
+ goAwayChan: goAwayChan,
+ abortChan: m.abortChan,
+ pingTimestamp: pingTimestamp,
+ connActive: connActive,
+ initialStreamWindow: defaultWindowSize,
+ streamWindowMax: maxWindowSize,
+ r: m.r,
+ updateRTTChan: updateRTTChan,
+ updateReceiveWindowChan: updateReceiveWindowChan,
+ updateSendWindowChan: updateSendWindowChan,
+ updateInBoundBytesChan: updateInBoundBytesChan,
}
m.muxWriter = &MuxWriter{
- f: m.f,
- streams: m.streams,
- streamErrors: streamErrors,
- readyStreamChan: m.readyList.ReadyChannel(),
- newStreamChan: m.newStreamChan,
- goAwayChan: goAwayChan,
- abortChan: m.abortChan,
- pingTimestamp: pingTimestamp,
- idleTimer: NewIdleTimer(idleDuration, maxRetries),
- connActiveChan: connActive.WaitChannel(),
- maxFrameSize: defaultFrameSize,
+ f: m.f,
+ streams: m.streams,
+ streamErrors: streamErrors,
+ readyStreamChan: m.readyList.ReadyChannel(),
+ newStreamChan: m.newStreamChan,
+ goAwayChan: goAwayChan,
+ abortChan: m.abortChan,
+ pingTimestamp: pingTimestamp,
+ idleTimer: NewIdleTimer(idleDuration, maxRetries),
+ connActiveChan: connActive.WaitChannel(),
+ maxFrameSize: defaultFrameSize,
+ updateReceiveWindowChan: updateReceiveWindowChan,
+ updateSendWindowChan: updateSendWindowChan,
+ updateOutBoundBytesChan: updateOutBoundBytesChan,
}
m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
-
+ m.muxMetricsUpdater = newMuxMetricsUpdater(
+ updateRTTChan,
+ updateReceiveWindowChan,
+ updateSendWindowChan,
+ updateInBoundBytesChan,
+ updateOutBoundBytesChan,
+ m.abortChan,
+ )
return m, nil
}
@@ -246,9 +267,13 @@ func (m *Muxer) Serve() error {
m.w.Close()
m.abort()
}()
+ go func() {
+ errChan <- m.muxMetricsUpdater.run(logger)
+ }()
err := <-errChan
go func() {
- // discard error as other handler closes
+ // discard errors as other handler and muxMetricsUpdater close
+ <-errChan
<-errChan
close(errChan)
}()
@@ -318,14 +343,8 @@ func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, erro
}
}
-// Return the estimated round-trip time.
-func (m *Muxer) RTT() RTTMeasurement {
- return m.muxReader.RTT()
-}
-
-// Return min/max/average of send/receive window for all streams on this connection
-func (m *Muxer) FlowControlMetrics() *FlowControlMetrics {
- return m.muxReader.FlowControlMetrics()
+func (m *Muxer) Metrics() *MuxerMetrics {
+ return m.muxMetricsUpdater.Metrics()
}
func (m *Muxer) abort() {
diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go
index 49480cf4..3f9d1a75 100644
--- a/h2mux/h2mux_test.go
+++ b/h2mux/h2mux_test.go
@@ -14,6 +14,7 @@ import (
"time"
log "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
)
func TestMain(m *testing.M) {
@@ -134,9 +135,12 @@ func TestSingleStream(t *testing.T) {
if stream.Headers[0].Value != "headerValue" {
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
}
- stream.WriteHeaders([]Header{
+ headers := []Header{
Header{Name: "response-header", Value: "responseValue"},
- })
+ }
+ stream.WriteHeaders(headers)
+ assert.Equal(t, headers, stream.writeHeaders)
+ assert.False(t, stream.headersSent)
buf := []byte("Hello world")
stream.Write(buf)
// after this receive, the edge closed the stream
diff --git a/h2mux/muxedstream.go b/h2mux/muxedstream.go
index f0367ce9..f9997858 100644
--- a/h2mux/muxedstream.go
+++ b/h2mux/muxedstream.go
@@ -19,6 +19,7 @@ type MuxedStream struct {
receiveWindowCurrentMax uint32
// limit set in http2 spec. 2^31-1
receiveWindowMax uint32
+
// nonzero if a WINDOW_UPDATE frame for a stream needs to be sent
windowUpdate uint32
@@ -39,10 +40,6 @@ type MuxedStream struct {
receivedEOF bool
}
-type flowControlWindow struct {
- receiveWindow, sendWindow uint32
-}
-
func (s *MuxedStream) Read(p []byte) (n int, err error) {
return s.readBuffer.Read(p)
}
@@ -101,17 +98,21 @@ func (s *MuxedStream) WriteHeaders(headers []Header) error {
return ErrStreamHeadersSent
}
s.writeHeaders = headers
+ s.headersSent = false
s.writeNotify()
return nil
}
-func (s *MuxedStream) FlowControlWindow() *flowControlWindow {
+func (s *MuxedStream) getReceiveWindow() uint32 {
s.writeLock.Lock()
defer s.writeLock.Unlock()
- return &flowControlWindow{
- receiveWindow: s.receiveWindow,
- sendWindow: s.sendWindow,
- }
+ return s.receiveWindow
+}
+
+func (s *MuxedStream) getSendWindow() uint32 {
+ s.writeLock.Lock()
+ defer s.writeLock.Unlock()
+ return s.sendWindow
}
// writeNotify must happen while holding writeLock.
@@ -209,9 +210,7 @@ func (s *MuxedStream) getChunk() *streamChunk {
}
// Copies at most s.sendWindow bytes
- //log.Infof("writeBuffer len %d stream %d", s.writeBuffer.Len(), s.streamID)
writeLen, _ := io.CopyN(&chunk.buffer, &s.writeBuffer, int64(s.sendWindow))
- //log.Infof("writeLen %d stream %d", writeLen, s.streamID)
s.sendWindow -= uint32(writeLen)
s.receiveWindow += s.windowUpdate
s.windowUpdate = 0
diff --git a/h2mux/muxmetrics.go b/h2mux/muxmetrics.go
new file mode 100644
index 00000000..ebac9935
--- /dev/null
+++ b/h2mux/muxmetrics.go
@@ -0,0 +1,232 @@
+package h2mux
+
+import (
+ "sync"
+ "time"
+
+ "github.com/golang-collections/collections/queue"
+
+ log "github.com/sirupsen/logrus"
+)
+
+// data points used to compute average receive window and send window size
+const (
+ // data points used to compute average receive window and send window size
+ dataPoints = 100
+ // updateFreq is set to 1 sec so we can get inbound & outbound byes/sec
+ updateFreq = time.Second
+)
+
+type muxMetricsUpdater struct {
+ // rttData keeps record of rtt, rttMin, rttMax and last measured time
+ rttData *rttData
+ // receiveWindowData keeps record of receive window measurement
+ receiveWindowData *flowControlData
+ // sendWindowData keeps record of send window measurement
+ sendWindowData *flowControlData
+ // inBoundRate is incoming bytes/sec
+ inBoundRate *rate
+ // outBoundRate is outgoing bytes/sec
+ outBoundRate *rate
+ // updateRTTChan is the channel to receive new RTT measurement from muxReader
+ updateRTTChan <-chan *roundTripMeasurement
+ //updateReceiveWindowChan is the channel to receive updated receiveWindow size from muxReader and muxWriter
+ updateReceiveWindowChan <-chan uint32
+ //updateSendWindowChan is the channel to receive updated sendWindow size from muxReader and muxWriter
+ updateSendWindowChan <-chan uint32
+ // updateInBoundBytesChan us the channel to receive bytesRead from muxReader
+ updateInBoundBytesChan <-chan uint64
+ // updateOutBoundBytesChan us the channel to receive bytesWrote from muxWriter
+ updateOutBoundBytesChan <-chan uint64
+ // shutdownC is to signal the muxerMetricsUpdater to shutdown
+ abortChan <-chan struct{}
+}
+
+type MuxerMetrics struct {
+ RTT, RTTMin, RTTMax time.Duration
+ ReceiveWindowAve, SendWindowAve float64
+ ReceiveWindowMin, ReceiveWindowMax, SendWindowMin, SendWindowMax uint32
+ InBoundRateCurr, InBoundRateMin, InBoundRateMax uint64
+ OutBoundRateCurr, OutBoundRateMin, OutBoundRateMax uint64
+}
+
+type roundTripMeasurement struct {
+ receiveTime, sendTime time.Time
+}
+
+type rttData struct {
+ rtt, rttMin, rttMax time.Duration
+ lastMeasurementTime time.Time
+ lock sync.RWMutex
+}
+
+type flowControlData struct {
+ sum uint64
+ min, max uint32
+ queue *queue.Queue
+ lock sync.RWMutex
+}
+
+type rate struct {
+ curr uint64
+ min, max uint64
+ lock sync.RWMutex
+}
+
+func newMuxMetricsUpdater(
+ updateRTTChan <-chan *roundTripMeasurement,
+ updateReceiveWindowChan <-chan uint32,
+ updateSendWindowChan <-chan uint32,
+ updateInBoundBytesChan <-chan uint64,
+ updateOutBoundBytesChan <-chan uint64,
+ abortChan <-chan struct{},
+) *muxMetricsUpdater {
+ return &muxMetricsUpdater{
+ rttData: newRTTData(),
+ receiveWindowData: newFlowControlData(),
+ sendWindowData: newFlowControlData(),
+ inBoundRate: newRate(),
+ outBoundRate: newRate(),
+ updateRTTChan: updateRTTChan,
+ updateReceiveWindowChan: updateReceiveWindowChan,
+ updateSendWindowChan: updateSendWindowChan,
+ updateInBoundBytesChan: updateInBoundBytesChan,
+ updateOutBoundBytesChan: updateOutBoundBytesChan,
+ abortChan: abortChan,
+ }
+}
+
+func (updater *muxMetricsUpdater) Metrics() *MuxerMetrics {
+ m := &MuxerMetrics{}
+ m.RTT, m.RTTMin, m.RTTMax = updater.rttData.metrics()
+ m.ReceiveWindowAve, m.ReceiveWindowMin, m.ReceiveWindowMax = updater.receiveWindowData.metrics()
+ m.SendWindowAve, m.SendWindowMin, m.SendWindowMax = updater.sendWindowData.metrics()
+ m.InBoundRateCurr, m.InBoundRateMin, m.InBoundRateMax = updater.inBoundRate.get()
+ m.OutBoundRateCurr, m.OutBoundRateMin, m.OutBoundRateMax = updater.outBoundRate.get()
+ return m
+}
+
+func (updater *muxMetricsUpdater) run(parentLogger *log.Entry) error {
+ logger := parentLogger.WithFields(log.Fields{
+ "subsystem": "mux",
+ "dir": "metrics",
+ })
+ defer logger.Debug("event loop finished")
+ for {
+ select {
+ case <-updater.abortChan:
+ logger.Infof("Stopping mux metrics updater")
+ return nil
+ case roundTripMeasurement := <-updater.updateRTTChan:
+ go updater.rttData.update(roundTripMeasurement)
+ logger.Debug("Update rtt")
+ case receiveWindow := <-updater.updateReceiveWindowChan:
+ go updater.receiveWindowData.update(receiveWindow)
+ logger.Debug("Update receive window")
+ case sendWindow := <-updater.updateSendWindowChan:
+ go updater.sendWindowData.update(sendWindow)
+ logger.Debug("Update send window")
+ case inBoundBytes := <-updater.updateInBoundBytesChan:
+ // inBoundBytes is bytes/sec because the update interval is 1 sec
+ go updater.inBoundRate.update(inBoundBytes)
+ logger.Debugf("Inbound bytes %d", inBoundBytes)
+ case outBoundBytes := <-updater.updateOutBoundBytesChan:
+ // outBoundBytes is bytes/sec because the update interval is 1 sec
+ go updater.outBoundRate.update(outBoundBytes)
+ logger.Debugf("Outbound bytes %d", outBoundBytes)
+ }
+ }
+}
+
+func newRTTData() *rttData {
+ return &rttData{}
+}
+
+func (r *rttData) update(measurement *roundTripMeasurement) {
+ r.lock.Lock()
+ defer r.lock.Unlock()
+ // discard pings before lastMeasurementTime
+ if r.lastMeasurementTime.After(measurement.sendTime) {
+ return
+ }
+ r.lastMeasurementTime = measurement.sendTime
+ r.rtt = measurement.receiveTime.Sub(measurement.sendTime)
+ if r.rttMax < r.rtt {
+ r.rttMax = r.rtt
+ }
+ if r.rttMin == 0 || r.rttMin > r.rtt {
+ r.rttMin = r.rtt
+ }
+}
+
+func (r *rttData) metrics() (rtt, rttMin, rttMax time.Duration) {
+ r.lock.RLock()
+ defer r.lock.RUnlock()
+ return r.rtt, r.rttMin, r.rttMax
+}
+
+func newFlowControlData() *flowControlData {
+ return &flowControlData{queue: queue.New()}
+}
+
+func (f *flowControlData) update(measurement uint32) {
+ f.lock.Lock()
+ defer f.lock.Unlock()
+ var firstItem uint32
+ // store new data into queue, remove oldest data if queue is full
+ f.queue.Enqueue(measurement)
+ if f.queue.Len() > dataPoints {
+ // data type should always be uint32
+ firstItem = f.queue.Dequeue().(uint32)
+ }
+ // if (measurement - firstItem) < 0, uint64(measurement - firstItem)
+ // will overflow and become a large positive number
+ f.sum += uint64(measurement)
+ f.sum -= uint64(firstItem)
+ if measurement > f.max {
+ f.max = measurement
+ }
+ if f.min == 0 || measurement < f.min {
+ f.min = measurement
+ }
+}
+
+// caller of ave() should acquire lock first
+func (f *flowControlData) ave() float64 {
+ if f.queue.Len() == 0 {
+ return 0
+ }
+ return float64(f.sum) / float64(f.queue.Len())
+}
+
+func (f *flowControlData) metrics() (ave float64, min, max uint32) {
+ f.lock.RLock()
+ defer f.lock.RUnlock()
+ return f.ave(), f.min, f.max
+}
+
+func newRate() *rate {
+ return &rate{}
+}
+
+func (r *rate) update(measurement uint64) {
+ r.lock.Lock()
+ defer r.lock.Unlock()
+ r.curr = measurement
+ // if measurement is 0, then there is no incoming/outgoing connection, don't update min/max
+ if r.curr == 0 {
+ return
+ }
+ if measurement > r.max {
+ r.max = measurement
+ }
+ if r.min == 0 || measurement < r.min {
+ r.min = measurement
+ }
+}
+
+func (r *rate) get() (curr, min, max uint64) {
+ r.lock.RLock()
+ defer r.lock.RUnlock()
+ return r.curr, r.min, r.max
+}
diff --git a/h2mux/muxmetrics_test.go b/h2mux/muxmetrics_test.go
new file mode 100644
index 00000000..7f4cddf2
--- /dev/null
+++ b/h2mux/muxmetrics_test.go
@@ -0,0 +1,176 @@
+package h2mux
+
+import (
+ "sync"
+ "testing"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
+)
+
+func ave(sum uint64, len int) float64 {
+ return float64(sum) / float64(len)
+}
+
+func TestRTTUpdate(t *testing.T) {
+ r := newRTTData()
+ start := time.Now()
+ // send at 0 ms, receive at 2 ms, RTT = 2ms
+ m := &roundTripMeasurement{receiveTime: start.Add(2 * time.Millisecond), sendTime: start}
+ r.update(m)
+ assert.Equal(t, start, r.lastMeasurementTime)
+ assert.Equal(t, 2*time.Millisecond, r.rtt)
+ assert.Equal(t, 2*time.Millisecond, r.rttMin)
+ assert.Equal(t, 2*time.Millisecond, r.rttMax)
+
+ // send at 3 ms, receive at 6 ms, RTT = 3ms
+ m = &roundTripMeasurement{receiveTime: start.Add(6 * time.Millisecond), sendTime: start.Add(3 * time.Millisecond)}
+ r.update(m)
+ assert.Equal(t, start.Add(3*time.Millisecond), r.lastMeasurementTime)
+ assert.Equal(t, 3*time.Millisecond, r.rtt)
+ assert.Equal(t, 2*time.Millisecond, r.rttMin)
+ assert.Equal(t, 3*time.Millisecond, r.rttMax)
+
+ // send at 7 ms, receive at 8 ms, RTT = 1ms
+ m = &roundTripMeasurement{receiveTime: start.Add(8 * time.Millisecond), sendTime: start.Add(7 * time.Millisecond)}
+ r.update(m)
+ assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime)
+ assert.Equal(t, 1*time.Millisecond, r.rtt)
+ assert.Equal(t, 1*time.Millisecond, r.rttMin)
+ assert.Equal(t, 3*time.Millisecond, r.rttMax)
+
+ // send at -4 ms, receive at 0 ms, RTT = 4ms, but this ping is before last measurement
+ // so it will be discarded
+ m = &roundTripMeasurement{receiveTime: start, sendTime: start.Add(-2 * time.Millisecond)}
+ r.update(m)
+ assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime)
+ assert.Equal(t, 1*time.Millisecond, r.rtt)
+ assert.Equal(t, 1*time.Millisecond, r.rttMin)
+ assert.Equal(t, 3*time.Millisecond, r.rttMax)
+}
+
+func TestFlowControlDataUpdate(t *testing.T) {
+ f := newFlowControlData()
+ assert.Equal(t, 0, f.queue.Len())
+ assert.Equal(t, float64(0), f.ave())
+
+ var sum uint64
+ min := maxWindowSize - dataPoints
+ max := maxWindowSize
+ for i := 1; i <= dataPoints; i++ {
+ size := maxWindowSize - uint32(i)
+ f.update(size)
+ assert.Equal(t, max - uint32(1), f.max)
+ assert.Equal(t, size, f.min)
+
+ assert.Equal(t, i, f.queue.Len())
+
+ sum += uint64(size)
+ assert.Equal(t, sum, f.sum)
+ assert.Equal(t, ave(sum, f.queue.Len()), f.ave())
+ }
+
+ // queue is full, should start to dequeue first element
+ for i := 1; i <= dataPoints; i++ {
+ f.update(max)
+ assert.Equal(t, max, f.max)
+ assert.Equal(t, min, f.min)
+
+ assert.Equal(t, dataPoints, f.queue.Len())
+
+ sum += uint64(i)
+ assert.Equal(t, sum, f.sum)
+ assert.Equal(t, ave(sum, dataPoints), f.ave())
+ }
+}
+
+func TestMuxMetricsUpdater(t *testing.T) {
+ updateRTTChan := make(chan *roundTripMeasurement)
+ updateReceiveWindowChan := make(chan uint32)
+ updateSendWindowChan := make(chan uint32)
+ updateInBoundBytesChan := make(chan uint64)
+ updateOutBoundBytesChan := make(chan uint64)
+ abortChan := make(chan struct{})
+ errChan := make(chan error)
+ m := newMuxMetricsUpdater(updateRTTChan,
+ updateReceiveWindowChan,
+ updateSendWindowChan,
+ updateInBoundBytesChan,
+ updateOutBoundBytesChan,
+ abortChan,
+ )
+ logger := log.NewEntry(log.New())
+
+ go func() {
+ errChan <- m.run(logger)
+ }()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ // mock muxReader
+ readerStart := time.Now()
+ rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart}
+ updateRTTChan <- rm
+ go func() {
+ defer wg.Done()
+ // Becareful if dataPoints is not divisibile by 4
+ readerSend := readerStart.Add(time.Millisecond)
+ for i := 1; i <= dataPoints/4; i++ {
+ readerReceive := readerSend.Add(time.Duration(i) * time.Millisecond)
+ rm := &roundTripMeasurement{receiveTime: readerReceive, sendTime: readerSend}
+ updateRTTChan <- rm
+ readerSend = readerReceive.Add(time.Millisecond)
+
+ updateReceiveWindowChan <- uint32(i)
+ updateSendWindowChan <- uint32(i)
+
+ updateInBoundBytesChan <- uint64(i)
+ }
+ }()
+
+ // mock muxWriter
+ go func() {
+ defer wg.Done()
+ for j := dataPoints/4 + 1; j <= dataPoints/2; j++ {
+ updateReceiveWindowChan <- uint32(j)
+ updateSendWindowChan <- uint32(j)
+
+ // should always be disgard since the send time is before readerSend
+ rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart.Add(-time.Duration(j*dataPoints) * time.Millisecond)}
+ updateRTTChan <- rm
+
+ updateOutBoundBytesChan <- uint64(j)
+ }
+
+ }()
+ wg.Wait()
+
+ metrics := m.Metrics()
+ points := dataPoints / 2
+ assert.Equal(t, time.Millisecond, metrics.RTTMin)
+ assert.Equal(t, time.Duration(dataPoints/4)*time.Millisecond, metrics.RTTMax)
+
+ // sum(1..i) = i*(i+1)/2, ave(1..i) = i*(i+1)/2/i = (i+1)/2
+ assert.Equal(t, float64(points+1)/float64(2), metrics.ReceiveWindowAve)
+ assert.Equal(t, uint32(1), metrics.ReceiveWindowMin)
+ assert.Equal(t, uint32(points), metrics.ReceiveWindowMax)
+
+ assert.Equal(t, float64(points+1)/float64(2), metrics.SendWindowAve)
+ assert.Equal(t, uint32(1), metrics.SendWindowMin)
+ assert.Equal(t, uint32(points), metrics.SendWindowMax)
+
+ assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateCurr)
+ assert.Equal(t, uint64(1), metrics.InBoundRateMin)
+ assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateMax)
+
+ assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateCurr)
+ assert.Equal(t, uint64(dataPoints/4+1), metrics.OutBoundRateMin)
+ assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateMax)
+
+ close(abortChan)
+ assert.Nil(t, <-errChan)
+ close(errChan)
+
+}
diff --git a/h2mux/muxreader.go b/h2mux/muxreader.go
index a0b56366..3530dab8 100644
--- a/h2mux/muxreader.go
+++ b/h2mux/muxreader.go
@@ -3,7 +3,6 @@ package h2mux
import (
"encoding/binary"
"io"
- "sync"
"time"
log "github.com/sirupsen/logrus"
@@ -34,14 +33,18 @@ type MuxReader struct {
initialStreamWindow uint32
// The max value for the send window of a stream.
streamWindowMax uint32
- // windowMetrics keeps track of min/max/average of send/receive windows for all streams
- flowControlMetrics *FlowControlMetrics
- metricsMutex sync.Mutex
// r is a reference to the underlying connection used when shutting down.
r io.Closer
- // rttMeasurement measures RTT based on ping timestamps.
- rttMeasurement RTTMeasurement
- rttMutex sync.Mutex
+ // updateRTTChan is the channel to send new RTT measurement to muxerMetricsUpdater
+ updateRTTChan chan<- *roundTripMeasurement
+ // updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
+ updateReceiveWindowChan chan<- uint32
+ // updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
+ updateSendWindowChan chan<- uint32
+ // bytesRead is the amount of bytes read from data frame since the last time we send bytes read to metrics
+ bytesRead AtomicCounter
+ // updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
+ updateInBoundBytesChan chan<- uint64
}
func (r *MuxReader) Shutdown() {
@@ -57,28 +60,26 @@ func (r *MuxReader) Shutdown() {
}()
}
-func (r *MuxReader) RTT() RTTMeasurement {
- r.rttMutex.Lock()
- defer r.rttMutex.Unlock()
- return r.rttMeasurement
-}
-
-func (r *MuxReader) FlowControlMetrics() *FlowControlMetrics {
- r.metricsMutex.Lock()
- defer r.metricsMutex.Unlock()
- if r.flowControlMetrics != nil {
- return r.flowControlMetrics
- }
- // No metrics available yet
- return &FlowControlMetrics{}
-}
-
func (r *MuxReader) run(parentLogger *log.Entry) error {
logger := parentLogger.WithFields(log.Fields{
"subsystem": "mux",
"dir": "read",
})
defer logger.Debug("event loop finished")
+
+ // routine to periodically update bytesRead
+ go func() {
+ tickC := time.Tick(updateFreq)
+ for {
+ select {
+ case <-r.abortChan:
+ return
+ case <-tickC:
+ r.updateInBoundBytesChan <- r.bytesRead.Count()
+ }
+ }
+ }()
+
for {
frame, err := r.f.ReadFrame()
if err != nil {
@@ -120,6 +121,8 @@ func (r *MuxReader) run(parentLogger *log.Entry) error {
r.receivePingData(f)
case *http2.GoAwayFrame:
err = r.receiveGoAway(f)
+ // The receiver of a flow-controlled frame sends a WINDOW_UPDATE frame as it
+ // consumes data and frees up space in flow-control windows
case *http2.WindowUpdateFrame:
err = r.updateStreamWindow(f)
default:
@@ -236,10 +239,11 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E
}
data := frame.Data()
if len(data) > 0 {
- _, err = stream.readBuffer.Write(data)
+ n, err := stream.readBuffer.Write(data)
if err != nil {
return r.streamError(stream.streamID, http2.ErrCodeInternal)
}
+ r.bytesRead.IncrementBy(uint64(n))
}
if frame.Header().Flags.Has(http2.FlagDataEndStream) {
if stream.receiveEOF() {
@@ -253,6 +257,7 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E
if !stream.consumeReceiveWindow(uint32(len(data))) {
return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
}
+ r.updateReceiveWindowChan <- stream.getReceiveWindow()
return nil
}
@@ -263,10 +268,14 @@ func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
r.pingTimestamp.Set(ts)
return
}
- r.rttMutex.Lock()
- r.rttMeasurement.Update(time.Unix(0, ts))
- r.rttMutex.Unlock()
- r.flowControlMetrics = r.streams.Metrics()
+
+ // Update updates the computed values with a new measurement.
+ // outgoingTime is the time that the probe was sent.
+ // We assume that time.Now() is the time we received that probe.
+ r.updateRTTChan <- &roundTripMeasurement{
+ receiveTime: time.Now(),
+ sendTime: time.Unix(0, ts),
+ }
}
// Receive a GOAWAY from the peer. Gracefully shut down our connection.
@@ -293,6 +302,7 @@ func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
return nil
}
stream.replenishSendWindow(frame.Increment)
+ r.updateSendWindowChan <- stream.getSendWindow()
return nil
}
diff --git a/h2mux/muxwriter.go b/h2mux/muxwriter.go
index 1cfbf1f3..8ea4c675 100644
--- a/h2mux/muxwriter.go
+++ b/h2mux/muxwriter.go
@@ -40,6 +40,14 @@ type MuxWriter struct {
headerEncoder *hpack.Encoder
// headerBuffer is the temporary buffer used by headerEncoder.
headerBuffer bytes.Buffer
+ // updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
+ updateReceiveWindowChan chan<- uint32
+ // updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
+ updateSendWindowChan chan<- uint32
+ // bytesWrote is the amount of bytes wrote to data frame since the last time we send bytes wrote to metrics
+ bytesWrote AtomicCounter
+ // updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
+ updateOutBoundBytesChan chan<- uint64
}
type MuxedStreamRequest struct {
@@ -64,6 +72,20 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error {
"dir": "write",
})
defer logger.Debug("event loop finished")
+
+ // routine to periodically communicate bytesWrote
+ go func() {
+ tickC := time.Tick(updateFreq)
+ for {
+ select {
+ case <-w.abortChan:
+ return
+ case <-tickC:
+ w.updateOutBoundBytesChan <- w.bytesWrote.Count()
+ }
+ }
+ }()
+
for {
select {
case <-w.abortChan:
@@ -141,7 +163,8 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error {
func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error {
logger.Debug("writable")
chunk := stream.getChunk()
-
+ w.updateReceiveWindowChan <- stream.getReceiveWindow()
+ w.updateSendWindowChan <- stream.getSendWindow()
if chunk.sendHeadersFrame() {
err := w.writeHeaders(chunk.streamID, chunk.headers)
if err != nil {
@@ -154,7 +177,9 @@ func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) erro
if chunk.sendWindowUpdateFrame() {
// Send a WINDOW_UPDATE frame to update our receive window.
// If the Stream ID is zero, the window update applies to the connection as a whole
- // A WINDOW_UPDATE in a specific stream applies to the connection-level flow control as well.
+ // RFC7540 section-6.9.1 "A receiver that receives a flow-controlled frame MUST
+ // always account for its contribution against the connection flow-control
+ // window, unless the receiver treats this as a connection error"
err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
if err != nil {
logger.WithError(err).Warn("error writing window update")
@@ -170,6 +195,8 @@ func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) erro
logger.WithError(err).Warn("error writing data")
return err
}
+ // update the amount of data wrote
+ w.bytesWrote.IncrementBy(uint64(len(payload)))
logger.WithField("len", len(payload)).Debug("output data")
if sentEOF {
@@ -214,19 +241,16 @@ func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error {
return err
}
blockSize := int(w.maxFrameSize)
- continuation := false
endHeaders := len(encodedHeaders) == 0
for !endHeaders && err == nil {
blockFragment := encodedHeaders
if len(encodedHeaders) > blockSize {
blockFragment = blockFragment[:blockSize]
encodedHeaders = encodedHeaders[blockSize:]
- } else {
- endHeaders = true
- }
- if continuation {
+ // Send CONTINUATION frame if the headers can't be fit into 1 frame
err = w.f.WriteContinuation(streamID, endHeaders, blockFragment)
} else {
+ endHeaders = true
err = w.f.WriteHeaders(http2.HeadersFrameParam{
StreamID: streamID,
EndHeaders: endHeaders,
diff --git a/h2mux/rtt.go b/h2mux/rtt.go
index 1c42ff82..350233e3 100644
--- a/h2mux/rtt.go
+++ b/h2mux/rtt.go
@@ -2,7 +2,6 @@ package h2mux
import (
"sync/atomic"
- "time"
)
// PingTimestamp is an atomic interface around ping timestamping and signalling.
@@ -28,26 +27,3 @@ func (pt *PingTimestamp) Get() int64 {
func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} {
return pt.signal.WaitChannel()
}
-
-// RTTMeasurement encapsulates a continuous round trip time measurement.
-type RTTMeasurement struct {
- Current, Min, Max time.Duration
- lastMeasurementTime time.Time
-}
-
-// Update updates the computed values with a new measurement.
-// outgoingTime is the time that the probe was sent.
-// We assume that time.Now() is the time we received that probe.
-func (r *RTTMeasurement) Update(outgoingTime time.Time) {
- if !r.lastMeasurementTime.Before(outgoingTime) {
- return
- }
- r.lastMeasurementTime = outgoingTime
- r.Current = time.Since(outgoingTime)
- if r.Max < r.Current {
- r.Max = r.Current
- }
- if r.Min > r.Current {
- r.Min = r.Current
- }
-}
diff --git a/origin/metrics.go b/origin/metrics.go
index 314d4060..9908c554 100644
--- a/origin/metrics.go
+++ b/origin/metrics.go
@@ -2,13 +2,32 @@ package origin
import (
"sync"
+ "time"
- "github.com/cloudflare/cloudflare-warp/h2mux"
+ "github.com/cloudflare/cloudflared/h2mux"
"github.com/prometheus/client_golang/prometheus"
)
-type TunnelMetrics struct {
+type muxerMetrics struct {
+ rtt *prometheus.GaugeVec
+ rttMin *prometheus.GaugeVec
+ rttMax *prometheus.GaugeVec
+ receiveWindowAve *prometheus.GaugeVec
+ sendWindowAve *prometheus.GaugeVec
+ receiveWindowMin *prometheus.GaugeVec
+ receiveWindowMax *prometheus.GaugeVec
+ sendWindowMin *prometheus.GaugeVec
+ sendWindowMax *prometheus.GaugeVec
+ inBoundRateCurr *prometheus.GaugeVec
+ inBoundRateMin *prometheus.GaugeVec
+ inBoundRateMax *prometheus.GaugeVec
+ outBoundRateCurr *prometheus.GaugeVec
+ outBoundRateMin *prometheus.GaugeVec
+ outBoundRateMax *prometheus.GaugeVec
+}
+
+type tunnelMetrics struct {
haConnections prometheus.Gauge
totalRequests prometheus.Counter
requestsPerTunnel *prometheus.CounterVec
@@ -20,16 +39,7 @@ type TunnelMetrics struct {
maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
// concurrentRequests records max count of concurrent requests for each tunnel
maxConcurrentRequests map[string]uint64
- rtt prometheus.Gauge
- rttMin prometheus.Gauge
- rttMax prometheus.Gauge
timerRetries prometheus.Gauge
- receiveWindowSizeAve prometheus.Gauge
- sendWindowSizeAve prometheus.Gauge
- receiveWindowSizeMin prometheus.Gauge
- receiveWindowSizeMax prometheus.Gauge
- sendWindowSizeMin prometheus.Gauge
- sendWindowSizeMax prometheus.Gauge
responseByCode *prometheus.CounterVec
responseCodePerTunnel *prometheus.CounterVec
serverLocations *prometheus.GaugeVec
@@ -37,10 +47,189 @@ type TunnelMetrics struct {
locationLock sync.Mutex
// oldServerLocations stores the last server the tunnel was connected to
oldServerLocations map[string]string
+
+ muxerMetrics *muxerMetrics
+}
+
+func newMuxerMetrics() *muxerMetrics {
+ rtt := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "rtt",
+ Help: "Round-trip time in millisecond",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(rtt)
+
+ rttMin := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "rtt_min",
+ Help: "Shortest round-trip time in millisecond",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(rttMin)
+
+ rttMax := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "rtt_max",
+ Help: "Longest round-trip time in millisecond",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(rttMax)
+
+ receiveWindowAve := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "receive_window_ave",
+ Help: "Average receive window size in bytes",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(receiveWindowAve)
+
+ sendWindowAve := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "send_window_ave",
+ Help: "Average send window size in bytes",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(sendWindowAve)
+
+ receiveWindowMin := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "receive_window_min",
+ Help: "Smallest receive window size in bytes",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(receiveWindowMin)
+
+ receiveWindowMax := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "receive_window_max",
+ Help: "Largest receive window size in bytes",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(receiveWindowMax)
+
+ sendWindowMin := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "send_window_min",
+ Help: "Smallest send window size in bytes",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(sendWindowMin)
+
+ sendWindowMax := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "send_window_max",
+ Help: "Largest send window size in bytes",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(sendWindowMax)
+
+ inBoundRateCurr := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "inbound_bytes_per_sec_curr",
+ Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(inBoundRateCurr)
+
+ inBoundRateMin := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "inbound_bytes_per_sec_min",
+ Help: "Minimum non-zero inbounding bytes per second",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(inBoundRateMin)
+
+ inBoundRateMax := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "inbound_bytes_per_sec_max",
+ Help: "Maximum inbounding bytes per second",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(inBoundRateMax)
+
+ outBoundRateCurr := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "outbound_bytes_per_sec_curr",
+ Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(outBoundRateCurr)
+
+ outBoundRateMin := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "outbound_bytes_per_sec_min",
+ Help: "Minimum non-zero outbounding bytes per second",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(outBoundRateMin)
+
+ outBoundRateMax := prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Name: "outbound_bytes_per_sec_max",
+ Help: "Maximum outbounding bytes per second",
+ },
+ []string{"connection_id"},
+ )
+ prometheus.MustRegister(outBoundRateMax)
+
+ return &muxerMetrics{
+ rtt: rtt,
+ rttMin: rttMin,
+ rttMax: rttMax,
+ receiveWindowAve: receiveWindowAve,
+ sendWindowAve: sendWindowAve,
+ receiveWindowMin: receiveWindowMin,
+ receiveWindowMax: receiveWindowMax,
+ sendWindowMin: sendWindowMin,
+ sendWindowMax: sendWindowMax,
+ inBoundRateCurr: inBoundRateCurr,
+ inBoundRateMin: inBoundRateMin,
+ inBoundRateMax: inBoundRateMax,
+ outBoundRateCurr: outBoundRateCurr,
+ outBoundRateMin: outBoundRateMin,
+ outBoundRateMax: outBoundRateMax,
+ }
+}
+
+func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) {
+ m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT))
+ m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin))
+ m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax))
+ m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve)
+ m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve)
+ m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin))
+ m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax))
+ m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin))
+ m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax))
+ m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr))
+ m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin))
+ m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax))
+ m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr))
+ m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin))
+ m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax))
+}
+
+func convertRTTMilliSec(t time.Duration) float64 {
+ return float64(t / time.Millisecond)
}
// Metrics that can be collected without asking the edge
-func NewTunnelMetrics() *TunnelMetrics {
+func NewTunnelMetrics() *tunnelMetrics {
haConnections := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "ha_connections",
@@ -82,27 +271,6 @@ func NewTunnelMetrics() *TunnelMetrics {
)
prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
- rtt := prometheus.NewGauge(
- prometheus.GaugeOpts{
- Name: "rtt",
- Help: "Round-trip time",
- })
- prometheus.MustRegister(rtt)
-
- rttMin := prometheus.NewGauge(
- prometheus.GaugeOpts{
- Name: "rtt_min",
- Help: "Shortest round-trip time",
- })
- prometheus.MustRegister(rttMin)
-
- rttMax := prometheus.NewGauge(
- prometheus.GaugeOpts{
- Name: "rtt_max",
- Help: "Longest round-trip time",
- })
- prometheus.MustRegister(rttMax)
-
timerRetries := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "timer_retries",
@@ -110,48 +278,6 @@ func NewTunnelMetrics() *TunnelMetrics {
})
prometheus.MustRegister(timerRetries)
- receiveWindowSizeAve := prometheus.NewGauge(
- prometheus.GaugeOpts{
- Name: "receive_window_ave",
- Help: "Average receive window size",
- })
- prometheus.MustRegister(receiveWindowSizeAve)
-
- sendWindowSizeAve := prometheus.NewGauge(
- prometheus.GaugeOpts{
- Name: "send_window_ave",
- Help: "Average send window size",
- })
- prometheus.MustRegister(sendWindowSizeAve)
-
- receiveWindowSizeMin := prometheus.NewGauge(
- prometheus.GaugeOpts{
- Name: "receive_window_min",
- Help: "Smallest receive window size",
- })
- prometheus.MustRegister(receiveWindowSizeMin)
-
- receiveWindowSizeMax := prometheus.NewGauge(
- prometheus.GaugeOpts{
- Name: "receive_window_max",
- Help: "Largest receive window size",
- })
- prometheus.MustRegister(receiveWindowSizeMax)
-
- sendWindowSizeMin := prometheus.NewGauge(
- prometheus.GaugeOpts{
- Name: "send_window_min",
- Help: "Smallest send window size",
- })
- prometheus.MustRegister(sendWindowSizeMin)
-
- sendWindowSizeMax := prometheus.NewGauge(
- prometheus.GaugeOpts{
- Name: "send_window_max",
- Help: "Largest send window size",
- })
- prometheus.MustRegister(sendWindowSizeMax)
-
responseByCode := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "response_by_code",
@@ -179,7 +305,7 @@ func NewTunnelMetrics() *TunnelMetrics {
)
prometheus.MustRegister(serverLocations)
- return &TunnelMetrics{
+ return &tunnelMetrics{
haConnections: haConnections,
totalRequests: totalRequests,
requestsPerTunnel: requestsPerTunnel,
@@ -187,41 +313,28 @@ func NewTunnelMetrics() *TunnelMetrics {
concurrentRequests: make(map[string]uint64),
maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
maxConcurrentRequests: make(map[string]uint64),
- rtt: rtt,
- rttMin: rttMin,
- rttMax: rttMax,
- timerRetries: timerRetries,
- receiveWindowSizeAve: receiveWindowSizeAve,
- sendWindowSizeAve: sendWindowSizeAve,
- receiveWindowSizeMin: receiveWindowSizeMin,
- receiveWindowSizeMax: receiveWindowSizeMax,
- sendWindowSizeMin: sendWindowSizeMin,
- sendWindowSizeMax: sendWindowSizeMax,
- responseByCode: responseByCode,
- responseCodePerTunnel: responseCodePerTunnel,
- serverLocations: serverLocations,
- oldServerLocations: make(map[string]string),
+ timerRetries: timerRetries,
+ responseByCode: responseByCode,
+ responseCodePerTunnel: responseCodePerTunnel,
+ serverLocations: serverLocations,
+ oldServerLocations: make(map[string]string),
+ muxerMetrics: newMuxerMetrics(),
}
}
-func (t *TunnelMetrics) incrementHaConnections() {
+func (t *tunnelMetrics) incrementHaConnections() {
t.haConnections.Inc()
}
-func (t *TunnelMetrics) decrementHaConnections() {
+func (t *tunnelMetrics) decrementHaConnections() {
t.haConnections.Dec()
}
-func (t *TunnelMetrics) updateTunnelFlowControlMetrics(metrics *h2mux.FlowControlMetrics) {
- t.receiveWindowSizeAve.Set(float64(metrics.AverageReceiveWindowSize))
- t.sendWindowSizeAve.Set(float64(metrics.AverageSendWindowSize))
- t.receiveWindowSizeMin.Set(float64(metrics.MinReceiveWindowSize))
- t.receiveWindowSizeMax.Set(float64(metrics.MaxReceiveWindowSize))
- t.sendWindowSizeMin.Set(float64(metrics.MinSendWindowSize))
- t.sendWindowSizeMax.Set(float64(metrics.MaxSendWindowSize))
+func (t *tunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) {
+ t.muxerMetrics.update(connectionID, metrics)
}
-func (t *TunnelMetrics) incrementRequests(connectionID string) {
+func (t *tunnelMetrics) incrementRequests(connectionID string) {
t.concurrentRequestsLock.Lock()
var concurrentRequests uint64
var ok bool
@@ -243,7 +356,7 @@ func (t *TunnelMetrics) incrementRequests(connectionID string) {
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
}
-func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
+func (t *tunnelMetrics) decrementConcurrentRequests(connectionID string) {
t.concurrentRequestsLock.Lock()
if _, ok := t.concurrentRequests[connectionID]; ok {
t.concurrentRequests[connectionID] -= 1
@@ -255,13 +368,13 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
}
-func (t *TunnelMetrics) incrementResponses(connectionID, code string) {
+func (t *tunnelMetrics) incrementResponses(connectionID, code string) {
t.responseByCode.WithLabelValues(code).Inc()
t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
}
-func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) {
+func (t *tunnelMetrics) registerServerLocation(connectionID, loc string) {
t.locationLock.Lock()
defer t.locationLock.Unlock()
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
diff --git a/origin/tunnel.go b/origin/tunnel.go
index dfeb3ee8..e75988b4 100644
--- a/origin/tunnel.go
+++ b/origin/tunnel.go
@@ -15,11 +15,11 @@ import (
"golang.org/x/net/context"
- "github.com/cloudflare/cloudflare-warp/h2mux"
- "github.com/cloudflare/cloudflare-warp/tunnelrpc"
- tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
- "github.com/cloudflare/cloudflare-warp/validation"
- "github.com/cloudflare/cloudflare-warp/websocket"
+ "github.com/cloudflare/cloudflared/h2mux"
+ "github.com/cloudflare/cloudflared/tunnelrpc"
+ tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
+ "github.com/cloudflare/cloudflared/validation"
+ "github.com/cloudflare/cloudflared/websocket"
raven "github.com/getsentry/raven-go"
"github.com/pkg/errors"
@@ -53,7 +53,7 @@ type TunnelConfig struct {
Tags []tunnelpogs.Tag
HAConnections int
HTTPTransport http.RoundTripper
- Metrics *TunnelMetrics
+ Metrics *tunnelMetrics
MetricsUpdateFreq time.Duration
ProtocolLogger *logrus.Logger
Logger *logrus.Logger
@@ -185,6 +185,7 @@ func ServeTunnel(
serveCtx, serveCancel := context.WithCancel(ctx)
registerErrC := make(chan error, 1)
go func() {
+ defer wg.Done()
err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP)
if err == nil {
connectedFuse.Fuse(true)
@@ -193,18 +194,18 @@ func ServeTunnel(
serveCancel()
}
registerErrC <- err
- wg.Done()
}()
updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
go func() {
defer wg.Done()
+ connectionTag := uint8ToString(connectionID)
for {
select {
case <-serveCtx.Done():
handler.muxer.Shutdown()
return
case <-updateMetricsTickC:
- handler.UpdateMetrics()
+ handler.UpdateMetrics(connectionTag)
}
}
}()
@@ -303,7 +304,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
func LogServerInfo(logger *logrus.Entry,
promise tunnelrpc.ServerInfo_Promise,
connectionID uint8,
- metrics *TunnelMetrics,
+ metrics *tunnelMetrics,
) {
serverInfoMessage, err := promise.Struct()
if err != nil {
@@ -356,13 +357,17 @@ func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
return
}
+func FindCfRayHeader(h1 *http.Request) string {
+ return h1.Header.Get("Cf-Ray")
+}
+
type TunnelHandler struct {
originUrl string
muxer *h2mux.Muxer
httpClient http.RoundTripper
tlsConfig *tls.Config
tags []tunnelpogs.Tag
- metrics *TunnelMetrics
+ metrics *tunnelMetrics
// connectionID is only used by metrics, and prometheus requires labels to be string
connectionID string
}
@@ -435,7 +440,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
Log.WithError(err).Error("invalid request received")
}
h.AppendTagHeaders(req)
-
+ cfRay := FindCfRayHeader(req)
+ h.logRequest(req, cfRay)
if websocket.IsWebSocketUpgrade(req) {
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
if err != nil {
@@ -444,6 +450,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
stream.WriteHeaders(H1ResponseToH2Response(response))
defer conn.Close()
websocket.Stream(conn.UnderlyingConn(), stream)
+ h.metrics.incrementResponses(h.connectionID, "200")
+ h.logResponse(response, cfRay)
}
} else {
response, err := h.httpClient.RoundTrip(req)
@@ -454,6 +462,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
stream.WriteHeaders(H1ResponseToH2Response(response))
io.Copy(stream, response.Body)
h.metrics.incrementResponses(h.connectionID, "200")
+ h.logResponse(response, cfRay)
}
}
h.metrics.decrementConcurrentRequests(h.connectionID)
@@ -467,9 +476,27 @@ func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
h.metrics.incrementResponses(h.connectionID, "502")
}
-func (h *TunnelHandler) UpdateMetrics() {
- flowCtlMetrics := h.muxer.FlowControlMetrics()
- h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics)
+func (h *TunnelHandler) logRequest(req *http.Request, cfRay string) {
+ if cfRay != "" {
+ Log.WithField("CF-RAY", cfRay).Infof("%s %s %s", req.Method, req.URL, req.Proto)
+ } else {
+ Log.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto)
+ }
+ Log.Debugf("Request Headers %+v", req.Header)
+}
+
+func (h *TunnelHandler) logResponse(r *http.Response, cfRay string) {
+ if cfRay != "" {
+ Log.WithField("CF-RAY", cfRay).Infof("%s", r.Status)
+ } else {
+ Log.Infof("%s", r.Status)
+ }
+ Log.Debugf("Response Headers %+v", r.Header)
+}
+
+
+func (h *TunnelHandler) UpdateMetrics(connectionID string) {
+ h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics())
}
func uint8ToString(input uint8) string {
diff --git a/tlsconfig/hello_ca.go b/tlsconfig/hello_ca.go
index bb49093c..7ca15868 100644
--- a/tlsconfig/hello_ca.go
+++ b/tlsconfig/hello_ca.go
@@ -11,28 +11,28 @@ const (
BgUrgQQAIg==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
-MIGkAgEBBDAdyQBXfxTDCQSOT0HugmH9pVBtIw8t5dYvm6HxGlNq6P57v5GeN02Z
-dH9FRl7+VSWgBwYFK4EEACKhZANiAATqpFzTxxV7D+/oqhKCTR6BEM9elTfKaRQE
-FsLufcmaTMw/9tTwgpHKao/QsLKDTNbQhbSQLkcmpCQKlSGhl+pCrqNt/oYUAhav
-UIwpwGiLCqGH/R2AqWLKRPOa/Rufs/U=
+MIGkAgEBBDBGGfwhIJdiUiJUVIItqJjEIMmlXxsMa8TQeer47+g+cIZ466rgg8EK
++Mdn6BY48GCgBwYFK4EEACKhZANiAASW//A9iDbPKg3OLkn7yJqLer32g9I5lBKR
+tPc/zBubQLLz9lAaYI6AOQiJXhGr5JkKmQfi1sYHK5rJITPFy4W8Et4hHLdazDZH
+WnEd+TStQABFUjrhtqXPWmGKcly0pOE=
-----END EC PRIVATE KEY-----`
helloCRT = `
-----BEGIN CERTIFICATE-----
-MIICkDCCAhigAwIBAgIJAPtKfUjc2lwGMAkGByqGSM49BAEwgYoxCzAJBgNVBAYT
-AlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQKDBBD
-bG91ZGZsYXJlLCBJbmMuMT8wPQYDVQQDDDZDbG91ZGZsYXJlIEFyZ28gVHVubmVs
-IFNhbXBsZSBIZWxsbyBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMjE1MjAxNjU5
-WhcNMjgwMjEzMjAxNjU5WjCBijELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVRleGFz
-MQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENsb3VkZmxhcmUsIEluYy4xPzA9
-BgNVBAMMNkNsb3VkZmxhcmUgQXJnbyBUdW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZl
-ciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49AgEGBSuBBAAiA2IABOqkXNPHFXsP7+iq
-EoJNHoEQz16VN8ppFAQWwu59yZpMzD/21PCCkcpqj9CwsoNM1tCFtJAuRyakJAqV
-IaGX6kKuo23+hhQCFq9QjCnAaIsKoYf9HYCpYspE85r9G5+z9aNJMEcwRQYDVR0R
-BD4wPIIJbG9jYWxob3N0ggp3YXJwLWhlbGxvggt3YXJwMi1oZWxsb4cEfwAAAYcQ
-AAAAAAAAAAAAAAAAAAAAATAJBgcqhkjOPQQBA2cAMGQCMHyVPufXZ6vQo6XRWRa0
-dAwtfgesOdZVP2Wt+t5v8jOIQQh1IQXYk5GtyoZGSObjhQIwd1fRgAyKXaZt+1DV
-ZtHTdf8pMvESfJsSd8AB1eQ6q+pAiRUYyaxcE1Mlo2YY5o+g
+MIICiDCCAg6gAwIBAgIJAJ/FfkBTtbuIMAkGByqGSM49BAEwfzELMAkGA1UEBhMC
+VVMxDjAMBgNVBAgMBVRleGFzMQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENs
+b3VkZmxhcmUsIEluYy4xNDAyBgNVBAMMK0FyZ28gVHVubmVsIFNhbXBsZSBIZWxs
+byBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMzE5MjMwNTMyWhcNMjgwMzE2MjMw
+NTMyWjB/MQswCQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1
+c3RpbjEZMBcGA1UECgwQQ2xvdWRmbGFyZSwgSW5jLjE0MDIGA1UEAwwrQXJnbyBU
+dW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZlciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49
+AgEGBSuBBAAiA2IABJb/8D2INs8qDc4uSfvImot6vfaD0jmUEpG09z/MG5tAsvP2
+UBpgjoA5CIleEavkmQqZB+LWxgcrmskhM8XLhbwS3iEct1rMNkdacR35NK1AAEVS
+OuG2pc9aYYpyXLSk4aNXMFUwUwYDVR0RBEwwSoIJbG9jYWxob3N0ghFjbG91ZGZs
+YXJlZC1oZWxsb4ISY2xvdWRmbGFyZWQyLWhlbGxvhwR/AAABhxAAAAAAAAAAAAAA
+AAAAAAABMAkGByqGSM49BAEDaQAwZgIxAPxkdghH6y8xLMnY9Bom3Llf4NYM6yB9
+PD1YsaNUJTsxjTk3YY1Jsp+yzK0yUKtTZwIxAPcdvqCF2/iR9H288pCT1TgtO0a9
+cJL9RY1lq7DIGN37v1ZXReWaD+3hNokY8NriVg==
-----END CERTIFICATE-----`
)
diff --git a/tunneldns/https_proxy.go b/tunneldns/https_proxy.go
new file mode 100644
index 00000000..3f6140b8
--- /dev/null
+++ b/tunneldns/https_proxy.go
@@ -0,0 +1,38 @@
+package tunneldns
+
+import (
+ "github.com/coredns/coredns/plugin"
+ "github.com/miekg/dns"
+ "github.com/pkg/errors"
+ "golang.org/x/net/context"
+)
+
+// Upstream is a simplified interface for proxy destination
+type Upstream interface {
+ Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
+}
+
+// ProxyPlugin is a simplified DNS proxy using a generic upstream interface
+type ProxyPlugin struct {
+ Upstreams []Upstream
+ Next plugin.Handler
+}
+
+// ServeDNS implements interface for CoreDNS plugin
+func (p ProxyPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ var reply *dns.Msg
+ var backendErr error
+
+ for _, upstream := range p.Upstreams {
+ reply, backendErr = upstream.Exchange(ctx, r)
+ if backendErr == nil {
+ w.WriteMsg(reply)
+ return 0, nil
+ }
+ }
+
+ return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
+}
+
+// Name implements interface for CoreDNS plugin
+func (p ProxyPlugin) Name() string { return "proxy" }
diff --git a/tunneldns/https_upstream.go b/tunneldns/https_upstream.go
new file mode 100644
index 00000000..eee448c4
--- /dev/null
+++ b/tunneldns/https_upstream.go
@@ -0,0 +1,97 @@
+package tunneldns
+
+import (
+ "bytes"
+ "crypto/tls"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "time"
+
+ "github.com/miekg/dns"
+ "github.com/pkg/errors"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/net/context"
+)
+
+const (
+ defaultTimeout = 5 * time.Second
+)
+
+// UpstreamHTTPS is the upstream implementation for DNS over HTTPS service
+type UpstreamHTTPS struct {
+ client *http.Client
+ endpoint *url.URL
+}
+
+// NewUpstreamHTTPS creates a new DNS over HTTPS upstream from hostname
+func NewUpstreamHTTPS(endpoint string) (Upstream, error) {
+ u, err := url.Parse(endpoint)
+ if err != nil {
+ return nil, err
+ }
+
+ // Update TLS and HTTP client configuration
+ tls := &tls.Config{ServerName: u.Hostname()}
+ client := &http.Client{
+ Timeout: time.Second * defaultTimeout,
+ Transport: &http.Transport{TLSClientConfig: tls},
+ }
+
+ return &UpstreamHTTPS{client: client, endpoint: u}, nil
+}
+
+// Exchange provides an implementation for the Upstream interface
+func (u *UpstreamHTTPS) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
+ queryBuf, err := query.Pack()
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to pack DNS query")
+ }
+
+ // No content negotiation for now, use DNS wire format
+ buf, backendErr := u.exchangeWireformat(queryBuf)
+ if backendErr == nil {
+ response := &dns.Msg{}
+ if err := response.Unpack(buf); err != nil {
+ return nil, errors.Wrap(err, "failed to unpack DNS response from body")
+ }
+
+ response.Id = query.Id
+ return response, nil
+ }
+
+ log.WithError(backendErr).Errorf("failed to connect to an HTTPS backend %q", u.endpoint)
+ return nil, backendErr
+}
+
+// Perform message exchange with the default UDP wireformat defined in current draft
+// https://datatracker.ietf.org/doc/draft-ietf-doh-dns-over-https
+func (u *UpstreamHTTPS) exchangeWireformat(msg []byte) ([]byte, error) {
+ req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to create an HTTPS request")
+ }
+
+ req.Header.Add("Content-Type", "application/dns-udpwireformat")
+ req.Host = u.endpoint.Hostname()
+
+ resp, err := u.client.Do(req)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to perform an HTTPS request")
+ }
+
+ // Check response status code
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
+ }
+
+ // Read wireformat response from the body
+ buf, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to read the response body")
+ }
+
+ return buf, nil
+}
diff --git a/tunneldns/metrics.go b/tunneldns/metrics.go
new file mode 100644
index 00000000..5f688186
--- /dev/null
+++ b/tunneldns/metrics.go
@@ -0,0 +1,45 @@
+package tunneldns
+
+import (
+ "github.com/coredns/coredns/plugin"
+ "github.com/coredns/coredns/plugin/metrics/vars"
+ "github.com/coredns/coredns/plugin/pkg/dnstest"
+ "github.com/coredns/coredns/plugin/pkg/rcode"
+ "github.com/coredns/coredns/request"
+ "github.com/miekg/dns"
+ "github.com/prometheus/client_golang/prometheus"
+ "golang.org/x/net/context"
+)
+
+// MetricsPlugin is an adapter for CoreDNS and built-in metrics
+type MetricsPlugin struct {
+ Next plugin.Handler
+}
+
+// NewMetricsPlugin creates a plugin with configured metrics
+func NewMetricsPlugin(next plugin.Handler) *MetricsPlugin {
+ prometheus.MustRegister(vars.RequestCount)
+ prometheus.MustRegister(vars.RequestDuration)
+ prometheus.MustRegister(vars.RequestSize)
+ prometheus.MustRegister(vars.RequestDo)
+ prometheus.MustRegister(vars.RequestType)
+ prometheus.MustRegister(vars.ResponseSize)
+ prometheus.MustRegister(vars.ResponseRcode)
+ return &MetricsPlugin{Next: next}
+}
+
+// ServeDNS implements the CoreDNS plugin interface
+func (p MetricsPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ state := request.Request{W: w, Req: r}
+
+ rw := dnstest.NewRecorder(w)
+ status, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, rw, r)
+
+ // Update built-in metrics
+ vars.Report(state, ".", rcode.ToString(rw.Rcode), rw.Len, rw.Start)
+
+ return status, err
+}
+
+// Name implements the CoreDNS plugin interface
+func (p MetricsPlugin) Name() string { return "metrics" }
diff --git a/tunneldns/tunnel.go b/tunneldns/tunnel.go
new file mode 100644
index 00000000..f4b5eed7
--- /dev/null
+++ b/tunneldns/tunnel.go
@@ -0,0 +1,144 @@
+package tunneldns
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "os/signal"
+ "strconv"
+ "sync"
+ "syscall"
+
+ "gopkg.in/urfave/cli.v2"
+
+ "github.com/cloudflare/cloudflared/metrics"
+ "github.com/coredns/coredns/core/dnsserver"
+ "github.com/coredns/coredns/plugin"
+ "github.com/coredns/coredns/plugin/cache"
+ "github.com/pkg/errors"
+ log "github.com/sirupsen/logrus"
+)
+
+// Listener is an adapter between CoreDNS server and Warp runnable
+type Listener struct {
+ server *dnsserver.Server
+ wg sync.WaitGroup
+}
+
+// Run implements a foreground runner
+func Run(c *cli.Context) error {
+ metricsListener, err := net.Listen("tcp", c.String("metrics"))
+ if err != nil {
+ log.WithError(err).Fatal("Failed to open the metrics listener")
+ }
+
+ go metrics.ServeMetrics(metricsListener, nil)
+
+ listener, err := CreateListener(c.String("address"), uint16(c.Uint("port")), c.StringSlice("upstream"))
+ if err != nil {
+ log.WithError(err).Errorf("Failed to create the listeners")
+ return err
+ }
+
+ // Try to start the server
+ err = listener.Start()
+ if err != nil {
+ log.WithError(err).Errorf("Failed to start the listeners")
+ return listener.Stop()
+ }
+
+ // Wait for signal
+ signals := make(chan os.Signal, 10)
+ signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
+ defer signal.Stop(signals)
+ <-signals
+
+ // Shut down server
+ err = listener.Stop()
+ if err != nil {
+ log.WithError(err).Errorf("failed to stop")
+ }
+ return err
+}
+
+// Create a CoreDNS server plugin from configuration
+func createConfig(address string, port uint16, p plugin.Handler) *dnsserver.Config {
+ c := &dnsserver.Config{
+ Zone: ".",
+ Transport: "dns",
+ ListenHosts: []string{address},
+ Port: strconv.FormatUint(uint64(port), 10),
+ }
+
+ c.AddPlugin(func(next plugin.Handler) plugin.Handler { return p })
+ return c
+}
+
+// Start blocks for serving requests
+func (l *Listener) Start() error {
+ log.WithField("addr", l.server.Address()).Infof("Starting DNS over HTTPS proxy server")
+
+ // Start UDP listener
+ if udp, err := l.server.ListenPacket(); err == nil {
+ l.wg.Add(1)
+ go func() {
+ l.server.ServePacket(udp)
+ l.wg.Done()
+ }()
+ } else {
+ return errors.Wrap(err, "failed to create a UDP listener")
+ }
+
+ // Start TCP listener
+ tcp, err := l.server.Listen()
+ if err == nil {
+ l.wg.Add(1)
+ go func() {
+ l.server.Serve(tcp)
+ l.wg.Done()
+ }()
+ }
+
+ return errors.Wrap(err, "failed to create a TCP listener")
+}
+
+// Stop signals server shutdown and blocks until completed
+func (l *Listener) Stop() error {
+ if err := l.server.Stop(); err != nil {
+ return err
+ }
+
+ l.wg.Wait()
+ return nil
+}
+
+// CreateListener configures the server and bound sockets
+func CreateListener(address string, port uint16, upstreams []string) (*Listener, error) {
+ // Build the list of upstreams
+ upstreamList := make([]Upstream, 0)
+ for _, url := range upstreams {
+ log.WithField("url", url).Infof("Adding DNS upstream")
+ upstream, err := NewUpstreamHTTPS(url)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to create HTTPS upstream")
+ }
+ upstreamList = append(upstreamList, upstream)
+ }
+
+ // Create a local cache with HTTPS proxy plugin
+ chain := cache.New()
+ chain.Next = ProxyPlugin{
+ Upstreams: upstreamList,
+ }
+
+ // Format an endpoint
+ endpoint := fmt.Sprintf("dns://%s:%d", address, port)
+
+ // Create the actual middleware server
+ server, err := dnsserver.NewServer(endpoint, []*dnsserver.Config{createConfig(address, port, NewMetricsPlugin(chain))})
+ if err != nil {
+ return nil, err
+ }
+
+ return &Listener{server: server}, nil
+}
diff --git a/tunnelrpc/pogs/tunnelrpc.go b/tunnelrpc/pogs/tunnelrpc.go
index 4f5280e4..f70545b3 100644
--- a/tunnelrpc/pogs/tunnelrpc.go
+++ b/tunnelrpc/pogs/tunnelrpc.go
@@ -1,7 +1,7 @@
package pogs
import (
- "github.com/cloudflare/cloudflare-warp/tunnelrpc"
+ "github.com/cloudflare/cloudflared/tunnelrpc"
"golang.org/x/net/context"
"zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/pogs"
diff --git a/tunnelrpc/tunnelrpc.capnp b/tunnelrpc/tunnelrpc.capnp
index 636f317a..3886f534 100644
--- a/tunnelrpc/tunnelrpc.capnp
+++ b/tunnelrpc/tunnelrpc.capnp
@@ -1,7 +1,7 @@
using Go = import "go.capnp";
@0xdb8274f9144abc7e;
$Go.package("tunnelrpc");
-$Go.import("github.com/cloudflare/cloudflare-warp/tunnelrpc");
+$Go.import("github.com/cloudflare/cloudflared/tunnelrpc");
struct Authentication {
key @0 :Text;
@@ -35,7 +35,7 @@ struct RegistrationOptions {
connectionId @6 :UInt8;
# origin LAN IP
originLocalIp @7 :Text;
- # whether Warp client has been autoupdated
+ # whether Argo Tunnel client has been autoupdated
isAutoupdated @8 :Bool;
}
diff --git a/validation/validation.go b/validation/validation.go
index eb4eb90b..746bf9c0 100644
--- a/validation/validation.go
+++ b/validation/validation.go
@@ -119,7 +119,7 @@ func validateScheme(scheme string) error {
return nil
}
}
- return fmt.Errorf("Currently Cloudflare-Warp does not support %s protocol.", scheme)
+ return fmt.Errorf("Currently Argo Tunnel does not support %s protocol.", scheme)
}
func validateIP(scheme, host, port string) (string, error) {
diff --git a/validation/validation_test.go b/validation/validation_test.go
index 0be8066b..49a2cf10 100644
--- a/validation/validation_test.go
+++ b/validation/validation_test.go
@@ -126,7 +126,7 @@ func TestValidateUrl(t *testing.T) {
assert.Equal(t, "https://hello.example.com", validUrl)
validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
- assert.Equal(t, "Currently Cloudflare-Warp does not support ftp protocol.", err.Error())
+ assert.Equal(t, "Currently Argo Tunnel does not support ftp protocol.", err.Error())
assert.Empty(t, validUrl)
validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")