Release Warp Client 2018.2.1

This commit is contained in:
cloudflare-warp-bot 2018-02-20 21:13:56 +00:00
parent e0ae598112
commit 3780e14f41
25 changed files with 713 additions and 223 deletions

View File

@ -2,17 +2,20 @@ package main
import ( import (
"bytes" "bytes"
"crypto/tls"
"encoding/json"
"fmt" "fmt"
"html/template" "html/template"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"os" "os"
"time"
"github.com/pkg/errors" "github.com/gorilla/websocket"
"gopkg.in/urfave/cli.v2"
log "github.com/Sirupsen/logrus" "github.com/cloudflare/cloudflare-warp/tlsconfig"
cli "gopkg.in/urfave/cli.v2"
) )
type templateData struct { type templateData struct {
@ -21,6 +24,11 @@ type templateData struct {
Body string Body string
} }
type OriginUpTime struct {
StartTime time.Time `json:"startTime"`
UpTime string `json:"uptime"`
}
const defaultServerName = "the Cloudflare Warp test server" const defaultServerName = "the Cloudflare Warp test server"
const indexTemplate = ` const indexTemplate = `
<!DOCTYPE html> <!DOCTYPE html>
@ -85,54 +93,95 @@ const indexTemplate = `
func hello(c *cli.Context) error { func hello(c *cli.Context) error {
address := fmt.Sprintf(":%d", c.Int("port")) address := fmt.Sprintf(":%d", c.Int("port"))
server := NewHelloWorldServer() listener, err := createListener(address)
if hostname, err := os.Hostname(); err != nil { if err != nil {
server.serverName = hostname return err
} }
err := server.ListenAndServe(address) defer listener.Close()
return errors.Wrap(err, "Fail to start Hello World Server") err = startHelloWorldServer(listener, nil)
return err
} }
func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error { func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
server := NewHelloWorldServer() Log.Infof("Starting Hello World server at %s", listener.Addr())
if hostname, err := os.Hostname(); err != nil { serverName := defaultServerName
server.serverName = hostname if hostname, err := os.Hostname(); err == nil {
serverName = hostname
} }
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: server}
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
go func() { go func() {
<-shutdownC <-shutdownC
httpServer.Close() httpServer.Close()
}() }()
http.HandleFunc("/uptime", uptimeHandler(time.Now()))
http.HandleFunc("/ws", websocketHandler(upgrader))
http.HandleFunc("/", rootHandler(serverName))
err := httpServer.Serve(listener) err := httpServer.Serve(listener)
return err return err
} }
type HelloWorldServer struct { func createListener(address string) (net.Listener, error) {
responseTemplate *template.Template certificate, err := tlsconfig.GetHelloCertificate()
serverName string if err != nil {
return nil, err
} }
func NewHelloWorldServer() *HelloWorldServer { // If the port in address is empty, a port number is automatically chosen
return &HelloWorldServer{ listener, err := tls.Listen(
responseTemplate: template.Must(template.New("index").Parse(indexTemplate)), "tcp",
serverName: defaultServerName, address,
} &tls.Config{Certificates: []tls.Certificate{certificate}})
}
func findAvailablePort() (net.Listener, error) {
// If the port in address is empty, a port number is automatically chosen.
listener, err := net.Listen("tcp", "127.0.0.1:")
return listener, err return listener, err
} }
func (s *HelloWorldServer) ListenAndServe(address string) error { func uptimeHandler(startTime time.Time) http.HandlerFunc {
log.Infof("Starting Hello World server on %s", address) return func(w http.ResponseWriter, r *http.Request) {
err := http.ListenAndServe(address, s) // Note that if autoupdate is enabled, the uptime is reset when a new client
return err // release is available
resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()}
respJson, err := json.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
} else {
w.Header().Set("Content-Type", "application/json")
w.Write(respJson)
}
}
} }
func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc {
log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto) return func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
if err := conn.WriteMessage(mt, message); err != nil {
break
}
}
}
}
func rootHandler(serverName string) http.HandlerFunc {
responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
return func(w http.ResponseWriter, r *http.Request) {
Log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto)
var buffer bytes.Buffer var buffer bytes.Buffer
var body string var body string
rawBody, err := ioutil.ReadAll(r.Body) rawBody, err := ioutil.ReadAll(r.Body)
@ -141,8 +190,8 @@ func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} else { } else {
body = "" body = ""
} }
err = s.responseTemplate.Execute(&buffer, &templateData{ err = responseTemplate.Execute(&buffer, &templateData{
ServerName: s.serverName, ServerName: serverName,
Request: r, Request: r,
Body: body, Body: body,
}) })
@ -153,3 +202,4 @@ func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
buffer.WriteTo(w) buffer.WriteTo(w)
} }
} }
}

View File

@ -4,18 +4,30 @@ import (
"testing" "testing"
) )
const testPort = "8080" func TestCreateListenerHostAndPortSuccess(t *testing.T) {
listener, err := createListener("localhost:1234")
func TestNewHelloWorldServer(t *testing.T) {
if NewHelloWorldServer() == nil {
t.Fatal("NewHelloWorldServer returned nil")
}
}
func TestFindAvailablePort(t *testing.T) {
listener, err := findAvailablePort()
if err != nil { if err != nil {
t.Fatal("Fail to find available port") t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}
func TestCreateListenerOnlyHostSuccess(t *testing.T) {
listener, err := createListener("localhost:")
if err != nil {
t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}
func TestCreateListenerOnlyPortSuccess(t *testing.T) {
listener, err := createListener(":8888")
if err != nil {
t.Fatal(err)
} }
if listener.Addr().String() == "" { if listener.Addr().String() == "" {
t.Fatal("Fail to find available port") t.Fatal("Fail to find available port")

View File

@ -42,6 +42,8 @@ After=network.target
TimeoutStartSec=0 TimeoutStartSec=0
Type=notify Type=notify
ExecStart={{ .Path }} --config /etc/cloudflare-warp/config.yml --origincert /etc/cloudflare-warp/cert.pem --no-autoupdate ExecStart={{ .Path }} --config /etc/cloudflare-warp/config.yml --origincert /etc/cloudflare-warp/cert.pem --no-autoupdate
Restart=on-failure
RestartSec=5s
[Install] [Install]
WantedBy=multi-user.target WantedBy=multi-user.target

View File

@ -14,7 +14,6 @@ import (
"syscall" "syscall"
"time" "time"
log "github.com/Sirupsen/logrus"
homedir "github.com/mitchellh/go-homedir" homedir "github.com/mitchellh/go-homedir"
cli "gopkg.in/urfave/cli.v2" cli "gopkg.in/urfave/cli.v2"
) )
@ -137,7 +136,7 @@ func download(certURL, filePath string) bool {
return true return true
} }
if err != nil { if err != nil {
log.WithError(err).Error("Error fetching certificate") Log.WithError(err).Error("Error fetching certificate")
return false return false
} }
} }
@ -180,16 +179,16 @@ func putSuccess(client *http.Client, certURL string) {
// indicate success to the relay server // indicate success to the relay server
req, err := http.NewRequest("PUT", certURL+"/ok", nil) req, err := http.NewRequest("PUT", certURL+"/ok", nil)
if err != nil { if err != nil {
log.WithError(err).Error("HTTP request error") Log.WithError(err).Error("HTTP request error")
return return
} }
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
log.WithError(err).Error("HTTP error") Log.WithError(err).Error("HTTP error")
return return
} }
resp.Body.Close() resp.Body.Close()
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
log.Errorf("Unexpected HTTP error code %d", resp.StatusCode) Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
} }
} }

View File

@ -11,6 +11,8 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"runtime"
"strings"
"sync" "sync"
"syscall" "syscall"
"time" "time"
@ -21,11 +23,12 @@ import (
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
"github.com/cloudflare/cloudflare-warp/validation" "github.com/cloudflare/cloudflare-warp/validation"
log "github.com/Sirupsen/logrus"
"github.com/facebookgo/grace/gracenet" "github.com/facebookgo/grace/gracenet"
raven "github.com/getsentry/raven-go" "github.com/getsentry/raven-go"
homedir "github.com/mitchellh/go-homedir" "github.com/mitchellh/go-homedir"
cli "gopkg.in/urfave/cli.v2" "github.com/rifflock/lfshook"
"github.com/sirupsen/logrus"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc" "gopkg.in/urfave/cli.v2/altsrc"
"github.com/coreos/go-systemd/daemon" "github.com/coreos/go-systemd/daemon"
@ -40,11 +43,21 @@ const configFile = "config.yml"
var listeners = gracenet.Net{} var listeners = gracenet.Net{}
var Version = "DEV" var Version = "DEV"
var BuildTime = "unknown" var BuildTime = "unknown"
var Log *logrus.Logger
// Shutdown channel used by the app. When closed, app must terminate. // Shutdown channel used by the app. When closed, app must terminate.
// May be closed by the Windows service runner. // May be closed by the Windows service runner.
var shutdownC chan struct{} var shutdownC chan struct{}
type BuildAndRuntimeInfo struct {
GoOS string `json:"go_os"`
GoVersion string `json:"go_version"`
GoArch string `json:"go_arch"`
WarpVersion string `json:"warp_version"`
WarpFlags map[string]interface{} `json:"warp_flags"`
WarpEnvs map[string]string `json:"warp_envs"`
}
func main() { func main() {
metrics.RegisterBuildInfo(BuildTime, Version) metrics.RegisterBuildInfo(BuildTime, Version)
raven.SetDSN(sentryDSN) raven.SetDSN(sentryDSN)
@ -84,6 +97,12 @@ WARNING:
Usage: "Disable periodic check for updates, restarting the server with the new version.", Usage: "Disable periodic check for updates, restarting the server with the new version.",
Value: false, Value: false,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "is-autoupdated",
Usage: "Signal the new process that Warp client has been autoupdated",
Value: false,
Hidden: true,
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "edge", Name: "edge",
Usage: "Address of the Cloudflare tunnel server.", Usage: "Address of the Cloudflare tunnel server.",
@ -99,12 +118,12 @@ WARNING:
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: "origincert", Name: "origincert",
Usage: "Path to the certificate generated for your origin when you run cloudflare-warp login.", Usage: "Path to the certificate generated for your origin when you run cloudflare-warp login.",
EnvVars: []string{"ORIGIN_CERT"}, EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: filepath.Join(defaultConfigDir, credentialFile), Value: filepath.Join(defaultConfigDir, credentialFile),
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: "url", Name: "url",
Value: "http://localhost:8080", Value: "https://localhost:8080",
Usage: "Connect to the local webserver at `URL`.", Usage: "Connect to the local webserver at `URL`.",
EnvVars: []string{"TUNNEL_URL"}, EnvVars: []string{"TUNNEL_URL"},
}), }),
@ -191,14 +210,20 @@ WARNING:
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "hello-world", Name: "hello-world",
Usage: "Run Hello World Server",
Value: false, Value: false,
Usage: "Run Hello World Server",
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: "pidfile", Name: "pidfile",
Usage: "Write the application's PID to this file after first successful connection.", Usage: "Write the application's PID to this file after first successful connection.",
EnvVars: []string{"TUNNEL_PIDFILE"}, EnvVars: []string{"TUNNEL_PIDFILE"},
}), }),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "logfile",
Usage: "Save application log to this file for reporting issues.",
EnvVars: []string{"TUNNEL_LOGFILE"},
}),
altsrc.NewIntFlag(&cli.IntFlag{ altsrc.NewIntFlag(&cli.IntFlag{
Name: "ha-connections", Name: "ha-connections",
Value: 4, Value: 4,
@ -239,6 +264,7 @@ WARNING:
return nil return nil
} }
app.Before = func(context *cli.Context) error { app.Before = func(context *cli.Context) error {
Log = logrus.New()
inputSource, err := findInputSourceContext(context) inputSource, err := findInputSourceContext(context)
if err != nil { if err != nil {
return err return err
@ -248,7 +274,7 @@ WARNING:
return nil return nil
} }
app.Commands = []*cli.Command{ app.Commands = []*cli.Command{
&cli.Command{ {
Name: "update", Name: "update",
Action: update, Action: update,
Usage: "Update the agent if a new version exists", Usage: "Update the agent if a new version exists",
@ -259,7 +285,7 @@ WARNING:
To determine if an update happened in a script, check for error code 64.`, To determine if an update happened in a script, check for error code 64.`,
}, },
&cli.Command{ {
Name: "login", Name: "login",
Action: login, Action: login,
Usage: "Generate a configuration file with your login details", Usage: "Generate a configuration file with your login details",
@ -271,7 +297,7 @@ WARNING:
}, },
}, },
}, },
&cli.Command{ {
Name: "hello", Name: "hello",
Action: hello, Action: hello,
Usage: "Run a simple \"Hello World\" server for testing Cloudflare Warp.", Usage: "Run a simple \"Hello World\" server for testing Cloudflare Warp.",
@ -293,27 +319,43 @@ func startServer(c *cli.Context) {
errC := make(chan error) errC := make(chan error)
wg.Add(2) wg.Add(2)
if c.NumFlags() == 0 && c.NArg() == 0 { // If the user choose to supply all options through env variables,
// c.NumFlags() == 0 && c.NArg() == 0. For warp to work, the user needs to at
// least provide a hostname.
if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" {
cli.ShowAppHelp(c) cli.ShowAppHelp(c)
return return
} }
logLevel, err := logrus.ParseLevel(c.String("loglevel"))
logLevel, err := log.ParseLevel(c.String("loglevel"))
if err != nil { if err != nil {
log.WithError(err).Fatal("Unknown logging level specified") Log.WithError(err).Fatal("Unknown logging level specified")
} }
logrus.SetLevel(logLevel)
log.SetLevel(logLevel) protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
protoLogLevel, err := log.ParseLevel(c.String("proto-loglevel"))
if err != nil { if err != nil {
log.WithError(err).Fatal("Unknown protocol logging level specified") Log.WithError(err).Fatal("Unknown protocol logging level specified")
} }
protoLogger := log.New() protoLogger := logrus.New()
protoLogger.Level = protoLogLevel protoLogger.Level = protoLogLevel
if c.String("logfile") != "" {
if err := initLogFile(c, protoLogger); err != nil {
Log.Error(err)
}
}
if !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0 {
if initUpdate() {
return
}
Log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
go autoupdate(c.Duration("autoupdate-freq"), shutdownC)
}
hostname, err := validation.ValidateHostname(c.String("hostname")) hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil { if err != nil {
log.WithError(err).Fatal("Invalid hostname") Log.WithError(err).Fatal("Invalid hostname")
} }
clientID := c.String("id") clientID := c.String("id")
@ -323,46 +365,44 @@ func startServer(c *cli.Context) {
tags, err := NewTagSliceFromCLI(c.StringSlice("tag")) tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
if err != nil { if err != nil {
log.WithError(err).Fatal("Tag parse failure") Log.WithError(err).Fatal("Tag parse failure")
} }
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID}) tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
if c.IsSet("hello-world") { if c.IsSet("hello-world") {
wg.Add(1) wg.Add(1)
listener, err := findAvailablePort() listener, err := createListener("127.0.0.1:")
if err != nil { if err != nil {
listener.Close() listener.Close()
log.WithError(err).Fatal("Cannot start Hello World Server") Log.WithError(err).Fatal("Cannot start Hello World Server")
} }
go func() { go func() {
startHelloWorldServer(listener, shutdownC) startHelloWorldServer(listener, shutdownC)
wg.Done() wg.Done()
listener.Close() listener.Close()
}() }()
c.Set("url", "http://"+listener.Addr().String()) c.Set("url", "https://"+listener.Addr().String())
log.Infof("Starting Hello World Server at %s", c.String("url"))
} }
url, err := validateUrl(c) url, err := validateUrl(c)
if err != nil { if err != nil {
log.WithError(err).Fatal("Error validating url") Log.WithError(err).Fatal("Error validating url")
} }
log.Infof("Proxying tunnel requests to %s", url) Log.Infof("Proxying tunnel requests to %s", url)
// Fail if the user provided an old authentication method // Fail if the user provided an old authentication method
if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") { if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login") Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login")
} }
// Check that the user has acquired a certificate using the log in command // Check that the user has acquired a certificate using the log in command
originCertPath, err := homedir.Expand(c.String("origincert")) originCertPath, err := homedir.Expand(c.String("origincert"))
if err != nil { if err != nil {
log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert")) Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
} }
ok, err := fileExists(originCertPath) ok, err := fileExists(originCertPath)
if !ok { if !ok {
log.Fatalf(`Cannot find a valid certificate for your origin at the path: Log.Fatalf(`Cannot find a valid certificate for your origin at the path:
%s %s
@ -375,8 +415,9 @@ If you don't have a certificate signed by Cloudflare, run the command:
// Easier to send the certificate as []byte via RPC than decoding it at this point // Easier to send the certificate as []byte via RPC than decoding it at this point
originCert, err := ioutil.ReadFile(originCertPath) originCert, err := ioutil.ReadFile(originCertPath)
if err != nil { if err != nil {
log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath) Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
} }
tunnelMetrics := origin.NewTunnelMetrics() tunnelMetrics := origin.NewTunnelMetrics()
httpTransport := &http.Transport{ httpTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
@ -389,13 +430,15 @@ If you don't have a certificate signed by Cloudflare, run the command:
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"), IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"), TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()},
} }
tunnelConfig := &origin.TunnelConfig{ tunnelConfig := &origin.TunnelConfig{
EdgeAddrs: c.StringSlice("edge"), EdgeAddrs: c.StringSlice("edge"),
OriginUrl: url, OriginUrl: url,
Hostname: hostname, Hostname: hostname,
OriginCert: originCert, OriginCert: originCert,
TlsConfig: &tls.Config{}, TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
ClientTlsConfig: httpTransport.TLSClientConfig,
Retries: c.Uint("retries"), Retries: c.Uint("retries"),
HeartbeatInterval: c.Duration("heartbeat-interval"), HeartbeatInterval: c.Duration("heartbeat-interval"),
MaxHeartbeats: c.Uint64("heartbeat-count"), MaxHeartbeats: c.Uint64("heartbeat-count"),
@ -408,18 +451,11 @@ If you don't have a certificate signed by Cloudflare, run the command:
Metrics: tunnelMetrics, Metrics: tunnelMetrics,
MetricsUpdateFreq: c.Duration("metrics-update-freq"), MetricsUpdateFreq: c.Duration("metrics-update-freq"),
ProtocolLogger: protoLogger, ProtocolLogger: protoLogger,
Logger: Log,
IsAutoupdated: c.Bool("is-autoupdated"),
} }
connectedSignal := make(chan struct{}) connectedSignal := make(chan struct{})
tunnelConfig.TlsConfig = tlsconfig.CLIFlags{RootCA: "cacert"}.GetConfig(c)
if tunnelConfig.TlsConfig.RootCAs == nil {
tunnelConfig.TlsConfig.RootCAs = GetCloudflareRootCA()
tunnelConfig.TlsConfig.ServerName = "cftunnel.com"
} else if len(tunnelConfig.EdgeAddrs) > 0 {
// Set for development environments and for testing specific origintunneld instances
tunnelConfig.TlsConfig.ServerName, _, _ = net.SplitHostPort(tunnelConfig.EdgeAddrs[0])
}
go writePidFile(connectedSignal, c.String("pidfile")) go writePidFile(connectedSignal, c.String("pidfile"))
go func() { go func() {
errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal) errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal)
@ -428,24 +464,21 @@ If you don't have a certificate signed by Cloudflare, run the command:
metricsListener, err := listeners.Listen("tcp", c.String("metrics")) metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
if err != nil { if err != nil {
log.WithError(err).Fatal("Error opening metrics server listener") Log.WithError(err).Fatal("Error opening metrics server listener")
} }
go func() { go func() {
errC <- metrics.ServeMetrics(metricsListener, shutdownC) errC <- metrics.ServeMetrics(metricsListener, shutdownC)
wg.Done() wg.Done()
}() }()
if !c.Bool("no-autoupdate") { var errCode int
log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
go autoupdate(c.Duration("autoupdate-period"), shutdownC)
}
err = WaitForSignal(errC, shutdownC) err = WaitForSignal(errC, shutdownC)
if err != nil { if err != nil {
log.WithError(err).Error("Quitting due to error") Log.WithError(err).Error("Quitting due to error")
raven.CaptureErrorAndWait(err, nil) raven.CaptureErrorAndWait(err, nil)
errCode = 1
} else { } else {
log.Info("Quitting...") Log.Info("Quitting...")
} }
// Wait for clean exit, discarding all errors // Wait for clean exit, discarding all errors
go func() { go func() {
@ -453,6 +486,7 @@ If you don't have a certificate signed by Cloudflare, run the command:
} }
}() }()
wg.Wait() wg.Wait()
os.Exit(errCode)
} }
func WaitForSignal(errC chan error, shutdownC chan struct{}) error { func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
@ -477,30 +511,40 @@ func update(c *cli.Context) error {
return nil return nil
} }
func autoupdate(frequency time.Duration, shutdownC chan struct{}) { func initUpdate() bool {
if int64(frequency) == 0 { if updateApplied() {
return os.Args = append(os.Args, "--is-autoupdated=true")
if _, err := listeners.StartProcess(); err != nil {
Log.WithError(err).Error("Unable to restart server automatically")
return false
} }
return true
}
return false
}
func autoupdate(freq time.Duration, shutdownC chan struct{}) {
for { for {
if updateApplied() { if updateApplied() {
os.Args = append(os.Args, "--is-autoupdated=true")
if _, err := listeners.StartProcess(); err != nil { if _, err := listeners.StartProcess(); err != nil {
log.WithError(err).Error("Unable to restart server automatically") Log.WithError(err).Error("Unable to restart server automatically")
} }
close(shutdownC) close(shutdownC)
return return
} }
time.Sleep(frequency) time.Sleep(freq)
} }
} }
func updateApplied() bool { func updateApplied() bool {
releaseInfo := checkForUpdates() releaseInfo := checkForUpdates()
if releaseInfo.Updated { if releaseInfo.Updated {
log.Infof("Updated to version %s", releaseInfo.Version) Log.Infof("Updated to version %s", releaseInfo.Version)
return true return true
} }
if releaseInfo.Error != nil { if releaseInfo.Error != nil {
log.WithError(releaseInfo.Error).Error("Update check failed") Log.WithError(releaseInfo.Error).Error("Update check failed")
} }
return false return false
} }
@ -555,7 +599,7 @@ func writePidFile(waitForSignal chan struct{}, pidFile string) {
} }
file, err := os.Create(pidFile) file, err := os.Create(pidFile)
if err != nil { if err != nil {
log.WithError(err).Errorf("Unable to write pid to %s", pidFile) Log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
} }
defer file.Close() defer file.Close()
fmt.Fprintf(file, "%d", os.Getpid()) fmt.Fprintf(file, "%d", os.Getpid())
@ -573,3 +617,55 @@ func validateUrl(c *cli.Context) (string, error) {
validUrl, err := validation.ValidateUrl(url) validUrl, err := validation.ValidateUrl(url)
return validUrl, err return validUrl, err
} }
func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error {
fileMode := os.O_WRONLY|os.O_APPEND|os.O_CREATE|os.O_TRUNC
// do not truncate log file if the client has been autoupdated
if c.Bool("is-autoupdated") {
fileMode = os.O_WRONLY|os.O_APPEND|os.O_CREATE
}
f, err := os.OpenFile(c.String("logfile"), fileMode, 0664)
if err != nil {
errors.Wrap(err, fmt.Sprintf("Cannot open file %s", c.String("logfile")))
}
defer f.Close()
pathMap := lfshook.PathMap{
logrus.InfoLevel: c.String("logfile"),
logrus.ErrorLevel: c.String("logfile"),
logrus.FatalLevel: c.String("logfile"),
logrus.PanicLevel: c.String("logfile"),
}
Log.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
flags := make(map[string]interface{})
envs := make(map[string]string)
for _, flag := range c.LocalFlagNames() {
flags[flag] = c.Generic(flag)
}
// Find env variables for Warp
for _, env := range os.Environ() {
// All Warp env variables start with TUNNEL_
if strings.Contains(env, "TUNNEL_") {
vars := strings.Split(env, "=")
if len(vars) == 2 {
envs[vars[0]] = vars[1]
}
}
}
Log.Infof("Warp build and runtime configuration: %+v", BuildAndRuntimeInfo{
GoOS: runtime.GOOS,
GoVersion: runtime.Version(),
GoArch: runtime.GOARCH,
WarpVersion: Version,
WarpFlags: flags,
WarpEnvs: envs,
})
return nil
}

View File

@ -9,7 +9,7 @@ import (
"fmt" "fmt"
"os" "os"
log "github.com/Sirupsen/logrus" log "github.com/sirupsen/logrus"
cli "gopkg.in/urfave/cli.v2" cli "gopkg.in/urfave/cli.v2"
"golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc"

View File

@ -6,7 +6,7 @@ import (
"sync" "sync"
"time" "time"
log "github.com/Sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
) )

View File

@ -6,13 +6,14 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"net"
"os" "os"
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
log "github.com/Sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -25,25 +26,20 @@ func TestMain(m *testing.M) {
type DefaultMuxerPair struct { type DefaultMuxerPair struct {
OriginMuxConfig MuxerConfig OriginMuxConfig MuxerConfig
OriginMux *Muxer OriginMux *Muxer
OriginWriter *io.PipeWriter OriginConn net.Conn
OriginReader *io.PipeReader
EdgeMuxConfig MuxerConfig EdgeMuxConfig MuxerConfig
EdgeMux *Muxer EdgeMux *Muxer
EdgeWriter *io.PipeWriter EdgeConn net.Conn
EdgeReader *io.PipeReader
doneC chan struct{} doneC chan struct{}
} }
func NewDefaultMuxerPair() *DefaultMuxerPair { func NewDefaultMuxerPair() *DefaultMuxerPair {
originReader, edgeWriter := io.Pipe() origin, edge := net.Pipe()
edgeReader, originWriter := io.Pipe()
return &DefaultMuxerPair{ return &DefaultMuxerPair{
OriginMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: true, Name: "origin"}, OriginMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: true, Name: "origin"},
OriginWriter: originWriter, OriginConn: origin,
OriginReader: originReader,
EdgeMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: false, Name: "edge"}, EdgeMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: false, Name: "edge"},
EdgeWriter: edgeWriter, EdgeConn: edge,
EdgeReader: edgeReader,
doneC: make(chan struct{}), doneC: make(chan struct{}),
} }
} }
@ -53,12 +49,12 @@ func (p *DefaultMuxerPair) Handshake(t *testing.T) {
originErrC := make(chan error) originErrC := make(chan error)
go func() { go func() {
var err error var err error
p.EdgeMux, err = Handshake(p.EdgeWriter, p.EdgeReader, p.EdgeMuxConfig) p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig)
edgeErrC <- err edgeErrC <- err
}() }()
go func() { go func() {
var err error var err error
p.OriginMux, err = Handshake(p.OriginWriter, p.OriginReader, p.OriginMuxConfig) p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig)
originErrC <- err originErrC <- err
}() }()
@ -120,8 +116,8 @@ func (p *DefaultMuxerPair) Wait(t *testing.T) {
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
muxPair := NewDefaultMuxerPair() muxPair := NewDefaultMuxerPair()
muxPair.Handshake(t) muxPair.Handshake(t)
AssertIfPipeReadable(t, muxPair.OriginReader) AssertIfPipeReadable(t, muxPair.OriginConn)
AssertIfPipeReadable(t, muxPair.EdgeReader) AssertIfPipeReadable(t, muxPair.EdgeConn)
} }
func TestSingleStream(t *testing.T) { func TestSingleStream(t *testing.T) {
@ -145,7 +141,7 @@ func TestSingleStream(t *testing.T) {
stream.Write(buf) stream.Write(buf)
// after this receive, the edge closed the stream // after this receive, the edge closed the stream
<-closeC <-closeC
n, err := stream.Read(buf) n, err := io.ReadFull(stream, buf)
if n > 0 { if n > 0 {
t.Fatalf("read %d bytes after EOF", n) t.Fatalf("read %d bytes after EOF", n)
} }
@ -173,7 +169,7 @@ func TestSingleStream(t *testing.T) {
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value) t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
} }
responseBody := make([]byte, 11) responseBody := make([]byte, 11)
n, err := stream.Read(responseBody) n, err := io.ReadFull(stream, responseBody)
if err != nil { if err != nil {
t.Fatalf("error from (*MuxedStream).Read: %s", err) t.Fatalf("error from (*MuxedStream).Read: %s", err)
} }
@ -243,7 +239,7 @@ func TestSingleStreamLargeResponseBody(t *testing.T) {
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value) t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
} }
responseBody := make([]byte, bodySize) responseBody := make([]byte, bodySize)
n, err := stream.Read(responseBody) n, err := io.ReadFull(stream, responseBody)
if err != nil { if err != nil {
t.Fatalf("error from (*MuxedStream).Read: %s", err) t.Fatalf("error from (*MuxedStream).Read: %s", err)
} }
@ -302,7 +298,7 @@ func TestMultipleStreams(t *testing.T) {
return return
} }
responseBody := make([]byte, 2) responseBody := make([]byte, 2)
n, err := stream.Read(responseBody) n, err := io.ReadFull(stream, responseBody)
if err != nil { if err != nil {
errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err) errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err)
return return
@ -392,7 +388,7 @@ func TestMultipleStreamsFlowControl(t *testing.T) {
} }
responseBody := make([]byte, responseSizes[(stream.streamID-2)/2]) responseBody := make([]byte, responseSizes[(stream.streamID-2)/2])
n, err := stream.Read(responseBody) n, err := io.ReadFull(stream, responseBody)
if err != nil { if err != nil {
errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err) errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err)
return return
@ -451,7 +447,7 @@ func TestGracefulShutdown(t *testing.T) {
} }
responseBody := make([]byte, len(responseBuf)) responseBody := make([]byte, len(responseBuf))
log.Debugf("Waiting for %d bytes", len(responseBuf)) log.Debugf("Waiting for %d bytes", len(responseBuf))
n, err := stream.Read(responseBody) n, err := io.ReadFull(stream, responseBody)
if err != nil { if err != nil {
t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err) t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err)
} }
@ -498,13 +494,13 @@ func TestUnexpectedShutdown(t *testing.T) {
nil, nil,
) )
// Close the underlying connection before telling the origin to write. // Close the underlying connection before telling the origin to write.
muxPair.EdgeReader.Close() muxPair.EdgeConn.Close()
close(sendC) close(sendC)
if err != nil { if err != nil {
t.Fatalf("error in OpenStream: %s", err) t.Fatalf("error in OpenStream: %s", err)
} }
responseBody := make([]byte, len(responseBuf)) responseBody := make([]byte, len(responseBuf))
n, err := stream.Read(responseBody) n, err := io.ReadFull(stream, responseBody)
if err != io.EOF { if err != io.EOF {
t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err) t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err)
} }
@ -545,14 +541,14 @@ func TestOpenAfterDisconnect(t *testing.T) {
switch i { switch i {
case 0: case 0:
// Close both directions of the connection to cause EOF on both peers. // Close both directions of the connection to cause EOF on both peers.
muxPair.OriginReader.Close() muxPair.OriginConn.Close()
muxPair.OriginWriter.Close() muxPair.EdgeConn.Close()
case 1: case 1:
// Close origin reader (edge writer) to cause EOF on origin only. // Close origin conn to cause EOF on origin first.
muxPair.OriginReader.Close() muxPair.OriginConn.Close()
case 2: case 2:
// Close origin writer (edge reader) to cause EOF on edge only. // Close edge conn to cause EOF on edge first.
muxPair.OriginWriter.Close() muxPair.EdgeConn.Close()
} }
_, err := muxPair.EdgeMux.OpenStream( _, err := muxPair.EdgeMux.OpenStream(
@ -623,7 +619,7 @@ func TestHPACK(t *testing.T) {
} }
} }
func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) { func AssertIfPipeReadable(t *testing.T, pipe io.ReadCloser) {
errC := make(chan error) errC := make(chan error)
go func() { go func() {
b := []byte{0} b := []byte{0}
@ -640,7 +636,5 @@ func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) {
} }
case <-time.After(100 * time.Millisecond): case <-time.After(100 * time.Millisecond):
// nothing to read // nothing to read
pipe.Close()
<-errC
} }
} }

View File

@ -1,6 +1,7 @@
package h2mux package h2mux
import ( import (
"io"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -63,3 +64,27 @@ func TestFlowControlSingleStream(t *testing.T) {
assert.Equal(t, testWindowSize<<2, stream.receiveWindow) assert.Equal(t, testWindowSize<<2, stream.receiveWindow)
assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax) assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax)
} }
func TestMuxedStreamEOF(t *testing.T) {
for i := 0; i < 4096; i++ {
readyList := NewReadyList()
stream := &MuxedStream{
streamID: 1,
readBuffer: NewSharedBuffer(),
receiveWindow: 65536,
receiveWindowMax: 65536,
sendWindow: 65536,
readyList: readyList,
}
go func() { stream.Close() }()
n, err := stream.Read([]byte{0})
assert.Equal(t, io.EOF, err)
assert.Equal(t, 0, n)
// Write comes after read, because write buffers data before it is flushed. It wouldn't know about EOF
// until some time later. Calling read first forces it to know about EOF now.
n, err = stream.Write([]byte{1})
assert.Equal(t, io.EOF, err)
assert.Equal(t, 0, n)
}
}

View File

@ -6,7 +6,7 @@ import (
"sync" "sync"
"time" "time"
log "github.com/Sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )

View File

@ -6,7 +6,7 @@ import (
"io" "io"
"time" "time"
log "github.com/Sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
) )

View File

@ -21,7 +21,7 @@ func NewSharedBuffer() *SharedBuffer {
func (s *SharedBuffer) Read(p []byte) (n int, err error) { func (s *SharedBuffer) Read(p []byte) (n int, err error) {
totalRead := 0 totalRead := 0
s.cond.L.Lock() s.cond.L.Lock()
for totalRead < len(p) { for totalRead == 0 {
n, err = s.buffer.Read(p[totalRead:]) n, err = s.buffer.Read(p[totalRead:])
totalRead += n totalRead += n
if err == io.EOF { if err == io.EOF {
@ -29,6 +29,9 @@ func (s *SharedBuffer) Read(p []byte) (n int, err error) {
break break
} }
err = nil err = nil
if n > 0 {
break
}
s.cond.Wait() s.cond.Wait()
} }
} }

View File

@ -6,6 +6,8 @@ import (
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) { func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) {
@ -29,30 +31,35 @@ func TestSharedBuffer(t *testing.T) {
func TestSharedBufferBlockingRead(t *testing.T) { func TestSharedBufferBlockingRead(t *testing.T) {
b := NewSharedBuffer() b := NewSharedBuffer()
testData := []byte("Hello world") testData1 := []byte("Hello")
testData2 := []byte(" world")
result := make(chan []byte) result := make(chan []byte)
go func() { go func() {
bytesRead := make([]byte, len(testData)) bytesRead := make([]byte, len(testData1)+len(testData2))
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead)) nRead, err := b.Read(bytesRead)
result <- bytesRead AssertIOReturnIsGood(t, len(testData1))(nRead, err)
result <- bytesRead[:nRead]
nRead, err = b.Read(bytesRead)
AssertIOReturnIsGood(t, len(testData2))(nRead, err)
result <- bytesRead[:nRead]
}() }()
time.Sleep(time.Millisecond * 250)
select { select {
case <-result: case <-result:
t.Fatalf("read returned early") t.Fatalf("read returned early")
default: default:
} }
AssertIOReturnIsGood(t, 5)(b.Write(testData[:5])) AssertIOReturnIsGood(t, len(testData1))(b.Write([]byte(testData1)))
select {
case <-result:
t.Fatalf("read returned early")
default:
}
AssertIOReturnIsGood(t, len(testData)-5)(b.Write(testData[5:]))
select { select {
case r := <-result: case r := <-result:
if string(r) != string(testData) { assert.Equal(t, testData1, r)
t.Fatalf("expected read to return %s, got %s", testData, r) case <-time.After(time.Second):
t.Fatalf("read timed out")
} }
AssertIOReturnIsGood(t, len(testData2))(b.Write([]byte(testData2)))
select {
case r := <-result:
assert.Equal(t, testData2, r)
case <-time.After(time.Second): case <-time.After(time.Second):
t.Fatalf("read timed out") t.Fatalf("read timed out")
} }
@ -85,7 +92,7 @@ func TestSharedBufferConcurrentReadWrite(t *testing.T) {
// Change block sizes in opposition to the write thread, to test blocking for new data. // Change block sizes in opposition to the write thread, to test blocking for new data.
for blockSize := 256; blockSize > 0; blockSize-- { for blockSize := 256; blockSize > 0; blockSize-- {
for i := 0; i < 256; i++ { for i := 0; i < 256; i++ {
n, err := b.Read(block[:blockSize]) n, err := io.ReadFull(b, block[:blockSize])
if n != blockSize || err != nil { if n != blockSize || err != nil {
t.Fatalf("read error: %d %s", n, err) t.Fatalf("read error: %d %s", n, err)
} }

View File

@ -11,9 +11,9 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/trace" "golang.org/x/net/trace"
log "github.com/Sirupsen/logrus"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
) )
const ( const (

View File

@ -5,7 +5,6 @@ import (
"github.com/cloudflare/cloudflare-warp/h2mux" "github.com/cloudflare/cloudflare-warp/h2mux"
log "github.com/Sirupsen/logrus"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
@ -249,7 +248,7 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
if _, ok := t.concurrentRequests[connectionID]; ok { if _, ok := t.concurrentRequests[connectionID]; ok {
t.concurrentRequests[connectionID] -= 1 t.concurrentRequests[connectionID] -= 1
} else { } else {
log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.") Log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.")
} }
t.concurrentRequestsLock.Unlock() t.concurrentRequestsLock.Unlock()

View File

@ -5,7 +5,6 @@ import (
"net" "net"
"time" "time"
log "github.com/Sirupsen/logrus"
"golang.org/x/net/context" "golang.org/x/net/context"
) )
@ -73,7 +72,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
case tunnelError := <-s.tunnelErrors: case tunnelError := <-s.tunnelErrors:
tunnelsActive-- tunnelsActive--
if tunnelError.err != nil { if tunnelError.err != nil {
log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error") Log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
s.waitForNextTunnel(tunnelError.index) s.waitForNextTunnel(tunnelError.index)
if backoffTimer == nil { if backoffTimer == nil {
@ -107,10 +106,10 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
s.lastResolve = time.Now() s.lastResolve = time.Now()
s.resolverC = nil s.resolverC = nil
if result.err == nil { if result.err == nil {
log.Debug("Service discovery refresh complete") Log.Debug("Service discovery refresh complete")
s.edgeIPs = result.edgeIPs s.edgeIPs = result.edgeIPs
} else { } else {
log.WithError(result.err).Error("Service discovery error") Log.WithError(result.err).Error("Service discovery error")
} }
} }
} }
@ -120,12 +119,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error { func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs) edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
if err != nil { if err != nil {
log.Infof("ResolveEdgeIPs err") Log.Infof("ResolveEdgeIPs err")
return err return err
} }
s.edgeIPs = edgeIPs s.edgeIPs = edgeIPs
if s.config.HAConnections > len(edgeIPs) { if s.config.HAConnections > len(edgeIPs) {
log.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs)) Log.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
s.config.HAConnections = len(edgeIPs) s.config.HAConnections = len(edgeIPs)
} }
s.lastResolve = time.Now() s.lastResolve = time.Now()

View File

@ -19,14 +19,17 @@ import (
"github.com/cloudflare/cloudflare-warp/tunnelrpc" "github.com/cloudflare/cloudflare-warp/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
"github.com/cloudflare/cloudflare-warp/validation" "github.com/cloudflare/cloudflare-warp/validation"
"github.com/cloudflare/cloudflare-warp/websocket"
log "github.com/Sirupsen/logrus"
raven "github.com/getsentry/raven-go" raven "github.com/getsentry/raven-go"
"github.com/pkg/errors" "github.com/pkg/errors"
_ "github.com/prometheus/client_golang/prometheus" _ "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
rpc "zombiezen.com/go/capnproto2/rpc" rpc "zombiezen.com/go/capnproto2/rpc"
) )
var Log *logrus.Logger
const ( const (
dialTimeout = 15 * time.Second dialTimeout = 15 * time.Second
@ -40,6 +43,7 @@ type TunnelConfig struct {
Hostname string Hostname string
OriginCert []byte OriginCert []byte
TlsConfig *tls.Config TlsConfig *tls.Config
ClientTlsConfig *tls.Config
Retries uint Retries uint
HeartbeatInterval time.Duration HeartbeatInterval time.Duration
MaxHeartbeats uint64 MaxHeartbeats uint64
@ -51,7 +55,9 @@ type TunnelConfig struct {
HTTPTransport http.RoundTripper HTTPTransport http.RoundTripper
Metrics *TunnelMetrics Metrics *TunnelMetrics
MetricsUpdateFreq time.Duration MetricsUpdateFreq time.Duration
ProtocolLogger *log.Logger ProtocolLogger *logrus.Logger
Logger *logrus.Logger
IsAutoupdated bool
} }
type dialError struct { type dialError struct {
@ -87,14 +93,16 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
Version: c.ReportedVersion, Version: c.ReportedVersion,
OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH), OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH),
ExistingTunnelPolicy: policy, ExistingTunnelPolicy: policy,
PoolID: c.LBPool, PoolName: c.LBPool,
Tags: c.Tags, Tags: c.Tags,
ConnectionID: connectionID, ConnectionID: connectionID,
OriginLocalIP: OriginLocalIP, OriginLocalIP: OriginLocalIP,
IsAutoupdated: c.IsAutoupdated,
} }
} }
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error { func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error {
Log = config.Logger
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go func() { go func() {
<-shutdownC <-shutdownC
@ -129,7 +137,7 @@ func ServeTunnelLoop(ctx context.Context, config *TunnelConfig, addr *net.TCPAdd
err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff) err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff)
if recoverable { if recoverable {
if duration, ok := backoff.GetBackoffDuration(ctx); ok { if duration, ok := backoff.GetBackoffDuration(ctx); ok {
log.Infof("Retrying in %s seconds", duration) Log.Infof("Retrying in %s seconds", duration)
backoff.Backoff(ctx) backoff.Backoff(ctx)
continue continue
} }
@ -162,11 +170,10 @@ func ServeTunnel(
// Returns error from parsing the origin URL or handshake errors // Returns error from parsing the origin URL or handshake errors
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID) handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID)
if err != nil { if err != nil {
errLog := log.WithError(err) errLog := Log.WithError(err)
switch err.(type) { switch err.(type) {
case dialError: case dialError:
errLog.Error("Unable to dial edge") errLog.Error("Unable to dial edge")
return err, false
case h2mux.MuxerHandshakeError: case h2mux.MuxerHandshakeError:
errLog.Error("Handshake failed with edge server") errLog.Error("Handshake failed with edge server")
default: default:
@ -207,24 +214,21 @@ func ServeTunnel(
registerErr := <-registerErrC registerErr := <-registerErrC
wg.Wait() wg.Wait()
if err != nil { if err != nil {
log.WithError(err).Error("Tunnel error") Log.WithError(err).Error("Tunnel error")
return err, true return err, true
} }
if registerErr != nil { if registerErr != nil {
// Don't retry on errors like entitlement failure or version too old // Don't retry on errors like entitlement failure or version too old
if e, ok := registerErr.(printableRegisterTunnelError); ok { if e, ok := registerErr.(printableRegisterTunnelError); ok {
log.Error(e) Log.Error(e)
if e.permanent { return e.cause, !e.permanent
return e, false
}
return e.cause, true
} else if e, ok := registerErr.(dupConnRegisterTunnelError); ok { } else if e, ok := registerErr.(dupConnRegisterTunnelError); ok {
log.Info("Already connected to this server, selecting a different one") Log.Info("Already connected to this server, selecting a different one")
return e, true return e, true
} }
// Only log errors to Sentry that may have been caused by the client side, to reduce dupes // Only log errors to Sentry that may have been caused by the client side, to reduce dupes
raven.CaptureError(registerErr, nil) raven.CaptureError(registerErr, nil)
log.Error("Cannot register") Log.Error("Cannot register")
return err, true return err, true
} }
return nil, false return nil, false
@ -241,7 +245,7 @@ func IsRPCStreamResponse(headers []h2mux.Header) bool {
} }
func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error { func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error {
logger := log.WithField("subsystem", "rpc") logger := Log.WithField("subsystem", "rpc")
logger.Debug("initiating RPC stream") logger.Debug("initiating RPC stream")
stream, err := muxer.OpenStream([]h2mux.Header{ stream, err := muxer.OpenStream([]h2mux.Header{
{Name: ":method", Value: "RPC"}, {Name: ":method", Value: "RPC"},
@ -292,11 +296,11 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
} }
} }
log.Infof("Registered at %s", registration.Url) Log.Infof("Registered at %s", registration.Url)
return nil return nil
} }
func LogServerInfo(logger *log.Entry, func LogServerInfo(logger *logrus.Entry,
promise tunnelrpc.ServerInfo_Promise, promise tunnelrpc.ServerInfo_Promise,
connectionID uint8, connectionID uint8,
metrics *TunnelMetrics, metrics *TunnelMetrics,
@ -311,7 +315,7 @@ func LogServerInfo(logger *log.Entry,
logger.WithError(err).Warn("Failed to retrieve server information") logger.WithError(err).Warn("Failed to retrieve server information")
return return
} }
log.Infof("Connected to %s", serverInfo.LocationName) Log.Infof("Connected to %s", serverInfo.LocationName)
metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName) metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName)
} }
@ -356,6 +360,7 @@ type TunnelHandler struct {
originUrl string originUrl string
muxer *h2mux.Muxer muxer *h2mux.Muxer
httpClient http.RoundTripper httpClient http.RoundTripper
tlsConfig *tls.Config
tags []tunnelpogs.Tag tags []tunnelpogs.Tag
metrics *TunnelMetrics metrics *TunnelMetrics
// connectionID is only used by metrics, and prometheus requires labels to be string // connectionID is only used by metrics, and prometheus requires labels to be string
@ -373,6 +378,7 @@ func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, co
h := &TunnelHandler{ h := &TunnelHandler{
originUrl: url, originUrl: url,
httpClient: config.HTTPTransport, httpClient: config.HTTPTransport,
tlsConfig: config.ClientTlsConfig,
tags: config.Tags, tags: config.Tags,
metrics: config.Metrics, metrics: config.Metrics,
connectionID: uint8ToString(connectionID), connectionID: uint8ToString(connectionID),
@ -422,29 +428,45 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
h.metrics.incrementRequests(h.connectionID) h.metrics.incrementRequests(h.connectionID)
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil { if err != nil {
log.WithError(err).Panic("Unexpected error from http.NewRequest") Log.WithError(err).Panic("Unexpected error from http.NewRequest")
} }
err = H2RequestHeadersToH1Request(stream.Headers, req) err = H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil { if err != nil {
log.WithError(err).Error("invalid request received") Log.WithError(err).Error("invalid request received")
} }
h.AppendTagHeaders(req) h.AppendTagHeaders(req)
if websocket.IsWebSocketUpgrade(req) {
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
if err != nil {
h.logError(stream, err)
} else {
stream.WriteHeaders(H1ResponseToH2Response(response))
defer conn.Close()
websocket.Stream(conn.UnderlyingConn(), stream)
}
} else {
response, err := h.httpClient.RoundTrip(req) response, err := h.httpClient.RoundTrip(req)
if err != nil { if err != nil {
log.WithError(err).Error("HTTP request error") h.logError(stream, err)
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
stream.Write([]byte("502 Bad Gateway"))
h.metrics.incrementResponses(h.connectionID, "502")
} else { } else {
defer response.Body.Close() defer response.Body.Close()
stream.WriteHeaders(H1ResponseToH2Response(response)) stream.WriteHeaders(H1ResponseToH2Response(response))
io.Copy(stream, response.Body) io.Copy(stream, response.Body)
h.metrics.incrementResponses(h.connectionID, "200") h.metrics.incrementResponses(h.connectionID, "200")
} }
}
h.metrics.decrementConcurrentRequests(h.connectionID) h.metrics.decrementConcurrentRequests(h.connectionID)
return nil return nil
} }
func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
Log.WithError(err).Error("HTTP request error")
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
stream.Write([]byte("502 Bad Gateway"))
h.metrics.incrementResponses(h.connectionID, "502")
}
func (h *TunnelHandler) UpdateMetrics() { func (h *TunnelHandler) UpdateMetrics() {
flowCtlMetrics := h.muxer.FlowControlMetrics() flowCtlMetrics := h.muxer.FlowControlMetrics()
h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics) h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics)

View File

@ -0,0 +1,95 @@
package tlsconfig
import (
"crypto/x509"
)
// TODO: remove the Origin CA root certs when migrated to Authenticated Origin Pull certs
const cloudflareRootCA = `
Issuer: C=US, ST=California, L=San Francisco, O=CloudFlare, Inc., OU=CloudFlare Origin SSL ECC Certificate Authority
-----BEGIN CERTIFICATE-----
MIICiDCCAi6gAwIBAgIUXZP3MWb8MKwBE1Qbawsp1sfA/Y4wCgYIKoZIzj0EAwIw
gY8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1T
YW4gRnJhbmNpc2NvMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTgwNgYDVQQL
Ey9DbG91ZEZsYXJlIE9yaWdpbiBTU0wgRUNDIENlcnRpZmljYXRlIEF1dGhvcml0
eTAeFw0xNjAyMjIxODI0MDBaFw0yMTAyMjIwMDI0MDBaMIGPMQswCQYDVQQGEwJV
UzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEZ
MBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE4MDYGA1UECxMvQ2xvdWRGbGFyZSBP
cmlnaW4gU1NMIEVDQyBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkwWTATBgcqhkjOPQIB
BggqhkjOPQMBBwNCAASR+sGALuaGshnUbcxKry+0LEXZ4NY6JUAtSeA6g87K3jaA
xpIg9G50PokpfWkhbarLfpcZu0UAoYy2su0EhN7wo2YwZDAOBgNVHQ8BAf8EBAMC
AQYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUhTBdOypw1O3VkmcH/es5
tBoOOKcwHwYDVR0jBBgwFoAUhTBdOypw1O3VkmcH/es5tBoOOKcwCgYIKoZIzj0E
AwIDSAAwRQIgEiIEHQr5UKma50D1WRMJBUSgjg24U8n8E2mfw/8UPz0CIQCr5V/e
mcifak4CQsr+DH4pn5SJD7JxtCG3YGswW8QZsw==
-----END CERTIFICATE-----
Issuer: C=US, O=CloudFlare, Inc., OU=CloudFlare Origin SSL Certificate Authority, L=San Francisco, ST=California
-----BEGIN CERTIFICATE-----
MIID/DCCAuagAwIBAgIID+rOSdTGfGcwCwYJKoZIhvcNAQELMIGLMQswCQYDVQQG
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRG
bGFyZSBPcmlnaW4gU1NMIENlcnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMN
U2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTAeFw0xNDExMTMyMDM4
NTBaFw0xOTExMTQwMTQzNTBaMIGLMQswCQYDVQQGEwJVUzEZMBcGA1UEChMQQ2xv
dWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRGbGFyZSBPcmlnaW4gU1NMIENl
cnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEG
A1UECBMKQ2FsaWZvcm5pYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
AMBIlWf1KEKR5hbB75OYrAcUXobpD/AxvSYRXr91mbRu+lqE7YbyyRUShQh15lem
ef+umeEtPZoLFLhcLyczJxOhI+siLGDQm/a/UDkWvAXYa5DZ+pHU5ct5nZ8pGzqJ
p8G1Hy5RMVYDXZT9F6EaHjMG0OOffH6Ih25TtgfyyrjXycwDH0u6GXt+G/rywcqz
/9W4Aki3XNQMUHNQAtBLEEIYHMkyTYJxuL2tXO6ID5cCsoWw8meHufTeZW2DyUpl
yP3AHt4149RQSyWZMJ6AyntL9d8Xhfpxd9rJkh9Kge2iV9rQTFuE1rRT5s7OSJcK
xUsklgHcGHYMcNfNMilNHb8CAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgAGMBIGA1Ud
EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFCToU1ddfDRAh6nrlNu64RZ4/CmkMB8G
A1UdIwQYMBaAFCToU1ddfDRAh6nrlNu64RZ4/CmkMAsGCSqGSIb3DQEBCwOCAQEA
cQDBVAoRrhhsGegsSFsv1w8v27zzHKaJNv6ffLGIRvXK8VKKK0gKXh2zQtN9SnaD
gYNe7Pr4C3I8ooYKRJJWLsmEHdGdnYYmj0OJfGrfQf6MLIc/11bQhLepZTxdhFYh
QGgDl6gRmb8aDwk7Q92BPvek5nMzaWlP82ixavvYI+okoSY8pwdcVKobx6rWzMWz
ZEC9M6H3F0dDYE23XcCFIdgNSAmmGyXPBstOe0aAJXwJTxOEPn36VWr0PKIQJy5Y
4o1wpMpqCOIwWc8J9REV/REzN6Z1LXImdUgXIXOwrz56gKUJzPejtBQyIGj0mveX
Fu6q54beR89jDc+oABmOgg==
-----END CERTIFICATE-----
Issuer: C=US, O=CloudFlare, Inc., OU=Origin Pull, L=San Francisco, ST=California, CN=origin-pull.cloudflare.net
-----BEGIN CERTIFICATE-----
MIIGBjCCA/CgAwIBAgIIV5G6lVbCLmEwCwYJKoZIhvcNAQENMIGQMQswCQYDVQQG
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjEUMBIGA1UECxMLT3JpZ2lu
IFB1bGwxFjAUBgNVBAcTDVNhbiBGcmFuY2lzY28xEzARBgNVBAgTCkNhbGlmb3Ju
aWExIzAhBgNVBAMTGm9yaWdpbi1wdWxsLmNsb3VkZmxhcmUubmV0MB4XDTE1MDEx
MzAyNDc1M1oXDTIwMDExMjAyNTI1M1owgZAxCzAJBgNVBAYTAlVTMRkwFwYDVQQK
ExBDbG91ZEZsYXJlLCBJbmMuMRQwEgYDVQQLEwtPcmlnaW4gUHVsbDEWMBQGA1UE
BxMNU2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTEjMCEGA1UEAxMa
b3JpZ2luLXB1bGwuY2xvdWRmbGFyZS5uZXQwggIiMA0GCSqGSIb3DQEBAQUAA4IC
DwAwggIKAoICAQDdsts6I2H5dGyn4adACQRXlfo0KmwsN7B5rxD8C5qgy6spyONr
WV0ecvdeGQfWa8Gy/yuTuOnsXfy7oyZ1dm93c3Mea7YkM7KNMc5Y6m520E9tHooc
f1qxeDpGSsnWc7HWibFgD7qZQx+T+yfNqt63vPI0HYBOYao6hWd3JQhu5caAcIS2
ms5tzSSZVH83ZPe6Lkb5xRgLl3eXEFcfI2DjnlOtLFqpjHuEB3Tr6agfdWyaGEEi
lRY1IB3k6TfLTaSiX2/SyJ96bp92wvTSjR7USjDV9ypf7AD6u6vwJZ3bwNisNw5L
ptph0FBnc1R6nDoHmvQRoyytoe0rl/d801i9Nru/fXa+l5K2nf1koR3IX440Z2i9
+Z4iVA69NmCbT4MVjm7K3zlOtwfI7i1KYVv+ATo4ycgBuZfY9f/2lBhIv7BHuZal
b9D+/EK8aMUfjDF4icEGm+RQfExv2nOpkR4BfQppF/dLmkYfjgtO1403X0ihkT6T
PYQdmYS6Jf53/KpqC3aA+R7zg2birtvprinlR14MNvwOsDOzsK4p8WYsgZOR4Qr2
gAx+z2aVOs/87+TVOR0r14irQsxbg7uP2X4t+EXx13glHxwG+CnzUVycDLMVGvuG
aUgF9hukZxlOZnrl6VOf1fg0Caf3uvV8smOkVw6DMsGhBZSJVwao0UQNqQIDAQAB
o2YwZDAOBgNVHQ8BAf8EBAMCAAYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4E
FgQUQ1lLK2mLgOERM2pXzVc42p59xeswHwYDVR0jBBgwFoAUQ1lLK2mLgOERM2pX
zVc42p59xeswCwYJKoZIhvcNAQENA4ICAQDKDQM1qPRVP/4Gltz0D6OU6xezFBKr
LWtDoA1qW2F7pkiYawCP9MrDPDJsHy7dx+xw3bBZxOsK5PA/T7p1dqpEl6i8F692
g//EuYOifLYw3ySPe3LRNhvPl/1f6Sn862VhPvLa8aQAAwR9e/CZvlY3fj+6G5ik
3it7fikmKUsVnugNOkjmwI3hZqXfJNc7AtHDFw0mEOV0dSeAPTo95N9cxBbm9PKv
qAEmTEXp2trQ/RjJ/AomJyfA1BQjsD0j++DI3a9/BbDwWmr1lJciKxiNKaa0BRLB
dKMrYQD+PkPNCgEuojT+paLKRrMyFUzHSG1doYm46NE9/WARTh3sFUp1B7HZSBqA
kHleoB/vQ/mDuW9C3/8Jk2uRUdZxR+LoNZItuOjU8oTy6zpN1+GgSj7bHjiy9rfA
F+ehdrz+IOh80WIiqs763PGoaYUyzxLvVowLWNoxVVoc9G+PqFKqD988XlipHVB6
Bz+1CD4D/bWrs3cC9+kk/jFmrrAymZlkFX8tDb5aXASSLJjUjcptci9SKqtI2h0J
wUGkD7+bQAr+7vr8/R+CBmNMe7csE8NeEX6lVMF7Dh0a1YKQa6hUN18bBuYgTMuT
QzMmZpRpIBB321ZBlcnlxiTJvWxvbCPHKHj20VwwAz7LONF59s84ZsOqfoBv8gKM
s0s5dsq5zpLeaw==
-----END CERTIFICATE-----`
func GetCloudflareRootCA() *x509.CertPool {
ca := x509.NewCertPool()
if !ca.AppendCertsFromPEM([]byte(cloudflareRootCA)) {
// should never happen
panic("failure loading Cloudflare origin CA pem")
}
return ca
}

50
tlsconfig/hello_ca.go Normal file
View File

@ -0,0 +1,50 @@
package tlsconfig
import (
"crypto/tls"
"crypto/x509"
)
const (
helloKey = `
-----BEGIN EC PARAMETERS-----
BgUrgQQAIg==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
MIGkAgEBBDAdyQBXfxTDCQSOT0HugmH9pVBtIw8t5dYvm6HxGlNq6P57v5GeN02Z
dH9FRl7+VSWgBwYFK4EEACKhZANiAATqpFzTxxV7D+/oqhKCTR6BEM9elTfKaRQE
FsLufcmaTMw/9tTwgpHKao/QsLKDTNbQhbSQLkcmpCQKlSGhl+pCrqNt/oYUAhav
UIwpwGiLCqGH/R2AqWLKRPOa/Rufs/U=
-----END EC PRIVATE KEY-----`
helloCRT = `
-----BEGIN CERTIFICATE-----
MIICkDCCAhigAwIBAgIJAPtKfUjc2lwGMAkGByqGSM49BAEwgYoxCzAJBgNVBAYT
AlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQKDBBD
bG91ZGZsYXJlLCBJbmMuMT8wPQYDVQQDDDZDbG91ZGZsYXJlIEFyZ28gVHVubmVs
IFNhbXBsZSBIZWxsbyBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMjE1MjAxNjU5
WhcNMjgwMjEzMjAxNjU5WjCBijELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVRleGFz
MQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENsb3VkZmxhcmUsIEluYy4xPzA9
BgNVBAMMNkNsb3VkZmxhcmUgQXJnbyBUdW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZl
ciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49AgEGBSuBBAAiA2IABOqkXNPHFXsP7+iq
EoJNHoEQz16VN8ppFAQWwu59yZpMzD/21PCCkcpqj9CwsoNM1tCFtJAuRyakJAqV
IaGX6kKuo23+hhQCFq9QjCnAaIsKoYf9HYCpYspE85r9G5+z9aNJMEcwRQYDVR0R
BD4wPIIJbG9jYWxob3N0ggp3YXJwLWhlbGxvggt3YXJwMi1oZWxsb4cEfwAAAYcQ
AAAAAAAAAAAAAAAAAAAAATAJBgcqhkjOPQQBA2cAMGQCMHyVPufXZ6vQo6XRWRa0
dAwtfgesOdZVP2Wt+t5v8jOIQQh1IQXYk5GtyoZGSObjhQIwd1fRgAyKXaZt+1DV
ZtHTdf8pMvESfJsSd8AB1eQ6q+pAiRUYyaxcE1Mlo2YY5o+g
-----END CERTIFICATE-----`
)
func GetHelloCertificate() (tls.Certificate, error) {
return tls.X509KeyPair([]byte(helloCRT), []byte(helloKey))
}
func GetHelloCertificateX509() (*x509.Certificate, error) {
helloCertificate, err := GetHelloCertificate()
if err != nil {
return nil, err
}
return x509.ParseCertificate(helloCertificate.Certificate[0])
}

View File

@ -6,8 +6,9 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"io/ioutil" "io/ioutil"
"net"
log "github.com/Sirupsen/logrus" log "github.com/sirupsen/logrus"
cli "gopkg.in/urfave/cli.v2" cli "gopkg.in/urfave/cli.v2"
) )
@ -60,3 +61,43 @@ func LoadCert(certPath string) *x509.CertPool {
} }
return ca return ca
} }
func LoadOriginCertsPool() *x509.CertPool {
// First, obtain the system certificate pool
certPool, systemCertPoolErr := x509.SystemCertPool()
if systemCertPoolErr != nil {
log.Warn("error obtaining the system certificates: %s", systemCertPoolErr)
certPool = x509.NewCertPool()
}
// Next, append the Cloudflare CA pool into the system pool
if !certPool.AppendCertsFromPEM([]byte(cloudflareRootCA)) {
log.Warn("could not append the CF certificate to the system certificate pool")
if systemCertPoolErr != nil { // Obtaining both certificates failed; this is a fatal error
log.WithError(systemCertPoolErr).Fatalf("Error loading the certificate pool")
}
}
// Finally, add the Hello certificate into the pool (since it's self-signed)
helloCertificate, err := GetHelloCertificateX509()
if err != nil {
log.Warn("error obtaining the Hello server certificate")
}
certPool.AddCert(helloCertificate)
return certPool
}
func CreateTunnelConfig(c *cli.Context, addrs []string) *tls.Config {
tlsConfig := CLIFlags{RootCA: "cacert"}.GetConfig(c)
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = GetCloudflareRootCA()
tlsConfig.ServerName = "cftunnel.com"
} else if len(addrs) > 0 {
// Set for development environments and for testing specific origintunneld instances
tlsConfig.ServerName, _, _ = net.SplitHostPort(addrs[0])
}
return tlsConfig
}

View File

@ -1,10 +1,9 @@
package tunnelrpc package tunnelrpc
//go:generate capnp compile -ogo -I./tunnelrpc/ tunnelrpc.capnp
import ( import (
log "github.com/Sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/context" "golang.org/x/net/context"
"golang.org/x/net/trace"
"zombiezen.com/go/capnproto2/rpc" "zombiezen.com/go/capnproto2/rpc"
) )
@ -24,3 +23,20 @@ func (c ConnLogger) Errorf(ctx context.Context, format string, args ...interface
func ConnLog(log *log.Entry) rpc.ConnOption { func ConnLog(log *log.Entry) rpc.ConnOption {
return rpc.ConnLog(ConnLogger{log}) return rpc.ConnLog(ConnLogger{log})
} }
// ConnTracer wraps a trace.EventLog for a connection.
type ConnTracer struct {
Events trace.EventLog
}
func (c ConnTracer) Infof(ctx context.Context, format string, args ...interface{}) {
c.Events.Printf(format, args...)
}
func (c ConnTracer) Errorf(ctx context.Context, format string, args ...interface{}) {
c.Events.Errorf(format, args...)
}
func ConnTrace(events trace.EventLog) rpc.ConnOption {
return rpc.ConnLog(ConnTracer{events})
}

View File

@ -4,7 +4,7 @@ package tunnelrpc
import ( import (
"bytes" "bytes"
log "github.com/Sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/context" "golang.org/x/net/context"
"zombiezen.com/go/capnproto2/encoding/text" "zombiezen.com/go/capnproto2/encoding/text"
"zombiezen.com/go/capnproto2/rpc" "zombiezen.com/go/capnproto2/rpc"

View File

@ -47,10 +47,11 @@ type RegistrationOptions struct {
Version string Version string
OS string `capnp:"os"` OS string `capnp:"os"`
ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy
PoolID string `capnp:"poolId"` PoolName string `capnp:"poolName"`
Tags []Tag Tags []Tag
ConnectionID uint8 `capnp:"connectionId"` ConnectionID uint8 `capnp:"connectionId"`
OriginLocalIP string `capnp:"originLocalIp"` OriginLocalIP string `capnp:"originLocalIp"`
IsAutoupdated bool `capnp:"isAutoupdated"`
} }
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error { func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {

View File

@ -28,13 +28,15 @@ struct RegistrationOptions {
# What to do with existing tunnels for the given hostname. # What to do with existing tunnels for the given hostname.
existingTunnelPolicy @3 :ExistingTunnelPolicy; existingTunnelPolicy @3 :ExistingTunnelPolicy;
# If using the balancing policy, identifies the LB pool to use. # If using the balancing policy, identifies the LB pool to use.
poolId @4 :Text; poolName @4 :Text;
# Client-defined tags to associate with the tunnel # Client-defined tags to associate with the tunnel
tags @5 :List(Tag); tags @5 :List(Tag);
# A unique identifier for a high-availability connection made by a single client. # A unique identifier for a high-availability connection made by a single client.
connectionId @6 :UInt8; connectionId @6 :UInt8;
# origin LAN IP # origin LAN IP
originLocalIp @7 :Text; originLocalIp @7 :Text;
# whether Warp client has been autoupdated
isAutoupdated @8 :Bool;
} }
struct Tag { struct Tag {

77
websocket/websocket.go Normal file
View File

@ -0,0 +1,77 @@
package websocket
import (
"bufio"
"crypto/sha1"
"crypto/tls"
"encoding/base64"
"errors"
"io"
"net"
"net/http"
"github.com/gorilla/websocket"
)
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
func IsWebSocketUpgrade(req *http.Request) bool {
return websocket.IsWebSocketUpgrade(req)
}
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing.
func ClientConnect(req *http.Request, tlsClientConfig *tls.Config) (*websocket.Conn, *http.Response, error) {
req.URL.Scheme = "wss"
d := &websocket.Dialer{TLSClientConfig: tlsClientConfig}
conn, response, err := d.Dial(req.URL.String(), nil)
if err != nil {
return nil, nil, err
}
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
return conn, response, err
}
// HijackConnection takes over an HTTP connection. Caller is responsible for closing connection.
func HijackConnection(w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
hj, ok := w.(http.Hijacker)
if !ok {
return nil, nil, errors.New("hijack error")
}
conn, brw, err := hj.Hijack()
if err != nil {
return nil, nil, err
}
return conn, brw, nil
}
// Stream copies copy data to & from provided io.ReadWriters.
func Stream(conn, backendConn io.ReadWriter) {
proxyDone := make(chan struct{}, 2)
go func() {
io.Copy(conn, backendConn)
proxyDone <- struct{}{}
}()
go func() {
io.Copy(backendConn, conn)
proxyDone <- struct{}{}
}()
// If one side is done, we are done.
<-proxyDone
}
// sha1Base64 sha1 and then base64 encodes str.
func sha1Base64(str string) string {
hasher := sha1.New()
io.WriteString(hasher, str)
hash := hasher.Sum(nil)
return base64.StdEncoding.EncodeToString(hash)
}
// generateAcceptKey returns the string needed for the Sec-WebSocket-Accept header.
// https://tools.ietf.org/html/rfc6455#section-1.3 describes this process in more detail.
func generateAcceptKey(req *http.Request) string {
return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
}