Release Argo Tunnel Client 2018.5.0

This commit is contained in:
cloudflare-warp-bot 2018-05-03 22:32:30 +00:00
parent a1a45b0f63
commit 9135a4837c
24 changed files with 1425 additions and 655 deletions

View File

@ -0,0 +1,297 @@
package main
import (
"crypto/tls"
"crypto/x509"
"encoding/hex"
"fmt"
"io/ioutil"
"math/rand"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/tlsconfig"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation"
"github.com/sirupsen/logrus"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
"github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
)
var (
defaultConfigFiles = []string{"config.yml", "config.yaml"}
// Launchd doesn't set root env variables, so there is default
// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible
defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp", "/usr/local/etc/cloudflared", "/etc/cloudflared"}
)
const defaultCredentialFile = "cert.pem"
func fileExists(path string) (bool, error) {
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
// ignore missing files
return false, nil
}
return false, err
}
f.Close()
return true, nil
}
// returns the first path that contains a cert.pem file. If none of the defaultConfigDirs
// (differs by OS for legacy reasons) contains a cert.pem file, return empty string
func findDefaultOriginCertPath() string {
for _, defaultConfigDir := range defaultConfigDirs {
originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, defaultCredentialFile))
if ok, _ := fileExists(originCertPath); ok {
return originCertPath
}
}
return ""
}
// returns the first path that contains a config file. If none of the combination of
// defaultConfigDirs (differs by OS for legacy reasons) and defaultConfigFiles
// contains a config file, return empty string
func findDefaultConfigPath() string {
for _, configDir := range defaultConfigDirs {
for _, configFile := range defaultConfigFiles {
dirPath, err := homedir.Expand(configDir)
if err != nil {
continue
}
path := filepath.Join(dirPath, configFile)
if ok, _ := fileExists(path); ok {
return path
}
}
}
return ""
}
func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) {
if context.String("config") != "" {
return altsrc.NewYamlSourceFromFile(context.String("config"))
}
return nil, nil
}
func generateRandomClientID() string {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
id := make([]byte, 32)
r.Read(id)
return hex.EncodeToString(id)
}
func enoughOptionsSet(c *cli.Context) bool {
// For cloudflared to work, the user needs to at least provide a hostname,
// or runs as stand alone DNS proxy .
// When using sudo, use -E flag to preserve env vars
if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" && os.Getenv("TUNNEL_DNS") == "" {
if isRunningFromTerminal() {
logger.Errorf("No arguments were provided. You need to at least specify the hostname for this tunnel. See %s", quickStartUrl)
logger.Infof("If you want to run Argo Tunnel client as a stand alone DNS proxy, run with --proxy-dns option or set TUNNEL_DNS environment variable.")
} else {
logger.Errorf("You need to specify all the options in a configuration file, or use environment variables. See %s and %s", serviceUrl, argumentsUrl)
logger.Infof("If you want to run Argo Tunnel client as a stand alone DNS proxy, specify proxy-dns option in the configuration file, or set TUNNEL_DNS environment variable.")
}
cli.ShowAppHelp(c)
return false
}
return true
}
func handleDeprecatedOptions(c *cli.Context) {
// Fail if the user provided an old authentication method
if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
logger.Fatal("You don't need to give us your api-key anymore. Please use the new login method. Just run cloudflared login")
}
}
// validate url. It can be either from --url or argument
func validateUrl(c *cli.Context) (string, error) {
var url = c.String("url")
if c.NArg() > 0 {
if c.IsSet("url") {
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
}
url = c.Args().Get(0)
}
validUrl, err := validation.ValidateUrl(url)
return validUrl, err
}
func logClientOptions(c *cli.Context) {
flags := make(map[string]interface{})
for _, flag := range c.LocalFlagNames() {
flags[flag] = c.Generic(flag)
}
if len(flags) > 0 {
logger.Infof("Flags %v", flags)
}
envs := make(map[string]string)
// Find env variables for Argo Tunnel
for _, env := range os.Environ() {
// All Argo Tunnel env variables start with TUNNEL_
if strings.Contains(env, "TUNNEL_") {
vars := strings.Split(env, "=")
if len(vars) == 2 {
envs[vars[0]] = vars[1]
}
}
}
if len(envs) > 0 {
logger.Infof("Environmental variables %v", envs)
}
}
func dnsProxyStandAlone(c *cli.Context) bool {
return c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world"))
}
func getOriginCert(c *cli.Context) []byte {
if c.String("origincert") == "" {
logger.Warnf("Cannot determine default origin certificate path. No file %s in %v", defaultCredentialFile, defaultConfigDirs)
if isRunningFromTerminal() {
logger.Fatalf("You need to specify the origin certificate path with --origincert option, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", argumentsUrl)
} else {
logger.Fatalf("You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", serviceUrl)
}
}
// Check that the user has acquired a certificate using the login command
originCertPath, err := homedir.Expand(c.String("origincert"))
if err != nil {
logger.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
}
ok, err := fileExists(originCertPath)
if err != nil {
logger.Fatalf("Cannot check if origin cert exists at path %s", c.String("origincert"))
}
if !ok {
logger.Fatalf(`Cannot find a valid certificate for your origin at the path:
%s
If the path above is wrong, specify the path with the -origincert option.
If you don't have a certificate signed by Cloudflare, run the command:
%s login
`, originCertPath, os.Args[0])
}
// Easier to send the certificate as []byte via RPC than decoding it at this point
originCert, err := ioutil.ReadFile(originCertPath)
if err != nil {
logger.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
}
return originCert
}
func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, logger, protoLogger *logrus.Logger) *origin.TunnelConfig {
hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil {
logger.WithError(err).Fatal("Invalid hostname")
}
clientID := c.String("id")
if !c.IsSet("id") {
clientID = generateRandomClientID()
}
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
if err != nil {
logger.WithError(err).Fatal("Tag parse failure")
}
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
url, err := validateUrl(c)
if err != nil {
logger.WithError(err).Fatal("Error validating url")
}
logger.Infof("Proxying tunnel requests to %s", url)
originCert := getOriginCert(c)
originCertPool, err := loadCertPool(c)
if err != nil {
logger.Fatal(err)
}
tunnelMetrics := origin.NewTunnelMetrics()
httpTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: c.Duration("proxy-connect-timeout"),
KeepAlive: c.Duration("proxy-tcp-keepalive"),
DualStack: !c.Bool("proxy-no-happy-eyeballs"),
}).DialContext,
MaxIdleConns: c.Int("proxy-keepalive-connections"),
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{RootCAs: originCertPool},
}
if !c.IsSet("hello-world") && c.IsSet("origin-server-name") {
httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
}
return &origin.TunnelConfig{
EdgeAddrs: c.StringSlice("edge"),
OriginUrl: url,
Hostname: hostname,
OriginCert: originCert,
TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
ClientTlsConfig: httpTransport.TLSClientConfig,
Retries: c.Uint("retries"),
HeartbeatInterval: c.Duration("heartbeat-interval"),
MaxHeartbeats: c.Uint64("heartbeat-count"),
ClientID: clientID,
BuildInfo: buildInfo,
ReportedVersion: Version,
LBPool: c.String("lb-pool"),
Tags: tags,
HAConnections: c.Int("ha-connections"),
HTTPTransport: httpTransport,
Metrics: tunnelMetrics,
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
ProtocolLogger: protoLogger,
Logger: logger,
IsAutoupdated: c.Bool("is-autoupdated"),
GracePeriod: c.Duration("grace-period"),
RunFromTerminal: isRunningFromTerminal(),
}
}
func loadCertPool(c *cli.Context) (*x509.CertPool, error) {
const originCAPoolFlag = "origin-ca-pool"
originCAPoolFilename := c.String(originCAPoolFlag)
var originCustomCAPool []byte
if originCAPoolFilename != "" {
var err error
originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, originCAPoolFlag))
}
}
originCertPool, err := tlsconfig.LoadOriginCertPool(originCustomCAPool)
if err != nil {
return nil, errors.Wrap(err, "error loading the certificate pool")
}
return originCertPool, nil
}

View File

@ -8,6 +8,6 @@ import (
cli "gopkg.in/urfave/cli.v2"
)
func runApp(app *cli.App) {
func runApp(app *cli.App, shutdownC chan struct{}) {
app.Run(os.Args)
}

View File

@ -1,204 +1,21 @@
package main
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"html/template"
"io/ioutil"
"net"
"net/http"
"os"
"time"
"github.com/gorilla/websocket"
"gopkg.in/urfave/cli.v2"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/hello"
)
type templateData struct {
ServerName string
Request *http.Request
Body string
}
type OriginUpTime struct {
StartTime time.Time `json:"startTime"`
UpTime string `json:"uptime"`
}
const defaultServerName = "the Argo Tunnel test server"
const indexTemplate = `
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=Edge">
<title>
Argo Tunnel Connection
</title>
<meta name="author" content="">
<meta name="description" content="Argo Tunnel Connection">
<meta name="viewport" content="width=device-width, initial-scale=1">
<style>
html{line-height:1.15;-ms-text-size-adjust:100%;-webkit-text-size-adjust:100%}body{margin:0}section{display:block}h1{font-size:2em;margin:.67em 0}a{background-color:transparent;-webkit-text-decoration-skip:objects}/* 1 */::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}/* 1 */a,body,dd,div,dl,dt,h1,h4,html,p,section{box-sizing:border-box}.bt{border-top-style:solid;border-top-width:1px}.bl{border-left-style:solid;border-left-width:1px}.b--orange{border-color:#f38020}.br1{border-radius:.125rem}.bw2{border-width:.25rem}.dib{display:inline-block}.sans-serif{font-family:open sans,-apple-system,BlinkMacSystemFont,avenir next,avenir,helvetica neue,helvetica,ubuntu,roboto,noto,segoe ui,arial,sans-serif}.code{font-family:Consolas,monaco,monospace}.b{font-weight:700}.fw3{font-weight:300}.fw4{font-weight:400}.fw5{font-weight:500}.fw6{font-weight:600}.lh-copy{line-height:1.5}.link{text-decoration:none}.link,.link:active,.link:focus,.link:hover,.link:link,.link:visited{transition:color .15s ease-in}.link:focus{outline:1px dotted currentColor}.mw-100{max-width:100%}.mw4{max-width:8rem}.mw7{max-width:48rem}.bg-light-gray{background-color:#f7f7f7}.link-hover:hover{background-color:#1f679e}.white{color:#fff}.bg-white{background-color:#fff}.bg-blue{background-color:#408bc9}.pb2{padding-bottom:.5rem}.pb6{padding-bottom:8rem}.pt3{padding-top:1rem}.pt5{padding-top:4rem}.pv2{padding-top:.5rem;padding-bottom:.5rem}.ph3{padding-left:1rem;padding-right:1rem}.ph4{padding-left:2rem;padding-right:2rem}.ml0{margin-left:0}.mb1{margin-bottom:.25rem}.mb2{margin-bottom:.5rem}.mb3{margin-bottom:1rem}.mt5{margin-top:4rem}.ttu{text-transform:uppercase}.f4{font-size:1.25rem}.f5{font-size:1rem}.f6{font-size:.875rem}.f7{font-size:.75rem}.measure{max-width:30em}.center{margin-left:auto}.center{margin-right:auto}@media screen and (min-width:30em){.f2-ns{font-size:2.25rem}}@media screen and (min-width:30em) and (max-width:60em){.f5-m{font-size:1rem}}@media screen and (min-width:60em){.f4-l{font-size:1.25rem}}
.st0{fill:#FFF}.st1{fill:#f48120}.st2{fill:#faad3f}.st3{fill:#404041}
</style>
</head>
<body class="sans-serif black">
<div class="bt bw2 b--orange bg-white pb6">
<div class="mw7 center ph4 pt3">
<svg id="Layer_2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 109 40.5" class="mw4">
<path class="st0" d="M98.6 14.2L93 12.9l-1-.4-25.7.2v12.4l32.3.1z"/>
<path class="st1" d="M88.1 24c.3-1 .2-2-.3-2.6-.5-.6-1.2-1-2.1-1.1l-17.4-.2c-.1 0-.2-.1-.3-.1-.1-.1-.1-.2 0-.3.1-.2.2-.3.4-.3l17.5-.2c2.1-.1 4.3-1.8 5.1-3.8l1-2.6c0-.1.1-.2 0-.3-1.1-5.1-5.7-8.9-11.1-8.9-5 0-9.3 3.2-10.8 7.7-1-.7-2.2-1.1-3.6-1-2.4.2-4.3 2.2-4.6 4.6-.1.6 0 1.2.1 1.8-3.9.1-7.1 3.3-7.1 7.3 0 .4 0 .7.1 1.1 0 .2.2.3.3.3h32.1c.2 0 .4-.1.4-.3l.3-1.1z"/>
<path class="st2" d="M93.6 12.8h-.5c-.1 0-.2.1-.3.2l-.7 2.4c-.3 1-.2 2 .3 2.6.5.6 1.2 1 2.1 1.1l3.7.2c.1 0 .2.1.3.1.1.1.1.2 0 .3-.1.2-.2.3-.4.3l-3.8.2c-2.1.1-4.3 1.8-5.1 3.8l-.2.9c-.1.1 0 .3.2.3h13.2c.2 0 .3-.1.3-.3.2-.8.4-1.7.4-2.6 0-5.2-4.3-9.5-9.5-9.5"/>
<path class="st3" d="M104.4 30.8c-.5 0-.9-.4-.9-.9s.4-.9.9-.9.9.4.9.9-.4.9-.9.9m0-1.6c-.4 0-.7.3-.7.7 0 .4.3.7.7.7.4 0 .7-.3.7-.7 0-.4-.3-.7-.7-.7m.4 1.2h-.2l-.2-.3h-.2v.3h-.2v-.9h.5c.2 0 .3.1.3.3 0 .1-.1.2-.2.3l.2.3zm-.3-.5c.1 0 .1 0 .1-.1s-.1-.1-.1-.1h-.3v.3h.3zM14.8 29H17v6h3.8v1.9h-6zM23.1 32.9c0-2.3 1.8-4.1 4.3-4.1s4.2 1.8 4.2 4.1-1.8 4.1-4.3 4.1c-2.4 0-4.2-1.8-4.2-4.1m6.3 0c0-1.2-.8-2.2-2-2.2s-2 1-2 2.1.8 2.1 2 2.1c1.2.2 2-.8 2-2M34.3 33.4V29h2.2v4.4c0 1.1.6 1.7 1.5 1.7s1.5-.5 1.5-1.6V29h2.2v4.4c0 2.6-1.5 3.7-3.7 3.7-2.3-.1-3.7-1.2-3.7-3.7M45 29h3.1c2.8 0 4.5 1.6 4.5 3.9s-1.7 4-4.5 4h-3V29zm3.1 5.9c1.3 0 2.2-.7 2.2-2s-.9-2-2.2-2h-.9v4h.9zM55.7 29H62v1.9h-4.1v1.3h3.7V34h-3.7v2.9h-2.2zM65.1 29h2.2v6h3.8v1.9h-6zM76.8 28.9H79l3.4 8H80l-.6-1.4h-3.1l-.6 1.4h-2.3l3.4-8zm2 4.9l-.9-2.2-.9 2.2h1.8zM85.2 29h3.7c1.2 0 2 .3 2.6.9.5.5.7 1.1.7 1.8 0 1.2-.6 2-1.6 2.4l1.9 2.8H90l-1.6-2.4h-1v2.4h-2.2V29zm3.6 3.8c.7 0 1.2-.4 1.2-.9 0-.6-.5-.9-1.2-.9h-1.4v1.9h1.4zM95.3 29h6.4v1.8h-4.2V32h3.8v1.8h-3.8V35h4.3v1.9h-6.5zM10 33.9c-.3.7-1 1.2-1.8 1.2-1.2 0-2-1-2-2.1s.8-2.1 2-2.1c.9 0 1.6.6 1.9 1.3h2.3c-.4-1.9-2-3.3-4.2-3.3-2.4 0-4.3 1.8-4.3 4.1s1.8 4.1 4.2 4.1c2.1 0 3.7-1.4 4.2-3.2H10z"/>
</svg>
<h1 class="f4 f2-ns mt5 fw5">Congrats! You created your first tunnel!</h1>
<p class="f6 f5-m f4-l measure lh-copy fw3">
Argo Tunnel exposes locally running applications to the internet by
running an encrypted, virtual tunnel from your laptop or server to
Cloudflare's edge network.
</p>
<p class="b f5 mt5 fw6">Ready for the next step?</p>
<a
class="fw6 link white bg-blue ph4 pv2 br1 dib f5 link-hover"
style="border-bottom: 1px solid #1f679e"
href="https://developers.cloudflare.com/argo-tunnel/">
Get started here
</a>
<section>
<h4 class="f6 fw4 pt5 mb2">Request</h4>
<dl class="bl bw2 b--orange ph3 pt3 pb2 bg-light-gray f7 code overflow-x-auto mw-100">
<dd class="ml0 mb3 f5">Method: {{.Request.Method}}</dd>
<dd class="ml0 mb3 f5">Protocol: {{.Request.Proto}}</dd>
<dd class="ml0 mb3 f5">Request URL: {{.Request.URL}}</dd>
<dd class="ml0 mb3 f5">Transfer encoding: {{.Request.TransferEncoding}}</dd>
<dd class="ml0 mb3 f5">Host: {{.Request.Host}}</dd>
<dd class="ml0 mb3 f5">Remote address: {{.Request.RemoteAddr}}</dd>
<dd class="ml0 mb3 f5">Request URI: {{.Request.RequestURI}}</dd>
{{range $key, $value := .Request.Header}}
<dd class="ml0 mb3 f5">Header: {{$key}}, Value: {{$value}}</dd>
{{end}}
<dd class="ml0 mb3 f5">Body: {{.Body}}</dd>
</dl>
</section>
</div>
</div>
</body>
</html>
`
func hello(c *cli.Context) error {
func helloWorld(c *cli.Context) error {
address := fmt.Sprintf(":%d", c.Int("port"))
listener, err := createListener(address)
listener, err := hello.CreateTLSListener(address)
if err != nil {
return err
}
defer listener.Close()
err = startHelloWorldServer(listener, nil)
err = hello.StartHelloWorldServer(logger, listener, nil)
return err
}
func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
logger.Infof("Starting Hello World server at %s", listener.Addr())
serverName := defaultServerName
if hostname, err := os.Hostname(); err == nil {
serverName = hostname
}
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
go func() {
<-shutdownC
httpServer.Close()
}()
http.HandleFunc("/uptime", uptimeHandler(time.Now()))
http.HandleFunc("/ws", websocketHandler(upgrader))
http.HandleFunc("/", rootHandler(serverName))
err := httpServer.Serve(listener)
return err
}
func createListener(address string) (net.Listener, error) {
certificate, err := tlsconfig.GetHelloCertificate()
if err != nil {
return nil, err
}
// If the port in address is empty, a port number is automatically chosen
listener, err := tls.Listen(
"tcp",
address,
&tls.Config{Certificates: []tls.Certificate{certificate}})
return listener, err
}
func uptimeHandler(startTime time.Time) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Note that if autoupdate is enabled, the uptime is reset when a new client
// release is available
resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()}
respJson, err := json.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
} else {
w.Header().Set("Content-Type", "application/json")
w.Write(respJson)
}
}
}
func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
if err := conn.WriteMessage(mt, message); err != nil {
break
}
}
}
}
func rootHandler(serverName string) http.HandlerFunc {
responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
return func(w http.ResponseWriter, r *http.Request) {
var buffer bytes.Buffer
var body string
rawBody, err := ioutil.ReadAll(r.Body)
if err == nil {
body = string(rawBody)
} else {
body = ""
}
err = responseTemplate.Execute(&buffer, &templateData{
ServerName: serverName,
Request: r,
Body: body,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "error: %v", err)
} else {
buffer.WriteTo(w)
}
}
}

View File

@ -10,7 +10,7 @@ import (
cli "gopkg.in/urfave/cli.v2"
)
func runApp(app *cli.App) {
func runApp(app *cli.App, shutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel system service",
@ -183,9 +183,9 @@ func installLinuxService(c *cli.Context) error {
defaultConfigDir := filepath.Dir(c.String("config"))
defaultConfigFile := filepath.Base(c.String("config"))
if err = copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile); err != nil {
if err = copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile, defaultCredentialFile); err != nil {
logger.WithError(err).Infof("Failed to copy user configuration. Before running the service, ensure that %s contains two files, %s and %s",
serviceConfigDir, credentialFile, defaultConfigFiles[0])
serviceConfigDir, defaultCredentialFile, defaultConfigFiles[0])
return err
}

65
cmd/cloudflared/logger.go Normal file
View File

@ -0,0 +1,65 @@
package main
import (
"fmt"
"os"
"github.com/cloudflare/cloudflared/log"
"github.com/rifflock/lfshook"
"github.com/sirupsen/logrus"
"gopkg.in/urfave/cli.v2"
"github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
)
var logger = log.CreateLogger()
func configMainLogger(c *cli.Context) {
logLevel, err := logrus.ParseLevel(c.String("loglevel"))
if err != nil {
logger.WithError(err).Fatal("Unknown logging level specified")
}
logger.SetLevel(logLevel)
}
func configProtoLogger(c *cli.Context) *logrus.Logger {
protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
if err != nil {
logger.WithError(err).Fatal("Unknown protocol logging level specified")
}
protoLogger := logrus.New()
protoLogger.Level = protoLogLevel
return protoLogger
}
func initLogFile(c *cli.Context, loggers ...*logrus.Logger) error {
filePath, err := homedir.Expand(c.String("logfile"))
if err != nil {
return errors.Wrap(err, "Cannot resolve logfile path")
}
fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC
// do not truncate log file if the client has been autoupdated
if c.Bool("is-autoupdated") {
fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE
}
f, err := os.OpenFile(filePath, fileMode, 0664)
if err != nil {
errors.Wrap(err, fmt.Sprintf("Cannot open file %s", filePath))
}
defer f.Close()
pathMap := lfshook.PathMap{
logrus.InfoLevel: filePath,
logrus.ErrorLevel: filePath,
logrus.FatalLevel: filePath,
logrus.PanicLevel: filePath,
}
for _, l := range loggers {
l.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
}
return nil
}

View File

@ -35,7 +35,7 @@ func login(c *cli.Context) error {
if err != nil {
return err
}
path := filepath.Join(configPath, credentialFile)
path := filepath.Join(configPath, defaultCredentialFile)
fileInfo, err := os.Stat(path)
if err == nil && fileInfo.Size() > 0 {
fmt.Fprintf(os.Stderr, `You have an existing certificate at %s which login would overwrite.

View File

@ -13,7 +13,7 @@ const (
launchdIdentifier = "com.cloudflare.cloudflared"
)
func runApp(app *cli.App) {
func runApp(app *cli.App, shutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel launch agent",
@ -91,12 +91,12 @@ func stderrPath() string {
func installLaunchd(c *cli.Context) error {
if isRootUser() {
logger.Infof("Installing Argo Tunnel client as a system launch daemon. " +
"Argo Tunnel client will run at boot")
"Argo Tunnel client will run at boot")
} else {
logger.Infof("Installing Argo Tunnel client as an user launch agent. " +
"Note that Argo Tunnel client will only run when the user is logged in. " +
"If you want to run Argo Tunnel client at boot, install with root permission. " +
"For more information, visit https://developers.cloudflare.com/argo-tunnel/reference/service/")
"Note that Argo Tunnel client will only run when the user is logged in. " +
"If you want to run Argo Tunnel client at boot, install with root permission. " +
"For more information, visit https://developers.cloudflare.com/argo-tunnel/reference/service/")
}
etPath, err := os.Executable()
if err != nil {
@ -120,7 +120,6 @@ func installLaunchd(c *cli.Context) error {
}
func uninstallLaunchd(c *cli.Context) error {
if isRootUser() {
logger.Infof("Uninstalling Argo Tunnel as a system launch daemon")
} else {

View File

@ -1,69 +1,47 @@
package main
import (
"crypto/tls"
"encoding/hex"
"fmt"
"io/ioutil"
"math/rand"
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
"github.com/cloudflare/cloudflared/log"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation"
"github.com/facebookgo/grace/gracenet"
"github.com/getsentry/raven-go"
"github.com/mitchellh/go-homedir"
"github.com/rifflock/lfshook"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh/terminal"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
"github.com/coreos/go-systemd/daemon"
"github.com/pkg/errors"
"github.com/facebookgo/grace/gracenet"
)
const (
sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
credentialFile = "cert.pem"
quickStartUrl = "https://developers.cloudflare.com/argo-tunnel/quickstart/quickstart/"
noAutoupdateMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/"
licenseUrl = "https://developers.cloudflare.com/argo-tunnel/licence/"
sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
developerPortal = "https://developers.cloudflare.com/argo-tunnel"
quickStartUrl = developerPortal + "/quickstart/quickstart/"
serviceUrl = developerPortal + "/reference/service/"
argumentsUrl = developerPortal + "/reference/arguments/"
licenseUrl = developerPortal + "/licence/"
)
var listeners = gracenet.Net{}
var Version = "DEV"
var BuildTime = "unknown"
var logger = log.CreateLogger()
var defaultConfigFiles = []string{"config.yml", "config.yaml"}
// Launchd doesn't set root env variables, so there is default
// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible
var defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp", "/usr/local/etc/cloudflared", "/etc/cloudflared"}
// Shutdown channel used by the app. When closed, app must terminate.
// May be closed by the Windows service runner.
var shutdownC chan struct{}
var (
Version = "DEV"
BuildTime = "unknown"
)
func main() {
metrics.RegisterBuildInfo(BuildTime, Version)
raven.SetDSN(sentryDSN)
raven.SetRelease(Version)
shutdownC = make(chan struct{})
// Shutdown channel used by the app. When closed, app must terminate.
// May be closed by the Windows service runner.
shutdownC := make(chan struct{})
app := &cli.App{}
app.Name = "cloudflared"
@ -119,6 +97,11 @@ func main() {
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: findDefaultOriginCertPath(),
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "origin-ca-pool",
Usage: "Path to the CA for the certificate of your origin. This option should be used only if your certificate is not signed by Cloudflare.",
EnvVars: []string{"TUNNEL_ORIGIN_CA_POOL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "url",
Value: "https://localhost:8080",
@ -293,10 +276,13 @@ func main() {
}),
}
app.Action = func(c *cli.Context) error {
raven.CapturePanic(func() { startServer(c) }, nil)
raven.CapturePanic(func() { startServer(c, shutdownC) }, nil)
return nil
}
app.Before = func(context *cli.Context) error {
if context.String("config") == "" {
logger.Warnf("Cannot determine default configuration path. No file %v in %v", defaultConfigFiles, defaultConfigDirs)
}
inputSource, err := findInputSourceContext(context)
if err != nil {
logger.WithError(err).Infof("Cannot load configuration from %s", context.String("config"))
@ -337,7 +323,7 @@ func main() {
},
{
Name: "hello",
Action: hello,
Action: helloWorld,
Usage: "Run a simple \"Hello World\" server for testing Argo Tunnel.",
Flags: []cli.Flag{
&cli.IntFlag{
@ -381,233 +367,130 @@ func main() {
ArgsUsage: " ", // can't be the empty string or we get the default output
},
}
runApp(app)
runApp(app, shutdownC)
}
func startServer(c *cli.Context) {
func startServer(c *cli.Context, shutdownC chan struct{}) {
var wg sync.WaitGroup
listeners := gracenet.Net{}
errC := make(chan error)
connectedSignal := make(chan struct{})
dnsReadySignal := make(chan struct{})
graceShutdownSignal := make(chan struct{})
// If the user choose to supply all options through env variables,
// c.NumFlags() == 0 && c.NArg() == 0. For cloudflared to work, the user needs to at
// least provide a hostname.
if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" {
logger.Infof("No arguments were provided. You need to at least specify the hostname for this tunnel. See %s", quickStartUrl)
cli.ShowAppHelp(c)
// check whether client provides enough flags or env variables. If not, print help.
if ok := enoughOptionsSet(c); !ok {
return
}
logLevel, err := logrus.ParseLevel(c.String("loglevel"))
if err != nil {
logger.WithError(err).Fatal("Unknown logging level specified")
}
logger.SetLevel(logLevel)
protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
if err != nil {
logger.WithError(err).Fatal("Unknown protocol logging level specified")
}
protoLogger := logrus.New()
protoLogger.Level = protoLogLevel
configMainLogger(c)
protoLogger := configProtoLogger(c)
if c.String("logfile") != "" {
if err := initLogFile(c, protoLogger); err != nil {
if err := initLogFile(c, logger, protoLogger); err != nil {
logger.Error(err)
}
}
handleDeprecatedOptions(c)
buildInfo := origin.GetBuildInfo()
logger.Infof("Build info: %+v", *buildInfo)
logger.Infof("Version %s", Version)
logClientOptions(c)
if c.IsSet("proxy-dns") {
port := c.Int("proxy-dns-port")
if port <= 0 || port > 65535 {
logger.Fatal("The 'proxy-dns-port' must be a valid port number in <1, 65535> range.")
}
wg.Add(1)
listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(port), c.StringSlice("proxy-dns-upstream"))
if err != nil {
close(dnsReadySignal)
listener.Stop()
logger.WithError(err).Fatal("Cannot create the DNS over HTTPS proxy server")
}
go func() {
err := listener.Start(dnsReadySignal)
if err != nil {
logger.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server")
} else {
<-shutdownC
}
listener.Stop()
wg.Done()
defer wg.Done()
runDNSProxyServer(c, dnsReadySignal, shutdownC)
}()
} else {
close(dnsReadySignal)
}
isRunningFromTerminal := isRunningFromTerminal()
if isAutoupdateEnabled(c, isRunningFromTerminal) {
// Wait for proxy-dns to come up (if used)
<-dnsReadySignal
if initUpdate() {
// Wait for proxy-dns to come up (if used)
<-dnsReadySignal
// update needs to be after DNS proxy is up to resolve equinox server address
if isAutoupdateEnabled(c) {
if initUpdate(&listeners) {
return
}
logger.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
go autoupdate(c.Duration("autoupdate-freq"), shutdownC)
go autoupdate(c.Duration("autoupdate-freq"), &listeners, shutdownC)
}
// Serve DNS proxy stand-alone if no hostname or tag or app is going to run
if c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world")) {
go writePidFile(connectedSignal, c.String("pidfile"))
close(connectedSignal)
runServer(c, &wg, errC, shutdownC)
return
}
hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil {
logger.WithError(err).Fatal("Invalid hostname")
}
clientID := c.String("id")
if !c.IsSet("id") {
clientID = generateRandomClientID()
}
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
if err != nil {
logger.WithError(err).Fatal("Tag parse failure")
}
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
if c.IsSet("hello-world") {
wg.Add(1)
listener, err := createListener("127.0.0.1:")
if err != nil {
listener.Close()
logger.WithError(err).Fatal("Cannot start Hello World Server")
}
go func() {
startHelloWorldServer(listener, shutdownC)
wg.Done()
listener.Close()
}()
c.Set("url", "https://"+listener.Addr().String())
}
url, err := validateUrl(c)
if err != nil {
logger.WithError(err).Fatal("Error validating url")
}
logger.Infof("Proxying tunnel requests to %s", url)
// Fail if the user provided an old authentication method
if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
logger.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflared login")
}
// Check that the user has acquired a certificate using the log in command
originCertPath, err := homedir.Expand(c.String("origincert"))
if err != nil {
logger.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
}
ok, err := fileExists(originCertPath)
if err != nil {
logger.Fatalf("Cannot check if origin cert exists at path %s", c.String("origincert"))
}
if !ok {
logger.Fatalf(`Cannot find a valid certificate for your origin at the path:
%s
If the path above is wrong, specify the path with the -origincert option.
If you don't have a certificate signed by Cloudflare, run the command:
%s login
`, originCertPath, os.Args[0])
}
// Easier to send the certificate as []byte via RPC than decoding it at this point
originCert, err := ioutil.ReadFile(originCertPath)
if err != nil {
logger.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
}
tunnelMetrics := origin.NewTunnelMetrics()
httpTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: c.Duration("proxy-connect-timeout"),
KeepAlive: c.Duration("proxy-tcp-keepalive"),
DualStack: !c.Bool("proxy-no-happy-eyeballs"),
}).DialContext,
MaxIdleConns: c.Int("proxy-keepalive-connections"),
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()},
}
if !c.IsSet("hello-world") && c.IsSet("origin-server-name") {
httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
}
tunnelConfig := &origin.TunnelConfig{
EdgeAddrs: c.StringSlice("edge"),
OriginUrl: url,
Hostname: hostname,
OriginCert: originCert,
TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
ClientTlsConfig: httpTransport.TLSClientConfig,
Retries: c.Uint("retries"),
HeartbeatInterval: c.Duration("heartbeat-interval"),
MaxHeartbeats: c.Uint64("heartbeat-count"),
ClientID: clientID,
BuildInfo: buildInfo,
ReportedVersion: Version,
LBPool: c.String("lb-pool"),
Tags: tags,
HAConnections: c.Int("ha-connections"),
HTTPTransport: httpTransport,
Metrics: tunnelMetrics,
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
ProtocolLogger: protoLogger,
Logger: logger,
IsAutoupdated: c.Bool("is-autoupdated"),
GracePeriod: c.Duration("grace-period"),
RunFromTerminal: isRunningFromTerminal,
}
go writePidFile(connectedSignal, c.String("pidfile"))
wg.Add(1)
go func() {
errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal)
wg.Done()
}()
runServer(c, &wg, errC, shutdownC)
}
func runServer(c *cli.Context, wg *sync.WaitGroup, errC chan error, shutdownC chan struct{}) {
wg.Add(1)
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
if err != nil {
logger.WithError(err).Fatal("Error opening metrics server listener")
}
defer metricsListener.Close()
wg.Add(1)
go func() {
defer wg.Done()
errC <- metrics.ServeMetrics(metricsListener, shutdownC, logger)
wg.Done()
}()
// Serve DNS proxy stand-alone if no hostname or tag or app is going to run
if dnsProxyStandAlone(c) {
if c.IsSet("pidfile") {
go writePidFile(connectedSignal, c.String("pidfile"))
close(connectedSignal)
}
// no grace period, handle SIGINT/SIGTERM immediately
waitToShutdown(&wg, errC, shutdownC, graceShutdownSignal, 0)
return
}
if c.IsSet("hello-world") {
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
if err != nil {
logger.WithError(err).Fatal("Cannot start Hello World Server")
}
defer helloListener.Close()
wg.Add(1)
go func() {
defer wg.Done()
hello.StartHelloWorldServer(logger, helloListener, shutdownC)
}()
c.Set("url", "https://"+helloListener.Addr().String())
}
tunnelConfig := prepareTunnelConfig(c, buildInfo, logger, protoLogger)
if c.IsSet("pidFile") {
go writePidFile(connectedSignal, c.String("pidfile"))
}
wg.Add(1)
go func() {
defer wg.Done()
errC <- origin.StartTunnelDaemon(tunnelConfig, graceShutdownSignal, connectedSignal)
}()
waitToShutdown(&wg, errC, shutdownC, graceShutdownSignal, c.Duration("grace-period"))
}
func waitToShutdown(wg *sync.WaitGroup,
errC chan error,
shutdownC, graceShutdownSignal chan struct{},
gracePeriod time.Duration,
) {
var err error
if gracePeriod > 0 {
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceShutdownSignal, gracePeriod)
} else {
err = waitForSignal(errC, shutdownC)
close(graceShutdownSignal)
}
var errCode int
err = WaitForSignal(errC, shutdownC)
if err != nil {
logger.WithError(err).Fatal("Quitting due to error")
raven.CaptureErrorAndWait(err, nil)
errCode = 1
} else {
logger.Info("Graceful shutdown...")
logger.Info("Quitting...")
}
// Wait for clean exit, discarding all errors
go func() {
@ -618,126 +501,6 @@ func runServer(c *cli.Context, wg *sync.WaitGroup, errC chan error, shutdownC ch
os.Exit(errCode)
}
func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signals)
select {
case err := <-errC:
close(shutdownC)
return err
case <-signals:
close(shutdownC)
case <-shutdownC:
}
return nil
}
func update(_ *cli.Context) error {
if updateApplied() {
os.Exit(64)
}
return nil
}
func initUpdate() bool {
if updateApplied() {
os.Args = append(os.Args, "--is-autoupdated=true")
if _, err := listeners.StartProcess(); err != nil {
logger.WithError(err).Error("Unable to restart server automatically")
return false
}
return true
}
return false
}
func autoupdate(freq time.Duration, shutdownC chan struct{}) {
for {
if updateApplied() {
os.Args = append(os.Args, "--is-autoupdated=true")
if _, err := listeners.StartProcess(); err != nil {
logger.WithError(err).Error("Unable to restart server automatically")
}
close(shutdownC)
return
}
time.Sleep(freq)
}
}
func updateApplied() bool {
releaseInfo := checkForUpdates()
if releaseInfo.Updated {
logger.Infof("Updated to version %s", releaseInfo.Version)
return true
}
if releaseInfo.Error != nil {
logger.WithError(releaseInfo.Error).Error("Update check failed")
}
return false
}
func fileExists(path string) (bool, error) {
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
// ignore missing files
return false, nil
}
return false, err
}
f.Close()
return true, nil
}
// returns the first path that contains a cert.pem file. If none of the defaultConfigDirs
// (differs by OS for legacy reasons) contains a cert.pem file, return empty string
func findDefaultOriginCertPath() string {
for _, defaultConfigDir := range defaultConfigDirs {
originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, credentialFile))
if ok, _ := fileExists(originCertPath); ok {
return originCertPath
}
}
return ""
}
// returns the firt path that contains a config file. If none of the combination of
// defaultConfigDirs (differs by OS for legacy reasons) and defaultConfigFiles
// contains a config file, return empty string
func findDefaultConfigPath() string {
for _, configDir := range defaultConfigDirs {
for _, configFile := range defaultConfigFiles {
dirPath, err := homedir.Expand(configDir)
if err != nil {
continue
}
path := filepath.Join(dirPath, configFile)
if ok, _ := fileExists(path); ok {
return path
}
}
}
return ""
}
func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) {
if context.String("config") != "" {
return altsrc.NewYamlSourceFromFile(context.String("config"))
}
return nil, nil
}
func generateRandomClientID() string {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
id := make([]byte, 32)
r.Read(id)
return hex.EncodeToString(id)
}
func writePidFile(waitForSignal chan struct{}, pidFile string) {
<-waitForSignal
daemon.SdNotify(false, "READY=1")
@ -752,87 +515,6 @@ func writePidFile(waitForSignal chan struct{}, pidFile string) {
fmt.Fprintf(file, "%d", os.Getpid())
}
// validate url. It can be either from --url or argument
func validateUrl(c *cli.Context) (string, error) {
var url = c.String("url")
if c.NArg() > 0 {
if c.IsSet("url") {
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
}
url = c.Args().Get(0)
}
validUrl, err := validation.ValidateUrl(url)
return validUrl, err
}
func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error {
filePath, err := homedir.Expand(c.String("logfile"))
if err != nil {
return errors.Wrap(err, "Cannot resolve logfile path")
}
fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC
// do not truncate log file if the client has been autoupdated
if c.Bool("is-autoupdated") {
fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE
}
f, err := os.OpenFile(filePath, fileMode, 0664)
if err != nil {
errors.Wrap(err, fmt.Sprintf("Cannot open file %s", filePath))
}
defer f.Close()
pathMap := lfshook.PathMap{
logrus.InfoLevel: filePath,
logrus.ErrorLevel: filePath,
logrus.FatalLevel: filePath,
logrus.PanicLevel: filePath,
}
logger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
return nil
}
func logClientOptions(c *cli.Context) {
flags := make(map[string]interface{})
for _, flag := range c.LocalFlagNames() {
flags[flag] = c.Generic(flag)
}
if len(flags) > 0 {
logger.Infof("Flags %v", flags)
}
envs := make(map[string]string)
// Find env variables for Argo Tunnel
for _, env := range os.Environ() {
// All Argo Tunnel env variables start with TUNNEL_
if strings.Contains(env, "TUNNEL_") {
vars := strings.Split(env, "=")
if len(vars) == 2 {
envs[vars[0]] = vars[1]
}
}
}
if len(envs) > 0 {
logger.Infof("Environmental variables %v", envs)
}
}
func isAutoupdateEnabled(c *cli.Context, isRunningFromTerminal bool) bool {
if isRunningFromTerminal {
logger.Info(noAutoupdateMessage)
return false
}
return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0
}
func isRunningFromTerminal() bool {
return terminal.IsTerminal(int(os.Stdout.Fd()))
}
func userHomeDir() string {
// This returns the home dir of the executing user using OS-specific method
// for discovering the home dir. It's not recommended to call this function

27
cmd/cloudflared/server.go Normal file
View File

@ -0,0 +1,27 @@
package main
import (
"github.com/cloudflare/cloudflared/tunneldns"
"gopkg.in/urfave/cli.v2"
)
func runDNSProxyServer(c *cli.Context, dnsReadySignal, shutdownC chan struct{}) {
port := c.Int("proxy-dns-port")
if port <= 0 || port > 65535 {
logger.Fatal("The 'proxy-dns-port' must be a valid port number in <1, 65535> range.")
}
listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(port), c.StringSlice("proxy-dns-upstream"))
if err != nil {
close(dnsReadySignal)
listener.Stop()
logger.WithError(err).Fatal("Cannot create the DNS over HTTPS proxy server")
}
err = listener.Start(dnsReadySignal)
if err != nil {
logger.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server")
}
<-shutdownC
listener.Stop()
}

View File

@ -119,7 +119,7 @@ func openFile(path string, create bool) (file *os.File, exists bool, err error)
return file, false, err
}
func copyCertificate(srcConfigDir, destConfigDir string) error {
func copyCertificate(srcConfigDir, destConfigDir, credentialFile string) error {
destCredentialPath := filepath.Join(destConfigDir, credentialFile)
destFile, exists, err := openFile(destCredentialPath, true)
if err != nil {
@ -146,12 +146,12 @@ func copyCertificate(srcConfigDir, destConfigDir string) error {
return nil
}
func copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile string) error {
func copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile, defaultCredentialFile string) error {
if err := ensureConfigDirExists(serviceConfigDir); err != nil {
return err
}
if err := copyCertificate(defaultConfigDir, serviceConfigDir); err != nil {
if err := copyCertificate(defaultConfigDir, serviceConfigDir, defaultCredentialFile); err != nil {
return err
}

54
cmd/cloudflared/signal.go Normal file
View File

@ -0,0 +1,54 @@
package main
import (
"os"
"os/signal"
"syscall"
"time"
)
func waitForSignal(errC chan error, shutdownC chan struct{}) error {
signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signals)
select {
case err := <-errC:
close(shutdownC)
return err
case <-signals:
close(shutdownC)
case <-shutdownC:
}
return nil
}
func waitForSignalWithGraceShutdown(errC chan error, shutdownC, graceShutdownSignal chan struct{}, gracePeriod time.Duration) error {
signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signals)
select {
case err := <-errC:
close(graceShutdownSignal)
close(shutdownC)
return err
case <-signals:
close(graceShutdownSignal)
logger.Infof("Initiating graceful shutdown...")
// Unregister signal handler early, so the client can send a second SIGTERM/SIGINT
// to force shutdown cloudflared
signal.Stop(signals)
graceTimerTick := time.Tick(gracePeriod)
// send close signal via shutdownC when grace period expires or when an
// error is encountered.
select {
case <-graceTimerTick:
case <-errC:
}
close(shutdownC)
case <-shutdownC:
}
return nil
}

View File

@ -0,0 +1,131 @@
package main
import (
"fmt"
"syscall"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const tick = 100 * time.Millisecond
var (
serverErr = fmt.Errorf("server error")
shutdownErr = fmt.Errorf("receive shutdown")
graceShutdownErr = fmt.Errorf("receive grace shutdown")
)
func testChannelClosed(t *testing.T, c chan struct{}) {
select {
case <-c:
return
default:
t.Fatal("Channel should be readable")
}
}
func TestWaitForSignal(t *testing.T) {
// Test handling server error
errC := make(chan error)
shutdownC := make(chan struct{})
go func() {
errC <- serverErr
}()
err := waitForSignal(errC, shutdownC)
assert.Equal(t, serverErr, err)
testChannelClosed(t, shutdownC)
// Test handling SIGTERM & SIGINT
for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
errC = make(chan error)
shutdownC = make(chan struct{})
go func(shutdownC chan struct{}) {
<-shutdownC
errC <- shutdownErr
}(shutdownC)
go func(sig syscall.Signal) {
// sleep for a tick to prevent sending signal before calling waitForSignal
time.Sleep(tick)
syscall.Kill(syscall.Getpid(), sig)
}(sig)
err = waitForSignal(errC, shutdownC)
assert.Equal(t, nil, err)
assert.Equal(t, shutdownErr, <-errC)
testChannelClosed(t, shutdownC)
}
}
func TestWaitForSignalWithGraceShutdown(t *testing.T) {
// Test server returning error
errC := make(chan error)
shutdownC := make(chan struct{})
graceshutdownC := make(chan struct{})
go func() {
errC <- serverErr
}()
err := waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick)
assert.Equal(t, serverErr, err)
testChannelClosed(t, shutdownC)
testChannelClosed(t, graceshutdownC)
// Test handling SIGTERM & SIGINT
for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
//var wg sync.WaitGroup
errC := make(chan error)
shutdownC = make(chan struct{})
graceshutdownC = make(chan struct{})
go func(shutdownC, graceshutdownC chan struct{}) {
<-graceshutdownC
<-shutdownC
errC <- graceShutdownErr
}(shutdownC, graceshutdownC)
go func(sig syscall.Signal) {
// sleep for a tick to prevent sending signal before calling waitForSignalWithGraceShutdown
time.Sleep(tick)
syscall.Kill(syscall.Getpid(), sig)
}(sig)
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick)
assert.Equal(t, nil, err)
assert.Equal(t, graceShutdownErr, <-errC)
testChannelClosed(t, shutdownC)
testChannelClosed(t, graceshutdownC)
}
// Test handling SIGTERM & SIGINT, server send error before end of grace period
for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
errC := make(chan error)
shutdownC = make(chan struct{})
graceshutdownC = make(chan struct{})
go func(shutdownC, graceshutdownC chan struct{}) {
<-graceshutdownC
errC <- graceShutdownErr
<-shutdownC
errC <- shutdownErr
}(shutdownC, graceshutdownC)
go func(sig syscall.Signal) {
// sleep for a tick to prevent sending signal before calling waitForSignalWithGraceShutdown
time.Sleep(tick)
syscall.Kill(syscall.Getpid(), sig)
}(sig)
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick)
assert.Equal(t, nil, err)
assert.Equal(t, shutdownErr, <-errC)
testChannelClosed(t, shutdownC)
testChannelClosed(t, graceshutdownC)
}
}

View File

@ -1,8 +1,20 @@
package main
import "github.com/equinox-io/equinox"
import (
"os"
"time"
const appID = "app_idCzgxYerVD"
"golang.org/x/crypto/ssh/terminal"
"gopkg.in/urfave/cli.v2"
"github.com/equinox-io/equinox"
"github.com/facebookgo/grace/gracenet"
)
const (
appID = "app_idCzgxYerVD"
noAutoupdateMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/"
)
var publicKey = []byte(`
-----BEGIN ECDSA PUBLIC KEY-----
@ -39,3 +51,61 @@ func checkForUpdates() ReleaseInfo {
return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion}
}
func update(_ *cli.Context) error {
if updateApplied() {
os.Exit(64)
}
return nil
}
func initUpdate(listeners *gracenet.Net) bool {
if updateApplied() {
os.Args = append(os.Args, "--is-autoupdated=true")
if _, err := listeners.StartProcess(); err != nil {
logger.WithError(err).Error("Unable to restart server automatically")
return false
}
return true
}
return false
}
func autoupdate(freq time.Duration, listeners *gracenet.Net, shutdownC chan struct{}) {
for {
if updateApplied() {
os.Args = append(os.Args, "--is-autoupdated=true")
if _, err := listeners.StartProcess(); err != nil {
logger.WithError(err).Error("Unable to restart server automatically")
}
close(shutdownC)
return
}
time.Sleep(freq)
}
}
func updateApplied() bool {
releaseInfo := checkForUpdates()
if releaseInfo.Updated {
logger.Infof("Updated to version %s", releaseInfo.Version)
return true
}
if releaseInfo.Error != nil {
logger.WithError(releaseInfo.Error).Error("Update check failed")
}
return false
}
func isAutoupdateEnabled(c *cli.Context) bool {
if isRunningFromTerminal() {
logger.Info(noAutoupdateMessage)
return false
}
return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0
}
func isRunningFromTerminal() bool {
return terminal.IsTerminal(int(os.Stdout.Fd()))
}

View File

@ -21,7 +21,7 @@ const (
windowsServiceDescription = "Argo Tunnel agent"
)
func runApp(app *cli.App) {
func runApp(app *cli.App, shutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel Windows service",
@ -59,7 +59,7 @@ func runApp(app *cli.App) {
elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName))
// Run executes service name by calling windowsService which is a Handler
// interface that implements Execute method
err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog})
err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog, shutdownC: shutdownC})
if err != nil {
elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
return
@ -68,8 +68,9 @@ func runApp(app *cli.App) {
}
type windowsService struct {
app *cli.App
elog *eventlog.Log
app *cli.App
elog *eventlog.Log
shutdownC chan struct{}
}
// called by the package code at the start of the service
@ -98,7 +99,7 @@ loop:
}
}
}
close(shutdownC)
close(s.shutdownC)
changes <- svc.Status{State: svc.StopPending}
return
}

