From 82cb539fbe92f9fbf808adc8e634bd1d4ca5d814 Mon Sep 17 00:00:00 2001 From: Chris Branch Date: Mon, 16 Oct 2017 12:44:03 +0100 Subject: [PATCH] Initial import --- README.md | 9 + cmd/cloudflare-warp/cloudflareca.go | 39 + cmd/cloudflare-warp/generic_service.go | 13 + cmd/cloudflare-warp/hello.go | 155 +++ cmd/cloudflare-warp/hello_test.go | 23 + cmd/cloudflare-warp/linux_service.go | 264 ++++++ cmd/cloudflare-warp/login.go | 31 + cmd/cloudflare-warp/macos_service.go | 82 ++ cmd/cloudflare-warp/main.go | 473 ++++++++++ cmd/cloudflare-warp/service_template.go | 88 ++ cmd/cloudflare-warp/tag.go | 32 + cmd/cloudflare-warp/tag_test.go | 46 + cmd/cloudflare-warp/update.go | 41 + cmd/cloudflare-warp/windows_service.go | 145 +++ h2mux/activestreammap.go | 213 +++++ h2mux/activestreammap_test.go | 28 + h2mux/booleanfuse.go | 25 + h2mux/error.go | 61 ++ h2mux/h2mux.go | 335 +++++++ h2mux/h2mux_test.go | 646 +++++++++++++ h2mux/idletimer.go | 81 ++ h2mux/idletimer_test.go | 31 + h2mux/muxedstream.go | 250 +++++ h2mux/muxedstream_test.go | 65 ++ h2mux/muxreader.go | 326 +++++++ h2mux/muxwriter.go | 238 +++++ h2mux/readylist.go | 140 +++ h2mux/readylist_test.go | 115 +++ h2mux/rtt.go | 53 ++ h2mux/shared_buffer.go | 64 ++ h2mux/shared_buffer_test.go | 120 +++ h2mux/signal.go | 34 + h2mux/streamerrormap.go | 47 + metrics/metrics.go | 48 + origin/backoffhandler.go | 70 ++ origin/backoffhandler_test.go | 114 +++ origin/tunnel.go | 360 +++++++ tlsconfig/tlsconfig.go | 62 ++ tunnelrpc/go.capnp | 15 + tunnelrpc/log.go | 26 + tunnelrpc/logtransport.go | 45 + tunnelrpc/pogs/tunnelrpc.go | 194 ++++ tunnelrpc/tunnelrpc.capnp | 56 ++ tunnelrpc/tunnelrpc.capnp.go | 1145 +++++++++++++++++++++++ validation/validation.go | 136 +++ validation/validation_test.go | 136 +++ 46 files changed, 6720 insertions(+) create mode 100644 README.md create mode 100644 cmd/cloudflare-warp/cloudflareca.go create mode 100644 cmd/cloudflare-warp/generic_service.go create mode 100644 cmd/cloudflare-warp/hello.go create mode 100644 cmd/cloudflare-warp/hello_test.go create mode 100644 cmd/cloudflare-warp/linux_service.go create mode 100644 cmd/cloudflare-warp/login.go create mode 100644 cmd/cloudflare-warp/macos_service.go create mode 100644 cmd/cloudflare-warp/main.go create mode 100644 cmd/cloudflare-warp/service_template.go create mode 100644 cmd/cloudflare-warp/tag.go create mode 100644 cmd/cloudflare-warp/tag_test.go create mode 100644 cmd/cloudflare-warp/update.go create mode 100644 cmd/cloudflare-warp/windows_service.go create mode 100644 h2mux/activestreammap.go create mode 100644 h2mux/activestreammap_test.go create mode 100644 h2mux/booleanfuse.go create mode 100644 h2mux/error.go create mode 100644 h2mux/h2mux.go create mode 100644 h2mux/h2mux_test.go create mode 100644 h2mux/idletimer.go create mode 100644 h2mux/idletimer_test.go create mode 100644 h2mux/muxedstream.go create mode 100644 h2mux/muxedstream_test.go create mode 100644 h2mux/muxreader.go create mode 100644 h2mux/muxwriter.go create mode 100644 h2mux/readylist.go create mode 100644 h2mux/readylist_test.go create mode 100644 h2mux/rtt.go create mode 100644 h2mux/shared_buffer.go create mode 100644 h2mux/shared_buffer_test.go create mode 100644 h2mux/signal.go create mode 100644 h2mux/streamerrormap.go create mode 100644 metrics/metrics.go create mode 100644 origin/backoffhandler.go create mode 100644 origin/backoffhandler_test.go create mode 100644 origin/tunnel.go create mode 100644 tlsconfig/tlsconfig.go create mode 100644 tunnelrpc/go.capnp create mode 100644 tunnelrpc/log.go create mode 100644 tunnelrpc/logtransport.go create mode 100644 tunnelrpc/pogs/tunnelrpc.go create mode 100644 tunnelrpc/tunnelrpc.capnp create mode 100644 tunnelrpc/tunnelrpc.capnp.go create mode 100644 validation/validation.go create mode 100644 validation/validation_test.go diff --git a/README.md b/README.md new file mode 100644 index 00000000..1ec515e8 --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +# Cloudflare Warp client + +Contains the command-line client and its libraries for Cloudflare Warp, a tunneling daemon that proxies any local webserver through the Cloudflare network. + +## Getting started + + go install github.com/cloudflare/cloudflare-warp/cmd/cloudflare-warp + +User documentation for Warp can be found at https://warp.cloudflare.com diff --git a/cmd/cloudflare-warp/cloudflareca.go b/cmd/cloudflare-warp/cloudflareca.go new file mode 100644 index 00000000..ee88666a --- /dev/null +++ b/cmd/cloudflare-warp/cloudflareca.go @@ -0,0 +1,39 @@ +package main + +import ( + "crypto/x509" +) + +const cloudflareRootCA = `-----BEGIN CERTIFICATE----- +MIID/DCCAuagAwIBAgIID+rOSdTGfGcwCwYJKoZIhvcNAQELMIGLMQswCQYDVQQG +EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRG +bGFyZSBPcmlnaW4gU1NMIENlcnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMN +U2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTAeFw0xNDExMTMyMDM4 +NTBaFw0xOTExMTQwMTQzNTBaMIGLMQswCQYDVQQGEwJVUzEZMBcGA1UEChMQQ2xv +dWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRGbGFyZSBPcmlnaW4gU1NMIENl +cnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEG +A1UECBMKQ2FsaWZvcm5pYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +AMBIlWf1KEKR5hbB75OYrAcUXobpD/AxvSYRXr91mbRu+lqE7YbyyRUShQh15lem +ef+umeEtPZoLFLhcLyczJxOhI+siLGDQm/a/UDkWvAXYa5DZ+pHU5ct5nZ8pGzqJ +p8G1Hy5RMVYDXZT9F6EaHjMG0OOffH6Ih25TtgfyyrjXycwDH0u6GXt+G/rywcqz +/9W4Aki3XNQMUHNQAtBLEEIYHMkyTYJxuL2tXO6ID5cCsoWw8meHufTeZW2DyUpl +yP3AHt4149RQSyWZMJ6AyntL9d8Xhfpxd9rJkh9Kge2iV9rQTFuE1rRT5s7OSJcK +xUsklgHcGHYMcNfNMilNHb8CAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgAGMBIGA1Ud +EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFCToU1ddfDRAh6nrlNu64RZ4/CmkMB8G +A1UdIwQYMBaAFCToU1ddfDRAh6nrlNu64RZ4/CmkMAsGCSqGSIb3DQEBCwOCAQEA +cQDBVAoRrhhsGegsSFsv1w8v27zzHKaJNv6ffLGIRvXK8VKKK0gKXh2zQtN9SnaD +gYNe7Pr4C3I8ooYKRJJWLsmEHdGdnYYmj0OJfGrfQf6MLIc/11bQhLepZTxdhFYh +QGgDl6gRmb8aDwk7Q92BPvek5nMzaWlP82ixavvYI+okoSY8pwdcVKobx6rWzMWz +ZEC9M6H3F0dDYE23XcCFIdgNSAmmGyXPBstOe0aAJXwJTxOEPn36VWr0PKIQJy5Y +4o1wpMpqCOIwWc8J9REV/REzN6Z1LXImdUgXIXOwrz56gKUJzPejtBQyIGj0mveX +Fu6q54beR89jDc+oABmOgg== +-----END CERTIFICATE-----` + +func GetCloudflareRootCA() *x509.CertPool { + ca := x509.NewCertPool() + if !ca.AppendCertsFromPEM([]byte(cloudflareRootCA)) { + // should never happen + panic("failure loading Cloudflare origin CA pem") + } + return ca +} diff --git a/cmd/cloudflare-warp/generic_service.go b/cmd/cloudflare-warp/generic_service.go new file mode 100644 index 00000000..3eb4819e --- /dev/null +++ b/cmd/cloudflare-warp/generic_service.go @@ -0,0 +1,13 @@ +// +build !windows,!darwin,!linux + +package main + +import ( + "os" + + cli "gopkg.in/urfave/cli.v2" +) + +func runApp(app *cli.App) { + app.Run(os.Args) +} diff --git a/cmd/cloudflare-warp/hello.go b/cmd/cloudflare-warp/hello.go new file mode 100644 index 00000000..87d10868 --- /dev/null +++ b/cmd/cloudflare-warp/hello.go @@ -0,0 +1,155 @@ +package main + +import ( + "bytes" + "fmt" + "html/template" + "net" + "net/http" + "os" + "strings" + + "github.com/pkg/errors" + + log "github.com/Sirupsen/logrus" + "github.com/cloudflare/cloudflare-warp/origin" + tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" + cli "gopkg.in/urfave/cli.v2" +) + +type templateData struct { + ServerName string + Request *http.Request + Tags []tunnelpogs.Tag +} + +const defaultServerName = "the Cloudflare Warp test server" +const indexTemplate = ` + + + + + + + Cloudflare Warp Connection + + + + + + + +
+
+ + + + + + +

Congrats! You created your first tunnel!

+

+ Cloudflare Warp exposes locally running applications to the internet by + running an encrypted, virtual tunnel from your laptop or server to + Cloudflare's edge network. +

+

Ready for the next step?

+ + Get started here + +{{if .Tags}}
+

Connection

+
+{{range .Tags}}
{{.Name}}
+
{{.Value}}
+{{end}}
+
+{{end}}
+
+ + +` + +func hello(c *cli.Context) error { + address := fmt.Sprintf(":%d", c.Int("port")) + server := NewHelloWorldServer() + if hostname, err := os.Hostname(); err != nil { + server.serverName = hostname + } + err := server.ListenAndServe(address) + return errors.Wrap(err, "Fail to start Hello World Server") +} + +func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error { + server := NewHelloWorldServer() + if hostname, err := os.Hostname(); err != nil { + server.serverName = hostname + } + httpServer := &http.Server{Addr: listener.Addr().String(), Handler: server} + go func() { + <-shutdownC + httpServer.Close() + }() + err := httpServer.Serve(listener) + return err +} + +type HelloWorldServer struct { + responseTemplate *template.Template + serverName string +} + +func NewHelloWorldServer() *HelloWorldServer { + return &HelloWorldServer{ + responseTemplate: template.Must(template.New("index").Parse(indexTemplate)), + serverName: defaultServerName, + } +} + +func findAvailablePort() (net.Listener, error) { + // If the port in address is empty, a port number is automatically chosen. + listener, err := net.Listen("tcp", "127.0.0.1:") + return listener, err +} + +func (s *HelloWorldServer) ListenAndServe(address string) error { + log.Infof("Starting Hello World server on %s", address) + err := http.ListenAndServe(address, s) + return err +} + +func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto) + var buffer bytes.Buffer + err := s.responseTemplate.Execute(&buffer, &templateData{ + ServerName: s.serverName, + Request: r, + Tags: tagsFromHeaders(r.Header), + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "error: %v", err) + } else { + buffer.WriteTo(w) + } +} + +func tagsFromHeaders(header http.Header) []tunnelpogs.Tag { + var tags []tunnelpogs.Tag + for headerName, headerValues := range header { + trimmed := strings.TrimPrefix(headerName, origin.TagHeaderNamePrefix) + if trimmed == headerName { + continue + } + for _, value := range headerValues { + tags = append(tags, tunnelpogs.Tag{Name: trimmed, Value: value}) + } + } + return tags +} diff --git a/cmd/cloudflare-warp/hello_test.go b/cmd/cloudflare-warp/hello_test.go new file mode 100644 index 00000000..9586953d --- /dev/null +++ b/cmd/cloudflare-warp/hello_test.go @@ -0,0 +1,23 @@ +package main + +import ( + "testing" +) + +const testPort = "8080" + +func TestNewHelloWorldServer(t *testing.T) { + if NewHelloWorldServer() == nil { + t.Fatal("NewHelloWorldServer returned nil") + } +} + +func TestFindAvailablePort(t *testing.T) { + listener, err := findAvailablePort() + if err != nil { + t.Fatal("Fail to find available port") + } + if listener.Addr().String() == "" { + t.Fatal("Fail to find available port") + } +} diff --git a/cmd/cloudflare-warp/linux_service.go b/cmd/cloudflare-warp/linux_service.go new file mode 100644 index 00000000..99a1686c --- /dev/null +++ b/cmd/cloudflare-warp/linux_service.go @@ -0,0 +1,264 @@ +// +build linux + +package main + +import ( + "fmt" + "os" + + cli "gopkg.in/urfave/cli.v2" +) + +func runApp(app *cli.App) { + app.Commands = append(app.Commands, &cli.Command{ + Name: "service", + Usage: "Manages the Cloudflare Warp system service", + Subcommands: []*cli.Command{ + &cli.Command{ + Name: "install", + Usage: "Install Cloudflare Warp as a system service", + Action: installLinuxService, + }, + &cli.Command{ + Name: "uninstall", + Usage: "Uninstall the Cloudflare Warp service", + Action: uninstallLinuxService, + }, + }, + }) + app.Run(os.Args) +} + +var systemdTemplates = []ServiceTemplate{ + { + Path: "/etc/systemd/system/cloudflare-warp.service", + Content: `[Unit] +Description=Cloudflare Warp +After=network.target + +[Service] +TimeoutStartSec=0 +Type=notify +ExecStart={{ .Path }} --config /etc/cloudflare-warp.yml --autoupdate 0s +User=nobody + +[Install] +WantedBy=multi-user.target +`, + }, + { + Path: "/etc/systemd/system/cloudflare-warp-update.service", + Content: `[Unit] +Description=Update Cloudflare Warp +After=network.target + +[Service] +ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 64 ]; then systemctl restart cloudflare-warp; exit 0; fi; exit $code' +`, + }, + { + Path: "/etc/systemd/system/cloudflare-warp-update.timer", + Content: `[Unit] +Description=Update Cloudflare Warp + +[Timer] +OnUnitActiveSec=1d + +[Install] +WantedBy=timers.target +`, + }, +} + +var sysvTemplate = ServiceTemplate{ + Path: "/etc/init.d/cloudflare-warp", + FileMode: 0755, + Content: `# For RedHat and cousins: +# chkconfig: 2345 99 01 +# description: Cloudflare Warp agent +# processname: {{.Path}} +### BEGIN INIT INFO +# Provides: {{.Path}} +# Required-Start: +# Required-Stop: +# Default-Start: 2 3 4 5 +# Default-Stop: 0 1 6 +# Short-Description: Cloudflare Warp +# Description: Cloudflare Warp agent +### END INIT INFO +cmd="{{.Path}} --config /etc/cloudflare-warp.yml --pidfile /var/run/$name.pid" +name=$(basename $(readlink -f $0)) +pid_file="/var/run/$name.pid" +stdout_log="/var/log/$name.log" +stderr_log="/var/log/$name.err" +[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name +get_pid() { + cat "$pid_file" +} +is_running() { + [ -f "$pid_file" ] && ps $(get_pid) > /dev/null 2>&1 +} +case "$1" in + start) + if is_running; then + echo "Already started" + else + echo "Starting $name" + $cmd >> "$stdout_log" 2>> "$stderr_log" & + echo $! > "$pid_file" + if ! is_running; then + echo "Unable to start, see $stdout_log and $stderr_log" + exit 1 + fi + fi + ;; + stop) + if is_running; then + echo -n "Stopping $name.." + kill $(get_pid) + for i in {1..10} + do + if ! is_running; then + break + fi + echo -n "." + sleep 1 + done + echo + if is_running; then + echo "Not stopped; may still be shutting down or shutdown may have failed" + exit 1 + else + echo "Stopped" + if [ -f "$pid_file" ]; then + rm "$pid_file" + fi + fi + else + echo "Not running" + fi + ;; + restart) + $0 stop + if is_running; then + echo "Unable to stop, will not attempt to start" + exit 1 + fi + $0 start + ;; + status) + if is_running; then + echo "Running" + else + echo "Stopped" + exit 1 + fi + ;; + *) + echo "Usage: $0 {start|stop|restart|status}" + exit 1 + ;; +esac +exit 0 +`, +} + +func isSystemd() bool { + if _, err := os.Stat("/run/systemd/system"); err == nil { + return true + } + return false +} + +func installLinuxService(c *cli.Context) error { + etPath, err := os.Executable() + if err != nil { + return fmt.Errorf("error determining executable path: %v", err) + } + templateArgs := ServiceTemplateArgs{Path: etPath} + + switch { + case isSystemd(): + return installSystemd(&templateArgs) + default: + return installSysv(&templateArgs) + } +} + +func installSystemd(templateArgs *ServiceTemplateArgs) error { + for _, serviceTemplate := range systemdTemplates { + err := serviceTemplate.Generate(templateArgs) + if err != nil { + return err + } + } + if err := runCommand("systemctl", "enable", "cloudflare-warp.service"); err != nil { + return err + } + if err := runCommand("systemctl", "start", "cloudflare-warp-update.timer"); err != nil { + return err + } + return runCommand("systemctl", "daemon-reload") +} + +func installSysv(templateArgs *ServiceTemplateArgs) error { + confPath, err := sysvTemplate.ResolvePath() + if err != nil { + return err + } + if err := sysvTemplate.Generate(templateArgs); err != nil { + return err + } + for _, i := range [...]string{"2", "3", "4", "5"} { + if err := os.Symlink(confPath, "/etc/rc"+i+".d/S50et"); err != nil { + continue + } + } + for _, i := range [...]string{"0", "1", "6"} { + if err := os.Symlink(confPath, "/etc/rc"+i+".d/K02et"); err != nil { + continue + } + } + return nil +} + +func uninstallLinuxService(c *cli.Context) error { + switch { + case isSystemd(): + return uninstallSystemd() + default: + return uninstallSysv() + } +} + +func uninstallSystemd() error { + if err := runCommand("systemctl", "disable", "cloudflare-warp.service"); err != nil { + return err + } + if err := runCommand("systemctl", "stop", "cloudflare-warp-update.timer"); err != nil { + return err + } + for _, serviceTemplate := range systemdTemplates { + if err := serviceTemplate.Remove(); err != nil { + return err + } + } + return nil +} + +func uninstallSysv() error { + if err := sysvTemplate.Remove(); err != nil { + return err + } + for _, i := range [...]string{"2", "3", "4", "5"} { + if err := os.Remove("/etc/rc" + i + ".d/S50et"); err != nil { + continue + } + } + for _, i := range [...]string{"0", "1", "6"} { + if err := os.Remove("/etc/rc" + i + ".d/K02et"); err != nil { + continue + } + } + return nil +} diff --git a/cmd/cloudflare-warp/login.go b/cmd/cloudflare-warp/login.go new file mode 100644 index 00000000..dbb452f4 --- /dev/null +++ b/cmd/cloudflare-warp/login.go @@ -0,0 +1,31 @@ +package main + +import ( + "fmt" + "os" + "syscall" + + homedir "github.com/mitchellh/go-homedir" + cli "gopkg.in/urfave/cli.v2" +) + +func login(c *cli.Context) error { + path, err := homedir.Expand(defaultConfigPath) + if err != nil { + return err + } + fileInfo, err := os.Stat(path) + if err == nil && fileInfo.Size() > 0 { + fmt.Fprintf(os.Stderr, `You have an existing config file at %s which login would overwrite. +If this is intentional, please move or delete that file then run this command again. +`, defaultConfigPath) + return nil + } + if err != nil && err.(*os.PathError).Err != syscall.ENOENT { + return err + } + + fmt.Fprintln(os.Stderr, "Please visit https://www.cloudflare.com/a/warp to obtain a certificate.") + + return nil +} diff --git a/cmd/cloudflare-warp/macos_service.go b/cmd/cloudflare-warp/macos_service.go new file mode 100644 index 00000000..9b50a5b3 --- /dev/null +++ b/cmd/cloudflare-warp/macos_service.go @@ -0,0 +1,82 @@ +// +build darwin + +package main + +import ( + "fmt" + "os" + + cli "gopkg.in/urfave/cli.v2" +) + +func runApp(app *cli.App) { + app.Commands = append(app.Commands, &cli.Command{ + Name: "service", + Usage: "Manages the Cloudflare Warp launch agent", + Subcommands: []*cli.Command{ + &cli.Command{ + Name: "install", + Usage: "Install Cloudflare Warp as an user launch agent", + Action: installLaunchd, + }, + &cli.Command{ + Name: "uninstall", + Usage: "Uninstall the Cloudflare Warp launch agent", + Action: uninstallLaunchd, + }, + }, + }) + app.Run(os.Args) +} + +var launchdTemplate = ServiceTemplate{ + Path: "~/Library/LaunchAgents/com.cloudflare.warp.plist", + Content: ` + + + + Label + com.cloudflare.warp + Program + {{ .Path }} + RunAtLoad + + KeepAlive + + NetworkState + + + ThrottleInterval + 20 + +`, +} + +func installLaunchd(c *cli.Context) error { + etPath, err := os.Executable() + if err != nil { + return fmt.Errorf("error determining executable path: %v", err) + } + templateArgs := ServiceTemplateArgs{Path: etPath} + err = launchdTemplate.Generate(&templateArgs) + if err != nil { + return err + } + plistPath, err := launchdTemplate.ResolvePath() + if err != nil { + return err + } + return runCommand("launchctl", "load", plistPath) +} + +func uninstallLaunchd(c *cli.Context) error { + plistPath, err := launchdTemplate.ResolvePath() + if err != nil { + return err + } + err = runCommand("launchctl", "unload", plistPath) + if err != nil { + return err + } + return launchdTemplate.Remove() +} diff --git a/cmd/cloudflare-warp/main.go b/cmd/cloudflare-warp/main.go new file mode 100644 index 00000000..8b7fda3c --- /dev/null +++ b/cmd/cloudflare-warp/main.go @@ -0,0 +1,473 @@ +package main + +import ( + "crypto/tls" + "encoding/hex" + "fmt" + "math/rand" + "net" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/cloudflare/cloudflare-warp/h2mux" + "github.com/cloudflare/cloudflare-warp/metrics" + "github.com/cloudflare/cloudflare-warp/origin" + "github.com/cloudflare/cloudflare-warp/tlsconfig" + tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" + "github.com/cloudflare/cloudflare-warp/validation" + + log "github.com/Sirupsen/logrus" + "github.com/facebookgo/grace/gracenet" + raven "github.com/getsentry/raven-go" + homedir "github.com/mitchellh/go-homedir" + cli "gopkg.in/urfave/cli.v2" + "gopkg.in/urfave/cli.v2/altsrc" + + "github.com/coreos/go-systemd/daemon" + "github.com/pkg/errors" +) + +const sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878" +const defaultConfigPath = "~/.cloudflare-warp.yml" + +var listeners = gracenet.Net{} +var Version = "DEV" +var BuildTime = "unknown" + +// Shutdown channel used by the app. When closed, app must terminate. +// May be closed by the Windows service runner. +var shutdownC chan struct{} + +func main() { + metrics.RegisterBuildInfo(BuildTime, Version) + raven.SetDSN(sentryDSN) + raven.SetRelease(Version) + shutdownC = make(chan struct{}) + app := &cli.App{} + app.Name = "cloudflare-warp" + app.Copyright = `(c) 2017 Cloudflare Inc. + Use is subject to the license agreement at https://warp.cloudflare.com/licence/` + app.Usage = "Cloudflare reverse tunnelling proxy agent \033[1;31m*BETA*\033[0m" + app.ArgsUsage = "origin-url" + app.Version = fmt.Sprintf("%s (built %s)", Version, BuildTime) + app.Description = `A reverse tunnel proxy agent that connects to Cloudflare's infrastructure. + Upon connecting, you are assigned a unique subdomain on cftunnel.com. + Alternatively, you can specify a hostname on a zone you control. + + Requests made to Cloudflare's servers for your hostname will be proxied + through the tunnel to your local webserver. + +WARNING: + ` + "\033[1;31m*** THIS IS A BETA VERSION OF THE CLOUDFLARE WARP AGENT ***\033[0m" + ` + + At this time, do not use Cloudflare Warp for connecting production servers to Cloudflare. + Availability and reliability of this service is not guaranteed through the beta period.` + app.Flags = []cli.Flag{ + &cli.StringFlag{ + Name: "config", + Usage: "Specifies a config file in YAML format.", + }, + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "autoupdate", + Usage: "Periodically check for updates, restarting the server with the new version.", + Value: time.Hour * 24, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "edge", + Value: "cftunnel.com:7844", + Usage: "Address of the Cloudflare tunnel server.", + EnvVars: []string{"TUNNEL_EDGE"}, + Hidden: true, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "cacert", + Usage: "Certificate Authority authenticating the Cloudflare tunnel connection.", + EnvVars: []string{"TUNNEL_CACERT"}, + Hidden: true, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "url", + Value: "http://localhost:8080", + Usage: "Connect to the local webserver at `URL`.", + EnvVars: []string{"TUNNEL_URL"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "hostname", + Usage: "Set a hostname on a Cloudflare zone to route traffic through this tunnel.", + EnvVars: []string{"TUNNEL_HOSTNAME"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "id", + Usage: "A unique identifier used to tie connections to this tunnel instance.", + EnvVars: []string{"TUNNEL_ID"}, + Hidden: true, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "lb-pool", + Usage: "The name of a (new/existing) load balancing pool to add this origin to.", + EnvVars: []string{"TUNNEL_LB_POOL"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "api-key", + Usage: "A Cloudflare API key. Required(can be in the config file) unless you are only running the hello command or login command.", + EnvVars: []string{"TUNNEL_API_KEY"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "api-email", + Usage: "The Cloudflare user's email address associated with the API key. Required(can be in the config file) unless you are only running the hello command or login command.", + EnvVars: []string{"TUNNEL_API_EMAIL"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "api-ca-key", + Usage: "The Origin CA service key associated with the user. Required(can be in the config file) unless you are only running the hello command or login command.", + EnvVars: []string{"TUNNEL_API_CA_KEY"}, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "metrics", + Value: "localhost:", + Usage: "Listen address for metrics reporting.", + EnvVars: []string{"TUNNEL_METRICS"}, + }), + altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ + Name: "tag", + Usage: "Custom tags used to identify this tunnel, in format `KEY=VALUE`. Multiple tags may be specified", + EnvVars: []string{"TUNNEL_TAG"}, + }), + altsrc.NewDurationFlag(&cli.DurationFlag{ + Name: "heartbeat-interval", + Usage: "Minimum idle time before sending a heartbeat.", + Value: time.Second * 5, + Hidden: true, + }), + altsrc.NewUint64Flag(&cli.Uint64Flag{ + Name: "heartbeat-count", + Usage: "Minimum number of unacked heartbeats to send before closing the connection.", + Value: 5, + Hidden: true, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "loglevel", + Value: "info", + Usage: "Logging level {panic, fatal, error, warn, info, debug}", + EnvVars: []string{"TUNNEL_LOGLEVEL"}, + }), + altsrc.NewUintFlag(&cli.UintFlag{ + Name: "retries", + Value: 5, + Usage: "Maximum number of retries for connection/protocol errors.", + EnvVars: []string{"TUNNEL_RETRIES"}, + }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: "debug", + Value: false, + Usage: "Enable HTTP requests to the autogenerated cftunnel.com domain.", + EnvVars: []string{"TUNNEL_DEBUG"}, + }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: "hello-world", + Usage: "Run Hello World Server", + Value: false, + }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "pidfile", + Usage: "Write the application's PID to this file after first successful connection.", + EnvVars: []string{"TUNNEL_PIDFILE"}, + }), + } + app.Action = func(c *cli.Context) error { + raven.CapturePanic(func() { startServer(c) }, nil) + return nil + } + app.Before = func(context *cli.Context) error { + inputSource, err := findInputSourceContext(context) + if err != nil { + return err + } else if inputSource != nil { + return altsrc.ApplyInputSourceValues(context, inputSource, app.Flags) + } + return nil + } + app.Commands = []*cli.Command{ + &cli.Command{ + Name: "update", + Action: update, + Usage: "Update the agent if a new version exists", + ArgsUsage: " ", + Description: `Looks for a new version on the offical download server. + If a new version exists, updates the agent binary and quits. + Otherwise, does nothing. + + To determine if an update happened in a script, check for error code 64.`, + }, + &cli.Command{ + Name: "login", + Action: login, + Usage: "Generate a configuration file with your login details", + ArgsUsage: " ", + }, + &cli.Command{ + Name: "hello", + Action: hello, + Usage: "Run a simple \"Hello World\" server for testing Cloudflare Warp.", + Flags: []cli.Flag{ + &cli.IntFlag{ + Name: "port", + Usage: "Listen on the selected port.", + Value: 8080, + }, + }, + ArgsUsage: " ", // can't be the empty string or we get the default output + }, + } + runApp(app) +} + +func startServer(c *cli.Context) { + var wg sync.WaitGroup + errC := make(chan error) + wg.Add(2) + + if c.NumFlags() == 0 && c.NArg() == 0 { + cli.ShowAppHelp(c) + return + } + + logLevel, err := log.ParseLevel(c.String("loglevel")) + if err != nil { + log.WithError(err).Fatal("Unknown logging level specified") + } + log.SetLevel(logLevel) + hostname, err := validation.ValidateHostname(c.String("hostname")) + if err != nil { + log.WithError(err).Fatal("Invalid hostname") + + } + clientID := c.String("id") + if !c.IsSet("id") { + clientID = generateRandomClientID() + } + + tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) + if err != nil { + log.WithError(err).Fatal("Tag parse failure") + } + + tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) + + if c.IsSet("hello-world") { + wg.Add(1) + listener, err := findAvailablePort() + + if err != nil { + listener.Close() + log.WithError(err).Fatal("Cannot start Hello World Server") + } + go func() { + startHelloWorldServer(listener, shutdownC) + wg.Done() + }() + c.Set("url", "http://"+listener.Addr().String()) + log.Infof("Starting Hello World Server at %s", c.String("url")) + } + url, err := validateUrl(c) + if err != nil { + log.WithError(err).Fatal("Error validating url") + } + // User must have api-key, api-email and api-ca-key + if !c.IsSet("api-key") { + log.Fatal("You need to give us your api-key either via the --api-key option or put it in the configuration file. You will also need to give us your api-email and api-ca-key.") + } + if !c.IsSet("api-email") { + log.Fatal("You need to give us your api-email either via the --api-email option or put it in the configuration file. You will also need to give us your api-ca-key.") + } + if !c.IsSet("api-ca-key") { + log.Fatal("You need to give us your api-ca-key either via the --api-ca-key option or put it in the configuration file.") + } + log.Infof("Proxying tunnel requests to %s", url) + tunnelConfig := &origin.TunnelConfig{ + EdgeAddr: c.String("edge"), + OriginUrl: url, + Hostname: hostname, + APIKey: c.String("api-key"), + APIEmail: c.String("api-email"), + APICAKey: c.String("api-ca-key"), + TlsConfig: &tls.Config{}, + Retries: c.Uint("retries"), + HeartbeatInterval: c.Duration("heartbeat-interval"), + MaxHeartbeats: c.Uint64("heartbeat-count"), + ClientID: clientID, + ReportedVersion: Version, + LBPool: c.String("lb-pool"), + Tags: tags, + AccessInternalIP: c.Bool("debug"), + ConnectedSignal: h2mux.NewSignal(), + } + + tunnelConfig.TlsConfig = tlsconfig.CLIFlags{RootCA: "cacert"}.GetConfig(c) + if tunnelConfig.TlsConfig.RootCAs == nil { + tunnelConfig.TlsConfig.RootCAs = GetCloudflareRootCA() + tunnelConfig.TlsConfig.ServerName = "cftunnel.com" + } else { + tunnelConfig.TlsConfig.ServerName, _, _ = net.SplitHostPort(tunnelConfig.EdgeAddr) + } + + go writePidFile(tunnelConfig.ConnectedSignal, c.String("pidfile")) + go func() { + errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC) + wg.Done() + }() + + metricsListener, err := listeners.Listen("tcp", c.String("metrics")) + if err != nil { + log.WithError(err).Fatal("Error opening metrics server listener") + } + go func() { + errC <- metrics.ServeMetrics(metricsListener, shutdownC) + wg.Done() + }() + + go autoupdate(c.Duration("autoupdate"), shutdownC) + + err = WaitForSignal(errC, shutdownC) + if err != nil { + log.WithError(err).Error("Quitting due to error") + raven.CaptureErrorAndWait(err, nil) + } else { + log.Info("Quitting...") + } + // Wait for clean exit, discarding all errors + go func() { + for range errC { + } + }() + wg.Wait() +} + +func WaitForSignal(errC chan error, shutdownC chan struct{}) error { + signals := make(chan os.Signal, 10) + signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) + defer signal.Stop(signals) + select { + case err := <-errC: + close(shutdownC) + return err + case <-signals: + close(shutdownC) + case <-shutdownC: + } + return nil +} + +func update(c *cli.Context) error { + if updateApplied() { + os.Exit(64) + } + return nil +} + +func autoupdate(frequency time.Duration, shutdownC chan struct{}) { + if int64(frequency) == 0 { + return + } + for { + if updateApplied() { + if _, err := listeners.StartProcess(); err != nil { + log.WithError(err).Error("Unable to restart server automatically") + } + close(shutdownC) + return + } + time.Sleep(frequency) + } +} + +func updateApplied() bool { + releaseInfo := checkForUpdates() + if releaseInfo.Updated { + log.Infof("Updated to version %s", releaseInfo.Version) + return true + } + if releaseInfo.Error != nil { + log.WithError(releaseInfo.Error).Error("Update check failed") + } + return false +} + +func fileExists(path string) (bool, error) { + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + // ignore missing files + return false, nil + } + return false, err + } + f.Close() + return true, nil +} + +func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) { + if context.IsSet("config") { + return altsrc.NewYamlSourceFromFile(context.String("config")) + } + for _, tryPath := range []string{ + defaultConfigPath, + "~/.cloudflare-warp.yaml", + "~/cloudflare-warp.yaml", + "~/cloudflare-warp.yml", + "~/.et.yaml", + "~/et.yml", + "~/et.yaml", + "~/.cftunnel.yaml", // for existing users + "~/cftunnel.yaml", + } { + path, err := homedir.Expand(tryPath) + if err != nil { + continue + } + ok, err := fileExists(path) + if ok { + return altsrc.NewYamlSourceFromFile(path) + } else if err != nil { + return nil, err + } + } + return nil, nil +} + +func generateRandomClientID() string { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + id := make([]byte, 32) + r.Read(id) + return hex.EncodeToString(id) +} + +func writePidFile(waitForSignal h2mux.Signal, pidFile string) { + waitForSignal.Wait() + daemon.SdNotify(false, "READY=1") + if pidFile == "" { + return + } + file, err := os.Create(pidFile) + if err != nil { + log.WithError(err).Errorf("Unable to write pid to %s", pidFile) + } + defer file.Close() + fmt.Fprintf(file, "%d", os.Getpid()) +} + +// validate url. It can be either from --url or argument +func validateUrl(c *cli.Context) (string, error) { + var url = c.String("url") + if c.NArg() > 0 { + if c.IsSet("url") { + return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.") + } + url = c.Args().Get(0) + } + validUrl, err := validation.ValidateUrl(url) + return validUrl, err +} diff --git a/cmd/cloudflare-warp/service_template.go b/cmd/cloudflare-warp/service_template.go new file mode 100644 index 00000000..28b5df30 --- /dev/null +++ b/cmd/cloudflare-warp/service_template.go @@ -0,0 +1,88 @@ +package main + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "os/exec" + "text/template" + + homedir "github.com/mitchellh/go-homedir" +) + +type ServiceTemplate struct { + Path string + Content string + FileMode os.FileMode +} + +type ServiceTemplateArgs struct { + Path string +} + +func (st *ServiceTemplate) ResolvePath() (string, error) { + resolvedPath, err := homedir.Expand(st.Path) + if err != nil { + return "", fmt.Errorf("error resolving path %s: %v", st.Path, err) + } + return resolvedPath, nil +} + +func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error { + tmpl, err := template.New(st.Path).Parse(st.Content) + if err != nil { + return fmt.Errorf("error generating %s template: %v", st.Path, err) + } + resolvedPath, err := st.ResolvePath() + if err != nil { + return err + } + var buffer bytes.Buffer + err = tmpl.Execute(&buffer, args) + if err != nil { + return fmt.Errorf("error generating %s: %v", st.Path, err) + } + fileMode := os.FileMode(0644) + if st.FileMode != 0 { + fileMode = st.FileMode + } + err = ioutil.WriteFile(resolvedPath, buffer.Bytes(), fileMode) + if err != nil { + return fmt.Errorf("error writing %s: %v", resolvedPath, err) + } + return nil +} + +func (st *ServiceTemplate) Remove() error { + resolvedPath, err := st.ResolvePath() + if err != nil { + return err + } + err = os.Remove(resolvedPath) + if err != nil { + return fmt.Errorf("error deleting %s: %v", resolvedPath, err) + } + return nil +} + +func runCommand(command string, args ...string) error { + cmd := exec.Command(command, args...) + stderr, err := cmd.StderrPipe() + if err != nil { + return fmt.Errorf("error getting stderr pipe: %v", err) + } + err = cmd.Start() + if err != nil { + return fmt.Errorf("error starting %s: %v", command, err) + } + commandErr, _ := ioutil.ReadAll(stderr) + if len(commandErr) > 0 { + return fmt.Errorf("%s error: %s", command, commandErr) + } + err = cmd.Wait() + if err != nil { + return fmt.Errorf("%s returned with error: %v", command, err) + } + return nil +} diff --git a/cmd/cloudflare-warp/tag.go b/cmd/cloudflare-warp/tag.go new file mode 100644 index 00000000..ec9bb86f --- /dev/null +++ b/cmd/cloudflare-warp/tag.go @@ -0,0 +1,32 @@ +package main + +import ( + "fmt" + "regexp" + + tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" +) + +// Restrict key names to characters allowed in an HTTP header name. +// Restrict key values to printable characters (what is recognised as data in an HTTP header value). +var tagRegexp = regexp.MustCompile("^([a-zA-Z0-9!#$%&'*+\\-.^_`|~]+)=([[:print:]]+)$") + +func NewTagFromCLI(compoundTag string) (tunnelpogs.Tag, bool) { + matches := tagRegexp.FindStringSubmatch(compoundTag) + if len(matches) == 0 { + return tunnelpogs.Tag{}, false + } + return tunnelpogs.Tag{Name: matches[1], Value: matches[2]}, true +} + +func NewTagSliceFromCLI(tags []string) ([]tunnelpogs.Tag, error) { + var tagSlice []tunnelpogs.Tag + for _, compoundTag := range tags { + if tag, ok := NewTagFromCLI(compoundTag); ok { + tagSlice = append(tagSlice, tag) + } else { + return nil, fmt.Errorf("Cannot parse tag value %s", compoundTag) + } + } + return tagSlice, nil +} diff --git a/cmd/cloudflare-warp/tag_test.go b/cmd/cloudflare-warp/tag_test.go new file mode 100644 index 00000000..d0f3466b --- /dev/null +++ b/cmd/cloudflare-warp/tag_test.go @@ -0,0 +1,46 @@ +package main + +import ( + "testing" + + tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" + + "github.com/stretchr/testify/assert" +) + +func TestSingleTag(t *testing.T) { + testCases := []struct { + Input string + Output tunnelpogs.Tag + Fail bool + }{ + {Input: "x=y", Output: tunnelpogs.Tag{Name: "x", Value: "y"}}, + {Input: "More-Complex=Tag Values", Output: tunnelpogs.Tag{Name: "More-Complex", Value: "Tag Values"}}, + {Input: "First=Equals=Wins", Output: tunnelpogs.Tag{Name: "First", Value: "Equals=Wins"}}, + {Input: "x=", Fail: true}, + {Input: "=y", Fail: true}, + {Input: "=", Fail: true}, + {Input: "No spaces allowed=in key names", Fail: true}, + {Input: "omg\nwtf=bbq", Fail: true}, + } + for i, testCase := range testCases { + tag, ok := NewTagFromCLI(testCase.Input) + assert.Equalf(t, !testCase.Fail, ok, "mismatched success for test case %d", i) + assert.Equalf(t, testCase.Output, tag, "mismatched output for test case %d", i) + } +} + +func TestTagSlice(t *testing.T) { + tagSlice, err := NewTagSliceFromCLI([]string{"a=b", "c=d", "e=f"}) + assert.NoError(t, err) + assert.Len(t, tagSlice, 3) + assert.Equal(t, "a", tagSlice[0].Name) + assert.Equal(t, "b", tagSlice[0].Value) + assert.Equal(t, "c", tagSlice[1].Name) + assert.Equal(t, "d", tagSlice[1].Value) + assert.Equal(t, "e", tagSlice[2].Name) + assert.Equal(t, "f", tagSlice[2].Value) + + tagSlice, err = NewTagSliceFromCLI([]string{"a=b", "=", "e=f"}) + assert.Error(t, err) +} diff --git a/cmd/cloudflare-warp/update.go b/cmd/cloudflare-warp/update.go new file mode 100644 index 00000000..dc2f0cd0 --- /dev/null +++ b/cmd/cloudflare-warp/update.go @@ -0,0 +1,41 @@ +package main + +import "github.com/equinox-io/equinox" + +const appID = "app_cwbQae3Tpea" + +var publicKey = []byte(` +-----BEGIN ECDSA PUBLIC KEY----- +MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE4OWZocTVZ8Do/L6ScLdkV+9A0IYMHoOf +dsCmJ/QZ6aw0w9qkkwEpne1Lmo6+0pGexZzFZOH6w5amShn+RXt7qkSid9iWlzGq +EKx0BZogHSor9Wy5VztdFaAaVbsJiCbO +-----END ECDSA PUBLIC KEY----- +`) + +type ReleaseInfo struct { + Updated bool + Version string + Error error +} + +func checkForUpdates() ReleaseInfo { + var opts equinox.Options + if err := opts.SetPublicKeyPEM(publicKey); err != nil { + return ReleaseInfo{Error: err} + } + + resp, err := equinox.Check(appID, opts) + switch { + case err == equinox.NotAvailableErr: + return ReleaseInfo{} + case err != nil: + return ReleaseInfo{Error: err} + } + + err = resp.Apply() + if err != nil { + return ReleaseInfo{Error: err} + } + + return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion} +} diff --git a/cmd/cloudflare-warp/windows_service.go b/cmd/cloudflare-warp/windows_service.go new file mode 100644 index 00000000..0c0181f5 --- /dev/null +++ b/cmd/cloudflare-warp/windows_service.go @@ -0,0 +1,145 @@ +// +build windows + +package main + +// Copypasta from the example files: +// https://github.com/golang/sys/blob/master/windows/svc/example + +import ( + "fmt" + "os" + + log "github.com/Sirupsen/logrus" + cli "gopkg.in/urfave/cli.v2" + + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/eventlog" + "golang.org/x/sys/windows/svc/mgr" +) + +const ( + windowsServiceName = "CloudflareWarp" + windowsServiceDescription = "Cloudflare Warp agent" +) + +func runApp(app *cli.App) { + app.Commands = append(app.Commands, &cli.Command{ + Name: "service", + Usage: "Manages the Cloudflare Warp Windows service", + Subcommands: []*cli.Command{ + &cli.Command{ + Name: "install", + Usage: "Install Cloudflare Warp as a Windows service", + Action: installWindowsService, + }, + &cli.Command{ + Name: "uninstall", + Usage: "Uninstall the Cloudflare Warp service", + Action: uninstallWindowsService, + }, + }, + }) + + isIntSess, err := svc.IsAnInteractiveSession() + if err != nil { + log.Fatalf("failed to determine if we are running in an interactive session: %v", err) + } + + if isIntSess { + app.Run(os.Args) + return + } + + elog, err := eventlog.Open(windowsServiceName) + if err != nil { + return + } + defer elog.Close() + + elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName)) + err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog}) + if err != nil { + elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err)) + return + } + elog.Info(1, fmt.Sprintf("%s service stopped", windowsServiceName)) +} + +type windowsService struct { + app *cli.App + elog *eventlog.Log +} + +func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) { + const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown + changes <- svc.Status{State: svc.StartPending} + go s.app.Run(args) + changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} +loop: + for { + select { + case c := <-r: + switch c.Cmd { + case svc.Interrogate: + changes <- c.CurrentStatus + case svc.Stop, svc.Shutdown: + break loop + default: + s.elog.Error(1, fmt.Sprintf("unexpected control request #%d", c)) + } + } + } + close(shutdownC) + changes <- svc.Status{State: svc.StopPending} + return +} + +func installWindowsService(c *cli.Context) error { + exepath, err := os.Executable() + if err != nil { + return err + } + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + s, err := m.OpenService(windowsServiceName) + if err == nil { + s.Close() + return fmt.Errorf("service %s already exists", windowsServiceName) + } + s, err = m.CreateService(windowsServiceName, exepath, mgr.Config{DisplayName: windowsServiceDescription}, "is", "auto-started") + if err != nil { + return err + } + defer s.Close() + err = eventlog.InstallAsEventCreate(windowsServiceName, eventlog.Error|eventlog.Warning|eventlog.Info) + if err != nil { + s.Delete() + return fmt.Errorf("SetupEventLogSource() failed: %s", err) + } + return nil +} + +func uninstallWindowsService(c *cli.Context) error { + m, err := mgr.Connect() + if err != nil { + return err + } + defer m.Disconnect() + s, err := m.OpenService(windowsServiceName) + if err != nil { + return fmt.Errorf("service %s is not installed", windowsServiceName) + } + defer s.Close() + err = s.Delete() + if err != nil { + return err + } + err = eventlog.Remove(windowsServiceName) + if err != nil { + return fmt.Errorf("RemoveEventLogSource() failed: %s", err) + } + return nil +} diff --git a/h2mux/activestreammap.go b/h2mux/activestreammap.go new file mode 100644 index 00000000..9059e972 --- /dev/null +++ b/h2mux/activestreammap.go @@ -0,0 +1,213 @@ +package h2mux + +import ( + "sync" + + "golang.org/x/net/http2" +) + +// activeStreamMap is used to moderate access to active streams between the read and write +// threads, and deny access to new peer streams while shutting down. +type activeStreamMap struct { + sync.RWMutex + // streams tracks open streams. + streams map[uint32]*MuxedStream + // streamsEmpty is a chan that should be closed when no more streams are open. + streamsEmpty chan struct{} + // nextStreamID is the next ID to use on our side of the connection. + // This is odd for clients, even for servers. + nextStreamID uint32 + // maxPeerStreamID is the ID of the most recent stream opened by the peer. + maxPeerStreamID uint32 + // ignoreNewStreams is true when the connection is being shut down. New streams + // cannot be registered. + ignoreNewStreams bool +} + +type FlowControlMetrics struct { + AverageReceiveWindowSize, AverageSendWindowSize float64 + MinReceiveWindowSize, MaxReceiveWindowSize uint32 + MinSendWindowSize, MaxSendWindowSize uint32 +} + +func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap { + m := &activeStreamMap{ + streams: make(map[uint32]*MuxedStream), + streamsEmpty: make(chan struct{}), + nextStreamID: 1, + } + // Client initiated stream uses odd stream ID, server initiated stream uses even stream ID + if !useClientStreamNumbers { + m.nextStreamID = 2 + } + return m +} + +// Len returns the number of active streams. +func (m *activeStreamMap) Len() int { + m.RLock() + defer m.RUnlock() + return len(m.streams) +} + +func (m *activeStreamMap) Get(streamID uint32) (*MuxedStream, bool) { + m.RLock() + defer m.RUnlock() + stream, ok := m.streams[streamID] + return stream, ok +} + +// Set returns true if the stream was assigned successfully. If a stream +// already existed with that ID or we are shutting down, return false. +func (m *activeStreamMap) Set(newStream *MuxedStream) bool { + m.Lock() + defer m.Unlock() + if _, ok := m.streams[newStream.streamID]; ok { + return false + } + if m.ignoreNewStreams { + return false + } + m.streams[newStream.streamID] = newStream + return true +} + +// Delete stops tracking the stream. It should be called only after it is closed and resetted. +func (m *activeStreamMap) Delete(streamID uint32) { + m.Lock() + defer m.Unlock() + delete(m.streams, streamID) + if len(m.streams) == 0 && m.streamsEmpty != nil { + close(m.streamsEmpty) + m.streamsEmpty = nil + } +} + +// Shutdown blocks new streams from being created. It returns a channel that receives an event +// once the last stream has closed, or nil if a shutdown is in progress. +func (m *activeStreamMap) Shutdown() <-chan struct{} { + m.Lock() + defer m.Unlock() + if m.ignoreNewStreams { + // already shutting down + return nil + } + m.ignoreNewStreams = true + done := make(chan struct{}) + if len(m.streams) == 0 { + // nothing to shut down + close(done) + return done + } + m.streamsEmpty = done + return done +} + +// AcquireLocalID acquires a new stream ID for a stream you're opening. +func (m *activeStreamMap) AcquireLocalID() uint32 { + m.Lock() + defer m.Unlock() + x := m.nextStreamID + m.nextStreamID += 2 + return x +} + +// ObservePeerID observes the ID of a stream opened by the peer. It returns true if we should accept +// the new stream, or false to reject it. The ErrCode gives the reason why. +func (m *activeStreamMap) AcquirePeerID(streamID uint32) (bool, http2.ErrCode) { + m.Lock() + defer m.Unlock() + switch { + case m.ignoreNewStreams: + return false, http2.ErrCodeStreamClosed + case streamID > m.maxPeerStreamID: + m.maxPeerStreamID = streamID + return true, http2.ErrCodeNo + default: + return false, http2.ErrCodeStreamClosed + } +} + +// IsPeerStreamID is true if the stream ID belongs to the peer. +func (m *activeStreamMap) IsPeerStreamID(streamID uint32) bool { + m.RLock() + defer m.RUnlock() + return (streamID % 2) != (m.nextStreamID % 2) +} + +// IsLocalStreamID is true if it is a stream we have opened, even if it is now closed. +func (m *activeStreamMap) IsLocalStreamID(streamID uint32) bool { + m.RLock() + defer m.RUnlock() + return (streamID%2) == (m.nextStreamID%2) && streamID < m.nextStreamID +} + +// LastPeerStreamID returns the most recently opened peer stream ID. +func (m *activeStreamMap) LastPeerStreamID() uint32 { + m.RLock() + defer m.RUnlock() + return m.maxPeerStreamID +} + +// LastLocalStreamID returns the most recently opened local stream ID. +func (m *activeStreamMap) LastLocalStreamID() uint32 { + m.RLock() + defer m.RUnlock() + if m.nextStreamID > 1 { + return m.nextStreamID - 2 + } + return 0 +} + +// Abort closes every active stream and prevents new ones being created. This should be used to +// return errors in pending read/writes when the underlying connection goes away. +func (m *activeStreamMap) Abort() { + m.Lock() + defer m.Unlock() + for _, stream := range m.streams { + stream.Close() + } + m.ignoreNewStreams = true +} + +func (m *activeStreamMap) Metrics() *FlowControlMetrics { + m.Lock() + defer m.Unlock() + var averageReceiveWindowSize, averageSendWindowSize float64 + var minReceiveWindowSize, maxReceiveWindowSize, minSendWindowSize, maxSendWindowSize uint32 + i := 0 + // The first variable in the range expression for map is the key, not index. + for _, stream := range m.streams { + // iterative mean: a(t+1) = a(t) + (a(t)-x)/(t+1) + windows := stream.FlowControlWindow() + averageReceiveWindowSize += (float64(windows.receiveWindow) - averageReceiveWindowSize) / float64(i+1) + averageSendWindowSize += (float64(windows.sendWindow) - averageSendWindowSize) / float64(i+1) + if i == 0 { + maxReceiveWindowSize = windows.receiveWindow + minReceiveWindowSize = windows.receiveWindow + maxSendWindowSize = windows.sendWindow + minSendWindowSize = windows.sendWindow + } else { + if windows.receiveWindow > maxReceiveWindowSize { + maxReceiveWindowSize = windows.receiveWindow + } else if windows.receiveWindow < minReceiveWindowSize { + minReceiveWindowSize = windows.receiveWindow + } + + if windows.sendWindow > maxSendWindowSize { + maxSendWindowSize = windows.sendWindow + } else if windows.sendWindow < minSendWindowSize { + minSendWindowSize = windows.sendWindow + } + } + i++ + } + return &FlowControlMetrics{ + MinReceiveWindowSize: minReceiveWindowSize, + MaxReceiveWindowSize: maxReceiveWindowSize, + AverageReceiveWindowSize: averageReceiveWindowSize, + MinSendWindowSize: minSendWindowSize, + MaxSendWindowSize: maxSendWindowSize, + AverageSendWindowSize: averageSendWindowSize, + } +} diff --git a/h2mux/activestreammap_test.go b/h2mux/activestreammap_test.go new file mode 100644 index 00000000..f21e84a3 --- /dev/null +++ b/h2mux/activestreammap_test.go @@ -0,0 +1,28 @@ +package h2mux + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMetrics(t *testing.T) { + streamMap := newActiveStreamMap(false) + for i := 1; i <= 7; i++ { + stream := new(MuxedStream) + *stream = MuxedStream{ + streamID: defaultWindowSize * uint32(i), + receiveWindow: defaultWindowSize * uint32(i), + sendWindow: defaultWindowSize * uint32(i), + } + + assert.True(t, streamMap.Set(stream)) + } + metrics := streamMap.Metrics() + assert.Equal(t, float64(defaultWindowSize*4), metrics.AverageReceiveWindowSize) + assert.Equal(t, float64(defaultWindowSize*4), metrics.AverageSendWindowSize) + assert.Equal(t, defaultWindowSize, metrics.MinReceiveWindowSize) + assert.Equal(t, defaultWindowSize, metrics.MinSendWindowSize) + assert.Equal(t, defaultWindowSize*7, metrics.MaxReceiveWindowSize) + assert.Equal(t, defaultWindowSize*7, metrics.MaxSendWindowSize) +} diff --git a/h2mux/booleanfuse.go b/h2mux/booleanfuse.go new file mode 100644 index 00000000..b4cbe8d7 --- /dev/null +++ b/h2mux/booleanfuse.go @@ -0,0 +1,25 @@ +package h2mux + +import "sync/atomic" + +// BooleanFuse is a data structure that can be set once to a particular value using Fuse(value). +// Subsequent calls to Fuse() will have no effect. +type BooleanFuse struct { + value int32 +} + +// Value gets the value +func (f *BooleanFuse) Value() bool { + // 0: unset + // 1: set true + // 2: set false + return atomic.LoadInt32(&f.value) == 1 +} + +func (f *BooleanFuse) Fuse(result bool) { + newValue := int32(2) + if result { + newValue = 1 + } + atomic.CompareAndSwapInt32(&f.value, 0, newValue) +} diff --git a/h2mux/error.go b/h2mux/error.go new file mode 100644 index 00000000..efb1f37d --- /dev/null +++ b/h2mux/error.go @@ -0,0 +1,61 @@ +package h2mux + +import ( + "fmt" + + "golang.org/x/net/http2" +) + +var ( + ErrHandshakeTimeout = MuxerHandshakeError{"1000 handshake timeout"} + ErrBadHandshakeNotSettings = MuxerHandshakeError{"1001 unexpected response"} + ErrBadHandshakeUnexpectedAck = MuxerHandshakeError{"1002 unexpected response"} + ErrBadHandshakeNoMagic = MuxerHandshakeError{"1003 unexpected response"} + ErrBadHandshakeWrongMagic = MuxerHandshakeError{"1004 connected to endpoint of wrong type"} + ErrBadHandshakeNotSettingsAck = MuxerHandshakeError{"1005 unexpected response"} + ErrBadHandshakeUnexpectedSettings = MuxerHandshakeError{"1006 unexpected response"} + + ErrUnexpectedFrameType = MuxerProtocolError{"2001 unexpected frame type", http2.ErrCodeProtocol} + ErrUnknownStream = MuxerProtocolError{"2002 unknown stream", http2.ErrCodeProtocol} + ErrInvalidStream = MuxerProtocolError{"2003 invalid stream", http2.ErrCodeProtocol} + + ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"} + ErrConnectionClosed = MuxerApplicationError{"3001 connection closed"} + ErrConnectionDropped = MuxerApplicationError{"3002 connection dropped"} + + ErrClosedStream = MuxerStreamError{"4000 stream closed", http2.ErrCodeStreamClosed} +) + +type MuxerHandshakeError struct { + cause string +} + +func (e MuxerHandshakeError) Error() string { + return fmt.Sprintf("Handshake error: %s", e.cause) +} + +type MuxerProtocolError struct { + cause string + h2code http2.ErrCode +} + +func (e MuxerProtocolError) Error() string { + return fmt.Sprintf("Protocol error: %s", e.cause) +} + +type MuxerApplicationError struct { + cause string +} + +func (e MuxerApplicationError) Error() string { + return fmt.Sprintf("Application error: %s", e.cause) +} + +type MuxerStreamError struct { + cause string + h2code http2.ErrCode +} + +func (e MuxerStreamError) Error() string { + return fmt.Sprintf("Stream error: %s", e.cause) +} diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go new file mode 100644 index 00000000..d47b4059 --- /dev/null +++ b/h2mux/h2mux.go @@ -0,0 +1,335 @@ +package h2mux + +import ( + "io" + "strings" + "sync" + "time" + + log "github.com/Sirupsen/logrus" + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" +) + +const ( + defaultFrameSize uint32 = 1 << 14 // Minimum frame size in http2 spec + defaultWindowSize uint32 = 65535 + maxWindowSize uint32 = (1 << 31) - 1 // 2^31-1 = 2147483647, max window size specified in http2 spec + defaultTimeout time.Duration = 5 * time.Second + defaultRetries uint64 = 5 + + SettingMuxerMagic http2.SettingID = 0x42db + MuxerMagicOrigin uint32 = 0xa2e43c8b + MuxerMagicEdge uint32 = 0x1088ebf9 +) + +type MuxedStreamHandler interface { + ServeStream(*MuxedStream) error +} + +type MuxedStreamFunc func(stream *MuxedStream) error + +func (f MuxedStreamFunc) ServeStream(stream *MuxedStream) error { + return f(stream) +} + +type MuxerConfig struct { + Timeout time.Duration + Handler MuxedStreamHandler + IsClient bool + // Name is used to identify this muxer instance when logging. + Name string + // The minimum time this connection can be idle before sending a heartbeat. + HeartbeatInterval time.Duration + // The minimum number of heartbeats to send before terminating the connection. + MaxHeartbeats uint64 +} + +type Muxer struct { + // f is used to read and write HTTP2 frames on the wire. + f *http2.Framer + // config is the MuxerConfig given in Handshake. + config MuxerConfig + // w, r are references to the underlying connection used. + w io.WriteCloser + r io.ReadCloser + // muxReader is the read process. + muxReader *MuxReader + // muxWriter is the write process. + muxWriter *MuxWriter + // newStreamChan is used to create new streams on the writer thread. + // The writer will assign the next available stream ID. + newStreamChan chan MuxedStreamRequest + // abortChan is used to abort the writer event loop. + abortChan chan struct{} + // abortOnce is used to ensure abortChan is closed once only. + abortOnce sync.Once + // readyList is used to signal writable streams. + readyList *ReadyList + // streams tracks currently-open streams. + streams *activeStreamMap + // explicitShutdown records whether the Muxer is closing because Shutdown was called, or due to another + // error. + explicitShutdown BooleanFuse +} + +type Header struct { + Name, Value string +} + +// Handshake establishes a muxed connection with the peer. +// After the handshake completes, it is possible to open and accept streams. +func Handshake( + w io.WriteCloser, + r io.ReadCloser, + config MuxerConfig, +) (*Muxer, error) { + // Initialise connection state fields + m := &Muxer{ + f: http2.NewFramer(w, r), // A framer that writes to w and reads from r + config: config, + w: w, + r: r, + newStreamChan: make(chan MuxedStreamRequest), + abortChan: make(chan struct{}), + readyList: NewReadyList(), + streams: newActiveStreamMap(config.IsClient), + } + m.f.ReadMetaHeaders = hpack.NewDecoder(4096, func(hpack.HeaderField) {}) + + if config.Timeout == 0 { + config.Timeout = defaultTimeout + } + + // Initialise the settings to identify this connection and confirm the other end is sane. + handshakeSetting := http2.Setting{ID: SettingMuxerMagic, Val: MuxerMagicEdge} + expectedMagic := MuxerMagicOrigin + if config.IsClient { + handshakeSetting.Val = MuxerMagicOrigin + expectedMagic = MuxerMagicEdge + } + errChan := make(chan error, 2) + // Simultaneously send our settings and verify the peer's settings. + go func() { errChan <- m.f.WriteSettings(handshakeSetting) }() + go func() { errChan <- m.readPeerSettings(expectedMagic) }() + err := joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout) + if err != nil { + return nil, err + } + // Confirm sanity by ACKing the frame and expecting an ACK for our frame. + // Not strictly necessary, but let's pretend to be H2-like. + go func() { errChan <- m.f.WriteSettingsAck() }() + go func() { errChan <- m.readPeerSettingsAck() }() + err = joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout) + if err != nil { + return nil, err + } + + // set up reader/writer pair ready for serve + streamErrors := NewStreamErrorMap() + goAwayChan := make(chan http2.ErrCode, 1) + pingTimestamp := NewPingTimestamp() + connActive := NewSignal() + idleDuration := config.HeartbeatInterval + // Sanity check to enusre idelDuration is sane + if idleDuration == 0 || idleDuration < defaultTimeout { + idleDuration = defaultTimeout + log.Warn("Minimum idle time has been adjusted to ", defaultTimeout) + } + maxRetries := config.MaxHeartbeats + if maxRetries == 0 { + maxRetries = defaultRetries + log.Warn("Minimum number of unacked heartbeats to send before closing the connection has been adjusted to ", maxRetries) + } + + m.muxReader = &MuxReader{ + f: m.f, + handler: m.config.Handler, + streams: m.streams, + readyList: m.readyList, + streamErrors: streamErrors, + goAwayChan: goAwayChan, + abortChan: m.abortChan, + pingTimestamp: pingTimestamp, + connActive: connActive, + initialStreamWindow: defaultWindowSize, + streamWindowMax: maxWindowSize, + r: m.r, + } + m.muxWriter = &MuxWriter{ + f: m.f, + streams: m.streams, + streamErrors: streamErrors, + readyStreamChan: m.readyList.ReadyChannel(), + newStreamChan: m.newStreamChan, + goAwayChan: goAwayChan, + abortChan: m.abortChan, + pingTimestamp: pingTimestamp, + idleTimer: NewIdleTimer(idleDuration, maxRetries), + connActiveChan: connActive.WaitChannel(), + maxFrameSize: defaultFrameSize, + } + m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer) + + return m, nil +} + +func (m *Muxer) readPeerSettings(magic uint32) error { + frame, err := m.f.ReadFrame() + if err != nil { + return err + } + settingsFrame, ok := frame.(*http2.SettingsFrame) + if !ok { + return ErrBadHandshakeNotSettings + } + if settingsFrame.Header().Flags != 0 { + return ErrBadHandshakeUnexpectedAck + } + peerMagic, ok := settingsFrame.Value(SettingMuxerMagic) + if !ok { + return ErrBadHandshakeNoMagic + } + if magic != peerMagic { + return ErrBadHandshakeWrongMagic + } + return nil +} + +func (m *Muxer) readPeerSettingsAck() error { + frame, err := m.f.ReadFrame() + if err != nil { + return err + } + settingsFrame, ok := frame.(*http2.SettingsFrame) + if !ok { + return ErrBadHandshakeNotSettingsAck + } + if settingsFrame.Header().Flags != http2.FlagSettingsAck { + return ErrBadHandshakeUnexpectedSettings + } + return nil +} + +func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time.Duration, timeoutError error) error { + for i := 0; i < receiveCount; i++ { + select { + case err := <-errChan: + if err != nil { + return err + } + case <-time.After(timeout): + return timeoutError + } + } + return nil +} + +func (m *Muxer) Serve() error { + logger := log.WithField("name", m.config.Name) + errChan := make(chan error) + go func() { + errChan <- m.muxReader.run(logger) + m.explicitShutdown.Fuse(false) + m.r.Close() + m.abort() + }() + go func() { + errChan <- m.muxWriter.run(logger) + m.explicitShutdown.Fuse(false) + m.w.Close() + m.abort() + }() + err := <-errChan + go func() { + // discard error as other handler closes + <-errChan + close(errChan) + }() + if isUnexpectedTunnelError(err, m.explicitShutdown.Value()) { + return err + } + return nil +} + +func (m *Muxer) Shutdown() { + m.explicitShutdown.Fuse(true) + m.muxReader.Shutdown() +} + +// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel. +// The set of expected errors change depending on whether we initiated shutdown or not. +func isUnexpectedTunnelError(err error, expectedShutdown bool) bool { + if err == nil { + return false + } + if !expectedShutdown { + return true + } + return !isConnectionClosedError(err) +} + +func isConnectionClosedError(err error) bool { + if err == io.EOF { + return true + } + if err == io.ErrClosedPipe { + return true + } + if err.Error() == "tls: use of closed connection" { + return true + } + if strings.HasSuffix(err.Error(), "use of closed network connection") { + return true + } + return false +} + +// OpenStream opens a new data stream with the given headers. +// Called by proxy server and tunnel +func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, error) { + stream := &MuxedStream{ + responseHeadersReceived: make(chan struct{}), + readBuffer: NewSharedBuffer(), + receiveWindow: defaultWindowSize, + receiveWindowCurrentMax: defaultWindowSize, // Initial window size limit. exponentially increase it when receiveWindow is exhausted + receiveWindowMax: maxWindowSize, + sendWindow: defaultWindowSize, + readyList: m.readyList, + writeHeaders: headers, + } + select { + // Will be received by mux writer + case m.newStreamChan <- MuxedStreamRequest{stream: stream, body: body}: + case <-m.abortChan: + return nil, ErrConnectionClosed + } + select { + case <-stream.responseHeadersReceived: + return stream, nil + case <-m.abortChan: + return nil, ErrConnectionClosed + } +} + +// Return the estimated round-trip time. +func (m *Muxer) RTT() RTTMeasurement { + return m.muxReader.RTT() +} + +// Return min/max/average of send/receive window for all streams on this connection +func (m *Muxer) FlowControlMetrics() *FlowControlMetrics { + return m.muxReader.FlowControlMetrics() +} + +func (m *Muxer) abort() { + m.abortOnce.Do(func() { + close(m.abortChan) + m.streams.Abort() + }) +} + +// Return how many retries/ticks since the connection was last marked active +func (m *Muxer) TimerRetries() uint64 { + return m.muxWriter.idleTimer.RetryCount() +} diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go new file mode 100644 index 00000000..e80c04ee --- /dev/null +++ b/h2mux/h2mux_test.go @@ -0,0 +1,646 @@ +package h2mux + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "math/rand" + "os" + "strconv" + "sync" + "testing" + "time" + + log "github.com/Sirupsen/logrus" +) + +func TestMain(m *testing.M) { + if os.Getenv("VERBOSE") == "1" { + log.SetLevel(log.DebugLevel) + } + os.Exit(m.Run()) +} + +type DefaultMuxerPair struct { + OriginMuxConfig MuxerConfig + OriginMux *Muxer + OriginWriter *io.PipeWriter + OriginReader *io.PipeReader + EdgeMuxConfig MuxerConfig + EdgeMux *Muxer + EdgeWriter *io.PipeWriter + EdgeReader *io.PipeReader + doneC chan struct{} +} + +func NewDefaultMuxerPair() *DefaultMuxerPair { + originReader, edgeWriter := io.Pipe() + edgeReader, originWriter := io.Pipe() + return &DefaultMuxerPair{ + OriginMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: true, Name: "origin"}, + OriginWriter: originWriter, + OriginReader: originReader, + EdgeMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: false, Name: "edge"}, + EdgeWriter: edgeWriter, + EdgeReader: edgeReader, + doneC: make(chan struct{}), + } +} + +func (p *DefaultMuxerPair) Handshake(t *testing.T) { + edgeErrC := make(chan error) + originErrC := make(chan error) + go func() { + var err error + p.EdgeMux, err = Handshake(p.EdgeWriter, p.EdgeReader, p.EdgeMuxConfig) + edgeErrC <- err + }() + go func() { + var err error + p.OriginMux, err = Handshake(p.OriginWriter, p.OriginReader, p.OriginMuxConfig) + originErrC <- err + }() + + select { + case err := <-edgeErrC: + if err != nil { + t.Fatalf("edge handshake failure: %s", err) + } + case <-time.After(time.Second * 5): + t.Fatalf("edge handshake timeout") + } + + select { + case err := <-originErrC: + if err != nil { + t.Fatalf("origin handshake failure: %s", err) + } + case <-time.After(time.Second * 5): + t.Fatalf("origin handshake timeout") + } +} + +func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) { + p.Handshake(t) + var wg sync.WaitGroup + wg.Add(2) + go func() { + err := p.EdgeMux.Serve() + if err != nil && err != io.EOF && err != io.ErrClosedPipe { + t.Errorf("error in edge muxer Serve(): %s", err) + } + p.OriginMux.Shutdown() + wg.Done() + }() + go func() { + err := p.OriginMux.Serve() + if err != nil && err != io.EOF && err != io.ErrClosedPipe { + t.Errorf("error in origin muxer Serve(): %s", err) + } + p.EdgeMux.Shutdown() + wg.Done() + }() + go func() { + // notify when both muxes have stopped serving + wg.Wait() + close(p.doneC) + }() +} + +func (p *DefaultMuxerPair) Wait(t *testing.T) { + select { + case <-p.doneC: + return + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for shutdown") + } +} + +func TestHandshake(t *testing.T) { + muxPair := NewDefaultMuxerPair() + muxPair.Handshake(t) + AssertIfPipeReadable(t, muxPair.OriginReader) + AssertIfPipeReadable(t, muxPair.EdgeReader) +} + +func TestSingleStream(t *testing.T) { + closeC := make(chan struct{}) + muxPair := NewDefaultMuxerPair() + muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { + defer close(closeC) + if len(stream.Headers) != 1 { + t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) + } + if stream.Headers[0].Name != "test-header" { + t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name) + } + if stream.Headers[0].Value != "headerValue" { + t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value) + } + stream.WriteHeaders([]Header{ + Header{Name: "response-header", Value: "responseValue"}, + }) + buf := []byte("Hello world") + stream.Write(buf) + // after this receive, the edge closed the stream + <-closeC + n, err := stream.Read(buf) + if n > 0 { + t.Fatalf("read %d bytes after EOF", n) + } + if err != io.EOF { + t.Fatalf("expected EOF, got %s", err) + } + return nil + }) + muxPair.HandshakeAndServe(t) + + stream, err := muxPair.EdgeMux.OpenStream( + []Header{Header{Name: "test-header", Value: "headerValue"}}, + nil, + ) + if err != nil { + t.Fatalf("error in OpenStream: %s", err) + } + if len(stream.Headers) != 1 { + t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) + } + if stream.Headers[0].Name != "response-header" { + t.Fatalf("expected header name %s, got %s", "response-header", stream.Headers[0].Name) + } + if stream.Headers[0].Value != "responseValue" { + t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value) + } + responseBody := make([]byte, 11) + n, err := stream.Read(responseBody) + if err != nil { + t.Fatalf("error from (*MuxedStream).Read: %s", err) + } + if n != len(responseBody) { + t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n) + } + if string(responseBody) != "Hello world" { + t.Fatalf("expected response body %s, got %s", "Hello world", responseBody) + } + stream.Close() + closeC <- struct{}{} + n, err = stream.Write([]byte("aaaaa")) + if n > 0 { + t.Fatalf("wrote %d bytes after EOF", n) + } + if err != io.EOF { + t.Fatalf("expected EOF, got %s", err) + } + <-closeC +} + +func TestSingleStreamLargeResponseBody(t *testing.T) { + muxPair := NewDefaultMuxerPair() + bodySize := 1 << 24 + muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { + if len(stream.Headers) != 1 { + t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) + } + if stream.Headers[0].Name != "test-header" { + t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name) + } + if stream.Headers[0].Value != "headerValue" { + t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value) + } + stream.WriteHeaders([]Header{ + Header{Name: "response-header", Value: "responseValue"}, + }) + payload := make([]byte, bodySize) + for i := range payload { + payload[i] = byte(i % 256) + } + n, err := stream.Write(payload) + if err != nil { + t.Fatalf("origin write error: %s", err) + } + if n != len(payload) { + t.Fatalf("origin short write: %d/%d bytes", n, len(payload)) + } + return nil + }) + muxPair.HandshakeAndServe(t) + + stream, err := muxPair.EdgeMux.OpenStream( + []Header{Header{Name: "test-header", Value: "headerValue"}}, + nil, + ) + if err != nil { + t.Fatalf("error in OpenStream: %s", err) + } + if len(stream.Headers) != 1 { + t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) + } + if stream.Headers[0].Name != "response-header" { + t.Fatalf("expected header name %s, got %s", "response-header", stream.Headers[0].Name) + } + if stream.Headers[0].Value != "responseValue" { + t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value) + } + responseBody := make([]byte, bodySize) + n, err := stream.Read(responseBody) + if err != nil { + t.Fatalf("error from (*MuxedStream).Read: %s", err) + } + if n != len(responseBody) { + t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n) + } +} + +func TestMultipleStreams(t *testing.T) { + muxPair := NewDefaultMuxerPair() + maxStreams := 64 + errorsC := make(chan error, maxStreams) + muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { + if len(stream.Headers) != 1 { + t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) + } + if stream.Headers[0].Name != "client-token" { + t.Fatalf("expected header name %s, got %s", "client-token", stream.Headers[0].Name) + } + log.Debugf("Got request for stream %s", stream.Headers[0].Value) + stream.WriteHeaders([]Header{ + Header{Name: "response-token", Value: stream.Headers[0].Value}, + }) + log.Debugf("Wrote headers for stream %s", stream.Headers[0].Value) + stream.Write([]byte("OK")) + log.Debugf("Wrote body for stream %s", stream.Headers[0].Value) + return nil + }) + muxPair.HandshakeAndServe(t) + + var wg sync.WaitGroup + wg.Add(maxStreams) + for i := 0; i < maxStreams; i++ { + go func(tokenId int) { + defer wg.Done() + tokenString := fmt.Sprintf("%d", tokenId) + stream, err := muxPair.EdgeMux.OpenStream( + []Header{Header{Name: "client-token", Value: tokenString}}, + nil, + ) + log.Debugf("Got headers for stream %d", tokenId) + if err != nil { + errorsC <- err + return + } + if len(stream.Headers) != 1 { + errorsC <- fmt.Errorf("stream %d has error: expected %d headers, got %d", stream.streamID, 1, len(stream.Headers)) + return + } + if stream.Headers[0].Name != "response-token" { + errorsC <- fmt.Errorf("stream %d has error: expected header name %s, got %s", stream.streamID, "response-token", stream.Headers[0].Name) + return + } + if stream.Headers[0].Value != tokenString { + errorsC <- fmt.Errorf("stream %d has error: expected header value %s, got %s", stream.streamID, tokenString, stream.Headers[0].Value) + return + } + responseBody := make([]byte, 2) + n, err := stream.Read(responseBody) + if err != nil { + errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err) + return + } + if n != len(responseBody) { + errorsC <- fmt.Errorf("stream %d has error: expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n) + return + } + if string(responseBody) != "OK" { + errorsC <- fmt.Errorf("stream %d has error: expected response body %s, got %s", stream.streamID, "OK", responseBody) + return + } + }(i) + } + wg.Wait() + close(errorsC) + testFail := false + for err := range errorsC { + testFail = true + log.Error(err) + } + if testFail { + t.Fatalf("TestMultipleStreamsFlowControl failed") + } +} + +func TestMultipleStreamsFlowControl(t *testing.T) { + maxStreams := 32 + errorsC := make(chan error, maxStreams) + responseSizes := make([]int32, maxStreams) + for i := 0; i < maxStreams; i++ { + responseSizes[i] = rand.Int31n(int32(defaultWindowSize << 4)) + } + muxPair := NewDefaultMuxerPair() + muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { + if len(stream.Headers) != 1 { + t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) + } + if stream.Headers[0].Name != "test-header" { + t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name) + } + if stream.Headers[0].Value != "headerValue" { + t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value) + } + stream.WriteHeaders([]Header{ + Header{Name: "response-header", Value: "responseValue"}, + }) + payload := make([]byte, responseSizes[(stream.streamID-2)/2]) + for i := range payload { + payload[i] = byte(i % 256) + } + n, err := stream.Write(payload) + if err != nil { + t.Fatalf("origin write error: %s", err) + } + if n != len(payload) { + t.Fatalf("origin short write: %d/%d bytes", n, len(payload)) + } + return nil + }) + muxPair.HandshakeAndServe(t) + + var wg sync.WaitGroup + wg.Add(maxStreams) + for i := 0; i < maxStreams; i++ { + go func(tokenId int) { + defer wg.Done() + stream, err := muxPair.EdgeMux.OpenStream( + []Header{Header{Name: "test-header", Value: "headerValue"}}, + nil, + ) + if err != nil { + errorsC <- fmt.Errorf("stream %d error in OpenStream: %s", stream.streamID, err) + return + } + if len(stream.Headers) != 1 { + errorsC <- fmt.Errorf("stream %d expected %d headers, got %d", stream.streamID, 1, len(stream.Headers)) + return + } + if stream.Headers[0].Name != "response-header" { + errorsC <- fmt.Errorf("stream %d expected header name %s, got %s", stream.streamID, "response-header", stream.Headers[0].Name) + return + } + if stream.Headers[0].Value != "responseValue" { + errorsC <- fmt.Errorf("stream %d expected header value %s, got %s", stream.streamID, "responseValue", stream.Headers[0].Value) + return + } + + responseBody := make([]byte, responseSizes[(stream.streamID-2)/2]) + n, err := stream.Read(responseBody) + if err != nil { + errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err) + return + } + if n != len(responseBody) { + errorsC <- fmt.Errorf("stream %d expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n) + return + } + }(i) + } + wg.Wait() + close(errorsC) + testFail := false + for err := range errorsC { + testFail = true + log.Error(err) + } + if testFail { + t.Fatalf("TestMultipleStreamsFlowControl failed") + } +} + +func TestGracefulShutdown(t *testing.T) { + sendC := make(chan struct{}) + responseBuf := bytes.Repeat([]byte("Hello world"), 65536) + muxPair := NewDefaultMuxerPair() + muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { + stream.WriteHeaders([]Header{ + Header{Name: "response-header", Value: "responseValue"}, + }) + <-sendC + log.Debugf("Writing %d bytes", len(responseBuf)) + stream.Write(responseBuf) + stream.CloseWrite() + log.Debugf("Wrote %d bytes", len(responseBuf)) + // Reading from the stream will block until the edge closes its end of the stream. + // Otherwise, we'll close the whole connection before receiving the 'stream closed' + // message from the edge. + // Graceful shutdown works if you omit this, it just gives spurious errors for now - + // TODO ignore errors when writing 'stream closed' and we're shutting down. + stream.Read([]byte{0}) + log.Debugf("Handler ends") + return nil + }) + muxPair.HandshakeAndServe(t) + + stream, err := muxPair.EdgeMux.OpenStream( + []Header{Header{Name: "test-header", Value: "headerValue"}}, + nil, + ) + // Start graceful shutdown of the edge mux - this should also close the origin mux when done + muxPair.EdgeMux.Shutdown() + close(sendC) + if err != nil { + t.Fatalf("error in OpenStream: %s", err) + } + responseBody := make([]byte, len(responseBuf)) + log.Debugf("Waiting for %d bytes", len(responseBuf)) + n, err := stream.Read(responseBody) + if err != nil { + t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err) + } + if n != len(responseBody) { + t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n) + } + if !bytes.Equal(responseBuf, responseBody) { + t.Fatalf("response body mismatch") + } + stream.Close() + muxPair.Wait(t) +} + +func TestUnexpectedShutdown(t *testing.T) { + sendC := make(chan struct{}) + handlerFinishC := make(chan struct{}) + responseBuf := bytes.Repeat([]byte("Hello world"), 65536) + muxPair := NewDefaultMuxerPair() + muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { + defer close(handlerFinishC) + stream.WriteHeaders([]Header{ + Header{Name: "response-header", Value: "responseValue"}, + }) + <-sendC + n, err := stream.Read([]byte{0}) + if err != io.EOF { + t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err) + } + if n != 0 { + t.Fatalf("expected empty read, got %d bytes", n) + } + // Write comes after read, because write buffers data before it is flushed. It wouldn't know about EOF + // until some time later. Calling read first forces it to know about EOF now. + _, err = stream.Write(responseBuf) + if err != io.EOF { + t.Fatalf("unexpected error from (*MuxedStream).Write: %s", err) + } + return nil + }) + muxPair.HandshakeAndServe(t) + + stream, err := muxPair.EdgeMux.OpenStream( + []Header{Header{Name: "test-header", Value: "headerValue"}}, + nil, + ) + // Close the underlying connection before telling the origin to write. + muxPair.EdgeReader.Close() + close(sendC) + if err != nil { + t.Fatalf("error in OpenStream: %s", err) + } + responseBody := make([]byte, len(responseBuf)) + n, err := stream.Read(responseBody) + if err != io.EOF { + t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err) + } + if n != 0 { + t.Fatalf("expected response body to have %d bytes, got %d", 0, n) + } + // The write ordering requirement explained in the origin handler applies here too. + _, err = stream.Write(responseBuf) + if err != io.EOF { + t.Fatalf("unexpected error from (*MuxedStream).Write: %s", err) + } + <-handlerFinishC +} + +func EchoHandler(stream *MuxedStream) error { + var buf bytes.Buffer + fmt.Fprintf(&buf, "Hello, world!\n\n# REQUEST HEADERS:\n\n") + for _, header := range stream.Headers { + fmt.Fprintf(&buf, "[%s] = %s\n", header.Name, header.Value) + } + stream.WriteHeaders([]Header{ + {Name: ":status", Value: "200"}, + {Name: "server", Value: "Echo-server/1.0"}, + {Name: "date", Value: time.Now().Format(time.RFC850)}, + {Name: "content-type", Value: "text/html; charset=utf-8"}, + {Name: "content-length", Value: strconv.Itoa(buf.Len())}, + }) + buf.WriteTo(stream) + return nil +} + +func TestOpenAfterDisconnect(t *testing.T) { + for i := 0; i < 3; i++ { + muxPair := NewDefaultMuxerPair() + muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(EchoHandler) + muxPair.HandshakeAndServe(t) + + switch i { + case 0: + // Close both directions of the connection to cause EOF on both peers. + muxPair.OriginReader.Close() + muxPair.OriginWriter.Close() + case 1: + // Close origin reader (edge writer) to cause EOF on origin only. + muxPair.OriginReader.Close() + case 2: + // Close origin writer (edge reader) to cause EOF on edge only. + muxPair.OriginWriter.Close() + } + + _, err := muxPair.EdgeMux.OpenStream( + []Header{Header{Name: "test-header", Value: "headerValue"}}, + nil, + ) + if err != ErrConnectionClosed { + t.Fatalf("unexpected error in OpenStream: %s", err) + } + } +} + +func TestHPACK(t *testing.T) { + muxPair := NewDefaultMuxerPair() + muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(EchoHandler) + muxPair.HandshakeAndServe(t) + + stream, err := muxPair.EdgeMux.OpenStream( + []Header{ + {Name: ":method", Value: "RPC"}, + {Name: ":scheme", Value: "capnp"}, + {Name: ":path", Value: "*"}, + }, + nil, + ) + if err != nil { + t.Fatalf("error in OpenStream: %s", err) + } + stream.Close() + + for i := 0; i < 3; i++ { + stream, err := muxPair.EdgeMux.OpenStream( + []Header{ + {Name: ":method", Value: "GET"}, + {Name: ":scheme", Value: "https"}, + {Name: ":authority", Value: "tunnel.otterlyadorable.co.uk"}, + {Name: ":path", Value: "/get"}, + {Name: "accept-encoding", Value: "gzip"}, + {Name: "cf-ray", Value: "378948953f044408-SFO-DOG"}, + {Name: "cf-visitor", Value: "{\"scheme\":\"https\"}"}, + {Name: "cf-connecting-ip", Value: "2400:cb00:0025:010d:0000:0000:0000:0001"}, + {Name: "x-forwarded-for", Value: "2400:cb00:0025:010d:0000:0000:0000:0001"}, + {Name: "x-forwarded-proto", Value: "https"}, + {Name: "accept-language", Value: "en-gb"}, + {Name: "referer", Value: "https://tunnel.otterlyadorable.co.uk/"}, + {Name: "cookie", Value: "__cfduid=d4555095065f92daedc059490771967d81493032162"}, + {Name: "connection", Value: "Keep-Alive"}, + {Name: "cf-ipcountry", Value: "US"}, + {Name: "accept", Value: "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}, + {Name: "user-agent", Value: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_5) AppleWebKit/603.2.4 (KHTML, like Gecko) Version/10.1.1 Safari/603.2.4"}, + }, + nil, + ) + if err != nil { + t.Fatalf("error in OpenStream: %s", err) + } + if len(stream.Headers) == 0 { + t.Fatal("response has no headers") + } + if stream.Headers[0].Name != ":status" { + t.Fatalf("first header should be status, found %s instead", stream.Headers[0].Name) + } + if stream.Headers[0].Value != "200" { + t.Fatalf("expected status 200, got %s", stream.Headers[0].Value) + } + ioutil.ReadAll(stream) + stream.Close() + } +} + +func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) { + errC := make(chan error) + go func() { + b := []byte{0} + n, err := pipe.Read(b) + if n > 0 { + t.Fatalf("read pipe was not empty") + } + errC <- err + }() + select { + case err := <-errC: + if err != nil { + t.Fatalf("read error: %s", err) + } + case <-time.After(100 * time.Millisecond): + // nothing to read + pipe.Close() + <-errC + } +} diff --git a/h2mux/idletimer.go b/h2mux/idletimer.go new file mode 100644 index 00000000..6e171801 --- /dev/null +++ b/h2mux/idletimer.go @@ -0,0 +1,81 @@ +package h2mux + +import ( + "math/rand" + "sync" + "time" +) + +// IdleTimer is a type of Timer designed for managing heartbeats on an idle connection. +// The timer ticks on an interval with added jitter to avoid accidental synchronisation +// between two endpoints. It tracks the number of retries/ticks since the connection was +// last marked active. +// +// The methods of IdleTimer must not be called while a goroutine is reading from C. +type IdleTimer struct { + // The channel on which ticks are delivered. + C <-chan time.Time + + // A timer used to measure idle connection time. Reset after sending data. + idleTimer *time.Timer + // The maximum length of time a connection is idle before sending a ping. + idleDuration time.Duration + // A pseudorandom source used to add jitter to the idle duration. + randomSource *rand.Rand + // The maximum number of retries allowed. + maxRetries uint64 + // The number of retries since the connection was last marked active. + retries uint64 + // A lock to prevent race condition while checking retries + stateLock sync.RWMutex +} + +func NewIdleTimer(idleDuration time.Duration, maxRetries uint64) *IdleTimer { + t := &IdleTimer{ + idleTimer: time.NewTimer(idleDuration), + idleDuration: idleDuration, + randomSource: rand.New(rand.NewSource(time.Now().Unix())), + maxRetries: maxRetries, + } + t.C = t.idleTimer.C + return t +} + +// Retry should be called when retrying the idle timeout. If the maximum number of retries +// has been met, returns false. +// After calling this function and sending a heartbeat, call ResetTimer. Since sending the +// heartbeat could be a blocking operation, we resetting the timer after the write completes +// to avoid it expiring during the write. +func (t *IdleTimer) Retry() bool { + t.stateLock.Lock() + defer t.stateLock.Unlock() + if t.retries >= t.maxRetries { + return false + } + t.retries++ + return true +} + +func (t *IdleTimer) RetryCount() uint64 { + t.stateLock.RLock() + defer t.stateLock.RUnlock() + return t.retries +} + +// MarkActive resets the idle connection timer and suppresses any outstanding idle events. +func (t *IdleTimer) MarkActive() { + if !t.idleTimer.Stop() { + // eat the timer event to prevent spurious pings + <-t.idleTimer.C + } + t.stateLock.Lock() + t.retries = 0 + t.stateLock.Unlock() + t.ResetTimer() +} + +// Reset the idle timer according to the configured duration, with some added jitter. +func (t *IdleTimer) ResetTimer() { + jitter := time.Duration(t.randomSource.Int63n(int64(t.idleDuration))) + t.idleTimer.Reset(t.idleDuration + jitter) +} diff --git a/h2mux/idletimer_test.go b/h2mux/idletimer_test.go new file mode 100644 index 00000000..92f2b2a3 --- /dev/null +++ b/h2mux/idletimer_test.go @@ -0,0 +1,31 @@ +package h2mux + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRetry(t *testing.T) { + timer := NewIdleTimer(time.Second, 2) + assert.Equal(t, uint64(0), timer.RetryCount()) + ok := timer.Retry() + assert.True(t, ok) + assert.Equal(t, uint64(1), timer.RetryCount()) + ok = timer.Retry() + assert.True(t, ok) + assert.Equal(t, uint64(2), timer.RetryCount()) + ok = timer.Retry() + assert.False(t, ok) +} + +func TestMarkActive(t *testing.T) { + timer := NewIdleTimer(time.Second, 2) + assert.Equal(t, uint64(0), timer.RetryCount()) + ok := timer.Retry() + assert.True(t, ok) + assert.Equal(t, uint64(1), timer.RetryCount()) + timer.MarkActive() + assert.Equal(t, uint64(0), timer.RetryCount()) +} diff --git a/h2mux/muxedstream.go b/h2mux/muxedstream.go new file mode 100644 index 00000000..f0367ce9 --- /dev/null +++ b/h2mux/muxedstream.go @@ -0,0 +1,250 @@ +package h2mux + +import ( + "bytes" + "io" + "sync" +) + +type MuxedStream struct { + Headers []Header + + streamID uint32 + + responseHeadersReceived chan struct{} + + readBuffer *SharedBuffer + receiveWindow uint32 + // current window size limit. Exponentially increase it when it's exhausted + receiveWindowCurrentMax uint32 + // limit set in http2 spec. 2^31-1 + receiveWindowMax uint32 + // nonzero if a WINDOW_UPDATE frame for a stream needs to be sent + windowUpdate uint32 + + writeLock sync.Mutex + // The zero value for Buffer is an empty buffer ready to use. + writeBuffer bytes.Buffer + + sendWindow uint32 + + readyList *ReadyList + headersSent bool + writeHeaders []Header + // true if the write end of this stream has been closed + writeEOF bool + // true if we have sent EOF to the peer + sentEOF bool + // true if the peer sent us an EOF + receivedEOF bool +} + +type flowControlWindow struct { + receiveWindow, sendWindow uint32 +} + +func (s *MuxedStream) Read(p []byte) (n int, err error) { + return s.readBuffer.Read(p) +} + +func (s *MuxedStream) Write(p []byte) (n int, err error) { + s.writeLock.Lock() + defer s.writeLock.Unlock() + if s.writeEOF { + return 0, io.EOF + } + n, err = s.writeBuffer.Write(p) + if n != len(p) || err != nil { + return n, err + } + s.writeNotify() + return n, nil +} + +func (s *MuxedStream) Close() error { + // TUN-115: Close the write buffer before the read buffer. + // In the case of shutdown, read will not get new data, but the write buffer can still receive + // new data. Closing read before write allows application to race between a failed read and a + // successful write, even though this close should appear to be atomic. + // This can't happen the other way because reads may succeed after a failed write; if we read + // past EOF the application will block until we close the buffer. + err := s.CloseWrite() + if err != nil { + if s.CloseRead() == nil { + // don't bother the caller with errors if at least one close succeeded + return nil + } + return err + } + return s.CloseRead() +} + +func (s *MuxedStream) CloseRead() error { + return s.readBuffer.Close() +} + +func (s *MuxedStream) CloseWrite() error { + s.writeLock.Lock() + defer s.writeLock.Unlock() + if s.writeEOF { + return io.EOF + } + s.writeEOF = true + s.writeNotify() + return nil +} + +func (s *MuxedStream) WriteHeaders(headers []Header) error { + s.writeLock.Lock() + defer s.writeLock.Unlock() + if s.writeHeaders != nil { + return ErrStreamHeadersSent + } + s.writeHeaders = headers + s.writeNotify() + return nil +} + +func (s *MuxedStream) FlowControlWindow() *flowControlWindow { + s.writeLock.Lock() + defer s.writeLock.Unlock() + return &flowControlWindow{ + receiveWindow: s.receiveWindow, + sendWindow: s.sendWindow, + } +} + +// writeNotify must happen while holding writeLock. +func (s *MuxedStream) writeNotify() { + s.readyList.Signal(s.streamID) +} + +// Call by muxreader when it gets a WindowUpdateFrame. This is an update of the peer's +// receive window (how much data we can send). +func (s *MuxedStream) replenishSendWindow(bytes uint32) { + s.writeLock.Lock() + s.sendWindow += bytes + s.writeNotify() + s.writeLock.Unlock() +} + +// Call by muxreader when it receives a data frame +func (s *MuxedStream) consumeReceiveWindow(bytes uint32) bool { + s.writeLock.Lock() + defer s.writeLock.Unlock() + // received data size is greater than receive window/buffer + if s.receiveWindow < bytes { + return false + } + s.receiveWindow -= bytes + if s.receiveWindow < s.receiveWindowCurrentMax/2 { + // exhausting client send window (how much data client can send) + if s.receiveWindowCurrentMax < s.receiveWindowMax { + s.receiveWindowCurrentMax <<= 1 + } + s.windowUpdate += s.receiveWindowCurrentMax - s.receiveWindow + s.writeNotify() + } + return true +} + +// receiveEOF should be called when the peer indicates no more data will be sent. +// Returns true if the socket is now closed (i.e. the write side is already closed). +func (s *MuxedStream) receiveEOF() (closed bool) { + s.writeLock.Lock() + defer s.writeLock.Unlock() + s.receivedEOF = true + s.CloseRead() + return s.writeEOF && s.writeBuffer.Len() == 0 +} + +func (s *MuxedStream) gotReceiveEOF() bool { + s.writeLock.Lock() + defer s.writeLock.Unlock() + return s.receivedEOF +} + +// MuxedStreamReader implements io.ReadCloser for the read end of the stream. +// This is useful for passing to functions that close the object after it is done reading, +// but you still want to be able to write data afterwards (e.g. http.Client). +type MuxedStreamReader struct { + *MuxedStream +} + +func (s MuxedStreamReader) Read(p []byte) (n int, err error) { + return s.MuxedStream.Read(p) +} + +func (s MuxedStreamReader) Close() error { + return s.MuxedStream.CloseRead() +} + +// streamChunk represents a chunk of data to be written. +type streamChunk struct { + streamID uint32 + // true if a HEADERS frame should be sent + sendHeaders bool + headers []Header + // nonzero if a WINDOW_UPDATE frame should be sent + windowUpdate uint32 + // true if data frames should be sent + sendData bool + eof bool + buffer bytes.Buffer +} + +// getChunk atomically extracts a chunk of data to be written by MuxWriter. +// The data returned will not exceed the send window for this stream. +func (s *MuxedStream) getChunk() *streamChunk { + s.writeLock.Lock() + defer s.writeLock.Unlock() + + chunk := &streamChunk{ + streamID: s.streamID, + sendHeaders: !s.headersSent, + headers: s.writeHeaders, + windowUpdate: s.windowUpdate, + sendData: !s.sentEOF, + eof: s.writeEOF && uint32(s.writeBuffer.Len()) <= s.sendWindow, + } + + // Copies at most s.sendWindow bytes + //log.Infof("writeBuffer len %d stream %d", s.writeBuffer.Len(), s.streamID) + writeLen, _ := io.CopyN(&chunk.buffer, &s.writeBuffer, int64(s.sendWindow)) + //log.Infof("writeLen %d stream %d", writeLen, s.streamID) + s.sendWindow -= uint32(writeLen) + s.receiveWindow += s.windowUpdate + s.windowUpdate = 0 + s.headersSent = true + + // if this chunk contains the end of the stream, close the stream now + if chunk.sendData && chunk.eof { + s.sentEOF = true + } + + return chunk +} + +func (c *streamChunk) sendHeadersFrame() bool { + return c.sendHeaders +} + +func (c *streamChunk) sendWindowUpdateFrame() bool { + return c.windowUpdate > 0 +} + +func (c *streamChunk) sendDataFrame() bool { + return c.sendData +} + +func (c *streamChunk) nextDataFrame(frameSize int) (payload []byte, endStream bool) { + payload = c.buffer.Next(frameSize) + if c.buffer.Len() == 0 { + // this is the last data frame in this chunk + c.sendData = false + if c.eof { + endStream = true + } + } + return +} diff --git a/h2mux/muxedstream_test.go b/h2mux/muxedstream_test.go new file mode 100644 index 00000000..0987221e --- /dev/null +++ b/h2mux/muxedstream_test.go @@ -0,0 +1,65 @@ +package h2mux + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +const testWindowSize uint32 = 65535 +const testMaxWindowSize uint32 = testWindowSize << 2 + +// Only sending WINDOW_UPDATE frame, so sendWindow should never change +func TestFlowControlSingleStream(t *testing.T) { + stream := &MuxedStream{ + responseHeadersReceived: make(chan struct{}), + readBuffer: NewSharedBuffer(), + receiveWindow: testWindowSize, + receiveWindowCurrentMax: testWindowSize, + receiveWindowMax: testMaxWindowSize, + sendWindow: testWindowSize, + readyList: NewReadyList(), + } + assert.True(t, stream.consumeReceiveWindow(testWindowSize/2)) + dataSent := testWindowSize / 2 + assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow) + assert.Equal(t, testWindowSize, stream.receiveWindowCurrentMax) + assert.Equal(t, uint32(0), stream.windowUpdate) + tempWindowUpdate := stream.windowUpdate + + streamChunk := stream.getChunk() + assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate) + assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow) + assert.Equal(t, uint32(0), stream.windowUpdate) + assert.Equal(t, testWindowSize, stream.sendWindow) + + assert.True(t, stream.consumeReceiveWindow(2)) + dataSent += 2 + assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow) + assert.Equal(t, testWindowSize<<1, stream.receiveWindowCurrentMax) + assert.Equal(t, (testWindowSize<<1)-stream.receiveWindow, stream.windowUpdate) + tempWindowUpdate = stream.windowUpdate + + streamChunk = stream.getChunk() + assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate) + assert.Equal(t, testWindowSize<<1, stream.receiveWindow) + assert.Equal(t, uint32(0), stream.windowUpdate) + assert.Equal(t, testWindowSize, stream.sendWindow) + + assert.True(t, stream.consumeReceiveWindow(testWindowSize+10)) + dataSent = testWindowSize + 10 + assert.Equal(t, (testWindowSize<<1)-dataSent, stream.receiveWindow) + assert.Equal(t, testWindowSize<<2, stream.receiveWindowCurrentMax) + assert.Equal(t, (testWindowSize<<2)-stream.receiveWindow, stream.windowUpdate) + tempWindowUpdate = stream.windowUpdate + + streamChunk = stream.getChunk() + assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate) + assert.Equal(t, testWindowSize<<2, stream.receiveWindow) + assert.Equal(t, uint32(0), stream.windowUpdate) + assert.Equal(t, testWindowSize, stream.sendWindow) + + assert.False(t, stream.consumeReceiveWindow(testMaxWindowSize+1)) + assert.Equal(t, testWindowSize<<2, stream.receiveWindow) + assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax) +} diff --git a/h2mux/muxreader.go b/h2mux/muxreader.go new file mode 100644 index 00000000..b916f21a --- /dev/null +++ b/h2mux/muxreader.go @@ -0,0 +1,326 @@ +package h2mux + +import ( + "encoding/binary" + "io" + "sync" + "time" + + log "github.com/Sirupsen/logrus" + "golang.org/x/net/http2" +) + +type MuxReader struct { + // f is used to read HTTP2 frames. + f *http2.Framer + // handler provides a callback to receive new streams. if nil, new streams cannot be accepted. + handler MuxedStreamHandler + // streams tracks currently-open streams. + streams *activeStreamMap + // readyList is used to signal writable streams. + readyList *ReadyList + // streamErrors lets us report stream errors to the MuxWriter. + streamErrors *StreamErrorMap + // goAwayChan is used to tell the writer to send a GOAWAY message. + goAwayChan chan<- http2.ErrCode + // abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop. + abortChan <-chan struct{} + // pingTimestamp is an atomic value containing the latest received ping timestamp. + pingTimestamp *PingTimestamp + // connActive is used to signal to the writer that something happened on the connection. + // This is used to clear idle timeout disconnection deadlines. + connActive Signal + // The initial value for the send and receive window of a new stream. + initialStreamWindow uint32 + // The max value for the send window of a stream. + streamWindowMax uint32 + // windowMetrics keeps track of min/max/average of send/receive windows for all streams + flowControlMetrics *FlowControlMetrics + metricsMutex sync.Mutex + // r is a reference to the underlying connection used when shutting down. + r io.Closer + // rttMeasurement measures RTT based on ping timestamps. + rttMeasurement RTTMeasurement + rttMutex sync.Mutex +} + +func (r *MuxReader) Shutdown() { + done := r.streams.Shutdown() + if done == nil { + return + } + r.sendGoAway(http2.ErrCodeNo) + go func() { + // close reader side when last stream ends; this will cause the writer to abort + <-done + r.r.Close() + }() +} + +func (r *MuxReader) RTT() RTTMeasurement { + r.rttMutex.Lock() + defer r.rttMutex.Unlock() + return r.rttMeasurement +} + +func (r *MuxReader) FlowControlMetrics() *FlowControlMetrics { + r.metricsMutex.Lock() + defer r.metricsMutex.Unlock() + if r.flowControlMetrics != nil { + return r.flowControlMetrics + } + // No metrics available yet + return &FlowControlMetrics{} +} + +func (r *MuxReader) run(parentLogger *log.Entry) error { + logger := parentLogger.WithFields(log.Fields{ + "subsystem": "mux", + "dir": "read", + }) + defer logger.Debug("event loop finished") + for { + frame, err := r.f.ReadFrame() + if err != nil { + switch e := err.(type) { + case http2.StreamError: + logger.WithError(err).Warn("stream error") + r.streamError(e.StreamID, e.Code) + case http2.ConnectionError: + logger.WithError(err).Warn("connection error") + return r.connectionError(err) + default: + if isConnectionClosedError(err) { + if r.streams.Len() == 0 { + logger.Debug("shutting down") + return nil + } + logger.Warn("connection closed unexpectedly") + return err + } else { + logger.WithError(err).Warn("frame read error") + return r.connectionError(err) + } + } + } + r.connActive.Signal() + logger.WithField("data", frame).Debug("read frame") + switch f := frame.(type) { + case *http2.DataFrame: + err = r.receiveFrameData(f, logger) + case *http2.MetaHeadersFrame: + err = r.receiveHeaderData(f) + case *http2.RSTStreamFrame: + streamID := f.Header().StreamID + if streamID == 0 { + return ErrInvalidStream + } + r.streams.Delete(streamID) + case *http2.PingFrame: + r.receivePingData(f) + case *http2.GoAwayFrame: + err = r.receiveGoAway(f) + case *http2.WindowUpdateFrame: + err = r.updateStreamWindow(f) + default: + err = ErrUnexpectedFrameType + } + if err != nil { + logger.WithField("data", frame).WithError(err).Debug("frame error") + return r.connectionError(err) + } + } +} + +func (r *MuxReader) newMuxedStream(streamID uint32) *MuxedStream { + return &MuxedStream{ + streamID: streamID, + readBuffer: NewSharedBuffer(), + receiveWindow: r.initialStreamWindow, + receiveWindowCurrentMax: r.initialStreamWindow, + receiveWindowMax: r.streamWindowMax, + sendWindow: r.initialStreamWindow, + readyList: r.readyList, + } +} + +// getStreamForFrame returns a stream if valid, or an error describing why the stream could not be returned. +func (r *MuxReader) getStreamForFrame(frame http2.Frame) (*MuxedStream, error) { + sid := frame.Header().StreamID + if sid == 0 { + return nil, ErrUnexpectedFrameType + } + if stream, ok := r.streams.Get(sid); ok { + return stream, nil + } + if r.streams.IsLocalStreamID(sid) { + // no stream available, but no error + return nil, ErrClosedStream + } + if sid < r.streams.LastPeerStreamID() { + // no stream available, stream closed error + return nil, ErrClosedStream + } + return nil, ErrUnknownStream +} + +func (r *MuxReader) defaultStreamErrorHandler(err error, header http2.FrameHeader) error { + if header.Flags.Has(http2.FlagHeadersEndStream) { + return nil + } else if err == ErrUnknownStream || err == ErrClosedStream { + return r.streamError(header.StreamID, http2.ErrCodeStreamClosed) + } else { + return err + } +} + +// Receives header frames from a stream. A non-nil error is a connection error. +func (r *MuxReader) receiveHeaderData(frame *http2.MetaHeadersFrame) error { + var stream *MuxedStream + sid := frame.Header().StreamID + if sid == 0 { + return ErrUnexpectedFrameType + } + newStream := r.streams.IsPeerStreamID(sid) + if newStream { + // header request + // TODO support trailers (if stream exists) + ok, err := r.streams.AcquirePeerID(sid) + if !ok { + // ignore new streams while shutting down + return r.streamError(sid, err) + } + stream = r.newMuxedStream(sid) + // Set stream. Returns false if a stream already existed with that ID or we are shutting down, return false. + if !r.streams.Set(stream) { + // got HEADERS frame for an existing stream + // TODO support trailers + return r.streamError(sid, http2.ErrCodeInternal) + } + } else { + // header response + var err error + if stream, err = r.getStreamForFrame(frame); err != nil { + return r.defaultStreamErrorHandler(err, frame.Header()) + } + } + headers := make([]Header, len(frame.Fields)) + for i, header := range frame.Fields { + headers[i].Name = header.Name + headers[i].Value = header.Value + } + stream.Headers = headers + if frame.Header().Flags.Has(http2.FlagHeadersEndStream) { + stream.receiveEOF() + return nil + } + if newStream { + go r.handleStream(stream) + } else { + close(stream.responseHeadersReceived) + } + return nil +} + +func (r *MuxReader) handleStream(stream *MuxedStream) { + defer stream.Close() + r.handler.ServeStream(stream) +} + +// Receives a data frame from a stream. A non-nil error is a connection error. +func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.Entry) error { + logger := parentLogger.WithField("stream", frame.Header().StreamID) + stream, err := r.getStreamForFrame(frame) + if err != nil { + return r.defaultStreamErrorHandler(err, frame.Header()) + } + data := frame.Data() + if len(data) > 0 { + _, err = stream.readBuffer.Write(data) + if err != nil { + return r.streamError(stream.streamID, http2.ErrCodeInternal) + } + } + if frame.Header().Flags.Has(http2.FlagDataEndStream) { + if stream.receiveEOF() { + r.streams.Delete(stream.streamID) + logger.Debug("stream closed") + } else { + logger.Debug("shutdown receive side") + } + return nil + } + if !stream.consumeReceiveWindow(uint32(len(data))) { + return r.streamError(stream.streamID, http2.ErrCodeFlowControl) + } + return nil +} + +// Receive a PING from the peer. Update RTT and send/receive window metrics if it's an ACK. +func (r *MuxReader) receivePingData(frame *http2.PingFrame) { + ts := int64(binary.LittleEndian.Uint64(frame.Data[:])) + if !frame.IsAck() { + r.pingTimestamp.Set(ts) + return + } + r.rttMutex.Lock() + r.rttMeasurement.Update(time.Unix(0, ts)) + r.rttMutex.Unlock() + r.flowControlMetrics = r.streams.Metrics() +} + +// Receive a GOAWAY from the peer. Gracefully shut down our connection. +func (r *MuxReader) receiveGoAway(frame *http2.GoAwayFrame) error { + r.Shutdown() + // Close all streams above the last processed stream + lastStream := r.streams.LastLocalStreamID() + for i := frame.LastStreamID + 2; i <= lastStream; i++ { + if stream, ok := r.streams.Get(i); ok { + stream.Close() + } + } + return nil +} + +// Receives header frames from a stream. A non-nil error is a connection error. +func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error { + stream, err := r.getStreamForFrame(frame) + if err != nil && err != ErrUnknownStream && err != ErrClosedStream { + return err + } + if stream == nil { + // ignore window updates on closed streams + return nil + } + stream.replenishSendWindow(frame.Increment) + return nil +} + +// Raise a stream processing error, closing the stream. Runs on the write thread. +func (r *MuxReader) streamError(streamID uint32, e http2.ErrCode) error { + r.streamErrors.RaiseError(streamID, e) + return nil +} + +func (r *MuxReader) connectionError(err error) error { + http2Code := http2.ErrCodeInternal + switch e := err.(type) { + case http2.ConnectionError: + http2Code = http2.ErrCode(e) + case MuxerProtocolError: + http2Code = e.h2code + } + log.Warnf("Connection error %v", http2Code) + r.sendGoAway(http2Code) + return err +} + +// Instruct the writer to send a GOAWAY message if possible. This may fail in +// the case where an existing GOAWAY message is in flight or the writer event +// loop already ended. +func (r *MuxReader) sendGoAway(errCode http2.ErrCode) { + select { + case r.goAwayChan <- errCode: + default: + } +} diff --git a/h2mux/muxwriter.go b/h2mux/muxwriter.go new file mode 100644 index 00000000..0da90832 --- /dev/null +++ b/h2mux/muxwriter.go @@ -0,0 +1,238 @@ +package h2mux + +import ( + "bytes" + "encoding/binary" + "io" + "time" + + log "github.com/Sirupsen/logrus" + "golang.org/x/net/http2" + "golang.org/x/net/http2/hpack" +) + +type MuxWriter struct { + // f is used to write HTTP2 frames. + f *http2.Framer + // streams tracks currently-open streams. + streams *activeStreamMap + // streamErrors receives stream errors raised by the MuxReader. + streamErrors *StreamErrorMap + // readyStreamChan is used to multiplex writable streams onto the single connection. + // When a stream becomes writable its ID is sent on this channel. + readyStreamChan <-chan uint32 + // newStreamChan is used to create new streams with a given set of headers. + newStreamChan <-chan MuxedStreamRequest + // goAwayChan is used to send a single GOAWAY message to the peer. The element received + // is the HTTP/2 error code to send. + goAwayChan <-chan http2.ErrCode + // abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop. + abortChan <-chan struct{} + // pingTimestamp is an atomic value containing the latest received ping timestamp. + pingTimestamp *PingTimestamp + // A timer used to measure idle connection time. Reset after sending data. + idleTimer *IdleTimer + // connActiveChan receives a signal that the connection received some (read) activity. + connActiveChan <-chan struct{} + // Maximum size of all frames that can be sent on this connection. + maxFrameSize uint32 + // headerEncoder is the stateful header encoder for this connection + headerEncoder *hpack.Encoder + // headerBuffer is the temporary buffer used by headerEncoder. + headerBuffer bytes.Buffer +} + +type MuxedStreamRequest struct { + stream *MuxedStream + body io.Reader +} + +func (r *MuxedStreamRequest) flushBody() { + io.Copy(r.stream, r.body) + r.stream.CloseWrite() +} + +func tsToPingData(ts int64) [8]byte { + pingData := [8]byte{} + binary.LittleEndian.PutUint64(pingData[:], uint64(ts)) + return pingData +} + +func (w *MuxWriter) run(parentLogger *log.Entry) error { + logger := parentLogger.WithFields(log.Fields{ + "subsystem": "mux", + "dir": "write", + }) + defer logger.Debug("event loop finished") + for { + select { + case <-w.abortChan: + logger.Debug("aborting writer thread") + return nil + case errCode := <-w.goAwayChan: + logger.Debug("sending GOAWAY code ", errCode) + err := w.f.WriteGoAway(w.streams.LastPeerStreamID(), errCode, []byte{}) + if err != nil { + return err + } + w.idleTimer.MarkActive() + case <-w.pingTimestamp.GetUpdateChan(): + logger.Debug("sending PING ACK") + err := w.f.WritePing(true, tsToPingData(w.pingTimestamp.Get())) + if err != nil { + return err + } + w.idleTimer.MarkActive() + case <-w.idleTimer.C: + if !w.idleTimer.Retry() { + return ErrConnectionDropped + } + logger.Debug("sending PING") + err := w.f.WritePing(false, tsToPingData(time.Now().UnixNano())) + if err != nil { + return err + } + w.idleTimer.ResetTimer() + case <-w.connActiveChan: + w.idleTimer.MarkActive() + case <-w.streamErrors.GetSignalChan(): + for streamID, errCode := range w.streamErrors.GetErrors() { + logger.WithField("stream", streamID).WithField("code", errCode).Debug("resetting stream") + err := w.f.WriteRSTStream(streamID, errCode) + if err != nil { + return err + } + } + w.idleTimer.MarkActive() + case streamRequest := <-w.newStreamChan: + streamID := w.streams.AcquireLocalID() + streamRequest.stream.streamID = streamID + if !w.streams.Set(streamRequest.stream) { + // Race between OpenStream and Shutdown, and Shutdown won. Let Shutdown (and the eventual abort) take + // care of this stream. Ideally we'd pass the error directly to the stream object somehow so the + // caller can be unblocked sooner, but the value of that optimisation is minimal for most of the + // reasons why you'd call Shutdown anyway. + continue + } + if streamRequest.body != nil { + go streamRequest.flushBody() + } + streamLogger := logger.WithField("stream", streamID) + err := w.writeStreamData(streamRequest.stream, streamLogger) + if err != nil { + return err + } + w.idleTimer.MarkActive() + case streamID := <-w.readyStreamChan: + streamLogger := logger.WithField("stream", streamID) + stream, ok := w.streams.Get(streamID) + if !ok { + continue + } + err := w.writeStreamData(stream, streamLogger) + if err != nil { + return err + } + w.idleTimer.MarkActive() + } + } +} + +func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error { + logger.Debug("writable") + chunk := stream.getChunk() + + if chunk.sendHeadersFrame() { + err := w.writeHeaders(chunk.streamID, chunk.headers) + if err != nil { + logger.WithError(err).Warn("error writing headers") + return err + } + logger.Debug("output headers") + } + + if chunk.sendWindowUpdateFrame() { + // Send a WINDOW_UPDATE frame to update our receive window. + // If the Stream ID is zero, the window update applies to the connection as a whole + // A WINDOW_UPDATE in a specific stream applies to the connection-level flow control as well. + err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate) + if err != nil { + logger.WithError(err).Warn("error writing window update") + return err + } + logger.Debugf("increment receive window by %d", chunk.windowUpdate) + } + + for chunk.sendDataFrame() { + payload, sentEOF := chunk.nextDataFrame(int(w.maxFrameSize)) + err := w.f.WriteData(chunk.streamID, sentEOF, payload) + if err != nil { + logger.WithError(err).Warn("error writing data") + return err + } + logger.WithField("len", len(payload)).Debug("output data") + + if sentEOF { + if stream.readBuffer.Closed() { + // transition into closed state + if !stream.gotReceiveEOF() { + // the peer may send data that we no longer want to receive. Force them into the + // closed state. + logger.Debug("resetting stream") + w.f.WriteRSTStream(chunk.streamID, http2.ErrCodeNo) + } else { + // Half-open stream transitioned into closed + logger.Debug("closing stream") + } + w.streams.Delete(chunk.streamID) + } else { + logger.Debug("closing stream write side") + } + } + } + return nil +} + +func (w *MuxWriter) encodeHeaders(headers []Header) ([]byte, error) { + w.headerBuffer.Reset() + for _, header := range headers { + err := w.headerEncoder.WriteField(hpack.HeaderField{ + Name: header.Name, + Value: header.Value, + }) + if err != nil { + return nil, err + } + } + return w.headerBuffer.Bytes(), nil +} + +// writeHeaders writes a block of encoded headers, splitting it into multiple frames if necessary. +func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error { + encodedHeaders, err := w.encodeHeaders(headers) + if err != nil { + return err + } + blockSize := int(w.maxFrameSize) + continuation := false + endHeaders := len(encodedHeaders) == 0 + for !endHeaders && err == nil { + blockFragment := encodedHeaders + if len(encodedHeaders) > blockSize { + blockFragment = blockFragment[:blockSize] + encodedHeaders = encodedHeaders[blockSize:] + } else { + endHeaders = true + } + if continuation { + err = w.f.WriteContinuation(streamID, endHeaders, blockFragment) + } else { + err = w.f.WriteHeaders(http2.HeadersFrameParam{ + StreamID: streamID, + EndHeaders: endHeaders, + BlockFragment: blockFragment, + }) + } + } + return err +} diff --git a/h2mux/readylist.go b/h2mux/readylist.go new file mode 100644 index 00000000..5215e464 --- /dev/null +++ b/h2mux/readylist.go @@ -0,0 +1,140 @@ +package h2mux + +// ReadyList multiplexes several event signals onto a single channel. +type ReadyList struct { + signalC chan uint32 + waitC chan uint32 +} + +func NewReadyList() *ReadyList { + rl := &ReadyList{ + signalC: make(chan uint32), + waitC: make(chan uint32), + } + go rl.run() + return rl +} + +// ID is the stream ID +func (r *ReadyList) Signal(ID uint32) { + r.signalC <- ID +} + +func (r *ReadyList) ReadyChannel() <-chan uint32 { + return r.waitC +} + +func (r *ReadyList) Close() { + close(r.signalC) +} + +func (r *ReadyList) run() { + defer close(r.waitC) + var queue readyDescriptorQueue + var firstReady *readyDescriptor + activeDescriptors := newReadyDescriptorMap() + for { + if firstReady == nil { + // Wait for first ready descriptor + i, ok := <-r.signalC + if !ok { + // closed + return + } + firstReady = activeDescriptors.SetIfMissing(i) + } + select { + case r.waitC <- firstReady.ID: + activeDescriptors.Delete(firstReady.ID) + firstReady = queue.Dequeue() + case i, ok := <-r.signalC: + if !ok { + // closed + return + } + newReady := activeDescriptors.SetIfMissing(i) + if newReady != nil { + // key doesn't exist + queue.Enqueue(newReady) + } + } + } +} + +type readyDescriptor struct { + ID uint32 + Next *readyDescriptor +} + +// readyDescriptorQueue is a queue of readyDescriptors in the form of a singly-linked list. +// The nil readyDescriptorQueue is an empty queue ready for use. +type readyDescriptorQueue struct { + Head *readyDescriptor + Tail *readyDescriptor +} + +func (q *readyDescriptorQueue) Empty() bool { + return q.Head == nil +} + +func (q *readyDescriptorQueue) Enqueue(x *readyDescriptor) { + if x.Next != nil { + panic("enqueued already queued item") + } + if q.Empty() { + q.Head = x + q.Tail = x + } else { + q.Tail.Next = x + q.Tail = x + } +} + +// Dequeue returns the first readyDescriptor in the queue, or nil if empty. +func (q *readyDescriptorQueue) Dequeue() *readyDescriptor { + if q.Empty() { + return nil + } + x := q.Head + q.Head = x.Next + x.Next = nil + return x +} + +// readyDescriptorQueue is a map of readyDescriptors keyed by ID. +// It maintains a free list of deleted ready descriptors. +type readyDescriptorMap struct { + descriptors map[uint32]*readyDescriptor + free []*readyDescriptor +} + +func newReadyDescriptorMap() *readyDescriptorMap { + return &readyDescriptorMap{descriptors: make(map[uint32]*readyDescriptor)} +} + +// create or reuse a readyDescriptor if the stream is not in the queue. +// This avoid stream starvation caused by a single high-bandwidth stream monopolising the writer goroutine +func (m *readyDescriptorMap) SetIfMissing(key uint32) *readyDescriptor { + if _, ok := m.descriptors[key]; ok { + return nil + } + + var newDescriptor *readyDescriptor + if len(m.free) > 0 { + // reuse deleted ready descriptors + newDescriptor = m.free[len(m.free)-1] + m.free = m.free[:len(m.free)-1] + } else { + newDescriptor = &readyDescriptor{} + } + newDescriptor.ID = key + m.descriptors[key] = newDescriptor + return newDescriptor +} + +func (m *readyDescriptorMap) Delete(key uint32) { + if descriptor, ok := m.descriptors[key]; ok { + m.free = append(m.free, descriptor) + delete(m.descriptors, key) + } +} diff --git a/h2mux/readylist_test.go b/h2mux/readylist_test.go new file mode 100644 index 00000000..1bf5f0bf --- /dev/null +++ b/h2mux/readylist_test.go @@ -0,0 +1,115 @@ +package h2mux + +import ( + "testing" + "time" +) + +func TestReadyList(t *testing.T) { + rl := NewReadyList() + c := rl.ReadyChannel() + // helper functions + assertEmpty := func() { + select { + case <-c: + t.Fatalf("Spurious wakeup") + default: + } + } + receiveWithTimeout := func() uint32 { + select { + case i := <-c: + return i + case <-time.After(100 * time.Millisecond): + t.Fatalf("Timeout") + return 0 + } + } + // no signals, receive should fail + assertEmpty() + rl.Signal(0) + if receiveWithTimeout() != 0 { + t.Fatalf("Received wrong ID of signalled event") + } + // no new signals, receive should fail + assertEmpty() + // Signals should not block; + // Duplicate unhandled signals should not cause multiple wakeups + signalled := [5]bool{} + for i := range signalled { + rl.Signal(uint32(i)) + rl.Signal(uint32(i)) + } + // All signals should be received once (in any order) + for range signalled { + i := receiveWithTimeout() + if signalled[i] { + t.Fatalf("Received signal %d more than once", i) + } + signalled[i] = true + } +} + +func TestReadyDescriptorQueue(t *testing.T) { + var queue readyDescriptorQueue + items := [4]readyDescriptor{} + for i := range items { + items[i].ID = uint32(i) + } + + if !queue.Empty() { + t.Fatalf("nil queue should be empty") + } + queue.Enqueue(&items[3]) + queue.Enqueue(&items[1]) + queue.Enqueue(&items[0]) + queue.Enqueue(&items[2]) + if queue.Empty() { + t.Fatalf("Empty should be false after enqueue") + } + i := queue.Dequeue().ID + if i != 3 { + t.Fatalf("item 3 should have been dequeued, got %d instead", i) + } + i = queue.Dequeue().ID + if i != 1 { + t.Fatalf("item 1 should have been dequeued, got %d instead", i) + } + i = queue.Dequeue().ID + if i != 0 { + t.Fatalf("item 0 should have been dequeued, got %d instead", i) + } + i = queue.Dequeue().ID + if i != 2 { + t.Fatalf("item 2 should have been dequeued, got %d instead", i) + } + if !queue.Empty() { + t.Fatal("queue should be empty after dequeuing all items") + } + if queue.Dequeue() != nil { + t.Fatal("dequeue on empty queue should return nil") + } +} + +func TestReadyDescriptorMap(t *testing.T) { + m := newReadyDescriptorMap() + m.Delete(42) + // (delete of missing key should be a noop) + x := m.SetIfMissing(42) + if x == nil { + t.Fatal("SetIfMissing for new key returned nil") + } + if m.SetIfMissing(42) != nil { + t.Fatal("SetIfMissing for existing key returned non-nil") + } + // this delete has effect + m.Delete(42) + // the next set should reuse the old object + y := m.SetIfMissing(666) + if y == nil { + t.Fatal("SetIfMissing for new key returned nil") + } + if x != y { + t.Fatal("SetIfMissing didn't reuse freed object") + } +} diff --git a/h2mux/rtt.go b/h2mux/rtt.go new file mode 100644 index 00000000..1c42ff82 --- /dev/null +++ b/h2mux/rtt.go @@ -0,0 +1,53 @@ +package h2mux + +import ( + "sync/atomic" + "time" +) + +// PingTimestamp is an atomic interface around ping timestamping and signalling. +type PingTimestamp struct { + ts int64 + signal Signal +} + +func NewPingTimestamp() *PingTimestamp { + return &PingTimestamp{signal: NewSignal()} +} + +func (pt *PingTimestamp) Set(v int64) { + if atomic.SwapInt64(&pt.ts, v) != 0 { + pt.signal.Signal() + } +} + +func (pt *PingTimestamp) Get() int64 { + return atomic.SwapInt64(&pt.ts, 0) +} + +func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} { + return pt.signal.WaitChannel() +} + +// RTTMeasurement encapsulates a continuous round trip time measurement. +type RTTMeasurement struct { + Current, Min, Max time.Duration + lastMeasurementTime time.Time +} + +// Update updates the computed values with a new measurement. +// outgoingTime is the time that the probe was sent. +// We assume that time.Now() is the time we received that probe. +func (r *RTTMeasurement) Update(outgoingTime time.Time) { + if !r.lastMeasurementTime.Before(outgoingTime) { + return + } + r.lastMeasurementTime = outgoingTime + r.Current = time.Since(outgoingTime) + if r.Max < r.Current { + r.Max = r.Current + } + if r.Min > r.Current { + r.Min = r.Current + } +} diff --git a/h2mux/shared_buffer.go b/h2mux/shared_buffer.go new file mode 100644 index 00000000..4c1b713e --- /dev/null +++ b/h2mux/shared_buffer.go @@ -0,0 +1,64 @@ +package h2mux + +import ( + "bytes" + "io" + "sync" +) + +type SharedBuffer struct { + cond *sync.Cond + buffer bytes.Buffer + eof bool +} + +func NewSharedBuffer() *SharedBuffer { + return &SharedBuffer{ + cond: sync.NewCond(&sync.Mutex{}), + } +} + +func (s *SharedBuffer) Read(p []byte) (n int, err error) { + totalRead := 0 + s.cond.L.Lock() + for totalRead < len(p) { + n, err = s.buffer.Read(p[totalRead:]) + totalRead += n + if err == io.EOF { + if s.eof { + break + } + err = nil + s.cond.Wait() + } + } + s.cond.L.Unlock() + return totalRead, err +} + +func (s *SharedBuffer) Write(p []byte) (n int, err error) { + s.cond.L.Lock() + defer s.cond.L.Unlock() + if s.eof { + return 0, io.EOF + } + n, err = s.buffer.Write(p) + s.cond.Signal() + return +} + +func (s *SharedBuffer) Close() error { + s.cond.L.Lock() + defer s.cond.L.Unlock() + if !s.eof { + s.eof = true + s.cond.Signal() + } + return nil +} + +func (s *SharedBuffer) Closed() bool { + s.cond.L.Lock() + defer s.cond.L.Unlock() + return s.eof +} diff --git a/h2mux/shared_buffer_test.go b/h2mux/shared_buffer_test.go new file mode 100644 index 00000000..939228e1 --- /dev/null +++ b/h2mux/shared_buffer_test.go @@ -0,0 +1,120 @@ +package h2mux + +import ( + "bytes" + "io" + "sync" + "testing" + "time" +) + +func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) { + return func(actual int, err error) { + if expected != actual { + t.Fatalf("Expected %d bytes, got %d", expected, actual) + } + if err != nil { + t.Fatalf("Unexpected error %s", err) + } + } +} + +func TestSharedBuffer(t *testing.T) { + b := NewSharedBuffer() + testData := []byte("Hello world") + AssertIOReturnIsGood(t, len(testData))(b.Write(testData)) + bytesRead := make([]byte, len(testData)) + AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead)) +} + +func TestSharedBufferBlockingRead(t *testing.T) { + b := NewSharedBuffer() + testData := []byte("Hello world") + result := make(chan []byte) + go func() { + bytesRead := make([]byte, len(testData)) + AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead)) + result <- bytesRead + }() + select { + case <-result: + t.Fatalf("read returned early") + default: + } + AssertIOReturnIsGood(t, 5)(b.Write(testData[:5])) + select { + case <-result: + t.Fatalf("read returned early") + default: + } + AssertIOReturnIsGood(t, len(testData)-5)(b.Write(testData[5:])) + select { + case r := <-result: + if string(r) != string(testData) { + t.Fatalf("expected read to return %s, got %s", testData, r) + } + case <-time.After(time.Second): + t.Fatalf("read timed out") + } +} + +// This is quite slow under the race detector +func TestSharedBufferConcurrentReadWrite(t *testing.T) { + b := NewSharedBuffer() + var expectedResult, actualResult bytes.Buffer + var wg sync.WaitGroup + wg.Add(2) + go func() { + block := make([]byte, 256) + for i := range block { + block[i] = byte(i) + } + for blockSize := 1; blockSize <= 256; blockSize++ { + for i := 0; i < 256; i++ { + expectedResult.Write(block[:blockSize]) + n, err := b.Write(block[:blockSize]) + if n != blockSize || err != nil { + t.Fatalf("write error: %d %s", n, err) + } + } + } + wg.Done() + }() + go func() { + block := make([]byte, 256) + // Change block sizes in opposition to the write thread, to test blocking for new data. + for blockSize := 256; blockSize > 0; blockSize-- { + for i := 0; i < 256; i++ { + n, err := b.Read(block[:blockSize]) + if n != blockSize || err != nil { + t.Fatalf("read error: %d %s", n, err) + } + actualResult.Write(block[:blockSize]) + } + } + wg.Done() + }() + wg.Wait() + if bytes.Compare(expectedResult.Bytes(), actualResult.Bytes()) != 0 { + t.Fatal("Result diverged") + } +} + +func TestSharedBufferClose(t *testing.T) { + b := NewSharedBuffer() + testData := []byte("Hello world") + AssertIOReturnIsGood(t, len(testData))(b.Write(testData)) + err := b.Close() + if err != nil { + t.Fatalf("unexpected error from Close: %s", err) + } + bytesRead := make([]byte, len(testData)) + AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead)) + n, err := b.Read(bytesRead) + if n != 0 { + t.Fatalf("extra bytes received: %d", n) + } + if err != io.EOF { + t.Fatalf("expected EOF, got %s", err) + } +} diff --git a/h2mux/signal.go b/h2mux/signal.go new file mode 100644 index 00000000..d716aed2 --- /dev/null +++ b/h2mux/signal.go @@ -0,0 +1,34 @@ +package h2mux + +// Signal describes an event that can be waited on for at least one signal. +// Signalling the event while it is in the signalled state is a noop. +// When the waiter wakes up, the signal is set to unsignalled. +// It is a way for any number of writers to inform a reader (without blocking) +// that an event has happened. +type Signal struct { + c chan struct{} +} + +// NewSignal creates a new Signal. +func NewSignal() Signal { + return Signal{c: make(chan struct{}, 1)} +} + +// Signal signals the event. +func (s Signal) Signal() { + // This channel is buffered, so the nonblocking send will always succeed if the buffer is empty. + select { + case s.c <- struct{}{}: + default: + } +} + +// Wait for the event to be signalled. +func (s Signal) Wait() { + <-s.c +} + +// WaitChannel returns a channel that is readable after Signal is called. +func (s Signal) WaitChannel() <-chan struct{} { + return s.c +} diff --git a/h2mux/streamerrormap.go b/h2mux/streamerrormap.go new file mode 100644 index 00000000..926b5ff2 --- /dev/null +++ b/h2mux/streamerrormap.go @@ -0,0 +1,47 @@ +package h2mux + +import ( + "sync" + + "golang.org/x/net/http2" +) + +// StreamErrorMap is used to track stream errors. This is a separate structure to ActiveStreamMap because +// errors can be raised against non-existent or closed streams. +type StreamErrorMap struct { + sync.RWMutex + // errors tracks per-stream errors + errors map[uint32]http2.ErrCode + // hasError is signaled whenever an error is raised. + hasError Signal +} + +// NewStreamErrorMap creates a new StreamErrorMap. +func NewStreamErrorMap() *StreamErrorMap { + return &StreamErrorMap{ + errors: make(map[uint32]http2.ErrCode), + hasError: NewSignal(), + } +} + +// RaiseError raises a stream error. +func (s *StreamErrorMap) RaiseError(streamID uint32, err http2.ErrCode) { + s.Lock() + s.errors[streamID] = err + s.Unlock() + s.hasError.Signal() +} + +// GetSignalChan returns a channel that is signalled when an error is raised. +func (s *StreamErrorMap) GetSignalChan() <-chan struct{} { + return s.hasError.WaitChannel() +} + +// GetErrors retrieves all errors currently raised. This resets the currently-tracked errors. +func (s *StreamErrorMap) GetErrors() map[uint32]http2.ErrCode { + s.Lock() + errors := s.errors + s.errors = make(map[uint32]http2.ErrCode) + s.Unlock() + return errors +} diff --git a/metrics/metrics.go b/metrics/metrics.go new file mode 100644 index 00000000..63036bd8 --- /dev/null +++ b/metrics/metrics.go @@ -0,0 +1,48 @@ +package metrics + +import ( + "net" + "net/http" + _ "net/http/pprof" + "runtime" + "time" + + "golang.org/x/net/context" + + log "github.com/Sirupsen/logrus" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +func ServeMetrics(l net.Listener, shutdownC <-chan struct{}) error { + server := &http.Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + } + go func() { + <-shutdownC + server.Shutdown(context.Background()) + }() + http.Handle("/metrics", promhttp.Handler()) + log.WithField("addr", l.Addr()).Info("Starting metrics server") + err := server.Serve(l) + if err == http.ErrServerClosed { + log.Info("Metrics server stopped") + return nil + } + log.WithError(err).Error("Metrics server quit with error") + return err +} + +func RegisterBuildInfo(buildTime string, version string) { + buildInfo := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + // Don't namespace build_info, since we want it to be consistent across all Cloudflare services + Name: "build_info", + Help: "Build and version information", + }, + []string{"goversion", "revision", "version"}, + ) + prometheus.MustRegister(buildInfo) + buildInfo.WithLabelValues(runtime.Version(), buildTime, version).Set(1) +} diff --git a/origin/backoffhandler.go b/origin/backoffhandler.go new file mode 100644 index 00000000..67c7235e --- /dev/null +++ b/origin/backoffhandler.go @@ -0,0 +1,70 @@ +package origin + +import ( + "time" + + "golang.org/x/net/context" +) + +// Redeclare time functions so they can be overridden in tests. +var ( + timeNow = time.Now + timeAfter = time.After +) + +// BackoffHandler manages exponential backoff and limits the maximum number of retries. +// The base time period is 1 second, doubling with each retry. +// After initial success, a grace period can be set to reset the backoff timer if +// a connection is maintained successfully for a long enough period. The base grace period +// is 2 seconds, doubling with each retry. +type BackoffHandler struct { + // MaxRetries sets the maximum number of retries to perform. The default value + // of 0 disables retry completely. + MaxRetries uint + + retries uint + resetDeadline time.Time +} + +func (b BackoffHandler) GetBackoffDuration(ctx context.Context) (time.Duration, bool) { + // Follows the same logic as Backoff, but without mutating the receiver. + // This select has to happen first to reflect the actual behaviour of the Backoff function. + select { + case <-ctx.Done(): + return time.Duration(0), false + default: + } + if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) { + // b.retries would be set to 0 at this point + return time.Second, true + } + if b.retries >= b.MaxRetries { + return time.Duration(0), false + } + return time.Duration(time.Second * 1 << b.retries), true +} + +// Backoff is used to wait according to exponential backoff. Returns false if the +// maximum number of retries have been used or if the underlying context has been cancelled. +func (b *BackoffHandler) Backoff(ctx context.Context) bool { + if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) { + b.retries = 0 + b.resetDeadline = time.Time{} + } + if b.retries >= b.MaxRetries { + return false + } + select { + case <-timeAfter(time.Duration(time.Second * 1 << b.retries)): + b.retries++ + return true + case <-ctx.Done(): + return false + } +} + +// Sets a grace period within which the the backoff timer is maintained. After the grace +// period expires, the number of retries & backoff duration is reset. +func (b *BackoffHandler) SetGracePeriod() { + b.resetDeadline = timeNow().Add(time.Duration(time.Second * 2 << b.retries)) +} diff --git a/origin/backoffhandler_test.go b/origin/backoffhandler_test.go new file mode 100644 index 00000000..4e598561 --- /dev/null +++ b/origin/backoffhandler_test.go @@ -0,0 +1,114 @@ +package origin + +import ( + "testing" + "time" + + "golang.org/x/net/context" +) + +func immediateTimeAfter(time.Duration) <-chan time.Time { + c := make(chan time.Time, 1) + c <- time.Now() + return c +} + +func TestBackoffRetries(t *testing.T) { + // make backoff return immediately + timeAfter = immediateTimeAfter + ctx := context.Background() + backoff := BackoffHandler{MaxRetries: 3} + if !backoff.Backoff(ctx) { + t.Fatalf("backoff failed immediately") + } + if !backoff.Backoff(ctx) { + t.Fatalf("backoff failed after 1 retry") + } + if !backoff.Backoff(ctx) { + t.Fatalf("backoff failed after 2 retry") + } + if backoff.Backoff(ctx) { + t.Fatalf("backoff allowed after 3 (max) retries") + } +} + +func TestBackoffCancel(t *testing.T) { + // prevent backoff from returning normally + timeAfter = func(time.Duration) <-chan time.Time { return make(chan time.Time) } + ctx, cancelFunc := context.WithCancel(context.Background()) + backoff := BackoffHandler{MaxRetries: 3} + cancelFunc() + if backoff.Backoff(ctx) { + t.Fatalf("backoff allowed after cancel") + } + if _, ok := backoff.GetBackoffDuration(ctx); ok { + t.Fatalf("backoff allowed after cancel") + } +} + +func TestBackoffGracePeriod(t *testing.T) { + currentTime := time.Now() + // make timeNow return whatever we like + timeNow = func() time.Time { return currentTime } + // make backoff return immediately + timeAfter = immediateTimeAfter + ctx := context.Background() + backoff := BackoffHandler{MaxRetries: 1} + if !backoff.Backoff(ctx) { + t.Fatalf("backoff failed immediately") + } + // the next call to Backoff would fail unless it's after the grace period + backoff.SetGracePeriod() + // advance time to after the grace period (~4 seconds) and see what happens + currentTime = currentTime.Add(time.Second * 5) + if !backoff.Backoff(ctx) { + t.Fatalf("backoff failed after the grace period expired") + } + // confirm we ignore grace period after backoff + if backoff.Backoff(ctx) { + t.Fatalf("backoff allowed after 1 (max) retry") + } +} + +func TestGetBackoffDurationRetries(t *testing.T) { + // make backoff return immediately + timeAfter = immediateTimeAfter + ctx := context.Background() + backoff := BackoffHandler{MaxRetries: 3} + if _, ok := backoff.GetBackoffDuration(ctx); !ok { + t.Fatalf("backoff failed immediately") + } + backoff.Backoff(ctx) // noop + if _, ok := backoff.GetBackoffDuration(ctx); !ok { + t.Fatalf("backoff failed after 1 retry") + } + backoff.Backoff(ctx) // noop + if _, ok := backoff.GetBackoffDuration(ctx); !ok { + t.Fatalf("backoff failed after 2 retry") + } + backoff.Backoff(ctx) // noop + if _, ok := backoff.GetBackoffDuration(ctx); ok { + t.Fatalf("backoff allowed after 3 (max) retries") + } + if backoff.Backoff(ctx) { + t.Fatalf("backoff allowed after 3 (max) retries") + } +} + +func TestGetBackoffDuration(t *testing.T) { + // make backoff return immediately + timeAfter = immediateTimeAfter + ctx := context.Background() + backoff := BackoffHandler{MaxRetries: 3} + if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second { + t.Fatalf("backoff didn't return 1 second on first retry") + } + backoff.Backoff(ctx) // noop + if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*2 { + t.Fatalf("backoff didn't return 2 seconds on second retry") + } + backoff.Backoff(ctx) // noop + if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*4 { + t.Fatalf("backoff didn't return 4 seconds on third retry") + } +} diff --git a/origin/tunnel.go b/origin/tunnel.go new file mode 100644 index 00000000..975cd5fc --- /dev/null +++ b/origin/tunnel.go @@ -0,0 +1,360 @@ +package origin + +import ( + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "net/url" + "runtime" + "strings" + "time" + + "golang.org/x/net/context" + + "github.com/cloudflare/cloudflare-warp/h2mux" + "github.com/cloudflare/cloudflare-warp/tunnelrpc" + tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" + "github.com/cloudflare/cloudflare-warp/validation" + + log "github.com/Sirupsen/logrus" + raven "github.com/getsentry/raven-go" + "github.com/pkg/errors" + rpc "zombiezen.com/go/capnproto2/rpc" +) + +const ( + dialTimeout = 15 * time.Second + + TagHeaderNamePrefix = "Cf-Warp-Tag-" +) + +type TunnelConfig struct { + EdgeAddr string + OriginUrl string + Hostname string + APIKey string + APIEmail string + APICAKey string + TlsConfig *tls.Config + Retries uint + HeartbeatInterval time.Duration + MaxHeartbeats uint64 + ClientID string + ReportedVersion string + LBPool string + Tags []tunnelpogs.Tag + AccessInternalIP bool + ConnectedSignal h2mux.Signal +} + +type dialError struct { + cause error +} + +func (e dialError) Error() string { + return e.cause.Error() +} + +type printableRegisterTunnelError struct { + cause error + permanent bool +} + +func (e printableRegisterTunnelError) Error() string { + return e.cause.Error() +} + +func (c *TunnelConfig) RegistrationOptions() *tunnelpogs.RegistrationOptions { + policy := tunnelrpc.ExistingTunnelPolicy_disconnect + if c.LBPool != "" { + policy = tunnelrpc.ExistingTunnelPolicy_balance + } + return &tunnelpogs.RegistrationOptions{ + ClientID: c.ClientID, + Version: c.ReportedVersion, + OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH), + ExistingTunnelPolicy: policy, + PoolID: c.LBPool, + Tags: c.Tags, + ExposeInternalHostname: c.AccessInternalIP, + } +} + +func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}) error { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-shutdownC + cancel() + }() + backoff := BackoffHandler{MaxRetries: config.Retries} + for { + err, recoverable := ServeTunnel(ctx, config, &backoff) + if recoverable { + if duration, ok := backoff.GetBackoffDuration(ctx); ok { + log.Infof("Retrying in %s seconds", duration) + backoff.Backoff(ctx) + continue + } + } + return err + } +} + +func ServeTunnel( + ctx context.Context, + config *TunnelConfig, + backoff *BackoffHandler, +) (err error, recoverable bool) { + // Returns error from parsing the origin URL or handshake errors + handler, err := NewTunnelHandler(ctx, config) + if err != nil { + errLog := log.WithError(err) + switch err.(type) { + case dialError: + errLog.Error("Unable to dial edge") + case h2mux.MuxerHandshakeError: + errLog.Error("Handshake failed with edge server") + default: + errLog.Error("Tunnel creation failure") + return err, false + } + return err, true + } + serveCtx, serveCancel := context.WithCancel(ctx) + registerErrC := make(chan error, 1) + go func() { + err := RegisterTunnel(serveCtx, handler.muxer, config) + if err == nil { + config.ConnectedSignal.Signal() + backoff.SetGracePeriod() + } else { + serveCancel() + } + registerErrC <- err + }() + go func() { + <-serveCtx.Done() + handler.muxer.Shutdown() + }() + err = handler.muxer.Serve() + serveCancel() + registerErr := <-registerErrC + if err != nil { + log.WithError(err).Error("Tunnel error") + return err, true + } + if registerErr != nil { + raven.CaptureError(registerErr, nil) + // Don't retry on errors like entitlement failure or version too old + if e, ok := registerErr.(printableRegisterTunnelError); ok { + log.WithError(e).Error("Cannot register") + if e.permanent { + return nil, false + } + return e.cause, true + } + log.Error("Cannot register") + return err, true + } + return nil, false +} + +func IsRPCStreamResponse(headers []h2mux.Header) bool { + if len(headers) != 1 { + return false + } + if headers[0].Name != ":status" || headers[0].Value != "200" { + return false + } + return true +} + +func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig) error { + logger := log.WithField("subsystem", "rpc") + logger.Debug("initiating RPC stream") + stream, err := muxer.OpenStream([]h2mux.Header{ + {Name: ":method", Value: "RPC"}, + {Name: ":scheme", Value: "capnp"}, + {Name: ":path", Value: "*"}, + }, nil) + if err != nil { + // RPC stream open error + raven.CaptureError(err, nil) + return err + } + if !IsRPCStreamResponse(stream.Headers) { + // stream response error + raven.CaptureError(err, nil) + return err + } + conn := rpc.NewConn( + tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)), + tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")), + ) + defer conn.Close() + ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)} + // Request server info without blocking tunnel registration; must use capnp library directly. + tsClient := tunnelrpc.TunnelServer{Client: ts.Client} + serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error { + return nil + }) + registration, err := ts.RegisterTunnel( + ctx, + &tunnelpogs.Authentication{Key: config.APIKey, Email: config.APIEmail, OriginCAKey: config.APICAKey}, + config.Hostname, + config.RegistrationOptions(), + ) + LogServerInfo(logger, serverInfoPromise.Result()) + if err != nil { + // RegisterTunnel RPC failure + return err + } + for _, logLine := range registration.LogLines { + logger.Info(logLine) + } + if registration.Err != "" { + return printableRegisterTunnelError{ + cause: fmt.Errorf("Server error: %s", registration.Err), + permanent: registration.PermanentFailure, + } + } + for _, url := range registration.Urls { + log.Infof("Registered at %s", url) + } + for _, logLine := range registration.LogLines { + log.Infof(logLine) + } + return nil +} + +func LogServerInfo(logger *log.Entry, promise tunnelrpc.ServerInfo_Promise) { + serverInfoMessage, err := promise.Struct() + if err != nil { + logger.WithError(err).Warn("Failed to retrieve server information") + return + } + serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage) + if err != nil { + logger.WithError(err).Warn("Failed to retrieve server information") + return + } + log.Infof("Connected to %s", serverInfo.LocationName) +} + +type TunnelHandler struct { + originUrl string + muxer *h2mux.Muxer + httpClient *http.Client + tags []tunnelpogs.Tag +} + +var dialer = net.Dialer{DualStack: true} + +func NewTunnelHandler(ctx context.Context, config *TunnelConfig) (*TunnelHandler, error) { + url, err := validation.ValidateUrl(config.OriginUrl) + if err != nil { + return nil, fmt.Errorf("Unable to parse origin url %#v", url) + } + h := &TunnelHandler{ + originUrl: url, + httpClient: &http.Client{Timeout: time.Minute}, + tags: config.Tags, + } + // Inherit from parent context so we can cancel (Ctrl-C) while dialing + dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout) + // TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one) + plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", config.EdgeAddr) + dialCancel() + if err != nil { + return nil, dialError{cause: err} + } + edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig) + edgeConn.SetDeadline(time.Now().Add(dialTimeout)) + err = edgeConn.Handshake() + if err != nil { + return nil, dialError{cause: err} + } + // clear the deadline on the conn; h2mux has its own timeouts + edgeConn.SetDeadline(time.Time{}) + // Establish a muxed connection with the edge + // Client mux handshake with agent server + h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{ + Timeout: 5 * time.Second, + Handler: h, + IsClient: true, + HeartbeatInterval: config.HeartbeatInterval, + MaxHeartbeats: config.MaxHeartbeats, + }) + if err != nil { + return h, errors.New("TLS handshake error") + } + return h, err +} + +func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error { + for _, header := range h2 { + switch header.Name { + case ":method": + h1.Method = header.Value + case ":scheme": + case ":authority": + // Otherwise the host header will be based on the origin URL + h1.Host = header.Value + case ":path": + u, err := url.Parse(header.Value) + if err != nil { + return fmt.Errorf("unparseable path") + } + resolved := h1.URL.ResolveReference(u) + // prevent escaping base URL + if !strings.HasPrefix(resolved.String(), h1.URL.String()) { + return fmt.Errorf("invalid path") + } + h1.URL = resolved + default: + h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value) + } + } + return nil +} + +func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) { + h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}} + for headerName, headerValues := range h1.Header { + for _, headerValue := range headerValues { + h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue}) + } + } + return +} + +func (h *TunnelHandler) AppendTagHeaders(r *http.Request) { + for _, tag := range h.tags { + r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value) + } +} + +func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { + req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) + if err != nil { + log.WithError(err).Panic("Unexpected error from http.NewRequest") + } + err = H2RequestHeadersToH1Request(stream.Headers, req) + if err != nil { + log.WithError(err).Error("invalid request received") + } + h.AppendTagHeaders(req) + response, err := h.httpClient.Do(req) + if err != nil { + log.WithError(err).Error("HTTP request error") + stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}}) + stream.Write([]byte("502 Bad Gateway")) + } else { + defer response.Body.Close() + stream.WriteHeaders(H1ResponseToH2Response(response)) + io.Copy(stream, response.Body) + } + return nil +} diff --git a/tlsconfig/tlsconfig.go b/tlsconfig/tlsconfig.go new file mode 100644 index 00000000..5b0c81a6 --- /dev/null +++ b/tlsconfig/tlsconfig.go @@ -0,0 +1,62 @@ +// Package tlsconfig provides convenience functions for configuring TLS connections from the +// command line. +package tlsconfig + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + + log "github.com/Sirupsen/logrus" + cli "gopkg.in/urfave/cli.v2" +) + +// CLIFlags names the flags used to configure TLS for a command or subsystem. +// The nil value for a field means the flag is ignored. +type CLIFlags struct { + Cert string + Key string + ClientCert string + RootCA string +} + +// GetConfig returns a TLS configuration according to the flags defined in f and +// set by the user. +func (f CLIFlags) GetConfig(c *cli.Context) *tls.Config { + config := &tls.Config{} + + if c.IsSet(f.Cert) && c.IsSet(f.Key) { + cert, err := tls.LoadX509KeyPair(c.String(f.Cert), c.String(f.Key)) + if err != nil { + log.WithError(err).Fatal("Error parsing X509 key pair") + } + config.Certificates = []tls.Certificate{cert} + config.BuildNameToCertificate() + } + if c.IsSet(f.ClientCert) { + // set of root certificate authorities that servers use if required to verify a client certificate + // by the policy in ClientAuth + config.ClientCAs = LoadCert(c.String(f.ClientCert)) + // server's policy for TLS Client Authentication. Default is no client cert + config.ClientAuth = tls.RequireAndVerifyClientCert + } + // set of root certificate authorities that clients use when verifying server certificates + if c.IsSet(f.RootCA) { + config.RootCAs = LoadCert(c.String(f.RootCA)) + } + + return config +} + +// LoadCert creates a CertPool containing all certificates in a PEM-format file. +func LoadCert(certPath string) *x509.CertPool { + caCert, err := ioutil.ReadFile(certPath) + if err != nil { + log.WithError(err).Fatalf("Error reading certificate %s", certPath) + } + ca := x509.NewCertPool() + if !ca.AppendCertsFromPEM(caCert) { + log.WithError(err).Fatalf("Error parsing certificate %s", certPath) + } + return ca +} diff --git a/tunnelrpc/go.capnp b/tunnelrpc/go.capnp new file mode 100644 index 00000000..c12d70a4 --- /dev/null +++ b/tunnelrpc/go.capnp @@ -0,0 +1,15 @@ +# Generate go.capnp.out with: +# capnp compile -o- go.capnp > go.capnp.out +# Must run inside this directory to preserve paths. + +@0xd12a1c51fedd6c88; + +annotation package(file) :Text; +annotation import(file) :Text; +annotation doc(struct, field, enum) :Text; +annotation tag(enumerant) :Text; +annotation notag(enumerant) :Void; +annotation customtype(field) :Text; +annotation name(struct, field, union, enum, enumerant, interface, method, param, annotation, const, group) :Text; + +$package("capnp"); diff --git a/tunnelrpc/log.go b/tunnelrpc/log.go new file mode 100644 index 00000000..d5bb5698 --- /dev/null +++ b/tunnelrpc/log.go @@ -0,0 +1,26 @@ +package tunnelrpc + +//go:generate capnp compile -ogo -I./tunnelrpc/ tunnelrpc.capnp + +import ( + log "github.com/Sirupsen/logrus" + "golang.org/x/net/context" + "zombiezen.com/go/capnproto2/rpc" +) + +// ConnLogger wraps a logrus *log.Entry for a connection. +type ConnLogger struct { + Entry *log.Entry +} + +func (c ConnLogger) Infof(ctx context.Context, format string, args ...interface{}) { + c.Entry.Infof(format, args...) +} + +func (c ConnLogger) Errorf(ctx context.Context, format string, args ...interface{}) { + c.Entry.Errorf(format, args...) +} + +func ConnLog(log *log.Entry) rpc.ConnOption { + return rpc.ConnLog(ConnLogger{log}) +} diff --git a/tunnelrpc/logtransport.go b/tunnelrpc/logtransport.go new file mode 100644 index 00000000..b31dde9c --- /dev/null +++ b/tunnelrpc/logtransport.go @@ -0,0 +1,45 @@ +// Package logtransport provides a transport that logs all of its messages. +package tunnelrpc + +import ( + "bytes" + + log "github.com/Sirupsen/logrus" + "golang.org/x/net/context" + "zombiezen.com/go/capnproto2/encoding/text" + "zombiezen.com/go/capnproto2/rpc" + rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc" +) + +type transport struct { + rpc.Transport + l *log.Entry +} + +// New creates a new logger that proxies messages to and from t and +// logs them to l. If l is nil, then the log package's default +// logger is used. +func NewTransportLogger(l *log.Entry, t rpc.Transport) rpc.Transport { + return &transport{Transport: t, l: l} +} + +func (t *transport) SendMessage(ctx context.Context, msg rpccapnp.Message) error { + t.l.Debugf("tx %s", formatMsg(msg)) + return t.Transport.SendMessage(ctx, msg) +} + +func (t *transport) RecvMessage(ctx context.Context) (rpccapnp.Message, error) { + msg, err := t.Transport.RecvMessage(ctx) + if err != nil { + t.l.WithError(err).Debug("rx error") + return msg, err + } + t.l.Debugf("rx %s", formatMsg(msg)) + return msg, nil +} + +func formatMsg(m rpccapnp.Message) string { + var buf bytes.Buffer + text.NewEncoder(&buf).Encode(0x91b79f1f808db032, m.Struct) + return buf.String() +} diff --git a/tunnelrpc/pogs/tunnelrpc.go b/tunnelrpc/pogs/tunnelrpc.go new file mode 100644 index 00000000..da3d7836 --- /dev/null +++ b/tunnelrpc/pogs/tunnelrpc.go @@ -0,0 +1,194 @@ +package pogs + +import ( + "github.com/cloudflare/cloudflare-warp/tunnelrpc" + "golang.org/x/net/context" + "zombiezen.com/go/capnproto2" + "zombiezen.com/go/capnproto2/pogs" + "zombiezen.com/go/capnproto2/rpc" + "zombiezen.com/go/capnproto2/server" +) + +type Authentication struct { + Key string + Email string + OriginCAKey string +} + +func MarshalAuthentication(s tunnelrpc.Authentication, p *Authentication) error { + return pogs.Insert(tunnelrpc.Authentication_TypeID, s.Struct, p) +} + +func UnmarshalAuthentication(s tunnelrpc.Authentication) (*Authentication, error) { + p := new(Authentication) + err := pogs.Extract(p, tunnelrpc.Authentication_TypeID, s.Struct) + return p, err +} + +type TunnelRegistration struct { + Err string + Urls []string + LogLines []string + PermanentFailure bool +} + +func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error { + return pogs.Insert(tunnelrpc.TunnelRegistration_TypeID, s.Struct, p) +} + +func UnmarshalTunnelRegistration(s tunnelrpc.TunnelRegistration) (*TunnelRegistration, error) { + p := new(TunnelRegistration) + err := pogs.Extract(p, tunnelrpc.TunnelRegistration_TypeID, s.Struct) + return p, err +} + +type RegistrationOptions struct { + ClientID string `capnp:"clientId"` + Version string + OS string `capnp:"os"` + ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy + PoolID string `capnp:"poolId"` + ExposeInternalHostname bool + Tags []Tag +} + +func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error { + return pogs.Insert(tunnelrpc.RegistrationOptions_TypeID, s.Struct, p) +} + +func UnmarshalRegistrationOptions(s tunnelrpc.RegistrationOptions) (*RegistrationOptions, error) { + p := new(RegistrationOptions) + err := pogs.Extract(p, tunnelrpc.RegistrationOptions_TypeID, s.Struct) + return p, err +} + +type Tag struct { + Name string + Value string +} + +type ServerInfo struct { + LocationName string +} + +func MarshalServerInfo(s tunnelrpc.ServerInfo, p *ServerInfo) error { + return pogs.Insert(tunnelrpc.ServerInfo_TypeID, s.Struct, p) +} + +func UnmarshalServerInfo(s tunnelrpc.ServerInfo) (*ServerInfo, error) { + p := new(ServerInfo) + err := pogs.Extract(p, tunnelrpc.ServerInfo_TypeID, s.Struct) + return p, err +} + +type TunnelServer interface { + RegisterTunnel(ctx context.Context, auth *Authentication, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) + GetServerInfo(ctx context.Context) (*ServerInfo, error) +} + +func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer { + return tunnelrpc.TunnelServer_ServerToClient(TunnelServer_PogsImpl{s}) +} + +type TunnelServer_PogsImpl struct { + impl TunnelServer +} + +func (i TunnelServer_PogsImpl) RegisterTunnel(p tunnelrpc.TunnelServer_registerTunnel) error { + authentication, err := p.Params.Auth() + if err != nil { + return err + } + pogsAuthentication, err := UnmarshalAuthentication(authentication) + if err != nil { + return err + } + hostname, err := p.Params.Hostname() + if err != nil { + return err + } + options, err := p.Params.Options() + if err != nil { + return err + } + pogsOptions, err := UnmarshalRegistrationOptions(options) + if err != nil { + return err + } + server.Ack(p.Options) + registration, err := i.impl.RegisterTunnel(p.Ctx, pogsAuthentication, hostname, pogsOptions) + if err != nil { + return err + } + result, err := p.Results.NewResult() + if err != nil { + return err + } + return MarshalTunnelRegistration(result, registration) +} + +func (i TunnelServer_PogsImpl) GetServerInfo(p tunnelrpc.TunnelServer_getServerInfo) error { + server.Ack(p.Options) + serverInfo, err := i.impl.GetServerInfo(p.Ctx) + if err != nil { + return err + } + result, err := p.Results.NewResult() + if err != nil { + return err + } + return MarshalServerInfo(result, serverInfo) +} + +type TunnelServer_PogsClient struct { + Client capnp.Client + Conn *rpc.Conn +} + +func (c TunnelServer_PogsClient) Close() error { + return c.Conn.Close() +} + +func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, auth *Authentication, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) { + client := tunnelrpc.TunnelServer{Client: c.Client} + promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error { + authentication, err := p.NewAuth() + if err != nil { + return err + } + err = MarshalAuthentication(authentication, auth) + if err != nil { + return err + } + err = p.SetHostname(hostname) + if err != nil { + return err + } + registrationOptions, err := p.NewOptions() + if err != nil { + return err + } + err = MarshalRegistrationOptions(registrationOptions, options) + if err != nil { + return err + } + return nil + }) + retval, err := promise.Result().Struct() + if err != nil { + return nil, err + } + return UnmarshalTunnelRegistration(retval) +} + +func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) { + client := tunnelrpc.TunnelServer{Client: c.Client} + promise := client.GetServerInfo(ctx, func(p tunnelrpc.TunnelServer_getServerInfo_Params) error { + return nil + }) + retval, err := promise.Result().Struct() + if err != nil { + return nil, err + } + return UnmarshalServerInfo(retval) +} diff --git a/tunnelrpc/tunnelrpc.capnp b/tunnelrpc/tunnelrpc.capnp new file mode 100644 index 00000000..8c98c188 --- /dev/null +++ b/tunnelrpc/tunnelrpc.capnp @@ -0,0 +1,56 @@ +using Go = import "go.capnp"; +@0xdb8274f9144abc7e; +$Go.package("tunnelrpc"); +$Go.import("github.com/cloudflare/cloudflare-warp/tunnelrpc"); + +struct Authentication { + key @0 :Text; + email @1 :Text; + originCAKey @2 :Text; +} + +struct TunnelRegistration { + err @0 :Text; + # A list of URLs that the tunnel is accessible from. + urls @1 :List(Text); + # Used to inform the client of actions taken. + logLines @2 :List(Text); + # In case of error, whether the client should attempt to reconnect. + permanentFailure @3 :Bool; +} + +struct RegistrationOptions { + # The tunnel client's unique identifier, used to verify a reconnection. + clientId @0 :Text; + # Information about the running binary. + version @1 :Text; + os @2 :Text; + # What to do with existing tunnels for the given hostname. + existingTunnelPolicy @3 :ExistingTunnelPolicy; + # If using the balancing policy, identifies the LB pool to use. + poolId @4 :Text; + # Prevents the tunnel from being accessed at .cftunnel.com + exposeInternalHostname @5 :Bool; + # Client-defined tags to associate with the tunnel + tags @6 :List(Tag); +} + +struct Tag { + name @0 :Text; + value @1 :Text; +} + +enum ExistingTunnelPolicy { + ignore @0; + disconnect @1; + balance @2; +} + +struct ServerInfo { + locationName @0 :Text; +} + +interface TunnelServer { + registerTunnel @0 (auth :Authentication, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration); + getServerInfo @1 () -> (result :ServerInfo); +} diff --git a/tunnelrpc/tunnelrpc.capnp.go b/tunnelrpc/tunnelrpc.capnp.go new file mode 100644 index 00000000..1352ece1 --- /dev/null +++ b/tunnelrpc/tunnelrpc.capnp.go @@ -0,0 +1,1145 @@ +// Code generated by capnpc-go. DO NOT EDIT. + +package tunnelrpc + +import ( + context "golang.org/x/net/context" + capnp "zombiezen.com/go/capnproto2" + text "zombiezen.com/go/capnproto2/encoding/text" + schemas "zombiezen.com/go/capnproto2/schemas" + server "zombiezen.com/go/capnproto2/server" +) + +type Authentication struct{ capnp.Struct } + +// Authentication_TypeID is the unique identifier for the type Authentication. +const Authentication_TypeID = 0xc082ef6e0d42ed1d + +func NewAuthentication(s *capnp.Segment) (Authentication, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}) + return Authentication{st}, err +} + +func NewRootAuthentication(s *capnp.Segment) (Authentication, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}) + return Authentication{st}, err +} + +func ReadRootAuthentication(msg *capnp.Message) (Authentication, error) { + root, err := msg.RootPtr() + return Authentication{root.Struct()}, err +} + +func (s Authentication) String() string { + str, _ := text.Marshal(0xc082ef6e0d42ed1d, s.Struct) + return str +} + +func (s Authentication) Key() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s Authentication) HasKey() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s Authentication) KeyBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s Authentication) SetKey(v string) error { + return s.Struct.SetText(0, v) +} + +func (s Authentication) Email() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s Authentication) HasEmail() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s Authentication) EmailBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s Authentication) SetEmail(v string) error { + return s.Struct.SetText(1, v) +} + +func (s Authentication) OriginCAKey() (string, error) { + p, err := s.Struct.Ptr(2) + return p.Text(), err +} + +func (s Authentication) HasOriginCAKey() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s Authentication) OriginCAKeyBytes() ([]byte, error) { + p, err := s.Struct.Ptr(2) + return p.TextBytes(), err +} + +func (s Authentication) SetOriginCAKey(v string) error { + return s.Struct.SetText(2, v) +} + +// Authentication_List is a list of Authentication. +type Authentication_List struct{ capnp.List } + +// NewAuthentication creates a new list of Authentication. +func NewAuthentication_List(s *capnp.Segment, sz int32) (Authentication_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}, sz) + return Authentication_List{l}, err +} + +func (s Authentication_List) At(i int) Authentication { return Authentication{s.List.Struct(i)} } + +func (s Authentication_List) Set(i int, v Authentication) error { return s.List.SetStruct(i, v.Struct) } + +// Authentication_Promise is a wrapper for a Authentication promised by a client call. +type Authentication_Promise struct{ *capnp.Pipeline } + +func (p Authentication_Promise) Struct() (Authentication, error) { + s, err := p.Pipeline.Struct() + return Authentication{s}, err +} + +type TunnelRegistration struct{ capnp.Struct } + +// TunnelRegistration_TypeID is the unique identifier for the type TunnelRegistration. +const TunnelRegistration_TypeID = 0xf41a0f001ad49e46 + +func NewTunnelRegistration(s *capnp.Segment) (TunnelRegistration, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 3}) + return TunnelRegistration{st}, err +} + +func NewRootTunnelRegistration(s *capnp.Segment) (TunnelRegistration, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 3}) + return TunnelRegistration{st}, err +} + +func ReadRootTunnelRegistration(msg *capnp.Message) (TunnelRegistration, error) { + root, err := msg.RootPtr() + return TunnelRegistration{root.Struct()}, err +} + +func (s TunnelRegistration) String() string { + str, _ := text.Marshal(0xf41a0f001ad49e46, s.Struct) + return str +} + +func (s TunnelRegistration) Err() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s TunnelRegistration) HasErr() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s TunnelRegistration) ErrBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s TunnelRegistration) SetErr(v string) error { + return s.Struct.SetText(0, v) +} + +func (s TunnelRegistration) Urls() (capnp.TextList, error) { + p, err := s.Struct.Ptr(1) + return capnp.TextList{List: p.List()}, err +} + +func (s TunnelRegistration) HasUrls() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s TunnelRegistration) SetUrls(v capnp.TextList) error { + return s.Struct.SetPtr(1, v.List.ToPtr()) +} + +// NewUrls sets the urls field to a newly +// allocated capnp.TextList, preferring placement in s's segment. +func (s TunnelRegistration) NewUrls(n int32) (capnp.TextList, error) { + l, err := capnp.NewTextList(s.Struct.Segment(), n) + if err != nil { + return capnp.TextList{}, err + } + err = s.Struct.SetPtr(1, l.List.ToPtr()) + return l, err +} + +func (s TunnelRegistration) LogLines() (capnp.TextList, error) { + p, err := s.Struct.Ptr(2) + return capnp.TextList{List: p.List()}, err +} + +func (s TunnelRegistration) HasLogLines() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s TunnelRegistration) SetLogLines(v capnp.TextList) error { + return s.Struct.SetPtr(2, v.List.ToPtr()) +} + +// NewLogLines sets the logLines field to a newly +// allocated capnp.TextList, preferring placement in s's segment. +func (s TunnelRegistration) NewLogLines(n int32) (capnp.TextList, error) { + l, err := capnp.NewTextList(s.Struct.Segment(), n) + if err != nil { + return capnp.TextList{}, err + } + err = s.Struct.SetPtr(2, l.List.ToPtr()) + return l, err +} + +func (s TunnelRegistration) PermanentFailure() bool { + return s.Struct.Bit(0) +} + +func (s TunnelRegistration) SetPermanentFailure(v bool) { + s.Struct.SetBit(0, v) +} + +// TunnelRegistration_List is a list of TunnelRegistration. +type TunnelRegistration_List struct{ capnp.List } + +// NewTunnelRegistration creates a new list of TunnelRegistration. +func NewTunnelRegistration_List(s *capnp.Segment, sz int32) (TunnelRegistration_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 3}, sz) + return TunnelRegistration_List{l}, err +} + +func (s TunnelRegistration_List) At(i int) TunnelRegistration { + return TunnelRegistration{s.List.Struct(i)} +} + +func (s TunnelRegistration_List) Set(i int, v TunnelRegistration) error { + return s.List.SetStruct(i, v.Struct) +} + +// TunnelRegistration_Promise is a wrapper for a TunnelRegistration promised by a client call. +type TunnelRegistration_Promise struct{ *capnp.Pipeline } + +func (p TunnelRegistration_Promise) Struct() (TunnelRegistration, error) { + s, err := p.Pipeline.Struct() + return TunnelRegistration{s}, err +} + +type RegistrationOptions struct{ capnp.Struct } + +// RegistrationOptions_TypeID is the unique identifier for the type RegistrationOptions. +const RegistrationOptions_TypeID = 0xc793e50592935b4a + +func NewRegistrationOptions(s *capnp.Segment) (RegistrationOptions, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 5}) + return RegistrationOptions{st}, err +} + +func NewRootRegistrationOptions(s *capnp.Segment) (RegistrationOptions, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 5}) + return RegistrationOptions{st}, err +} + +func ReadRootRegistrationOptions(msg *capnp.Message) (RegistrationOptions, error) { + root, err := msg.RootPtr() + return RegistrationOptions{root.Struct()}, err +} + +func (s RegistrationOptions) String() string { + str, _ := text.Marshal(0xc793e50592935b4a, s.Struct) + return str +} + +func (s RegistrationOptions) ClientId() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s RegistrationOptions) HasClientId() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s RegistrationOptions) ClientIdBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s RegistrationOptions) SetClientId(v string) error { + return s.Struct.SetText(0, v) +} + +func (s RegistrationOptions) Version() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s RegistrationOptions) HasVersion() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s RegistrationOptions) VersionBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s RegistrationOptions) SetVersion(v string) error { + return s.Struct.SetText(1, v) +} + +func (s RegistrationOptions) Os() (string, error) { + p, err := s.Struct.Ptr(2) + return p.Text(), err +} + +func (s RegistrationOptions) HasOs() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s RegistrationOptions) OsBytes() ([]byte, error) { + p, err := s.Struct.Ptr(2) + return p.TextBytes(), err +} + +func (s RegistrationOptions) SetOs(v string) error { + return s.Struct.SetText(2, v) +} + +func (s RegistrationOptions) ExistingTunnelPolicy() ExistingTunnelPolicy { + return ExistingTunnelPolicy(s.Struct.Uint16(0)) +} + +func (s RegistrationOptions) SetExistingTunnelPolicy(v ExistingTunnelPolicy) { + s.Struct.SetUint16(0, uint16(v)) +} + +func (s RegistrationOptions) PoolId() (string, error) { + p, err := s.Struct.Ptr(3) + return p.Text(), err +} + +func (s RegistrationOptions) HasPoolId() bool { + p, err := s.Struct.Ptr(3) + return p.IsValid() || err != nil +} + +func (s RegistrationOptions) PoolIdBytes() ([]byte, error) { + p, err := s.Struct.Ptr(3) + return p.TextBytes(), err +} + +func (s RegistrationOptions) SetPoolId(v string) error { + return s.Struct.SetText(3, v) +} + +func (s RegistrationOptions) ExposeInternalHostname() bool { + return s.Struct.Bit(16) +} + +func (s RegistrationOptions) SetExposeInternalHostname(v bool) { + s.Struct.SetBit(16, v) +} + +func (s RegistrationOptions) Tags() (Tag_List, error) { + p, err := s.Struct.Ptr(4) + return Tag_List{List: p.List()}, err +} + +func (s RegistrationOptions) HasTags() bool { + p, err := s.Struct.Ptr(4) + return p.IsValid() || err != nil +} + +func (s RegistrationOptions) SetTags(v Tag_List) error { + return s.Struct.SetPtr(4, v.List.ToPtr()) +} + +// NewTags sets the tags field to a newly +// allocated Tag_List, preferring placement in s's segment. +func (s RegistrationOptions) NewTags(n int32) (Tag_List, error) { + l, err := NewTag_List(s.Struct.Segment(), n) + if err != nil { + return Tag_List{}, err + } + err = s.Struct.SetPtr(4, l.List.ToPtr()) + return l, err +} + +// RegistrationOptions_List is a list of RegistrationOptions. +type RegistrationOptions_List struct{ capnp.List } + +// NewRegistrationOptions creates a new list of RegistrationOptions. +func NewRegistrationOptions_List(s *capnp.Segment, sz int32) (RegistrationOptions_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 5}, sz) + return RegistrationOptions_List{l}, err +} + +func (s RegistrationOptions_List) At(i int) RegistrationOptions { + return RegistrationOptions{s.List.Struct(i)} +} + +func (s RegistrationOptions_List) Set(i int, v RegistrationOptions) error { + return s.List.SetStruct(i, v.Struct) +} + +// RegistrationOptions_Promise is a wrapper for a RegistrationOptions promised by a client call. +type RegistrationOptions_Promise struct{ *capnp.Pipeline } + +func (p RegistrationOptions_Promise) Struct() (RegistrationOptions, error) { + s, err := p.Pipeline.Struct() + return RegistrationOptions{s}, err +} + +type Tag struct{ capnp.Struct } + +// Tag_TypeID is the unique identifier for the type Tag. +const Tag_TypeID = 0xcbd96442ae3bb01a + +func NewTag(s *capnp.Segment) (Tag, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return Tag{st}, err +} + +func NewRootTag(s *capnp.Segment) (Tag, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}) + return Tag{st}, err +} + +func ReadRootTag(msg *capnp.Message) (Tag, error) { + root, err := msg.RootPtr() + return Tag{root.Struct()}, err +} + +func (s Tag) String() string { + str, _ := text.Marshal(0xcbd96442ae3bb01a, s.Struct) + return str +} + +func (s Tag) Name() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s Tag) HasName() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s Tag) NameBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s Tag) SetName(v string) error { + return s.Struct.SetText(0, v) +} + +func (s Tag) Value() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s Tag) HasValue() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s Tag) ValueBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s Tag) SetValue(v string) error { + return s.Struct.SetText(1, v) +} + +// Tag_List is a list of Tag. +type Tag_List struct{ capnp.List } + +// NewTag creates a new list of Tag. +func NewTag_List(s *capnp.Segment, sz int32) (Tag_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}, sz) + return Tag_List{l}, err +} + +func (s Tag_List) At(i int) Tag { return Tag{s.List.Struct(i)} } + +func (s Tag_List) Set(i int, v Tag) error { return s.List.SetStruct(i, v.Struct) } + +// Tag_Promise is a wrapper for a Tag promised by a client call. +type Tag_Promise struct{ *capnp.Pipeline } + +func (p Tag_Promise) Struct() (Tag, error) { + s, err := p.Pipeline.Struct() + return Tag{s}, err +} + +type ExistingTunnelPolicy uint16 + +// ExistingTunnelPolicy_TypeID is the unique identifier for the type ExistingTunnelPolicy. +const ExistingTunnelPolicy_TypeID = 0x84cb9536a2cf6d3c + +// Values of ExistingTunnelPolicy. +const ( + ExistingTunnelPolicy_ignore ExistingTunnelPolicy = 0 + ExistingTunnelPolicy_disconnect ExistingTunnelPolicy = 1 + ExistingTunnelPolicy_balance ExistingTunnelPolicy = 2 +) + +// String returns the enum's constant name. +func (c ExistingTunnelPolicy) String() string { + switch c { + case ExistingTunnelPolicy_ignore: + return "ignore" + case ExistingTunnelPolicy_disconnect: + return "disconnect" + case ExistingTunnelPolicy_balance: + return "balance" + + default: + return "" + } +} + +// ExistingTunnelPolicyFromString returns the enum value with a name, +// or the zero value if there's no such value. +func ExistingTunnelPolicyFromString(c string) ExistingTunnelPolicy { + switch c { + case "ignore": + return ExistingTunnelPolicy_ignore + case "disconnect": + return ExistingTunnelPolicy_disconnect + case "balance": + return ExistingTunnelPolicy_balance + + default: + return 0 + } +} + +type ExistingTunnelPolicy_List struct{ capnp.List } + +func NewExistingTunnelPolicy_List(s *capnp.Segment, sz int32) (ExistingTunnelPolicy_List, error) { + l, err := capnp.NewUInt16List(s, sz) + return ExistingTunnelPolicy_List{l.List}, err +} + +func (l ExistingTunnelPolicy_List) At(i int) ExistingTunnelPolicy { + ul := capnp.UInt16List{List: l.List} + return ExistingTunnelPolicy(ul.At(i)) +} + +func (l ExistingTunnelPolicy_List) Set(i int, v ExistingTunnelPolicy) { + ul := capnp.UInt16List{List: l.List} + ul.Set(i, uint16(v)) +} + +type ServerInfo struct{ capnp.Struct } + +// ServerInfo_TypeID is the unique identifier for the type ServerInfo. +const ServerInfo_TypeID = 0xf2c68e2547ec3866 + +func NewServerInfo(s *capnp.Segment) (ServerInfo, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return ServerInfo{st}, err +} + +func NewRootServerInfo(s *capnp.Segment) (ServerInfo, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return ServerInfo{st}, err +} + +func ReadRootServerInfo(msg *capnp.Message) (ServerInfo, error) { + root, err := msg.RootPtr() + return ServerInfo{root.Struct()}, err +} + +func (s ServerInfo) String() string { + str, _ := text.Marshal(0xf2c68e2547ec3866, s.Struct) + return str +} + +func (s ServerInfo) LocationName() (string, error) { + p, err := s.Struct.Ptr(0) + return p.Text(), err +} + +func (s ServerInfo) HasLocationName() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s ServerInfo) LocationNameBytes() ([]byte, error) { + p, err := s.Struct.Ptr(0) + return p.TextBytes(), err +} + +func (s ServerInfo) SetLocationName(v string) error { + return s.Struct.SetText(0, v) +} + +// ServerInfo_List is a list of ServerInfo. +type ServerInfo_List struct{ capnp.List } + +// NewServerInfo creates a new list of ServerInfo. +func NewServerInfo_List(s *capnp.Segment, sz int32) (ServerInfo_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}, sz) + return ServerInfo_List{l}, err +} + +func (s ServerInfo_List) At(i int) ServerInfo { return ServerInfo{s.List.Struct(i)} } + +func (s ServerInfo_List) Set(i int, v ServerInfo) error { return s.List.SetStruct(i, v.Struct) } + +// ServerInfo_Promise is a wrapper for a ServerInfo promised by a client call. +type ServerInfo_Promise struct{ *capnp.Pipeline } + +func (p ServerInfo_Promise) Struct() (ServerInfo, error) { + s, err := p.Pipeline.Struct() + return ServerInfo{s}, err +} + +type TunnelServer struct{ Client capnp.Client } + +// TunnelServer_TypeID is the unique identifier for the type TunnelServer. +const TunnelServer_TypeID = 0xea58385c65416035 + +func (c TunnelServer) RegisterTunnel(ctx context.Context, params func(TunnelServer_registerTunnel_Params) error, opts ...capnp.CallOption) TunnelServer_registerTunnel_Results_Promise { + if c.Client == nil { + return TunnelServer_registerTunnel_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "registerTunnel", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 3} + call.ParamsFunc = func(s capnp.Struct) error { return params(TunnelServer_registerTunnel_Params{Struct: s}) } + } + return TunnelServer_registerTunnel_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} +func (c TunnelServer) GetServerInfo(ctx context.Context, params func(TunnelServer_getServerInfo_Params) error, opts ...capnp.CallOption) TunnelServer_getServerInfo_Results_Promise { + if c.Client == nil { + return TunnelServer_getServerInfo_Results_Promise{Pipeline: capnp.NewPipeline(capnp.ErrorAnswer(capnp.ErrNullClient))} + } + call := &capnp.Call{ + Ctx: ctx, + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 1, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "getServerInfo", + }, + Options: capnp.NewCallOptions(opts), + } + if params != nil { + call.ParamsSize = capnp.ObjectSize{DataSize: 0, PointerCount: 0} + call.ParamsFunc = func(s capnp.Struct) error { return params(TunnelServer_getServerInfo_Params{Struct: s}) } + } + return TunnelServer_getServerInfo_Results_Promise{Pipeline: capnp.NewPipeline(c.Client.Call(call))} +} + +type TunnelServer_Server interface { + RegisterTunnel(TunnelServer_registerTunnel) error + + GetServerInfo(TunnelServer_getServerInfo) error +} + +func TunnelServer_ServerToClient(s TunnelServer_Server) TunnelServer { + c, _ := s.(server.Closer) + return TunnelServer{Client: server.New(TunnelServer_Methods(nil, s), c)} +} + +func TunnelServer_Methods(methods []server.Method, s TunnelServer_Server) []server.Method { + if cap(methods) == 0 { + methods = make([]server.Method, 0, 2) + } + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 0, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "registerTunnel", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := TunnelServer_registerTunnel{c, opts, TunnelServer_registerTunnel_Params{Struct: p}, TunnelServer_registerTunnel_Results{Struct: r}} + return s.RegisterTunnel(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + methods = append(methods, server.Method{ + Method: capnp.Method{ + InterfaceID: 0xea58385c65416035, + MethodID: 1, + InterfaceName: "tunnelrpc.capnp:TunnelServer", + MethodName: "getServerInfo", + }, + Impl: func(c context.Context, opts capnp.CallOptions, p, r capnp.Struct) error { + call := TunnelServer_getServerInfo{c, opts, TunnelServer_getServerInfo_Params{Struct: p}, TunnelServer_getServerInfo_Results{Struct: r}} + return s.GetServerInfo(call) + }, + ResultsSize: capnp.ObjectSize{DataSize: 0, PointerCount: 1}, + }) + + return methods +} + +// TunnelServer_registerTunnel holds the arguments for a server call to TunnelServer.registerTunnel. +type TunnelServer_registerTunnel struct { + Ctx context.Context + Options capnp.CallOptions + Params TunnelServer_registerTunnel_Params + Results TunnelServer_registerTunnel_Results +} + +// TunnelServer_getServerInfo holds the arguments for a server call to TunnelServer.getServerInfo. +type TunnelServer_getServerInfo struct { + Ctx context.Context + Options capnp.CallOptions + Params TunnelServer_getServerInfo_Params + Results TunnelServer_getServerInfo_Results +} + +type TunnelServer_registerTunnel_Params struct{ capnp.Struct } + +// TunnelServer_registerTunnel_Params_TypeID is the unique identifier for the type TunnelServer_registerTunnel_Params. +const TunnelServer_registerTunnel_Params_TypeID = 0xb70431c0dc014915 + +func NewTunnelServer_registerTunnel_Params(s *capnp.Segment) (TunnelServer_registerTunnel_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}) + return TunnelServer_registerTunnel_Params{st}, err +} + +func NewRootTunnelServer_registerTunnel_Params(s *capnp.Segment) (TunnelServer_registerTunnel_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}) + return TunnelServer_registerTunnel_Params{st}, err +} + +func ReadRootTunnelServer_registerTunnel_Params(msg *capnp.Message) (TunnelServer_registerTunnel_Params, error) { + root, err := msg.RootPtr() + return TunnelServer_registerTunnel_Params{root.Struct()}, err +} + +func (s TunnelServer_registerTunnel_Params) String() string { + str, _ := text.Marshal(0xb70431c0dc014915, s.Struct) + return str +} + +func (s TunnelServer_registerTunnel_Params) Auth() (Authentication, error) { + p, err := s.Struct.Ptr(0) + return Authentication{Struct: p.Struct()}, err +} + +func (s TunnelServer_registerTunnel_Params) HasAuth() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s TunnelServer_registerTunnel_Params) SetAuth(v Authentication) error { + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewAuth sets the auth field to a newly +// allocated Authentication struct, preferring placement in s's segment. +func (s TunnelServer_registerTunnel_Params) NewAuth() (Authentication, error) { + ss, err := NewAuthentication(s.Struct.Segment()) + if err != nil { + return Authentication{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +func (s TunnelServer_registerTunnel_Params) Hostname() (string, error) { + p, err := s.Struct.Ptr(1) + return p.Text(), err +} + +func (s TunnelServer_registerTunnel_Params) HasHostname() bool { + p, err := s.Struct.Ptr(1) + return p.IsValid() || err != nil +} + +func (s TunnelServer_registerTunnel_Params) HostnameBytes() ([]byte, error) { + p, err := s.Struct.Ptr(1) + return p.TextBytes(), err +} + +func (s TunnelServer_registerTunnel_Params) SetHostname(v string) error { + return s.Struct.SetText(1, v) +} + +func (s TunnelServer_registerTunnel_Params) Options() (RegistrationOptions, error) { + p, err := s.Struct.Ptr(2) + return RegistrationOptions{Struct: p.Struct()}, err +} + +func (s TunnelServer_registerTunnel_Params) HasOptions() bool { + p, err := s.Struct.Ptr(2) + return p.IsValid() || err != nil +} + +func (s TunnelServer_registerTunnel_Params) SetOptions(v RegistrationOptions) error { + return s.Struct.SetPtr(2, v.Struct.ToPtr()) +} + +// NewOptions sets the options field to a newly +// allocated RegistrationOptions struct, preferring placement in s's segment. +func (s TunnelServer_registerTunnel_Params) NewOptions() (RegistrationOptions, error) { + ss, err := NewRegistrationOptions(s.Struct.Segment()) + if err != nil { + return RegistrationOptions{}, err + } + err = s.Struct.SetPtr(2, ss.Struct.ToPtr()) + return ss, err +} + +// TunnelServer_registerTunnel_Params_List is a list of TunnelServer_registerTunnel_Params. +type TunnelServer_registerTunnel_Params_List struct{ capnp.List } + +// NewTunnelServer_registerTunnel_Params creates a new list of TunnelServer_registerTunnel_Params. +func NewTunnelServer_registerTunnel_Params_List(s *capnp.Segment, sz int32) (TunnelServer_registerTunnel_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 3}, sz) + return TunnelServer_registerTunnel_Params_List{l}, err +} + +func (s TunnelServer_registerTunnel_Params_List) At(i int) TunnelServer_registerTunnel_Params { + return TunnelServer_registerTunnel_Params{s.List.Struct(i)} +} + +func (s TunnelServer_registerTunnel_Params_List) Set(i int, v TunnelServer_registerTunnel_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +// TunnelServer_registerTunnel_Params_Promise is a wrapper for a TunnelServer_registerTunnel_Params promised by a client call. +type TunnelServer_registerTunnel_Params_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_registerTunnel_Params_Promise) Struct() (TunnelServer_registerTunnel_Params, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_registerTunnel_Params{s}, err +} + +func (p TunnelServer_registerTunnel_Params_Promise) Auth() Authentication_Promise { + return Authentication_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +func (p TunnelServer_registerTunnel_Params_Promise) Options() RegistrationOptions_Promise { + return RegistrationOptions_Promise{Pipeline: p.Pipeline.GetPipeline(2)} +} + +type TunnelServer_registerTunnel_Results struct{ capnp.Struct } + +// TunnelServer_registerTunnel_Results_TypeID is the unique identifier for the type TunnelServer_registerTunnel_Results. +const TunnelServer_registerTunnel_Results_TypeID = 0xf2c122394f447e8e + +func NewTunnelServer_registerTunnel_Results(s *capnp.Segment) (TunnelServer_registerTunnel_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return TunnelServer_registerTunnel_Results{st}, err +} + +func NewRootTunnelServer_registerTunnel_Results(s *capnp.Segment) (TunnelServer_registerTunnel_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return TunnelServer_registerTunnel_Results{st}, err +} + +func ReadRootTunnelServer_registerTunnel_Results(msg *capnp.Message) (TunnelServer_registerTunnel_Results, error) { + root, err := msg.RootPtr() + return TunnelServer_registerTunnel_Results{root.Struct()}, err +} + +func (s TunnelServer_registerTunnel_Results) String() string { + str, _ := text.Marshal(0xf2c122394f447e8e, s.Struct) + return str +} + +func (s TunnelServer_registerTunnel_Results) Result() (TunnelRegistration, error) { + p, err := s.Struct.Ptr(0) + return TunnelRegistration{Struct: p.Struct()}, err +} + +func (s TunnelServer_registerTunnel_Results) HasResult() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s TunnelServer_registerTunnel_Results) SetResult(v TunnelRegistration) error { + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewResult sets the result field to a newly +// allocated TunnelRegistration struct, preferring placement in s's segment. +func (s TunnelServer_registerTunnel_Results) NewResult() (TunnelRegistration, error) { + ss, err := NewTunnelRegistration(s.Struct.Segment()) + if err != nil { + return TunnelRegistration{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +// TunnelServer_registerTunnel_Results_List is a list of TunnelServer_registerTunnel_Results. +type TunnelServer_registerTunnel_Results_List struct{ capnp.List } + +// NewTunnelServer_registerTunnel_Results creates a new list of TunnelServer_registerTunnel_Results. +func NewTunnelServer_registerTunnel_Results_List(s *capnp.Segment, sz int32) (TunnelServer_registerTunnel_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}, sz) + return TunnelServer_registerTunnel_Results_List{l}, err +} + +func (s TunnelServer_registerTunnel_Results_List) At(i int) TunnelServer_registerTunnel_Results { + return TunnelServer_registerTunnel_Results{s.List.Struct(i)} +} + +func (s TunnelServer_registerTunnel_Results_List) Set(i int, v TunnelServer_registerTunnel_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +// TunnelServer_registerTunnel_Results_Promise is a wrapper for a TunnelServer_registerTunnel_Results promised by a client call. +type TunnelServer_registerTunnel_Results_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_registerTunnel_Results_Promise) Struct() (TunnelServer_registerTunnel_Results, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_registerTunnel_Results{s}, err +} + +func (p TunnelServer_registerTunnel_Results_Promise) Result() TunnelRegistration_Promise { + return TunnelRegistration_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +type TunnelServer_getServerInfo_Params struct{ capnp.Struct } + +// TunnelServer_getServerInfo_Params_TypeID is the unique identifier for the type TunnelServer_getServerInfo_Params. +const TunnelServer_getServerInfo_Params_TypeID = 0xdc3ed6801961e502 + +func NewTunnelServer_getServerInfo_Params(s *capnp.Segment) (TunnelServer_getServerInfo_Params, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return TunnelServer_getServerInfo_Params{st}, err +} + +func NewRootTunnelServer_getServerInfo_Params(s *capnp.Segment) (TunnelServer_getServerInfo_Params, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}) + return TunnelServer_getServerInfo_Params{st}, err +} + +func ReadRootTunnelServer_getServerInfo_Params(msg *capnp.Message) (TunnelServer_getServerInfo_Params, error) { + root, err := msg.RootPtr() + return TunnelServer_getServerInfo_Params{root.Struct()}, err +} + +func (s TunnelServer_getServerInfo_Params) String() string { + str, _ := text.Marshal(0xdc3ed6801961e502, s.Struct) + return str +} + +// TunnelServer_getServerInfo_Params_List is a list of TunnelServer_getServerInfo_Params. +type TunnelServer_getServerInfo_Params_List struct{ capnp.List } + +// NewTunnelServer_getServerInfo_Params creates a new list of TunnelServer_getServerInfo_Params. +func NewTunnelServer_getServerInfo_Params_List(s *capnp.Segment, sz int32) (TunnelServer_getServerInfo_Params_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 0}, sz) + return TunnelServer_getServerInfo_Params_List{l}, err +} + +func (s TunnelServer_getServerInfo_Params_List) At(i int) TunnelServer_getServerInfo_Params { + return TunnelServer_getServerInfo_Params{s.List.Struct(i)} +} + +func (s TunnelServer_getServerInfo_Params_List) Set(i int, v TunnelServer_getServerInfo_Params) error { + return s.List.SetStruct(i, v.Struct) +} + +// TunnelServer_getServerInfo_Params_Promise is a wrapper for a TunnelServer_getServerInfo_Params promised by a client call. +type TunnelServer_getServerInfo_Params_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_getServerInfo_Params_Promise) Struct() (TunnelServer_getServerInfo_Params, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_getServerInfo_Params{s}, err +} + +type TunnelServer_getServerInfo_Results struct{ capnp.Struct } + +// TunnelServer_getServerInfo_Results_TypeID is the unique identifier for the type TunnelServer_getServerInfo_Results. +const TunnelServer_getServerInfo_Results_TypeID = 0xe3e37d096a5b564e + +func NewTunnelServer_getServerInfo_Results(s *capnp.Segment) (TunnelServer_getServerInfo_Results, error) { + st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return TunnelServer_getServerInfo_Results{st}, err +} + +func NewRootTunnelServer_getServerInfo_Results(s *capnp.Segment) (TunnelServer_getServerInfo_Results, error) { + st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}) + return TunnelServer_getServerInfo_Results{st}, err +} + +func ReadRootTunnelServer_getServerInfo_Results(msg *capnp.Message) (TunnelServer_getServerInfo_Results, error) { + root, err := msg.RootPtr() + return TunnelServer_getServerInfo_Results{root.Struct()}, err +} + +func (s TunnelServer_getServerInfo_Results) String() string { + str, _ := text.Marshal(0xe3e37d096a5b564e, s.Struct) + return str +} + +func (s TunnelServer_getServerInfo_Results) Result() (ServerInfo, error) { + p, err := s.Struct.Ptr(0) + return ServerInfo{Struct: p.Struct()}, err +} + +func (s TunnelServer_getServerInfo_Results) HasResult() bool { + p, err := s.Struct.Ptr(0) + return p.IsValid() || err != nil +} + +func (s TunnelServer_getServerInfo_Results) SetResult(v ServerInfo) error { + return s.Struct.SetPtr(0, v.Struct.ToPtr()) +} + +// NewResult sets the result field to a newly +// allocated ServerInfo struct, preferring placement in s's segment. +func (s TunnelServer_getServerInfo_Results) NewResult() (ServerInfo, error) { + ss, err := NewServerInfo(s.Struct.Segment()) + if err != nil { + return ServerInfo{}, err + } + err = s.Struct.SetPtr(0, ss.Struct.ToPtr()) + return ss, err +} + +// TunnelServer_getServerInfo_Results_List is a list of TunnelServer_getServerInfo_Results. +type TunnelServer_getServerInfo_Results_List struct{ capnp.List } + +// NewTunnelServer_getServerInfo_Results creates a new list of TunnelServer_getServerInfo_Results. +func NewTunnelServer_getServerInfo_Results_List(s *capnp.Segment, sz int32) (TunnelServer_getServerInfo_Results_List, error) { + l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 1}, sz) + return TunnelServer_getServerInfo_Results_List{l}, err +} + +func (s TunnelServer_getServerInfo_Results_List) At(i int) TunnelServer_getServerInfo_Results { + return TunnelServer_getServerInfo_Results{s.List.Struct(i)} +} + +func (s TunnelServer_getServerInfo_Results_List) Set(i int, v TunnelServer_getServerInfo_Results) error { + return s.List.SetStruct(i, v.Struct) +} + +// TunnelServer_getServerInfo_Results_Promise is a wrapper for a TunnelServer_getServerInfo_Results promised by a client call. +type TunnelServer_getServerInfo_Results_Promise struct{ *capnp.Pipeline } + +func (p TunnelServer_getServerInfo_Results_Promise) Struct() (TunnelServer_getServerInfo_Results, error) { + s, err := p.Pipeline.Struct() + return TunnelServer_getServerInfo_Results{s}, err +} + +func (p TunnelServer_getServerInfo_Results_Promise) Result() ServerInfo_Promise { + return ServerInfo_Promise{Pipeline: p.Pipeline.GetPipeline(0)} +} + +const schema_db8274f9144abc7e = "x\xda\x94To\x88\x14e\x18\x7f~\xef\xbb3\xabp" + + "\xb6;\xcc\x0aut\x08b\x90\x82\xe6e\x86\x99\xb4\xe7" + + "\xa5\xe6^\xa7\xb7\xaf]a\xe9\x07\xc7\xbd\xf7\xf6\xc6f" + + "g\xb6\x99\xd9\xcb\x8b\xd4\x92 \x0c2R\xfb\xd2\x87\xc8" + + "\xfbf`%\x14E\x18\x9c\xd0\x1f\xc1\"\x02\x0b\xad\xeb" + + "C\x98\x04RH\xd2\x87\x0cb\xe2\x9d\xbd\xd9\x99\xee\x94" + + "\xf0\xdb\xfb\xe7y\x7f\xef\xef\xf9=\xcf\xf3[9\xc1\xfa" + + "X\xafv\x8fF$\xd6iz\xb4\xae\xf1\xcd\xe4\xfdo" + + "\x9c{\x89\x8c\"\x8b\xf6\x9f\x1e(]\x0f\x0f\xfeH\x84" + + "U\xfb\xd82\x98\xaf\xb2<\x91y\x88\x0d\x11\xa2\x85\x15" + + "LO\xf5\xe6>\"\xa3\x07D\x1a\xcf\x13\xad:\xce\xde" + + "\x04\xc1<\xc5\xde#D=\xbf\xf7/p\xaf\x1e\x9c\"" + + "\xa3\x88\x14*\x0e4\xb7\xf0\xbf\xcd'\xe3\xd5\xe3\\\xc5" + + "\x0e\xec8zD\xbb|\xf4K\x12Ed\x835\x85\xfa" + + "\x07_\x0c\x139\xb5\xfc\x87\xbf\x06B\xd4\xfd\xfe\x83\xef" + + "\xf6\x8f\\<7\x0b:fwB\x9b4O\xa9w\xe6" + + "I\xedYB\xc4.[w\xbc\xf0\xfdC\xd3m\x9e1" + + "\xca|\xfd\x08(\x17m}b\xc7\x9e\xf9\xfb.]\x9a" + + "\xc9\x00\xea\xea\xba\x16g0_/\x13\xa2\xd5\xbb\xd6\xcb" + + "\x9dk\xb6_!\xa3\xc8\xb3b\x98K\xf5+\xe6j]" + + "\xfd\xd1\xab\xbfl\x1eR\xab\xe8\xf0\xfe\x0dC\x0f,>" + + "s-\x8b\xf6\x8c>\xa9\xd0^\x8c\xd1F\xd7\xfc\xf6\xc8" + + "]\x87\xbf\xb86\x8b\xb4\x0a4\x8f\xeb?\x98'c\xc0" + + "\x13*\xf6\xea\xa6\xb7\xcew\x17\xba\xff\x9c\xa5F\xac\xf1" + + "\xd7z7\xcc\x9f\xe2\xd8\x8b\xfa\xaf\xb4<\x0a[\xae+" + + "\x1d\xbf\xc9k+jV\xd3m\xae\xdd\xb8\xd7\x0eB\xdb" + + "\xad\x0f\xc7\x17U\xaf\xe0\xd8\xb5\x89* \xba\xc0\x88\x8c" + + "\x9e\xb5D\x80\xb1\xf0)\"0\xc3\xe8'*\xdbu\xd7" + + "\xf3e4b\x075\xcfu%\xf1Zx`\xb7\xe5X" + + "nMv\xe0\xb5\x04\xbe\x0d\xfb\x98\xf4\xc7\xa5\xbf\xc2\x97" + + "u;\x08\xa5\xdf>\\R\xb5|\x8b7\x02\xd1\xc5s" + + "D9\x10\x19\x1b\x97\x11\x89>\x0e1\xc8`\x00%\xa8" + + "\xc3\xca\x00\x91\xd8\xcc!\x86\x19\x0c\xc6J1/\xd1O" + + "$\x069\xc4v\x86\x82\xd5\x0a\xc7PL{\x88\x80\"" + + "!\x1a\xf3\x82\xd0\xb5\x1a\x92\x88\xd0E\x0c]\x84\x03^" + + "3\xb4=7@1\xed\xa2\x99\xe8\x84:K\xa8\xafo" + + "\x85c\xd2\x0d\xedr\xcdRobMR\xa6\x8bo\xc4" + + "\xf4^\"\xb1\x81CT3L\xb7\xecN\x99\xe6\x9f\x96" + + "\x13\x09\x95E\xb2a\xd9N\xb2\x8b<\xdf\xae\xdb\xee\xc3" + + "\xeb)\xffh\x1a3\xb7\\\xdbb\x09\xfd\x98\xd1P3" + + "\xb4\xf3\x9e\x1b(fwv\x98}\xa8\xe4\xfa\x80CL" + + "e\x98}\xaa\xe4\xfa\x98C|\x96av\xa6\x9bH\x9c" + + "\xe6\x10g\x19\xc0K\xe0D\xc6\xe7\xef\x10\x89\xb3\x1c\xe2" + + "<\x83\x91\xe3%\xe4\x88\x8co\xd7\x12\x89\xaf8\xc4\x05" + + "\x06C+\x96\xa0\x11\x19\xdf}B$.p\x88_\x18" + + "\x0c=W\x82Nd\xfc\xac\x0a8\xcd!\xfeb\x88j" + + "\x8e-\xdd\xb02\x92\xd5\x7f\\\xfa\x81\xed\xb9\xc9\x9e{" + + "A'W9\xd3\x89\xf8O+\xa2\x90\xda\x0c\x01\x05B" + + "\xb9\xe9yNe$\xf3\xae\xe9\x05\xb2\xe2\"\x94\xbek" + + "9\x9b\xbdr\xbb\xec\x001\x80P\x08\xadz\x80\xdb\x08" + + "U\x0e\x14S; \xa8\xc3\x8e\xc4H$\xce\x0f[u" + + "%\xe9\xbc\x8e\xa4KUVK8\xc4\xca\x8c\xa4\xcbU" + + "\xb1\xef\xe6\x10\xf71\x14\xe2\xff\x92\xc2\x8e[NK\xce" + + ")\xe1\x8dG\xa2.\xc3\xf6\xaa\xe2\x8ez\xf1D4\x10" + + "\xdc\xd2\x9bm2h9<\x0cD\xae\xc3w\x81\xaa\xd7" + + "<\x0eQb(\xfb\xea>D1\xb5\x94\x9b5|\xf2" + + "IAa\xb7\x15\xd0\x88:\xde\x8d\xc4\xb4\x8c\xde\xe7\x88" + + "\x19K\xf3H\xfd\x12\x89=\x1a=>1ca>J" + + "f\x9d\xcam\xd8>D\x09oZ\x143\xefC\x15\xb8" + + "5\xc7P\xb9\xe6\x9d\xff\xcf5\xb1\xc4\x9be\x9a\xc8\xc7" + + "G=\x95g\x06m\x0f\x91\xe8\xe2\x10\xb73D\x8e\xd7" + + "\x9e|*l\xcd\x94w\xeeL\xb6\xc9\xa5\x93\xc9\xdbf" + + "Q\xec\xa0Z\xca,vr\x88\xb1L\xffH\xd5T\xbb" + + "8\xc4\xf3\x99\x91\x9cP\xc3\xbb\x97C\x1cKG\xf2\xf5" + + "W\x88\xc41\x0e\xf16C^\xfa~B\xa4\xd0\xf2\x9d" + + "N_\xab3\xd5\xcd\x8eW\x1f\xb4]\x19\xa8\x99\x9bu" + + "\xd5\x94~\xc3r\xa5\x8bp\x93e;-_\xaaFh" + + "\x8f\xc8\xbf\x01\x00\x00\xff\xff\x80\xf4\x060" + +func init() { + schemas.Register(schema_db8274f9144abc7e, + 0x84cb9536a2cf6d3c, + 0xb70431c0dc014915, + 0xc082ef6e0d42ed1d, + 0xc793e50592935b4a, + 0xcbd96442ae3bb01a, + 0xdc3ed6801961e502, + 0xe3e37d096a5b564e, + 0xea58385c65416035, + 0xf2c122394f447e8e, + 0xf2c68e2547ec3866, + 0xf41a0f001ad49e46) +} diff --git a/validation/validation.go b/validation/validation.go new file mode 100644 index 00000000..eb4eb90b --- /dev/null +++ b/validation/validation.go @@ -0,0 +1,136 @@ +package validation + +import ( + "fmt" + "net" + "net/url" + "strings" + + "golang.org/x/net/idna" +) + +const defaultScheme = "http" + +var supportedProtocol = [2]string{"http", "https"} + +func ValidateHostname(hostname string) (string, error) { + if hostname == "" { + return "", fmt.Errorf("Hostname should not be empty") + } + // users gives url(contains schema) not just hostname + if strings.Contains(hostname, ":") || strings.Contains(hostname, "%3A") { + unescapeHostname, err := url.PathUnescape(hostname) + if err != nil { + return "", fmt.Errorf("Hostname(actually a URL) %s has invalid escape characters %s", hostname, unescapeHostname) + } + hostnameToURL, err := url.Parse(unescapeHostname) + if err != nil { + return "", fmt.Errorf("Hostname(actually a URL) %s has invalid format %s", hostname, hostnameToURL) + } + asciiHostname, err := idna.ToASCII(hostnameToURL.Hostname()) + if err != nil { + return "", fmt.Errorf("Hostname(actually a URL) %s has invalid ASCII encdoing %s", hostname, asciiHostname) + } + return asciiHostname, nil + } + + asciiHostname, err := idna.ToASCII(hostname) + if err != nil { + return "", fmt.Errorf("Hostname %s has invalid ASCII encdoing %s", hostname, asciiHostname) + } + hostnameToURL, err := url.Parse(asciiHostname) + if err != nil { + return "", fmt.Errorf("Hostname %s is not valid", hostnameToURL) + } + return hostnameToURL.RequestURI(), nil + +} + +func ValidateUrl(originUrl string) (string, error) { + if originUrl == "" { + return "", fmt.Errorf("Url should not be empty") + } + + if net.ParseIP(originUrl) != nil { + return validateIP("", originUrl, "") + } else if strings.HasPrefix(originUrl, "[") && strings.HasSuffix(originUrl, "]") { + // ParseIP doesn't recoginze [::1] + return validateIP("", originUrl[1:len(originUrl)-1], "") + } + + host, port, err := net.SplitHostPort(originUrl) + // user might pass in an ip address like 127.0.0.1 + if err == nil && net.ParseIP(host) != nil { + return validateIP("", host, port) + } + + unescapedUrl, err := url.PathUnescape(originUrl) + if err != nil { + return "", fmt.Errorf("URL %s has invalid escape characters %s", originUrl, unescapedUrl) + } + + parsedUrl, err := url.Parse(unescapedUrl) + if err != nil { + return "", fmt.Errorf("URL %s has invalid format", originUrl) + } + + // if the url is in the form of host:port, IsAbs() will think host is the schema + var hostname string + hasScheme := parsedUrl.IsAbs() && parsedUrl.Host != "" + if hasScheme { + err := validateScheme(parsedUrl.Scheme) + if err != nil { + return "", err + } + // The earlier check for ip address will miss the case http://[::1] + // and http://[::1]:8080 + if net.ParseIP(parsedUrl.Hostname()) != nil { + return validateIP(parsedUrl.Scheme, parsedUrl.Hostname(), parsedUrl.Port()) + } + hostname, err = ValidateHostname(parsedUrl.Hostname()) + if err != nil { + return "", fmt.Errorf("URL %s has invalid format", originUrl) + } + if parsedUrl.Port() != "" { + return fmt.Sprintf("%s://%s", parsedUrl.Scheme, net.JoinHostPort(hostname, parsedUrl.Port())), nil + } + return fmt.Sprintf("%s://%s", parsedUrl.Scheme, hostname), nil + } else { + if host == "" { + hostname, err = ValidateHostname(originUrl) + if err != nil { + return "", fmt.Errorf("URL no %s has invalid format", originUrl) + } + return fmt.Sprintf("%s://%s", defaultScheme, hostname), nil + } else { + hostname, err = ValidateHostname(host) + if err != nil { + return "", fmt.Errorf("URL %s has invalid format", originUrl) + } + return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil + } + } + +} + +func validateScheme(scheme string) error { + for _, protocol := range supportedProtocol { + if scheme == protocol { + return nil + } + } + return fmt.Errorf("Currently Cloudflare-Warp does not support %s protocol.", scheme) +} + +func validateIP(scheme, host, port string) (string, error) { + if scheme == "" { + scheme = defaultScheme + } + if port != "" { + return fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(host, port)), nil + } else if strings.Contains(host, ":") { + // IPv6 + return fmt.Sprintf("%s://[%s]", scheme, host), nil + } + return fmt.Sprintf("%s://%s", scheme, host), nil +} diff --git a/validation/validation_test.go b/validation/validation_test.go new file mode 100644 index 00000000..0be8066b --- /dev/null +++ b/validation/validation_test.go @@ -0,0 +1,136 @@ +package validation + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateHostname(t *testing.T) { + var inputHostname string + hostname, err := ValidateHostname(inputHostname) + assert.Equal(t, err, fmt.Errorf("Hostname should not be empty")) + assert.Empty(t, hostname) + + inputHostname = "hello.example.com" + hostname, err = ValidateHostname(inputHostname) + assert.Nil(t, err) + assert.Equal(t, "hello.example.com", hostname) + + inputHostname = "http://hello.example.com" + hostname, err = ValidateHostname(inputHostname) + assert.Nil(t, err) + assert.Equal(t, "hello.example.com", hostname) + + inputHostname = "bücher.example.com" + hostname, err = ValidateHostname(inputHostname) + assert.Nil(t, err) + assert.Equal(t, "xn--bcher-kva.example.com", hostname) + + inputHostname = "http://bücher.example.com" + hostname, err = ValidateHostname(inputHostname) + assert.Nil(t, err) + assert.Equal(t, "xn--bcher-kva.example.com", hostname) + + inputHostname = "http%3A%2F%2Fhello.example.com" + hostname, err = ValidateHostname(inputHostname) + assert.Nil(t, err) + assert.Equal(t, "hello.example.com", hostname) + +} + +func TestValidateUrl(t *testing.T) { + validUrl, err := ValidateUrl("") + assert.Equal(t, fmt.Errorf("Url should not be empty"), err) + assert.Empty(t, validUrl) + + validUrl, err = ValidateUrl("https://localhost:8080") + assert.Nil(t, err) + assert.Equal(t, "https://localhost:8080", validUrl) + + validUrl, err = ValidateUrl("localhost:8080") + assert.Nil(t, err) + assert.Equal(t, "http://localhost:8080", validUrl) + + validUrl, err = ValidateUrl("http://localhost") + assert.Nil(t, err) + assert.Equal(t, "http://localhost", validUrl) + + validUrl, err = ValidateUrl("http://127.0.0.1:8080") + assert.Nil(t, err) + assert.Equal(t, "http://127.0.0.1:8080", validUrl) + + validUrl, err = ValidateUrl("127.0.0.1:8080") + assert.Nil(t, err) + assert.Equal(t, "http://127.0.0.1:8080", validUrl) + + validUrl, err = ValidateUrl("127.0.0.1") + assert.Nil(t, err) + assert.Equal(t, "http://127.0.0.1", validUrl) + + validUrl, err = ValidateUrl("https://127.0.0.1:8080") + assert.Nil(t, err) + assert.Equal(t, "https://127.0.0.1:8080", validUrl) + + validUrl, err = ValidateUrl("[::1]:8080") + assert.Nil(t, err) + assert.Equal(t, "http://[::1]:8080", validUrl) + + validUrl, err = ValidateUrl("http://[::1]") + assert.Nil(t, err) + assert.Equal(t, "http://[::1]", validUrl) + + validUrl, err = ValidateUrl("http://[::1]:8080") + assert.Nil(t, err) + assert.Equal(t, "http://[::1]:8080", validUrl) + + validUrl, err = ValidateUrl("[::1]") + assert.Nil(t, err) + assert.Equal(t, "http://[::1]", validUrl) + + validUrl, err = ValidateUrl("https://example.com") + assert.Nil(t, err) + assert.Equal(t, "https://example.com", validUrl) + + validUrl, err = ValidateUrl("example.com") + assert.Nil(t, err) + assert.Equal(t, "http://example.com", validUrl) + + validUrl, err = ValidateUrl("http://hello.example.com") + assert.Nil(t, err) + assert.Equal(t, "http://hello.example.com", validUrl) + + validUrl, err = ValidateUrl("hello.example.com") + assert.Nil(t, err) + assert.Equal(t, "http://hello.example.com", validUrl) + + validUrl, err = ValidateUrl("hello.example.com:8080") + assert.Nil(t, err) + assert.Equal(t, "http://hello.example.com:8080", validUrl) + + validUrl, err = ValidateUrl("https://hello.example.com:8080") + assert.Nil(t, err) + assert.Equal(t, "https://hello.example.com:8080", validUrl) + + validUrl, err = ValidateUrl("https://bücher.example.com") + assert.Nil(t, err) + assert.Equal(t, "https://xn--bcher-kva.example.com", validUrl) + + validUrl, err = ValidateUrl("bücher.example.com") + assert.Nil(t, err) + assert.Equal(t, "http://xn--bcher-kva.example.com", validUrl) + + validUrl, err = ValidateUrl("https%3A%2F%2Fhello.example.com") + assert.Nil(t, err) + assert.Equal(t, "https://hello.example.com", validUrl) + + validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt") + assert.Equal(t, "Currently Cloudflare-Warp does not support ftp protocol.", err.Error()) + assert.Empty(t, validUrl) + + validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080") + assert.Nil(t, err) + assert.Equal(t, "https://hello.example.com:8080", validUrl) + +}