From 3780e14f41a5bf16b4aa5bad6c7d708083d24095 Mon Sep 17 00:00:00 2001 From: cloudflare-warp-bot Date: Tue, 20 Feb 2018 21:13:56 +0000 Subject: [PATCH] Release Warp Client 2018.2.1 --- cmd/cloudflare-warp/hello.go | 148 +++++++++++------ cmd/cloudflare-warp/hello_test.go | 34 ++-- cmd/cloudflare-warp/linux_service.go | 2 + cmd/cloudflare-warp/login.go | 9 +- cmd/cloudflare-warp/main.go | 214 ++++++++++++++++++------- cmd/cloudflare-warp/windows_service.go | 2 +- h2mux/h2mux.go | 2 +- h2mux/h2mux_test.go | 58 +++---- h2mux/muxedstream_test.go | 25 +++ h2mux/muxreader.go | 2 +- h2mux/muxwriter.go | 2 +- h2mux/shared_buffer.go | 5 +- h2mux/shared_buffer_test.go | 37 +++-- metrics/metrics.go | 2 +- origin/metrics.go | 3 +- origin/supervisor.go | 11 +- origin/tunnel.go | 82 ++++++---- tlsconfig/cloudflare_ca.go | 95 +++++++++++ tlsconfig/hello_ca.go | 50 ++++++ tlsconfig/tlsconfig.go | 43 ++++- tunnelrpc/log.go | 22 ++- tunnelrpc/logtransport.go | 2 +- tunnelrpc/pogs/tunnelrpc.go | 5 +- tunnelrpc/tunnelrpc.capnp | 4 +- websocket/websocket.go | 77 +++++++++ 25 files changed, 713 insertions(+), 223 deletions(-) create mode 100644 tlsconfig/cloudflare_ca.go create mode 100644 tlsconfig/hello_ca.go create mode 100644 websocket/websocket.go diff --git a/cmd/cloudflare-warp/hello.go b/cmd/cloudflare-warp/hello.go index 8bb823db..ff5f0ce7 100644 --- a/cmd/cloudflare-warp/hello.go +++ b/cmd/cloudflare-warp/hello.go @@ -2,17 +2,20 @@ package main import ( "bytes" + "crypto/tls" + "encoding/json" "fmt" "html/template" "io/ioutil" "net" "net/http" "os" + "time" - "github.com/pkg/errors" + "github.com/gorilla/websocket" + "gopkg.in/urfave/cli.v2" - log "github.com/Sirupsen/logrus" - cli "gopkg.in/urfave/cli.v2" + "github.com/cloudflare/cloudflare-warp/tlsconfig" ) type templateData struct { @@ -21,6 +24,11 @@ type templateData struct { Body string } +type OriginUpTime struct { + StartTime time.Time `json:"startTime"` + UpTime string `json:"uptime"` +} + const defaultServerName = "the Cloudflare Warp test server" const indexTemplate = ` @@ -85,71 +93,113 @@ const indexTemplate = ` func hello(c *cli.Context) error { address := fmt.Sprintf(":%d", c.Int("port")) - server := NewHelloWorldServer() - if hostname, err := os.Hostname(); err != nil { - server.serverName = hostname + listener, err := createListener(address) + if err != nil { + return err } - err := server.ListenAndServe(address) - return errors.Wrap(err, "Fail to start Hello World Server") + defer listener.Close() + err = startHelloWorldServer(listener, nil) + return err } func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error { - server := NewHelloWorldServer() - if hostname, err := os.Hostname(); err != nil { - server.serverName = hostname + Log.Infof("Starting Hello World server at %s", listener.Addr()) + serverName := defaultServerName + if hostname, err := os.Hostname(); err == nil { + serverName = hostname } - httpServer := &http.Server{Addr: listener.Addr().String(), Handler: server} + + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + + httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil} go func() { <-shutdownC httpServer.Close() }() + + http.HandleFunc("/uptime", uptimeHandler(time.Now())) + http.HandleFunc("/ws", websocketHandler(upgrader)) + http.HandleFunc("/", rootHandler(serverName)) err := httpServer.Serve(listener) return err } -type HelloWorldServer struct { - responseTemplate *template.Template - serverName string -} - -func NewHelloWorldServer() *HelloWorldServer { - return &HelloWorldServer{ - responseTemplate: template.Must(template.New("index").Parse(indexTemplate)), - serverName: defaultServerName, +func createListener(address string) (net.Listener, error) { + certificate, err := tlsconfig.GetHelloCertificate() + if err != nil { + return nil, err } -} -func findAvailablePort() (net.Listener, error) { - // If the port in address is empty, a port number is automatically chosen. - listener, err := net.Listen("tcp", "127.0.0.1:") + // If the port in address is empty, a port number is automatically chosen + listener, err := tls.Listen( + "tcp", + address, + &tls.Config{Certificates: []tls.Certificate{certificate}}) + return listener, err } -func (s *HelloWorldServer) ListenAndServe(address string) error { - log.Infof("Starting Hello World server on %s", address) - err := http.ListenAndServe(address, s) - return err +func uptimeHandler(startTime time.Time) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Note that if autoupdate is enabled, the uptime is reset when a new client + // release is available + resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()} + respJson, err := json.Marshal(resp) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } else { + w.Header().Set("Content-Type", "application/json") + w.Write(respJson) + } + } } -func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto) - var buffer bytes.Buffer - var body string - rawBody, err := ioutil.ReadAll(r.Body) - if err == nil { - body = string(rawBody) - } else { - body = "" - } - err = s.responseTemplate.Execute(&buffer, &templateData{ - ServerName: s.serverName, - Request: r, - Body: body, - }) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintf(w, "error: %v", err) - } else { - buffer.WriteTo(w) +func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + for { + mt, message, err := conn.ReadMessage() + if err != nil { + break + } + + if err := conn.WriteMessage(mt, message); err != nil { + break + } + } + } +} + +func rootHandler(serverName string) http.HandlerFunc { + responseTemplate := template.Must(template.New("index").Parse(indexTemplate)) + return func(w http.ResponseWriter, r *http.Request) { + Log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto) + var buffer bytes.Buffer + var body string + rawBody, err := ioutil.ReadAll(r.Body) + if err == nil { + body = string(rawBody) + } else { + body = "" + } + err = responseTemplate.Execute(&buffer, &templateData{ + ServerName: serverName, + Request: r, + Body: body, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "error: %v", err) + } else { + buffer.WriteTo(w) + } } } diff --git a/cmd/cloudflare-warp/hello_test.go b/cmd/cloudflare-warp/hello_test.go index 9586953d..f6e8842a 100644 --- a/cmd/cloudflare-warp/hello_test.go +++ b/cmd/cloudflare-warp/hello_test.go @@ -4,18 +4,30 @@ import ( "testing" ) -const testPort = "8080" - -func TestNewHelloWorldServer(t *testing.T) { - if NewHelloWorldServer() == nil { - t.Fatal("NewHelloWorldServer returned nil") - } -} - -func TestFindAvailablePort(t *testing.T) { - listener, err := findAvailablePort() +func TestCreateListenerHostAndPortSuccess(t *testing.T) { + listener, err := createListener("localhost:1234") if err != nil { - t.Fatal("Fail to find available port") + t.Fatal(err) + } + if listener.Addr().String() == "" { + t.Fatal("Fail to find available port") + } +} + +func TestCreateListenerOnlyHostSuccess(t *testing.T) { + listener, err := createListener("localhost:") + if err != nil { + t.Fatal(err) + } + if listener.Addr().String() == "" { + t.Fatal("Fail to find available port") + } +} + +func TestCreateListenerOnlyPortSuccess(t *testing.T) { + listener, err := createListener(":8888") + if err != nil { + t.Fatal(err) } if listener.Addr().String() == "" { t.Fatal("Fail to find available port") diff --git a/cmd/cloudflare-warp/linux_service.go b/cmd/cloudflare-warp/linux_service.go index 7d9816b9..755fd97c 100644 --- a/cmd/cloudflare-warp/linux_service.go +++ b/cmd/cloudflare-warp/linux_service.go @@ -42,6 +42,8 @@ After=network.target TimeoutStartSec=0 Type=notify ExecStart={{ .Path }} --config /etc/cloudflare-warp/config.yml --origincert /etc/cloudflare-warp/cert.pem --no-autoupdate +Restart=on-failure +RestartSec=5s [Install] WantedBy=multi-user.target diff --git a/cmd/cloudflare-warp/login.go b/cmd/cloudflare-warp/login.go index 32537275..3ede3892 100644 --- a/cmd/cloudflare-warp/login.go +++ b/cmd/cloudflare-warp/login.go @@ -14,7 +14,6 @@ import ( "syscall" "time" - log "github.com/Sirupsen/logrus" homedir "github.com/mitchellh/go-homedir" cli "gopkg.in/urfave/cli.v2" ) @@ -137,7 +136,7 @@ func download(certURL, filePath string) bool { return true } if err != nil { - log.WithError(err).Error("Error fetching certificate") + Log.WithError(err).Error("Error fetching certificate") return false } } @@ -180,16 +179,16 @@ func putSuccess(client *http.Client, certURL string) { // indicate success to the relay server req, err := http.NewRequest("PUT", certURL+"/ok", nil) if err != nil { - log.WithError(err).Error("HTTP request error") + Log.WithError(err).Error("HTTP request error") return } resp, err := client.Do(req) if err != nil { - log.WithError(err).Error("HTTP error") + Log.WithError(err).Error("HTTP error") return } resp.Body.Close() if resp.StatusCode != 200 { - log.Errorf("Unexpected HTTP error code %d", resp.StatusCode) + Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode) } } diff --git a/cmd/cloudflare-warp/main.go b/cmd/cloudflare-warp/main.go index 0fae594f..c03f756a 100644 --- a/cmd/cloudflare-warp/main.go +++ b/cmd/cloudflare-warp/main.go @@ -11,6 +11,8 @@ import ( "os" "os/signal" "path/filepath" + "runtime" + "strings" "sync" "syscall" "time" @@ -21,11 +23,12 @@ import ( tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" "github.com/cloudflare/cloudflare-warp/validation" - log "github.com/Sirupsen/logrus" "github.com/facebookgo/grace/gracenet" - raven "github.com/getsentry/raven-go" - homedir "github.com/mitchellh/go-homedir" - cli "gopkg.in/urfave/cli.v2" + "github.com/getsentry/raven-go" + "github.com/mitchellh/go-homedir" + "github.com/rifflock/lfshook" + "github.com/sirupsen/logrus" + "gopkg.in/urfave/cli.v2" "gopkg.in/urfave/cli.v2/altsrc" "github.com/coreos/go-systemd/daemon" @@ -40,11 +43,21 @@ const configFile = "config.yml" var listeners = gracenet.Net{} var Version = "DEV" var BuildTime = "unknown" +var Log *logrus.Logger // Shutdown channel used by the app. When closed, app must terminate. // May be closed by the Windows service runner. var shutdownC chan struct{} +type BuildAndRuntimeInfo struct { + GoOS string `json:"go_os"` + GoVersion string `json:"go_version"` + GoArch string `json:"go_arch"` + WarpVersion string `json:"warp_version"` + WarpFlags map[string]interface{} `json:"warp_flags"` + WarpEnvs map[string]string `json:"warp_envs"` +} + func main() { metrics.RegisterBuildInfo(BuildTime, Version) raven.SetDSN(sentryDSN) @@ -84,6 +97,12 @@ WARNING: Usage: "Disable periodic check for updates, restarting the server with the new version.", Value: false, }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: "is-autoupdated", + Usage: "Signal the new process that Warp client has been autoupdated", + Value: false, + Hidden: true, + }), altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ Name: "edge", Usage: "Address of the Cloudflare tunnel server.", @@ -99,12 +118,12 @@ WARNING: altsrc.NewStringFlag(&cli.StringFlag{ Name: "origincert", Usage: "Path to the certificate generated for your origin when you run cloudflare-warp login.", - EnvVars: []string{"ORIGIN_CERT"}, + EnvVars: []string{"TUNNEL_ORIGIN_CERT"}, Value: filepath.Join(defaultConfigDir, credentialFile), }), altsrc.NewStringFlag(&cli.StringFlag{ Name: "url", - Value: "http://localhost:8080", + Value: "https://localhost:8080", Usage: "Connect to the local webserver at `URL`.", EnvVars: []string{"TUNNEL_URL"}, }), @@ -190,15 +209,21 @@ WARNING: EnvVars: []string{"TUNNEL_RETRIES"}, }), altsrc.NewBoolFlag(&cli.BoolFlag{ - Name: "hello-world", - Usage: "Run Hello World Server", - Value: false, + Name: "hello-world", + Value: false, + Usage: "Run Hello World Server", + EnvVars: []string{"TUNNEL_HELLO_WORLD"}, }), altsrc.NewStringFlag(&cli.StringFlag{ Name: "pidfile", Usage: "Write the application's PID to this file after first successful connection.", EnvVars: []string{"TUNNEL_PIDFILE"}, }), + altsrc.NewStringFlag(&cli.StringFlag{ + Name: "logfile", + Usage: "Save application log to this file for reporting issues.", + EnvVars: []string{"TUNNEL_LOGFILE"}, + }), altsrc.NewIntFlag(&cli.IntFlag{ Name: "ha-connections", Value: 4, @@ -239,6 +264,7 @@ WARNING: return nil } app.Before = func(context *cli.Context) error { + Log = logrus.New() inputSource, err := findInputSourceContext(context) if err != nil { return err @@ -248,7 +274,7 @@ WARNING: return nil } app.Commands = []*cli.Command{ - &cli.Command{ + { Name: "update", Action: update, Usage: "Update the agent if a new version exists", @@ -259,7 +285,7 @@ WARNING: To determine if an update happened in a script, check for error code 64.`, }, - &cli.Command{ + { Name: "login", Action: login, Usage: "Generate a configuration file with your login details", @@ -271,7 +297,7 @@ WARNING: }, }, }, - &cli.Command{ + { Name: "hello", Action: hello, Usage: "Run a simple \"Hello World\" server for testing Cloudflare Warp.", @@ -293,27 +319,43 @@ func startServer(c *cli.Context) { errC := make(chan error) wg.Add(2) - if c.NumFlags() == 0 && c.NArg() == 0 { + // If the user choose to supply all options through env variables, + // c.NumFlags() == 0 && c.NArg() == 0. For warp to work, the user needs to at + // least provide a hostname. + if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" { cli.ShowAppHelp(c) return } - - logLevel, err := log.ParseLevel(c.String("loglevel")) + logLevel, err := logrus.ParseLevel(c.String("loglevel")) if err != nil { - log.WithError(err).Fatal("Unknown logging level specified") + Log.WithError(err).Fatal("Unknown logging level specified") } + logrus.SetLevel(logLevel) - log.SetLevel(logLevel) - protoLogLevel, err := log.ParseLevel(c.String("proto-loglevel")) + protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel")) if err != nil { - log.WithError(err).Fatal("Unknown protocol logging level specified") + Log.WithError(err).Fatal("Unknown protocol logging level specified") } - protoLogger := log.New() + protoLogger := logrus.New() protoLogger.Level = protoLogLevel + if c.String("logfile") != "" { + if err := initLogFile(c, protoLogger); err != nil { + Log.Error(err) + } + } + + if !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0 { + if initUpdate() { + return + } + Log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq")) + go autoupdate(c.Duration("autoupdate-freq"), shutdownC) + } + hostname, err := validation.ValidateHostname(c.String("hostname")) if err != nil { - log.WithError(err).Fatal("Invalid hostname") + Log.WithError(err).Fatal("Invalid hostname") } clientID := c.String("id") @@ -323,46 +365,44 @@ func startServer(c *cli.Context) { tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) if err != nil { - log.WithError(err).Fatal("Tag parse failure") + Log.WithError(err).Fatal("Tag parse failure") } tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) - if c.IsSet("hello-world") { wg.Add(1) - listener, err := findAvailablePort() + listener, err := createListener("127.0.0.1:") if err != nil { listener.Close() - log.WithError(err).Fatal("Cannot start Hello World Server") + Log.WithError(err).Fatal("Cannot start Hello World Server") } go func() { startHelloWorldServer(listener, shutdownC) wg.Done() listener.Close() }() - c.Set("url", "http://"+listener.Addr().String()) - log.Infof("Starting Hello World Server at %s", c.String("url")) + c.Set("url", "https://"+listener.Addr().String()) } url, err := validateUrl(c) if err != nil { - log.WithError(err).Fatal("Error validating url") + Log.WithError(err).Fatal("Error validating url") } - log.Infof("Proxying tunnel requests to %s", url) + Log.Infof("Proxying tunnel requests to %s", url) // Fail if the user provided an old authentication method if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") { - log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login") + Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login") } // Check that the user has acquired a certificate using the log in command originCertPath, err := homedir.Expand(c.String("origincert")) if err != nil { - log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert")) + Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert")) } ok, err := fileExists(originCertPath) if !ok { - log.Fatalf(`Cannot find a valid certificate for your origin at the path: + Log.Fatalf(`Cannot find a valid certificate for your origin at the path: %s @@ -375,8 +415,9 @@ If you don't have a certificate signed by Cloudflare, run the command: // Easier to send the certificate as []byte via RPC than decoding it at this point originCert, err := ioutil.ReadFile(originCertPath) if err != nil { - log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath) + Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath) } + tunnelMetrics := origin.NewTunnelMetrics() httpTransport := &http.Transport{ Proxy: http.ProxyFromEnvironment, @@ -389,13 +430,15 @@ If you don't have a certificate signed by Cloudflare, run the command: IdleConnTimeout: c.Duration("proxy-keepalive-timeout"), TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"), ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()}, } tunnelConfig := &origin.TunnelConfig{ EdgeAddrs: c.StringSlice("edge"), OriginUrl: url, Hostname: hostname, OriginCert: originCert, - TlsConfig: &tls.Config{}, + TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")), + ClientTlsConfig: httpTransport.TLSClientConfig, Retries: c.Uint("retries"), HeartbeatInterval: c.Duration("heartbeat-interval"), MaxHeartbeats: c.Uint64("heartbeat-count"), @@ -408,18 +451,11 @@ If you don't have a certificate signed by Cloudflare, run the command: Metrics: tunnelMetrics, MetricsUpdateFreq: c.Duration("metrics-update-freq"), ProtocolLogger: protoLogger, + Logger: Log, + IsAutoupdated: c.Bool("is-autoupdated"), } connectedSignal := make(chan struct{}) - tunnelConfig.TlsConfig = tlsconfig.CLIFlags{RootCA: "cacert"}.GetConfig(c) - if tunnelConfig.TlsConfig.RootCAs == nil { - tunnelConfig.TlsConfig.RootCAs = GetCloudflareRootCA() - tunnelConfig.TlsConfig.ServerName = "cftunnel.com" - } else if len(tunnelConfig.EdgeAddrs) > 0 { - // Set for development environments and for testing specific origintunneld instances - tunnelConfig.TlsConfig.ServerName, _, _ = net.SplitHostPort(tunnelConfig.EdgeAddrs[0]) - } - go writePidFile(connectedSignal, c.String("pidfile")) go func() { errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal) @@ -428,24 +464,21 @@ If you don't have a certificate signed by Cloudflare, run the command: metricsListener, err := listeners.Listen("tcp", c.String("metrics")) if err != nil { - log.WithError(err).Fatal("Error opening metrics server listener") + Log.WithError(err).Fatal("Error opening metrics server listener") } go func() { errC <- metrics.ServeMetrics(metricsListener, shutdownC) wg.Done() }() - if !c.Bool("no-autoupdate") { - log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq")) - go autoupdate(c.Duration("autoupdate-period"), shutdownC) - } - + var errCode int err = WaitForSignal(errC, shutdownC) if err != nil { - log.WithError(err).Error("Quitting due to error") + Log.WithError(err).Error("Quitting due to error") raven.CaptureErrorAndWait(err, nil) + errCode = 1 } else { - log.Info("Quitting...") + Log.Info("Quitting...") } // Wait for clean exit, discarding all errors go func() { @@ -453,6 +486,7 @@ If you don't have a certificate signed by Cloudflare, run the command: } }() wg.Wait() + os.Exit(errCode) } func WaitForSignal(errC chan error, shutdownC chan struct{}) error { @@ -477,30 +511,40 @@ func update(c *cli.Context) error { return nil } -func autoupdate(frequency time.Duration, shutdownC chan struct{}) { - if int64(frequency) == 0 { - return +func initUpdate() bool { + if updateApplied() { + os.Args = append(os.Args, "--is-autoupdated=true") + if _, err := listeners.StartProcess(); err != nil { + Log.WithError(err).Error("Unable to restart server automatically") + return false + } + return true } + return false +} + +func autoupdate(freq time.Duration, shutdownC chan struct{}) { for { if updateApplied() { + os.Args = append(os.Args, "--is-autoupdated=true") if _, err := listeners.StartProcess(); err != nil { - log.WithError(err).Error("Unable to restart server automatically") + Log.WithError(err).Error("Unable to restart server automatically") } close(shutdownC) return } - time.Sleep(frequency) + time.Sleep(freq) } } func updateApplied() bool { releaseInfo := checkForUpdates() if releaseInfo.Updated { - log.Infof("Updated to version %s", releaseInfo.Version) + Log.Infof("Updated to version %s", releaseInfo.Version) return true } if releaseInfo.Error != nil { - log.WithError(releaseInfo.Error).Error("Update check failed") + Log.WithError(releaseInfo.Error).Error("Update check failed") } return false } @@ -555,7 +599,7 @@ func writePidFile(waitForSignal chan struct{}, pidFile string) { } file, err := os.Create(pidFile) if err != nil { - log.WithError(err).Errorf("Unable to write pid to %s", pidFile) + Log.WithError(err).Errorf("Unable to write pid to %s", pidFile) } defer file.Close() fmt.Fprintf(file, "%d", os.Getpid()) @@ -573,3 +617,55 @@ func validateUrl(c *cli.Context) (string, error) { validUrl, err := validation.ValidateUrl(url) return validUrl, err } + +func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error { + fileMode := os.O_WRONLY|os.O_APPEND|os.O_CREATE|os.O_TRUNC + // do not truncate log file if the client has been autoupdated + if c.Bool("is-autoupdated") { + fileMode = os.O_WRONLY|os.O_APPEND|os.O_CREATE + } + f, err := os.OpenFile(c.String("logfile"), fileMode, 0664) + if err != nil { + errors.Wrap(err, fmt.Sprintf("Cannot open file %s", c.String("logfile"))) + } + defer f.Close() + + pathMap := lfshook.PathMap{ + logrus.InfoLevel: c.String("logfile"), + logrus.ErrorLevel: c.String("logfile"), + logrus.FatalLevel: c.String("logfile"), + logrus.PanicLevel: c.String("logfile"), + } + + Log.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{})) + protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{})) + + flags := make(map[string]interface{}) + envs := make(map[string]string) + + for _, flag := range c.LocalFlagNames() { + flags[flag] = c.Generic(flag) + } + + // Find env variables for Warp + for _, env := range os.Environ() { + // All Warp env variables start with TUNNEL_ + if strings.Contains(env, "TUNNEL_") { + vars := strings.Split(env, "=") + if len(vars) == 2 { + envs[vars[0]] = vars[1] + } + } + } + + Log.Infof("Warp build and runtime configuration: %+v", BuildAndRuntimeInfo{ + GoOS: runtime.GOOS, + GoVersion: runtime.Version(), + GoArch: runtime.GOARCH, + WarpVersion: Version, + WarpFlags: flags, + WarpEnvs: envs, + }) + + return nil +} diff --git a/cmd/cloudflare-warp/windows_service.go b/cmd/cloudflare-warp/windows_service.go index 0c0181f5..2dea8120 100644 --- a/cmd/cloudflare-warp/windows_service.go +++ b/cmd/cloudflare-warp/windows_service.go @@ -9,7 +9,7 @@ import ( "fmt" "os" - log "github.com/Sirupsen/logrus" + log "github.com/sirupsen/logrus" cli "gopkg.in/urfave/cli.v2" "golang.org/x/sys/windows/svc" diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go index 624123b5..928b162b 100644 --- a/h2mux/h2mux.go +++ b/h2mux/h2mux.go @@ -6,7 +6,7 @@ import ( "sync" "time" - log "github.com/Sirupsen/logrus" + log "github.com/sirupsen/logrus" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" ) diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go index e80c04ee..49480cf4 100644 --- a/h2mux/h2mux_test.go +++ b/h2mux/h2mux_test.go @@ -6,13 +6,14 @@ import ( "io" "io/ioutil" "math/rand" + "net" "os" "strconv" "sync" "testing" "time" - log "github.com/Sirupsen/logrus" + log "github.com/sirupsen/logrus" ) func TestMain(m *testing.M) { @@ -25,25 +26,20 @@ func TestMain(m *testing.M) { type DefaultMuxerPair struct { OriginMuxConfig MuxerConfig OriginMux *Muxer - OriginWriter *io.PipeWriter - OriginReader *io.PipeReader + OriginConn net.Conn EdgeMuxConfig MuxerConfig EdgeMux *Muxer - EdgeWriter *io.PipeWriter - EdgeReader *io.PipeReader + EdgeConn net.Conn doneC chan struct{} } func NewDefaultMuxerPair() *DefaultMuxerPair { - originReader, edgeWriter := io.Pipe() - edgeReader, originWriter := io.Pipe() + origin, edge := net.Pipe() return &DefaultMuxerPair{ OriginMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: true, Name: "origin"}, - OriginWriter: originWriter, - OriginReader: originReader, + OriginConn: origin, EdgeMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: false, Name: "edge"}, - EdgeWriter: edgeWriter, - EdgeReader: edgeReader, + EdgeConn: edge, doneC: make(chan struct{}), } } @@ -53,12 +49,12 @@ func (p *DefaultMuxerPair) Handshake(t *testing.T) { originErrC := make(chan error) go func() { var err error - p.EdgeMux, err = Handshake(p.EdgeWriter, p.EdgeReader, p.EdgeMuxConfig) + p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig) edgeErrC <- err }() go func() { var err error - p.OriginMux, err = Handshake(p.OriginWriter, p.OriginReader, p.OriginMuxConfig) + p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig) originErrC <- err }() @@ -120,8 +116,8 @@ func (p *DefaultMuxerPair) Wait(t *testing.T) { func TestHandshake(t *testing.T) { muxPair := NewDefaultMuxerPair() muxPair.Handshake(t) - AssertIfPipeReadable(t, muxPair.OriginReader) - AssertIfPipeReadable(t, muxPair.EdgeReader) + AssertIfPipeReadable(t, muxPair.OriginConn) + AssertIfPipeReadable(t, muxPair.EdgeConn) } func TestSingleStream(t *testing.T) { @@ -145,7 +141,7 @@ func TestSingleStream(t *testing.T) { stream.Write(buf) // after this receive, the edge closed the stream <-closeC - n, err := stream.Read(buf) + n, err := io.ReadFull(stream, buf) if n > 0 { t.Fatalf("read %d bytes after EOF", n) } @@ -173,7 +169,7 @@ func TestSingleStream(t *testing.T) { t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value) } responseBody := make([]byte, 11) - n, err := stream.Read(responseBody) + n, err := io.ReadFull(stream, responseBody) if err != nil { t.Fatalf("error from (*MuxedStream).Read: %s", err) } @@ -243,7 +239,7 @@ func TestSingleStreamLargeResponseBody(t *testing.T) { t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value) } responseBody := make([]byte, bodySize) - n, err := stream.Read(responseBody) + n, err := io.ReadFull(stream, responseBody) if err != nil { t.Fatalf("error from (*MuxedStream).Read: %s", err) } @@ -302,7 +298,7 @@ func TestMultipleStreams(t *testing.T) { return } responseBody := make([]byte, 2) - n, err := stream.Read(responseBody) + n, err := io.ReadFull(stream, responseBody) if err != nil { errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err) return @@ -392,7 +388,7 @@ func TestMultipleStreamsFlowControl(t *testing.T) { } responseBody := make([]byte, responseSizes[(stream.streamID-2)/2]) - n, err := stream.Read(responseBody) + n, err := io.ReadFull(stream, responseBody) if err != nil { errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err) return @@ -451,7 +447,7 @@ func TestGracefulShutdown(t *testing.T) { } responseBody := make([]byte, len(responseBuf)) log.Debugf("Waiting for %d bytes", len(responseBuf)) - n, err := stream.Read(responseBody) + n, err := io.ReadFull(stream, responseBody) if err != nil { t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err) } @@ -498,13 +494,13 @@ func TestUnexpectedShutdown(t *testing.T) { nil, ) // Close the underlying connection before telling the origin to write. - muxPair.EdgeReader.Close() + muxPair.EdgeConn.Close() close(sendC) if err != nil { t.Fatalf("error in OpenStream: %s", err) } responseBody := make([]byte, len(responseBuf)) - n, err := stream.Read(responseBody) + n, err := io.ReadFull(stream, responseBody) if err != io.EOF { t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err) } @@ -545,14 +541,14 @@ func TestOpenAfterDisconnect(t *testing.T) { switch i { case 0: // Close both directions of the connection to cause EOF on both peers. - muxPair.OriginReader.Close() - muxPair.OriginWriter.Close() + muxPair.OriginConn.Close() + muxPair.EdgeConn.Close() case 1: - // Close origin reader (edge writer) to cause EOF on origin only. - muxPair.OriginReader.Close() + // Close origin conn to cause EOF on origin first. + muxPair.OriginConn.Close() case 2: - // Close origin writer (edge reader) to cause EOF on edge only. - muxPair.OriginWriter.Close() + // Close edge conn to cause EOF on edge first. + muxPair.EdgeConn.Close() } _, err := muxPair.EdgeMux.OpenStream( @@ -623,7 +619,7 @@ func TestHPACK(t *testing.T) { } } -func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) { +func AssertIfPipeReadable(t *testing.T, pipe io.ReadCloser) { errC := make(chan error) go func() { b := []byte{0} @@ -640,7 +636,5 @@ func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) { } case <-time.After(100 * time.Millisecond): // nothing to read - pipe.Close() - <-errC } } diff --git a/h2mux/muxedstream_test.go b/h2mux/muxedstream_test.go index 0987221e..a7ac63b6 100644 --- a/h2mux/muxedstream_test.go +++ b/h2mux/muxedstream_test.go @@ -1,6 +1,7 @@ package h2mux import ( + "io" "testing" "github.com/stretchr/testify/assert" @@ -63,3 +64,27 @@ func TestFlowControlSingleStream(t *testing.T) { assert.Equal(t, testWindowSize<<2, stream.receiveWindow) assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax) } + +func TestMuxedStreamEOF(t *testing.T) { + for i := 0; i < 4096; i++ { + readyList := NewReadyList() + stream := &MuxedStream{ + streamID: 1, + readBuffer: NewSharedBuffer(), + receiveWindow: 65536, + receiveWindowMax: 65536, + sendWindow: 65536, + readyList: readyList, + } + + go func() { stream.Close() }() + n, err := stream.Read([]byte{0}) + assert.Equal(t, io.EOF, err) + assert.Equal(t, 0, n) + // Write comes after read, because write buffers data before it is flushed. It wouldn't know about EOF + // until some time later. Calling read first forces it to know about EOF now. + n, err = stream.Write([]byte{1}) + assert.Equal(t, io.EOF, err) + assert.Equal(t, 0, n) + } +} diff --git a/h2mux/muxreader.go b/h2mux/muxreader.go index 1e49b8b0..a0b56366 100644 --- a/h2mux/muxreader.go +++ b/h2mux/muxreader.go @@ -6,7 +6,7 @@ import ( "sync" "time" - log "github.com/Sirupsen/logrus" + log "github.com/sirupsen/logrus" "golang.org/x/net/http2" ) diff --git a/h2mux/muxwriter.go b/h2mux/muxwriter.go index 0da90832..1cfbf1f3 100644 --- a/h2mux/muxwriter.go +++ b/h2mux/muxwriter.go @@ -6,7 +6,7 @@ import ( "io" "time" - log "github.com/Sirupsen/logrus" + log "github.com/sirupsen/logrus" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" ) diff --git a/h2mux/shared_buffer.go b/h2mux/shared_buffer.go index 4c1b713e..31868c8c 100644 --- a/h2mux/shared_buffer.go +++ b/h2mux/shared_buffer.go @@ -21,7 +21,7 @@ func NewSharedBuffer() *SharedBuffer { func (s *SharedBuffer) Read(p []byte) (n int, err error) { totalRead := 0 s.cond.L.Lock() - for totalRead < len(p) { + for totalRead == 0 { n, err = s.buffer.Read(p[totalRead:]) totalRead += n if err == io.EOF { @@ -29,6 +29,9 @@ func (s *SharedBuffer) Read(p []byte) (n int, err error) { break } err = nil + if n > 0 { + break + } s.cond.Wait() } } diff --git a/h2mux/shared_buffer_test.go b/h2mux/shared_buffer_test.go index 939228e1..5de77fc9 100644 --- a/h2mux/shared_buffer_test.go +++ b/h2mux/shared_buffer_test.go @@ -6,6 +6,8 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) { @@ -29,30 +31,35 @@ func TestSharedBuffer(t *testing.T) { func TestSharedBufferBlockingRead(t *testing.T) { b := NewSharedBuffer() - testData := []byte("Hello world") + testData1 := []byte("Hello") + testData2 := []byte(" world") result := make(chan []byte) go func() { - bytesRead := make([]byte, len(testData)) - AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead)) - result <- bytesRead + bytesRead := make([]byte, len(testData1)+len(testData2)) + nRead, err := b.Read(bytesRead) + AssertIOReturnIsGood(t, len(testData1))(nRead, err) + result <- bytesRead[:nRead] + nRead, err = b.Read(bytesRead) + AssertIOReturnIsGood(t, len(testData2))(nRead, err) + result <- bytesRead[:nRead] }() + time.Sleep(time.Millisecond * 250) select { case <-result: t.Fatalf("read returned early") default: } - AssertIOReturnIsGood(t, 5)(b.Write(testData[:5])) - select { - case <-result: - t.Fatalf("read returned early") - default: - } - AssertIOReturnIsGood(t, len(testData)-5)(b.Write(testData[5:])) + AssertIOReturnIsGood(t, len(testData1))(b.Write([]byte(testData1))) select { case r := <-result: - if string(r) != string(testData) { - t.Fatalf("expected read to return %s, got %s", testData, r) - } + assert.Equal(t, testData1, r) + case <-time.After(time.Second): + t.Fatalf("read timed out") + } + AssertIOReturnIsGood(t, len(testData2))(b.Write([]byte(testData2))) + select { + case r := <-result: + assert.Equal(t, testData2, r) case <-time.After(time.Second): t.Fatalf("read timed out") } @@ -85,7 +92,7 @@ func TestSharedBufferConcurrentReadWrite(t *testing.T) { // Change block sizes in opposition to the write thread, to test blocking for new data. for blockSize := 256; blockSize > 0; blockSize-- { for i := 0; i < 256; i++ { - n, err := b.Read(block[:blockSize]) + n, err := io.ReadFull(b, block[:blockSize]) if n != blockSize || err != nil { t.Fatalf("read error: %d %s", n, err) } diff --git a/metrics/metrics.go b/metrics/metrics.go index b7b5743e..4707b4ed 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -11,9 +11,9 @@ import ( "golang.org/x/net/context" "golang.org/x/net/trace" - log "github.com/Sirupsen/logrus" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" + log "github.com/sirupsen/logrus" ) const ( diff --git a/origin/metrics.go b/origin/metrics.go index f7eb9476..314d4060 100644 --- a/origin/metrics.go +++ b/origin/metrics.go @@ -5,7 +5,6 @@ import ( "github.com/cloudflare/cloudflare-warp/h2mux" - log "github.com/Sirupsen/logrus" "github.com/prometheus/client_golang/prometheus" ) @@ -249,7 +248,7 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) { if _, ok := t.concurrentRequests[connectionID]; ok { t.concurrentRequests[connectionID] -= 1 } else { - log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.") + Log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.") } t.concurrentRequestsLock.Unlock() diff --git a/origin/supervisor.go b/origin/supervisor.go index ea514262..35c73f63 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -5,7 +5,6 @@ import ( "net" "time" - log "github.com/Sirupsen/logrus" "golang.org/x/net/context" ) @@ -73,7 +72,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err case tunnelError := <-s.tunnelErrors: tunnelsActive-- if tunnelError.err != nil { - log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error") + Log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error") tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) s.waitForNextTunnel(tunnelError.index) if backoffTimer == nil { @@ -107,10 +106,10 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err s.lastResolve = time.Now() s.resolverC = nil if result.err == nil { - log.Debug("Service discovery refresh complete") + Log.Debug("Service discovery refresh complete") s.edgeIPs = result.edgeIPs } else { - log.WithError(result.err).Error("Service discovery error") + Log.WithError(result.err).Error("Service discovery error") } } } @@ -120,12 +119,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error { edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs) if err != nil { - log.Infof("ResolveEdgeIPs err") + Log.Infof("ResolveEdgeIPs err") return err } s.edgeIPs = edgeIPs if s.config.HAConnections > len(edgeIPs) { - log.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs)) + Log.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs)) s.config.HAConnections = len(edgeIPs) } s.lastResolve = time.Now() diff --git a/origin/tunnel.go b/origin/tunnel.go index cf0f446a..dfeb3ee8 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -19,14 +19,17 @@ import ( "github.com/cloudflare/cloudflare-warp/tunnelrpc" tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" "github.com/cloudflare/cloudflare-warp/validation" + "github.com/cloudflare/cloudflare-warp/websocket" - log "github.com/Sirupsen/logrus" raven "github.com/getsentry/raven-go" "github.com/pkg/errors" _ "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" rpc "zombiezen.com/go/capnproto2/rpc" ) +var Log *logrus.Logger + const ( dialTimeout = 15 * time.Second @@ -40,6 +43,7 @@ type TunnelConfig struct { Hostname string OriginCert []byte TlsConfig *tls.Config + ClientTlsConfig *tls.Config Retries uint HeartbeatInterval time.Duration MaxHeartbeats uint64 @@ -51,7 +55,9 @@ type TunnelConfig struct { HTTPTransport http.RoundTripper Metrics *TunnelMetrics MetricsUpdateFreq time.Duration - ProtocolLogger *log.Logger + ProtocolLogger *logrus.Logger + Logger *logrus.Logger + IsAutoupdated bool } type dialError struct { @@ -87,14 +93,16 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str Version: c.ReportedVersion, OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH), ExistingTunnelPolicy: policy, - PoolID: c.LBPool, + PoolName: c.LBPool, Tags: c.Tags, ConnectionID: connectionID, OriginLocalIP: OriginLocalIP, + IsAutoupdated: c.IsAutoupdated, } } func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error { + Log = config.Logger ctx, cancel := context.WithCancel(context.Background()) go func() { <-shutdownC @@ -129,7 +137,7 @@ func ServeTunnelLoop(ctx context.Context, config *TunnelConfig, addr *net.TCPAdd err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff) if recoverable { if duration, ok := backoff.GetBackoffDuration(ctx); ok { - log.Infof("Retrying in %s seconds", duration) + Log.Infof("Retrying in %s seconds", duration) backoff.Backoff(ctx) continue } @@ -162,11 +170,10 @@ func ServeTunnel( // Returns error from parsing the origin URL or handshake errors handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID) if err != nil { - errLog := log.WithError(err) + errLog := Log.WithError(err) switch err.(type) { case dialError: errLog.Error("Unable to dial edge") - return err, false case h2mux.MuxerHandshakeError: errLog.Error("Handshake failed with edge server") default: @@ -207,24 +214,21 @@ func ServeTunnel( registerErr := <-registerErrC wg.Wait() if err != nil { - log.WithError(err).Error("Tunnel error") + Log.WithError(err).Error("Tunnel error") return err, true } if registerErr != nil { // Don't retry on errors like entitlement failure or version too old if e, ok := registerErr.(printableRegisterTunnelError); ok { - log.Error(e) - if e.permanent { - return e, false - } - return e.cause, true + Log.Error(e) + return e.cause, !e.permanent } else if e, ok := registerErr.(dupConnRegisterTunnelError); ok { - log.Info("Already connected to this server, selecting a different one") + Log.Info("Already connected to this server, selecting a different one") return e, true } // Only log errors to Sentry that may have been caused by the client side, to reduce dupes raven.CaptureError(registerErr, nil) - log.Error("Cannot register") + Log.Error("Cannot register") return err, true } return nil, false @@ -241,7 +245,7 @@ func IsRPCStreamResponse(headers []h2mux.Header) bool { } func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error { - logger := log.WithField("subsystem", "rpc") + logger := Log.WithField("subsystem", "rpc") logger.Debug("initiating RPC stream") stream, err := muxer.OpenStream([]h2mux.Header{ {Name: ":method", Value: "RPC"}, @@ -292,11 +296,11 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi } } - log.Infof("Registered at %s", registration.Url) + Log.Infof("Registered at %s", registration.Url) return nil } -func LogServerInfo(logger *log.Entry, +func LogServerInfo(logger *logrus.Entry, promise tunnelrpc.ServerInfo_Promise, connectionID uint8, metrics *TunnelMetrics, @@ -311,7 +315,7 @@ func LogServerInfo(logger *log.Entry, logger.WithError(err).Warn("Failed to retrieve server information") return } - log.Infof("Connected to %s", serverInfo.LocationName) + Log.Infof("Connected to %s", serverInfo.LocationName) metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName) } @@ -356,6 +360,7 @@ type TunnelHandler struct { originUrl string muxer *h2mux.Muxer httpClient http.RoundTripper + tlsConfig *tls.Config tags []tunnelpogs.Tag metrics *TunnelMetrics // connectionID is only used by metrics, and prometheus requires labels to be string @@ -373,6 +378,7 @@ func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, co h := &TunnelHandler{ originUrl: url, httpClient: config.HTTPTransport, + tlsConfig: config.ClientTlsConfig, tags: config.Tags, metrics: config.Metrics, connectionID: uint8ToString(connectionID), @@ -422,29 +428,45 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { h.metrics.incrementRequests(h.connectionID) req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) if err != nil { - log.WithError(err).Panic("Unexpected error from http.NewRequest") + Log.WithError(err).Panic("Unexpected error from http.NewRequest") } err = H2RequestHeadersToH1Request(stream.Headers, req) if err != nil { - log.WithError(err).Error("invalid request received") + Log.WithError(err).Error("invalid request received") } h.AppendTagHeaders(req) - response, err := h.httpClient.RoundTrip(req) - if err != nil { - log.WithError(err).Error("HTTP request error") - stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}}) - stream.Write([]byte("502 Bad Gateway")) - h.metrics.incrementResponses(h.connectionID, "502") + + if websocket.IsWebSocketUpgrade(req) { + conn, response, err := websocket.ClientConnect(req, h.tlsConfig) + if err != nil { + h.logError(stream, err) + } else { + stream.WriteHeaders(H1ResponseToH2Response(response)) + defer conn.Close() + websocket.Stream(conn.UnderlyingConn(), stream) + } } else { - defer response.Body.Close() - stream.WriteHeaders(H1ResponseToH2Response(response)) - io.Copy(stream, response.Body) - h.metrics.incrementResponses(h.connectionID, "200") + response, err := h.httpClient.RoundTrip(req) + if err != nil { + h.logError(stream, err) + } else { + defer response.Body.Close() + stream.WriteHeaders(H1ResponseToH2Response(response)) + io.Copy(stream, response.Body) + h.metrics.incrementResponses(h.connectionID, "200") + } } h.metrics.decrementConcurrentRequests(h.connectionID) return nil } +func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) { + Log.WithError(err).Error("HTTP request error") + stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}}) + stream.Write([]byte("502 Bad Gateway")) + h.metrics.incrementResponses(h.connectionID, "502") +} + func (h *TunnelHandler) UpdateMetrics() { flowCtlMetrics := h.muxer.FlowControlMetrics() h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics) diff --git a/tlsconfig/cloudflare_ca.go b/tlsconfig/cloudflare_ca.go new file mode 100644 index 00000000..f1444a7d --- /dev/null +++ b/tlsconfig/cloudflare_ca.go @@ -0,0 +1,95 @@ +package tlsconfig + +import ( + "crypto/x509" +) + +// TODO: remove the Origin CA root certs when migrated to Authenticated Origin Pull certs +const cloudflareRootCA = ` +Issuer: C=US, ST=California, L=San Francisco, O=CloudFlare, Inc., OU=CloudFlare Origin SSL ECC Certificate Authority +-----BEGIN CERTIFICATE----- +MIICiDCCAi6gAwIBAgIUXZP3MWb8MKwBE1Qbawsp1sfA/Y4wCgYIKoZIzj0EAwIw +gY8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1T +YW4gRnJhbmNpc2NvMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTgwNgYDVQQL +Ey9DbG91ZEZsYXJlIE9yaWdpbiBTU0wgRUNDIENlcnRpZmljYXRlIEF1dGhvcml0 +eTAeFw0xNjAyMjIxODI0MDBaFw0yMTAyMjIwMDI0MDBaMIGPMQswCQYDVQQGEwJV +UzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEZ +MBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE4MDYGA1UECxMvQ2xvdWRGbGFyZSBP +cmlnaW4gU1NMIEVDQyBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkwWTATBgcqhkjOPQIB +BggqhkjOPQMBBwNCAASR+sGALuaGshnUbcxKry+0LEXZ4NY6JUAtSeA6g87K3jaA +xpIg9G50PokpfWkhbarLfpcZu0UAoYy2su0EhN7wo2YwZDAOBgNVHQ8BAf8EBAMC +AQYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUhTBdOypw1O3VkmcH/es5 +tBoOOKcwHwYDVR0jBBgwFoAUhTBdOypw1O3VkmcH/es5tBoOOKcwCgYIKoZIzj0E +AwIDSAAwRQIgEiIEHQr5UKma50D1WRMJBUSgjg24U8n8E2mfw/8UPz0CIQCr5V/e +mcifak4CQsr+DH4pn5SJD7JxtCG3YGswW8QZsw== +-----END CERTIFICATE----- +Issuer: C=US, O=CloudFlare, Inc., OU=CloudFlare Origin SSL Certificate Authority, L=San Francisco, ST=California +-----BEGIN CERTIFICATE----- +MIID/DCCAuagAwIBAgIID+rOSdTGfGcwCwYJKoZIhvcNAQELMIGLMQswCQYDVQQG +EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRG +bGFyZSBPcmlnaW4gU1NMIENlcnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMN +U2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTAeFw0xNDExMTMyMDM4 +NTBaFw0xOTExMTQwMTQzNTBaMIGLMQswCQYDVQQGEwJVUzEZMBcGA1UEChMQQ2xv +dWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRGbGFyZSBPcmlnaW4gU1NMIENl +cnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEG +A1UECBMKQ2FsaWZvcm5pYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB +AMBIlWf1KEKR5hbB75OYrAcUXobpD/AxvSYRXr91mbRu+lqE7YbyyRUShQh15lem +ef+umeEtPZoLFLhcLyczJxOhI+siLGDQm/a/UDkWvAXYa5DZ+pHU5ct5nZ8pGzqJ +p8G1Hy5RMVYDXZT9F6EaHjMG0OOffH6Ih25TtgfyyrjXycwDH0u6GXt+G/rywcqz +/9W4Aki3XNQMUHNQAtBLEEIYHMkyTYJxuL2tXO6ID5cCsoWw8meHufTeZW2DyUpl +yP3AHt4149RQSyWZMJ6AyntL9d8Xhfpxd9rJkh9Kge2iV9rQTFuE1rRT5s7OSJcK +xUsklgHcGHYMcNfNMilNHb8CAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgAGMBIGA1Ud +EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFCToU1ddfDRAh6nrlNu64RZ4/CmkMB8G +A1UdIwQYMBaAFCToU1ddfDRAh6nrlNu64RZ4/CmkMAsGCSqGSIb3DQEBCwOCAQEA +cQDBVAoRrhhsGegsSFsv1w8v27zzHKaJNv6ffLGIRvXK8VKKK0gKXh2zQtN9SnaD +gYNe7Pr4C3I8ooYKRJJWLsmEHdGdnYYmj0OJfGrfQf6MLIc/11bQhLepZTxdhFYh +QGgDl6gRmb8aDwk7Q92BPvek5nMzaWlP82ixavvYI+okoSY8pwdcVKobx6rWzMWz +ZEC9M6H3F0dDYE23XcCFIdgNSAmmGyXPBstOe0aAJXwJTxOEPn36VWr0PKIQJy5Y +4o1wpMpqCOIwWc8J9REV/REzN6Z1LXImdUgXIXOwrz56gKUJzPejtBQyIGj0mveX +Fu6q54beR89jDc+oABmOgg== +-----END CERTIFICATE----- +Issuer: C=US, O=CloudFlare, Inc., OU=Origin Pull, L=San Francisco, ST=California, CN=origin-pull.cloudflare.net +-----BEGIN CERTIFICATE----- +MIIGBjCCA/CgAwIBAgIIV5G6lVbCLmEwCwYJKoZIhvcNAQENMIGQMQswCQYDVQQG +EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjEUMBIGA1UECxMLT3JpZ2lu +IFB1bGwxFjAUBgNVBAcTDVNhbiBGcmFuY2lzY28xEzARBgNVBAgTCkNhbGlmb3Ju +aWExIzAhBgNVBAMTGm9yaWdpbi1wdWxsLmNsb3VkZmxhcmUubmV0MB4XDTE1MDEx +MzAyNDc1M1oXDTIwMDExMjAyNTI1M1owgZAxCzAJBgNVBAYTAlVTMRkwFwYDVQQK +ExBDbG91ZEZsYXJlLCBJbmMuMRQwEgYDVQQLEwtPcmlnaW4gUHVsbDEWMBQGA1UE +BxMNU2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTEjMCEGA1UEAxMa +b3JpZ2luLXB1bGwuY2xvdWRmbGFyZS5uZXQwggIiMA0GCSqGSIb3DQEBAQUAA4IC +DwAwggIKAoICAQDdsts6I2H5dGyn4adACQRXlfo0KmwsN7B5rxD8C5qgy6spyONr +WV0ecvdeGQfWa8Gy/yuTuOnsXfy7oyZ1dm93c3Mea7YkM7KNMc5Y6m520E9tHooc +f1qxeDpGSsnWc7HWibFgD7qZQx+T+yfNqt63vPI0HYBOYao6hWd3JQhu5caAcIS2 +ms5tzSSZVH83ZPe6Lkb5xRgLl3eXEFcfI2DjnlOtLFqpjHuEB3Tr6agfdWyaGEEi +lRY1IB3k6TfLTaSiX2/SyJ96bp92wvTSjR7USjDV9ypf7AD6u6vwJZ3bwNisNw5L +ptph0FBnc1R6nDoHmvQRoyytoe0rl/d801i9Nru/fXa+l5K2nf1koR3IX440Z2i9 ++Z4iVA69NmCbT4MVjm7K3zlOtwfI7i1KYVv+ATo4ycgBuZfY9f/2lBhIv7BHuZal +b9D+/EK8aMUfjDF4icEGm+RQfExv2nOpkR4BfQppF/dLmkYfjgtO1403X0ihkT6T +PYQdmYS6Jf53/KpqC3aA+R7zg2birtvprinlR14MNvwOsDOzsK4p8WYsgZOR4Qr2 +gAx+z2aVOs/87+TVOR0r14irQsxbg7uP2X4t+EXx13glHxwG+CnzUVycDLMVGvuG +aUgF9hukZxlOZnrl6VOf1fg0Caf3uvV8smOkVw6DMsGhBZSJVwao0UQNqQIDAQAB +o2YwZDAOBgNVHQ8BAf8EBAMCAAYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4E +FgQUQ1lLK2mLgOERM2pXzVc42p59xeswHwYDVR0jBBgwFoAUQ1lLK2mLgOERM2pX +zVc42p59xeswCwYJKoZIhvcNAQENA4ICAQDKDQM1qPRVP/4Gltz0D6OU6xezFBKr +LWtDoA1qW2F7pkiYawCP9MrDPDJsHy7dx+xw3bBZxOsK5PA/T7p1dqpEl6i8F692 +g//EuYOifLYw3ySPe3LRNhvPl/1f6Sn862VhPvLa8aQAAwR9e/CZvlY3fj+6G5ik +3it7fikmKUsVnugNOkjmwI3hZqXfJNc7AtHDFw0mEOV0dSeAPTo95N9cxBbm9PKv +qAEmTEXp2trQ/RjJ/AomJyfA1BQjsD0j++DI3a9/BbDwWmr1lJciKxiNKaa0BRLB +dKMrYQD+PkPNCgEuojT+paLKRrMyFUzHSG1doYm46NE9/WARTh3sFUp1B7HZSBqA +kHleoB/vQ/mDuW9C3/8Jk2uRUdZxR+LoNZItuOjU8oTy6zpN1+GgSj7bHjiy9rfA +F+ehdrz+IOh80WIiqs763PGoaYUyzxLvVowLWNoxVVoc9G+PqFKqD988XlipHVB6 +Bz+1CD4D/bWrs3cC9+kk/jFmrrAymZlkFX8tDb5aXASSLJjUjcptci9SKqtI2h0J +wUGkD7+bQAr+7vr8/R+CBmNMe7csE8NeEX6lVMF7Dh0a1YKQa6hUN18bBuYgTMuT +QzMmZpRpIBB321ZBlcnlxiTJvWxvbCPHKHj20VwwAz7LONF59s84ZsOqfoBv8gKM +s0s5dsq5zpLeaw== +-----END CERTIFICATE-----` + +func GetCloudflareRootCA() *x509.CertPool { + ca := x509.NewCertPool() + if !ca.AppendCertsFromPEM([]byte(cloudflareRootCA)) { + // should never happen + panic("failure loading Cloudflare origin CA pem") + } + return ca +} diff --git a/tlsconfig/hello_ca.go b/tlsconfig/hello_ca.go new file mode 100644 index 00000000..bb49093c --- /dev/null +++ b/tlsconfig/hello_ca.go @@ -0,0 +1,50 @@ +package tlsconfig + +import ( + "crypto/tls" + "crypto/x509" +) + +const ( + helloKey = ` +-----BEGIN EC PARAMETERS----- +BgUrgQQAIg== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDAdyQBXfxTDCQSOT0HugmH9pVBtIw8t5dYvm6HxGlNq6P57v5GeN02Z +dH9FRl7+VSWgBwYFK4EEACKhZANiAATqpFzTxxV7D+/oqhKCTR6BEM9elTfKaRQE +FsLufcmaTMw/9tTwgpHKao/QsLKDTNbQhbSQLkcmpCQKlSGhl+pCrqNt/oYUAhav +UIwpwGiLCqGH/R2AqWLKRPOa/Rufs/U= +-----END EC PRIVATE KEY-----` + + helloCRT = ` +-----BEGIN CERTIFICATE----- +MIICkDCCAhigAwIBAgIJAPtKfUjc2lwGMAkGByqGSM49BAEwgYoxCzAJBgNVBAYT +AlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQKDBBD +bG91ZGZsYXJlLCBJbmMuMT8wPQYDVQQDDDZDbG91ZGZsYXJlIEFyZ28gVHVubmVs +IFNhbXBsZSBIZWxsbyBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMjE1MjAxNjU5 +WhcNMjgwMjEzMjAxNjU5WjCBijELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVRleGFz +MQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENsb3VkZmxhcmUsIEluYy4xPzA9 +BgNVBAMMNkNsb3VkZmxhcmUgQXJnbyBUdW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZl +ciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49AgEGBSuBBAAiA2IABOqkXNPHFXsP7+iq +EoJNHoEQz16VN8ppFAQWwu59yZpMzD/21PCCkcpqj9CwsoNM1tCFtJAuRyakJAqV +IaGX6kKuo23+hhQCFq9QjCnAaIsKoYf9HYCpYspE85r9G5+z9aNJMEcwRQYDVR0R +BD4wPIIJbG9jYWxob3N0ggp3YXJwLWhlbGxvggt3YXJwMi1oZWxsb4cEfwAAAYcQ +AAAAAAAAAAAAAAAAAAAAATAJBgcqhkjOPQQBA2cAMGQCMHyVPufXZ6vQo6XRWRa0 +dAwtfgesOdZVP2Wt+t5v8jOIQQh1IQXYk5GtyoZGSObjhQIwd1fRgAyKXaZt+1DV +ZtHTdf8pMvESfJsSd8AB1eQ6q+pAiRUYyaxcE1Mlo2YY5o+g +-----END CERTIFICATE-----` +) + +func GetHelloCertificate() (tls.Certificate, error) { + return tls.X509KeyPair([]byte(helloCRT), []byte(helloKey)) +} + +func GetHelloCertificateX509() (*x509.Certificate, error) { + helloCertificate, err := GetHelloCertificate() + if err != nil { + return nil, err + } + + return x509.ParseCertificate(helloCertificate.Certificate[0]) +} diff --git a/tlsconfig/tlsconfig.go b/tlsconfig/tlsconfig.go index 5b0c81a6..d81e94f2 100644 --- a/tlsconfig/tlsconfig.go +++ b/tlsconfig/tlsconfig.go @@ -6,8 +6,9 @@ import ( "crypto/tls" "crypto/x509" "io/ioutil" + "net" - log "github.com/Sirupsen/logrus" + log "github.com/sirupsen/logrus" cli "gopkg.in/urfave/cli.v2" ) @@ -60,3 +61,43 @@ func LoadCert(certPath string) *x509.CertPool { } return ca } + +func LoadOriginCertsPool() *x509.CertPool { + // First, obtain the system certificate pool + certPool, systemCertPoolErr := x509.SystemCertPool() + if systemCertPoolErr != nil { + log.Warn("error obtaining the system certificates: %s", systemCertPoolErr) + certPool = x509.NewCertPool() + } + + // Next, append the Cloudflare CA pool into the system pool + if !certPool.AppendCertsFromPEM([]byte(cloudflareRootCA)) { + log.Warn("could not append the CF certificate to the system certificate pool") + + if systemCertPoolErr != nil { // Obtaining both certificates failed; this is a fatal error + log.WithError(systemCertPoolErr).Fatalf("Error loading the certificate pool") + } + } + + // Finally, add the Hello certificate into the pool (since it's self-signed) + helloCertificate, err := GetHelloCertificateX509() + if err != nil { + log.Warn("error obtaining the Hello server certificate") + } + + certPool.AddCert(helloCertificate) + + return certPool +} + +func CreateTunnelConfig(c *cli.Context, addrs []string) *tls.Config { + tlsConfig := CLIFlags{RootCA: "cacert"}.GetConfig(c) + if tlsConfig.RootCAs == nil { + tlsConfig.RootCAs = GetCloudflareRootCA() + tlsConfig.ServerName = "cftunnel.com" + } else if len(addrs) > 0 { + // Set for development environments and for testing specific origintunneld instances + tlsConfig.ServerName, _, _ = net.SplitHostPort(addrs[0]) + } + return tlsConfig +} diff --git a/tunnelrpc/log.go b/tunnelrpc/log.go index d5bb5698..d69f167a 100644 --- a/tunnelrpc/log.go +++ b/tunnelrpc/log.go @@ -1,10 +1,9 @@ package tunnelrpc -//go:generate capnp compile -ogo -I./tunnelrpc/ tunnelrpc.capnp - import ( - log "github.com/Sirupsen/logrus" + log "github.com/sirupsen/logrus" "golang.org/x/net/context" + "golang.org/x/net/trace" "zombiezen.com/go/capnproto2/rpc" ) @@ -24,3 +23,20 @@ func (c ConnLogger) Errorf(ctx context.Context, format string, args ...interface func ConnLog(log *log.Entry) rpc.ConnOption { return rpc.ConnLog(ConnLogger{log}) } + +// ConnTracer wraps a trace.EventLog for a connection. +type ConnTracer struct { + Events trace.EventLog +} + +func (c ConnTracer) Infof(ctx context.Context, format string, args ...interface{}) { + c.Events.Printf(format, args...) +} + +func (c ConnTracer) Errorf(ctx context.Context, format string, args ...interface{}) { + c.Events.Errorf(format, args...) +} + +func ConnTrace(events trace.EventLog) rpc.ConnOption { + return rpc.ConnLog(ConnTracer{events}) +} diff --git a/tunnelrpc/logtransport.go b/tunnelrpc/logtransport.go index b31dde9c..6893ba03 100644 --- a/tunnelrpc/logtransport.go +++ b/tunnelrpc/logtransport.go @@ -4,7 +4,7 @@ package tunnelrpc import ( "bytes" - log "github.com/Sirupsen/logrus" + log "github.com/sirupsen/logrus" "golang.org/x/net/context" "zombiezen.com/go/capnproto2/encoding/text" "zombiezen.com/go/capnproto2/rpc" diff --git a/tunnelrpc/pogs/tunnelrpc.go b/tunnelrpc/pogs/tunnelrpc.go index cc2636f8..4f5280e4 100644 --- a/tunnelrpc/pogs/tunnelrpc.go +++ b/tunnelrpc/pogs/tunnelrpc.go @@ -47,10 +47,11 @@ type RegistrationOptions struct { Version string OS string `capnp:"os"` ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy - PoolID string `capnp:"poolId"` + PoolName string `capnp:"poolName"` Tags []Tag - ConnectionID uint8 `capnp:"connectionId"` + ConnectionID uint8 `capnp:"connectionId"` OriginLocalIP string `capnp:"originLocalIp"` + IsAutoupdated bool `capnp:"isAutoupdated"` } func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error { diff --git a/tunnelrpc/tunnelrpc.capnp b/tunnelrpc/tunnelrpc.capnp index a2924a9c..636f317a 100644 --- a/tunnelrpc/tunnelrpc.capnp +++ b/tunnelrpc/tunnelrpc.capnp @@ -28,13 +28,15 @@ struct RegistrationOptions { # What to do with existing tunnels for the given hostname. existingTunnelPolicy @3 :ExistingTunnelPolicy; # If using the balancing policy, identifies the LB pool to use. - poolId @4 :Text; + poolName @4 :Text; # Client-defined tags to associate with the tunnel tags @5 :List(Tag); # A unique identifier for a high-availability connection made by a single client. connectionId @6 :UInt8; # origin LAN IP originLocalIp @7 :Text; + # whether Warp client has been autoupdated + isAutoupdated @8 :Bool; } struct Tag { diff --git a/websocket/websocket.go b/websocket/websocket.go new file mode 100644 index 00000000..d821c706 --- /dev/null +++ b/websocket/websocket.go @@ -0,0 +1,77 @@ +package websocket + +import ( + "bufio" + "crypto/sha1" + "crypto/tls" + "encoding/base64" + "errors" + "io" + "net" + "net/http" + + "github.com/gorilla/websocket" +) + +// IsWebSocketUpgrade checks to see if the request is a WebSocket connection. +func IsWebSocketUpgrade(req *http.Request) bool { + return websocket.IsWebSocketUpgrade(req) +} + +// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing. +func ClientConnect(req *http.Request, tlsClientConfig *tls.Config) (*websocket.Conn, *http.Response, error) { + req.URL.Scheme = "wss" + d := &websocket.Dialer{TLSClientConfig: tlsClientConfig} + conn, response, err := d.Dial(req.URL.String(), nil) + if err != nil { + return nil, nil, err + } + response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req)) + return conn, response, err +} + +// HijackConnection takes over an HTTP connection. Caller is responsible for closing connection. +func HijackConnection(w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { + hj, ok := w.(http.Hijacker) + if !ok { + return nil, nil, errors.New("hijack error") + } + + conn, brw, err := hj.Hijack() + if err != nil { + return nil, nil, err + } + return conn, brw, nil +} + +// Stream copies copy data to & from provided io.ReadWriters. +func Stream(conn, backendConn io.ReadWriter) { + proxyDone := make(chan struct{}, 2) + + go func() { + io.Copy(conn, backendConn) + proxyDone <- struct{}{} + }() + + go func() { + io.Copy(backendConn, conn) + proxyDone <- struct{}{} + }() + + // If one side is done, we are done. + <-proxyDone +} + +// sha1Base64 sha1 and then base64 encodes str. +func sha1Base64(str string) string { + hasher := sha1.New() + io.WriteString(hasher, str) + hash := hasher.Sum(nil) + return base64.StdEncoding.EncodeToString(hash) +} + +// generateAcceptKey returns the string needed for the Sec-WebSocket-Accept header. +// https://tools.ietf.org/html/rfc6455#section-1.3 describes this process in more detail. +func generateAcceptKey(req *http.Request) string { + return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") +}