View File

@ -135,7 +135,7 @@ func TestSingleStream(t *testing.T) {
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
}
stream.WriteHeaders([]Header{
Header{Name: "response-header", Value: "responseValue"},
{Name: "response-header", Value: "responseValue"},
})
buf := []byte("Hello world")
stream.Write(buf)
@ -153,7 +153,7 @@ func TestSingleStream(t *testing.T) {
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "test-header", Value: "headerValue"}},
[]Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
if err != nil {
@ -194,6 +194,7 @@ func TestSingleStream(t *testing.T) {
func TestSingleStreamLargeResponseBody(t *testing.T) {
muxPair := NewDefaultMuxerPair()
bodySize := 1 << 24
streamReady := make(chan struct{})
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
if len(stream.Headers) != 1 {
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
@ -205,25 +206,30 @@ func TestSingleStreamLargeResponseBody(t *testing.T) {
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
}
stream.WriteHeaders([]Header{
Header{Name: "response-header", Value: "responseValue"},
{Name: "response-header", Value: "responseValue"},
})
payload := make([]byte, bodySize)
for i := range payload {
payload[i] = byte(i % 256)
}
t.Log("Writing payload...")
n, err := stream.Write(payload)
t.Logf("Wrote %d bytes into the stream", n)
if err != nil {
t.Fatalf("origin write error: %s", err)
}
if n != len(payload) {
t.Fatalf("origin short write: %d/%d bytes", n, len(payload))
}
t.Log("Payload written; signaling that the stream is ready")
streamReady <- struct{}{}
return nil
})
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "test-header", Value: "headerValue"}},
[]Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
if err != nil {
@ -239,6 +245,10 @@ func TestSingleStreamLargeResponseBody(t *testing.T) {
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
}
responseBody := make([]byte, bodySize)
<-streamReady
t.Log("Received stream ready signal; resuming the test")
n, err := io.ReadFull(stream, responseBody)
if err != nil {
t.Fatalf("error from (*MuxedStream).Read: %s", err)
@ -261,7 +271,7 @@ func TestMultipleStreams(t *testing.T) {
}
log.Debugf("Got request for stream %s", stream.Headers[0].Value)
stream.WriteHeaders([]Header{
Header{Name: "response-token", Value: stream.Headers[0].Value},
{Name: "response-token", Value: stream.Headers[0].Value},
})
log.Debugf("Wrote headers for stream %s", stream.Headers[0].Value)
stream.Write([]byte("OK"))
@ -277,7 +287,7 @@ func TestMultipleStreams(t *testing.T) {
defer wg.Done()
tokenString := fmt.Sprintf("%d", tokenId)
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "client-token", Value: tokenString}},
[]Header{{Name: "client-token", Value: tokenString}},
nil,
)
log.Debugf("Got headers for stream %d", tokenId)
@ -328,6 +338,7 @@ func TestMultipleStreams(t *testing.T) {
func TestMultipleStreamsFlowControl(t *testing.T) {
maxStreams := 32
errorsC := make(chan error, maxStreams)
streamReady := make(chan struct{})
responseSizes := make([]int32, maxStreams)
for i := 0; i < maxStreams; i++ {
responseSizes[i] = rand.Int31n(int32(defaultWindowSize << 4))
@ -344,13 +355,14 @@ func TestMultipleStreamsFlowControl(t *testing.T) {
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
}
stream.WriteHeaders([]Header{
Header{Name: "response-header", Value: "responseValue"},
{Name: "response-header", Value: "responseValue"},
})
payload := make([]byte, responseSizes[(stream.streamID-2)/2])
for i := range payload {
payload[i] = byte(i % 256)
}
n, err := stream.Write(payload)
streamReady <- struct{}{}
if err != nil {
t.Fatalf("origin write error: %s", err)
}
@ -367,7 +379,7 @@ func TestMultipleStreamsFlowControl(t *testing.T) {
go func(tokenId int) {
defer wg.Done()
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "test-header", Value: "headerValue"}},
[]Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
if err != nil {
@ -387,6 +399,7 @@ func TestMultipleStreamsFlowControl(t *testing.T) {
return
}
<-streamReady
responseBody := make([]byte, responseSizes[(stream.streamID-2)/2])
n, err := io.ReadFull(stream, responseBody)
if err != nil {
@ -417,7 +430,7 @@ func TestGracefulShutdown(t *testing.T) {
muxPair := NewDefaultMuxerPair()
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
stream.WriteHeaders([]Header{
Header{Name: "response-header", Value: "responseValue"},
{Name: "response-header", Value: "responseValue"},
})
<-sendC
log.Debugf("Writing %d bytes", len(responseBuf))
@ -436,7 +449,7 @@ func TestGracefulShutdown(t *testing.T) {
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "test-header", Value: "headerValue"}},
[]Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
// Start graceful shutdown of the edge mux - this should also close the origin mux when done
@ -469,7 +482,7 @@ func TestUnexpectedShutdown(t *testing.T) {
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
defer close(handlerFinishC)
stream.WriteHeaders([]Header{
Header{Name: "response-header", Value: "responseValue"},
{Name: "response-header", Value: "responseValue"},
})
<-sendC
n, err := stream.Read([]byte{0})
@ -490,7 +503,7 @@ func TestUnexpectedShutdown(t *testing.T) {
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "test-header", Value: "headerValue"}},
[]Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
// Close the underlying connection before telling the origin to write.
@ -552,7 +565,7 @@ func TestOpenAfterDisconnect(t *testing.T) {
}
_, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "test-header", Value: "headerValue"}},
[]Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
if err != ErrConnectionClosed {

197
hello/hello.go Normal file
View File

@ -0,0 +1,197 @@
package hello
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"html/template"
"io/ioutil"
"net"
"net/http"
"os"
"time"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
"github.com/cloudflare/cloudflared/tlsconfig"
)
type templateData struct {
ServerName string
Request *http.Request
Body string
}
type OriginUpTime struct {
StartTime time.Time `json:"startTime"`
UpTime string `json:"uptime"`
}
const defaultServerName = "the Argo Tunnel test server"
const indexTemplate = `
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=Edge">
<title>
Argo Tunnel Connection
</title>
<meta name="author" content="">
<meta name="description" content="Argo Tunnel Connection">
<meta name="viewport" content="width=device-width, initial-scale=1">
<style>
html{line-height:1.15;-ms-text-size-adjust:100%;-webkit-text-size-adjust:100%}body{margin:0}section{display:block}h1{font-size:2em;margin:.67em 0}a{background-color:transparent;-webkit-text-decoration-skip:objects}/* 1 */::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}/* 1 */a,body,dd,div,dl,dt,h1,h4,html,p,section{box-sizing:border-box}.bt{border-top-style:solid;border-top-width:1px}.bl{border-left-style:solid;border-left-width:1px}.b--orange{border-color:#f38020}.br1{border-radius:.125rem}.bw2{border-width:.25rem}.dib{display:inline-block}.sans-serif{font-family:open sans,-apple-system,BlinkMacSystemFont,avenir next,avenir,helvetica neue,helvetica,ubuntu,roboto,noto,segoe ui,arial,sans-serif}.code{font-family:Consolas,monaco,monospace}.b{font-weight:700}.fw3{font-weight:300}.fw4{font-weight:400}.fw5{font-weight:500}.fw6{font-weight:600}.lh-copy{line-height:1.5}.link{text-decoration:none}.link,.link:active,.link:focus,.link:hover,.link:link,.link:visited{transition:color .15s ease-in}.link:focus{outline:1px dotted currentColor}.mw-100{max-width:100%}.mw4{max-width:8rem}.mw7{max-width:48rem}.bg-light-gray{background-color:#f7f7f7}.link-hover:hover{background-color:#1f679e}.white{color:#fff}.bg-white{background-color:#fff}.bg-blue{background-color:#408bc9}.pb2{padding-bottom:.5rem}.pb6{padding-bottom:8rem}.pt3{padding-top:1rem}.pt5{padding-top:4rem}.pv2{padding-top:.5rem;padding-bottom:.5rem}.ph3{padding-left:1rem;padding-right:1rem}.ph4{padding-left:2rem;padding-right:2rem}.ml0{margin-left:0}.mb1{margin-bottom:.25rem}.mb2{margin-bottom:.5rem}.mb3{margin-bottom:1rem}.mt5{margin-top:4rem}.ttu{text-transform:uppercase}.f4{font-size:1.25rem}.f5{font-size:1rem}.f6{font-size:.875rem}.f7{font-size:.75rem}.measure{max-width:30em}.center{margin-left:auto}.center{margin-right:auto}@media screen and (min-width:30em){.f2-ns{font-size:2.25rem}}@media screen and (min-width:30em) and (max-width:60em){.f5-m{font-size:1rem}}@media screen and (min-width:60em){.f4-l{font-size:1.25rem}}
.st0{fill:#FFF}.st1{fill:#f48120}.st2{fill:#faad3f}.st3{fill:#404041}
</style>
</head>
<body class="sans-serif black">
<div class="bt bw2 b--orange bg-white pb6">
<div class="mw7 center ph4 pt3">
<svg id="Layer_2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 109 40.5" class="mw4">
<path class="st0" d="M98.6 14.2L93 12.9l-1-.4-25.7.2v12.4l32.3.1z"/>
<path class="st1" d="M88.1 24c.3-1 .2-2-.3-2.6-.5-.6-1.2-1-2.1-1.1l-17.4-.2c-.1 0-.2-.1-.3-.1-.1-.1-.1-.2 0-.3.1-.2.2-.3.4-.3l17.5-.2c2.1-.1 4.3-1.8 5.1-3.8l1-2.6c0-.1.1-.2 0-.3-1.1-5.1-5.7-8.9-11.1-8.9-5 0-9.3 3.2-10.8 7.7-1-.7-2.2-1.1-3.6-1-2.4.2-4.3 2.2-4.6 4.6-.1.6 0 1.2.1 1.8-3.9.1-7.1 3.3-7.1 7.3 0 .4 0 .7.1 1.1 0 .2.2.3.3.3h32.1c.2 0 .4-.1.4-.3l.3-1.1z"/>
<path class="st2" d="M93.6 12.8h-.5c-.1 0-.2.1-.3.2l-.7 2.4c-.3 1-.2 2 .3 2.6.5.6 1.2 1 2.1 1.1l3.7.2c.1 0 .2.1.3.1.1.1.1.2 0 .3-.1.2-.2.3-.4.3l-3.8.2c-2.1.1-4.3 1.8-5.1 3.8l-.2.9c-.1.1 0 .3.2.3h13.2c.2 0 .3-.1.3-.3.2-.8.4-1.7.4-2.6 0-5.2-4.3-9.5-9.5-9.5"/>
<path class="st3" d="M104.4 30.8c-.5 0-.9-.4-.9-.9s.4-.9.9-.9.9.4.9.9-.4.9-.9.9m0-1.6c-.4 0-.7.3-.7.7 0 .4.3.7.7.7.4 0 .7-.3.7-.7 0-.4-.3-.7-.7-.7m.4 1.2h-.2l-.2-.3h-.2v.3h-.2v-.9h.5c.2 0 .3.1.3.3 0 .1-.1.2-.2.3l.2.3zm-.3-.5c.1 0 .1 0 .1-.1s-.1-.1-.1-.1h-.3v.3h.3zM14.8 29H17v6h3.8v1.9h-6zM23.1 32.9c0-2.3 1.8-4.1 4.3-4.1s4.2 1.8 4.2 4.1-1.8 4.1-4.3 4.1c-2.4 0-4.2-1.8-4.2-4.1m6.3 0c0-1.2-.8-2.2-2-2.2s-2 1-2 2.1.8 2.1 2 2.1c1.2.2 2-.8 2-2M34.3 33.4V29h2.2v4.4c0 1.1.6 1.7 1.5 1.7s1.5-.5 1.5-1.6V29h2.2v4.4c0 2.6-1.5 3.7-3.7 3.7-2.3-.1-3.7-1.2-3.7-3.7M45 29h3.1c2.8 0 4.5 1.6 4.5 3.9s-1.7 4-4.5 4h-3V29zm3.1 5.9c1.3 0 2.2-.7 2.2-2s-.9-2-2.2-2h-.9v4h.9zM55.7 29H62v1.9h-4.1v1.3h3.7V34h-3.7v2.9h-2.2zM65.1 29h2.2v6h3.8v1.9h-6zM76.8 28.9H79l3.4 8H80l-.6-1.4h-3.1l-.6 1.4h-2.3l3.4-8zm2 4.9l-.9-2.2-.9 2.2h1.8zM85.2 29h3.7c1.2 0 2 .3 2.6.9.5.5.7 1.1.7 1.8 0 1.2-.6 2-1.6 2.4l1.9 2.8H90l-1.6-2.4h-1v2.4h-2.2V29zm3.6 3.8c.7 0 1.2-.4 1.2-.9 0-.6-.5-.9-1.2-.9h-1.4v1.9h1.4zM95.3 29h6.4v1.8h-4.2V32h3.8v1.8h-3.8V35h4.3v1.9h-6.5zM10 33.9c-.3.7-1 1.2-1.8 1.2-1.2 0-2-1-2-2.1s.8-2.1 2-2.1c.9 0 1.6.6 1.9 1.3h2.3c-.4-1.9-2-3.3-4.2-3.3-2.4 0-4.3 1.8-4.3 4.1s1.8 4.1 4.2 4.1c2.1 0 3.7-1.4 4.2-3.2H10z"/>
</svg>
<h1 class="f4 f2-ns mt5 fw5">Congrats! You created your first tunnel!</h1>
<p class="f6 f5-m f4-l measure lh-copy fw3">
Argo Tunnel exposes locally running applications to the internet by
running an encrypted, virtual tunnel from your laptop or server to
Cloudflare's edge network.
</p>
<p class="b f5 mt5 fw6">Ready for the next step?</p>
<a
class="fw6 link white bg-blue ph4 pv2 br1 dib f5 link-hover"
style="border-bottom: 1px solid #1f679e"
href="https://developers.cloudflare.com/argo-tunnel/">
Get started here
</a>
<section>
<h4 class="f6 fw4 pt5 mb2">Request</h4>
<dl class="bl bw2 b--orange ph3 pt3 pb2 bg-light-gray f7 code overflow-x-auto mw-100">
<dd class="ml0 mb3 f5">Method: {{.Request.Method}}</dd>
<dd class="ml0 mb3 f5">Protocol: {{.Request.Proto}}</dd>
<dd class="ml0 mb3 f5">Request URL: {{.Request.URL}}</dd>
<dd class="ml0 mb3 f5">Transfer encoding: {{.Request.TransferEncoding}}</dd>
<dd class="ml0 mb3 f5">Host: {{.Request.Host}}</dd>
<dd class="ml0 mb3 f5">Remote address: {{.Request.RemoteAddr}}</dd>
<dd class="ml0 mb3 f5">Request URI: {{.Request.RequestURI}}</dd>
{{range $key, $value := .Request.Header}}
<dd class="ml0 mb3 f5">Header: {{$key}}, Value: {{$value}}</dd>
{{end}}
<dd class="ml0 mb3 f5">Body: {{.Body}}</dd>
</dl>
</section>
</div>
</div>
</body>
</html>
`
func StartHelloWorldServer(logger *logrus.Logger, listener net.Listener, shutdownC <-chan struct{}) error {
logger.Infof("Starting Hello World server at %s", listener.Addr())
serverName := defaultServerName
if hostname, err := os.Hostname(); err == nil {
serverName = hostname
}
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
go func() {
<-shutdownC
httpServer.Close()
}()
http.HandleFunc("/uptime", uptimeHandler(time.Now()))
http.HandleFunc("/ws", websocketHandler(logger, upgrader))
http.HandleFunc("/", rootHandler(serverName))
err := httpServer.Serve(listener)
return err
}
func CreateTLSListener(address string) (net.Listener, error) {
certificate, err := tlsconfig.GetHelloCertificate()
if err != nil {
return nil, err
}
// If the port in address is empty, a port number is automatically chosen
listener, err := tls.Listen(
"tcp",
address,
&tls.Config{Certificates: []tls.Certificate{certificate}})
return listener, err
}
func uptimeHandler(startTime time.Time) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Note that if autoupdate is enabled, the uptime is reset when a new client
// release is available
resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()}
respJson, err := json.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
} else {
w.Header().Set("Content-Type", "application/json")
w.Write(respJson)
}
}
}
// This handler will echo message
func websocketHandler(logger *logrus.Logger, upgrader websocket.Upgrader) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
logger.WithError(err).Error("websocket read message error")
break
}
if err := conn.WriteMessage(mt, message); err != nil {
logger.WithError(err).Error("websocket write message error")
break
}
}
}
}
func rootHandler(serverName string) http.HandlerFunc {
responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
return func(w http.ResponseWriter, r *http.Request) {
var buffer bytes.Buffer
var body string
rawBody, err := ioutil.ReadAll(r.Body)
if err == nil {
body = string(rawBody)
} else {
body = ""
}
err = responseTemplate.Execute(&buffer, &templateData{
ServerName: serverName,
Request: r,
Body: body,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "error: %v", err)
} else {
buffer.WriteTo(w)
}
}
}

