Release Warp Client 2018.2.1
This commit is contained in:
		
							parent
							
								
									e0ae598112
								
							
						
					
					
						commit
						3780e14f41
					
				|  | @ -2,17 +2,20 @@ package main | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"html/template" | 	"html/template" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"os" | 	"os" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/pkg/errors" | 	"github.com/gorilla/websocket" | ||||||
|  | 	"gopkg.in/urfave/cli.v2" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" | 	"github.com/cloudflare/cloudflare-warp/tlsconfig" | ||||||
| 	cli "gopkg.in/urfave/cli.v2" |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type templateData struct { | type templateData struct { | ||||||
|  | @ -21,6 +24,11 @@ type templateData struct { | ||||||
| 	Body       string | 	Body       string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | type OriginUpTime struct { | ||||||
|  | 	StartTime time.Time `json:"startTime"` | ||||||
|  | 	UpTime    string    `json:"uptime"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
| const defaultServerName = "the Cloudflare Warp test server" | const defaultServerName = "the Cloudflare Warp test server" | ||||||
| const indexTemplate = ` | const indexTemplate = ` | ||||||
| <!DOCTYPE html> | <!DOCTYPE html> | ||||||
|  | @ -85,54 +93,95 @@ const indexTemplate = ` | ||||||
| 
 | 
 | ||||||
| func hello(c *cli.Context) error { | func hello(c *cli.Context) error { | ||||||
| 	address := fmt.Sprintf(":%d", c.Int("port")) | 	address := fmt.Sprintf(":%d", c.Int("port")) | ||||||
| 	server := NewHelloWorldServer() | 	listener, err := createListener(address) | ||||||
| 	if hostname, err := os.Hostname(); err != nil { | 	if err != nil { | ||||||
| 		server.serverName = hostname | 		return err | ||||||
| 	} | 	} | ||||||
| 	err := server.ListenAndServe(address) | 	defer listener.Close() | ||||||
| 	return errors.Wrap(err, "Fail to start Hello World Server") | 	err = startHelloWorldServer(listener, nil) | ||||||
|  | 	return err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error { | func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error { | ||||||
| 	server := NewHelloWorldServer() | 	Log.Infof("Starting Hello World server at %s", listener.Addr()) | ||||||
| 	if hostname, err := os.Hostname(); err != nil { | 	serverName := defaultServerName | ||||||
| 		server.serverName = hostname | 	if hostname, err := os.Hostname(); err == nil { | ||||||
|  | 		serverName = hostname | ||||||
| 	} | 	} | ||||||
| 	httpServer := &http.Server{Addr: listener.Addr().String(), Handler: server} | 
 | ||||||
|  | 	upgrader := websocket.Upgrader{ | ||||||
|  | 		ReadBufferSize:  1024, | ||||||
|  | 		WriteBufferSize: 1024, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil} | ||||||
| 	go func() { | 	go func() { | ||||||
| 		<-shutdownC | 		<-shutdownC | ||||||
| 		httpServer.Close() | 		httpServer.Close() | ||||||
| 	}() | 	}() | ||||||
|  | 
 | ||||||
|  | 	http.HandleFunc("/uptime", uptimeHandler(time.Now())) | ||||||
|  | 	http.HandleFunc("/ws", websocketHandler(upgrader)) | ||||||
|  | 	http.HandleFunc("/", rootHandler(serverName)) | ||||||
| 	err := httpServer.Serve(listener) | 	err := httpServer.Serve(listener) | ||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type HelloWorldServer struct { | func createListener(address string) (net.Listener, error) { | ||||||
| 	responseTemplate *template.Template | 	certificate, err := tlsconfig.GetHelloCertificate() | ||||||
| 	serverName       string | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| func NewHelloWorldServer() *HelloWorldServer { | 	// If the port in address is empty, a port number is automatically chosen
 | ||||||
| 	return &HelloWorldServer{ | 	listener, err := tls.Listen( | ||||||
| 		responseTemplate: template.Must(template.New("index").Parse(indexTemplate)), | 		"tcp", | ||||||
| 		serverName:       defaultServerName, | 		address, | ||||||
| 	} | 		&tls.Config{Certificates: []tls.Certificate{certificate}}) | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| func findAvailablePort() (net.Listener, error) { |  | ||||||
| 	// If the port in address is empty, a port number is automatically chosen.
 |  | ||||||
| 	listener, err := net.Listen("tcp", "127.0.0.1:") |  | ||||||
| 	return listener, err | 	return listener, err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *HelloWorldServer) ListenAndServe(address string) error { | func uptimeHandler(startTime time.Time) http.HandlerFunc { | ||||||
| 	log.Infof("Starting Hello World server on %s", address) | 	return func(w http.ResponseWriter, r *http.Request) { | ||||||
| 	err := http.ListenAndServe(address, s) | 		// Note that if autoupdate is enabled, the uptime is reset when a new client
 | ||||||
| 	return err | 		// release is available
 | ||||||
|  | 		resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()} | ||||||
|  | 		respJson, err := json.Marshal(resp) | ||||||
|  | 		if err != nil { | ||||||
|  | 			w.WriteHeader(http.StatusInternalServerError) | ||||||
|  | 		} else { | ||||||
|  | 			w.Header().Set("Content-Type", "application/json") | ||||||
|  | 			w.Write(respJson) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { | func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc { | ||||||
| 	log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto) | 	return func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 		conn, err := upgrader.Upgrade(w, r, nil) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		defer conn.Close() | ||||||
|  | 
 | ||||||
|  | 		for { | ||||||
|  | 			mt, message, err := conn.ReadMessage() | ||||||
|  | 			if err != nil { | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			if err := conn.WriteMessage(mt, message); err != nil { | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func rootHandler(serverName string) http.HandlerFunc { | ||||||
|  | 	responseTemplate := template.Must(template.New("index").Parse(indexTemplate)) | ||||||
|  | 	return func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 		Log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto) | ||||||
| 		var buffer bytes.Buffer | 		var buffer bytes.Buffer | ||||||
| 		var body string | 		var body string | ||||||
| 		rawBody, err := ioutil.ReadAll(r.Body) | 		rawBody, err := ioutil.ReadAll(r.Body) | ||||||
|  | @ -141,8 +190,8 @@ func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||||||
| 		} else { | 		} else { | ||||||
| 			body = "" | 			body = "" | ||||||
| 		} | 		} | ||||||
| 	err = s.responseTemplate.Execute(&buffer, &templateData{ | 		err = responseTemplate.Execute(&buffer, &templateData{ | ||||||
| 		ServerName: s.serverName, | 			ServerName: serverName, | ||||||
| 			Request:    r, | 			Request:    r, | ||||||
| 			Body:       body, | 			Body:       body, | ||||||
| 		}) | 		}) | ||||||
|  | @ -153,3 +202,4 @@ func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||||||
| 			buffer.WriteTo(w) | 			buffer.WriteTo(w) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -4,18 +4,30 @@ import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const testPort = "8080" | func TestCreateListenerHostAndPortSuccess(t *testing.T) { | ||||||
| 
 | 	listener, err := createListener("localhost:1234") | ||||||
| func TestNewHelloWorldServer(t *testing.T) { |  | ||||||
| 	if NewHelloWorldServer() == nil { |  | ||||||
| 		t.Fatal("NewHelloWorldServer returned nil") |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestFindAvailablePort(t *testing.T) { |  | ||||||
| 	listener, err := findAvailablePort() |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal("Fail to find available port") | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if listener.Addr().String() == "" { | ||||||
|  | 		t.Fatal("Fail to find available port") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestCreateListenerOnlyHostSuccess(t *testing.T) { | ||||||
|  | 	listener, err := createListener("localhost:") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if listener.Addr().String() == "" { | ||||||
|  | 		t.Fatal("Fail to find available port") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestCreateListenerOnlyPortSuccess(t *testing.T) { | ||||||
|  | 	listener, err := createListener(":8888") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	if listener.Addr().String() == "" { | 	if listener.Addr().String() == "" { | ||||||
| 		t.Fatal("Fail to find available port") | 		t.Fatal("Fail to find available port") | ||||||
|  |  | ||||||
|  | @ -42,6 +42,8 @@ After=network.target | ||||||
| TimeoutStartSec=0 | TimeoutStartSec=0 | ||||||
| Type=notify | Type=notify | ||||||
| ExecStart={{ .Path }} --config /etc/cloudflare-warp/config.yml --origincert /etc/cloudflare-warp/cert.pem --no-autoupdate | ExecStart={{ .Path }} --config /etc/cloudflare-warp/config.yml --origincert /etc/cloudflare-warp/cert.pem --no-autoupdate | ||||||
|  | Restart=on-failure | ||||||
|  | RestartSec=5s | ||||||
| 
 | 
 | ||||||
| [Install] | [Install] | ||||||
| WantedBy=multi-user.target | WantedBy=multi-user.target | ||||||
|  |  | ||||||
|  | @ -14,7 +14,6 @@ import ( | ||||||
| 	"syscall" | 	"syscall" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" |  | ||||||
| 	homedir "github.com/mitchellh/go-homedir" | 	homedir "github.com/mitchellh/go-homedir" | ||||||
| 	cli "gopkg.in/urfave/cli.v2" | 	cli "gopkg.in/urfave/cli.v2" | ||||||
| ) | ) | ||||||
|  | @ -137,7 +136,7 @@ func download(certURL, filePath string) bool { | ||||||
| 			return true | 			return true | ||||||
| 		} | 		} | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			log.WithError(err).Error("Error fetching certificate") | 			Log.WithError(err).Error("Error fetching certificate") | ||||||
| 			return false | 			return false | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | @ -180,16 +179,16 @@ func putSuccess(client *http.Client, certURL string) { | ||||||
| 	// indicate success to the relay server
 | 	// indicate success to the relay server
 | ||||||
| 	req, err := http.NewRequest("PUT", certURL+"/ok", nil) | 	req, err := http.NewRequest("PUT", certURL+"/ok", nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Error("HTTP request error") | 		Log.WithError(err).Error("HTTP request error") | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	resp, err := client.Do(req) | 	resp, err := client.Do(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Error("HTTP error") | 		Log.WithError(err).Error("HTTP error") | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	resp.Body.Close() | 	resp.Body.Close() | ||||||
| 	if resp.StatusCode != 200 { | 	if resp.StatusCode != 200 { | ||||||
| 		log.Errorf("Unexpected HTTP error code %d", resp.StatusCode) | 		Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -11,6 +11,8 @@ import ( | ||||||
| 	"os" | 	"os" | ||||||
| 	"os/signal" | 	"os/signal" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
|  | 	"runtime" | ||||||
|  | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"syscall" | 	"syscall" | ||||||
| 	"time" | 	"time" | ||||||
|  | @ -21,11 +23,12 @@ import ( | ||||||
| 	tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" | 	tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" | ||||||
| 	"github.com/cloudflare/cloudflare-warp/validation" | 	"github.com/cloudflare/cloudflare-warp/validation" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" |  | ||||||
| 	"github.com/facebookgo/grace/gracenet" | 	"github.com/facebookgo/grace/gracenet" | ||||||
| 	raven "github.com/getsentry/raven-go" | 	"github.com/getsentry/raven-go" | ||||||
| 	homedir "github.com/mitchellh/go-homedir" | 	"github.com/mitchellh/go-homedir" | ||||||
| 	cli "gopkg.in/urfave/cli.v2" | 	"github.com/rifflock/lfshook" | ||||||
|  | 	"github.com/sirupsen/logrus" | ||||||
|  | 	"gopkg.in/urfave/cli.v2" | ||||||
| 	"gopkg.in/urfave/cli.v2/altsrc" | 	"gopkg.in/urfave/cli.v2/altsrc" | ||||||
| 
 | 
 | ||||||
| 	"github.com/coreos/go-systemd/daemon" | 	"github.com/coreos/go-systemd/daemon" | ||||||
|  | @ -40,11 +43,21 @@ const configFile = "config.yml" | ||||||
| var listeners = gracenet.Net{} | var listeners = gracenet.Net{} | ||||||
| var Version = "DEV" | var Version = "DEV" | ||||||
| var BuildTime = "unknown" | var BuildTime = "unknown" | ||||||
|  | var Log *logrus.Logger | ||||||
| 
 | 
 | ||||||
| // Shutdown channel used by the app. When closed, app must terminate.
 | // Shutdown channel used by the app. When closed, app must terminate.
 | ||||||
| // May be closed by the Windows service runner.
 | // May be closed by the Windows service runner.
 | ||||||
| var shutdownC chan struct{} | var shutdownC chan struct{} | ||||||
| 
 | 
 | ||||||
|  | type BuildAndRuntimeInfo struct { | ||||||
|  | 	GoOS        string                 `json:"go_os"` | ||||||
|  | 	GoVersion   string                 `json:"go_version"` | ||||||
|  | 	GoArch      string                 `json:"go_arch"` | ||||||
|  | 	WarpVersion string                 `json:"warp_version"` | ||||||
|  | 	WarpFlags   map[string]interface{} `json:"warp_flags"` | ||||||
|  | 	WarpEnvs    map[string]string      `json:"warp_envs"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func main() { | func main() { | ||||||
| 	metrics.RegisterBuildInfo(BuildTime, Version) | 	metrics.RegisterBuildInfo(BuildTime, Version) | ||||||
| 	raven.SetDSN(sentryDSN) | 	raven.SetDSN(sentryDSN) | ||||||
|  | @ -84,6 +97,12 @@ WARNING: | ||||||
| 			Usage: "Disable periodic check for updates, restarting the server with the new version.", | 			Usage: "Disable periodic check for updates, restarting the server with the new version.", | ||||||
| 			Value: false, | 			Value: false, | ||||||
| 		}), | 		}), | ||||||
|  | 		altsrc.NewBoolFlag(&cli.BoolFlag{ | ||||||
|  | 			Name: "is-autoupdated", | ||||||
|  | 			Usage: "Signal the new process that Warp client has been autoupdated", | ||||||
|  | 			Value: false, | ||||||
|  | 			Hidden: true, | ||||||
|  | 		}), | ||||||
| 		altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ | 		altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ | ||||||
| 			Name:    "edge", | 			Name:    "edge", | ||||||
| 			Usage:   "Address of the Cloudflare tunnel server.", | 			Usage:   "Address of the Cloudflare tunnel server.", | ||||||
|  | @ -99,12 +118,12 @@ WARNING: | ||||||
| 		altsrc.NewStringFlag(&cli.StringFlag{ | 		altsrc.NewStringFlag(&cli.StringFlag{ | ||||||
| 			Name:    "origincert", | 			Name:    "origincert", | ||||||
| 			Usage:   "Path to the certificate generated for your origin when you run cloudflare-warp login.", | 			Usage:   "Path to the certificate generated for your origin when you run cloudflare-warp login.", | ||||||
| 			EnvVars: []string{"ORIGIN_CERT"}, | 			EnvVars: []string{"TUNNEL_ORIGIN_CERT"}, | ||||||
| 			Value:   filepath.Join(defaultConfigDir, credentialFile), | 			Value:   filepath.Join(defaultConfigDir, credentialFile), | ||||||
| 		}), | 		}), | ||||||
| 		altsrc.NewStringFlag(&cli.StringFlag{ | 		altsrc.NewStringFlag(&cli.StringFlag{ | ||||||
| 			Name:    "url", | 			Name:    "url", | ||||||
| 			Value:   "http://localhost:8080", | 			Value:   "https://localhost:8080", | ||||||
| 			Usage:   "Connect to the local webserver at `URL`.", | 			Usage:   "Connect to the local webserver at `URL`.", | ||||||
| 			EnvVars: []string{"TUNNEL_URL"}, | 			EnvVars: []string{"TUNNEL_URL"}, | ||||||
| 		}), | 		}), | ||||||
|  | @ -191,14 +210,20 @@ WARNING: | ||||||
| 		}), | 		}), | ||||||
| 		altsrc.NewBoolFlag(&cli.BoolFlag{ | 		altsrc.NewBoolFlag(&cli.BoolFlag{ | ||||||
| 			Name:    "hello-world", | 			Name:    "hello-world", | ||||||
| 			Usage: "Run Hello World Server", |  | ||||||
| 			Value:   false, | 			Value:   false, | ||||||
|  | 			Usage:   "Run Hello World Server", | ||||||
|  | 			EnvVars: []string{"TUNNEL_HELLO_WORLD"}, | ||||||
| 		}), | 		}), | ||||||
| 		altsrc.NewStringFlag(&cli.StringFlag{ | 		altsrc.NewStringFlag(&cli.StringFlag{ | ||||||
| 			Name:    "pidfile", | 			Name:    "pidfile", | ||||||
| 			Usage:   "Write the application's PID to this file after first successful connection.", | 			Usage:   "Write the application's PID to this file after first successful connection.", | ||||||
| 			EnvVars: []string{"TUNNEL_PIDFILE"}, | 			EnvVars: []string{"TUNNEL_PIDFILE"}, | ||||||
| 		}), | 		}), | ||||||
|  | 		altsrc.NewStringFlag(&cli.StringFlag{ | ||||||
|  | 			Name:    "logfile", | ||||||
|  | 			Usage:   "Save application log to this file for reporting issues.", | ||||||
|  | 			EnvVars: []string{"TUNNEL_LOGFILE"}, | ||||||
|  | 		}), | ||||||
| 		altsrc.NewIntFlag(&cli.IntFlag{ | 		altsrc.NewIntFlag(&cli.IntFlag{ | ||||||
| 			Name:   "ha-connections", | 			Name:   "ha-connections", | ||||||
| 			Value:  4, | 			Value:  4, | ||||||
|  | @ -239,6 +264,7 @@ WARNING: | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	app.Before = func(context *cli.Context) error { | 	app.Before = func(context *cli.Context) error { | ||||||
|  | 		Log = logrus.New() | ||||||
| 		inputSource, err := findInputSourceContext(context) | 		inputSource, err := findInputSourceContext(context) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
|  | @ -248,7 +274,7 @@ WARNING: | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	app.Commands = []*cli.Command{ | 	app.Commands = []*cli.Command{ | ||||||
| 		&cli.Command{ | 		{ | ||||||
| 			Name:      "update", | 			Name:      "update", | ||||||
| 			Action:    update, | 			Action:    update, | ||||||
| 			Usage:     "Update the agent if a new version exists", | 			Usage:     "Update the agent if a new version exists", | ||||||
|  | @ -259,7 +285,7 @@ WARNING: | ||||||
| 
 | 
 | ||||||
|    To determine if an update happened in a script, check for error code 64.`, |    To determine if an update happened in a script, check for error code 64.`, | ||||||
| 		}, | 		}, | ||||||
| 		&cli.Command{ | 		{ | ||||||
| 			Name:      "login", | 			Name:      "login", | ||||||
| 			Action:    login, | 			Action:    login, | ||||||
| 			Usage:     "Generate a configuration file with your login details", | 			Usage:     "Generate a configuration file with your login details", | ||||||
|  | @ -271,7 +297,7 @@ WARNING: | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		&cli.Command{ | 		{ | ||||||
| 			Name:   "hello", | 			Name:   "hello", | ||||||
| 			Action: hello, | 			Action: hello, | ||||||
| 			Usage:  "Run a simple \"Hello World\" server for testing Cloudflare Warp.", | 			Usage:  "Run a simple \"Hello World\" server for testing Cloudflare Warp.", | ||||||
|  | @ -293,27 +319,43 @@ func startServer(c *cli.Context) { | ||||||
| 	errC := make(chan error) | 	errC := make(chan error) | ||||||
| 	wg.Add(2) | 	wg.Add(2) | ||||||
| 
 | 
 | ||||||
| 	if c.NumFlags() == 0 && c.NArg() == 0 { | 	// If the user choose to supply all options through env variables,
 | ||||||
|  | 	// c.NumFlags() == 0 && c.NArg() == 0. For warp to work, the user needs to at
 | ||||||
|  | 	// least provide a hostname.
 | ||||||
|  | 	if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" { | ||||||
| 		cli.ShowAppHelp(c) | 		cli.ShowAppHelp(c) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 	logLevel, err := logrus.ParseLevel(c.String("loglevel")) | ||||||
| 	logLevel, err := log.ParseLevel(c.String("loglevel")) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Fatal("Unknown logging level specified") | 		Log.WithError(err).Fatal("Unknown logging level specified") | ||||||
| 	} | 	} | ||||||
|  | 	logrus.SetLevel(logLevel) | ||||||
| 
 | 
 | ||||||
| 	log.SetLevel(logLevel) | 	protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel")) | ||||||
| 	protoLogLevel, err := log.ParseLevel(c.String("proto-loglevel")) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Fatal("Unknown protocol logging level specified") | 		Log.WithError(err).Fatal("Unknown protocol logging level specified") | ||||||
| 	} | 	} | ||||||
| 	protoLogger := log.New() | 	protoLogger := logrus.New() | ||||||
| 	protoLogger.Level = protoLogLevel | 	protoLogger.Level = protoLogLevel | ||||||
| 
 | 
 | ||||||
|  | 	if c.String("logfile") != "" { | ||||||
|  | 		if err := initLogFile(c, protoLogger); err != nil { | ||||||
|  | 			Log.Error(err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0 { | ||||||
|  | 		if initUpdate() { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		Log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq")) | ||||||
|  | 		go autoupdate(c.Duration("autoupdate-freq"), shutdownC) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	hostname, err := validation.ValidateHostname(c.String("hostname")) | 	hostname, err := validation.ValidateHostname(c.String("hostname")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Fatal("Invalid hostname") | 		Log.WithError(err).Fatal("Invalid hostname") | ||||||
| 
 | 
 | ||||||
| 	} | 	} | ||||||
| 	clientID := c.String("id") | 	clientID := c.String("id") | ||||||
|  | @ -323,46 +365,44 @@ func startServer(c *cli.Context) { | ||||||
| 
 | 
 | ||||||
| 	tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) | 	tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Fatal("Tag parse failure") | 		Log.WithError(err).Fatal("Tag parse failure") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) | 	tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) | ||||||
| 
 |  | ||||||
| 	if c.IsSet("hello-world") { | 	if c.IsSet("hello-world") { | ||||||
| 		wg.Add(1) | 		wg.Add(1) | ||||||
| 		listener, err := findAvailablePort() | 		listener, err := createListener("127.0.0.1:") | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			listener.Close() | 			listener.Close() | ||||||
| 			log.WithError(err).Fatal("Cannot start Hello World Server") | 			Log.WithError(err).Fatal("Cannot start Hello World Server") | ||||||
| 		} | 		} | ||||||
| 		go func() { | 		go func() { | ||||||
| 			startHelloWorldServer(listener, shutdownC) | 			startHelloWorldServer(listener, shutdownC) | ||||||
| 			wg.Done() | 			wg.Done() | ||||||
| 			listener.Close() | 			listener.Close() | ||||||
| 		}() | 		}() | ||||||
| 		c.Set("url", "http://"+listener.Addr().String()) | 		c.Set("url", "https://"+listener.Addr().String()) | ||||||
| 		log.Infof("Starting Hello World Server at %s", c.String("url")) |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	url, err := validateUrl(c) | 	url, err := validateUrl(c) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Fatal("Error validating url") | 		Log.WithError(err).Fatal("Error validating url") | ||||||
| 	} | 	} | ||||||
| 	log.Infof("Proxying tunnel requests to %s", url) | 	Log.Infof("Proxying tunnel requests to %s", url) | ||||||
| 
 | 
 | ||||||
| 	// Fail if the user provided an old authentication method
 | 	// Fail if the user provided an old authentication method
 | ||||||
| 	if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") { | 	if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") { | ||||||
| 		log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login") | 		Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Check that the user has acquired a certificate using the log in command
 | 	// Check that the user has acquired a certificate using the log in command
 | ||||||
| 	originCertPath, err := homedir.Expand(c.String("origincert")) | 	originCertPath, err := homedir.Expand(c.String("origincert")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert")) | 		Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert")) | ||||||
| 	} | 	} | ||||||
| 	ok, err := fileExists(originCertPath) | 	ok, err := fileExists(originCertPath) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		log.Fatalf(`Cannot find a valid certificate for your origin at the path: | 		Log.Fatalf(`Cannot find a valid certificate for your origin at the path: | ||||||
| 
 | 
 | ||||||
|     %s |     %s | ||||||
| 
 | 
 | ||||||
|  | @ -375,8 +415,9 @@ If you don't have a certificate signed by Cloudflare, run the command: | ||||||
| 	// Easier to send the certificate as []byte via RPC than decoding it at this point
 | 	// Easier to send the certificate as []byte via RPC than decoding it at this point
 | ||||||
| 	originCert, err := ioutil.ReadFile(originCertPath) | 	originCert, err := ioutil.ReadFile(originCertPath) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath) | 		Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath) | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
| 	tunnelMetrics := origin.NewTunnelMetrics() | 	tunnelMetrics := origin.NewTunnelMetrics() | ||||||
| 	httpTransport := &http.Transport{ | 	httpTransport := &http.Transport{ | ||||||
| 		Proxy: http.ProxyFromEnvironment, | 		Proxy: http.ProxyFromEnvironment, | ||||||
|  | @ -389,13 +430,15 @@ If you don't have a certificate signed by Cloudflare, run the command: | ||||||
| 		IdleConnTimeout:       c.Duration("proxy-keepalive-timeout"), | 		IdleConnTimeout:       c.Duration("proxy-keepalive-timeout"), | ||||||
| 		TLSHandshakeTimeout:   c.Duration("proxy-tls-timeout"), | 		TLSHandshakeTimeout:   c.Duration("proxy-tls-timeout"), | ||||||
| 		ExpectContinueTimeout: 1 * time.Second, | 		ExpectContinueTimeout: 1 * time.Second, | ||||||
|  | 		TLSClientConfig:       &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()}, | ||||||
| 	} | 	} | ||||||
| 	tunnelConfig := &origin.TunnelConfig{ | 	tunnelConfig := &origin.TunnelConfig{ | ||||||
| 		EdgeAddrs:         c.StringSlice("edge"), | 		EdgeAddrs:         c.StringSlice("edge"), | ||||||
| 		OriginUrl:         url, | 		OriginUrl:         url, | ||||||
| 		Hostname:          hostname, | 		Hostname:          hostname, | ||||||
| 		OriginCert:        originCert, | 		OriginCert:        originCert, | ||||||
| 		TlsConfig:         &tls.Config{}, | 		TlsConfig:         tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")), | ||||||
|  | 		ClientTlsConfig:   httpTransport.TLSClientConfig, | ||||||
| 		Retries:           c.Uint("retries"), | 		Retries:           c.Uint("retries"), | ||||||
| 		HeartbeatInterval: c.Duration("heartbeat-interval"), | 		HeartbeatInterval: c.Duration("heartbeat-interval"), | ||||||
| 		MaxHeartbeats:     c.Uint64("heartbeat-count"), | 		MaxHeartbeats:     c.Uint64("heartbeat-count"), | ||||||
|  | @ -408,18 +451,11 @@ If you don't have a certificate signed by Cloudflare, run the command: | ||||||
| 		Metrics:           tunnelMetrics, | 		Metrics:           tunnelMetrics, | ||||||
| 		MetricsUpdateFreq: c.Duration("metrics-update-freq"), | 		MetricsUpdateFreq: c.Duration("metrics-update-freq"), | ||||||
| 		ProtocolLogger:    protoLogger, | 		ProtocolLogger:    protoLogger, | ||||||
|  | 		Logger:            Log, | ||||||
|  | 		IsAutoupdated:     c.Bool("is-autoupdated"), | ||||||
| 	} | 	} | ||||||
| 	connectedSignal := make(chan struct{}) | 	connectedSignal := make(chan struct{}) | ||||||
| 
 | 
 | ||||||
| 	tunnelConfig.TlsConfig = tlsconfig.CLIFlags{RootCA: "cacert"}.GetConfig(c) |  | ||||||
| 	if tunnelConfig.TlsConfig.RootCAs == nil { |  | ||||||
| 		tunnelConfig.TlsConfig.RootCAs = GetCloudflareRootCA() |  | ||||||
| 		tunnelConfig.TlsConfig.ServerName = "cftunnel.com" |  | ||||||
| 	} else if len(tunnelConfig.EdgeAddrs) > 0 { |  | ||||||
| 		// Set for development environments and for testing specific origintunneld instances
 |  | ||||||
| 		tunnelConfig.TlsConfig.ServerName, _, _ = net.SplitHostPort(tunnelConfig.EdgeAddrs[0]) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	go writePidFile(connectedSignal, c.String("pidfile")) | 	go writePidFile(connectedSignal, c.String("pidfile")) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal) | 		errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal) | ||||||
|  | @ -428,24 +464,21 @@ If you don't have a certificate signed by Cloudflare, run the command: | ||||||
| 
 | 
 | ||||||
| 	metricsListener, err := listeners.Listen("tcp", c.String("metrics")) | 	metricsListener, err := listeners.Listen("tcp", c.String("metrics")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Fatal("Error opening metrics server listener") | 		Log.WithError(err).Fatal("Error opening metrics server listener") | ||||||
| 	} | 	} | ||||||
| 	go func() { | 	go func() { | ||||||
| 		errC <- metrics.ServeMetrics(metricsListener, shutdownC) | 		errC <- metrics.ServeMetrics(metricsListener, shutdownC) | ||||||
| 		wg.Done() | 		wg.Done() | ||||||
| 	}() | 	}() | ||||||
| 
 | 
 | ||||||
| 	if !c.Bool("no-autoupdate") { | 	var errCode int | ||||||
| 		log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq")) |  | ||||||
| 		go autoupdate(c.Duration("autoupdate-period"), shutdownC) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	err = WaitForSignal(errC, shutdownC) | 	err = WaitForSignal(errC, shutdownC) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Error("Quitting due to error") | 		Log.WithError(err).Error("Quitting due to error") | ||||||
| 		raven.CaptureErrorAndWait(err, nil) | 		raven.CaptureErrorAndWait(err, nil) | ||||||
|  | 		errCode = 1 | ||||||
| 	} else { | 	} else { | ||||||
| 		log.Info("Quitting...") | 		Log.Info("Quitting...") | ||||||
| 	} | 	} | ||||||
| 	// Wait for clean exit, discarding all errors
 | 	// Wait for clean exit, discarding all errors
 | ||||||
| 	go func() { | 	go func() { | ||||||
|  | @ -453,6 +486,7 @@ If you don't have a certificate signed by Cloudflare, run the command: | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| 	wg.Wait() | 	wg.Wait() | ||||||
|  | 	os.Exit(errCode) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func WaitForSignal(errC chan error, shutdownC chan struct{}) error { | func WaitForSignal(errC chan error, shutdownC chan struct{}) error { | ||||||
|  | @ -477,30 +511,40 @@ func update(c *cli.Context) error { | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func autoupdate(frequency time.Duration, shutdownC chan struct{}) { | func initUpdate() bool { | ||||||
| 	if int64(frequency) == 0 { | 	if updateApplied() { | ||||||
| 		return | 		os.Args = append(os.Args, "--is-autoupdated=true") | ||||||
|  | 		if _, err := listeners.StartProcess(); err != nil { | ||||||
|  | 			Log.WithError(err).Error("Unable to restart server automatically") | ||||||
|  | 			return false | ||||||
| 		} | 		} | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func autoupdate(freq time.Duration, shutdownC chan struct{}) { | ||||||
| 	for { | 	for { | ||||||
| 		if updateApplied() { | 		if updateApplied() { | ||||||
|  | 			os.Args = append(os.Args, "--is-autoupdated=true") | ||||||
| 			if _, err := listeners.StartProcess(); err != nil { | 			if _, err := listeners.StartProcess(); err != nil { | ||||||
| 				log.WithError(err).Error("Unable to restart server automatically") | 				Log.WithError(err).Error("Unable to restart server automatically") | ||||||
| 			} | 			} | ||||||
| 			close(shutdownC) | 			close(shutdownC) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		time.Sleep(frequency) | 		time.Sleep(freq) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func updateApplied() bool { | func updateApplied() bool { | ||||||
| 	releaseInfo := checkForUpdates() | 	releaseInfo := checkForUpdates() | ||||||
| 	if releaseInfo.Updated { | 	if releaseInfo.Updated { | ||||||
| 		log.Infof("Updated to version %s", releaseInfo.Version) | 		Log.Infof("Updated to version %s", releaseInfo.Version) | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
| 	if releaseInfo.Error != nil { | 	if releaseInfo.Error != nil { | ||||||
| 		log.WithError(releaseInfo.Error).Error("Update check failed") | 		Log.WithError(releaseInfo.Error).Error("Update check failed") | ||||||
| 	} | 	} | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
|  | @ -555,7 +599,7 @@ func writePidFile(waitForSignal chan struct{}, pidFile string) { | ||||||
| 	} | 	} | ||||||
| 	file, err := os.Create(pidFile) | 	file, err := os.Create(pidFile) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Errorf("Unable to write pid to %s", pidFile) | 		Log.WithError(err).Errorf("Unable to write pid to %s", pidFile) | ||||||
| 	} | 	} | ||||||
| 	defer file.Close() | 	defer file.Close() | ||||||
| 	fmt.Fprintf(file, "%d", os.Getpid()) | 	fmt.Fprintf(file, "%d", os.Getpid()) | ||||||
|  | @ -573,3 +617,55 @@ func validateUrl(c *cli.Context) (string, error) { | ||||||
| 	validUrl, err := validation.ValidateUrl(url) | 	validUrl, err := validation.ValidateUrl(url) | ||||||
| 	return validUrl, err | 	return validUrl, err | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error { | ||||||
|  | 	fileMode := os.O_WRONLY|os.O_APPEND|os.O_CREATE|os.O_TRUNC | ||||||
|  | 	// do not truncate log file if the client has been autoupdated
 | ||||||
|  | 	if c.Bool("is-autoupdated") { | ||||||
|  | 		fileMode = os.O_WRONLY|os.O_APPEND|os.O_CREATE | ||||||
|  | 	} | ||||||
|  | 	f, err := os.OpenFile(c.String("logfile"), fileMode, 0664) | ||||||
|  | 	if err != nil { | ||||||
|  | 		errors.Wrap(err, fmt.Sprintf("Cannot open file %s", c.String("logfile"))) | ||||||
|  | 	} | ||||||
|  | 	defer f.Close() | ||||||
|  | 
 | ||||||
|  | 	pathMap := lfshook.PathMap{ | ||||||
|  | 		logrus.InfoLevel:  c.String("logfile"), | ||||||
|  | 		logrus.ErrorLevel: c.String("logfile"), | ||||||
|  | 		logrus.FatalLevel: c.String("logfile"), | ||||||
|  | 		logrus.PanicLevel: c.String("logfile"), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	Log.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{})) | ||||||
|  | 	protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{})) | ||||||
|  | 
 | ||||||
|  | 	flags := make(map[string]interface{}) | ||||||
|  | 	envs := make(map[string]string) | ||||||
|  | 
 | ||||||
|  | 	for _, flag := range c.LocalFlagNames() { | ||||||
|  | 		flags[flag] = c.Generic(flag) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Find env variables for Warp
 | ||||||
|  | 	for _, env := range os.Environ() { | ||||||
|  | 		// All Warp env variables start with TUNNEL_
 | ||||||
|  | 		if strings.Contains(env, "TUNNEL_") { | ||||||
|  | 			vars := strings.Split(env, "=") | ||||||
|  | 			if len(vars) == 2 { | ||||||
|  | 				envs[vars[0]] = vars[1] | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	Log.Infof("Warp build and runtime configuration: %+v", BuildAndRuntimeInfo{ | ||||||
|  | 		GoOS:        runtime.GOOS, | ||||||
|  | 		GoVersion:   runtime.Version(), | ||||||
|  | 		GoArch:      runtime.GOARCH, | ||||||
|  | 		WarpVersion: Version, | ||||||
|  | 		WarpFlags:   flags, | ||||||
|  | 		WarpEnvs:    envs, | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -9,7 +9,7 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"os" | 	"os" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| 	cli "gopkg.in/urfave/cli.v2" | 	cli "gopkg.in/urfave/cli.v2" | ||||||
| 
 | 
 | ||||||
| 	"golang.org/x/sys/windows/svc" | 	"golang.org/x/sys/windows/svc" | ||||||
|  |  | ||||||
|  | @ -6,7 +6,7 @@ import ( | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| 	"golang.org/x/net/http2" | 	"golang.org/x/net/http2" | ||||||
| 	"golang.org/x/net/http2/hpack" | 	"golang.org/x/net/http2/hpack" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | @ -6,13 +6,14 @@ import ( | ||||||
| 	"io" | 	"io" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
|  | 	"net" | ||||||
| 	"os" | 	"os" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestMain(m *testing.M) { | func TestMain(m *testing.M) { | ||||||
|  | @ -25,25 +26,20 @@ func TestMain(m *testing.M) { | ||||||
| type DefaultMuxerPair struct { | type DefaultMuxerPair struct { | ||||||
| 	OriginMuxConfig MuxerConfig | 	OriginMuxConfig MuxerConfig | ||||||
| 	OriginMux       *Muxer | 	OriginMux       *Muxer | ||||||
| 	OriginWriter    *io.PipeWriter | 	OriginConn      net.Conn | ||||||
| 	OriginReader    *io.PipeReader |  | ||||||
| 	EdgeMuxConfig   MuxerConfig | 	EdgeMuxConfig   MuxerConfig | ||||||
| 	EdgeMux         *Muxer | 	EdgeMux         *Muxer | ||||||
| 	EdgeWriter      *io.PipeWriter | 	EdgeConn        net.Conn | ||||||
| 	EdgeReader      *io.PipeReader |  | ||||||
| 	doneC           chan struct{} | 	doneC           chan struct{} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewDefaultMuxerPair() *DefaultMuxerPair { | func NewDefaultMuxerPair() *DefaultMuxerPair { | ||||||
| 	originReader, edgeWriter := io.Pipe() | 	origin, edge := net.Pipe() | ||||||
| 	edgeReader, originWriter := io.Pipe() |  | ||||||
| 	return &DefaultMuxerPair{ | 	return &DefaultMuxerPair{ | ||||||
| 		OriginMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: true, Name: "origin"}, | 		OriginMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: true, Name: "origin"}, | ||||||
| 		OriginWriter:    originWriter, | 		OriginConn:      origin, | ||||||
| 		OriginReader:    originReader, |  | ||||||
| 		EdgeMuxConfig:   MuxerConfig{Timeout: time.Second, IsClient: false, Name: "edge"}, | 		EdgeMuxConfig:   MuxerConfig{Timeout: time.Second, IsClient: false, Name: "edge"}, | ||||||
| 		EdgeWriter:      edgeWriter, | 		EdgeConn:        edge, | ||||||
| 		EdgeReader:      edgeReader, |  | ||||||
| 		doneC:           make(chan struct{}), | 		doneC:           make(chan struct{}), | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | @ -53,12 +49,12 @@ func (p *DefaultMuxerPair) Handshake(t *testing.T) { | ||||||
| 	originErrC := make(chan error) | 	originErrC := make(chan error) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		var err error | 		var err error | ||||||
| 		p.EdgeMux, err = Handshake(p.EdgeWriter, p.EdgeReader, p.EdgeMuxConfig) | 		p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig) | ||||||
| 		edgeErrC <- err | 		edgeErrC <- err | ||||||
| 	}() | 	}() | ||||||
| 	go func() { | 	go func() { | ||||||
| 		var err error | 		var err error | ||||||
| 		p.OriginMux, err = Handshake(p.OriginWriter, p.OriginReader, p.OriginMuxConfig) | 		p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig) | ||||||
| 		originErrC <- err | 		originErrC <- err | ||||||
| 	}() | 	}() | ||||||
| 
 | 
 | ||||||
|  | @ -120,8 +116,8 @@ func (p *DefaultMuxerPair) Wait(t *testing.T) { | ||||||
| func TestHandshake(t *testing.T) { | func TestHandshake(t *testing.T) { | ||||||
| 	muxPair := NewDefaultMuxerPair() | 	muxPair := NewDefaultMuxerPair() | ||||||
| 	muxPair.Handshake(t) | 	muxPair.Handshake(t) | ||||||
| 	AssertIfPipeReadable(t, muxPair.OriginReader) | 	AssertIfPipeReadable(t, muxPair.OriginConn) | ||||||
| 	AssertIfPipeReadable(t, muxPair.EdgeReader) | 	AssertIfPipeReadable(t, muxPair.EdgeConn) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSingleStream(t *testing.T) { | func TestSingleStream(t *testing.T) { | ||||||
|  | @ -145,7 +141,7 @@ func TestSingleStream(t *testing.T) { | ||||||
| 		stream.Write(buf) | 		stream.Write(buf) | ||||||
| 		// after this receive, the edge closed the stream
 | 		// after this receive, the edge closed the stream
 | ||||||
| 		<-closeC | 		<-closeC | ||||||
| 		n, err := stream.Read(buf) | 		n, err := io.ReadFull(stream, buf) | ||||||
| 		if n > 0 { | 		if n > 0 { | ||||||
| 			t.Fatalf("read %d bytes after EOF", n) | 			t.Fatalf("read %d bytes after EOF", n) | ||||||
| 		} | 		} | ||||||
|  | @ -173,7 +169,7 @@ func TestSingleStream(t *testing.T) { | ||||||
| 		t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value) | 		t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value) | ||||||
| 	} | 	} | ||||||
| 	responseBody := make([]byte, 11) | 	responseBody := make([]byte, 11) | ||||||
| 	n, err := stream.Read(responseBody) | 	n, err := io.ReadFull(stream, responseBody) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("error from (*MuxedStream).Read: %s", err) | 		t.Fatalf("error from (*MuxedStream).Read: %s", err) | ||||||
| 	} | 	} | ||||||
|  | @ -243,7 +239,7 @@ func TestSingleStreamLargeResponseBody(t *testing.T) { | ||||||
| 		t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value) | 		t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value) | ||||||
| 	} | 	} | ||||||
| 	responseBody := make([]byte, bodySize) | 	responseBody := make([]byte, bodySize) | ||||||
| 	n, err := stream.Read(responseBody) | 	n, err := io.ReadFull(stream, responseBody) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("error from (*MuxedStream).Read: %s", err) | 		t.Fatalf("error from (*MuxedStream).Read: %s", err) | ||||||
| 	} | 	} | ||||||
|  | @ -302,7 +298,7 @@ func TestMultipleStreams(t *testing.T) { | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 			responseBody := make([]byte, 2) | 			responseBody := make([]byte, 2) | ||||||
| 			n, err := stream.Read(responseBody) | 			n, err := io.ReadFull(stream, responseBody) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err) | 				errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err) | ||||||
| 				return | 				return | ||||||
|  | @ -392,7 +388,7 @@ func TestMultipleStreamsFlowControl(t *testing.T) { | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			responseBody := make([]byte, responseSizes[(stream.streamID-2)/2]) | 			responseBody := make([]byte, responseSizes[(stream.streamID-2)/2]) | ||||||
| 			n, err := stream.Read(responseBody) | 			n, err := io.ReadFull(stream, responseBody) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err) | 				errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err) | ||||||
| 				return | 				return | ||||||
|  | @ -451,7 +447,7 @@ func TestGracefulShutdown(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| 	responseBody := make([]byte, len(responseBuf)) | 	responseBody := make([]byte, len(responseBuf)) | ||||||
| 	log.Debugf("Waiting for %d bytes", len(responseBuf)) | 	log.Debugf("Waiting for %d bytes", len(responseBuf)) | ||||||
| 	n, err := stream.Read(responseBody) | 	n, err := io.ReadFull(stream, responseBody) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err) | 		t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err) | ||||||
| 	} | 	} | ||||||
|  | @ -498,13 +494,13 @@ func TestUnexpectedShutdown(t *testing.T) { | ||||||
| 		nil, | 		nil, | ||||||
| 	) | 	) | ||||||
| 	// Close the underlying connection before telling the origin to write.
 | 	// Close the underlying connection before telling the origin to write.
 | ||||||
| 	muxPair.EdgeReader.Close() | 	muxPair.EdgeConn.Close() | ||||||
| 	close(sendC) | 	close(sendC) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("error in OpenStream: %s", err) | 		t.Fatalf("error in OpenStream: %s", err) | ||||||
| 	} | 	} | ||||||
| 	responseBody := make([]byte, len(responseBuf)) | 	responseBody := make([]byte, len(responseBuf)) | ||||||
| 	n, err := stream.Read(responseBody) | 	n, err := io.ReadFull(stream, responseBody) | ||||||
| 	if err != io.EOF { | 	if err != io.EOF { | ||||||
| 		t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err) | 		t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err) | ||||||
| 	} | 	} | ||||||
|  | @ -545,14 +541,14 @@ func TestOpenAfterDisconnect(t *testing.T) { | ||||||
| 		switch i { | 		switch i { | ||||||
| 		case 0: | 		case 0: | ||||||
| 			// Close both directions of the connection to cause EOF on both peers.
 | 			// Close both directions of the connection to cause EOF on both peers.
 | ||||||
| 			muxPair.OriginReader.Close() | 			muxPair.OriginConn.Close() | ||||||
| 			muxPair.OriginWriter.Close() | 			muxPair.EdgeConn.Close() | ||||||
| 		case 1: | 		case 1: | ||||||
| 			// Close origin reader (edge writer) to cause EOF on origin only.
 | 			// Close origin conn to cause EOF on origin first.
 | ||||||
| 			muxPair.OriginReader.Close() | 			muxPair.OriginConn.Close() | ||||||
| 		case 2: | 		case 2: | ||||||
| 			// Close origin writer (edge reader) to cause EOF on edge only.
 | 			// Close edge conn to cause EOF on edge first.
 | ||||||
| 			muxPair.OriginWriter.Close() | 			muxPair.EdgeConn.Close() | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		_, err := muxPair.EdgeMux.OpenStream( | 		_, err := muxPair.EdgeMux.OpenStream( | ||||||
|  | @ -623,7 +619,7 @@ func TestHPACK(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) { | func AssertIfPipeReadable(t *testing.T, pipe io.ReadCloser) { | ||||||
| 	errC := make(chan error) | 	errC := make(chan error) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		b := []byte{0} | 		b := []byte{0} | ||||||
|  | @ -640,7 +636,5 @@ func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) { | ||||||
| 		} | 		} | ||||||
| 	case <-time.After(100 * time.Millisecond): | 	case <-time.After(100 * time.Millisecond): | ||||||
| 		// nothing to read
 | 		// nothing to read
 | ||||||
| 		pipe.Close() |  | ||||||
| 		<-errC |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,6 +1,7 @@ | ||||||
| package h2mux | package h2mux | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"io" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
|  | @ -63,3 +64,27 @@ func TestFlowControlSingleStream(t *testing.T) { | ||||||
| 	assert.Equal(t, testWindowSize<<2, stream.receiveWindow) | 	assert.Equal(t, testWindowSize<<2, stream.receiveWindow) | ||||||
| 	assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax) | 	assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestMuxedStreamEOF(t *testing.T) { | ||||||
|  | 	for i := 0; i < 4096; i++ { | ||||||
|  | 		readyList := NewReadyList() | ||||||
|  | 		stream := &MuxedStream{ | ||||||
|  | 			streamID:         1, | ||||||
|  | 			readBuffer:       NewSharedBuffer(), | ||||||
|  | 			receiveWindow:    65536, | ||||||
|  | 			receiveWindowMax: 65536, | ||||||
|  | 			sendWindow:       65536, | ||||||
|  | 			readyList:        readyList, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		go func() { stream.Close() }() | ||||||
|  | 		n, err := stream.Read([]byte{0}) | ||||||
|  | 		assert.Equal(t, io.EOF, err) | ||||||
|  | 		assert.Equal(t, 0, n) | ||||||
|  | 		// Write comes after read, because write buffers data before it is flushed. It wouldn't know about EOF
 | ||||||
|  | 		// until some time later. Calling read first forces it to know about EOF now.
 | ||||||
|  | 		n, err = stream.Write([]byte{1}) | ||||||
|  | 		assert.Equal(t, io.EOF, err) | ||||||
|  | 		assert.Equal(t, 0, n) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -6,7 +6,7 @@ import ( | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| 	"golang.org/x/net/http2" | 	"golang.org/x/net/http2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -6,7 +6,7 @@ import ( | ||||||
| 	"io" | 	"io" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| 	"golang.org/x/net/http2" | 	"golang.org/x/net/http2" | ||||||
| 	"golang.org/x/net/http2/hpack" | 	"golang.org/x/net/http2/hpack" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | @ -21,7 +21,7 @@ func NewSharedBuffer() *SharedBuffer { | ||||||
| func (s *SharedBuffer) Read(p []byte) (n int, err error) { | func (s *SharedBuffer) Read(p []byte) (n int, err error) { | ||||||
| 	totalRead := 0 | 	totalRead := 0 | ||||||
| 	s.cond.L.Lock() | 	s.cond.L.Lock() | ||||||
| 	for totalRead < len(p) { | 	for totalRead == 0 { | ||||||
| 		n, err = s.buffer.Read(p[totalRead:]) | 		n, err = s.buffer.Read(p[totalRead:]) | ||||||
| 		totalRead += n | 		totalRead += n | ||||||
| 		if err == io.EOF { | 		if err == io.EOF { | ||||||
|  | @ -29,6 +29,9 @@ func (s *SharedBuffer) Read(p []byte) (n int, err error) { | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			err = nil | 			err = nil | ||||||
|  | 			if n > 0 { | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
| 			s.cond.Wait() | 			s.cond.Wait() | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -6,6 +6,8 @@ import ( | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) { | func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) { | ||||||
|  | @ -29,30 +31,35 @@ func TestSharedBuffer(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| func TestSharedBufferBlockingRead(t *testing.T) { | func TestSharedBufferBlockingRead(t *testing.T) { | ||||||
| 	b := NewSharedBuffer() | 	b := NewSharedBuffer() | ||||||
| 	testData := []byte("Hello world") | 	testData1 := []byte("Hello") | ||||||
|  | 	testData2 := []byte(" world") | ||||||
| 	result := make(chan []byte) | 	result := make(chan []byte) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		bytesRead := make([]byte, len(testData)) | 		bytesRead := make([]byte, len(testData1)+len(testData2)) | ||||||
| 		AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead)) | 		nRead, err := b.Read(bytesRead) | ||||||
| 		result <- bytesRead | 		AssertIOReturnIsGood(t, len(testData1))(nRead, err) | ||||||
|  | 		result <- bytesRead[:nRead] | ||||||
|  | 		nRead, err = b.Read(bytesRead) | ||||||
|  | 		AssertIOReturnIsGood(t, len(testData2))(nRead, err) | ||||||
|  | 		result <- bytesRead[:nRead] | ||||||
| 	}() | 	}() | ||||||
|  | 	time.Sleep(time.Millisecond * 250) | ||||||
| 	select { | 	select { | ||||||
| 	case <-result: | 	case <-result: | ||||||
| 		t.Fatalf("read returned early") | 		t.Fatalf("read returned early") | ||||||
| 	default: | 	default: | ||||||
| 	} | 	} | ||||||
| 	AssertIOReturnIsGood(t, 5)(b.Write(testData[:5])) | 	AssertIOReturnIsGood(t, len(testData1))(b.Write([]byte(testData1))) | ||||||
| 	select { |  | ||||||
| 	case <-result: |  | ||||||
| 		t.Fatalf("read returned early") |  | ||||||
| 	default: |  | ||||||
| 	} |  | ||||||
| 	AssertIOReturnIsGood(t, len(testData)-5)(b.Write(testData[5:])) |  | ||||||
| 	select { | 	select { | ||||||
| 	case r := <-result: | 	case r := <-result: | ||||||
| 		if string(r) != string(testData) { | 		assert.Equal(t, testData1, r) | ||||||
| 			t.Fatalf("expected read to return %s, got %s", testData, r) | 	case <-time.After(time.Second): | ||||||
|  | 		t.Fatalf("read timed out") | ||||||
| 	} | 	} | ||||||
|  | 	AssertIOReturnIsGood(t, len(testData2))(b.Write([]byte(testData2))) | ||||||
|  | 	select { | ||||||
|  | 	case r := <-result: | ||||||
|  | 		assert.Equal(t, testData2, r) | ||||||
| 	case <-time.After(time.Second): | 	case <-time.After(time.Second): | ||||||
| 		t.Fatalf("read timed out") | 		t.Fatalf("read timed out") | ||||||
| 	} | 	} | ||||||
|  | @ -85,7 +92,7 @@ func TestSharedBufferConcurrentReadWrite(t *testing.T) { | ||||||
| 		// Change block sizes in opposition to the write thread, to test blocking for new data.
 | 		// Change block sizes in opposition to the write thread, to test blocking for new data.
 | ||||||
| 		for blockSize := 256; blockSize > 0; blockSize-- { | 		for blockSize := 256; blockSize > 0; blockSize-- { | ||||||
| 			for i := 0; i < 256; i++ { | 			for i := 0; i < 256; i++ { | ||||||
| 				n, err := b.Read(block[:blockSize]) | 				n, err := io.ReadFull(b, block[:blockSize]) | ||||||
| 				if n != blockSize || err != nil { | 				if n != blockSize || err != nil { | ||||||
| 					t.Fatalf("read error: %d %s", n, err) | 					t.Fatalf("read error: %d %s", n, err) | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
|  | @ -11,9 +11,9 @@ import ( | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| 	"golang.org/x/net/trace" | 	"golang.org/x/net/trace" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" |  | ||||||
| 	"github.com/prometheus/client_golang/prometheus" | 	"github.com/prometheus/client_golang/prometheus" | ||||||
| 	"github.com/prometheus/client_golang/prometheus/promhttp" | 	"github.com/prometheus/client_golang/prometheus/promhttp" | ||||||
|  | 	log "github.com/sirupsen/logrus" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| const ( | const ( | ||||||
|  |  | ||||||
|  | @ -5,7 +5,6 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/cloudflare/cloudflare-warp/h2mux" | 	"github.com/cloudflare/cloudflare-warp/h2mux" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" |  | ||||||
| 	"github.com/prometheus/client_golang/prometheus" | 	"github.com/prometheus/client_golang/prometheus" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -249,7 +248,7 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) { | ||||||
| 	if _, ok := t.concurrentRequests[connectionID]; ok { | 	if _, ok := t.concurrentRequests[connectionID]; ok { | ||||||
| 		t.concurrentRequests[connectionID] -= 1 | 		t.concurrentRequests[connectionID] -= 1 | ||||||
| 	} else { | 	} else { | ||||||
| 		log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.") | 		Log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.") | ||||||
| 	} | 	} | ||||||
| 	t.concurrentRequestsLock.Unlock() | 	t.concurrentRequestsLock.Unlock() | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -5,7 +5,6 @@ import ( | ||||||
| 	"net" | 	"net" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" |  | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -73,7 +72,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err | ||||||
| 		case tunnelError := <-s.tunnelErrors: | 		case tunnelError := <-s.tunnelErrors: | ||||||
| 			tunnelsActive-- | 			tunnelsActive-- | ||||||
| 			if tunnelError.err != nil { | 			if tunnelError.err != nil { | ||||||
| 				log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error") | 				Log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error") | ||||||
| 				tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) | 				tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) | ||||||
| 				s.waitForNextTunnel(tunnelError.index) | 				s.waitForNextTunnel(tunnelError.index) | ||||||
| 				if backoffTimer == nil { | 				if backoffTimer == nil { | ||||||
|  | @ -107,10 +106,10 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err | ||||||
| 			s.lastResolve = time.Now() | 			s.lastResolve = time.Now() | ||||||
| 			s.resolverC = nil | 			s.resolverC = nil | ||||||
| 			if result.err == nil { | 			if result.err == nil { | ||||||
| 				log.Debug("Service discovery refresh complete") | 				Log.Debug("Service discovery refresh complete") | ||||||
| 				s.edgeIPs = result.edgeIPs | 				s.edgeIPs = result.edgeIPs | ||||||
| 			} else { | 			} else { | ||||||
| 				log.WithError(result.err).Error("Service discovery error") | 				Log.WithError(result.err).Error("Service discovery error") | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | @ -120,12 +119,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err | ||||||
| func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error { | func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error { | ||||||
| 	edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs) | 	edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Infof("ResolveEdgeIPs err") | 		Log.Infof("ResolveEdgeIPs err") | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	s.edgeIPs = edgeIPs | 	s.edgeIPs = edgeIPs | ||||||
| 	if s.config.HAConnections > len(edgeIPs) { | 	if s.config.HAConnections > len(edgeIPs) { | ||||||
| 		log.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs)) | 		Log.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs)) | ||||||
| 		s.config.HAConnections = len(edgeIPs) | 		s.config.HAConnections = len(edgeIPs) | ||||||
| 	} | 	} | ||||||
| 	s.lastResolve = time.Now() | 	s.lastResolve = time.Now() | ||||||
|  |  | ||||||
|  | @ -19,14 +19,17 @@ import ( | ||||||
| 	"github.com/cloudflare/cloudflare-warp/tunnelrpc" | 	"github.com/cloudflare/cloudflare-warp/tunnelrpc" | ||||||
| 	tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" | 	tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" | ||||||
| 	"github.com/cloudflare/cloudflare-warp/validation" | 	"github.com/cloudflare/cloudflare-warp/validation" | ||||||
|  | 	"github.com/cloudflare/cloudflare-warp/websocket" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" |  | ||||||
| 	raven "github.com/getsentry/raven-go" | 	raven "github.com/getsentry/raven-go" | ||||||
| 	"github.com/pkg/errors" | 	"github.com/pkg/errors" | ||||||
| 	_ "github.com/prometheus/client_golang/prometheus" | 	_ "github.com/prometheus/client_golang/prometheus" | ||||||
|  | 	"github.com/sirupsen/logrus" | ||||||
| 	rpc "zombiezen.com/go/capnproto2/rpc" | 	rpc "zombiezen.com/go/capnproto2/rpc" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | var Log *logrus.Logger | ||||||
|  | 
 | ||||||
| const ( | const ( | ||||||
| 	dialTimeout = 15 * time.Second | 	dialTimeout = 15 * time.Second | ||||||
| 
 | 
 | ||||||
|  | @ -40,6 +43,7 @@ type TunnelConfig struct { | ||||||
| 	Hostname          string | 	Hostname          string | ||||||
| 	OriginCert        []byte | 	OriginCert        []byte | ||||||
| 	TlsConfig         *tls.Config | 	TlsConfig         *tls.Config | ||||||
|  | 	ClientTlsConfig   *tls.Config | ||||||
| 	Retries           uint | 	Retries           uint | ||||||
| 	HeartbeatInterval time.Duration | 	HeartbeatInterval time.Duration | ||||||
| 	MaxHeartbeats     uint64 | 	MaxHeartbeats     uint64 | ||||||
|  | @ -51,7 +55,9 @@ type TunnelConfig struct { | ||||||
| 	HTTPTransport     http.RoundTripper | 	HTTPTransport     http.RoundTripper | ||||||
| 	Metrics           *TunnelMetrics | 	Metrics           *TunnelMetrics | ||||||
| 	MetricsUpdateFreq time.Duration | 	MetricsUpdateFreq time.Duration | ||||||
| 	ProtocolLogger    *log.Logger | 	ProtocolLogger    *logrus.Logger | ||||||
|  | 	Logger            *logrus.Logger | ||||||
|  | 	IsAutoupdated     bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type dialError struct { | type dialError struct { | ||||||
|  | @ -87,14 +93,16 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str | ||||||
| 		Version:              c.ReportedVersion, | 		Version:              c.ReportedVersion, | ||||||
| 		OS:                   fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH), | 		OS:                   fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH), | ||||||
| 		ExistingTunnelPolicy: policy, | 		ExistingTunnelPolicy: policy, | ||||||
| 		PoolID:               c.LBPool, | 		PoolName:             c.LBPool, | ||||||
| 		Tags:                 c.Tags, | 		Tags:                 c.Tags, | ||||||
| 		ConnectionID:         connectionID, | 		ConnectionID:         connectionID, | ||||||
| 		OriginLocalIP:        OriginLocalIP, | 		OriginLocalIP:        OriginLocalIP, | ||||||
|  | 		IsAutoupdated:        c.IsAutoupdated, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error { | func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error { | ||||||
|  | 	Log = config.Logger | ||||||
| 	ctx, cancel := context.WithCancel(context.Background()) | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		<-shutdownC | 		<-shutdownC | ||||||
|  | @ -129,7 +137,7 @@ func ServeTunnelLoop(ctx context.Context, config *TunnelConfig, addr *net.TCPAdd | ||||||
| 		err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff) | 		err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff) | ||||||
| 		if recoverable { | 		if recoverable { | ||||||
| 			if duration, ok := backoff.GetBackoffDuration(ctx); ok { | 			if duration, ok := backoff.GetBackoffDuration(ctx); ok { | ||||||
| 				log.Infof("Retrying in %s seconds", duration) | 				Log.Infof("Retrying in %s seconds", duration) | ||||||
| 				backoff.Backoff(ctx) | 				backoff.Backoff(ctx) | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
|  | @ -162,11 +170,10 @@ func ServeTunnel( | ||||||
| 	// Returns error from parsing the origin URL or handshake errors
 | 	// Returns error from parsing the origin URL or handshake errors
 | ||||||
| 	handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID) | 	handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		errLog := log.WithError(err) | 		errLog := Log.WithError(err) | ||||||
| 		switch err.(type) { | 		switch err.(type) { | ||||||
| 		case dialError: | 		case dialError: | ||||||
| 			errLog.Error("Unable to dial edge") | 			errLog.Error("Unable to dial edge") | ||||||
| 			return err, false |  | ||||||
| 		case h2mux.MuxerHandshakeError: | 		case h2mux.MuxerHandshakeError: | ||||||
| 			errLog.Error("Handshake failed with edge server") | 			errLog.Error("Handshake failed with edge server") | ||||||
| 		default: | 		default: | ||||||
|  | @ -207,24 +214,21 @@ func ServeTunnel( | ||||||
| 	registerErr := <-registerErrC | 	registerErr := <-registerErrC | ||||||
| 	wg.Wait() | 	wg.Wait() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Error("Tunnel error") | 		Log.WithError(err).Error("Tunnel error") | ||||||
| 		return err, true | 		return err, true | ||||||
| 	} | 	} | ||||||
| 	if registerErr != nil { | 	if registerErr != nil { | ||||||
| 		// Don't retry on errors like entitlement failure or version too old
 | 		// Don't retry on errors like entitlement failure or version too old
 | ||||||
| 		if e, ok := registerErr.(printableRegisterTunnelError); ok { | 		if e, ok := registerErr.(printableRegisterTunnelError); ok { | ||||||
| 			log.Error(e) | 			Log.Error(e) | ||||||
| 			if e.permanent { | 			return e.cause, !e.permanent | ||||||
| 				return e, false |  | ||||||
| 			} |  | ||||||
| 			return e.cause, true |  | ||||||
| 		} else if e, ok := registerErr.(dupConnRegisterTunnelError); ok { | 		} else if e, ok := registerErr.(dupConnRegisterTunnelError); ok { | ||||||
| 			log.Info("Already connected to this server, selecting a different one") | 			Log.Info("Already connected to this server, selecting a different one") | ||||||
| 			return e, true | 			return e, true | ||||||
| 		} | 		} | ||||||
| 		// Only log errors to Sentry that may have been caused by the client side, to reduce dupes
 | 		// Only log errors to Sentry that may have been caused by the client side, to reduce dupes
 | ||||||
| 		raven.CaptureError(registerErr, nil) | 		raven.CaptureError(registerErr, nil) | ||||||
| 		log.Error("Cannot register") | 		Log.Error("Cannot register") | ||||||
| 		return err, true | 		return err, true | ||||||
| 	} | 	} | ||||||
| 	return nil, false | 	return nil, false | ||||||
|  | @ -241,7 +245,7 @@ func IsRPCStreamResponse(headers []h2mux.Header) bool { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error { | func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error { | ||||||
| 	logger := log.WithField("subsystem", "rpc") | 	logger := Log.WithField("subsystem", "rpc") | ||||||
| 	logger.Debug("initiating RPC stream") | 	logger.Debug("initiating RPC stream") | ||||||
| 	stream, err := muxer.OpenStream([]h2mux.Header{ | 	stream, err := muxer.OpenStream([]h2mux.Header{ | ||||||
| 		{Name: ":method", Value: "RPC"}, | 		{Name: ":method", Value: "RPC"}, | ||||||
|  | @ -292,11 +296,11 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	log.Infof("Registered at %s", registration.Url) | 	Log.Infof("Registered at %s", registration.Url) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func LogServerInfo(logger *log.Entry, | func LogServerInfo(logger *logrus.Entry, | ||||||
| 	promise tunnelrpc.ServerInfo_Promise, | 	promise tunnelrpc.ServerInfo_Promise, | ||||||
| 	connectionID uint8, | 	connectionID uint8, | ||||||
| 	metrics *TunnelMetrics, | 	metrics *TunnelMetrics, | ||||||
|  | @ -311,7 +315,7 @@ func LogServerInfo(logger *log.Entry, | ||||||
| 		logger.WithError(err).Warn("Failed to retrieve server information") | 		logger.WithError(err).Warn("Failed to retrieve server information") | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	log.Infof("Connected to %s", serverInfo.LocationName) | 	Log.Infof("Connected to %s", serverInfo.LocationName) | ||||||
| 	metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName) | 	metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -356,6 +360,7 @@ type TunnelHandler struct { | ||||||
| 	originUrl  string | 	originUrl  string | ||||||
| 	muxer      *h2mux.Muxer | 	muxer      *h2mux.Muxer | ||||||
| 	httpClient http.RoundTripper | 	httpClient http.RoundTripper | ||||||
|  | 	tlsConfig  *tls.Config | ||||||
| 	tags       []tunnelpogs.Tag | 	tags       []tunnelpogs.Tag | ||||||
| 	metrics    *TunnelMetrics | 	metrics    *TunnelMetrics | ||||||
| 	// connectionID is only used by metrics, and prometheus requires labels to be string
 | 	// connectionID is only used by metrics, and prometheus requires labels to be string
 | ||||||
|  | @ -373,6 +378,7 @@ func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, co | ||||||
| 	h := &TunnelHandler{ | 	h := &TunnelHandler{ | ||||||
| 		originUrl:    url, | 		originUrl:    url, | ||||||
| 		httpClient:   config.HTTPTransport, | 		httpClient:   config.HTTPTransport, | ||||||
|  | 		tlsConfig:    config.ClientTlsConfig, | ||||||
| 		tags:         config.Tags, | 		tags:         config.Tags, | ||||||
| 		metrics:      config.Metrics, | 		metrics:      config.Metrics, | ||||||
| 		connectionID: uint8ToString(connectionID), | 		connectionID: uint8ToString(connectionID), | ||||||
|  | @ -422,29 +428,45 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { | ||||||
| 	h.metrics.incrementRequests(h.connectionID) | 	h.metrics.incrementRequests(h.connectionID) | ||||||
| 	req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) | 	req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Panic("Unexpected error from http.NewRequest") | 		Log.WithError(err).Panic("Unexpected error from http.NewRequest") | ||||||
| 	} | 	} | ||||||
| 	err = H2RequestHeadersToH1Request(stream.Headers, req) | 	err = H2RequestHeadersToH1Request(stream.Headers, req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.WithError(err).Error("invalid request received") | 		Log.WithError(err).Error("invalid request received") | ||||||
| 	} | 	} | ||||||
| 	h.AppendTagHeaders(req) | 	h.AppendTagHeaders(req) | ||||||
|  | 
 | ||||||
|  | 	if websocket.IsWebSocketUpgrade(req) { | ||||||
|  | 		conn, response, err := websocket.ClientConnect(req, h.tlsConfig) | ||||||
|  | 		if err != nil { | ||||||
|  | 			h.logError(stream, err) | ||||||
|  | 		} else { | ||||||
|  | 			stream.WriteHeaders(H1ResponseToH2Response(response)) | ||||||
|  | 			defer conn.Close() | ||||||
|  | 			websocket.Stream(conn.UnderlyingConn(), stream) | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
| 		response, err := h.httpClient.RoundTrip(req) | 		response, err := h.httpClient.RoundTrip(req) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 		log.WithError(err).Error("HTTP request error") | 			h.logError(stream, err) | ||||||
| 		stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}}) |  | ||||||
| 		stream.Write([]byte("502 Bad Gateway")) |  | ||||||
| 		h.metrics.incrementResponses(h.connectionID, "502") |  | ||||||
| 		} else { | 		} else { | ||||||
| 			defer response.Body.Close() | 			defer response.Body.Close() | ||||||
| 			stream.WriteHeaders(H1ResponseToH2Response(response)) | 			stream.WriteHeaders(H1ResponseToH2Response(response)) | ||||||
| 			io.Copy(stream, response.Body) | 			io.Copy(stream, response.Body) | ||||||
| 			h.metrics.incrementResponses(h.connectionID, "200") | 			h.metrics.incrementResponses(h.connectionID, "200") | ||||||
| 		} | 		} | ||||||
|  | 	} | ||||||
| 	h.metrics.decrementConcurrentRequests(h.connectionID) | 	h.metrics.decrementConcurrentRequests(h.connectionID) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) { | ||||||
|  | 	Log.WithError(err).Error("HTTP request error") | ||||||
|  | 	stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}}) | ||||||
|  | 	stream.Write([]byte("502 Bad Gateway")) | ||||||
|  | 	h.metrics.incrementResponses(h.connectionID, "502") | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (h *TunnelHandler) UpdateMetrics() { | func (h *TunnelHandler) UpdateMetrics() { | ||||||
| 	flowCtlMetrics := h.muxer.FlowControlMetrics() | 	flowCtlMetrics := h.muxer.FlowControlMetrics() | ||||||
| 	h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics) | 	h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics) | ||||||
|  |  | ||||||
|  | @ -0,0 +1,95 @@ | ||||||
|  | package tlsconfig | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto/x509" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // TODO: remove the Origin CA root certs when migrated to Authenticated Origin Pull certs
 | ||||||
|  | const cloudflareRootCA = ` | ||||||
|  | Issuer: C=US, ST=California, L=San Francisco, O=CloudFlare, Inc., OU=CloudFlare Origin SSL ECC Certificate Authority | ||||||
|  | -----BEGIN CERTIFICATE----- | ||||||
|  | MIICiDCCAi6gAwIBAgIUXZP3MWb8MKwBE1Qbawsp1sfA/Y4wCgYIKoZIzj0EAwIw | ||||||
|  | gY8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1T | ||||||
|  | YW4gRnJhbmNpc2NvMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTgwNgYDVQQL | ||||||
|  | Ey9DbG91ZEZsYXJlIE9yaWdpbiBTU0wgRUNDIENlcnRpZmljYXRlIEF1dGhvcml0 | ||||||
|  | eTAeFw0xNjAyMjIxODI0MDBaFw0yMTAyMjIwMDI0MDBaMIGPMQswCQYDVQQGEwJV | ||||||
|  | UzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEZ | ||||||
|  | MBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE4MDYGA1UECxMvQ2xvdWRGbGFyZSBP | ||||||
|  | cmlnaW4gU1NMIEVDQyBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkwWTATBgcqhkjOPQIB | ||||||
|  | BggqhkjOPQMBBwNCAASR+sGALuaGshnUbcxKry+0LEXZ4NY6JUAtSeA6g87K3jaA | ||||||
|  | xpIg9G50PokpfWkhbarLfpcZu0UAoYy2su0EhN7wo2YwZDAOBgNVHQ8BAf8EBAMC | ||||||
|  | AQYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUhTBdOypw1O3VkmcH/es5 | ||||||
|  | tBoOOKcwHwYDVR0jBBgwFoAUhTBdOypw1O3VkmcH/es5tBoOOKcwCgYIKoZIzj0E | ||||||
|  | AwIDSAAwRQIgEiIEHQr5UKma50D1WRMJBUSgjg24U8n8E2mfw/8UPz0CIQCr5V/e | ||||||
|  | mcifak4CQsr+DH4pn5SJD7JxtCG3YGswW8QZsw== | ||||||
|  | -----END CERTIFICATE----- | ||||||
|  | Issuer: C=US, O=CloudFlare, Inc., OU=CloudFlare Origin SSL Certificate Authority, L=San Francisco, ST=California | ||||||
|  | -----BEGIN CERTIFICATE----- | ||||||
|  | MIID/DCCAuagAwIBAgIID+rOSdTGfGcwCwYJKoZIhvcNAQELMIGLMQswCQYDVQQG | ||||||
|  | EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRG | ||||||
|  | bGFyZSBPcmlnaW4gU1NMIENlcnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMN | ||||||
|  | U2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTAeFw0xNDExMTMyMDM4 | ||||||
|  | NTBaFw0xOTExMTQwMTQzNTBaMIGLMQswCQYDVQQGEwJVUzEZMBcGA1UEChMQQ2xv | ||||||
|  | dWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRGbGFyZSBPcmlnaW4gU1NMIENl | ||||||
|  | cnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEG | ||||||
|  | A1UECBMKQ2FsaWZvcm5pYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB | ||||||
|  | AMBIlWf1KEKR5hbB75OYrAcUXobpD/AxvSYRXr91mbRu+lqE7YbyyRUShQh15lem | ||||||
|  | ef+umeEtPZoLFLhcLyczJxOhI+siLGDQm/a/UDkWvAXYa5DZ+pHU5ct5nZ8pGzqJ | ||||||
|  | p8G1Hy5RMVYDXZT9F6EaHjMG0OOffH6Ih25TtgfyyrjXycwDH0u6GXt+G/rywcqz | ||||||
|  | /9W4Aki3XNQMUHNQAtBLEEIYHMkyTYJxuL2tXO6ID5cCsoWw8meHufTeZW2DyUpl | ||||||
|  | yP3AHt4149RQSyWZMJ6AyntL9d8Xhfpxd9rJkh9Kge2iV9rQTFuE1rRT5s7OSJcK | ||||||
|  | xUsklgHcGHYMcNfNMilNHb8CAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgAGMBIGA1Ud | ||||||
|  | EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFCToU1ddfDRAh6nrlNu64RZ4/CmkMB8G | ||||||
|  | A1UdIwQYMBaAFCToU1ddfDRAh6nrlNu64RZ4/CmkMAsGCSqGSIb3DQEBCwOCAQEA | ||||||
|  | cQDBVAoRrhhsGegsSFsv1w8v27zzHKaJNv6ffLGIRvXK8VKKK0gKXh2zQtN9SnaD | ||||||
|  | gYNe7Pr4C3I8ooYKRJJWLsmEHdGdnYYmj0OJfGrfQf6MLIc/11bQhLepZTxdhFYh | ||||||
|  | QGgDl6gRmb8aDwk7Q92BPvek5nMzaWlP82ixavvYI+okoSY8pwdcVKobx6rWzMWz | ||||||
|  | ZEC9M6H3F0dDYE23XcCFIdgNSAmmGyXPBstOe0aAJXwJTxOEPn36VWr0PKIQJy5Y | ||||||
|  | 4o1wpMpqCOIwWc8J9REV/REzN6Z1LXImdUgXIXOwrz56gKUJzPejtBQyIGj0mveX | ||||||
|  | Fu6q54beR89jDc+oABmOgg== | ||||||
|  | -----END CERTIFICATE----- | ||||||
|  | Issuer: C=US, O=CloudFlare, Inc., OU=Origin Pull, L=San Francisco, ST=California, CN=origin-pull.cloudflare.net | ||||||
|  | -----BEGIN CERTIFICATE----- | ||||||
|  | MIIGBjCCA/CgAwIBAgIIV5G6lVbCLmEwCwYJKoZIhvcNAQENMIGQMQswCQYDVQQG | ||||||
|  | EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjEUMBIGA1UECxMLT3JpZ2lu | ||||||
|  | IFB1bGwxFjAUBgNVBAcTDVNhbiBGcmFuY2lzY28xEzARBgNVBAgTCkNhbGlmb3Ju | ||||||
|  | aWExIzAhBgNVBAMTGm9yaWdpbi1wdWxsLmNsb3VkZmxhcmUubmV0MB4XDTE1MDEx | ||||||
|  | MzAyNDc1M1oXDTIwMDExMjAyNTI1M1owgZAxCzAJBgNVBAYTAlVTMRkwFwYDVQQK | ||||||
|  | ExBDbG91ZEZsYXJlLCBJbmMuMRQwEgYDVQQLEwtPcmlnaW4gUHVsbDEWMBQGA1UE | ||||||
|  | BxMNU2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTEjMCEGA1UEAxMa | ||||||
|  | b3JpZ2luLXB1bGwuY2xvdWRmbGFyZS5uZXQwggIiMA0GCSqGSIb3DQEBAQUAA4IC | ||||||
|  | DwAwggIKAoICAQDdsts6I2H5dGyn4adACQRXlfo0KmwsN7B5rxD8C5qgy6spyONr | ||||||
|  | WV0ecvdeGQfWa8Gy/yuTuOnsXfy7oyZ1dm93c3Mea7YkM7KNMc5Y6m520E9tHooc | ||||||
|  | f1qxeDpGSsnWc7HWibFgD7qZQx+T+yfNqt63vPI0HYBOYao6hWd3JQhu5caAcIS2 | ||||||
|  | ms5tzSSZVH83ZPe6Lkb5xRgLl3eXEFcfI2DjnlOtLFqpjHuEB3Tr6agfdWyaGEEi | ||||||
|  | lRY1IB3k6TfLTaSiX2/SyJ96bp92wvTSjR7USjDV9ypf7AD6u6vwJZ3bwNisNw5L | ||||||
|  | ptph0FBnc1R6nDoHmvQRoyytoe0rl/d801i9Nru/fXa+l5K2nf1koR3IX440Z2i9 | ||||||
|  | +Z4iVA69NmCbT4MVjm7K3zlOtwfI7i1KYVv+ATo4ycgBuZfY9f/2lBhIv7BHuZal | ||||||
|  | b9D+/EK8aMUfjDF4icEGm+RQfExv2nOpkR4BfQppF/dLmkYfjgtO1403X0ihkT6T | ||||||
|  | PYQdmYS6Jf53/KpqC3aA+R7zg2birtvprinlR14MNvwOsDOzsK4p8WYsgZOR4Qr2 | ||||||
|  | gAx+z2aVOs/87+TVOR0r14irQsxbg7uP2X4t+EXx13glHxwG+CnzUVycDLMVGvuG | ||||||
|  | aUgF9hukZxlOZnrl6VOf1fg0Caf3uvV8smOkVw6DMsGhBZSJVwao0UQNqQIDAQAB | ||||||
|  | o2YwZDAOBgNVHQ8BAf8EBAMCAAYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4E | ||||||
|  | FgQUQ1lLK2mLgOERM2pXzVc42p59xeswHwYDVR0jBBgwFoAUQ1lLK2mLgOERM2pX | ||||||
|  | zVc42p59xeswCwYJKoZIhvcNAQENA4ICAQDKDQM1qPRVP/4Gltz0D6OU6xezFBKr | ||||||
|  | LWtDoA1qW2F7pkiYawCP9MrDPDJsHy7dx+xw3bBZxOsK5PA/T7p1dqpEl6i8F692 | ||||||
|  | g//EuYOifLYw3ySPe3LRNhvPl/1f6Sn862VhPvLa8aQAAwR9e/CZvlY3fj+6G5ik
 | ||||||
|  | 3it7fikmKUsVnugNOkjmwI3hZqXfJNc7AtHDFw0mEOV0dSeAPTo95N9cxBbm9PKv | ||||||
|  | qAEmTEXp2trQ/RjJ/AomJyfA1BQjsD0j++DI3a9/BbDwWmr1lJciKxiNKaa0BRLB | ||||||
|  | dKMrYQD+PkPNCgEuojT+paLKRrMyFUzHSG1doYm46NE9/WARTh3sFUp1B7HZSBqA | ||||||
|  | kHleoB/vQ/mDuW9C3/8Jk2uRUdZxR+LoNZItuOjU8oTy6zpN1+GgSj7bHjiy9rfA | ||||||
|  | F+ehdrz+IOh80WIiqs763PGoaYUyzxLvVowLWNoxVVoc9G+PqFKqD988XlipHVB6 | ||||||
|  | Bz+1CD4D/bWrs3cC9+kk/jFmrrAymZlkFX8tDb5aXASSLJjUjcptci9SKqtI2h0J | ||||||
|  | wUGkD7+bQAr+7vr8/R+CBmNMe7csE8NeEX6lVMF7Dh0a1YKQa6hUN18bBuYgTMuT | ||||||
|  | QzMmZpRpIBB321ZBlcnlxiTJvWxvbCPHKHj20VwwAz7LONF59s84ZsOqfoBv8gKM | ||||||
|  | s0s5dsq5zpLeaw== | ||||||
|  | -----END CERTIFICATE-----` | ||||||
|  | 
 | ||||||
|  | func GetCloudflareRootCA() *x509.CertPool { | ||||||
|  | 	ca := x509.NewCertPool() | ||||||
|  | 	if !ca.AppendCertsFromPEM([]byte(cloudflareRootCA)) { | ||||||
|  | 		// should never happen
 | ||||||
|  | 		panic("failure loading Cloudflare origin CA pem") | ||||||
|  | 	} | ||||||
|  | 	return ca | ||||||
|  | } | ||||||
|  | @ -0,0 +1,50 @@ | ||||||
|  | package tlsconfig | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"crypto/x509" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	helloKey = ` | ||||||
|  | -----BEGIN EC PARAMETERS----- | ||||||
|  | BgUrgQQAIg== | ||||||
|  | -----END EC PARAMETERS----- | ||||||
|  | -----BEGIN EC PRIVATE KEY----- | ||||||
|  | MIGkAgEBBDAdyQBXfxTDCQSOT0HugmH9pVBtIw8t5dYvm6HxGlNq6P57v5GeN02Z | ||||||
|  | dH9FRl7+VSWgBwYFK4EEACKhZANiAATqpFzTxxV7D+/oqhKCTR6BEM9elTfKaRQE | ||||||
|  | FsLufcmaTMw/9tTwgpHKao/QsLKDTNbQhbSQLkcmpCQKlSGhl+pCrqNt/oYUAhav | ||||||
|  | UIwpwGiLCqGH/R2AqWLKRPOa/Rufs/U= | ||||||
|  | -----END EC PRIVATE KEY-----` | ||||||
|  | 
 | ||||||
|  | 	helloCRT = ` | ||||||
|  | -----BEGIN CERTIFICATE----- | ||||||
|  | MIICkDCCAhigAwIBAgIJAPtKfUjc2lwGMAkGByqGSM49BAEwgYoxCzAJBgNVBAYT | ||||||
|  | AlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQKDBBD | ||||||
|  | bG91ZGZsYXJlLCBJbmMuMT8wPQYDVQQDDDZDbG91ZGZsYXJlIEFyZ28gVHVubmVs | ||||||
|  | IFNhbXBsZSBIZWxsbyBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMjE1MjAxNjU5 | ||||||
|  | WhcNMjgwMjEzMjAxNjU5WjCBijELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVRleGFz | ||||||
|  | MQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENsb3VkZmxhcmUsIEluYy4xPzA9 | ||||||
|  | BgNVBAMMNkNsb3VkZmxhcmUgQXJnbyBUdW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZl | ||||||
|  | ciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49AgEGBSuBBAAiA2IABOqkXNPHFXsP7+iq | ||||||
|  | EoJNHoEQz16VN8ppFAQWwu59yZpMzD/21PCCkcpqj9CwsoNM1tCFtJAuRyakJAqV | ||||||
|  | IaGX6kKuo23+hhQCFq9QjCnAaIsKoYf9HYCpYspE85r9G5+z9aNJMEcwRQYDVR0R | ||||||
|  | BD4wPIIJbG9jYWxob3N0ggp3YXJwLWhlbGxvggt3YXJwMi1oZWxsb4cEfwAAAYcQ | ||||||
|  | AAAAAAAAAAAAAAAAAAAAATAJBgcqhkjOPQQBA2cAMGQCMHyVPufXZ6vQo6XRWRa0 | ||||||
|  | dAwtfgesOdZVP2Wt+t5v8jOIQQh1IQXYk5GtyoZGSObjhQIwd1fRgAyKXaZt+1DV | ||||||
|  | ZtHTdf8pMvESfJsSd8AB1eQ6q+pAiRUYyaxcE1Mlo2YY5o+g | ||||||
|  | -----END CERTIFICATE-----` | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func GetHelloCertificate() (tls.Certificate, error) { | ||||||
|  | 	return tls.X509KeyPair([]byte(helloCRT), []byte(helloKey)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func GetHelloCertificateX509() (*x509.Certificate, error) { | ||||||
|  | 	helloCertificate, err := GetHelloCertificate() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return x509.ParseCertificate(helloCertificate.Certificate[0]) | ||||||
|  | } | ||||||
|  | @ -6,8 +6,9 @@ import ( | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"crypto/x509" | 	"crypto/x509" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
|  | 	"net" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| 	cli "gopkg.in/urfave/cli.v2" | 	cli "gopkg.in/urfave/cli.v2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -60,3 +61,43 @@ func LoadCert(certPath string) *x509.CertPool { | ||||||
| 	} | 	} | ||||||
| 	return ca | 	return ca | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func LoadOriginCertsPool() *x509.CertPool { | ||||||
|  | 	// First, obtain the system certificate pool
 | ||||||
|  | 	certPool, systemCertPoolErr := x509.SystemCertPool() | ||||||
|  | 	if systemCertPoolErr != nil { | ||||||
|  | 		log.Warn("error obtaining the system certificates: %s", systemCertPoolErr) | ||||||
|  | 		certPool = x509.NewCertPool() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Next, append the Cloudflare CA pool into the system pool
 | ||||||
|  | 	if !certPool.AppendCertsFromPEM([]byte(cloudflareRootCA)) { | ||||||
|  | 		log.Warn("could not append the CF certificate to the system certificate pool") | ||||||
|  | 
 | ||||||
|  | 		if systemCertPoolErr != nil { // Obtaining both certificates failed; this is a fatal error
 | ||||||
|  | 			log.WithError(systemCertPoolErr).Fatalf("Error loading the certificate pool") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Finally, add the Hello certificate into the pool (since it's self-signed)
 | ||||||
|  | 	helloCertificate, err := GetHelloCertificateX509() | ||||||
|  | 	if err != nil { | ||||||
|  | 		log.Warn("error obtaining the Hello server certificate") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	certPool.AddCert(helloCertificate) | ||||||
|  | 
 | ||||||
|  | 	return certPool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func CreateTunnelConfig(c *cli.Context, addrs []string) *tls.Config { | ||||||
|  | 	tlsConfig := CLIFlags{RootCA: "cacert"}.GetConfig(c) | ||||||
|  | 	if tlsConfig.RootCAs == nil { | ||||||
|  | 		tlsConfig.RootCAs = GetCloudflareRootCA() | ||||||
|  | 		tlsConfig.ServerName = "cftunnel.com" | ||||||
|  | 	} else if len(addrs) > 0 { | ||||||
|  | 		// Set for development environments and for testing specific origintunneld instances
 | ||||||
|  | 		tlsConfig.ServerName, _, _ = net.SplitHostPort(addrs[0]) | ||||||
|  | 	} | ||||||
|  | 	return tlsConfig | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -1,10 +1,9 @@ | ||||||
| package tunnelrpc | package tunnelrpc | ||||||
| 
 | 
 | ||||||
| //go:generate capnp compile -ogo -I./tunnelrpc/ tunnelrpc.capnp
 |  | ||||||
| 
 |  | ||||||
| import ( | import ( | ||||||
| 	log "github.com/Sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
|  | 	"golang.org/x/net/trace" | ||||||
| 	"zombiezen.com/go/capnproto2/rpc" | 	"zombiezen.com/go/capnproto2/rpc" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -24,3 +23,20 @@ func (c ConnLogger) Errorf(ctx context.Context, format string, args ...interface | ||||||
| func ConnLog(log *log.Entry) rpc.ConnOption { | func ConnLog(log *log.Entry) rpc.ConnOption { | ||||||
| 	return rpc.ConnLog(ConnLogger{log}) | 	return rpc.ConnLog(ConnLogger{log}) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // ConnTracer wraps a trace.EventLog for a connection.
 | ||||||
|  | type ConnTracer struct { | ||||||
|  | 	Events trace.EventLog | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c ConnTracer) Infof(ctx context.Context, format string, args ...interface{}) { | ||||||
|  | 	c.Events.Printf(format, args...) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (c ConnTracer) Errorf(ctx context.Context, format string, args ...interface{}) { | ||||||
|  | 	c.Events.Errorf(format, args...) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func ConnTrace(events trace.EventLog) rpc.ConnOption { | ||||||
|  | 	return rpc.ConnLog(ConnTracer{events}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -4,7 +4,7 @@ package tunnelrpc | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 
 | 
 | ||||||
| 	log "github.com/Sirupsen/logrus" | 	log "github.com/sirupsen/logrus" | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| 	"zombiezen.com/go/capnproto2/encoding/text" | 	"zombiezen.com/go/capnproto2/encoding/text" | ||||||
| 	"zombiezen.com/go/capnproto2/rpc" | 	"zombiezen.com/go/capnproto2/rpc" | ||||||
|  |  | ||||||
|  | @ -47,10 +47,11 @@ type RegistrationOptions struct { | ||||||
| 	Version              string | 	Version              string | ||||||
| 	OS                   string `capnp:"os"` | 	OS                   string `capnp:"os"` | ||||||
| 	ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy | 	ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy | ||||||
| 	PoolID               string `capnp:"poolId"` | 	PoolName             string `capnp:"poolName"` | ||||||
| 	Tags                 []Tag | 	Tags                 []Tag | ||||||
| 	ConnectionID         uint8  `capnp:"connectionId"` | 	ConnectionID         uint8  `capnp:"connectionId"` | ||||||
| 	OriginLocalIP        string `capnp:"originLocalIp"` | 	OriginLocalIP        string `capnp:"originLocalIp"` | ||||||
|  | 	IsAutoupdated        bool   `capnp:"isAutoupdated"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error { | func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error { | ||||||
|  |  | ||||||
|  | @ -28,13 +28,15 @@ struct RegistrationOptions { | ||||||
|     # What to do with existing tunnels for the given hostname. |     # What to do with existing tunnels for the given hostname. | ||||||
|     existingTunnelPolicy @3 :ExistingTunnelPolicy; |     existingTunnelPolicy @3 :ExistingTunnelPolicy; | ||||||
|     # If using the balancing policy, identifies the LB pool to use. |     # If using the balancing policy, identifies the LB pool to use. | ||||||
|     poolId @4 :Text; |     poolName @4 :Text; | ||||||
|     # Client-defined tags to associate with the tunnel |     # Client-defined tags to associate with the tunnel | ||||||
|     tags @5 :List(Tag); |     tags @5 :List(Tag); | ||||||
|     # A unique identifier for a high-availability connection made by a single client. |     # A unique identifier for a high-availability connection made by a single client. | ||||||
|     connectionId @6 :UInt8; |     connectionId @6 :UInt8; | ||||||
|     # origin LAN IP |     # origin LAN IP | ||||||
|     originLocalIp @7 :Text; |     originLocalIp @7 :Text; | ||||||
|  |     # whether Warp client has been autoupdated | ||||||
|  |     isAutoupdated @8 :Bool; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| struct Tag { | struct Tag { | ||||||
|  |  | ||||||
|  | @ -0,0 +1,77 @@ | ||||||
|  | package websocket | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"crypto/sha1" | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"errors" | ||||||
|  | 	"io" | ||||||
|  | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 
 | ||||||
|  | 	"github.com/gorilla/websocket" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
 | ||||||
|  | func IsWebSocketUpgrade(req *http.Request) bool { | ||||||
|  | 	return websocket.IsWebSocketUpgrade(req) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing.
 | ||||||
|  | func ClientConnect(req *http.Request, tlsClientConfig *tls.Config) (*websocket.Conn, *http.Response, error) { | ||||||
|  | 	req.URL.Scheme = "wss" | ||||||
|  | 	d := &websocket.Dialer{TLSClientConfig: tlsClientConfig} | ||||||
|  | 	conn, response, err := d.Dial(req.URL.String(), nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, nil, err | ||||||
|  | 	} | ||||||
|  | 	response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req)) | ||||||
|  | 	return conn, response, err | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // HijackConnection takes over an HTTP connection. Caller is responsible for closing connection.
 | ||||||
|  | func HijackConnection(w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) { | ||||||
|  | 	hj, ok := w.(http.Hijacker) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, nil, errors.New("hijack error") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	conn, brw, err := hj.Hijack() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, nil, err | ||||||
|  | 	} | ||||||
|  | 	return conn, brw, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Stream copies copy data to & from provided io.ReadWriters.
 | ||||||
|  | func Stream(conn, backendConn io.ReadWriter) { | ||||||
|  | 	proxyDone := make(chan struct{}, 2) | ||||||
|  | 
 | ||||||
|  | 	go func() { | ||||||
|  | 		io.Copy(conn, backendConn) | ||||||
|  | 		proxyDone <- struct{}{} | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	go func() { | ||||||
|  | 		io.Copy(backendConn, conn) | ||||||
|  | 		proxyDone <- struct{}{} | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	// If one side is done, we are done.
 | ||||||
|  | 	<-proxyDone | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // sha1Base64 sha1 and then base64 encodes str.
 | ||||||
|  | func sha1Base64(str string) string { | ||||||
|  | 	hasher := sha1.New() | ||||||
|  | 	io.WriteString(hasher, str) | ||||||
|  | 	hash := hasher.Sum(nil) | ||||||
|  | 	return base64.StdEncoding.EncodeToString(hash) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // generateAcceptKey returns the string needed for the Sec-WebSocket-Accept header.
 | ||||||
|  | // https://tools.ietf.org/html/rfc6455#section-1.3 describes this process in more detail.
 | ||||||
|  | func generateAcceptKey(req *http.Request) string { | ||||||
|  | 	return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue