Release Warp Client 2018.2.1
This commit is contained in:
parent
e0ae598112
commit
3780e14f41
|
@ -2,17 +2,20 @@ package main
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/gorilla/websocket"
|
||||
"gopkg.in/urfave/cli.v2"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
"github.com/cloudflare/cloudflare-warp/tlsconfig"
|
||||
)
|
||||
|
||||
type templateData struct {
|
||||
|
@ -21,6 +24,11 @@ type templateData struct {
|
|||
Body string
|
||||
}
|
||||
|
||||
type OriginUpTime struct {
|
||||
StartTime time.Time `json:"startTime"`
|
||||
UpTime string `json:"uptime"`
|
||||
}
|
||||
|
||||
const defaultServerName = "the Cloudflare Warp test server"
|
||||
const indexTemplate = `
|
||||
<!DOCTYPE html>
|
||||
|
@ -85,54 +93,95 @@ const indexTemplate = `
|
|||
|
||||
func hello(c *cli.Context) error {
|
||||
address := fmt.Sprintf(":%d", c.Int("port"))
|
||||
server := NewHelloWorldServer()
|
||||
if hostname, err := os.Hostname(); err != nil {
|
||||
server.serverName = hostname
|
||||
listener, err := createListener(address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := server.ListenAndServe(address)
|
||||
return errors.Wrap(err, "Fail to start Hello World Server")
|
||||
defer listener.Close()
|
||||
err = startHelloWorldServer(listener, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
|
||||
server := NewHelloWorldServer()
|
||||
if hostname, err := os.Hostname(); err != nil {
|
||||
server.serverName = hostname
|
||||
Log.Infof("Starting Hello World server at %s", listener.Addr())
|
||||
serverName := defaultServerName
|
||||
if hostname, err := os.Hostname(); err == nil {
|
||||
serverName = hostname
|
||||
}
|
||||
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: server}
|
||||
|
||||
upgrader := websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
|
||||
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
|
||||
go func() {
|
||||
<-shutdownC
|
||||
httpServer.Close()
|
||||
}()
|
||||
|
||||
http.HandleFunc("/uptime", uptimeHandler(time.Now()))
|
||||
http.HandleFunc("/ws", websocketHandler(upgrader))
|
||||
http.HandleFunc("/", rootHandler(serverName))
|
||||
err := httpServer.Serve(listener)
|
||||
return err
|
||||
}
|
||||
|
||||
type HelloWorldServer struct {
|
||||
responseTemplate *template.Template
|
||||
serverName string
|
||||
func createListener(address string) (net.Listener, error) {
|
||||
certificate, err := tlsconfig.GetHelloCertificate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func NewHelloWorldServer() *HelloWorldServer {
|
||||
return &HelloWorldServer{
|
||||
responseTemplate: template.Must(template.New("index").Parse(indexTemplate)),
|
||||
serverName: defaultServerName,
|
||||
}
|
||||
}
|
||||
// If the port in address is empty, a port number is automatically chosen
|
||||
listener, err := tls.Listen(
|
||||
"tcp",
|
||||
address,
|
||||
&tls.Config{Certificates: []tls.Certificate{certificate}})
|
||||
|
||||
func findAvailablePort() (net.Listener, error) {
|
||||
// If the port in address is empty, a port number is automatically chosen.
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:")
|
||||
return listener, err
|
||||
}
|
||||
|
||||
func (s *HelloWorldServer) ListenAndServe(address string) error {
|
||||
log.Infof("Starting Hello World server on %s", address)
|
||||
err := http.ListenAndServe(address, s)
|
||||
return err
|
||||
func uptimeHandler(startTime time.Time) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Note that if autoupdate is enabled, the uptime is reset when a new client
|
||||
// release is available
|
||||
resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()}
|
||||
respJson, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(respJson)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto)
|
||||
func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if err := conn.WriteMessage(mt, message); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func rootHandler(serverName string) http.HandlerFunc {
|
||||
responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
Log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto)
|
||||
var buffer bytes.Buffer
|
||||
var body string
|
||||
rawBody, err := ioutil.ReadAll(r.Body)
|
||||
|
@ -141,8 +190,8 @@ func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
} else {
|
||||
body = ""
|
||||
}
|
||||
err = s.responseTemplate.Execute(&buffer, &templateData{
|
||||
ServerName: s.serverName,
|
||||
err = responseTemplate.Execute(&buffer, &templateData{
|
||||
ServerName: serverName,
|
||||
Request: r,
|
||||
Body: body,
|
||||
})
|
||||
|
@ -153,3 +202,4 @@ func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
buffer.WriteTo(w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,18 +4,30 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
const testPort = "8080"
|
||||
|
||||
func TestNewHelloWorldServer(t *testing.T) {
|
||||
if NewHelloWorldServer() == nil {
|
||||
t.Fatal("NewHelloWorldServer returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAvailablePort(t *testing.T) {
|
||||
listener, err := findAvailablePort()
|
||||
func TestCreateListenerHostAndPortSuccess(t *testing.T) {
|
||||
listener, err := createListener("localhost:1234")
|
||||
if err != nil {
|
||||
t.Fatal("Fail to find available port")
|
||||
t.Fatal(err)
|
||||
}
|
||||
if listener.Addr().String() == "" {
|
||||
t.Fatal("Fail to find available port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateListenerOnlyHostSuccess(t *testing.T) {
|
||||
listener, err := createListener("localhost:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if listener.Addr().String() == "" {
|
||||
t.Fatal("Fail to find available port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateListenerOnlyPortSuccess(t *testing.T) {
|
||||
listener, err := createListener(":8888")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if listener.Addr().String() == "" {
|
||||
t.Fatal("Fail to find available port")
|
||||
|
|
|
@ -42,6 +42,8 @@ After=network.target
|
|||
TimeoutStartSec=0
|
||||
Type=notify
|
||||
ExecStart={{ .Path }} --config /etc/cloudflare-warp/config.yml --origincert /etc/cloudflare-warp/cert.pem --no-autoupdate
|
||||
Restart=on-failure
|
||||
RestartSec=5s
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
@ -137,7 +136,7 @@ func download(certURL, filePath string) bool {
|
|||
return true
|
||||
}
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Error fetching certificate")
|
||||
Log.WithError(err).Error("Error fetching certificate")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
@ -180,16 +179,16 @@ func putSuccess(client *http.Client, certURL string) {
|
|||
// indicate success to the relay server
|
||||
req, err := http.NewRequest("PUT", certURL+"/ok", nil)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("HTTP request error")
|
||||
Log.WithError(err).Error("HTTP request error")
|
||||
return
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("HTTP error")
|
||||
Log.WithError(err).Error("HTTP error")
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
log.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
|
||||
Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,8 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
@ -21,11 +23,12 @@ import (
|
|||
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||
"github.com/cloudflare/cloudflare-warp/validation"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/facebookgo/grace/gracenet"
|
||||
raven "github.com/getsentry/raven-go"
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
"github.com/getsentry/raven-go"
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/rifflock/lfshook"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gopkg.in/urfave/cli.v2"
|
||||
"gopkg.in/urfave/cli.v2/altsrc"
|
||||
|
||||
"github.com/coreos/go-systemd/daemon"
|
||||
|
@ -40,11 +43,21 @@ const configFile = "config.yml"
|
|||
var listeners = gracenet.Net{}
|
||||
var Version = "DEV"
|
||||
var BuildTime = "unknown"
|
||||
var Log *logrus.Logger
|
||||
|
||||
// Shutdown channel used by the app. When closed, app must terminate.
|
||||
// May be closed by the Windows service runner.
|
||||
var shutdownC chan struct{}
|
||||
|
||||
type BuildAndRuntimeInfo struct {
|
||||
GoOS string `json:"go_os"`
|
||||
GoVersion string `json:"go_version"`
|
||||
GoArch string `json:"go_arch"`
|
||||
WarpVersion string `json:"warp_version"`
|
||||
WarpFlags map[string]interface{} `json:"warp_flags"`
|
||||
WarpEnvs map[string]string `json:"warp_envs"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
metrics.RegisterBuildInfo(BuildTime, Version)
|
||||
raven.SetDSN(sentryDSN)
|
||||
|
@ -84,6 +97,12 @@ WARNING:
|
|||
Usage: "Disable periodic check for updates, restarting the server with the new version.",
|
||||
Value: false,
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "is-autoupdated",
|
||||
Usage: "Signal the new process that Warp client has been autoupdated",
|
||||
Value: false,
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||
Name: "edge",
|
||||
Usage: "Address of the Cloudflare tunnel server.",
|
||||
|
@ -99,12 +118,12 @@ WARNING:
|
|||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "origincert",
|
||||
Usage: "Path to the certificate generated for your origin when you run cloudflare-warp login.",
|
||||
EnvVars: []string{"ORIGIN_CERT"},
|
||||
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
|
||||
Value: filepath.Join(defaultConfigDir, credentialFile),
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "url",
|
||||
Value: "http://localhost:8080",
|
||||
Value: "https://localhost:8080",
|
||||
Usage: "Connect to the local webserver at `URL`.",
|
||||
EnvVars: []string{"TUNNEL_URL"},
|
||||
}),
|
||||
|
@ -191,14 +210,20 @@ WARNING:
|
|||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "hello-world",
|
||||
Usage: "Run Hello World Server",
|
||||
Value: false,
|
||||
Usage: "Run Hello World Server",
|
||||
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "pidfile",
|
||||
Usage: "Write the application's PID to this file after first successful connection.",
|
||||
EnvVars: []string{"TUNNEL_PIDFILE"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "logfile",
|
||||
Usage: "Save application log to this file for reporting issues.",
|
||||
EnvVars: []string{"TUNNEL_LOGFILE"},
|
||||
}),
|
||||
altsrc.NewIntFlag(&cli.IntFlag{
|
||||
Name: "ha-connections",
|
||||
Value: 4,
|
||||
|
@ -239,6 +264,7 @@ WARNING:
|
|||
return nil
|
||||
}
|
||||
app.Before = func(context *cli.Context) error {
|
||||
Log = logrus.New()
|
||||
inputSource, err := findInputSourceContext(context)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -248,7 +274,7 @@ WARNING:
|
|||
return nil
|
||||
}
|
||||
app.Commands = []*cli.Command{
|
||||
&cli.Command{
|
||||
{
|
||||
Name: "update",
|
||||
Action: update,
|
||||
Usage: "Update the agent if a new version exists",
|
||||
|
@ -259,7 +285,7 @@ WARNING:
|
|||
|
||||
To determine if an update happened in a script, check for error code 64.`,
|
||||
},
|
||||
&cli.Command{
|
||||
{
|
||||
Name: "login",
|
||||
Action: login,
|
||||
Usage: "Generate a configuration file with your login details",
|
||||
|
@ -271,7 +297,7 @@ WARNING:
|
|||
},
|
||||
},
|
||||
},
|
||||
&cli.Command{
|
||||
{
|
||||
Name: "hello",
|
||||
Action: hello,
|
||||
Usage: "Run a simple \"Hello World\" server for testing Cloudflare Warp.",
|
||||
|
@ -293,27 +319,43 @@ func startServer(c *cli.Context) {
|
|||
errC := make(chan error)
|
||||
wg.Add(2)
|
||||
|
||||
if c.NumFlags() == 0 && c.NArg() == 0 {
|
||||
// If the user choose to supply all options through env variables,
|
||||
// c.NumFlags() == 0 && c.NArg() == 0. For warp to work, the user needs to at
|
||||
// least provide a hostname.
|
||||
if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" {
|
||||
cli.ShowAppHelp(c)
|
||||
return
|
||||
}
|
||||
|
||||
logLevel, err := log.ParseLevel(c.String("loglevel"))
|
||||
logLevel, err := logrus.ParseLevel(c.String("loglevel"))
|
||||
if err != nil {
|
||||
log.WithError(err).Fatal("Unknown logging level specified")
|
||||
Log.WithError(err).Fatal("Unknown logging level specified")
|
||||
}
|
||||
logrus.SetLevel(logLevel)
|
||||
|
||||
log.SetLevel(logLevel)
|
||||
protoLogLevel, err := log.ParseLevel(c.String("proto-loglevel"))
|
||||
protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
|
||||
if err != nil {
|
||||
log.WithError(err).Fatal("Unknown protocol logging level specified")
|
||||
Log.WithError(err).Fatal("Unknown protocol logging level specified")
|
||||
}
|
||||
protoLogger := log.New()
|
||||
protoLogger := logrus.New()
|
||||
protoLogger.Level = protoLogLevel
|
||||
|
||||
if c.String("logfile") != "" {
|
||||
if err := initLogFile(c, protoLogger); err != nil {
|
||||
Log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
if !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0 {
|
||||
if initUpdate() {
|
||||
return
|
||||
}
|
||||
Log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
|
||||
go autoupdate(c.Duration("autoupdate-freq"), shutdownC)
|
||||
}
|
||||
|
||||
hostname, err := validation.ValidateHostname(c.String("hostname"))
|
||||
if err != nil {
|
||||
log.WithError(err).Fatal("Invalid hostname")
|
||||
Log.WithError(err).Fatal("Invalid hostname")
|
||||
|
||||
}
|
||||
clientID := c.String("id")
|
||||
|
@ -323,46 +365,44 @@ func startServer(c *cli.Context) {
|
|||
|
||||
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
|
||||
if err != nil {
|
||||
log.WithError(err).Fatal("Tag parse failure")
|
||||
Log.WithError(err).Fatal("Tag parse failure")
|
||||
}
|
||||
|
||||
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
|
||||
|
||||
if c.IsSet("hello-world") {
|
||||
wg.Add(1)
|
||||
listener, err := findAvailablePort()
|
||||
listener, err := createListener("127.0.0.1:")
|
||||
if err != nil {
|
||||
listener.Close()
|
||||
log.WithError(err).Fatal("Cannot start Hello World Server")
|
||||
Log.WithError(err).Fatal("Cannot start Hello World Server")
|
||||
}
|
||||
go func() {
|
||||
startHelloWorldServer(listener, shutdownC)
|
||||
wg.Done()
|
||||
listener.Close()
|
||||
}()
|
||||
c.Set("url", "http://"+listener.Addr().String())
|
||||
log.Infof("Starting Hello World Server at %s", c.String("url"))
|
||||
c.Set("url", "https://"+listener.Addr().String())
|
||||
}
|
||||
|
||||
url, err := validateUrl(c)
|
||||
if err != nil {
|
||||
log.WithError(err).Fatal("Error validating url")
|
||||
Log.WithError(err).Fatal("Error validating url")
|
||||
}
|
||||
log.Infof("Proxying tunnel requests to %s", url)
|
||||
Log.Infof("Proxying tunnel requests to %s", url)
|
||||
|
||||
// Fail if the user provided an old authentication method
|
||||
if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
|
||||
log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login")
|
||||
Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflare-warp login")
|
||||
}
|
||||
|
||||
// Check that the user has acquired a certificate using the log in command
|
||||
originCertPath, err := homedir.Expand(c.String("origincert"))
|
||||
if err != nil {
|
||||
log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
|
||||
Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
|
||||
}
|
||||
ok, err := fileExists(originCertPath)
|
||||
if !ok {
|
||||
log.Fatalf(`Cannot find a valid certificate for your origin at the path:
|
||||
Log.Fatalf(`Cannot find a valid certificate for your origin at the path:
|
||||
|
||||
%s
|
||||
|
||||
|
@ -375,8 +415,9 @@ If you don't have a certificate signed by Cloudflare, run the command:
|
|||
// Easier to send the certificate as []byte via RPC than decoding it at this point
|
||||
originCert, err := ioutil.ReadFile(originCertPath)
|
||||
if err != nil {
|
||||
log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
|
||||
Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
|
||||
}
|
||||
|
||||
tunnelMetrics := origin.NewTunnelMetrics()
|
||||
httpTransport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
|
@ -389,13 +430,15 @@ If you don't have a certificate signed by Cloudflare, run the command:
|
|||
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
|
||||
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()},
|
||||
}
|
||||
tunnelConfig := &origin.TunnelConfig{
|
||||
EdgeAddrs: c.StringSlice("edge"),
|
||||
OriginUrl: url,
|
||||
Hostname: hostname,
|
||||
OriginCert: originCert,
|
||||
TlsConfig: &tls.Config{},
|
||||
TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
|
||||
ClientTlsConfig: httpTransport.TLSClientConfig,
|
||||
Retries: c.Uint("retries"),
|
||||
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
||||
MaxHeartbeats: c.Uint64("heartbeat-count"),
|
||||
|
@ -408,18 +451,11 @@ If you don't have a certificate signed by Cloudflare, run the command:
|
|||
Metrics: tunnelMetrics,
|
||||
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
|
||||
ProtocolLogger: protoLogger,
|
||||
Logger: Log,
|
||||
IsAutoupdated: c.Bool("is-autoupdated"),
|
||||
}
|
||||
connectedSignal := make(chan struct{})
|
||||
|
||||
tunnelConfig.TlsConfig = tlsconfig.CLIFlags{RootCA: "cacert"}.GetConfig(c)
|
||||
if tunnelConfig.TlsConfig.RootCAs == nil {
|
||||
tunnelConfig.TlsConfig.RootCAs = GetCloudflareRootCA()
|
||||
tunnelConfig.TlsConfig.ServerName = "cftunnel.com"
|
||||
} else if len(tunnelConfig.EdgeAddrs) > 0 {
|
||||
// Set for development environments and for testing specific origintunneld instances
|
||||
tunnelConfig.TlsConfig.ServerName, _, _ = net.SplitHostPort(tunnelConfig.EdgeAddrs[0])
|
||||
}
|
||||
|
||||
go writePidFile(connectedSignal, c.String("pidfile"))
|
||||
go func() {
|
||||
errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal)
|
||||
|
@ -428,24 +464,21 @@ If you don't have a certificate signed by Cloudflare, run the command:
|
|||
|
||||
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
|
||||
if err != nil {
|
||||
log.WithError(err).Fatal("Error opening metrics server listener")
|
||||
Log.WithError(err).Fatal("Error opening metrics server listener")
|
||||
}
|
||||
go func() {
|
||||
errC <- metrics.ServeMetrics(metricsListener, shutdownC)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
if !c.Bool("no-autoupdate") {
|
||||
log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
|
||||
go autoupdate(c.Duration("autoupdate-period"), shutdownC)
|
||||
}
|
||||
|
||||
var errCode int
|
||||
err = WaitForSignal(errC, shutdownC)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Quitting due to error")
|
||||
Log.WithError(err).Error("Quitting due to error")
|
||||
raven.CaptureErrorAndWait(err, nil)
|
||||
errCode = 1
|
||||
} else {
|
||||
log.Info("Quitting...")
|
||||
Log.Info("Quitting...")
|
||||
}
|
||||
// Wait for clean exit, discarding all errors
|
||||
go func() {
|
||||
|
@ -453,6 +486,7 @@ If you don't have a certificate signed by Cloudflare, run the command:
|
|||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
os.Exit(errCode)
|
||||
}
|
||||
|
||||
func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
|
||||
|
@ -477,30 +511,40 @@ func update(c *cli.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func autoupdate(frequency time.Duration, shutdownC chan struct{}) {
|
||||
if int64(frequency) == 0 {
|
||||
return
|
||||
func initUpdate() bool {
|
||||
if updateApplied() {
|
||||
os.Args = append(os.Args, "--is-autoupdated=true")
|
||||
if _, err := listeners.StartProcess(); err != nil {
|
||||
Log.WithError(err).Error("Unable to restart server automatically")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func autoupdate(freq time.Duration, shutdownC chan struct{}) {
|
||||
for {
|
||||
if updateApplied() {
|
||||
os.Args = append(os.Args, "--is-autoupdated=true")
|
||||
if _, err := listeners.StartProcess(); err != nil {
|
||||
log.WithError(err).Error("Unable to restart server automatically")
|
||||
Log.WithError(err).Error("Unable to restart server automatically")
|
||||
}
|
||||
close(shutdownC)
|
||||
return
|
||||
}
|
||||
time.Sleep(frequency)
|
||||
time.Sleep(freq)
|
||||
}
|
||||
}
|
||||
|
||||
func updateApplied() bool {
|
||||
releaseInfo := checkForUpdates()
|
||||
if releaseInfo.Updated {
|
||||
log.Infof("Updated to version %s", releaseInfo.Version)
|
||||
Log.Infof("Updated to version %s", releaseInfo.Version)
|
||||
return true
|
||||
}
|
||||
if releaseInfo.Error != nil {
|
||||
log.WithError(releaseInfo.Error).Error("Update check failed")
|
||||
Log.WithError(releaseInfo.Error).Error("Update check failed")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -555,7 +599,7 @@ func writePidFile(waitForSignal chan struct{}, pidFile string) {
|
|||
}
|
||||
file, err := os.Create(pidFile)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
|
||||
Log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
|
||||
}
|
||||
defer file.Close()
|
||||
fmt.Fprintf(file, "%d", os.Getpid())
|
||||
|
@ -573,3 +617,55 @@ func validateUrl(c *cli.Context) (string, error) {
|
|||
validUrl, err := validation.ValidateUrl(url)
|
||||
return validUrl, err
|
||||
}
|
||||
|
||||
func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error {
|
||||
fileMode := os.O_WRONLY|os.O_APPEND|os.O_CREATE|os.O_TRUNC
|
||||
// do not truncate log file if the client has been autoupdated
|
||||
if c.Bool("is-autoupdated") {
|
||||
fileMode = os.O_WRONLY|os.O_APPEND|os.O_CREATE
|
||||
}
|
||||
f, err := os.OpenFile(c.String("logfile"), fileMode, 0664)
|
||||
if err != nil {
|
||||
errors.Wrap(err, fmt.Sprintf("Cannot open file %s", c.String("logfile")))
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
pathMap := lfshook.PathMap{
|
||||
logrus.InfoLevel: c.String("logfile"),
|
||||
logrus.ErrorLevel: c.String("logfile"),
|
||||
logrus.FatalLevel: c.String("logfile"),
|
||||
logrus.PanicLevel: c.String("logfile"),
|
||||
}
|
||||
|
||||
Log.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
|
||||
protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
|
||||
|
||||
flags := make(map[string]interface{})
|
||||
envs := make(map[string]string)
|
||||
|
||||
for _, flag := range c.LocalFlagNames() {
|
||||
flags[flag] = c.Generic(flag)
|
||||
}
|
||||
|
||||
// Find env variables for Warp
|
||||
for _, env := range os.Environ() {
|
||||
// All Warp env variables start with TUNNEL_
|
||||
if strings.Contains(env, "TUNNEL_") {
|
||||
vars := strings.Split(env, "=")
|
||||
if len(vars) == 2 {
|
||||
envs[vars[0]] = vars[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Log.Infof("Warp build and runtime configuration: %+v", BuildAndRuntimeInfo{
|
||||
GoOS: runtime.GOOS,
|
||||
GoVersion: runtime.Version(),
|
||||
GoArch: runtime.GOARCH,
|
||||
WarpVersion: Version,
|
||||
WarpFlags: flags,
|
||||
WarpEnvs: envs,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
|
||||
"golang.org/x/sys/windows/svc"
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
|
|
@ -6,13 +6,14 @@ import (
|
|||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
|
@ -25,25 +26,20 @@ func TestMain(m *testing.M) {
|
|||
type DefaultMuxerPair struct {
|
||||
OriginMuxConfig MuxerConfig
|
||||
OriginMux *Muxer
|
||||
OriginWriter *io.PipeWriter
|
||||
OriginReader *io.PipeReader
|
||||
OriginConn net.Conn
|
||||
EdgeMuxConfig MuxerConfig
|
||||
EdgeMux *Muxer
|
||||
EdgeWriter *io.PipeWriter
|
||||
EdgeReader *io.PipeReader
|
||||
EdgeConn net.Conn
|
||||
doneC chan struct{}
|
||||
}
|
||||
|
||||
func NewDefaultMuxerPair() *DefaultMuxerPair {
|
||||
originReader, edgeWriter := io.Pipe()
|
||||
edgeReader, originWriter := io.Pipe()
|
||||
origin, edge := net.Pipe()
|
||||
return &DefaultMuxerPair{
|
||||
OriginMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: true, Name: "origin"},
|
||||
OriginWriter: originWriter,
|
||||
OriginReader: originReader,
|
||||
OriginConn: origin,
|
||||
EdgeMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: false, Name: "edge"},
|
||||
EdgeWriter: edgeWriter,
|
||||
EdgeReader: edgeReader,
|
||||
EdgeConn: edge,
|
||||
doneC: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
@ -53,12 +49,12 @@ func (p *DefaultMuxerPair) Handshake(t *testing.T) {
|
|||
originErrC := make(chan error)
|
||||
go func() {
|
||||
var err error
|
||||
p.EdgeMux, err = Handshake(p.EdgeWriter, p.EdgeReader, p.EdgeMuxConfig)
|
||||
p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig)
|
||||
edgeErrC <- err
|
||||
}()
|
||||
go func() {
|
||||
var err error
|
||||
p.OriginMux, err = Handshake(p.OriginWriter, p.OriginReader, p.OriginMuxConfig)
|
||||
p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig)
|
||||
originErrC <- err
|
||||
}()
|
||||
|
||||
|
@ -120,8 +116,8 @@ func (p *DefaultMuxerPair) Wait(t *testing.T) {
|
|||
func TestHandshake(t *testing.T) {
|
||||
muxPair := NewDefaultMuxerPair()
|
||||
muxPair.Handshake(t)
|
||||
AssertIfPipeReadable(t, muxPair.OriginReader)
|
||||
AssertIfPipeReadable(t, muxPair.EdgeReader)
|
||||
AssertIfPipeReadable(t, muxPair.OriginConn)
|
||||
AssertIfPipeReadable(t, muxPair.EdgeConn)
|
||||
}
|
||||
|
||||
func TestSingleStream(t *testing.T) {
|
||||
|
@ -145,7 +141,7 @@ func TestSingleStream(t *testing.T) {
|
|||
stream.Write(buf)
|
||||
// after this receive, the edge closed the stream
|
||||
<-closeC
|
||||
n, err := stream.Read(buf)
|
||||
n, err := io.ReadFull(stream, buf)
|
||||
if n > 0 {
|
||||
t.Fatalf("read %d bytes after EOF", n)
|
||||
}
|
||||
|
@ -173,7 +169,7 @@ func TestSingleStream(t *testing.T) {
|
|||
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
|
||||
}
|
||||
responseBody := make([]byte, 11)
|
||||
n, err := stream.Read(responseBody)
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error from (*MuxedStream).Read: %s", err)
|
||||
}
|
||||
|
@ -243,7 +239,7 @@ func TestSingleStreamLargeResponseBody(t *testing.T) {
|
|||
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
|
||||
}
|
||||
responseBody := make([]byte, bodySize)
|
||||
n, err := stream.Read(responseBody)
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error from (*MuxedStream).Read: %s", err)
|
||||
}
|
||||
|
@ -302,7 +298,7 @@ func TestMultipleStreams(t *testing.T) {
|
|||
return
|
||||
}
|
||||
responseBody := make([]byte, 2)
|
||||
n, err := stream.Read(responseBody)
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != nil {
|
||||
errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err)
|
||||
return
|
||||
|
@ -392,7 +388,7 @@ func TestMultipleStreamsFlowControl(t *testing.T) {
|
|||
}
|
||||
|
||||
responseBody := make([]byte, responseSizes[(stream.streamID-2)/2])
|
||||
n, err := stream.Read(responseBody)
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != nil {
|
||||
errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err)
|
||||
return
|
||||
|
@ -451,7 +447,7 @@ func TestGracefulShutdown(t *testing.T) {
|
|||
}
|
||||
responseBody := make([]byte, len(responseBuf))
|
||||
log.Debugf("Waiting for %d bytes", len(responseBuf))
|
||||
n, err := stream.Read(responseBody)
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err)
|
||||
}
|
||||
|
@ -498,13 +494,13 @@ func TestUnexpectedShutdown(t *testing.T) {
|
|||
nil,
|
||||
)
|
||||
// Close the underlying connection before telling the origin to write.
|
||||
muxPair.EdgeReader.Close()
|
||||
muxPair.EdgeConn.Close()
|
||||
close(sendC)
|
||||
if err != nil {
|
||||
t.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
responseBody := make([]byte, len(responseBuf))
|
||||
n, err := stream.Read(responseBody)
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != io.EOF {
|
||||
t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err)
|
||||
}
|
||||
|
@ -545,14 +541,14 @@ func TestOpenAfterDisconnect(t *testing.T) {
|
|||
switch i {
|
||||
case 0:
|
||||
// Close both directions of the connection to cause EOF on both peers.
|
||||
muxPair.OriginReader.Close()
|
||||
muxPair.OriginWriter.Close()
|
||||
muxPair.OriginConn.Close()
|
||||
muxPair.EdgeConn.Close()
|
||||
case 1:
|
||||
// Close origin reader (edge writer) to cause EOF on origin only.
|
||||
muxPair.OriginReader.Close()
|
||||
// Close origin conn to cause EOF on origin first.
|
||||
muxPair.OriginConn.Close()
|
||||
case 2:
|
||||
// Close origin writer (edge reader) to cause EOF on edge only.
|
||||
muxPair.OriginWriter.Close()
|
||||
// Close edge conn to cause EOF on edge first.
|
||||
muxPair.EdgeConn.Close()
|
||||
}
|
||||
|
||||
_, err := muxPair.EdgeMux.OpenStream(
|
||||
|
@ -623,7 +619,7 @@ func TestHPACK(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) {
|
||||
func AssertIfPipeReadable(t *testing.T, pipe io.ReadCloser) {
|
||||
errC := make(chan error)
|
||||
go func() {
|
||||
b := []byte{0}
|
||||
|
@ -640,7 +636,5 @@ func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) {
|
|||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// nothing to read
|
||||
pipe.Close()
|
||||
<-errC
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -63,3 +64,27 @@ func TestFlowControlSingleStream(t *testing.T) {
|
|||
assert.Equal(t, testWindowSize<<2, stream.receiveWindow)
|
||||
assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax)
|
||||
}
|
||||
|
||||
func TestMuxedStreamEOF(t *testing.T) {
|
||||
for i := 0; i < 4096; i++ {
|
||||
readyList := NewReadyList()
|
||||
stream := &MuxedStream{
|
||||
streamID: 1,
|
||||
readBuffer: NewSharedBuffer(),
|
||||
receiveWindow: 65536,
|
||||
receiveWindowMax: 65536,
|
||||
sendWindow: 65536,
|
||||
readyList: readyList,
|
||||
}
|
||||
|
||||
go func() { stream.Close() }()
|
||||
n, err := stream.Read([]byte{0})
|
||||
assert.Equal(t, io.EOF, err)
|
||||
assert.Equal(t, 0, n)
|
||||
// Write comes after read, because write buffers data before it is flushed. It wouldn't know about EOF
|
||||
// until some time later. Calling read first forces it to know about EOF now.
|
||||
n, err = stream.Write([]byte{1})
|
||||
assert.Equal(t, io.EOF, err)
|
||||
assert.Equal(t, 0, n)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
"io"
|
||||
"time"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
|
|
@ -21,7 +21,7 @@ func NewSharedBuffer() *SharedBuffer {
|
|||
func (s *SharedBuffer) Read(p []byte) (n int, err error) {
|
||||
totalRead := 0
|
||||
s.cond.L.Lock()
|
||||
for totalRead < len(p) {
|
||||
for totalRead == 0 {
|
||||
n, err = s.buffer.Read(p[totalRead:])
|
||||
totalRead += n
|
||||
if err == io.EOF {
|
||||
|
@ -29,6 +29,9 @@ func (s *SharedBuffer) Read(p []byte) (n int, err error) {
|
|||
break
|
||||
}
|
||||
err = nil
|
||||
if n > 0 {
|
||||
break
|
||||
}
|
||||
s.cond.Wait()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) {
|
||||
|
@ -29,30 +31,35 @@ func TestSharedBuffer(t *testing.T) {
|
|||
|
||||
func TestSharedBufferBlockingRead(t *testing.T) {
|
||||
b := NewSharedBuffer()
|
||||
testData := []byte("Hello world")
|
||||
testData1 := []byte("Hello")
|
||||
testData2 := []byte(" world")
|
||||
result := make(chan []byte)
|
||||
go func() {
|
||||
bytesRead := make([]byte, len(testData))
|
||||
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
|
||||
result <- bytesRead
|
||||
bytesRead := make([]byte, len(testData1)+len(testData2))
|
||||
nRead, err := b.Read(bytesRead)
|
||||
AssertIOReturnIsGood(t, len(testData1))(nRead, err)
|
||||
result <- bytesRead[:nRead]
|
||||
nRead, err = b.Read(bytesRead)
|
||||
AssertIOReturnIsGood(t, len(testData2))(nRead, err)
|
||||
result <- bytesRead[:nRead]
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 250)
|
||||
select {
|
||||
case <-result:
|
||||
t.Fatalf("read returned early")
|
||||
default:
|
||||
}
|
||||
AssertIOReturnIsGood(t, 5)(b.Write(testData[:5]))
|
||||
select {
|
||||
case <-result:
|
||||
t.Fatalf("read returned early")
|
||||
default:
|
||||
}
|
||||
AssertIOReturnIsGood(t, len(testData)-5)(b.Write(testData[5:]))
|
||||
AssertIOReturnIsGood(t, len(testData1))(b.Write([]byte(testData1)))
|
||||
select {
|
||||
case r := <-result:
|
||||
if string(r) != string(testData) {
|
||||
t.Fatalf("expected read to return %s, got %s", testData, r)
|
||||
assert.Equal(t, testData1, r)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("read timed out")
|
||||
}
|
||||
AssertIOReturnIsGood(t, len(testData2))(b.Write([]byte(testData2)))
|
||||
select {
|
||||
case r := <-result:
|
||||
assert.Equal(t, testData2, r)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("read timed out")
|
||||
}
|
||||
|
@ -85,7 +92,7 @@ func TestSharedBufferConcurrentReadWrite(t *testing.T) {
|
|||
// Change block sizes in opposition to the write thread, to test blocking for new data.
|
||||
for blockSize := 256; blockSize > 0; blockSize-- {
|
||||
for i := 0; i < 256; i++ {
|
||||
n, err := b.Read(block[:blockSize])
|
||||
n, err := io.ReadFull(b, block[:blockSize])
|
||||
if n != blockSize || err != nil {
|
||||
t.Fatalf("read error: %d %s", n, err)
|
||||
}
|
||||
|
|
|
@ -11,9 +11,9 @@ import (
|
|||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/trace"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
|
||||
"github.com/cloudflare/cloudflare-warp/h2mux"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
|
@ -249,7 +248,7 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
|
|||
if _, ok := t.concurrentRequests[connectionID]; ok {
|
||||
t.concurrentRequests[connectionID] -= 1
|
||||
} else {
|
||||
log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.")
|
||||
Log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.")
|
||||
}
|
||||
t.concurrentRequestsLock.Unlock()
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
|
@ -73,7 +72,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
|
|||
case tunnelError := <-s.tunnelErrors:
|
||||
tunnelsActive--
|
||||
if tunnelError.err != nil {
|
||||
log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
|
||||
Log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
|
||||
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
||||
s.waitForNextTunnel(tunnelError.index)
|
||||
if backoffTimer == nil {
|
||||
|
@ -107,10 +106,10 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
|
|||
s.lastResolve = time.Now()
|
||||
s.resolverC = nil
|
||||
if result.err == nil {
|
||||
log.Debug("Service discovery refresh complete")
|
||||
Log.Debug("Service discovery refresh complete")
|
||||
s.edgeIPs = result.edgeIPs
|
||||
} else {
|
||||
log.WithError(result.err).Error("Service discovery error")
|
||||
Log.WithError(result.err).Error("Service discovery error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -120,12 +119,12 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) err
|
|||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
|
||||
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
|
||||
if err != nil {
|
||||
log.Infof("ResolveEdgeIPs err")
|
||||
Log.Infof("ResolveEdgeIPs err")
|
||||
return err
|
||||
}
|
||||
s.edgeIPs = edgeIPs
|
||||
if s.config.HAConnections > len(edgeIPs) {
|
||||
log.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
|
||||
Log.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
|
||||
s.config.HAConnections = len(edgeIPs)
|
||||
}
|
||||
s.lastResolve = time.Now()
|
||||
|
|
|
@ -19,14 +19,17 @@ import (
|
|||
"github.com/cloudflare/cloudflare-warp/tunnelrpc"
|
||||
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||
"github.com/cloudflare/cloudflare-warp/validation"
|
||||
"github.com/cloudflare/cloudflare-warp/websocket"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
raven "github.com/getsentry/raven-go"
|
||||
"github.com/pkg/errors"
|
||||
_ "github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sirupsen/logrus"
|
||||
rpc "zombiezen.com/go/capnproto2/rpc"
|
||||
)
|
||||
|
||||
var Log *logrus.Logger
|
||||
|
||||
const (
|
||||
dialTimeout = 15 * time.Second
|
||||
|
||||
|
@ -40,6 +43,7 @@ type TunnelConfig struct {
|
|||
Hostname string
|
||||
OriginCert []byte
|
||||
TlsConfig *tls.Config
|
||||
ClientTlsConfig *tls.Config
|
||||
Retries uint
|
||||
HeartbeatInterval time.Duration
|
||||
MaxHeartbeats uint64
|
||||
|
@ -51,7 +55,9 @@ type TunnelConfig struct {
|
|||
HTTPTransport http.RoundTripper
|
||||
Metrics *TunnelMetrics
|
||||
MetricsUpdateFreq time.Duration
|
||||
ProtocolLogger *log.Logger
|
||||
ProtocolLogger *logrus.Logger
|
||||
Logger *logrus.Logger
|
||||
IsAutoupdated bool
|
||||
}
|
||||
|
||||
type dialError struct {
|
||||
|
@ -87,14 +93,16 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
|
|||
Version: c.ReportedVersion,
|
||||
OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH),
|
||||
ExistingTunnelPolicy: policy,
|
||||
PoolID: c.LBPool,
|
||||
PoolName: c.LBPool,
|
||||
Tags: c.Tags,
|
||||
ConnectionID: connectionID,
|
||||
OriginLocalIP: OriginLocalIP,
|
||||
IsAutoupdated: c.IsAutoupdated,
|
||||
}
|
||||
}
|
||||
|
||||
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error {
|
||||
Log = config.Logger
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
<-shutdownC
|
||||
|
@ -129,7 +137,7 @@ func ServeTunnelLoop(ctx context.Context, config *TunnelConfig, addr *net.TCPAdd
|
|||
err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff)
|
||||
if recoverable {
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||
log.Infof("Retrying in %s seconds", duration)
|
||||
Log.Infof("Retrying in %s seconds", duration)
|
||||
backoff.Backoff(ctx)
|
||||
continue
|
||||
}
|
||||
|
@ -162,11 +170,10 @@ func ServeTunnel(
|
|||
// Returns error from parsing the origin URL or handshake errors
|
||||
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID)
|
||||
if err != nil {
|
||||
errLog := log.WithError(err)
|
||||
errLog := Log.WithError(err)
|
||||
switch err.(type) {
|
||||
case dialError:
|
||||
errLog.Error("Unable to dial edge")
|
||||
return err, false
|
||||
case h2mux.MuxerHandshakeError:
|
||||
errLog.Error("Handshake failed with edge server")
|
||||
default:
|
||||
|
@ -207,24 +214,21 @@ func ServeTunnel(
|
|||
registerErr := <-registerErrC
|
||||
wg.Wait()
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Tunnel error")
|
||||
Log.WithError(err).Error("Tunnel error")
|
||||
return err, true
|
||||
}
|
||||
if registerErr != nil {
|
||||
// Don't retry on errors like entitlement failure or version too old
|
||||
if e, ok := registerErr.(printableRegisterTunnelError); ok {
|
||||
log.Error(e)
|
||||
if e.permanent {
|
||||
return e, false
|
||||
}
|
||||
return e.cause, true
|
||||
Log.Error(e)
|
||||
return e.cause, !e.permanent
|
||||
} else if e, ok := registerErr.(dupConnRegisterTunnelError); ok {
|
||||
log.Info("Already connected to this server, selecting a different one")
|
||||
Log.Info("Already connected to this server, selecting a different one")
|
||||
return e, true
|
||||
}
|
||||
// Only log errors to Sentry that may have been caused by the client side, to reduce dupes
|
||||
raven.CaptureError(registerErr, nil)
|
||||
log.Error("Cannot register")
|
||||
Log.Error("Cannot register")
|
||||
return err, true
|
||||
}
|
||||
return nil, false
|
||||
|
@ -241,7 +245,7 @@ func IsRPCStreamResponse(headers []h2mux.Header) bool {
|
|||
}
|
||||
|
||||
func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error {
|
||||
logger := log.WithField("subsystem", "rpc")
|
||||
logger := Log.WithField("subsystem", "rpc")
|
||||
logger.Debug("initiating RPC stream")
|
||||
stream, err := muxer.OpenStream([]h2mux.Header{
|
||||
{Name: ":method", Value: "RPC"},
|
||||
|
@ -292,11 +296,11 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
|
|||
}
|
||||
}
|
||||
|
||||
log.Infof("Registered at %s", registration.Url)
|
||||
Log.Infof("Registered at %s", registration.Url)
|
||||
return nil
|
||||
}
|
||||
|
||||
func LogServerInfo(logger *log.Entry,
|
||||
func LogServerInfo(logger *logrus.Entry,
|
||||
promise tunnelrpc.ServerInfo_Promise,
|
||||
connectionID uint8,
|
||||
metrics *TunnelMetrics,
|
||||
|
@ -311,7 +315,7 @@ func LogServerInfo(logger *log.Entry,
|
|||
logger.WithError(err).Warn("Failed to retrieve server information")
|
||||
return
|
||||
}
|
||||
log.Infof("Connected to %s", serverInfo.LocationName)
|
||||
Log.Infof("Connected to %s", serverInfo.LocationName)
|
||||
metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName)
|
||||
}
|
||||
|
||||
|
@ -356,6 +360,7 @@ type TunnelHandler struct {
|
|||
originUrl string
|
||||
muxer *h2mux.Muxer
|
||||
httpClient http.RoundTripper
|
||||
tlsConfig *tls.Config
|
||||
tags []tunnelpogs.Tag
|
||||
metrics *TunnelMetrics
|
||||
// connectionID is only used by metrics, and prometheus requires labels to be string
|
||||
|
@ -373,6 +378,7 @@ func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr string, co
|
|||
h := &TunnelHandler{
|
||||
originUrl: url,
|
||||
httpClient: config.HTTPTransport,
|
||||
tlsConfig: config.ClientTlsConfig,
|
||||
tags: config.Tags,
|
||||
metrics: config.Metrics,
|
||||
connectionID: uint8ToString(connectionID),
|
||||
|
@ -422,29 +428,45 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
|||
h.metrics.incrementRequests(h.connectionID)
|
||||
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||
if err != nil {
|
||||
log.WithError(err).Panic("Unexpected error from http.NewRequest")
|
||||
Log.WithError(err).Panic("Unexpected error from http.NewRequest")
|
||||
}
|
||||
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("invalid request received")
|
||||
Log.WithError(err).Error("invalid request received")
|
||||
}
|
||||
h.AppendTagHeaders(req)
|
||||
|
||||
if websocket.IsWebSocketUpgrade(req) {
|
||||
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
|
||||
if err != nil {
|
||||
h.logError(stream, err)
|
||||
} else {
|
||||
stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||
defer conn.Close()
|
||||
websocket.Stream(conn.UnderlyingConn(), stream)
|
||||
}
|
||||
} else {
|
||||
response, err := h.httpClient.RoundTrip(req)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("HTTP request error")
|
||||
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
|
||||
stream.Write([]byte("502 Bad Gateway"))
|
||||
h.metrics.incrementResponses(h.connectionID, "502")
|
||||
h.logError(stream, err)
|
||||
} else {
|
||||
defer response.Body.Close()
|
||||
stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||
io.Copy(stream, response.Body)
|
||||
h.metrics.incrementResponses(h.connectionID, "200")
|
||||
}
|
||||
}
|
||||
h.metrics.decrementConcurrentRequests(h.connectionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
|
||||
Log.WithError(err).Error("HTTP request error")
|
||||
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
|
||||
stream.Write([]byte("502 Bad Gateway"))
|
||||
h.metrics.incrementResponses(h.connectionID, "502")
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) UpdateMetrics() {
|
||||
flowCtlMetrics := h.muxer.FlowControlMetrics()
|
||||
h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics)
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
package tlsconfig
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
// TODO: remove the Origin CA root certs when migrated to Authenticated Origin Pull certs
|
||||
const cloudflareRootCA = `
|
||||
Issuer: C=US, ST=California, L=San Francisco, O=CloudFlare, Inc., OU=CloudFlare Origin SSL ECC Certificate Authority
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICiDCCAi6gAwIBAgIUXZP3MWb8MKwBE1Qbawsp1sfA/Y4wCgYIKoZIzj0EAwIw
|
||||
gY8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1T
|
||||
YW4gRnJhbmNpc2NvMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTgwNgYDVQQL
|
||||
Ey9DbG91ZEZsYXJlIE9yaWdpbiBTU0wgRUNDIENlcnRpZmljYXRlIEF1dGhvcml0
|
||||
eTAeFw0xNjAyMjIxODI0MDBaFw0yMTAyMjIwMDI0MDBaMIGPMQswCQYDVQQGEwJV
|
||||
UzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEZ
|
||||
MBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE4MDYGA1UECxMvQ2xvdWRGbGFyZSBP
|
||||
cmlnaW4gU1NMIEVDQyBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkwWTATBgcqhkjOPQIB
|
||||
BggqhkjOPQMBBwNCAASR+sGALuaGshnUbcxKry+0LEXZ4NY6JUAtSeA6g87K3jaA
|
||||
xpIg9G50PokpfWkhbarLfpcZu0UAoYy2su0EhN7wo2YwZDAOBgNVHQ8BAf8EBAMC
|
||||
AQYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUhTBdOypw1O3VkmcH/es5
|
||||
tBoOOKcwHwYDVR0jBBgwFoAUhTBdOypw1O3VkmcH/es5tBoOOKcwCgYIKoZIzj0E
|
||||
AwIDSAAwRQIgEiIEHQr5UKma50D1WRMJBUSgjg24U8n8E2mfw/8UPz0CIQCr5V/e
|
||||
mcifak4CQsr+DH4pn5SJD7JxtCG3YGswW8QZsw==
|
||||
-----END CERTIFICATE-----
|
||||
Issuer: C=US, O=CloudFlare, Inc., OU=CloudFlare Origin SSL Certificate Authority, L=San Francisco, ST=California
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIID/DCCAuagAwIBAgIID+rOSdTGfGcwCwYJKoZIhvcNAQELMIGLMQswCQYDVQQG
|
||||
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRG
|
||||
bGFyZSBPcmlnaW4gU1NMIENlcnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMN
|
||||
U2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTAeFw0xNDExMTMyMDM4
|
||||
NTBaFw0xOTExMTQwMTQzNTBaMIGLMQswCQYDVQQGEwJVUzEZMBcGA1UEChMQQ2xv
|
||||
dWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRGbGFyZSBPcmlnaW4gU1NMIENl
|
||||
cnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEG
|
||||
A1UECBMKQ2FsaWZvcm5pYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
|
||||
AMBIlWf1KEKR5hbB75OYrAcUXobpD/AxvSYRXr91mbRu+lqE7YbyyRUShQh15lem
|
||||
ef+umeEtPZoLFLhcLyczJxOhI+siLGDQm/a/UDkWvAXYa5DZ+pHU5ct5nZ8pGzqJ
|
||||
p8G1Hy5RMVYDXZT9F6EaHjMG0OOffH6Ih25TtgfyyrjXycwDH0u6GXt+G/rywcqz
|
||||
/9W4Aki3XNQMUHNQAtBLEEIYHMkyTYJxuL2tXO6ID5cCsoWw8meHufTeZW2DyUpl
|
||||
yP3AHt4149RQSyWZMJ6AyntL9d8Xhfpxd9rJkh9Kge2iV9rQTFuE1rRT5s7OSJcK
|
||||
xUsklgHcGHYMcNfNMilNHb8CAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgAGMBIGA1Ud
|
||||
EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFCToU1ddfDRAh6nrlNu64RZ4/CmkMB8G
|
||||
A1UdIwQYMBaAFCToU1ddfDRAh6nrlNu64RZ4/CmkMAsGCSqGSIb3DQEBCwOCAQEA
|
||||
cQDBVAoRrhhsGegsSFsv1w8v27zzHKaJNv6ffLGIRvXK8VKKK0gKXh2zQtN9SnaD
|
||||
gYNe7Pr4C3I8ooYKRJJWLsmEHdGdnYYmj0OJfGrfQf6MLIc/11bQhLepZTxdhFYh
|
||||
QGgDl6gRmb8aDwk7Q92BPvek5nMzaWlP82ixavvYI+okoSY8pwdcVKobx6rWzMWz
|
||||
ZEC9M6H3F0dDYE23XcCFIdgNSAmmGyXPBstOe0aAJXwJTxOEPn36VWr0PKIQJy5Y
|
||||
4o1wpMpqCOIwWc8J9REV/REzN6Z1LXImdUgXIXOwrz56gKUJzPejtBQyIGj0mveX
|
||||
Fu6q54beR89jDc+oABmOgg==
|
||||
-----END CERTIFICATE-----
|
||||
Issuer: C=US, O=CloudFlare, Inc., OU=Origin Pull, L=San Francisco, ST=California, CN=origin-pull.cloudflare.net
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIGBjCCA/CgAwIBAgIIV5G6lVbCLmEwCwYJKoZIhvcNAQENMIGQMQswCQYDVQQG
|
||||
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjEUMBIGA1UECxMLT3JpZ2lu
|
||||
IFB1bGwxFjAUBgNVBAcTDVNhbiBGcmFuY2lzY28xEzARBgNVBAgTCkNhbGlmb3Ju
|
||||
aWExIzAhBgNVBAMTGm9yaWdpbi1wdWxsLmNsb3VkZmxhcmUubmV0MB4XDTE1MDEx
|
||||
MzAyNDc1M1oXDTIwMDExMjAyNTI1M1owgZAxCzAJBgNVBAYTAlVTMRkwFwYDVQQK
|
||||
ExBDbG91ZEZsYXJlLCBJbmMuMRQwEgYDVQQLEwtPcmlnaW4gUHVsbDEWMBQGA1UE
|
||||
BxMNU2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTEjMCEGA1UEAxMa
|
||||
b3JpZ2luLXB1bGwuY2xvdWRmbGFyZS5uZXQwggIiMA0GCSqGSIb3DQEBAQUAA4IC
|
||||
DwAwggIKAoICAQDdsts6I2H5dGyn4adACQRXlfo0KmwsN7B5rxD8C5qgy6spyONr
|
||||
WV0ecvdeGQfWa8Gy/yuTuOnsXfy7oyZ1dm93c3Mea7YkM7KNMc5Y6m520E9tHooc
|
||||
f1qxeDpGSsnWc7HWibFgD7qZQx+T+yfNqt63vPI0HYBOYao6hWd3JQhu5caAcIS2
|
||||
ms5tzSSZVH83ZPe6Lkb5xRgLl3eXEFcfI2DjnlOtLFqpjHuEB3Tr6agfdWyaGEEi
|
||||
lRY1IB3k6TfLTaSiX2/SyJ96bp92wvTSjR7USjDV9ypf7AD6u6vwJZ3bwNisNw5L
|
||||
ptph0FBnc1R6nDoHmvQRoyytoe0rl/d801i9Nru/fXa+l5K2nf1koR3IX440Z2i9
|
||||
+Z4iVA69NmCbT4MVjm7K3zlOtwfI7i1KYVv+ATo4ycgBuZfY9f/2lBhIv7BHuZal
|
||||
b9D+/EK8aMUfjDF4icEGm+RQfExv2nOpkR4BfQppF/dLmkYfjgtO1403X0ihkT6T
|
||||
PYQdmYS6Jf53/KpqC3aA+R7zg2birtvprinlR14MNvwOsDOzsK4p8WYsgZOR4Qr2
|
||||
gAx+z2aVOs/87+TVOR0r14irQsxbg7uP2X4t+EXx13glHxwG+CnzUVycDLMVGvuG
|
||||
aUgF9hukZxlOZnrl6VOf1fg0Caf3uvV8smOkVw6DMsGhBZSJVwao0UQNqQIDAQAB
|
||||
o2YwZDAOBgNVHQ8BAf8EBAMCAAYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4E
|
||||
FgQUQ1lLK2mLgOERM2pXzVc42p59xeswHwYDVR0jBBgwFoAUQ1lLK2mLgOERM2pX
|
||||
zVc42p59xeswCwYJKoZIhvcNAQENA4ICAQDKDQM1qPRVP/4Gltz0D6OU6xezFBKr
|
||||
LWtDoA1qW2F7pkiYawCP9MrDPDJsHy7dx+xw3bBZxOsK5PA/T7p1dqpEl6i8F692
|
||||
g//EuYOifLYw3ySPe3LRNhvPl/1f6Sn862VhPvLa8aQAAwR9e/CZvlY3fj+6G5ik
|
||||
3it7fikmKUsVnugNOkjmwI3hZqXfJNc7AtHDFw0mEOV0dSeAPTo95N9cxBbm9PKv
|
||||
qAEmTEXp2trQ/RjJ/AomJyfA1BQjsD0j++DI3a9/BbDwWmr1lJciKxiNKaa0BRLB
|
||||
dKMrYQD+PkPNCgEuojT+paLKRrMyFUzHSG1doYm46NE9/WARTh3sFUp1B7HZSBqA
|
||||
kHleoB/vQ/mDuW9C3/8Jk2uRUdZxR+LoNZItuOjU8oTy6zpN1+GgSj7bHjiy9rfA
|
||||
F+ehdrz+IOh80WIiqs763PGoaYUyzxLvVowLWNoxVVoc9G+PqFKqD988XlipHVB6
|
||||
Bz+1CD4D/bWrs3cC9+kk/jFmrrAymZlkFX8tDb5aXASSLJjUjcptci9SKqtI2h0J
|
||||
wUGkD7+bQAr+7vr8/R+CBmNMe7csE8NeEX6lVMF7Dh0a1YKQa6hUN18bBuYgTMuT
|
||||
QzMmZpRpIBB321ZBlcnlxiTJvWxvbCPHKHj20VwwAz7LONF59s84ZsOqfoBv8gKM
|
||||
s0s5dsq5zpLeaw==
|
||||
-----END CERTIFICATE-----`
|
||||
|
||||
func GetCloudflareRootCA() *x509.CertPool {
|
||||
ca := x509.NewCertPool()
|
||||
if !ca.AppendCertsFromPEM([]byte(cloudflareRootCA)) {
|
||||
// should never happen
|
||||
panic("failure loading Cloudflare origin CA pem")
|
||||
}
|
||||
return ca
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
package tlsconfig
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
const (
|
||||
helloKey = `
|
||||
-----BEGIN EC PARAMETERS-----
|
||||
BgUrgQQAIg==
|
||||
-----END EC PARAMETERS-----
|
||||
-----BEGIN EC PRIVATE KEY-----
|
||||
MIGkAgEBBDAdyQBXfxTDCQSOT0HugmH9pVBtIw8t5dYvm6HxGlNq6P57v5GeN02Z
|
||||
dH9FRl7+VSWgBwYFK4EEACKhZANiAATqpFzTxxV7D+/oqhKCTR6BEM9elTfKaRQE
|
||||
FsLufcmaTMw/9tTwgpHKao/QsLKDTNbQhbSQLkcmpCQKlSGhl+pCrqNt/oYUAhav
|
||||
UIwpwGiLCqGH/R2AqWLKRPOa/Rufs/U=
|
||||
-----END EC PRIVATE KEY-----`
|
||||
|
||||
helloCRT = `
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICkDCCAhigAwIBAgIJAPtKfUjc2lwGMAkGByqGSM49BAEwgYoxCzAJBgNVBAYT
|
||||
AlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQKDBBD
|
||||
bG91ZGZsYXJlLCBJbmMuMT8wPQYDVQQDDDZDbG91ZGZsYXJlIEFyZ28gVHVubmVs
|
||||
IFNhbXBsZSBIZWxsbyBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMjE1MjAxNjU5
|
||||
WhcNMjgwMjEzMjAxNjU5WjCBijELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVRleGFz
|
||||
MQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENsb3VkZmxhcmUsIEluYy4xPzA9
|
||||
BgNVBAMMNkNsb3VkZmxhcmUgQXJnbyBUdW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZl
|
||||
ciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49AgEGBSuBBAAiA2IABOqkXNPHFXsP7+iq
|
||||
EoJNHoEQz16VN8ppFAQWwu59yZpMzD/21PCCkcpqj9CwsoNM1tCFtJAuRyakJAqV
|
||||
IaGX6kKuo23+hhQCFq9QjCnAaIsKoYf9HYCpYspE85r9G5+z9aNJMEcwRQYDVR0R
|
||||
BD4wPIIJbG9jYWxob3N0ggp3YXJwLWhlbGxvggt3YXJwMi1oZWxsb4cEfwAAAYcQ
|
||||
AAAAAAAAAAAAAAAAAAAAATAJBgcqhkjOPQQBA2cAMGQCMHyVPufXZ6vQo6XRWRa0
|
||||
dAwtfgesOdZVP2Wt+t5v8jOIQQh1IQXYk5GtyoZGSObjhQIwd1fRgAyKXaZt+1DV
|
||||
ZtHTdf8pMvESfJsSd8AB1eQ6q+pAiRUYyaxcE1Mlo2YY5o+g
|
||||
-----END CERTIFICATE-----`
|
||||
)
|
||||
|
||||
func GetHelloCertificate() (tls.Certificate, error) {
|
||||
return tls.X509KeyPair([]byte(helloCRT), []byte(helloKey))
|
||||
}
|
||||
|
||||
func GetHelloCertificateX509() (*x509.Certificate, error) {
|
||||
helloCertificate, err := GetHelloCertificate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return x509.ParseCertificate(helloCertificate.Certificate[0])
|
||||
}
|
|
@ -6,8 +6,9 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
||||
|
@ -60,3 +61,43 @@ func LoadCert(certPath string) *x509.CertPool {
|
|||
}
|
||||
return ca
|
||||
}
|
||||
|
||||
func LoadOriginCertsPool() *x509.CertPool {
|
||||
// First, obtain the system certificate pool
|
||||
certPool, systemCertPoolErr := x509.SystemCertPool()
|
||||
if systemCertPoolErr != nil {
|
||||
log.Warn("error obtaining the system certificates: %s", systemCertPoolErr)
|
||||
certPool = x509.NewCertPool()
|
||||
}
|
||||
|
||||
// Next, append the Cloudflare CA pool into the system pool
|
||||
if !certPool.AppendCertsFromPEM([]byte(cloudflareRootCA)) {
|
||||
log.Warn("could not append the CF certificate to the system certificate pool")
|
||||
|
||||
if systemCertPoolErr != nil { // Obtaining both certificates failed; this is a fatal error
|
||||
log.WithError(systemCertPoolErr).Fatalf("Error loading the certificate pool")
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, add the Hello certificate into the pool (since it's self-signed)
|
||||
helloCertificate, err := GetHelloCertificateX509()
|
||||
if err != nil {
|
||||
log.Warn("error obtaining the Hello server certificate")
|
||||
}
|
||||
|
||||
certPool.AddCert(helloCertificate)
|
||||
|
||||
return certPool
|
||||
}
|
||||
|
||||
func CreateTunnelConfig(c *cli.Context, addrs []string) *tls.Config {
|
||||
tlsConfig := CLIFlags{RootCA: "cacert"}.GetConfig(c)
|
||||
if tlsConfig.RootCAs == nil {
|
||||
tlsConfig.RootCAs = GetCloudflareRootCA()
|
||||
tlsConfig.ServerName = "cftunnel.com"
|
||||
} else if len(addrs) > 0 {
|
||||
// Set for development environments and for testing specific origintunneld instances
|
||||
tlsConfig.ServerName, _, _ = net.SplitHostPort(addrs[0])
|
||||
}
|
||||
return tlsConfig
|
||||
}
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
package tunnelrpc
|
||||
|
||||
//go:generate capnp compile -ogo -I./tunnelrpc/ tunnelrpc.capnp
|
||||
|
||||
import (
|
||||
log "github.com/Sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/trace"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
)
|
||||
|
||||
|
@ -24,3 +23,20 @@ func (c ConnLogger) Errorf(ctx context.Context, format string, args ...interface
|
|||
func ConnLog(log *log.Entry) rpc.ConnOption {
|
||||
return rpc.ConnLog(ConnLogger{log})
|
||||
}
|
||||
|
||||
// ConnTracer wraps a trace.EventLog for a connection.
|
||||
type ConnTracer struct {
|
||||
Events trace.EventLog
|
||||
}
|
||||
|
||||
func (c ConnTracer) Infof(ctx context.Context, format string, args ...interface{}) {
|
||||
c.Events.Printf(format, args...)
|
||||
}
|
||||
|
||||
func (c ConnTracer) Errorf(ctx context.Context, format string, args ...interface{}) {
|
||||
c.Events.Errorf(format, args...)
|
||||
}
|
||||
|
||||
func ConnTrace(events trace.EventLog) rpc.ConnOption {
|
||||
return rpc.ConnLog(ConnTracer{events})
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ package tunnelrpc
|
|||
import (
|
||||
"bytes"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2/encoding/text"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
|
|
|
@ -47,10 +47,11 @@ type RegistrationOptions struct {
|
|||
Version string
|
||||
OS string `capnp:"os"`
|
||||
ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy
|
||||
PoolID string `capnp:"poolId"`
|
||||
PoolName string `capnp:"poolName"`
|
||||
Tags []Tag
|
||||
ConnectionID uint8 `capnp:"connectionId"`
|
||||
OriginLocalIP string `capnp:"originLocalIp"`
|
||||
IsAutoupdated bool `capnp:"isAutoupdated"`
|
||||
}
|
||||
|
||||
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {
|
||||
|
|
|
@ -28,13 +28,15 @@ struct RegistrationOptions {
|
|||
# What to do with existing tunnels for the given hostname.
|
||||
existingTunnelPolicy @3 :ExistingTunnelPolicy;
|
||||
# If using the balancing policy, identifies the LB pool to use.
|
||||
poolId @4 :Text;
|
||||
poolName @4 :Text;
|
||||
# Client-defined tags to associate with the tunnel
|
||||
tags @5 :List(Tag);
|
||||
# A unique identifier for a high-availability connection made by a single client.
|
||||
connectionId @6 :UInt8;
|
||||
# origin LAN IP
|
||||
originLocalIp @7 :Text;
|
||||
# whether Warp client has been autoupdated
|
||||
isAutoupdated @8 :Bool;
|
||||
}
|
||||
|
||||
struct Tag {
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/sha1"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
|
||||
func IsWebSocketUpgrade(req *http.Request) bool {
|
||||
return websocket.IsWebSocketUpgrade(req)
|
||||
}
|
||||
|
||||
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing.
|
||||
func ClientConnect(req *http.Request, tlsClientConfig *tls.Config) (*websocket.Conn, *http.Response, error) {
|
||||
req.URL.Scheme = "wss"
|
||||
d := &websocket.Dialer{TLSClientConfig: tlsClientConfig}
|
||||
conn, response, err := d.Dial(req.URL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
|
||||
return conn, response, err
|
||||
}
|
||||
|
||||
// HijackConnection takes over an HTTP connection. Caller is responsible for closing connection.
|
||||
func HijackConnection(w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, errors.New("hijack error")
|
||||
}
|
||||
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, brw, nil
|
||||
}
|
||||
|
||||
// Stream copies copy data to & from provided io.ReadWriters.
|
||||
func Stream(conn, backendConn io.ReadWriter) {
|
||||
proxyDone := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
io.Copy(conn, backendConn)
|
||||
proxyDone <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
io.Copy(backendConn, conn)
|
||||
proxyDone <- struct{}{}
|
||||
}()
|
||||
|
||||
// If one side is done, we are done.
|
||||
<-proxyDone
|
||||
}
|
||||
|
||||
// sha1Base64 sha1 and then base64 encodes str.
|
||||
func sha1Base64(str string) string {
|
||||
hasher := sha1.New()
|
||||
io.WriteString(hasher, str)
|
||||
hash := hasher.Sum(nil)
|
||||
return base64.StdEncoding.EncodeToString(hash)
|
||||
}
|
||||
|
||||
// generateAcceptKey returns the string needed for the Sec-WebSocket-Accept header.
|
||||
// https://tools.ietf.org/html/rfc6455#section-1.3 describes this process in more detail.
|
||||
func generateAcceptKey(req *http.Request) string {
|
||||
return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
}
|
Loading…
Reference in New Issue