38
hello/hello_test.go Normal file
View File

@ -0,0 +1,38 @@
package hello
import (
"testing"
)
func TestCreateTLSListenerHostAndPortSuccess(t *testing.T) {
listener, err := CreateTLSListener("localhost:1234")
defer listener.Close()
if err != nil {
t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}
func TestCreateTLSListenerOnlyHostSuccess(t *testing.T) {
listener, err := CreateTLSListener("localhost:")
defer listener.Close()
if err != nil {
t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}
func TestCreateTLSListenerOnlyPortSuccess(t *testing.T) {
listener, err := CreateTLSListener(":8888")
defer listener.Close()
if err != nil {
t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}

View File

@ -123,7 +123,12 @@ func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connecte
}
}
func ServeTunnelLoop(ctx context.Context, config *TunnelConfig, addr *net.TCPAddr, connectionID uint8, connectedSignal chan struct{}) error {
func ServeTunnelLoop(ctx context.Context,
config *TunnelConfig,
addr *net.TCPAddr,
connectionID uint8,
connectedSignal chan struct{},
) error {
config.Metrics.incrementHaConnections()
defer config.Metrics.decrementHaConnections()
backoff := BackoffHandler{MaxRetries: config.Retries}
@ -482,6 +487,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
} else {
stream.WriteHeaders(H1ResponseToH2Response(response))
defer conn.Close()
// Copy to/from stream to the undelying connection. Use the underlying
// connection because cloudflared doesn't operate on the message themselves
websocket.Stream(conn.UnderlyingConn(), stream)
h.metrics.incrementResponses(h.connectionID, "200")
h.logResponse(response, cfRay)

View File

@ -5,7 +5,7 @@ import (
)
// TODO: remove the Origin CA root certs when migrated to Authenticated Origin Pull certs
const cloudflareRootCA = `
var cloudflareRootCA = []byte(`
Issuer: C=US, ST=California, L=San Francisco, O=CloudFlare, Inc., OU=CloudFlare Origin SSL ECC Certificate Authority
-----BEGIN CERTIFICATE-----
MIICiDCCAi6gAwIBAgIUXZP3MWb8MKwBE1Qbawsp1sfA/Y4wCgYIKoZIzj0EAwIw
@ -83,7 +83,7 @@ Bz+1CD4D/bWrs3cC9+kk/jFmrrAymZlkFX8tDb5aXASSLJjUjcptci9SKqtI2h0J
wUGkD7+bQAr+7vr8/R+CBmNMe7csE8NeEX6lVMF7Dh0a1YKQa6hUN18bBuYgTMuT
QzMmZpRpIBB321ZBlcnlxiTJvWxvbCPHKHj20VwwAz7LONF59s84ZsOqfoBv8gKM
s0s5dsq5zpLeaw==
-----END CERTIFICATE-----`
-----END CERTIFICATE-----`)
func GetCloudflareRootCA() *x509.CertPool {
ca := x509.NewCertPool()

View File

@ -9,6 +9,7 @@ import (
"net"
"github.com/cloudflare/cloudflared/log"
"github.com/pkg/errors"
"gopkg.in/urfave/cli.v2"
)
@ -64,21 +65,27 @@ func LoadCert(certPath string) *x509.CertPool {
return ca
}
func LoadOriginCertsPool() *x509.CertPool {
func LoadGlobalCertPool() (*x509.CertPool, error) {
success := false
// First, obtain the system certificate pool
certPool, systemCertPoolErr := x509.SystemCertPool()
if systemCertPoolErr != nil {
logger.Warnf("error obtaining the system certificates: %s", systemCertPoolErr)
certPool = x509.NewCertPool()
} else {
success = true
}
// Next, append the Cloudflare CA pool into the system pool
if !certPool.AppendCertsFromPEM([]byte(cloudflareRootCA)) {
logger.Warn("could not append the CF certificate to the system certificate pool")
if !certPool.AppendCertsFromPEM(cloudflareRootCA) {
logger.Warn("could not append the CF certificate to the cloudflared certificate pool")
} else {
success = true
}
if systemCertPoolErr != nil { // Obtaining both certificates failed; this is a fatal error
logger.WithError(systemCertPoolErr).Fatalf("Error loading the certificate pool")
}
if success != true { // Obtaining any of the CAs has failed; this is a fatal error
return nil, errors.New("error loading any of the CAs into the global certificate pool")
}
// Finally, add the Hello certificate into the pool (since it's self-signed)
@ -89,7 +96,34 @@ func LoadOriginCertsPool() *x509.CertPool {
certPool.AddCert(helloCertificate)
return certPool
return certPool, nil
}
func LoadOriginCertPool(originCAPoolPEM []byte) (*x509.CertPool, error) {
success := false
// Get the global pool
certPool, globalPoolErr := LoadGlobalCertPool()
if globalPoolErr != nil {
certPool = x509.NewCertPool()
} else {
success = true
}
// Then, add any custom origin CA pool the user may have passed
if originCAPoolPEM != nil {
if !certPool.AppendCertsFromPEM(originCAPoolPEM) {
logger.Warn("could not append the provided origin CA to the cloudflared certificate pool")
} else {
success = true
}
}
if success != true {
return nil, errors.New("error loading any of the CAs into the origin certificate pool")
}
return certPool, nil
}
func CreateTunnelConfig(c *cli.Context, addrs []string) *tls.Config {

211
tlsconfig/tlsconfig_test.go Normal file
View File

@ -0,0 +1,211 @@
package tlsconfig
import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
// Generated using `openssl req -newkey rsa:512 -nodes -x509 -days 3650`
var samplePEM = []byte(`
-----BEGIN CERTIFICATE-----
MIIB4DCCAYoCCQCb/H0EUrdXEjANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV
UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv
dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE
AwwIVGVzdCBPbmUwHhcNMTgwNDI2MTYxMDUxWhcNMjgwNDIzMTYxMDUxWjB3MQsw
CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG
A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn
eTERMA8GA1UEAwwIVGVzdCBPbmUwXDANBgkqhkiG9w0BAQEFAANLADBIAkEAwVQD
K0SJ25UFLznm2pU3zhzMEvpDEofHVNnCjk4mlDrtVop7PkKZ8pDEmuQANltUrxC8
yHBE2wXMv+GlH+bDtwIDAQABMA0GCSqGSIb3DQEBCwUAA0EAjVYQzozIFPkt/HRY
uUoZ8zEHIDICb0syFf5VAjm9AgTwIPzUmD+c5vl6LWDnxq7L45nLCzhhQ6YmiwDz
X7Wcyg==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIB4DCCAYoCCQDZfCdAJ+mwzDANBgkqhkiG9w0BAQsFADB3MQswCQYDVQQGEwJV
UzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcGA1UECgwQQ2xv
dWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVneTERMA8GA1UE
AwwIVGVzdCBUd28wHhcNMTgwNDI2MTYxMTIwWhcNMjgwNDIzMTYxMTIwWjB3MQsw
CQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1c3RpbjEZMBcG
A1UECgwQQ2xvdWRmbGFyZSwgSW5jLjEZMBcGA1UECwwQUHJvZHVjdCBTdHJhdGVn
eTERMA8GA1UEAwwIVGVzdCBUd28wXDANBgkqhkiG9w0BAQEFAANLADBIAkEAoHKp
ROVK3zCSsH7ocYeyRAML4V7SFAbZcb4WIwDnE08oMBVRkQVcW5tqEkvG3RiClfzV
wZIJ3CfqKIeSNSDU9wIDAQABMA0GCSqGSIb3DQEBCwUAA0EAJw2gUbnPiq4C2p5b
iWzlA9Q7aKo+VQ4H7IZS7tTccr59nVjvH/TG3eWujpnocr4TOqW9M3CK1DF9mUGP
3pQ3Jg==
-----END CERTIFICATE-----
`)
var systemCertPoolSubjects []*pkix.Name
type certificateFixture struct {
ou string
cn string
}
func TestMain(m *testing.M) {
systemCertPool, err := x509.SystemCertPool()
if isUnrecoverableError(err) {
os.Exit(1)
}
if systemCertPool == nil {
// On Windows, let's just assume the system cert pool was empty
systemCertPool = x509.NewCertPool()
}
systemCertPoolSubjects, err = getCertPoolSubjects(systemCertPool)
if err != nil {
os.Exit(1)
}
os.Exit(m.Run())
}
func TestLoadOriginCertPoolJustSystemPool(t *testing.T) {
certPoolSubjects := loadCertPoolSubjects(t, nil)
extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects)
// Remove extra subjects from the cert pool
var filteredSystemCertPoolSubjects []*pkix.Name
t.Log(extraSubjects)
OUTER:
for _, subject := range certPoolSubjects {
for _, extraSubject := range extraSubjects {
if subject == extraSubject {
t.Log(extraSubject)
continue OUTER
}
}
filteredSystemCertPoolSubjects = append(filteredSystemCertPoolSubjects, subject)
}
assert.Equal(t, len(filteredSystemCertPoolSubjects), len(systemCertPoolSubjects))
difference := subjectSubtract(systemCertPoolSubjects, filteredSystemCertPoolSubjects)
assert.Equal(t, 0, len(difference))
}
func TestLoadOriginCertPoolCFCertificates(t *testing.T) {
certPoolSubjects := loadCertPoolSubjects(t, nil)
extraSubjects := subjectSubtract(systemCertPoolSubjects, certPoolSubjects)
expected := []*certificateFixture{
{ou: "CloudFlare Origin SSL ECC Certificate Authority"},
{ou: "CloudFlare Origin SSL Certificate Authority"},
{cn: "origin-pull.cloudflare.net"},
{cn: "Argo Tunnel Sample Hello Server Certificate"},
}
assertFixturesMatchSubjects(t, expected, extraSubjects)
}
func TestLoadOriginCertPoolWithExtraPEMs(t *testing.T) {
certPoolWithoutPEMSubjects := loadCertPoolSubjects(t, nil)
certPoolWithPEMSubjects := loadCertPoolSubjects(t, samplePEM)
difference := subjectSubtract(certPoolWithoutPEMSubjects, certPoolWithPEMSubjects)
assert.Equal(t, 2, len(difference))
expected := []*certificateFixture{
{cn: "Test One"},
{cn: "Test Two"},
}
assertFixturesMatchSubjects(t, expected, difference)
}
func loadCertPoolSubjects(t *testing.T, originCAPoolPEM []byte) []*pkix.Name {
certPool, err := LoadOriginCertPool(originCAPoolPEM)
if isUnrecoverableError(err) {
t.Fatal(err)
}
assert.NotEmpty(t, certPool.Subjects())
certPoolSubjects, err := getCertPoolSubjects(certPool)
if err != nil {
t.Fatal(err)
}
return certPoolSubjects
}
func assertFixturesMatchSubjects(t *testing.T, fixtures []*certificateFixture, subjects []*pkix.Name) {
assert.Equal(t, len(fixtures), len(subjects))
for _, fixture := range fixtures {
found := false
for _, subject := range subjects {
found = found || fixtureMatchesSubjectPredicate(fixture, subject)
}
if !found {
t.Fail()
}
}
}
func fixtureMatchesSubjectPredicate(fixture *certificateFixture, subject *pkix.Name) bool {
cnMatch := true
if fixture.cn != "" {
cnMatch = fixture.cn == subject.CommonName
}
ouMatch := true
if fixture.ou != "" {
ouMatch = len(subject.OrganizationalUnit) > 0 && fixture.ou == subject.OrganizationalUnit[0]
}
return cnMatch && ouMatch
}
func subjectSubtract(left []*pkix.Name, right []*pkix.Name) []*pkix.Name {
var difference []*pkix.Name
var found bool
for _, r := range right {
found = false
for _, l := range left {
if (*l).String() == (*r).String() {
found = true
}
}
if !found {
difference = append(difference, r)
}
}
return difference
}
func getCertPoolSubjects(certPool *x509.CertPool) ([]*pkix.Name, error) {
var subjects []*pkix.Name
for _, subject := range certPool.Subjects() {
var sequence pkix.RDNSequence
_, err := asn1.Unmarshal(subject, &sequence)
if err != nil {
return nil, err
}
name := pkix.Name{}
name.FillFromRDNSequence(&sequence)
subjects = append(subjects, &name)
}
return subjects, nil
}
func isUnrecoverableError(err error) bool {
return err != nil && err.Error() != "crypto/x509: system root pool is not available on Windows"
}

View File

@ -43,7 +43,7 @@ func NewUpstreamHTTPS(endpoint string) (Upstream, error) {
http2.ConfigureTransport(transport)
client := &http.Client{
Timeout: time.Second * defaultTimeout,
Timeout: defaultTimeout,
Transport: transport,
}

View File

@ -13,16 +13,28 @@ import (
"github.com/gorilla/websocket"
)
var stripWebsocketHeaders = []string {
"Upgrade",
"Connection",
"Sec-Websocket-Key",
"Sec-Websocket-Version",
"Sec-Websocket-Extensions",
}
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
func IsWebSocketUpgrade(req *http.Request) bool {
return websocket.IsWebSocketUpgrade(req)
}
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing.
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
// the connection. The response body may not contain the entire response and does
// not need to be closed by the application.
func ClientConnect(req *http.Request, tlsClientConfig *tls.Config) (*websocket.Conn, *http.Response, error) {
req.URL.Scheme = changeRequestScheme(req)
wsHeaders := websocketHeaders(req)
d := &websocket.Dialer{TLSClientConfig: tlsClientConfig}
conn, response, err := d.Dial(req.URL.String(), nil)
conn, response, err := d.Dial(req.URL.String(), wsHeaders)
if err != nil {
return nil, nil, err
}
@ -62,6 +74,21 @@ func Stream(conn, backendConn io.ReadWriter) {
<-proxyDone
}
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
// Sec-WebSocket-Version and Sec-Websocket-Extensions headers.
// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194.
func websocketHeaders(req *http.Request) http.Header {
wsHeaders := make(http.Header)
for key, val := range req.Header {
wsHeaders[key] = val
}
// Assume the header keys are in canonical format.
for _, header := range stripWebsocketHeaders {
wsHeaders.Del(header)
}
return wsHeaders
}
// sha1Base64 sha1 and then base64 encodes str.
func sha1Base64(str string) string {
hasher := sha1.New()

100
websocket/websocket_test.go Normal file
View File

@ -0,0 +1,100 @@
package websocket
import (
"crypto/tls"
"io"
"math/rand"
"net/http"
"testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.org/x/net/websocket"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/tlsconfig"
)
const (
// example in Sec-Websocket-Key in rfc6455
testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ=="
// example Sec-Websocket-Accept in rfc6455
testSecWebsocketAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
)
func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
req, err := http.NewRequest("GET", url, stream)
if err != nil {
t.Fatalf("testRequestHeader error")
}
req.Header.Add("Connection", "Upgrade")
req.Header.Add("Upgrade", "WebSocket")
req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
req.Header.Add("Sec-Websocket-Version", "13")
req.Header.Add("User-Agent", "curl/7.59.0")
return req
}
func websocketClientTLSConfig(t *testing.T) *tls.Config {
certPool, err := tlsconfig.LoadOriginCertPool(nil)
assert.NoError(t, err)
assert.NotNil(t, certPool)
return &tls.Config{RootCAs: certPool}
}
func TestWebsocketHeaders(t *testing.T) {
req := testRequest(t, "http://example.com", nil)
wsHeaders := websocketHeaders(req)
for _, header := range stripWebsocketHeaders {
assert.Empty(t, wsHeaders[header])
}
assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
}
func TestGenerateAcceptKey(t *testing.T) {
req := testRequest(t, "http://example.com", nil)
assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(req))
}
func TestServe(t *testing.T) {
logger := logrus.New()
shutdownC := make(chan struct{})
errC := make(chan error)
listener, err := hello.CreateTLSListener("localhost:1111")
assert.NoError(t, err)
defer listener.Close()
go func() {
errC <- hello.StartHelloWorldServer(logger, listener, shutdownC)
}()
req := testRequest(t, "https://localhost:1111/ws", nil)
tlsConfig := websocketClientTLSConfig(t)
assert.NotNil(t, tlsConfig)
conn, resp, err := ClientConnect(req, tlsConfig)
assert.NoError(t, err)
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
for i := 0; i < 1000; i++ {
messageSize := rand.Int() % 2048 + 1
clientMessage := make([]byte, messageSize)
// rand.Read always returns len(clientMessage) and a nil error
rand.Read(clientMessage)
err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
assert.NoError(t, err)
messageType, message, err := conn.ReadMessage()
assert.NoError(t, err)
assert.Equal(t, websocket.BinaryFrame, messageType)
assert.Equal(t, clientMessage, message)
}
conn.Close()
close(shutdownC)
<-errC
}