Release Warp Client 2018.2.1
This commit is contained in:
parent
e0ae598112
commit
3780e14f41
|
@ -2,17 +2,20 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
"html/template"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/gorilla/websocket"
|
||||||
|
"gopkg.in/urfave/cli.v2"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
"github.com/cloudflare/cloudflare-warp/tlsconfig"
|
||||||
cli "gopkg.in/urfave/cli.v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type templateData struct {
|
type templateData struct {
|
||||||
|
@ -21,6 +24,11 @@ type templateData struct {
|
||||||
Body string
|
Body string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OriginUpTime struct {
|
||||||
|
StartTime time.Time `json:"startTime"`
|
||||||
|
UpTime string `json:"uptime"`
|
||||||
|
}
|
||||||
|
|
||||||
const defaultServerName = "the Cloudflare Warp test server"
|
const defaultServerName = "the Cloudflare Warp test server"
|
||||||
const indexTemplate = `
|
const indexTemplate = `
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
|
@ -85,54 +93,95 @@ const indexTemplate = `
|
||||||
|
|
||||||
func hello(c *cli.Context) error {
|
func hello(c *cli.Context) error {
|
||||||
address := fmt.Sprintf(":%d", c.Int("port"))
|
address := fmt.Sprintf(":%d", c.Int("port"))
|
||||||
server := NewHelloWorldServer()
|
listener, err := createListener(address)
|
||||||
if hostname, err := os.Hostname(); err != nil {
|
if err != nil {
|
||||||
server.serverName = hostname
|
return err
|
||||||
}
|
}
|
||||||
err := server.ListenAndServe(address)
|
defer listener.Close()
|
||||||
return errors.Wrap(err, "Fail to start Hello World Server")
|
err = startHelloWorldServer(listener, nil)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
|
func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
|
||||||
server := NewHelloWorldServer()
|
Log.Infof("Starting Hello World server at %s", listener.Addr())
|
||||||
if hostname, err := os.Hostname(); err != nil {
|
serverName := defaultServerName
|
||||||
server.serverName = hostname
|
if hostname, err := os.Hostname(); err == nil {
|
||||||
|
serverName = hostname
|
||||||
}
|
}
|
||||||
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: server}
|
|
||||||
|
upgrader := websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
|
||||||
go func() {
|
go func() {
|
||||||
<-shutdownC
|
<-shutdownC
|
||||||
httpServer.Close()
|
httpServer.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
http.HandleFunc("/uptime", uptimeHandler(time.Now()))
|
||||||
|
http.HandleFunc("/ws", websocketHandler(upgrader))
|
||||||
|
http.HandleFunc("/", rootHandler(serverName))
|
||||||
err := httpServer.Serve(listener)
|
err := httpServer.Serve(listener)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
type HelloWorldServer struct {
|
func createListener(address string) (net.Listener, error) {
|
||||||
responseTemplate *template.Template
|
certificate, err := tlsconfig.GetHelloCertificate()
|
||||||
serverName string
|
if err != nil {
|
||||||
}
|
return nil, err
|
||||||
|
|
||||||
func NewHelloWorldServer() *HelloWorldServer {
|
|
||||||
return &HelloWorldServer{
|
|
||||||
responseTemplate: template.Must(template.New("index").Parse(indexTemplate)),
|
|
||||||
serverName: defaultServerName,
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func findAvailablePort() (net.Listener, error) {
|
// If the port in address is empty, a port number is automatically chosen
|
||||||
// If the port in address is empty, a port number is automatically chosen.
|
listener, err := tls.Listen(
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:")
|
"tcp",
|
||||||
|
address,
|
||||||
|
&tls.Config{Certificates: []tls.Certificate{certificate}})
|
||||||
|
|
||||||
return listener, err
|
return listener, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HelloWorldServer) ListenAndServe(address string) error {
|
func uptimeHandler(startTime time.Time) http.HandlerFunc {
|
||||||
log.Infof("Starting Hello World server on %s", address)
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
err := http.ListenAndServe(address, s)
|
// Note that if autoupdate is enabled, the uptime is reset when a new client
|
||||||
return err
|
// release is available
|
||||||
|
resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()}
|
||||||
|
respJson, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
} else {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write(respJson)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc {
|
||||||
log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto)
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
mt, message, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.WriteMessage(mt, message); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rootHandler(serverName string) http.HandlerFunc {
|
||||||
|
responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
Log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto)
|
||||||
var buffer bytes.Buffer
|
var buffer bytes.Buffer
|
||||||
var body string
|
var body string
|
||||||
rawBody, err := ioutil.ReadAll(r.Body)
|
rawBody, err := ioutil.ReadAll(r.Body)
|
||||||
|
@ -141,8 +190,8 @@ func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
} else {
|
} else {
|
||||||
body = ""
|
body = ""
|
||||||
}
|
}
|
||||||
err = s.responseTemplate.Execute(&buffer, &templateData{
|
err = responseTemplate.Execute(&buffer, &templateData{
|
||||||
ServerName: s.serverName,
|
ServerName: serverName,
|
||||||
Request: r,
|
Request: r,
|
||||||
Body: body,
|
Body: body,
|
||||||
})
|
})
|
||||||
|
@ -152,4 +201,5 @@ func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
} else {
|
} else {
|
||||||
buffer.WriteTo(w)
|
buffer.WriteTo(w)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,18 +4,30 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
const testPort = "8080"
|
func TestCreateListenerHostAndPortSuccess(t *testing.T) {
|
||||||
|
listener, err := createListener("localhost:1234")
|
||||||
func TestNewHelloWorldServer(t *testing.T) {
|
|
||||||
if NewHelloWorldServer() == nil {
|
|
||||||
t.Fatal("NewHelloWorldServer returned nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFindAvailablePort(t *testing.T) {
|
|
||||||
listener, err := findAvailablePort()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("Fail to find available port")
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if listener.Addr().String() == "" {
|
||||||
|
t.Fatal("Fail to find available port")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateListenerOnlyHostSuccess(t *testing.T) {
|
||||||
|
listener, err := createListener("localhost:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if listener.Addr().String() == "" {
|
||||||
|
t.Fatal("Fail to find available port")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateListenerOnlyPortSuccess(t *testing.T) {
|
||||||
|
listener, err := createListener(":8888")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if listener.Addr().String() == "" {
|
if listener.Addr().String() == "" {
|
||||||
t.Fatal("Fail to find available port")
|
t.Fatal("Fail to find available port")
|
||||||
|
|
|
@ -42,6 +42,8 @@ After=network.target
|
||||||
TimeoutStartSec=0
|
TimeoutStartSec=0
|
||||||
Type=notify
|
Type=notify
|
||||||
ExecStart={{ .Path }} --config /etc/cloudflare-warp/config.yml --origincert /etc/cloudflare-warp/cert.pem --no-autoupdate
|
ExecStart={{ .Path }} --config /etc/cloudflare-warp/config.yml --origincert /etc/cloudflare-warp/cert.pem --no-autoupdate
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=5s
|
||||||
|
|
||||||
[Install]
|
[Install]
|
||||||
WantedBy=multi-user.target
|
WantedBy=multi-user.target
|
||||||
|
|
|
@ -14,7 +14,6 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
|
||||||
homedir "github.com/mitchellh/go-homedir"
|
homedir "github.com/mitchellh/go-homedir"
|
||||||
cli "gopkg.in/urfave/cli.v2"
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
)
|
)
|
||||||
|
@ -137,7 +136,7 @@ func download(certURL, filePath string) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Error fetching certificate")
|
Log.WithError(err).Error("Error fetching certificate")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -180,16 +179,16 @@ func putSuccess(client *http.Client, certURL string) {
|
||||||
// indicate success to the relay server
|
// indicate success to the relay server
|
||||||
req, err := http.NewRequest("PUT", certURL+"/ok", nil)
|
req, err := http.NewRequest("PUT", certURL+"/ok", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("HTTP request error")
|
Log.WithError(err).Error("HTTP request error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("HTTP error")
|
Log.WithError(err).Error("HTTP error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
log.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
|
Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,8 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
@ -21,11 +23,12 @@ import (
|
||||||
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||||
"github.com/cloudflare/cloudflare-warp/validation"
|
"github.com/cloudflare/cloudflare-warp/validation"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
|
||||||
"github.com/facebookgo/grace/gracenet"
|
"github.com/facebookgo/grace/gracenet"
|
||||||
raven "github.com/getsentry/raven-go"
|
"github.com/getsentry/raven-go"
|
||||||
homedir "github.com/mitchellh/go-homedir"
|
"github.com/mitchellh/go-homedir"
|
||||||
cli "gopkg.in/urfave/cli.v2"
|
"github.com/rifflock/lfshook"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"gopkg.in/urfave/cli.v2"
|
||||||
"gopkg.in/urfave/cli.v2/altsrc"
|
"gopkg.in/urfave/cli.v2/altsrc"
|
||||||
|
|
||||||
"github.com/coreos/go-systemd/daemon"
|
"github.com/coreos/go-systemd/daemon"
|
||||||
|
@ -40,11 +43,21 @@ const configFile = "config.yml"
|
||||||
var listeners = gracenet.Net{}
|
var listeners = gracenet.Net{}
|
||||||
var Version = "DEV"
|
var Version = "DEV"
|
||||||
var BuildTime = "unknown"
|
var BuildTime = "unknown"
|
||||||
|
var Log *logrus.Logger
|
||||||
|
|
||||||
// Shutdown channel used by the app. When closed, app must terminate.
|
// Shutdown channel used by the app. When closed, app must terminate.
|
||||||
// May be closed by the Windows service runner.
|
// May be closed by the Windows service runner.
|
||||||
var shutdownC chan struct{}
|
var shutdownC chan struct{}
|
||||||
|
|
||||||
|
type BuildAndRuntimeInfo struct {
|
||||||
|
GoOS string `json:"go_os"`
|
||||||
|
GoVersion string `json:"go_version"`
|
||||||
|
GoArch string `json:"go_arch"`
|
||||||
|
WarpVersion string `json:"warp_version"`
|
||||||
|
WarpFlags map[string]interface{} `json:"warp_flags"`
|
||||||
|
WarpEnvs map[string]string `json:"warp_envs"`
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
metrics.RegisterBuildInfo(BuildTime, Version)
|
metrics.RegisterBuildInfo(BuildTime, Version)
|
||||||
raven.SetDSN(sentryDSN)
|
raven.SetDSN(sentryDSN)
|
||||||
|
@ -84,6 +97,12 @@ WARNING:
|
||||||
Usage: "Disable periodic check for updates, restarting the server with the new version.",
|
Usage: "Disable periodic check for updates, restarting the server with the new version.",
|
||||||
Value: false,
|
Value: false,
|
||||||
}),
|
}),
|
||||||
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
|
Name: "is-autoupdated",
|
||||||
|
Usage: "Signal the new process that Warp client has been autoupdated",
|
||||||
|
Value: false,
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||||
Name: "edge",
|
Name: "edge",
|
||||||
Usage: "Address of the Cloudflare tunnel server.",
|
Usage: "Address of the Cloudflare tunnel server.",
|
||||||
|
@ -99,12 +118,12 @@ WARNING:
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "origincert",
|
Name: "origincert",
|
||||||
Usage: "Path to the certificate generated for your origin when you run cloudflare-warp login.",
|
Usage: "Path to the certificate generated for your origin when you run cloudflare-warp login.",
|
||||||
EnvVars: []string{"ORIGIN_CERT"},
|
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
|
||||||
Value: filepath.Join(defaultConfigDir, credentialFile),
|
Value: filepath.Join(defaultConfigDir, credentialFile),
|
||||||
}),
|
}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "url",
|
Name: "url",
|
||||||
Value: "http://localhost:8080",
|
Value: "https://localhost:8080",
|
||||||
Usage: "Connect to the local webserver at `URL`.",
|
Usage: "Connect to the local webserver at `URL`.",
|
||||||
EnvVars: []string{"TUNNEL_URL"},
|
EnvVars: []string{"TUNNEL_URL"},
|
||||||
}),
|
}),
|
||||||
|
@ -191,14 +210,20 @@ WARNING:
|
||||||
}),
|
}),
|
||||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
Name: "hello-world",
|
Name: "hello-world",
|
||||||
Usage: "Run Hello World Server",
|
|
||||||
Value: false,
|
Value: false,
|
||||||
|
Usage: "Run Hello World Server",
|
||||||
|
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
|
||||||
}),
|
}),
|
||||||
altsrc.NewStringFlag(&cli.StringFlag{
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
Name: "pidfile",
|
Name: "pidfile",
|
||||||
Usage: "Write the application's PID to this file after first successful connection.",
|
Usage: "Write the application's PID to this file after first successful connection.",
|
||||||
EnvVars: []string{"TUNNEL_PIDFILE"},
|
EnvVars: []string{"TUNNEL_PIDFILE"},
|
||||||
}),
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "logfile",
|
||||||
|
Usage: "Save application log to this file for reporting issues.",
|
||||||
|
EnvVars: []string{"TUNNEL_LOGFILE"},
|
||||||
|
}),
|
||||||
altsrc.NewIntFlag(&cli.IntFlag{
|
altsrc.NewIntFlag(&cli.IntFlag{
|
||||||
Name: "ha-connections",
|
Name: "ha-connections",
|
||||||
Value: 4,
|
Value: 4,
|
||||||
|
@ -239,6 +264,7 @@ WARNING:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
app.Before = func(context *cli.Context) error {
|
app.Before = func(context *cli.Context) error {
|
||||||
|
Log = logrus.New()
|
||||||
inputSource, err := findInputSourceContext(context)
|
inputSource, err := findInputSourceContext(context)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -248,7 +274,7 @@ WARNING:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
app.Commands = []*cli.Command{
|
app.Commands = []*cli.Command{
|
||||||
&cli.Command{
|
{
|
||||||
Name: "update",
|
Name: "update",
|
||||||
Action: update,
|
Action: update,
|
||||||
Usage: "Update the agent if a new version exists",
|
Usage: "Update the agent if a new version exists",
|
||||||
|
@ -259,7 +285,7 @@ WARNING:
|
||||||
|
|
||||||
To determine if an update happened in a script, check for error code 64.`,
|
To determine if an update happened in a script, check for error code 64.`,
|
||||||
},
|
},
|
||||||
&cli.Command{
|
{
|
||||||
Name: "login",
|
Name: "login",
|
||||||
Action: login,
|
Action: login,
|
||||||
Usage: "Generate a configuration file with your login details",
|
Usage: "Generate a configuration file with your login details",
|
||||||
|
@ -271,7 +297,7 @@ WARNING:
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
&cli.Command{
|
{
|
||||||
Name: "hello",
|
Name: "hello",
|
||||||
Action: hello,
|
Action: hello,
|
||||||
Usage: "Run a simple \"Hello World\" server for testing Cloudflare Warp.",
|
Usage: "Run a simple \"Hello World\" server for testing Cloudflare Warp.",
|
||||||
|
@ -293,27 +319,43 @@ func startServer(c *cli.Context) {
|
||||||
errC := make(chan error)
|
errC := make(chan error)
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
|
|
||||||
if c.NumFlags() == 0 && c.NArg() == 0 {
|
// If the user choose to supply all options through env variables,
|
||||||
|
// c.NumFlags() == 0 && c.NArg() == 0. For warp to work, the user needs to at
|
||||||
|
// least provide a hostname.
|
||||||
|
if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" {
|
||||||
cli.ShowAppHelp(c)
|
cli.ShowAppHelp(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
logLevel, err := logrus.ParseLevel(c.String("loglevel"))
|
||||||
logLevel, err := log.ParseLevel(c.String("loglevel"))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Fatal("Unknown logging level specified")
|
Log.WithError(err).Fatal("Unknown logging level specified")
|
||||||
}
|
}
|
||||||
|
logrus.SetLevel(logLevel)
|
||||||
|
|
||||||
log.SetLevel(logLevel)
|
protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
|
||||||
protoLogLevel, err := log.ParseLevel(c.String("proto-loglevel"))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Fatal("Unknown protocol logging level specified")
|
Log.WithError(err).Fatal("Unknown protocol logging level specified")
|
||||||
}
|
}
|
||||||
protoLogger := log.New()
|
protoLogger := logrus.New()
|
||||||
protoLogger.Level = protoLogLevel
|
protoLogger.Level = protoLogLevel
|
||||||
|
|
||||||
|
if c.String("logfile") != "" {
|
||||||
|
if err := initLogFile(c, protoLogger); err != nil {
|
||||||
|
Log.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0 {
|
||||||
|
if initUpdate() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
|
||||||
|
go autoupdate(c.Duration("autoupdate-freq"), shutdownC)
|
||||||
|
}
|
||||||
|
|
||||||
hostname, err := validation.ValidateHostname(c.String("hostname"))
|
hostname, err := validation.ValidateHostname(c.String("hostname"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Fatal("Invalid hostname")
|
Log.WithError(err).Fatal("Invalid hostname")
|
||||||
|
|
||||||
}
|
}
|
||||||
clientID := c.String("id")
|
clientID := c.String("id")
|
||||||
|
@ -323,46 +365,44 @@ func startServer(c *cli.Context) {
|
||||||
|
|
||||||
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
|
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Fatal("Tag parse failure")
|
Log.WithError(err).Fatal("Tag parse failure")
|
||||||
}
|
}
|
||||||
|
|
||||||
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
|
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
|
||||||
|
|
||||||
if c.IsSet("hello-world") {
|
if c.IsSet("hello-world") {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
listener, err := findAvailablePort()
|
listener, err := createListener("127.0.0.1:")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
listener.Close()
|
listener.Close()
|
||||||
log.WithError(err).Fatal("Cannot start Hello World Server")
|
Log.WithError(err).Fatal("Cannot start Hello World Server")
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
startHelloWorldServer(listener, shutdownC)
|
startHelloWorldServer(listener, shutdownC)
|
||||||
wg.Done()
|
wg.Done()
|
||||||
listener.Close()
|
listener.Close()
|
||||||
}()
|
}()
|
||||||
c.Set("url", "http://"+listener.Addr().String())
|
c.Set("url", "https://"+listener.Addr().String())
|
||||||
log.Infof("Starting Hello World Server at %s", c.String("url"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
url, err := validateUrl(c)
|
url, err := validateUrl(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Fatal("Error validating url")
|
Log.WithError(err).Fatal("Error validating url")
|
||||||
}
|
}
|
||||||
log.Infof("Proxying tunnel requests to %s", url)
|
Log.Infof("Proxying tunnel requests to %s", url)
|
||||||
|
|
||||||
// Fail if the user provided an old authentication method
|
// Fail if the user provided an old authentication method
|
||||||
if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
|
if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
|
||||||
log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login")
|
Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that the user has acquired a certificate using the log in command
|
// Check that the user has acquired a certificate using the log in command
|
||||||
originCertPath, err := homedir.Expand(c.String("origincert"))
|
originCertPath, err := homedir.Expand(c.String("origincert"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
|
Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
|
||||||
}
|
}
|
||||||
ok, err := fileExists(originCertPath)
|
ok, err := fileExists(originCertPath)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Fatalf(`Cannot find a valid certificate for your origin at the path:
|
Log.Fatalf(`Cannot find a valid certificate for your origin at the path:
|
||||||
|
|
||||||
%s
|
%s
|
||||||
|
|
||||||
|
@ -375,8 +415,9 @@ If you don't have a certificate signed by Cloudflare, run the command:
|
||||||
// Easier to send the certificate as []byte via RPC than decoding it at this point
|
// Easier to send the certificate as []byte via RPC than decoding it at this point
|
||||||
originCert, err := ioutil.ReadFile(originCertPath)
|
originCert, err := ioutil.ReadFile(originCertPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
|
Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
tunnelMetrics := origin.NewTunnelMetrics()
|
tunnelMetrics := origin.NewTunnelMetrics()
|
||||||
httpTransport := &http.Transport{
|
httpTransport := &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
@ -389,13 +430,15 @@ If you don't have a certificate signed by Cloudflare, run the command:
|
||||||
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
|
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
|
||||||
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
|
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()},
|
||||||
}
|
}
|
||||||
tunnelConfig := &origin.TunnelConfig{
|
tunnelConfig := &origin.TunnelConfig{
|
||||||
EdgeAddrs: c.StringSlice("edge"),
|
EdgeAddrs: c.StringSlice("edge"),
|
||||||
OriginUrl: url,
|
OriginUrl: url,
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
OriginCert: originCert,
|
OriginCert: originCert,
|
||||||
TlsConfig: &tls.Config{},
|
TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
|
||||||
|
ClientTlsConfig: httpTransport.TLSClientConfig,
|
||||||
Retries: c.Uint("retries"),
|
Retries: c.Uint("retries"),
|
||||||
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
||||||
MaxHeartbeats: c.Uint64("heartbeat-count"),
|
MaxHeartbeats: c.Uint64("heartbeat-count"),
|
||||||
|
@ -408,18 +451,11 @@ If you don't have a certificate signed by Cloudflare, run the command:
|
||||||
Metrics: tunnelMetrics,
|
Metrics: tunnelMetrics,
|
||||||
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
|
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
|
||||||
ProtocolLogger: protoLogger,
|
ProtocolLogger: protoLogger,
|
||||||
|
Logger: Log,
|
||||||
|
IsAutoupdated: c.Bool("is-autoupdated"),
|
||||||
}
|
}
|
||||||
connectedSignal := make(chan struct{})
|
connectedSignal := make(chan struct{})
|
||||||
|
|
||||||
tunnelConfig.TlsConfig = tlsconfig.CLIFlags{RootCA: "cacert"}.GetConfig(c)
|
|
||||||
if tunnelConfig.TlsConfig.RootCAs == nil {
|
|
||||||
tunnelConfig.TlsConfig.RootCAs = GetCloudflareRootCA()
|
|
||||||
tunnelConfig.TlsConfig.ServerName = "cftunnel.com"
|
|
||||||
} else if len(tunnelConfig.EdgeAddrs) > 0 {
|
|
||||||
// Set for development environments and for testing specific origintunneld instances
|
|
||||||
tunnelConfig.TlsConfig.ServerName, _, _ = net.SplitHostPort(tunnelConfig.EdgeAddrs[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
go writePidFile(connectedSignal, c.String("pidfile"))
|
go writePidFile(connectedSignal, c.String("pidfile"))
|
||||||
go func() {
|
go func() {
|
||||||
errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal)
|
errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal)
|
||||||
|
@ -428,24 +464,21 @@ If you don't have a certificate signed by Cloudflare, run the command:
|
||||||
|
|
||||||
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
|
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Fatal("Error opening metrics server listener")
|
Log.WithError(err).Fatal("Error opening metrics server listener")
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
errC <- metrics.ServeMetrics(metricsListener, shutdownC)
|
errC <- metrics.ServeMetrics(metricsListener, shutdownC)
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if !c.Bool("no-autoupdate") {
|
var errCode int
|
||||||
log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
|
|
||||||
go autoupdate(c.Duration("autoupdate-period"), shutdownC)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = WaitForSignal(errC, shutdownC)
|
err = WaitForSignal(errC, shutdownC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Quitting due to error")
|
Log.WithError(err).Error("Quitting due to error")
|
||||||
raven.CaptureErrorAndWait(err, nil)
|
raven.CaptureErrorAndWait(err, nil)
|
||||||
|
errCode = 1
|
||||||
} else {
|
} else {
|
||||||
log.Info("Quitting...")
|
Log.Info("Quitting...")
|
||||||
}
|
}
|
||||||
// Wait for clean exit, discarding all errors
|
// Wait for clean exit, discarding all errors
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -453,6 +486,7 @@ If you don't have a certificate signed by Cloudflare, run the command:
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
os.Exit(errCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
|
func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
|
||||||
|
@ -477,30 +511,40 @@ func update(c *cli.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func autoupdate(frequency time.Duration, shutdownC chan struct{}) {
|
func initUpdate() bool {
|
||||||
if int64(frequency) == 0 {
|
if updateApplied() {
|
||||||
return
|
os.Args = append(os.Args, "--is-autoupdated=true")
|
||||||
|
if _, err := listeners.StartProcess(); err != nil {
|
||||||
|
Log.WithError(err).Error("Unable to restart server automatically")
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func autoupdate(freq time.Duration, shutdownC chan struct{}) {
|
||||||
for {
|
for {
|
||||||
if updateApplied() {
|
if updateApplied() {
|
||||||
|
os.Args = append(os.Args, "--is-autoupdated=true")
|
||||||
if _, err := listeners.StartProcess(); err != nil {
|
if _, err := listeners.StartProcess(); err != nil {
|
||||||
log.WithError(err).Error("Unable to restart server automatically")
|
Log.WithError(err).Error("Unable to restart server automatically")
|
||||||
}
|
}
|
||||||
close(shutdownC)
|
close(shutdownC)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
time.Sleep(frequency)
|
time.Sleep(freq)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateApplied() bool {
|
func updateApplied() bool {
|
||||||
releaseInfo := checkForUpdates()
|
releaseInfo := checkForUpdates()
|
||||||
if releaseInfo.Updated {
|
if releaseInfo.Updated {
|
||||||
log.Infof("Updated to version %s", releaseInfo.Version)
|
Log.Infof("Updated to version %s", releaseInfo.Version)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if releaseInfo.Error != nil {
|
if releaseInfo.Error != nil {
|
||||||
log.WithError(releaseInfo.Error).Error("Update check failed")
|
Log.WithError(releaseInfo.Error).Error("Update check failed")
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -555,7 +599,7 @@ func writePidFile(waitForSignal chan struct{}, pidFile string) {
|
||||||
}
|
}
|
||||||
file, err := os.Create(pidFile)
|
file, err := os.Create(pidFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
|
Log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
fmt.Fprintf(file, "%d", os.Getpid())
|
fmt.Fprintf(file, "%d", os.Getpid())
|
||||||
|
@ -573,3 +617,55 @@ func validateUrl(c *cli.Context) (string, error) {
|
||||||
validUrl, err := validation.ValidateUrl(url)
|
validUrl, err := validation.ValidateUrl(url)
|
||||||
return validUrl, err
|
return validUrl, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error {
|
||||||
|
fileMode := os.O_WRONLY|os.O_APPEND|os.O_CREATE|os.O_TRUNC
|
||||||
|
// do not truncate log file if the client has been autoupdated
|
||||||
|
if c.Bool("is-autoupdated") {
|
||||||
|
fileMode = os.O_WRONLY|os.O_APPEND|os.O_CREATE
|
||||||
|
}
|
||||||
|
f, err := os.OpenFile(c.String("logfile"), fileMode, 0664)
|
||||||
|
if err != nil {
|
||||||
|
errors.Wrap(err, fmt.Sprintf("Cannot open file %s", c.String("logfile")))
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
pathMap := lfshook.PathMap{
|
||||||
|
logrus.InfoLevel: c.String("logfile"),
|
||||||
|
logrus.ErrorLevel: c.String("logfile"),
|
||||||
|
logrus.FatalLevel: c.String("logfile"),
|
||||||
|
logrus.PanicLevel: c.String("logfile"),
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
|
||||||
|
protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
|
||||||
|
|
||||||
|
flags := make(map[string]interface{})
|
||||||
|
envs := make(map[string]string)
|
||||||
|
|
||||||
|
for _, flag := range c.LocalFlagNames() {
|
||||||
|
flags[flag] = c.Generic(flag)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find env variables for Warp
|
||||||
|
for _, env := range os.Environ() {
|
||||||
|
// All Warp env variables start with TUNNEL_
|
||||||
|
if strings.Contains(env, "TUNNEL_") {
|
||||||
|
vars := strings.Split(env, "=")
|
||||||
|
if len(vars) == 2 {
|
||||||
|
envs[vars[0]] = vars[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.Infof("Warp build and runtime configuration: %+v", BuildAndRuntimeInfo{
|
||||||
|
GoOS: runtime.GOOS,
|
||||||
|
GoVersion: runtime.Version(),
|
||||||
|
GoArch: runtime.GOARCH,
|
||||||
|
WarpVersion: Version,
|
||||||
|
WarpFlags: flags,
|
||||||
|
WarpEnvs: envs,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
cli "gopkg.in/urfave/cli.v2"
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
|
|
||||||
"golang.org/x/sys/windows/svc"
|
"golang.org/x/sys/windows/svc"
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/hpack"
|
"golang.org/x/net/http2/hpack"
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,13 +6,14 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
|
@ -25,25 +26,20 @@ func TestMain(m *testing.M) {
|
||||||
type DefaultMuxerPair struct {
|
type DefaultMuxerPair struct {
|
||||||
OriginMuxConfig MuxerConfig
|
OriginMuxConfig MuxerConfig
|
||||||
OriginMux *Muxer
|
OriginMux *Muxer
|
||||||
OriginWriter *io.PipeWriter
|
OriginConn net.Conn
|
||||||
OriginReader *io.PipeReader
|
|
||||||
EdgeMuxConfig MuxerConfig
|
EdgeMuxConfig MuxerConfig
|
||||||
EdgeMux *Muxer
|
EdgeMux *Muxer
|
||||||
EdgeWriter *io.PipeWriter
|
EdgeConn net.Conn
|
||||||
EdgeReader *io.PipeReader
|
|
||||||
doneC chan struct{}
|
doneC chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDefaultMuxerPair() *DefaultMuxerPair {
|
func NewDefaultMuxerPair() *DefaultMuxerPair {
|
||||||
originReader, edgeWriter := io.Pipe()
|
origin, edge := net.Pipe()
|
||||||
edgeReader, originWriter := io.Pipe()
|
|
||||||
return &DefaultMuxerPair{
|
return &DefaultMuxerPair{
|
||||||
OriginMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: true, Name: "origin"},
|
OriginMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: true, Name: "origin"},
|
||||||
OriginWriter: originWriter,
|
OriginConn: origin,
|
||||||
OriginReader: originReader,
|
|
||||||
EdgeMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: false, Name: "edge"},
|
EdgeMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: false, Name: "edge"},
|
||||||
EdgeWriter: edgeWriter,
|
EdgeConn: edge,
|
||||||
EdgeReader: edgeReader,
|
|
||||||
doneC: make(chan struct{}),
|
doneC: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -53,12 +49,12 @@ func (p *DefaultMuxerPair) Handshake(t *testing.T) {
|
||||||
originErrC := make(chan error)
|
originErrC := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
var err error
|
var err error
|
||||||
p.EdgeMux, err = Handshake(p.EdgeWriter, p.EdgeReader, p.EdgeMuxConfig)
|
p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig)
|
||||||
edgeErrC <- err
|
edgeErrC <- err
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
var err error
|
var err error
|
||||||
p.OriginMux, err = Handshake(p.OriginWriter, p.OriginReader, p.OriginMuxConfig)
|
p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig)
|
||||||
originErrC <- err
|
originErrC <- err
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -120,8 +116,8 @@ func (p *DefaultMuxerPair) Wait(t *testing.T) {
|
||||||
func TestHandshake(t *testing.T) {
|
func TestHandshake(t *testing.T) {
|
||||||
muxPair := NewDefaultMuxerPair()
|
muxPair := NewDefaultMuxerPair()
|
||||||
muxPair.Handshake(t)
|
muxPair.Handshake(t)
|
||||||
AssertIfPipeReadable(t, muxPair.OriginReader)
|
AssertIfPipeReadable(t, muxPair.OriginConn)
|
||||||
AssertIfPipeReadable(t, muxPair.EdgeReader)
|
AssertIfPipeReadable(t, muxPair.EdgeConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSingleStream(t *testing.T) {
|
func TestSingleStream(t *testing.T) {
|
||||||
|
@ -145,7 +141,7 @@ func TestSingleStream(t *testing.T) {
|
||||||
stream.Write(buf)
|
stream.Write(buf)
|
||||||
// after this receive, the edge closed the stream
|
// after this receive, the edge closed the stream
|
||||||
<-closeC
|
<-closeC
|
||||||
n, err := stream.Read(buf)
|
n, err := io.ReadFull(stream, buf)
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
t.Fatalf("read %d bytes after EOF", n)
|
t.Fatalf("read %d bytes after EOF", n)
|
||||||
}
|
}
|
||||||
|
@ -173,7 +169,7 @@ func TestSingleStream(t *testing.T) {
|
||||||
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
|
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
|
||||||
}
|
}
|
||||||
responseBody := make([]byte, 11)
|
responseBody := make([]byte, 11)
|
||||||
n, err := stream.Read(responseBody)
|
n, err := io.ReadFull(stream, responseBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error from (*MuxedStream).Read: %s", err)
|
t.Fatalf("error from (*MuxedStream).Read: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -243,7 +239,7 @@ func TestSingleStreamLargeResponseBody(t *testing.T) {
|
||||||
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
|
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
|
||||||
}
|
}
|
||||||
responseBody := make([]byte, bodySize)
|
responseBody := make([]byte, bodySize)
|
||||||
n, err := stream.Read(responseBody)
|
n, err := io.ReadFull(stream, responseBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error from (*MuxedStream).Read: %s", err)
|
t.Fatalf("error from (*MuxedStream).Read: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -302,7 +298,7 @@ func TestMultipleStreams(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
responseBody := make([]byte, 2)
|
responseBody := make([]byte, 2)
|
||||||
n, err := stream.Read(responseBody)
|
n, err := io.ReadFull(stream, responseBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err)
|
errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err)
|
||||||
return
|
return
|
||||||
|
@ -392,7 +388,7 @@ func TestMultipleStreamsFlowControl(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
responseBody := make([]byte, responseSizes[(stream.streamID-2)/2])
|
responseBody := make([]byte, responseSizes[(stream.streamID-2)/2])
|
||||||
n, err := stream.Read(responseBody)
|
n, err := io.ReadFull(stream, responseBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err)
|
errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err)
|
||||||
return
|
return
|
||||||
|
@ -451,7 +447,7 @@ func TestGracefulShutdown(t *testing.T) {
|
||||||
}
|
}
|
||||||
responseBody := make([]byte, len(responseBuf))
|
responseBody := make([]byte, len(responseBuf))
|
||||||
log.Debugf("Waiting for %d bytes", len(responseBuf))
|
log.Debugf("Waiting for %d bytes", len(responseBuf))
|
||||||
n, err := stream.Read(responseBody)
|
n, err := io.ReadFull(stream, responseBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err)
|
t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err)
|
||||||
}
|
}
|
||||||
|
@ -498,13 +494,13 @@ func TestUnexpectedShutdown(t *testing.T) {
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
// Close the underlying connection before telling the origin to write.
|
// Close the underlying connection before telling the origin to write.
|
||||||
muxPair.EdgeReader.Close()
|
muxPair.EdgeConn.Close()
|
||||||
close(sendC)
|
close(sendC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("error in OpenStream: %s", err)
|
t.Fatalf("error in OpenStream: %s", err)
|
||||||
}
|
}
|
||||||
responseBody := make([]byte, len(responseBuf))
|
responseBody := make([]byte, len(responseBuf))
|
||||||
n, err := stream.Read(responseBody)
|
n, err := io.ReadFull(stream, responseBody)
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err)
|
t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -545,14 +541,14 @@ func TestOpenAfterDisconnect(t *testing.T) {
|
||||||
switch i {
|
switch i {
|
||||||
case 0:
|
case 0:
|
||||||
// Close both directions of the connection to cause EOF on both peers.
|
// Close both directions of the connection to cause EOF on both peers.
|
||||||
muxPair.OriginReader.Close()
|
muxPair.OriginConn.Close()
|
||||||
muxPair.OriginWriter.Close()
|
muxPair.EdgeConn.Close()
|
||||||
case 1:
|
case 1:
|
||||||
// Close origin reader (edge writer) to cause EOF on origin only.
|
// Close origin conn to cause EOF on origin first.
|
||||||
muxPair.OriginReader.Close()
|
muxPair.OriginConn.Close()
|
||||||
case 2:
|
case 2:
|
||||||
// Close origin writer (edge reader) to cause EOF on edge only.
|
// Close edge conn to cause EOF on edge first.
|
||||||
muxPair.OriginWriter.Close()
|
muxPair.EdgeConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := muxPair.EdgeMux.OpenStream(
|
_, err := muxPair.EdgeMux.OpenStream(
|
||||||
|
@ -623,7 +619,7 @@ func TestHPACK(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) {
|
func AssertIfPipeReadable(t *testing.T, pipe io.ReadCloser) {
|
||||||
errC := make(chan error)
|
errC := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
b := []byte{0}
|
b := []byte{0}
|
||||||
|
@ -640,7 +636,5 @@ func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) {
|
||||||
}
|
}
|
||||||
case <-time.After(100 * time.Millisecond):
|
case <-time.After(100 * time.Millisecond):
|
||||||
// nothing to read
|
// nothing to read
|
||||||
pipe.Close()
|
|
||||||
<-errC
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package h2mux
|
package h2mux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -63,3 +64,27 @@ func TestFlowControlSingleStream(t *testing.T) {
|
||||||
assert.Equal(t, testWindowSize<<2, stream.receiveWindow)
|
assert.Equal(t, testWindowSize<<2, stream.receiveWindow)
|
||||||
assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax)
|
assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMuxedStreamEOF(t *testing.T) {
|
||||||
|
for i := 0; i < 4096; i++ {
|
||||||
|
readyList := NewReadyList()
|
||||||
|
stream := &MuxedStream{
|
||||||
|
streamID: 1,
|
||||||
|
readBuffer: NewSharedBuffer(),
|
||||||
|
receiveWindow: 65536,
|
||||||
|
receiveWindowMax: 65536,
|
||||||
|
sendWindow: 65536,
|
||||||
|
readyList: readyList,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() { stream.Close() }()
|
||||||
|
n, err := stream.Read([]byte{0})
|
||||||
|
assert.Equal(t, io.EOF, err)
|
||||||
|
assert.Equal(t, 0, n)
|
||||||
|
// Write comes after read, because write buffers data before it is flushed. It wouldn't know about EOF
|
||||||
|
// until some time later. Calling read first forces it to know about EOF now.
|
||||||
|
n, err = stream.Write([]byte{1})
|
||||||
|
assert.Equal(t, io.EOF, err)
|
||||||
|
assert.Equal(t, 0, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/hpack"
|
"golang.org/x/net/http2/hpack"
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,7 +21,7 @@ func NewSharedBuffer() *SharedBuffer {
|
||||||
func (s *SharedBuffer) Read(p []byte) (n int, err error) {
|
func (s *SharedBuffer) Read(p []byte) (n int, err error) {
|
||||||
totalRead := 0
|
totalRead := 0
|
||||||
s.cond.L.Lock()
|
s.cond.L.Lock()
|
||||||
for totalRead < len(p) {
|
for totalRead == 0 {
|
||||||
n, err = s.buffer.Read(p[totalRead:])
|
n, err = s.buffer.Read(p[totalRead:])
|
||||||
totalRead += n
|
totalRead += n
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
|
@ -29,6 +29,9 @@ func (s *SharedBuffer) Read(p []byte) (n int, err error) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
err = nil
|
err = nil
|
||||||
|
if n > 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
s.cond.Wait()
|
s.cond.Wait()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) {
|
func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) {
|
||||||
|
@ -29,30 +31,35 @@ func TestSharedBuffer(t *testing.T) {
|
||||||
|
|
||||||
func TestSharedBufferBlockingRead(t *testing.T) {
|
func TestSharedBufferBlockingRead(t *testing.T) {
|
||||||
b := NewSharedBuffer()
|
b := NewSharedBuffer()
|
||||||
testData := []byte("Hello world")
|
testData1 := []byte("Hello")
|
||||||
|
testData2 := []byte(" world")
|
||||||
result := make(chan []byte)
|
result := make(chan []byte)
|
||||||
go func() {
|
go func() {
|
||||||
bytesRead := make([]byte, len(testData))
|
bytesRead := make([]byte, len(testData1)+len(testData2))
|
||||||
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
|
nRead, err := b.Read(bytesRead)
|
||||||
result <- bytesRead
|
AssertIOReturnIsGood(t, len(testData1))(nRead, err)
|
||||||
|
result <- bytesRead[:nRead]
|
||||||
|
nRead, err = b.Read(bytesRead)
|
||||||
|
AssertIOReturnIsGood(t, len(testData2))(nRead, err)
|
||||||
|
result <- bytesRead[:nRead]
|
||||||
}()
|
}()
|
||||||
|
time.Sleep(time.Millisecond * 250)
|
||||||
select {
|
select {
|
||||||
case <-result:
|
case <-result:
|
||||||
t.Fatalf("read returned early")
|
t.Fatalf("read returned early")
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
AssertIOReturnIsGood(t, 5)(b.Write(testData[:5]))
|
AssertIOReturnIsGood(t, len(testData1))(b.Write([]byte(testData1)))
|
||||||
select {
|
|
||||||
case <-result:
|
|
||||||
t.Fatalf("read returned early")
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
AssertIOReturnIsGood(t, len(testData)-5)(b.Write(testData[5:]))
|
|
||||||
select {
|
select {
|
||||||
case r := <-result:
|
case r := <-result:
|
||||||
if string(r) != string(testData) {
|
assert.Equal(t, testData1, r)
|
||||||
t.Fatalf("expected read to return %s, got %s", testData, r)
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("read timed out")
|
||||||
}
|
}
|
||||||
|
AssertIOReturnIsGood(t, len(testData2))(b.Write([]byte(testData2)))
|
||||||
|
select {
|
||||||
|
case r := <-result:
|
||||||
|
assert.Equal(t, testData2, r)
|
||||||
case <-time.After(time.Second):
|
case <-time.After(time.Second):
|
||||||
t.Fatalf("read timed out")
|
t.Fatalf("read timed out")
|
||||||
}
|
}
|
||||||
|
@ -85,7 +92,7 @@ func TestSharedBufferConcurrentReadWrite(t *testing.T) {
|
||||||
// Change block sizes in opposition to the write thread, to test blocking for new data.
|
// Change block sizes in opposition to the write thread, to test blocking for new data.
|
||||||
for blockSize := 256; blockSize > 0; blockSize-- {
|
for blockSize := 256; blockSize > 0; blockSize-- {
|
||||||
for i := 0; i < 256; i++ {
|
for i := 0; i < 256; i++ {
|
||||||
n, err := b.Read(block[:blockSize])
|
n, err := io.ReadFull(b, block[:blockSize])
|
||||||
if n != blockSize || err != nil {
|
if n != blockSize || err != nil {
|
||||||
t.Fatalf("read error: %d %s", n, err)
|
t.Fatalf("read error: %d %s", n, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,9 +11,9 @@ import (
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"golang.org/x/net/trace"
|
"golang.org/x/net/trace"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflare-warp/h2mux"
|
"github.com/cloudflare/cloudflare-warp/h2mux"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -249,7 +248,7 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
|
||||||
if _, ok := t.concurrentRequests[connectionID]; ok {
|
if _, ok := t.concurrentRequests[connectionID]; ok {
|
||||||
t.concurrentRequests[connectionID] -= 1
|
t.concurrentRequests[connectionID] -= 1
|
||||||
} else {
|
} else {
|
||||||
log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.")
|
Log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.")
|
||||||
}
|
}
|
||||||
t.concurrentRequestsLock.Unlock()
|
t.concurrentRequestsLock.Unlock()
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -73,7 +72,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
|
||||||
case tunnelError := <-s.tunnelErrors:
|
case tunnelError := <-s.tunnelErrors:
|
||||||
tunnelsActive--
|
tunnelsActive--
|
||||||
if tunnelError.err != nil {
|
if tunnelError.err != nil {
|
||||||
log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
|
Log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
|
||||||
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
||||||
s.waitForNextTunnel(tunnelError.index)
|
s.waitForNextTunnel(tunnelError.index)
|
||||||
if backoffTimer == nil {
|
if backoffTimer == nil {
|
||||||
|
@ -107,10 +106,10 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
|
||||||
s.lastResolve = time.Now()
|
s.lastResolve = time.Now()
|
||||||
s.resolverC = nil
|
s.resolverC = nil
|
||||||
if result.err == nil {
|
if result.err == nil {
|
||||||
log.Debug("Service discovery refresh complete")
|
Log.Debug("Service discovery refresh complete")
|
||||||
s.edgeIPs = result.edgeIPs
|
s.edgeIPs = result.edgeIPs
|
||||||
} else {
|
} else {
|
||||||
log.WithError(result.err).Error("Service discovery error")
|
Log.WithError(result.err).Error("Service discovery error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -120,12 +119,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
|
||||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
|
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
|
||||||
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
|
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Infof("ResolveEdgeIPs err")
|
Log.Infof("ResolveEdgeIPs err")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.edgeIPs = edgeIPs
|
s.edgeIPs = edgeIPs
|
||||||
if s.config.HAConnections > len(edgeIPs) {
|
if s.config.HAConnections > len(edgeIPs) {
|
||||||
log.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
|
Log.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
|
||||||
s.config.HAConnections = len(edgeIPs)
|
s.config.HAConnections = len(edgeIPs)
|
||||||
}
|
}
|
||||||
s.lastResolve = time.Now()
|
s.lastResolve = time.Now()
|
||||||
|
|
|
@ -19,14 +19,17 @@ import (
|
||||||
"github.com/cloudflare/cloudflare-warp/tunnelrpc"
|
"github.com/cloudflare/cloudflare-warp/tunnelrpc"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||||
"github.com/cloudflare/cloudflare-warp/validation"
|
"github.com/cloudflare/cloudflare-warp/validation"
|
||||||
|
"github.com/cloudflare/cloudflare-warp/websocket"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
|
||||||
raven "github.com/getsentry/raven-go"
|
raven "github.com/getsentry/raven-go"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
_ "github.com/prometheus/client_golang/prometheus"
|
_ "github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
rpc "zombiezen.com/go/capnproto2/rpc"
|
rpc "zombiezen.com/go/capnproto2/rpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var Log *logrus.Logger
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dialTimeout = 15 * time.Second
|
dialTimeout = 15 * time.Second
|
||||||
|
|
||||||
|
@ -40,6 +43,7 @@ type TunnelConfig struct {
|
||||||
Hostname string
|
Hostname string
|
||||||
OriginCert []byte
|
OriginCert []byte
|
||||||
TlsConfig *tls.Config
|
TlsConfig *tls.Config
|
||||||
|
ClientTlsConfig *tls.Config
|
||||||
Retries uint
|
Retries uint
|
||||||
HeartbeatInterval time.Duration
|
HeartbeatInterval time.Duration
|
||||||
MaxHeartbeats uint64
|
MaxHeartbeats uint64
|
||||||
|
@ -51,7 +55,9 @@ type TunnelConfig struct {
|
||||||
HTTPTransport http.RoundTripper
|
HTTPTransport http.RoundTripper
|
||||||
Metrics *TunnelMetrics
|
Metrics *TunnelMetrics
|
||||||
MetricsUpdateFreq time.Duration
|
MetricsUpdateFreq time.Duration
|
||||||
ProtocolLogger *log.Logger
|
ProtocolLogger *logrus.Logger
|
||||||
|
Logger *logrus.Logger
|
||||||
|
IsAutoupdated bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type dialError struct {
|
type dialError struct {
|
||||||
|
@ -87,14 +93,16 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
|
||||||
Version: c.ReportedVersion,
|
Version: c.ReportedVersion,
|
||||||
OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH),
|
OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH),
|
||||||
ExistingTunnelPolicy: policy,
|
ExistingTunnelPolicy: policy,
|
||||||
PoolID: c.LBPool,
|
PoolName: c.LBPool,
|
||||||
Tags: c.Tags,
|
Tags: c.Tags,
|
||||||
ConnectionID: connectionID,
|
ConnectionID: connectionID,
|
||||||
OriginLocalIP: OriginLocalIP,
|
OriginLocalIP: OriginLocalIP,
|
||||||
|
IsAutoupdated: c.IsAutoupdated,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error {
|
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error {
|
||||||
|
Log = config.Logger
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
go func() {
|
go func() {
|
||||||
<-shutdownC
|
<-shutdownC
|
||||||
|
@ -129,7 +137,7 @@ func ServeTunnelLoop(ctx context.Context, config *TunnelConfig, addr *net.TCPAdd
|
||||||
err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff)
|
err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff)
|
||||||
if recoverable {
|
if recoverable {
|
||||||
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||||
log.Infof("Retrying in %s seconds", duration)
|
Log.Infof("Retrying in %s seconds", duration)
|
||||||
backoff.Backoff(ctx)
|
backoff.Backoff(ctx)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -162,11 +170,10 @@ func ServeTunnel(
|
||||||
// Returns error from parsing the origin URL or handshake errors
|
// Returns error from parsing the origin URL or handshake errors
|
||||||
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID)
|
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errLog := log.WithError(err)
|
errLog := Log.WithError(err)
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
case dialError:
|
case dialError:
|
||||||
errLog.Error("Unable to dial edge")
|
errLog.Error("Unable to dial edge")
|
||||||
return err, false
|
|
||||||
case h2mux.MuxerHandshakeError:
|
case h2mux.MuxerHandshakeError:
|
||||||
errLog.Error("Handshake failed with edge server")
|
errLog.Error("Handshake failed with edge server")
|
||||||
default:
|
default:
|
||||||
|
@ -207,24 +214,21 @@ func ServeTunnel(
|
||||||
registerErr := <-registerErrC
|
registerErr := <-registerErrC
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("Tunnel error")
|
Log.WithError(err).Error("Tunnel error")
|
||||||
return err, true
|
return err, true
|
||||||
}
|
}
|
||||||
if registerErr != nil {
|
if registerErr != nil {
|
||||||
// Don't retry on errors like entitlement failure or version too old
|
// Don't retry on errors like entitlement failure or version too old
|
||||||
if e, ok := registerErr.(printableRegisterTunnelError); ok {
|
if e, ok := registerErr.(printableRegisterTunnelError); ok {
|
||||||
log.Error(e)
|
Log.Error(e)
|
||||||
if e.permanent {
|
return e.cause, !e.permanent
|
||||||
return e, false
|
|
||||||
}
|
|
||||||
return e.cause, true
|
|
||||||
} else if e, ok := registerErr.(dupConnRegisterTunnelError); ok {
|
} else if e, ok := registerErr.(dupConnRegisterTunnelError); ok {
|
||||||
log.Info("Already connected to this server, selecting a different one")
|
Log.Info("Already connected to this server, selecting a different one")
|
||||||
return e, true
|
return e, true
|
||||||
}
|
}
|
||||||
// Only log errors to Sentry that may have been caused by the client side, to reduce dupes
|
// Only log errors to Sentry that may have been caused by the client side, to reduce dupes
|
||||||
raven.CaptureError(registerErr, nil)
|
raven.CaptureError(registerErr, nil)
|
||||||
log.Error("Cannot register")
|
Log.Error("Cannot register")
|
||||||
return err, true
|
return err, true
|
||||||
}
|
}
|
||||||
return nil, false
|
return nil, false
|
||||||
|
@ -241,7 +245,7 @@ func IsRPCStreamResponse(headers []h2mux.Header) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error {
|
func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error {
|
||||||
logger := log.WithField("subsystem", "rpc")
|
logger := Log.WithField("subsystem", "rpc")
|
||||||
logger.Debug("initiating RPC stream")
|
logger.Debug("initiating RPC stream")
|
||||||
stream, err := muxer.OpenStream([]h2mux.Header{
|
stream, err := muxer.OpenStream([]h2mux.Header{
|
||||||
{Name: ":method", Value: "RPC"},
|
{Name: ":method", Value: "RPC"},
|
||||||
|
@ -292,11 +296,11 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("Registered at %s", registration.Url)
|
Log.Infof("Registered at %s", registration.Url)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func LogServerInfo(logger *log.Entry,
|
func LogServerInfo(logger *logrus.Entry,
|
||||||
promise tunnelrpc.ServerInfo_Promise,
|
promise tunnelrpc.ServerInfo_Promise,
|
||||||
connectionID uint8,
|
connectionID uint8,
|
||||||
metrics *TunnelMetrics,
|
metrics *TunnelMetrics,
|
||||||
|
@ -311,7 +315,7 @@ func LogServerInfo(logger *log.Entry,
|
||||||
logger.WithError(err).Warn("Failed to retrieve server information")
|
logger.WithError(err).Warn("Failed to retrieve server information")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Infof("Connected to %s", serverInfo.LocationName)
|
Log.Infof("Connected to %s", serverInfo.LocationName)
|
||||||
metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName)
|
metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -356,6 +360,7 @@ type TunnelHandler struct {
|
||||||
originUrl string
|
originUrl string
|
||||||
muxer *h2mux.Muxer
|
muxer *h2mux.Muxer
|
||||||
httpClient http.RoundTripper
|
httpClient http.RoundTripper
|
||||||
|
tlsConfig *tls.Config
|
||||||
tags []tunnelpogs.Tag
|
tags []tunnelpogs.Tag
|
||||||
metrics *TunnelMetrics
|
metrics *TunnelMetrics
|
||||||
// connectionID is only used by metrics, and prometheus requires labels to be string
|
// connectionID is only used by metrics, and prometheus requires labels to be string
|
||||||
|
@ -373,6 +378,7 @@ func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, co
|
||||||
h := &TunnelHandler{
|
h := &TunnelHandler{
|
||||||
originUrl: url,
|
originUrl: url,
|
||||||
httpClient: config.HTTPTransport,
|
httpClient: config.HTTPTransport,
|
||||||
|
tlsConfig: config.ClientTlsConfig,
|
||||||
tags: config.Tags,
|
tags: config.Tags,
|
||||||
metrics: config.Metrics,
|
metrics: config.Metrics,
|
||||||
connectionID: uint8ToString(connectionID),
|
connectionID: uint8ToString(connectionID),
|
||||||
|
@ -422,29 +428,45 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
h.metrics.incrementRequests(h.connectionID)
|
h.metrics.incrementRequests(h.connectionID)
|
||||||
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Panic("Unexpected error from http.NewRequest")
|
Log.WithError(err).Panic("Unexpected error from http.NewRequest")
|
||||||
}
|
}
|
||||||
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("invalid request received")
|
Log.WithError(err).Error("invalid request received")
|
||||||
}
|
}
|
||||||
h.AppendTagHeaders(req)
|
h.AppendTagHeaders(req)
|
||||||
|
|
||||||
|
if websocket.IsWebSocketUpgrade(req) {
|
||||||
|
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
h.logError(stream, err)
|
||||||
|
} else {
|
||||||
|
stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||||
|
defer conn.Close()
|
||||||
|
websocket.Stream(conn.UnderlyingConn(), stream)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
response, err := h.httpClient.RoundTrip(req)
|
response, err := h.httpClient.RoundTrip(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("HTTP request error")
|
h.logError(stream, err)
|
||||||
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
|
|
||||||
stream.Write([]byte("502 Bad Gateway"))
|
|
||||||
h.metrics.incrementResponses(h.connectionID, "502")
|
|
||||||
} else {
|
} else {
|
||||||
defer response.Body.Close()
|
defer response.Body.Close()
|
||||||
stream.WriteHeaders(H1ResponseToH2Response(response))
|
stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||||
io.Copy(stream, response.Body)
|
io.Copy(stream, response.Body)
|
||||||
h.metrics.incrementResponses(h.connectionID, "200")
|
h.metrics.incrementResponses(h.connectionID, "200")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
h.metrics.decrementConcurrentRequests(h.connectionID)
|
h.metrics.decrementConcurrentRequests(h.connectionID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
|
||||||
|
Log.WithError(err).Error("HTTP request error")
|
||||||
|
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
|
||||||
|
stream.Write([]byte("502 Bad Gateway"))
|
||||||
|
h.metrics.incrementResponses(h.connectionID, "502")
|
||||||
|
}
|
||||||
|
|
||||||
func (h *TunnelHandler) UpdateMetrics() {
|
func (h *TunnelHandler) UpdateMetrics() {
|
||||||
flowCtlMetrics := h.muxer.FlowControlMetrics()
|
flowCtlMetrics := h.muxer.FlowControlMetrics()
|
||||||
h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics)
|
h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics)
|
||||||
|
|
|
@ -0,0 +1,95 @@
|
||||||
|
package tlsconfig
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: remove the Origin CA root certs when migrated to Authenticated Origin Pull certs
|
||||||
|
const cloudflareRootCA = `
|
||||||
|
Issuer: C=US, ST=California, L=San Francisco, O=CloudFlare, Inc., OU=CloudFlare Origin SSL ECC Certificate Authority
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIICiDCCAi6gAwIBAgIUXZP3MWb8MKwBE1Qbawsp1sfA/Y4wCgYIKoZIzj0EAwIw
|
||||||
|
gY8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1T
|
||||||
|
YW4gRnJhbmNpc2NvMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTgwNgYDVQQL
|
||||||
|
Ey9DbG91ZEZsYXJlIE9yaWdpbiBTU0wgRUNDIENlcnRpZmljYXRlIEF1dGhvcml0
|
||||||
|
eTAeFw0xNjAyMjIxODI0MDBaFw0yMTAyMjIwMDI0MDBaMIGPMQswCQYDVQQGEwJV
|
||||||
|
UzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEZ
|
||||||
|
MBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE4MDYGA1UECxMvQ2xvdWRGbGFyZSBP
|
||||||
|
cmlnaW4gU1NMIEVDQyBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkwWTATBgcqhkjOPQIB
|
||||||
|
BggqhkjOPQMBBwNCAASR+sGALuaGshnUbcxKry+0LEXZ4NY6JUAtSeA6g87K3jaA
|
||||||
|
xpIg9G50PokpfWkhbarLfpcZu0UAoYy2su0EhN7wo2YwZDAOBgNVHQ8BAf8EBAMC
|
||||||
|
AQYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUhTBdOypw1O3VkmcH/es5
|
||||||
|
tBoOOKcwHwYDVR0jBBgwFoAUhTBdOypw1O3VkmcH/es5tBoOOKcwCgYIKoZIzj0E
|
||||||
|
AwIDSAAwRQIgEiIEHQr5UKma50D1WRMJBUSgjg24U8n8E2mfw/8UPz0CIQCr5V/e
|
||||||
|
mcifak4CQsr+DH4pn5SJD7JxtCG3YGswW8QZsw==
|
||||||
|
-----END CERTIFICATE-----
|
||||||
|
Issuer: C=US, O=CloudFlare, Inc., OU=CloudFlare Origin SSL Certificate Authority, L=San Francisco, ST=California
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIID/DCCAuagAwIBAgIID+rOSdTGfGcwCwYJKoZIhvcNAQELMIGLMQswCQYDVQQG
|
||||||
|
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRG
|
||||||
|
bGFyZSBPcmlnaW4gU1NMIENlcnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMN
|
||||||
|
U2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTAeFw0xNDExMTMyMDM4
|
||||||
|
NTBaFw0xOTExMTQwMTQzNTBaMIGLMQswCQYDVQQGEwJVUzEZMBcGA1UEChMQQ2xv
|
||||||
|
dWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRGbGFyZSBPcmlnaW4gU1NMIENl
|
||||||
|
cnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEG
|
||||||
|
A1UECBMKQ2FsaWZvcm5pYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
|
||||||
|
AMBIlWf1KEKR5hbB75OYrAcUXobpD/AxvSYRXr91mbRu+lqE7YbyyRUShQh15lem
|
||||||
|
ef+umeEtPZoLFLhcLyczJxOhI+siLGDQm/a/UDkWvAXYa5DZ+pHU5ct5nZ8pGzqJ
|
||||||
|
p8G1Hy5RMVYDXZT9F6EaHjMG0OOffH6Ih25TtgfyyrjXycwDH0u6GXt+G/rywcqz
|
||||||
|
/9W4Aki3XNQMUHNQAtBLEEIYHMkyTYJxuL2tXO6ID5cCsoWw8meHufTeZW2DyUpl
|
||||||
|
yP3AHt4149RQSyWZMJ6AyntL9d8Xhfpxd9rJkh9Kge2iV9rQTFuE1rRT5s7OSJcK
|
||||||
|
xUsklgHcGHYMcNfNMilNHb8CAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgAGMBIGA1Ud
|
||||||
|
EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFCToU1ddfDRAh6nrlNu64RZ4/CmkMB8G
|
||||||
|
A1UdIwQYMBaAFCToU1ddfDRAh6nrlNu64RZ4/CmkMAsGCSqGSIb3DQEBCwOCAQEA
|
||||||
|
cQDBVAoRrhhsGegsSFsv1w8v27zzHKaJNv6ffLGIRvXK8VKKK0gKXh2zQtN9SnaD
|
||||||
|
gYNe7Pr4C3I8ooYKRJJWLsmEHdGdnYYmj0OJfGrfQf6MLIc/11bQhLepZTxdhFYh
|
||||||
|
QGgDl6gRmb8aDwk7Q92BPvek5nMzaWlP82ixavvYI+okoSY8pwdcVKobx6rWzMWz
|
||||||
|
ZEC9M6H3F0dDYE23XcCFIdgNSAmmGyXPBstOe0aAJXwJTxOEPn36VWr0PKIQJy5Y
|
||||||
|
4o1wpMpqCOIwWc8J9REV/REzN6Z1LXImdUgXIXOwrz56gKUJzPejtBQyIGj0mveX
|
||||||
|
Fu6q54beR89jDc+oABmOgg==
|
||||||
|
-----END CERTIFICATE-----
|
||||||
|
Issuer: C=US, O=CloudFlare, Inc., OU=Origin Pull, L=San Francisco, ST=California, CN=origin-pull.cloudflare.net
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIIGBjCCA/CgAwIBAgIIV5G6lVbCLmEwCwYJKoZIhvcNAQENMIGQMQswCQYDVQQG
|
||||||
|
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjEUMBIGA1UECxMLT3JpZ2lu
|
||||||
|
IFB1bGwxFjAUBgNVBAcTDVNhbiBGcmFuY2lzY28xEzARBgNVBAgTCkNhbGlmb3Ju
|
||||||
|
aWExIzAhBgNVBAMTGm9yaWdpbi1wdWxsLmNsb3VkZmxhcmUubmV0MB4XDTE1MDEx
|
||||||
|
MzAyNDc1M1oXDTIwMDExMjAyNTI1M1owgZAxCzAJBgNVBAYTAlVTMRkwFwYDVQQK
|
||||||
|
ExBDbG91ZEZsYXJlLCBJbmMuMRQwEgYDVQQLEwtPcmlnaW4gUHVsbDEWMBQGA1UE
|
||||||
|
BxMNU2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTEjMCEGA1UEAxMa
|
||||||
|
b3JpZ2luLXB1bGwuY2xvdWRmbGFyZS5uZXQwggIiMA0GCSqGSIb3DQEBAQUAA4IC
|
||||||
|
DwAwggIKAoICAQDdsts6I2H5dGyn4adACQRXlfo0KmwsN7B5rxD8C5qgy6spyONr
|
||||||
|
WV0ecvdeGQfWa8Gy/yuTuOnsXfy7oyZ1dm93c3Mea7YkM7KNMc5Y6m520E9tHooc
|
||||||
|
f1qxeDpGSsnWc7HWibFgD7qZQx+T+yfNqt63vPI0HYBOYao6hWd3JQhu5caAcIS2
|
||||||
|
ms5tzSSZVH83ZPe6Lkb5xRgLl3eXEFcfI2DjnlOtLFqpjHuEB3Tr6agfdWyaGEEi
|
||||||
|
lRY1IB3k6TfLTaSiX2/SyJ96bp92wvTSjR7USjDV9ypf7AD6u6vwJZ3bwNisNw5L
|
||||||
|
ptph0FBnc1R6nDoHmvQRoyytoe0rl/d801i9Nru/fXa+l5K2nf1koR3IX440Z2i9
|
||||||
|
+Z4iVA69NmCbT4MVjm7K3zlOtwfI7i1KYVv+ATo4ycgBuZfY9f/2lBhIv7BHuZal
|
||||||
|
b9D+/EK8aMUfjDF4icEGm+RQfExv2nOpkR4BfQppF/dLmkYfjgtO1403X0ihkT6T
|
||||||
|
PYQdmYS6Jf53/KpqC3aA+R7zg2birtvprinlR14MNvwOsDOzsK4p8WYsgZOR4Qr2
|
||||||
|
gAx+z2aVOs/87+TVOR0r14irQsxbg7uP2X4t+EXx13glHxwG+CnzUVycDLMVGvuG
|
||||||
|
aUgF9hukZxlOZnrl6VOf1fg0Caf3uvV8smOkVw6DMsGhBZSJVwao0UQNqQIDAQAB
|
||||||
|
o2YwZDAOBgNVHQ8BAf8EBAMCAAYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4E
|
||||||
|
FgQUQ1lLK2mLgOERM2pXzVc42p59xeswHwYDVR0jBBgwFoAUQ1lLK2mLgOERM2pX
|
||||||
|
zVc42p59xeswCwYJKoZIhvcNAQENA4ICAQDKDQM1qPRVP/4Gltz0D6OU6xezFBKr
|
||||||
|
LWtDoA1qW2F7pkiYawCP9MrDPDJsHy7dx+xw3bBZxOsK5PA/T7p1dqpEl6i8F692
|
||||||
|
g//EuYOifLYw3ySPe3LRNhvPl/1f6Sn862VhPvLa8aQAAwR9e/CZvlY3fj+6G5ik
|
||||||
|
3it7fikmKUsVnugNOkjmwI3hZqXfJNc7AtHDFw0mEOV0dSeAPTo95N9cxBbm9PKv
|
||||||
|
qAEmTEXp2trQ/RjJ/AomJyfA1BQjsD0j++DI3a9/BbDwWmr1lJciKxiNKaa0BRLB
|
||||||
|
dKMrYQD+PkPNCgEuojT+paLKRrMyFUzHSG1doYm46NE9/WARTh3sFUp1B7HZSBqA
|
||||||
|
kHleoB/vQ/mDuW9C3/8Jk2uRUdZxR+LoNZItuOjU8oTy6zpN1+GgSj7bHjiy9rfA
|
||||||
|
F+ehdrz+IOh80WIiqs763PGoaYUyzxLvVowLWNoxVVoc9G+PqFKqD988XlipHVB6
|
||||||
|
Bz+1CD4D/bWrs3cC9+kk/jFmrrAymZlkFX8tDb5aXASSLJjUjcptci9SKqtI2h0J
|
||||||
|
wUGkD7+bQAr+7vr8/R+CBmNMe7csE8NeEX6lVMF7Dh0a1YKQa6hUN18bBuYgTMuT
|
||||||
|
QzMmZpRpIBB321ZBlcnlxiTJvWxvbCPHKHj20VwwAz7LONF59s84ZsOqfoBv8gKM
|
||||||
|
s0s5dsq5zpLeaw==
|
||||||
|
-----END CERTIFICATE-----`
|
||||||
|
|
||||||
|
func GetCloudflareRootCA() *x509.CertPool {
|
||||||
|
ca := x509.NewCertPool()
|
||||||
|
if !ca.AppendCertsFromPEM([]byte(cloudflareRootCA)) {
|
||||||
|
// should never happen
|
||||||
|
panic("failure loading Cloudflare origin CA pem")
|
||||||
|
}
|
||||||
|
return ca
|
||||||
|
}
|
|
@ -0,0 +1,50 @@
|
||||||
|
package tlsconfig
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
helloKey = `
|
||||||
|
-----BEGIN EC PARAMETERS-----
|
||||||
|
BgUrgQQAIg==
|
||||||
|
-----END EC PARAMETERS-----
|
||||||
|
-----BEGIN EC PRIVATE KEY-----
|
||||||
|
MIGkAgEBBDAdyQBXfxTDCQSOT0HugmH9pVBtIw8t5dYvm6HxGlNq6P57v5GeN02Z
|
||||||
|
dH9FRl7+VSWgBwYFK4EEACKhZANiAATqpFzTxxV7D+/oqhKCTR6BEM9elTfKaRQE
|
||||||
|
FsLufcmaTMw/9tTwgpHKao/QsLKDTNbQhbSQLkcmpCQKlSGhl+pCrqNt/oYUAhav
|
||||||
|
UIwpwGiLCqGH/R2AqWLKRPOa/Rufs/U=
|
||||||
|
-----END EC PRIVATE KEY-----`
|
||||||
|
|
||||||
|
helloCRT = `
|
||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIICkDCCAhigAwIBAgIJAPtKfUjc2lwGMAkGByqGSM49BAEwgYoxCzAJBgNVBAYT
|
||||||
|
AlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQKDBBD
|
||||||
|
bG91ZGZsYXJlLCBJbmMuMT8wPQYDVQQDDDZDbG91ZGZsYXJlIEFyZ28gVHVubmVs
|
||||||
|
IFNhbXBsZSBIZWxsbyBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMjE1MjAxNjU5
|
||||||
|
WhcNMjgwMjEzMjAxNjU5WjCBijELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVRleGFz
|
||||||
|
MQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENsb3VkZmxhcmUsIEluYy4xPzA9
|
||||||
|
BgNVBAMMNkNsb3VkZmxhcmUgQXJnbyBUdW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZl
|
||||||
|
ciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49AgEGBSuBBAAiA2IABOqkXNPHFXsP7+iq
|
||||||
|
EoJNHoEQz16VN8ppFAQWwu59yZpMzD/21PCCkcpqj9CwsoNM1tCFtJAuRyakJAqV
|
||||||
|
IaGX6kKuo23+hhQCFq9QjCnAaIsKoYf9HYCpYspE85r9G5+z9aNJMEcwRQYDVR0R
|
||||||
|
BD4wPIIJbG9jYWxob3N0ggp3YXJwLWhlbGxvggt3YXJwMi1oZWxsb4cEfwAAAYcQ
|
||||||
|
AAAAAAAAAAAAAAAAAAAAATAJBgcqhkjOPQQBA2cAMGQCMHyVPufXZ6vQo6XRWRa0
|
||||||
|
dAwtfgesOdZVP2Wt+t5v8jOIQQh1IQXYk5GtyoZGSObjhQIwd1fRgAyKXaZt+1DV
|
||||||
|
ZtHTdf8pMvESfJsSd8AB1eQ6q+pAiRUYyaxcE1Mlo2YY5o+g
|
||||||
|
-----END CERTIFICATE-----`
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetHelloCertificate() (tls.Certificate, error) {
|
||||||
|
return tls.X509KeyPair([]byte(helloCRT), []byte(helloKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetHelloCertificateX509() (*x509.Certificate, error) {
|
||||||
|
helloCertificate, err := GetHelloCertificate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return x509.ParseCertificate(helloCertificate.Certificate[0])
|
||||||
|
}
|
|
@ -6,8 +6,9 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
cli "gopkg.in/urfave/cli.v2"
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -60,3 +61,43 @@ func LoadCert(certPath string) *x509.CertPool {
|
||||||
}
|
}
|
||||||
return ca
|
return ca
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LoadOriginCertsPool() *x509.CertPool {
|
||||||
|
// First, obtain the system certificate pool
|
||||||
|
certPool, systemCertPoolErr := x509.SystemCertPool()
|
||||||
|
if systemCertPoolErr != nil {
|
||||||
|
log.Warn("error obtaining the system certificates: %s", systemCertPoolErr)
|
||||||
|
certPool = x509.NewCertPool()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, append the Cloudflare CA pool into the system pool
|
||||||
|
if !certPool.AppendCertsFromPEM([]byte(cloudflareRootCA)) {
|
||||||
|
log.Warn("could not append the CF certificate to the system certificate pool")
|
||||||
|
|
||||||
|
if systemCertPoolErr != nil { // Obtaining both certificates failed; this is a fatal error
|
||||||
|
log.WithError(systemCertPoolErr).Fatalf("Error loading the certificate pool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, add the Hello certificate into the pool (since it's self-signed)
|
||||||
|
helloCertificate, err := GetHelloCertificateX509()
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("error obtaining the Hello server certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
certPool.AddCert(helloCertificate)
|
||||||
|
|
||||||
|
return certPool
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateTunnelConfig(c *cli.Context, addrs []string) *tls.Config {
|
||||||
|
tlsConfig := CLIFlags{RootCA: "cacert"}.GetConfig(c)
|
||||||
|
if tlsConfig.RootCAs == nil {
|
||||||
|
tlsConfig.RootCAs = GetCloudflareRootCA()
|
||||||
|
tlsConfig.ServerName = "cftunnel.com"
|
||||||
|
} else if len(addrs) > 0 {
|
||||||
|
// Set for development environments and for testing specific origintunneld instances
|
||||||
|
tlsConfig.ServerName, _, _ = net.SplitHostPort(addrs[0])
|
||||||
|
}
|
||||||
|
return tlsConfig
|
||||||
|
}
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
package tunnelrpc
|
package tunnelrpc
|
||||||
|
|
||||||
//go:generate capnp compile -ogo -I./tunnelrpc/ tunnelrpc.capnp
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
|
"golang.org/x/net/trace"
|
||||||
"zombiezen.com/go/capnproto2/rpc"
|
"zombiezen.com/go/capnproto2/rpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,3 +23,20 @@ func (c ConnLogger) Errorf(ctx context.Context, format string, args ...interface
|
||||||
func ConnLog(log *log.Entry) rpc.ConnOption {
|
func ConnLog(log *log.Entry) rpc.ConnOption {
|
||||||
return rpc.ConnLog(ConnLogger{log})
|
return rpc.ConnLog(ConnLogger{log})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConnTracer wraps a trace.EventLog for a connection.
|
||||||
|
type ConnTracer struct {
|
||||||
|
Events trace.EventLog
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ConnTracer) Infof(ctx context.Context, format string, args ...interface{}) {
|
||||||
|
c.Events.Printf(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ConnTracer) Errorf(ctx context.Context, format string, args ...interface{}) {
|
||||||
|
c.Events.Errorf(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConnTrace(events trace.EventLog) rpc.ConnOption {
|
||||||
|
return rpc.ConnLog(ConnTracer{events})
|
||||||
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ package tunnelrpc
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"zombiezen.com/go/capnproto2/encoding/text"
|
"zombiezen.com/go/capnproto2/encoding/text"
|
||||||
"zombiezen.com/go/capnproto2/rpc"
|
"zombiezen.com/go/capnproto2/rpc"
|
||||||
|
|
|
@ -47,10 +47,11 @@ type RegistrationOptions struct {
|
||||||
Version string
|
Version string
|
||||||
OS string `capnp:"os"`
|
OS string `capnp:"os"`
|
||||||
ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy
|
ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy
|
||||||
PoolID string `capnp:"poolId"`
|
PoolName string `capnp:"poolName"`
|
||||||
Tags []Tag
|
Tags []Tag
|
||||||
ConnectionID uint8 `capnp:"connectionId"`
|
ConnectionID uint8 `capnp:"connectionId"`
|
||||||
OriginLocalIP string `capnp:"originLocalIp"`
|
OriginLocalIP string `capnp:"originLocalIp"`
|
||||||
|
IsAutoupdated bool `capnp:"isAutoupdated"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {
|
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {
|
||||||
|
|
|
@ -28,13 +28,15 @@ struct RegistrationOptions {
|
||||||
# What to do with existing tunnels for the given hostname.
|
# What to do with existing tunnels for the given hostname.
|
||||||
existingTunnelPolicy @3 :ExistingTunnelPolicy;
|
existingTunnelPolicy @3 :ExistingTunnelPolicy;
|
||||||
# If using the balancing policy, identifies the LB pool to use.
|
# If using the balancing policy, identifies the LB pool to use.
|
||||||
poolId @4 :Text;
|
poolName @4 :Text;
|
||||||
# Client-defined tags to associate with the tunnel
|
# Client-defined tags to associate with the tunnel
|
||||||
tags @5 :List(Tag);
|
tags @5 :List(Tag);
|
||||||
# A unique identifier for a high-availability connection made by a single client.
|
# A unique identifier for a high-availability connection made by a single client.
|
||||||
connectionId @6 :UInt8;
|
connectionId @6 :UInt8;
|
||||||
# origin LAN IP
|
# origin LAN IP
|
||||||
originLocalIp @7 :Text;
|
originLocalIp @7 :Text;
|
||||||
|
# whether Warp client has been autoupdated
|
||||||
|
isAutoupdated @8 :Bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Tag {
|
struct Tag {
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
package websocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
|
||||||
|
func IsWebSocketUpgrade(req *http.Request) bool {
|
||||||
|
return websocket.IsWebSocketUpgrade(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing.
|
||||||
|
func ClientConnect(req *http.Request, tlsClientConfig *tls.Config) (*websocket.Conn, *http.Response, error) {
|
||||||
|
req.URL.Scheme = "wss"
|
||||||
|
d := &websocket.Dialer{TLSClientConfig: tlsClientConfig}
|
||||||
|
conn, response, err := d.Dial(req.URL.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
|
||||||
|
return conn, response, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// HijackConnection takes over an HTTP connection. Caller is responsible for closing connection.
|
||||||
|
func HijackConnection(w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, errors.New("hijack error")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, brw, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return conn, brw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream copies copy data to & from provided io.ReadWriters.
|
||||||
|
func Stream(conn, backendConn io.ReadWriter) {
|
||||||
|
proxyDone := make(chan struct{}, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
io.Copy(conn, backendConn)
|
||||||
|
proxyDone <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
io.Copy(backendConn, conn)
|
||||||
|
proxyDone <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// If one side is done, we are done.
|
||||||
|
<-proxyDone
|
||||||
|
}
|
||||||
|
|
||||||
|
// sha1Base64 sha1 and then base64 encodes str.
|
||||||
|
func sha1Base64(str string) string {
|
||||||
|
hasher := sha1.New()
|
||||||
|
io.WriteString(hasher, str)
|
||||||
|
hash := hasher.Sum(nil)
|
||||||
|
return base64.StdEncoding.EncodeToString(hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateAcceptKey returns the string needed for the Sec-WebSocket-Accept header.
|
||||||
|
// https://tools.ietf.org/html/rfc6455#section-1.3 describes this process in more detail.
|
||||||
|
func generateAcceptKey(req *http.Request) string {
|
||||||
|
return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||||
|
}
|
Loading…
Reference in New Issue