Release Warp Client 2018.2.1

This commit is contained in:
cloudflare-warp-bot 2018-02-20 21:13:56 +00:00
parent e0ae598112
commit 3780e14f41
25 changed files with 713 additions and 223 deletions

View File

@ -2,17 +2,20 @@ package main
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"html/template"
"io/ioutil"
"net"
"net/http"
"os"
"time"
"github.com/pkg/errors"
"github.com/gorilla/websocket"
"gopkg.in/urfave/cli.v2"
log "github.com/Sirupsen/logrus"
cli "gopkg.in/urfave/cli.v2"
"github.com/cloudflare/cloudflare-warp/tlsconfig"
)
type templateData struct {
@ -21,6 +24,11 @@ type templateData struct {
Body string
}
type OriginUpTime struct {
StartTime time.Time `json:"startTime"`
UpTime string `json:"uptime"`
}
const defaultServerName = "the Cloudflare Warp test server"
const indexTemplate = `
<!DOCTYPE html>
@ -85,71 +93,113 @@ const indexTemplate = `
func hello(c *cli.Context) error {
address := fmt.Sprintf(":%d", c.Int("port"))
server := NewHelloWorldServer()
if hostname, err := os.Hostname(); err != nil {
server.serverName = hostname
listener, err := createListener(address)
if err != nil {
return err
}
err := server.ListenAndServe(address)
return errors.Wrap(err, "Fail to start Hello World Server")
defer listener.Close()
err = startHelloWorldServer(listener, nil)
return err
}
func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
server := NewHelloWorldServer()
if hostname, err := os.Hostname(); err != nil {
server.serverName = hostname
Log.Infof("Starting Hello World server at %s", listener.Addr())
serverName := defaultServerName
if hostname, err := os.Hostname(); err == nil {
serverName = hostname
}
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: server}
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
go func() {
<-shutdownC
httpServer.Close()
}()
http.HandleFunc("/uptime", uptimeHandler(time.Now()))
http.HandleFunc("/ws", websocketHandler(upgrader))
http.HandleFunc("/", rootHandler(serverName))
err := httpServer.Serve(listener)
return err
}
type HelloWorldServer struct {
responseTemplate *template.Template
serverName string
}
func NewHelloWorldServer() *HelloWorldServer {
return &HelloWorldServer{
responseTemplate: template.Must(template.New("index").Parse(indexTemplate)),
serverName: defaultServerName,
func createListener(address string) (net.Listener, error) {
certificate, err := tlsconfig.GetHelloCertificate()
if err != nil {
return nil, err
}
}
func findAvailablePort() (net.Listener, error) {
// If the port in address is empty, a port number is automatically chosen.
listener, err := net.Listen("tcp", "127.0.0.1:")
// If the port in address is empty, a port number is automatically chosen
listener, err := tls.Listen(
"tcp",
address,
&tls.Config{Certificates: []tls.Certificate{certificate}})
return listener, err
}
func (s *HelloWorldServer) ListenAndServe(address string) error {
log.Infof("Starting Hello World server on %s", address)
err := http.ListenAndServe(address, s)
return err
func uptimeHandler(startTime time.Time) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Note that if autoupdate is enabled, the uptime is reset when a new client
// release is available
resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()}
respJson, err := json.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
} else {
w.Header().Set("Content-Type", "application/json")
w.Write(respJson)
}
}
}
func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto)
var buffer bytes.Buffer
var body string
rawBody, err := ioutil.ReadAll(r.Body)
if err == nil {
body = string(rawBody)
} else {
body = ""
}
err = s.responseTemplate.Execute(&buffer, &templateData{
ServerName: s.serverName,
Request: r,
Body: body,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "error: %v", err)
} else {
buffer.WriteTo(w)
func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
if err := conn.WriteMessage(mt, message); err != nil {
break
}
}
}
}
func rootHandler(serverName string) http.HandlerFunc {
responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
return func(w http.ResponseWriter, r *http.Request) {
Log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto)
var buffer bytes.Buffer
var body string
rawBody, err := ioutil.ReadAll(r.Body)
if err == nil {
body = string(rawBody)
} else {
body = ""
}
err = responseTemplate.Execute(&buffer, &templateData{
ServerName: serverName,
Request: r,
Body: body,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "error: %v", err)
} else {
buffer.WriteTo(w)
}
}
}

View File

@ -4,18 +4,30 @@ import (
"testing"
)
const testPort = "8080"
func TestNewHelloWorldServer(t *testing.T) {
if NewHelloWorldServer() == nil {
t.Fatal("NewHelloWorldServer returned nil")
}
}
func TestFindAvailablePort(t *testing.T) {
listener, err := findAvailablePort()
func TestCreateListenerHostAndPortSuccess(t *testing.T) {
listener, err := createListener("localhost:1234")
if err != nil {
t.Fatal("Fail to find available port")
t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}
func TestCreateListenerOnlyHostSuccess(t *testing.T) {
listener, err := createListener("localhost:")
if err != nil {
t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}
func TestCreateListenerOnlyPortSuccess(t *testing.T) {
listener, err := createListener(":8888")
if err != nil {
t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")

View File

@ -42,6 +42,8 @@ After=network.target
TimeoutStartSec=0
Type=notify
ExecStart={{ .Path }} --config /etc/cloudflare-warp/config.yml --origincert /etc/cloudflare-warp/cert.pem --no-autoupdate
Restart=on-failure
RestartSec=5s
[Install]
WantedBy=multi-user.target

View File

@ -14,7 +14,6 @@ import (
"syscall"
"time"
log "github.com/Sirupsen/logrus"
homedir "github.com/mitchellh/go-homedir"
cli "gopkg.in/urfave/cli.v2"
)
@ -137,7 +136,7 @@ func download(certURL, filePath string) bool {
return true
}
if err != nil {
log.WithError(err).Error("Error fetching certificate")
Log.WithError(err).Error("Error fetching certificate")
return false
}
}
@ -180,16 +179,16 @@ func putSuccess(client *http.Client, certURL string) {
// indicate success to the relay server
req, err := http.NewRequest("PUT", certURL+"/ok", nil)
if err != nil {
log.WithError(err).Error("HTTP request error")
Log.WithError(err).Error("HTTP request error")
return
}
resp, err := client.Do(req)
if err != nil {
log.WithError(err).Error("HTTP error")
Log.WithError(err).Error("HTTP error")
return
}
resp.Body.Close()
if resp.StatusCode != 200 {
log.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
}
}

View File

@ -11,6 +11,8 @@ import (
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"sync"
"syscall"
"time"
@ -21,11 +23,12 @@ import (
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
"github.com/cloudflare/cloudflare-warp/validation"
log "github.com/Sirupsen/logrus"
"github.com/facebookgo/grace/gracenet"
raven "github.com/getsentry/raven-go"
homedir "github.com/mitchellh/go-homedir"
cli "gopkg.in/urfave/cli.v2"
"github.com/getsentry/raven-go"
"github.com/mitchellh/go-homedir"
"github.com/rifflock/lfshook"
"github.com/sirupsen/logrus"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
"github.com/coreos/go-systemd/daemon"
@ -40,11 +43,21 @@ const configFile = "config.yml"
var listeners = gracenet.Net{}
var Version = "DEV"
var BuildTime = "unknown"
var Log *logrus.Logger
// Shutdown channel used by the app. When closed, app must terminate.
// May be closed by the Windows service runner.
var shutdownC chan struct{}
type BuildAndRuntimeInfo struct {
GoOS string `json:"go_os"`
GoVersion string `json:"go_version"`
GoArch string `json:"go_arch"`
WarpVersion string `json:"warp_version"`
WarpFlags map[string]interface{} `json:"warp_flags"`
WarpEnvs map[string]string `json:"warp_envs"`
}
func main() {
metrics.RegisterBuildInfo(BuildTime, Version)
raven.SetDSN(sentryDSN)
@ -84,6 +97,12 @@ WARNING:
Usage: "Disable periodic check for updates, restarting the server with the new version.",
Value: false,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "is-autoupdated",
Usage: "Signal the new process that Warp client has been autoupdated",
Value: false,
Hidden: true,
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "edge",
Usage: "Address of the Cloudflare tunnel server.",
@ -99,12 +118,12 @@ WARNING:
altsrc.NewStringFlag(&cli.StringFlag{
Name: "origincert",
Usage: "Path to the certificate generated for your origin when you run cloudflare-warp login.",
EnvVars: []string{"ORIGIN_CERT"},
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: filepath.Join(defaultConfigDir, credentialFile),
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "url",
Value: "http://localhost:8080",
Value: "https://localhost:8080",
Usage: "Connect to the local webserver at `URL`.",
EnvVars: []string{"TUNNEL_URL"},
}),
@ -190,15 +209,21 @@ WARNING:
EnvVars: []string{"TUNNEL_RETRIES"},
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "hello-world",
Usage: "Run Hello World Server",
Value: false,
Name: "hello-world",
Value: false,
Usage: "Run Hello World Server",
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "pidfile",
Usage: "Write the application's PID to this file after first successful connection.",
EnvVars: []string{"TUNNEL_PIDFILE"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "logfile",
Usage: "Save application log to this file for reporting issues.",
EnvVars: []string{"TUNNEL_LOGFILE"},
}),
altsrc.NewIntFlag(&cli.IntFlag{
Name: "ha-connections",
Value: 4,
@ -239,6 +264,7 @@ WARNING:
return nil
}
app.Before = func(context *cli.Context) error {
Log = logrus.New()
inputSource, err := findInputSourceContext(context)
if err != nil {
return err
@ -248,7 +274,7 @@ WARNING:
return nil
}
app.Commands = []*cli.Command{
&cli.Command{
{
Name: "update",
Action: update,
Usage: "Update the agent if a new version exists",
@ -259,7 +285,7 @@ WARNING:
To determine if an update happened in a script, check for error code 64.`,
},
&cli.Command{
{
Name: "login",
Action: login,
Usage: "Generate a configuration file with your login details",
@ -271,7 +297,7 @@ WARNING:
},
},
},
&cli.Command{
{
Name: "hello",
Action: hello,
Usage: "Run a simple \"Hello World\" server for testing Cloudflare Warp.",
@ -293,27 +319,43 @@ func startServer(c *cli.Context) {
errC := make(chan error)
wg.Add(2)
if c.NumFlags() == 0 && c.NArg() == 0 {
// If the user choose to supply all options through env variables,
// c.NumFlags() == 0 && c.NArg() == 0. For warp to work, the user needs to at
// least provide a hostname.
if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" {
cli.ShowAppHelp(c)
return
}
logLevel, err := log.ParseLevel(c.String("loglevel"))
logLevel, err := logrus.ParseLevel(c.String("loglevel"))
if err != nil {
log.WithError(err).Fatal("Unknown logging level specified")
Log.WithError(err).Fatal("Unknown logging level specified")
}
logrus.SetLevel(logLevel)
log.SetLevel(logLevel)
protoLogLevel, err := log.ParseLevel(c.String("proto-loglevel"))
protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
if err != nil {
log.WithError(err).Fatal("Unknown protocol logging level specified")
Log.WithError(err).Fatal("Unknown protocol logging level specified")
}
protoLogger := log.New()
protoLogger := logrus.New()
protoLogger.Level = protoLogLevel
if c.String("logfile") != "" {
if err := initLogFile(c, protoLogger); err != nil {
Log.Error(err)
}
}
if !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0 {
if initUpdate() {
return
}
Log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
go autoupdate(c.Duration("autoupdate-freq"), shutdownC)
}
hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil {
log.WithError(err).Fatal("Invalid hostname")
Log.WithError(err).Fatal("Invalid hostname")
}
clientID := c.String("id")
@ -323,46 +365,44 @@ func startServer(c *cli.Context) {
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
if err != nil {
log.WithError(err).Fatal("Tag parse failure")
Log.WithError(err).Fatal("Tag parse failure")
}
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
if c.IsSet("hello-world") {
wg.Add(1)
listener, err := findAvailablePort()
listener, err := createListener("127.0.0.1:")
if err != nil {
listener.Close()
log.WithError(err).Fatal("Cannot start Hello World Server")
Log.WithError(err).Fatal("Cannot start Hello World Server")
}
go func() {
startHelloWorldServer(listener, shutdownC)
wg.Done()
listener.Close()
}()
c.Set("url", "http://"+listener.Addr().String())
log.Infof("Starting Hello World Server at %s", c.String("url"))
c.Set("url", "https://"+listener.Addr().String())
}
url, err := validateUrl(c)
if err != nil {
log.WithError(err).Fatal("Error validating url")
Log.WithError(err).Fatal("Error validating url")
}
log.Infof("Proxying tunnel requests to %s", url)
Log.Infof("Proxying tunnel requests to %s", url)
// Fail if the user provided an old authentication method
if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login")
Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login")
}
// Check that the user has acquired a certificate using the log in command
originCertPath, err := homedir.Expand(c.String("origincert"))
if err != nil {
log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
}
ok, err := fileExists(originCertPath)
if !ok {
log.Fatalf(`Cannot find a valid certificate for your origin at the path:
Log.Fatalf(`Cannot find a valid certificate for your origin at the path:
%s
@ -375,8 +415,9 @@ If you don't have a certificate signed by Cloudflare, run the command:
// Easier to send the certificate as []byte via RPC than decoding it at this point
originCert, err := ioutil.ReadFile(originCertPath)
if err != nil {
log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
}
tunnelMetrics := origin.NewTunnelMetrics()
httpTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
@ -389,13 +430,15 @@ If you don't have a certificate signed by Cloudflare, run the command:
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()},
}
tunnelConfig := &origin.TunnelConfig{
EdgeAddrs: c.StringSlice("edge"),
OriginUrl: url,
Hostname: hostname,
OriginCert: originCert,
TlsConfig: &tls.Config{},
TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
ClientTlsConfig: httpTransport.TLSClientConfig,
Retries: c.Uint("retries"),
HeartbeatInterval: c.Duration("heartbeat-interval"),
MaxHeartbeats: c.Uint64("heartbeat-count"),
@ -408,18 +451,11 @@ If you don't have a certificate signed by Cloudflare, run the command:
Metrics: tunnelMetrics,
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
ProtocolLogger: protoLogger,
Logger: Log,
IsAutoupdated: c.Bool("is-autoupdated"),
}
connectedSignal := make(chan struct{})
tunnelConfig.TlsConfig = tlsconfig.CLIFlags{RootCA: "cacert"}.GetConfig(c)
if tunnelConfig.TlsConfig.RootCAs == nil {
tunnelConfig.TlsConfig.RootCAs = GetCloudflareRootCA()
tunnelConfig.TlsConfig.ServerName = "cftunnel.com"
} else if len(tunnelConfig.EdgeAddrs) > 0 {
// Set for development environments and for testing specific origintunneld instances
tunnelConfig.TlsConfig.ServerName, _, _ = net.SplitHostPort(tunnelConfig.EdgeAddrs[0])
}
go writePidFile(connectedSignal, c.String("pidfile"))
go func() {
errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal)
@ -428,24 +464,21 @@ If you don't have a certificate signed by Cloudflare, run the command:
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
if err != nil {
log.WithError(err).Fatal("Error opening metrics server listener")
Log.WithError(err).Fatal("Error opening metrics server listener")
}
go func() {
errC <- metrics.ServeMetrics(metricsListener, shutdownC)
wg.Done()
}()
if !c.Bool("no-autoupdate") {
log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
go autoupdate(c.Duration("autoupdate-period"), shutdownC)
}
var errCode int
err = WaitForSignal(errC, shutdownC)
if err != nil {
log.WithError(err).Error("Quitting due to error")
Log.WithError(err).Error("Quitting due to error")
raven.CaptureErrorAndWait(err, nil)
errCode = 1
} else {
log.Info("Quitting...")
Log.Info("Quitting...")
}
// Wait for clean exit, discarding all errors
go func() {
@ -453,6 +486,7 @@ If you don't have a certificate signed by Cloudflare, run the command:
}
}()
wg.Wait()
os.Exit(errCode)
}
func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
@ -477,30 +511,40 @@ func update(c *cli.Context) error {
return nil
}
func autoupdate(frequency time.Duration, shutdownC chan struct{}) {
if int64(frequency) == 0 {
return
func initUpdate() bool {
if updateApplied() {
os.Args = append(os.Args, "--is-autoupdated=true")
if _, err := listeners.StartProcess(); err != nil {
Log.WithError(err).Error("Unable to restart server automatically")
return false
}
return true
}
return false
}
func autoupdate(freq time.Duration, shutdownC chan struct{}) {
for {
if updateApplied() {
os.Args = append(os.Args, "--is-autoupdated=true")
if _, err := listeners.StartProcess(); err != nil {
log.WithError(err).Error("Unable to restart server automatically")
Log.WithError(err).Error("Unable to restart server automatically")
}
close(shutdownC)
return
}
time.Sleep(frequency)
time.Sleep(freq)
}
}
func updateApplied() bool {
releaseInfo := checkForUpdates()
if releaseInfo.Updated {
log.Infof("Updated to version %s", releaseInfo.Version)
Log.Infof("Updated to version %s", releaseInfo.Version)
return true
}
if releaseInfo.Error != nil {
log.WithError(releaseInfo.Error).Error("Update check failed")
Log.WithError(releaseInfo.Error).Error("Update check failed")
}
return false
}
@ -555,7 +599,7 @@ func writePidFile(waitForSignal chan struct{}, pidFile string) {
}
file, err := os.Create(pidFile)
if err != nil {
log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
Log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
}
defer file.Close()
fmt.Fprintf(file, "%d", os.Getpid())
@ -573,3 +617,55 @@ func validateUrl(c *cli.Context) (string, error) {
validUrl, err := validation.ValidateUrl(url)
return validUrl, err
}
func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error {
fileMode := os.O_WRONLY|os.O_APPEND|os.O_CREATE|os.O_TRUNC
// do not truncate log file if the client has been autoupdated
if c.Bool("is-autoupdated") {
fileMode = os.O_WRONLY|os.O_APPEND|os.O_CREATE
}
f, err := os.OpenFile(c.String("logfile"), fileMode, 0664)
if err != nil {
errors.Wrap(err, fmt.Sprintf("Cannot open file %s", c.String("logfile")))
}
defer f.Close()
pathMap := lfshook.PathMap{
logrus.InfoLevel: c.String("logfile"),
logrus.ErrorLevel: c.String("logfile"),
logrus.FatalLevel: c.String("logfile"),
logrus.PanicLevel: c.String("logfile"),
}
Log.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
flags := make(map[string]interface{})
envs := make(map[string]string)
for _, flag := range c.LocalFlagNames() {
flags[flag] = c.Generic(flag)
}
// Find env variables for Warp
for _, env := range os.Environ() {
// All Warp env variables start with TUNNEL_
if strings.Contains(env, "TUNNEL_") {
vars := strings.Split(env, "=")
if len(vars) == 2 {
envs[vars[0]] = vars[1]
}
}
}
Log.Infof("Warp build and runtime configuration: %+v", BuildAndRuntimeInfo{
GoOS: runtime.GOOS,
GoVersion: runtime.Version(),
GoArch: runtime.GOARCH,
WarpVersion: Version,
WarpFlags: flags,
WarpEnvs: envs,
})
return nil
}

View File

@ -9,7 +9,7 @@ import (
"fmt"
"os"
log "github.com/Sirupsen/logrus"
log "github.com/sirupsen/logrus"
cli "gopkg.in/urfave/cli.v2"
"golang.org/x/sys/windows/svc"

View File

@ -6,7 +6,7 @@ import (
"sync"
"time"
log "github.com/Sirupsen/logrus"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)

View File

@ -6,13 +6,14 @@ import (
"io"
"io/ioutil"
"math/rand"
"net"
"os"
"strconv"
"sync"
"testing"
"time"
log "github.com/Sirupsen/logrus"
log "github.com/sirupsen/logrus"
)
func TestMain(m *testing.M) {
@ -25,25 +26,20 @@ func TestMain(m *testing.M) {
type DefaultMuxerPair struct {
OriginMuxConfig MuxerConfig
OriginMux *Muxer
OriginWriter *io.PipeWriter
OriginReader *io.PipeReader
OriginConn net.Conn
EdgeMuxConfig MuxerConfig
EdgeMux *Muxer
EdgeWriter *io.PipeWriter
EdgeReader *io.PipeReader
EdgeConn net.Conn
doneC chan struct{}
}
func NewDefaultMuxerPair() *DefaultMuxerPair {
originReader, edgeWriter := io.Pipe()
edgeReader, originWriter := io.Pipe()
origin, edge := net.Pipe()
return &DefaultMuxerPair{
OriginMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: true, Name: "origin"},
OriginWriter: originWriter,
OriginReader: originReader,
OriginConn: origin,
EdgeMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: false, Name: "edge"},
EdgeWriter: edgeWriter,
EdgeReader: edgeReader,
EdgeConn: edge,
doneC: make(chan struct{}),
}
}
@ -53,12 +49,12 @@ func (p *DefaultMuxerPair) Handshake(t *testing.T) {
originErrC := make(chan error)
go func() {
var err error
p.EdgeMux, err = Handshake(p.EdgeWriter, p.EdgeReader, p.EdgeMuxConfig)
p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig)
edgeErrC <- err
}()
go func() {
var err error
p.OriginMux, err = Handshake(p.OriginWriter, p.OriginReader, p.OriginMuxConfig)
p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig)
originErrC <- err
}()
@ -120,8 +116,8 @@ func (p *DefaultMuxerPair) Wait(t *testing.T) {
func TestHandshake(t *testing.T) {
muxPair := NewDefaultMuxerPair()
muxPair.Handshake(t)
AssertIfPipeReadable(t, muxPair.OriginReader)
AssertIfPipeReadable(t, muxPair.EdgeReader)
AssertIfPipeReadable(t, muxPair.OriginConn)
AssertIfPipeReadable(t, muxPair.EdgeConn)
}
func TestSingleStream(t *testing.T) {
@ -145,7 +141,7 @@ func TestSingleStream(t *testing.T) {
stream.Write(buf)
// after this receive, the edge closed the stream
<-closeC
n, err := stream.Read(buf)
n, err := io.ReadFull(stream, buf)
if n > 0 {
t.Fatalf("read %d bytes after EOF", n)
}
@ -173,7 +169,7 @@ func TestSingleStream(t *testing.T) {
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
}
responseBody := make([]byte, 11)
n, err := stream.Read(responseBody)
n, err := io.ReadFull(stream, responseBody)
if err != nil {
t.Fatalf("error from (*MuxedStream).Read: %s", err)
}
@ -243,7 +239,7 @@ func TestSingleStreamLargeResponseBody(t *testing.T) {
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
}
responseBody := make([]byte, bodySize)
n, err := stream.Read(responseBody)
n, err := io.ReadFull(stream, responseBody)
if err != nil {
t.Fatalf("error from (*MuxedStream).Read: %s", err)
}
@ -302,7 +298,7 @@ func TestMultipleStreams(t *testing.T) {
return
}
responseBody := make([]byte, 2)
n, err := stream.Read(responseBody)
n, err := io.ReadFull(stream, responseBody)
if err != nil {
errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err)
return
@ -392,7 +388,7 @@ func TestMultipleStreamsFlowControl(t *testing.T) {
}
responseBody := make([]byte, responseSizes[(stream.streamID-2)/2])
n, err := stream.Read(responseBody)
n, err := io.ReadFull(stream, responseBody)
if err != nil {
errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err)
return
@ -451,7 +447,7 @@ func TestGracefulShutdown(t *testing.T) {
}
responseBody := make([]byte, len(responseBuf))
log.Debugf("Waiting for %d bytes", len(responseBuf))
n, err := stream.Read(responseBody)
n, err := io.ReadFull(stream, responseBody)
if err != nil {
t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err)
}
@ -498,13 +494,13 @@ func TestUnexpectedShutdown(t *testing.T) {
nil,
)
// Close the underlying connection before telling the origin to write.
muxPair.EdgeReader.Close()
muxPair.EdgeConn.Close()
close(sendC)
if err != nil {
t.Fatalf("error in OpenStream: %s", err)
}
responseBody := make([]byte, len(responseBuf))
n, err := stream.Read(responseBody)
n, err := io.ReadFull(stream, responseBody)
if err != io.EOF {
t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err)
}
@ -545,14 +541,14 @@ func TestOpenAfterDisconnect(t *testing.T) {
switch i {
case 0:
// Close both directions of the connection to cause EOF on both peers.
muxPair.OriginReader.Close()
muxPair.OriginWriter.Close()
muxPair.OriginConn.Close()
muxPair.EdgeConn.Close()
case 1:
// Close origin reader (edge writer) to cause EOF on origin only.
muxPair.OriginReader.Close()
// Close origin conn to cause EOF on origin first.
muxPair.OriginConn.Close()
case 2:
// Close origin writer (edge reader) to cause EOF on edge only.
muxPair.OriginWriter.Close()
// Close edge conn to cause EOF on edge first.
muxPair.EdgeConn.Close()
}
_, err := muxPair.EdgeMux.OpenStream(
@ -623,7 +619,7 @@ func TestHPACK(t *testing.T) {
}
}
func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) {
func AssertIfPipeReadable(t *testing.T, pipe io.ReadCloser) {
errC := make(chan error)
go func() {
b := []byte{0}
@ -640,7 +636,5 @@ func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) {
}
case <-time.After(100 * time.Millisecond):
// nothing to read
pipe.Close()
<-errC
}
}

View File

@ -1,6 +1,7 @@
package h2mux
import (
"io"
"testing"
"github.com/stretchr/testify/assert"
@ -63,3 +64,27 @@ func TestFlowControlSingleStream(t *testing.T) {
assert.Equal(t, testWindowSize<<2, stream.receiveWindow)
assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax)
}
func TestMuxedStreamEOF(t *testing.T) {
for i := 0; i < 4096; i++ {
readyList := NewReadyList()
stream := &MuxedStream{
streamID: 1,
readBuffer: NewSharedBuffer(),
receiveWindow: 65536,
receiveWindowMax: 65536,
sendWindow: 65536,
readyList: readyList,
}
go func() { stream.Close() }()
n, err := stream.Read([]byte{0})
assert.Equal(t, io.EOF, err)
assert.Equal(t, 0, n)
// Write comes after read, because write buffers data before it is flushed. It wouldn't know about EOF
// until some time later. Calling read first forces it to know about EOF now.
n, err = stream.Write([]byte{1})
assert.Equal(t, io.EOF, err)
assert.Equal(t, 0, n)
}
}

View File

@ -6,7 +6,7 @@ import (
"sync"
"time"
log "github.com/Sirupsen/logrus"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
)

View File

@ -6,7 +6,7 @@ import (
"io"
"time"
log "github.com/Sirupsen/logrus"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)

View File

@ -21,7 +21,7 @@ func NewSharedBuffer() *SharedBuffer {
func (s *SharedBuffer) Read(p []byte) (n int, err error) {
totalRead := 0
s.cond.L.Lock()
for totalRead < len(p) {
for totalRead == 0 {
n, err = s.buffer.Read(p[totalRead:])
totalRead += n
if err == io.EOF {
@ -29,6 +29,9 @@ func (s *SharedBuffer) Read(p []byte) (n int, err error) {
break
}
err = nil
if n > 0 {
break
}
s.cond.Wait()
}
}

View File

@ -6,6 +6,8 @@ import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) {
@ -29,30 +31,35 @@ func TestSharedBuffer(t *testing.T) {
func TestSharedBufferBlockingRead(t *testing.T) {
b := NewSharedBuffer()
testData := []byte("Hello world")
testData1 := []byte("Hello")
testData2 := []byte(" world")
result := make(chan []byte)
go func() {
bytesRead := make([]byte, len(testData))
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
result <- bytesRead
bytesRead := make([]byte, len(testData1)+len(testData2))
nRead, err := b.Read(bytesRead)
AssertIOReturnIsGood(t, len(testData1))(nRead, err)
result <- bytesRead[:nRead]
nRead, err = b.Read(bytesRead)
AssertIOReturnIsGood(t, len(testData2))(nRead, err)
result <- bytesRead[:nRead]
}()
time.Sleep(time.Millisecond * 250)
select {
case <-result:
t.Fatalf("read returned early")
default:
}
AssertIOReturnIsGood(t, 5)(b.Write(testData[:5]))
select {
case <-result:
t.Fatalf("read returned early")
default:
}
AssertIOReturnIsGood(t, len(testData)-5)(b.Write(testData[5:]))
AssertIOReturnIsGood(t, len(testData1))(b.Write([]byte(testData1)))
select {
case r := <-result:
if string(r) != string(testData) {
t.Fatalf("expected read to return %s, got %s", testData, r)
}
assert.Equal(t, testData1, r)
case <-time.After(time.Second):
t.Fatalf("read timed out")
}
AssertIOReturnIsGood(t, len(testData2))(b.Write([]byte(testData2)))
select {
case r := <-result:
assert.Equal(t, testData2, r)
case <-time.After(time.Second):
t.Fatalf("read timed out")
}
@ -85,7 +92,7 @@ func TestSharedBufferConcurrentReadWrite(t *testing.T) {
// Change block sizes in opposition to the write thread, to test blocking for new data.
for blockSize := 256; blockSize > 0; blockSize-- {
for i := 0; i < 256; i++ {
n, err := b.Read(block[:blockSize])
n, err := io.ReadFull(b, block[:blockSize])
if n != blockSize || err != nil {
t.Fatalf("read error: %d %s", n, err)
}

View File

@ -11,9 +11,9 @@ import (
"golang.org/x/net/context"
"golang.org/x/net/trace"
log "github.com/Sirupsen/logrus"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
)
const (

View File

@ -5,7 +5,6 @@ import (
"github.com/cloudflare/cloudflare-warp/h2mux"
log "github.com/Sirupsen/logrus"
"github.com/prometheus/client_golang/prometheus"
)
@ -249,7 +248,7 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
if _, ok := t.concurrentRequests[connectionID]; ok {
t.concurrentRequests[connectionID] -= 1
} else {
log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.")
Log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.")
}
t.concurrentRequestsLock.Unlock()

View File

@ -5,7 +5,6 @@ import (
"net"
"time"
log "github.com/Sirupsen/logrus"
"golang.org/x/net/context"
)
@ -73,7 +72,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
case tunnelError := <-s.tunnelErrors:
tunnelsActive--
if tunnelError.err != nil {
log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
Log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
s.waitForNextTunnel(tunnelError.index)
if backoffTimer == nil {
@ -107,10 +106,10 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
s.lastResolve = time.Now()
s.resolverC = nil
if result.err == nil {
log.Debug("Service discovery refresh complete")
Log.Debug("Service discovery refresh complete")
s.edgeIPs = result.edgeIPs
} else {
log.WithError(result.err).Error("Service discovery error")
Log.WithError(result.err).Error("Service discovery error")
}
}
}
@ -120,12 +119,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
if err != nil {
log.Infof("ResolveEdgeIPs err")
Log.Infof("ResolveEdgeIPs err")
return err
}
s.edgeIPs = edgeIPs
if s.config.HAConnections > len(edgeIPs) {
log.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
Log.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
s.config.HAConnections = len(edgeIPs)
}
s.lastResolve = time.Now()

View File

@ -19,14 +19,17 @@ import (
"github.com/cloudflare/cloudflare-warp/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
"github.com/cloudflare/cloudflare-warp/validation"
"github.com/cloudflare/cloudflare-warp/websocket"
log "github.com/Sirupsen/logrus"
raven "github.com/getsentry/raven-go"
"github.com/pkg/errors"
_ "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
rpc "zombiezen.com/go/capnproto2/rpc"
)
var Log *logrus.Logger
const (
dialTimeout = 15 * time.Second
@ -40,6 +43,7 @@ type TunnelConfig struct {
Hostname string
OriginCert []byte
TlsConfig *tls.Config
ClientTlsConfig *tls.Config
Retries uint
HeartbeatInterval time.Duration
MaxHeartbeats uint64
@ -51,7 +55,9 @@ type TunnelConfig struct {
HTTPTransport http.RoundTripper
Metrics *TunnelMetrics
MetricsUpdateFreq time.Duration
ProtocolLogger *log.Logger
ProtocolLogger *logrus.Logger
Logger *logrus.Logger
IsAutoupdated bool
}
type dialError struct {
@ -87,14 +93,16 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
Version: c.ReportedVersion,
OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH),
ExistingTunnelPolicy: policy,
PoolID: c.LBPool,
PoolName: c.LBPool,
Tags: c.Tags,
ConnectionID: connectionID,
OriginLocalIP: OriginLocalIP,
IsAutoupdated: c.IsAutoupdated,
}
}
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error {
Log = config.Logger
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-shutdownC
@ -129,7 +137,7 @@ func ServeTunnelLoop(ctx context.Context, config *TunnelConfig, addr *net.TCPAdd
err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff)
if recoverable {
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
log.Infof("Retrying in %s seconds", duration)
Log.Infof("Retrying in %s seconds", duration)
backoff.Backoff(ctx)
continue
}
@ -162,11 +170,10 @@ func ServeTunnel(
// Returns error from parsing the origin URL or handshake errors
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID)
if err != nil {
errLog := log.WithError(err)
errLog := Log.WithError(err)
switch err.(type) {
case dialError:
errLog.Error("Unable to dial edge")
return err, false
case h2mux.MuxerHandshakeError:
errLog.Error("Handshake failed with edge server")
default:
@ -207,24 +214,21 @@ func ServeTunnel(
registerErr := <-registerErrC
wg.Wait()
if err != nil {
log.WithError(err).Error("Tunnel error")
Log.WithError(err).Error("Tunnel error")
return err, true
}
if registerErr != nil {
// Don't retry on errors like entitlement failure or version too old
if e, ok := registerErr.(printableRegisterTunnelError); ok {
log.Error(e)
if e.permanent {
return e, false
}
return e.cause, true
Log.Error(e)
return e.cause, !e.permanent
} else if e, ok := registerErr.(dupConnRegisterTunnelError); ok {
log.Info("Already connected to this server, selecting a different one")
Log.Info("Already connected to this server, selecting a different one")
return e, true
}
// Only log errors to Sentry that may have been caused by the client side, to reduce dupes
raven.CaptureError(registerErr, nil)
log.Error("Cannot register")
Log.Error("Cannot register")
return err, true
}
return nil, false
@ -241,7 +245,7 @@ func IsRPCStreamResponse(headers []h2mux.Header) bool {
}
func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error {
logger := log.WithField("subsystem", "rpc")
logger := Log.WithField("subsystem", "rpc")
logger.Debug("initiating RPC stream")
stream, err := muxer.OpenStream([]h2mux.Header{
{Name: ":method", Value: "RPC"},
@ -292,11 +296,11 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
}
}
log.Infof("Registered at %s", registration.Url)
Log.Infof("Registered at %s", registration.Url)
return nil
}
func LogServerInfo(logger *log.Entry,
func LogServerInfo(logger *logrus.Entry,
promise tunnelrpc.ServerInfo_Promise,
connectionID uint8,
metrics *TunnelMetrics,
@ -311,7 +315,7 @@ func LogServerInfo(logger *log.Entry,
logger.WithError(err).Warn("Failed to retrieve server information")
return
}
log.Infof("Connected to %s", serverInfo.LocationName)
Log.Infof("Connected to %s", serverInfo.LocationName)
metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName)
}
@ -356,6 +360,7 @@ type TunnelHandler struct {
originUrl string
muxer *h2mux.Muxer
httpClient http.RoundTripper
tlsConfig *tls.Config
tags []tunnelpogs.Tag
metrics *TunnelMetrics
// connectionID is only used by metrics, and prometheus requires labels to be string
@ -373,6 +378,7 @@ func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, co
h := &TunnelHandler{
originUrl: url,
httpClient: config.HTTPTransport,
tlsConfig: config.ClientTlsConfig,
tags: config.Tags,
metrics: config.Metrics,
connectionID: uint8ToString(connectionID),
@ -422,29 +428,45 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
h.metrics.incrementRequests(h.connectionID)
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil {
log.WithError(err).Panic("Unexpected error from http.NewRequest")
Log.WithError(err).Panic("Unexpected error from http.NewRequest")
}
err = H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil {
log.WithError(err).Error("invalid request received")
Log.WithError(err).Error("invalid request received")
}
h.AppendTagHeaders(req)
response, err := h.httpClient.RoundTrip(req)
if err != nil {
log.WithError(err).Error("HTTP request error")
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
stream.Write([]byte("502 Bad Gateway"))
h.metrics.incrementResponses(h.connectionID, "502")
if websocket.IsWebSocketUpgrade(req) {
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
if err != nil {
h.logError(stream, err)
} else {
stream.WriteHeaders(H1ResponseToH2Response(response))
defer conn.Close()
websocket.Stream(conn.UnderlyingConn(), stream)
}
} else {
defer response.Body.Close()
stream.WriteHeaders(H1ResponseToH2Response(response))
io.Copy(stream, response.Body)
h.metrics.incrementResponses(h.connectionID, "200")
response, err := h.httpClient.RoundTrip(req)
if err != nil {
h.logError(stream, err)
} else {
defer response.Body.Close()
stream.WriteHeaders(H1ResponseToH2Response(response))
io.Copy(stream, response.Body)
h.metrics.incrementResponses(h.connectionID, "200")
}
}
h.metrics.decrementConcurrentRequests(h.connectionID)
return nil
}
func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
Log.WithError(err).Error("HTTP request error")
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
stream.Write([]byte("502 Bad Gateway"))
h.metrics.incrementResponses(h.connectionID, "502")
}
func (h *TunnelHandler) UpdateMetrics() {
flowCtlMetrics := h.muxer.FlowControlMetrics()
h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics)

View File

@ -0,0 +1,95 @@
package tlsconfig
import (
"crypto/x509"
)
// TODO: remove the Origin CA root certs when migrated to Authenticated Origin Pull certs
const cloudflareRootCA = `
Issuer: C=US, ST=California, L=San Francisco, O=CloudFlare, Inc., OU=CloudFlare Origin SSL ECC Certificate Authority
-----BEGIN CERTIFICATE-----
MIICiDCCAi6gAwIBAgIUXZP3MWb8MKwBE1Qbawsp1sfA/Y4wCgYIKoZIzj0EAwIw
gY8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1T
YW4gRnJhbmNpc2NvMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTgwNgYDVQQL
Ey9DbG91ZEZsYXJlIE9yaWdpbiBTU0wgRUNDIENlcnRpZmljYXRlIEF1dGhvcml0
eTAeFw0xNjAyMjIxODI0MDBaFw0yMTAyMjIwMDI0MDBaMIGPMQswCQYDVQQGEwJV
UzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEZ
MBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE4MDYGA1UECxMvQ2xvdWRGbGFyZSBP
cmlnaW4gU1NMIEVDQyBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkwWTATBgcqhkjOPQIB
BggqhkjOPQMBBwNCAASR+sGALuaGshnUbcxKry+0LEXZ4NY6JUAtSeA6g87K3jaA
xpIg9G50PokpfWkhbarLfpcZu0UAoYy2su0EhN7wo2YwZDAOBgNVHQ8BAf8EBAMC
AQYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUhTBdOypw1O3VkmcH/es5
tBoOOKcwHwYDVR0jBBgwFoAUhTBdOypw1O3VkmcH/es5tBoOOKcwCgYIKoZIzj0E
AwIDSAAwRQIgEiIEHQr5UKma50D1WRMJBUSgjg24U8n8E2mfw/8UPz0CIQCr5V/e
mcifak4CQsr+DH4pn5SJD7JxtCG3YGswW8QZsw==
-----END CERTIFICATE-----
Issuer: C=US, O=CloudFlare, Inc., OU=CloudFlare Origin SSL Certificate Authority, L=San Francisco, ST=California
-----BEGIN CERTIFICATE-----
MIID/DCCAuagAwIBAgIID+rOSdTGfGcwCwYJKoZIhvcNAQELMIGLMQswCQYDVQQG
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRG
bGFyZSBPcmlnaW4gU1NMIENlcnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMN
U2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTAeFw0xNDExMTMyMDM4
NTBaFw0xOTExMTQwMTQzNTBaMIGLMQswCQYDVQQGEwJVUzEZMBcGA1UEChMQQ2xv
dWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRGbGFyZSBPcmlnaW4gU1NMIENl
cnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEG
A1UECBMKQ2FsaWZvcm5pYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
AMBIlWf1KEKR5hbB75OYrAcUXobpD/AxvSYRXr91mbRu+lqE7YbyyRUShQh15lem
ef+umeEtPZoLFLhcLyczJxOhI+siLGDQm/a/UDkWvAXYa5DZ+pHU5ct5nZ8pGzqJ
p8G1Hy5RMVYDXZT9F6EaHjMG0OOffH6Ih25TtgfyyrjXycwDH0u6GXt+G/rywcqz
/9W4Aki3XNQMUHNQAtBLEEIYHMkyTYJxuL2tXO6ID5cCsoWw8meHufTeZW2DyUpl
yP3AHt4149RQSyWZMJ6AyntL9d8Xhfpxd9rJkh9Kge2iV9rQTFuE1rRT5s7OSJcK
xUsklgHcGHYMcNfNMilNHb8CAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgAGMBIGA1Ud
EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFCToU1ddfDRAh6nrlNu64RZ4/CmkMB8G
A1UdIwQYMBaAFCToU1ddfDRAh6nrlNu64RZ4/CmkMAsGCSqGSIb3DQEBCwOCAQEA
cQDBVAoRrhhsGegsSFsv1w8v27zzHKaJNv6ffLGIRvXK8VKKK0gKXh2zQtN9SnaD
gYNe7Pr4C3I8ooYKRJJWLsmEHdGdnYYmj0OJfGrfQf6MLIc/11bQhLepZTxdhFYh
QGgDl6gRmb8aDwk7Q92BPvek5nMzaWlP82ixavvYI+okoSY8pwdcVKobx6rWzMWz
ZEC9M6H3F0dDYE23XcCFIdgNSAmmGyXPBstOe0aAJXwJTxOEPn36VWr0PKIQJy5Y
4o1wpMpqCOIwWc8J9REV/REzN6Z1LXImdUgXIXOwrz56gKUJzPejtBQyIGj0mveX
Fu6q54beR89jDc+oABmOgg==
-----END CERTIFICATE-----
Issuer: C=US, O=CloudFlare, Inc., OU=Origin Pull, L=San Francisco, ST=California, CN=origin-pull.cloudflare.net
-----BEGIN CERTIFICATE-----
MIIGBjCCA/CgAwIBAgIIV5G6lVbCLmEwCwYJKoZIhvcNAQENMIGQMQswCQYDVQQG
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjEUMBIGA1UECxMLT3JpZ2lu
IFB1bGwxFjAUBgNVBAcTDVNhbiBGcmFuY2lzY28xEzARBgNVBAgTCkNhbGlmb3Ju
aWExIzAhBgNVBAMTGm9yaWdpbi1wdWxsLmNsb3VkZmxhcmUubmV0MB4XDTE1MDEx
MzAyNDc1M1oXDTIwMDExMjAyNTI1M1owgZAxCzAJBgNVBAYTAlVTMRkwFwYDVQQK
ExBDbG91ZEZsYXJlLCBJbmMuMRQwEgYDVQQLEwtPcmlnaW4gUHVsbDEWMBQGA1UE
BxMNU2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTEjMCEGA1UEAxMa
b3JpZ2luLXB1bGwuY2xvdWRmbGFyZS5uZXQwggIiMA0GCSqGSIb3DQEBAQUAA4IC
DwAwggIKAoICAQDdsts6I2H5dGyn4adACQRXlfo0KmwsN7B5rxD8C5qgy6spyONr
WV0ecvdeGQfWa8Gy/yuTuOnsXfy7oyZ1dm93c3Mea7YkM7KNMc5Y6m520E9tHooc
f1qxeDpGSsnWc7HWibFgD7qZQx+T+yfNqt63vPI0HYBOYao6hWd3JQhu5caAcIS2
ms5tzSSZVH83ZPe6Lkb5xRgLl3eXEFcfI2DjnlOtLFqpjHuEB3Tr6agfdWyaGEEi
lRY1IB3k6TfLTaSiX2/SyJ96bp92wvTSjR7USjDV9ypf7AD6u6vwJZ3bwNisNw5L
ptph0FBnc1R6nDoHmvQRoyytoe0rl/d801i9Nru/fXa+l5K2nf1koR3IX440Z2i9
+Z4iVA69NmCbT4MVjm7K3zlOtwfI7i1KYVv+ATo4ycgBuZfY9f/2lBhIv7BHuZal
b9D+/EK8aMUfjDF4icEGm+RQfExv2nOpkR4BfQppF/dLmkYfjgtO1403X0ihkT6T
PYQdmYS6Jf53/KpqC3aA+R7zg2birtvprinlR14MNvwOsDOzsK4p8WYsgZOR4Qr2
gAx+z2aVOs/87+TVOR0r14irQsxbg7uP2X4t+EXx13glHxwG+CnzUVycDLMVGvuG
aUgF9hukZxlOZnrl6VOf1fg0Caf3uvV8smOkVw6DMsGhBZSJVwao0UQNqQIDAQAB
o2YwZDAOBgNVHQ8BAf8EBAMCAAYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4E
FgQUQ1lLK2mLgOERM2pXzVc42p59xeswHwYDVR0jBBgwFoAUQ1lLK2mLgOERM2pX
zVc42p59xeswCwYJKoZIhvcNAQENA4ICAQDKDQM1qPRVP/4Gltz0D6OU6xezFBKr
LWtDoA1qW2F7pkiYawCP9MrDPDJsHy7dx+xw3bBZxOsK5PA/T7p1dqpEl6i8F692
g//EuYOifLYw3ySPe3LRNhvPl/1f6Sn862VhPvLa8aQAAwR9e/CZvlY3fj+6G5ik
3it7fikmKUsVnugNOkjmwI3hZqXfJNc7AtHDFw0mEOV0dSeAPTo95N9cxBbm9PKv
qAEmTEXp2trQ/RjJ/AomJyfA1BQjsD0j++DI3a9/BbDwWmr1lJciKxiNKaa0BRLB
dKMrYQD+PkPNCgEuojT+paLKRrMyFUzHSG1doYm46NE9/WARTh3sFUp1B7HZSBqA
kHleoB/vQ/mDuW9C3/8Jk2uRUdZxR+LoNZItuOjU8oTy6zpN1+GgSj7bHjiy9rfA
F+ehdrz+IOh80WIiqs763PGoaYUyzxLvVowLWNoxVVoc9G+PqFKqD988XlipHVB6
Bz+1CD4D/bWrs3cC9+kk/jFmrrAymZlkFX8tDb5aXASSLJjUjcptci9SKqtI2h0J
wUGkD7+bQAr+7vr8/R+CBmNMe7csE8NeEX6lVMF7Dh0a1YKQa6hUN18bBuYgTMuT
QzMmZpRpIBB321ZBlcnlxiTJvWxvbCPHKHj20VwwAz7LONF59s84ZsOqfoBv8gKM
s0s5dsq5zpLeaw==
-----END CERTIFICATE-----`
func GetCloudflareRootCA() *x509.CertPool {
ca := x509.NewCertPool()
if !ca.AppendCertsFromPEM([]byte(cloudflareRootCA)) {
// should never happen
panic("failure loading Cloudflare origin CA pem")
}
return ca
}

50
tlsconfig/hello_ca.go Normal file
View File

@ -0,0 +1,50 @@
package tlsconfig
import (
"crypto/tls"
"crypto/x509"
)
const (
helloKey = `
-----BEGIN EC PARAMETERS-----
BgUrgQQAIg==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
MIGkAgEBBDAdyQBXfxTDCQSOT0HugmH9pVBtIw8t5dYvm6HxGlNq6P57v5GeN02Z
dH9FRl7+VSWgBwYFK4EEACKhZANiAATqpFzTxxV7D+/oqhKCTR6BEM9elTfKaRQE
FsLufcmaTMw/9tTwgpHKao/QsLKDTNbQhbSQLkcmpCQKlSGhl+pCrqNt/oYUAhav
UIwpwGiLCqGH/R2AqWLKRPOa/Rufs/U=
-----END EC PRIVATE KEY-----`
helloCRT = `
-----BEGIN CERTIFICATE-----
MIICkDCCAhigAwIBAgIJAPtKfUjc2lwGMAkGByqGSM49BAEwgYoxCzAJBgNVBAYT
AlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQKDBBD
bG91ZGZsYXJlLCBJbmMuMT8wPQYDVQQDDDZDbG91ZGZsYXJlIEFyZ28gVHVubmVs
IFNhbXBsZSBIZWxsbyBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMjE1MjAxNjU5
WhcNMjgwMjEzMjAxNjU5WjCBijELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVRleGFz
MQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENsb3VkZmxhcmUsIEluYy4xPzA9
BgNVBAMMNkNsb3VkZmxhcmUgQXJnbyBUdW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZl
ciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49AgEGBSuBBAAiA2IABOqkXNPHFXsP7+iq
EoJNHoEQz16VN8ppFAQWwu59yZpMzD/21PCCkcpqj9CwsoNM1tCFtJAuRyakJAqV
IaGX6kKuo23+hhQCFq9QjCnAaIsKoYf9HYCpYspE85r9G5+z9aNJMEcwRQYDVR0R
BD4wPIIJbG9jYWxob3N0ggp3YXJwLWhlbGxvggt3YXJwMi1oZWxsb4cEfwAAAYcQ
AAAAAAAAAAAAAAAAAAAAATAJBgcqhkjOPQQBA2cAMGQCMHyVPufXZ6vQo6XRWRa0
dAwtfgesOdZVP2Wt+t5v8jOIQQh1IQXYk5GtyoZGSObjhQIwd1fRgAyKXaZt+1DV
ZtHTdf8pMvESfJsSd8AB1eQ6q+pAiRUYyaxcE1Mlo2YY5o+g
-----END CERTIFICATE-----`
)
func GetHelloCertificate() (tls.Certificate, error) {
return tls.X509KeyPair([]byte(helloCRT), []byte(helloKey))
}
func GetHelloCertificateX509() (*x509.Certificate, error) {
helloCertificate, err := GetHelloCertificate()
if err != nil {
return nil, err
}
return x509.ParseCertificate(helloCertificate.Certificate[0])
}

View File

@ -6,8 +6,9 @@ import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net"
log "github.com/Sirupsen/logrus"
log "github.com/sirupsen/logrus"
cli "gopkg.in/urfave/cli.v2"
)
@ -60,3 +61,43 @@ func LoadCert(certPath string) *x509.CertPool {
}
return ca
}
func LoadOriginCertsPool() *x509.CertPool {
// First, obtain the system certificate pool
certPool, systemCertPoolErr := x509.SystemCertPool()
if systemCertPoolErr != nil {
log.Warn("error obtaining the system certificates: %s", systemCertPoolErr)
certPool = x509.NewCertPool()
}
// Next, append the Cloudflare CA pool into the system pool
if !certPool.AppendCertsFromPEM([]byte(cloudflareRootCA)) {
log.Warn("could not append the CF certificate to the system certificate pool")
if systemCertPoolErr != nil { // Obtaining both certificates failed; this is a fatal error
log.WithError(systemCertPoolErr).Fatalf("Error loading the certificate pool")
}
}
// Finally, add the Hello certificate into the pool (since it's self-signed)
helloCertificate, err := GetHelloCertificateX509()
if err != nil {
log.Warn("error obtaining the Hello server certificate")
}
certPool.AddCert(helloCertificate)
return certPool
}
func CreateTunnelConfig(c *cli.Context, addrs []string) *tls.Config {
tlsConfig := CLIFlags{RootCA: "cacert"}.GetConfig(c)
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = GetCloudflareRootCA()
tlsConfig.ServerName = "cftunnel.com"
} else if len(addrs) > 0 {
// Set for development environments and for testing specific origintunneld instances
tlsConfig.ServerName, _, _ = net.SplitHostPort(addrs[0])
}
return tlsConfig
}

View File

@ -1,10 +1,9 @@
package tunnelrpc
//go:generate capnp compile -ogo -I./tunnelrpc/ tunnelrpc.capnp
import (
log "github.com/Sirupsen/logrus"
log "github.com/sirupsen/logrus"
"golang.org/x/net/context"
"golang.org/x/net/trace"
"zombiezen.com/go/capnproto2/rpc"
)
@ -24,3 +23,20 @@ func (c ConnLogger) Errorf(ctx context.Context, format string, args ...interface
func ConnLog(log *log.Entry) rpc.ConnOption {
return rpc.ConnLog(ConnLogger{log})
}
// ConnTracer wraps a trace.EventLog for a connection.
type ConnTracer struct {
Events trace.EventLog
}
func (c ConnTracer) Infof(ctx context.Context, format string, args ...interface{}) {
c.Events.Printf(format, args...)
}
func (c ConnTracer) Errorf(ctx context.Context, format string, args ...interface{}) {
c.Events.Errorf(format, args...)
}
func ConnTrace(events trace.EventLog) rpc.ConnOption {
return rpc.ConnLog(ConnTracer{events})
}

View File

@ -4,7 +4,7 @@ package tunnelrpc
import (
"bytes"
log "github.com/Sirupsen/logrus"
log "github.com/sirupsen/logrus"
"golang.org/x/net/context"
"zombiezen.com/go/capnproto2/encoding/text"
"zombiezen.com/go/capnproto2/rpc"

View File

@ -47,10 +47,11 @@ type RegistrationOptions struct {
Version string
OS string `capnp:"os"`
ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy
PoolID string `capnp:"poolId"`
PoolName string `capnp:"poolName"`
Tags []Tag
ConnectionID uint8 `capnp:"connectionId"`
ConnectionID uint8 `capnp:"connectionId"`
OriginLocalIP string `capnp:"originLocalIp"`
IsAutoupdated bool `capnp:"isAutoupdated"`
}
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {

View File

@ -28,13 +28,15 @@ struct RegistrationOptions {
# What to do with existing tunnels for the given hostname.
existingTunnelPolicy @3 :ExistingTunnelPolicy;
# If using the balancing policy, identifies the LB pool to use.
poolId @4 :Text;
poolName @4 :Text;
# Client-defined tags to associate with the tunnel
tags @5 :List(Tag);
# A unique identifier for a high-availability connection made by a single client.
connectionId @6 :UInt8;
# origin LAN IP
originLocalIp @7 :Text;
# whether Warp client has been autoupdated
isAutoupdated @8 :Bool;
}
struct Tag {

77
websocket/websocket.go Normal file
View File

@ -0,0 +1,77 @@
package websocket
import (
"bufio"
"crypto/sha1"
"crypto/tls"
"encoding/base64"
"errors"
"io"
"net"
"net/http"
"github.com/gorilla/websocket"
)
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
func IsWebSocketUpgrade(req *http.Request) bool {
return websocket.IsWebSocketUpgrade(req)
}
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing.
func ClientConnect(req *http.Request, tlsClientConfig *tls.Config) (*websocket.Conn, *http.Response, error) {
req.URL.Scheme = "wss"
d := &websocket.Dialer{TLSClientConfig: tlsClientConfig}
conn, response, err := d.Dial(req.URL.String(), nil)
if err != nil {
return nil, nil, err
}
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
return conn, response, err
}
// HijackConnection takes over an HTTP connection. Caller is responsible for closing connection.
func HijackConnection(w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
hj, ok := w.(http.Hijacker)
if !ok {
return nil, nil, errors.New("hijack error")
}
conn, brw, err := hj.Hijack()
if err != nil {
return nil, nil, err
}
return conn, brw, nil
}
// Stream copies copy data to & from provided io.ReadWriters.
func Stream(conn, backendConn io.ReadWriter) {
proxyDone := make(chan struct{}, 2)
go func() {
io.Copy(conn, backendConn)
proxyDone <- struct{}{}
}()
go func() {
io.Copy(backendConn, conn)
proxyDone <- struct{}{}
}()
// If one side is done, we are done.
<-proxyDone
}
// sha1Base64 sha1 and then base64 encodes str.
func sha1Base64(str string) string {
hasher := sha1.New()
io.WriteString(hasher, str)
hash := hasher.Sum(nil)
return base64.StdEncoding.EncodeToString(hash)
}
// generateAcceptKey returns the string needed for the Sec-WebSocket-Accept header.
// https://tools.ietf.org/html/rfc6455#section-1.3 describes this process in more detail.
func generateAcceptKey(req *http.Request) string {
return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
}