Release Argo Tunnel Client 2018.3.1

This commit is contained in:
cloudflare-warp-bot 2018-03-22 15:24:52 +00:00
parent 9f5cec8dbc
commit d0a6a2a829
34 changed files with 3296 additions and 293 deletions

View File

@ -0,0 +1,13 @@
// +build !windows,!darwin,!linux
package main
import (
"os"
cli "gopkg.in/urfave/cli.v2"
)
func runApp(app *cli.App) {
app.Run(os.Args)
}

204
cmd/cloudflared/hello.go Normal file
View File

@ -0,0 +1,204 @@
package main
import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"html/template"
"io/ioutil"
"net"
"net/http"
"os"
"time"
"github.com/gorilla/websocket"
"gopkg.in/urfave/cli.v2"
"github.com/cloudflare/cloudflared/tlsconfig"
)
type templateData struct {
ServerName string
Request *http.Request
Body string
}
type OriginUpTime struct {
StartTime time.Time `json:"startTime"`
UpTime string `json:"uptime"`
}
const defaultServerName = "the Argo Tunnel test server"
const indexTemplate = `
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=Edge">
<title>
Argo Tunnel Connection
</title>
<meta name="author" content="">
<meta name="description" content="Argo Tunnel Connection">
<meta name="viewport" content="width=device-width, initial-scale=1">
<style>
html{line-height:1.15;-ms-text-size-adjust:100%;-webkit-text-size-adjust:100%}body{margin:0}section{display:block}h1{font-size:2em;margin:.67em 0}a{background-color:transparent;-webkit-text-decoration-skip:objects}/* 1 */::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}/* 1 */a,body,dd,div,dl,dt,h1,h4,html,p,section{box-sizing:border-box}.bt{border-top-style:solid;border-top-width:1px}.bl{border-left-style:solid;border-left-width:1px}.b--orange{border-color:#f38020}.br1{border-radius:.125rem}.bw2{border-width:.25rem}.dib{display:inline-block}.sans-serif{font-family:open sans,-apple-system,BlinkMacSystemFont,avenir next,avenir,helvetica neue,helvetica,ubuntu,roboto,noto,segoe ui,arial,sans-serif}.code{font-family:Consolas,monaco,monospace}.b{font-weight:700}.fw3{font-weight:300}.fw4{font-weight:400}.fw5{font-weight:500}.fw6{font-weight:600}.lh-copy{line-height:1.5}.link{text-decoration:none}.link,.link:active,.link:focus,.link:hover,.link:link,.link:visited{transition:color .15s ease-in}.link:focus{outline:1px dotted currentColor}.mw-100{max-width:100%}.mw4{max-width:8rem}.mw7{max-width:48rem}.bg-light-gray{background-color:#f7f7f7}.link-hover:hover{background-color:#1f679e}.white{color:#fff}.bg-white{background-color:#fff}.bg-blue{background-color:#408bc9}.pb2{padding-bottom:.5rem}.pb6{padding-bottom:8rem}.pt3{padding-top:1rem}.pt5{padding-top:4rem}.pv2{padding-top:.5rem;padding-bottom:.5rem}.ph3{padding-left:1rem;padding-right:1rem}.ph4{padding-left:2rem;padding-right:2rem}.ml0{margin-left:0}.mb1{margin-bottom:.25rem}.mb2{margin-bottom:.5rem}.mb3{margin-bottom:1rem}.mt5{margin-top:4rem}.ttu{text-transform:uppercase}.f4{font-size:1.25rem}.f5{font-size:1rem}.f6{font-size:.875rem}.f7{font-size:.75rem}.measure{max-width:30em}.center{margin-left:auto}.center{margin-right:auto}@media screen and (min-width:30em){.f2-ns{font-size:2.25rem}}@media screen and (min-width:30em) and (max-width:60em){.f5-m{font-size:1rem}}@media screen and (min-width:60em){.f4-l{font-size:1.25rem}}
.st0{fill:#FFF}.st1{fill:#f48120}.st2{fill:#faad3f}.st3{fill:#404041}
</style>
</head>
<body class="sans-serif black">
<div class="bt bw2 b--orange bg-white pb6">
<div class="mw7 center ph4 pt3">
<svg id="Layer_2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 109 40.5" class="mw4">
<path class="st0" d="M98.6 14.2L93 12.9l-1-.4-25.7.2v12.4l32.3.1z"/>
<path class="st1" d="M88.1 24c.3-1 .2-2-.3-2.6-.5-.6-1.2-1-2.1-1.1l-17.4-.2c-.1 0-.2-.1-.3-.1-.1-.1-.1-.2 0-.3.1-.2.2-.3.4-.3l17.5-.2c2.1-.1 4.3-1.8 5.1-3.8l1-2.6c0-.1.1-.2 0-.3-1.1-5.1-5.7-8.9-11.1-8.9-5 0-9.3 3.2-10.8 7.7-1-.7-2.2-1.1-3.6-1-2.4.2-4.3 2.2-4.6 4.6-.1.6 0 1.2.1 1.8-3.9.1-7.1 3.3-7.1 7.3 0 .4 0 .7.1 1.1 0 .2.2.3.3.3h32.1c.2 0 .4-.1.4-.3l.3-1.1z"/>
<path class="st2" d="M93.6 12.8h-.5c-.1 0-.2.1-.3.2l-.7 2.4c-.3 1-.2 2 .3 2.6.5.6 1.2 1 2.1 1.1l3.7.2c.1 0 .2.1.3.1.1.1.1.2 0 .3-.1.2-.2.3-.4.3l-3.8.2c-2.1.1-4.3 1.8-5.1 3.8l-.2.9c-.1.1 0 .3.2.3h13.2c.2 0 .3-.1.3-.3.2-.8.4-1.7.4-2.6 0-5.2-4.3-9.5-9.5-9.5"/>
<path class="st3" d="M104.4 30.8c-.5 0-.9-.4-.9-.9s.4-.9.9-.9.9.4.9.9-.4.9-.9.9m0-1.6c-.4 0-.7.3-.7.7 0 .4.3.7.7.7.4 0 .7-.3.7-.7 0-.4-.3-.7-.7-.7m.4 1.2h-.2l-.2-.3h-.2v.3h-.2v-.9h.5c.2 0 .3.1.3.3 0 .1-.1.2-.2.3l.2.3zm-.3-.5c.1 0 .1 0 .1-.1s-.1-.1-.1-.1h-.3v.3h.3zM14.8 29H17v6h3.8v1.9h-6zM23.1 32.9c0-2.3 1.8-4.1 4.3-4.1s4.2 1.8 4.2 4.1-1.8 4.1-4.3 4.1c-2.4 0-4.2-1.8-4.2-4.1m6.3 0c0-1.2-.8-2.2-2-2.2s-2 1-2 2.1.8 2.1 2 2.1c1.2.2 2-.8 2-2M34.3 33.4V29h2.2v4.4c0 1.1.6 1.7 1.5 1.7s1.5-.5 1.5-1.6V29h2.2v4.4c0 2.6-1.5 3.7-3.7 3.7-2.3-.1-3.7-1.2-3.7-3.7M45 29h3.1c2.8 0 4.5 1.6 4.5 3.9s-1.7 4-4.5 4h-3V29zm3.1 5.9c1.3 0 2.2-.7 2.2-2s-.9-2-2.2-2h-.9v4h.9zM55.7 29H62v1.9h-4.1v1.3h3.7V34h-3.7v2.9h-2.2zM65.1 29h2.2v6h3.8v1.9h-6zM76.8 28.9H79l3.4 8H80l-.6-1.4h-3.1l-.6 1.4h-2.3l3.4-8zm2 4.9l-.9-2.2-.9 2.2h1.8zM85.2 29h3.7c1.2 0 2 .3 2.6.9.5.5.7 1.1.7 1.8 0 1.2-.6 2-1.6 2.4l1.9 2.8H90l-1.6-2.4h-1v2.4h-2.2V29zm3.6 3.8c.7 0 1.2-.4 1.2-.9 0-.6-.5-.9-1.2-.9h-1.4v1.9h1.4zM95.3 29h6.4v1.8h-4.2V32h3.8v1.8h-3.8V35h4.3v1.9h-6.5zM10 33.9c-.3.7-1 1.2-1.8 1.2-1.2 0-2-1-2-2.1s.8-2.1 2-2.1c.9 0 1.6.6 1.9 1.3h2.3c-.4-1.9-2-3.3-4.2-3.3-2.4 0-4.3 1.8-4.3 4.1s1.8 4.1 4.2 4.1c2.1 0 3.7-1.4 4.2-3.2H10z"/>
</svg>
<h1 class="f4 f2-ns mt5 fw5">Congrats! You created your first tunnel!</h1>
<p class="f6 f5-m f4-l measure lh-copy fw3">
Argo Tunnel exposes locally running applications to the internet by
running an encrypted, virtual tunnel from your laptop or server to
Cloudflare's edge network.
</p>
<p class="b f5 mt5 fw6">Ready for the next step?</p>
<a
class="fw6 link white bg-blue ph4 pv2 br1 dib f5 link-hover"
style="border-bottom: 1px solid #1f679e"
href="https://developers.cloudflare.com/argo-tunnel/">
Get started here
</a>
<section>
<h4 class="f6 fw4 pt5 mb2">Request</h4>
<dl class="bl bw2 b--orange ph3 pt3 pb2 bg-light-gray f7 code overflow-x-auto mw-100">
<dd class="ml0 mb3 f5">Method: {{.Request.Method}}</dd>
<dd class="ml0 mb3 f5">Protocol: {{.Request.Proto}}</dd>
<dd class="ml0 mb3 f5">Request URL: {{.Request.URL}}</dd>
<dd class="ml0 mb3 f5">Transfer encoding: {{.Request.TransferEncoding}}</dd>
<dd class="ml0 mb3 f5">Host: {{.Request.Host}}</dd>
<dd class="ml0 mb3 f5">Remote address: {{.Request.RemoteAddr}}</dd>
<dd class="ml0 mb3 f5">Request URI: {{.Request.RequestURI}}</dd>
{{range $key, $value := .Request.Header}}
<dd class="ml0 mb3 f5">Header: {{$key}}, Value: {{$value}}</dd>
{{end}}
<dd class="ml0 mb3 f5">Body: {{.Body}}</dd>
</dl>
</section>
</div>
</div>
</body>
</html>
`
func hello(c *cli.Context) error {
address := fmt.Sprintf(":%d", c.Int("port"))
listener, err := createListener(address)
if err != nil {
return err
}
defer listener.Close()
err = startHelloWorldServer(listener, nil)
return err
}
func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
Log.Infof("Starting Hello World server at %s", listener.Addr())
serverName := defaultServerName
if hostname, err := os.Hostname(); err == nil {
serverName = hostname
}
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
go func() {
<-shutdownC
httpServer.Close()
}()
http.HandleFunc("/uptime", uptimeHandler(time.Now()))
http.HandleFunc("/ws", websocketHandler(upgrader))
http.HandleFunc("/", rootHandler(serverName))
err := httpServer.Serve(listener)
return err
}
func createListener(address string) (net.Listener, error) {
certificate, err := tlsconfig.GetHelloCertificate()
if err != nil {
return nil, err
}
// If the port in address is empty, a port number is automatically chosen
listener, err := tls.Listen(
"tcp",
address,
&tls.Config{Certificates: []tls.Certificate{certificate}})
return listener, err
}
func uptimeHandler(startTime time.Time) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Note that if autoupdate is enabled, the uptime is reset when a new client
// release is available
resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()}
respJson, err := json.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
} else {
w.Header().Set("Content-Type", "application/json")
w.Write(respJson)
}
}
}
func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
if err := conn.WriteMessage(mt, message); err != nil {
break
}
}
}
}
func rootHandler(serverName string) http.HandlerFunc {
responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
return func(w http.ResponseWriter, r *http.Request) {
var buffer bytes.Buffer
var body string
rawBody, err := ioutil.ReadAll(r.Body)
if err == nil {
body = string(rawBody)
} else {
body = ""
}
err = responseTemplate.Execute(&buffer, &templateData{
ServerName: serverName,
Request: r,
Body: body,
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "error: %v", err)
} else {
buffer.WriteTo(w)
}
}
}

View File

@ -0,0 +1,35 @@
package main
import (
"testing"
)
func TestCreateListenerHostAndPortSuccess(t *testing.T) {
listener, err := createListener("localhost:1234")
if err != nil {
t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}
func TestCreateListenerOnlyHostSuccess(t *testing.T) {
listener, err := createListener("localhost:")
if err != nil {
t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}
func TestCreateListenerOnlyPortSuccess(t *testing.T) {
listener, err := createListener(":8888")
if err != nil {
t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}

View File

@ -0,0 +1,292 @@
// +build linux
package main
import (
"fmt"
"os"
"path/filepath"
cli "gopkg.in/urfave/cli.v2"
)
func runApp(app *cli.App) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel system service",
Subcommands: []*cli.Command{
&cli.Command{
Name: "install",
Usage: "Install Argo Tunnel as a system service",
Action: installLinuxService,
},
&cli.Command{
Name: "uninstall",
Usage: "Uninstall the Argo Tunnel service",
Action: uninstallLinuxService,
},
},
})
app.Run(os.Args)
}
const serviceConfigDir = "/etc/cloudflared"
var systemdTemplates = []ServiceTemplate{
{
Path: "/etc/systemd/system/cloudflared.service",
Content: `[Unit]
Description=Argo Tunnel
After=network.target
[Service]
TimeoutStartSec=0
Type=notify
ExecStart={{ .Path }} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --no-autoupdate
Restart=on-failure
RestartSec=5s
[Install]
WantedBy=multi-user.target
`,
},
{
Path: "/etc/systemd/system/cloudflared-update.service",
Content: `[Unit]
Description=Update Argo Tunnel
After=network.target
[Service]
ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 64 ]; then systemctl restart cloudflared; exit 0; fi; exit $code'
`,
},
{
Path: "/etc/systemd/system/cloudflared-update.timer",
Content: `[Unit]
Description=Update Argo Tunnel
[Timer]
OnUnitActiveSec=1d
[Install]
WantedBy=timers.target
`,
},
}
var sysvTemplate = ServiceTemplate{
Path: "/etc/init.d/cloudflared",
FileMode: 0755,
Content: `# For RedHat and cousins:
# chkconfig: 2345 99 01
# description: Argo Tunnel agent
# processname: {{.Path}}
### BEGIN INIT INFO
# Provides: {{.Path}}
# Required-Start:
# Required-Stop:
# Default-Start: 2 3 4 5
# Default-Stop: 0 1 6
# Short-Description: Argo Tunnel
# Description: Argo Tunnel agent
### END INIT INFO
cmd="{{.Path}} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --pidfile /var/run/$name.pid --autoupdate-freq 24h0m0s"
name=$(basename $(readlink -f $0))
pid_file="/var/run/$name.pid"
stdout_log="/var/log/$name.log"
stderr_log="/var/log/$name.err"
[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name
get_pid() {
cat "$pid_file"
}
is_running() {
[ -f "$pid_file" ] && ps $(get_pid) > /dev/null 2>&1
}
case "$1" in
start)
if is_running; then
echo "Already started"
else
echo "Starting $name"
$cmd >> "$stdout_log" 2>> "$stderr_log" &
echo $! > "$pid_file"
if ! is_running; then
echo "Unable to start, see $stdout_log and $stderr_log"
exit 1
fi
fi
;;
stop)
if is_running; then
echo -n "Stopping $name.."
kill $(get_pid)
for i in {1..10}
do
if ! is_running; then
break
fi
echo -n "."
sleep 1
done
echo
if is_running; then
echo "Not stopped; may still be shutting down or shutdown may have failed"
exit 1
else
echo "Stopped"
if [ -f "$pid_file" ]; then
rm "$pid_file"
fi
fi
else
echo "Not running"
fi
;;
restart)
$0 stop
if is_running; then
echo "Unable to stop, will not attempt to start"
exit 1
fi
$0 start
;;
status)
if is_running; then
echo "Running"
else
echo "Stopped"
exit 1
fi
;;
*)
echo "Usage: $0 {start|stop|restart|status}"
exit 1
;;
esac
exit 0
`,
}
func isSystemd() bool {
if _, err := os.Stat("/run/systemd/system"); err == nil {
return true
}
return false
}
func installLinuxService(c *cli.Context) error {
etPath, err := os.Executable()
if err != nil {
return fmt.Errorf("error determining executable path: %v", err)
}
templateArgs := ServiceTemplateArgs{Path: etPath}
defaultConfigDir := filepath.Dir(c.String("config"))
defaultConfigFile := filepath.Base(c.String("config"))
if err = copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile); err != nil {
Log.WithError(err).Infof("Failed to copy user configuration. Before running the service, ensure that %s contains two files, %s and %s",
serviceConfigDir, credentialFile, defaultConfigFiles[0])
return err
}
switch {
case isSystemd():
Log.Infof("Using Systemd")
return installSystemd(&templateArgs)
default:
Log.Infof("Using Sysv")
return installSysv(&templateArgs)
}
}
func installSystemd(templateArgs *ServiceTemplateArgs) error {
for _, serviceTemplate := range systemdTemplates {
err := serviceTemplate.Generate(templateArgs)
if err != nil {
Log.WithError(err).Infof("error generating service template")
return err
}
}
if err := runCommand("systemctl", "enable", "cloudflared.service"); err != nil {
Log.WithError(err).Infof("systemctl enable cloudflared.service error")
return err
}
if err := runCommand("systemctl", "start", "cloudflared-update.timer"); err != nil {
Log.WithError(err).Infof("systemctl start cloudflared-update.timer error")
return err
}
Log.Infof("systemctl daemon-reload")
return runCommand("systemctl", "daemon-reload")
}
func installSysv(templateArgs *ServiceTemplateArgs) error {
confPath, err := sysvTemplate.ResolvePath()
if err != nil {
Log.WithError(err).Infof("error resolving system path")
return err
}
if err := sysvTemplate.Generate(templateArgs); err != nil {
Log.WithError(err).Infof("error generating system template")
return err
}
for _, i := range [...]string{"2", "3", "4", "5"} {
if err := os.Symlink(confPath, "/etc/rc"+i+".d/S50et"); err != nil {
continue
}
}
for _, i := range [...]string{"0", "1", "6"} {
if err := os.Symlink(confPath, "/etc/rc"+i+".d/K02et"); err != nil {
continue
}
}
return nil
}
func uninstallLinuxService(c *cli.Context) error {
switch {
case isSystemd():
Log.Infof("Using Systemd")
return uninstallSystemd()
default:
Log.Infof("Using Sysv")
return uninstallSysv()
}
}
func uninstallSystemd() error {
if err := runCommand("systemctl", "disable", "cloudflared.service"); err != nil {
Log.WithError(err).Infof("systemctl disable cloudflared.service error")
return err
}
if err := runCommand("systemctl", "stop", "cloudflared-update.timer"); err != nil {
Log.WithError(err).Infof("systemctl stop cloudflared-update.timer error")
return err
}
for _, serviceTemplate := range systemdTemplates {
if err := serviceTemplate.Remove(); err != nil {
Log.WithError(err).Infof("error removing service template")
return err
}
}
Log.Infof("Successfully uninstall cloudflared service")
return nil
}
func uninstallSysv() error {
if err := sysvTemplate.Remove(); err != nil {
Log.WithError(err).Infof("error removing service template")
return err
}
for _, i := range [...]string{"2", "3", "4", "5"} {
if err := os.Remove("/etc/rc" + i + ".d/S50et"); err != nil {
continue
}
}
for _, i := range [...]string{"0", "1", "6"} {
if err := os.Remove("/etc/rc" + i + ".d/K02et"); err != nil {
continue
}
}
Log.Infof("Successfully uninstall cloudflared service")
return nil
}

194
cmd/cloudflared/login.go Normal file
View File

@ -0,0 +1,194 @@
package main
import (
"crypto/rand"
"encoding/base32"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"syscall"
"time"
homedir "github.com/mitchellh/go-homedir"
cli "gopkg.in/urfave/cli.v2"
)
const baseLoginURL = "https://www.cloudflare.com/a/warp"
const baseCertStoreURL = "https://login.cloudflarewarp.com"
const clientTimeout = time.Minute * 20
func login(c *cli.Context) error {
configPath, err := homedir.Expand(defaultConfigDirs[0])
if err != nil {
return err
}
ok, err := fileExists(configPath)
if !ok && err == nil {
// create config directory if doesn't already exist
err = os.Mkdir(configPath, 0700)
}
if err != nil {
return err
}
path := filepath.Join(configPath, credentialFile)
fileInfo, err := os.Stat(path)
if err == nil && fileInfo.Size() > 0 {
fmt.Fprintf(os.Stderr, `You have an existing certificate at %s which login would overwrite.
If this is intentional, please move or delete that file then run this command again.
`, path)
return nil
}
if err != nil && err.(*os.PathError).Err != syscall.ENOENT {
return err
}
// for local debugging
baseURL := baseCertStoreURL
if c.IsSet("url") {
baseURL = c.String("url")
}
// Generate a random post URL
certURL := baseURL + generateRandomPath()
loginURL, err := url.Parse(baseLoginURL)
if err != nil {
// shouldn't happen, URL is hardcoded
return err
}
loginURL.RawQuery = "callback=" + url.QueryEscape(certURL)
err = open(loginURL.String())
if err != nil {
fmt.Fprintf(os.Stderr, `Please open the following URL and log in with your Cloudflare account:
%s
Leave cloudflared running to install the certificate automatically.
`, loginURL.String())
} else {
fmt.Fprintf(os.Stderr, `A browser window should have opened at the following URL:
%s
If the browser failed to open, open it yourself and visit the URL above.
`, loginURL.String())
}
if download(certURL, path) {
fmt.Fprintf(os.Stderr, `You have successfully logged in.
If you wish to copy your credentials to a server, they have been saved to:
%s
`, path)
} else {
fmt.Fprintf(os.Stderr, `Failed to write the certificate due to the following error:
%v
Your browser will download the certificate instead. You will have to manually
copy it to the following path:
%s
`, err, path)
}
return nil
}
// generateRandomPath generates a random URL to associate with the certificate.
func generateRandomPath() string {
randomBytes := make([]byte, 40)
_, err := rand.Read(randomBytes)
if err != nil {
panic(err)
}
return "/" + base32.StdEncoding.EncodeToString(randomBytes)
}
// open opens the specified URL in the default browser of the user.
func open(url string) error {
var cmd string
var args []string
switch runtime.GOOS {
case "windows":
cmd = "cmd"
args = []string{"/c", "start"}
case "darwin":
cmd = "open"
default: // "linux", "freebsd", "openbsd", "netbsd"
cmd = "xdg-open"
}
args = append(args, url)
return exec.Command(cmd, args...).Start()
}
func download(certURL, filePath string) bool {
client := &http.Client{Timeout: clientTimeout}
// attempt a (long-running) certificate get
for i := 0; i < 20; i++ {
ok, err := tryDownload(client, certURL, filePath)
if ok {
putSuccess(client, certURL)
return true
}
if err != nil {
Log.WithError(err).Error("Error fetching certificate")
return false
}
}
return false
}
func tryDownload(client *http.Client, certURL, filePath string) (ok bool, err error) {
resp, err := client.Get(certURL)
if err != nil {
return false, err
}
defer resp.Body.Close()
if resp.StatusCode == 404 {
return false, nil
}
if resp.StatusCode != 200 {
return false, fmt.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
}
if resp.Header.Get("Content-Type") != "application/x-pem-file" {
return false, fmt.Errorf("Unexpected content type %s", resp.Header.Get("Content-Type"))
}
// write response
file, err := os.Create(filePath)
if err != nil {
return false, err
}
defer file.Close()
written, err := io.Copy(file, resp.Body)
switch {
case err != nil:
return false, err
case resp.ContentLength != written && resp.ContentLength != -1:
return false, fmt.Errorf("Short read (%d bytes) from server while writing certificate", written)
default:
return true, nil
}
}
func putSuccess(client *http.Client, certURL string) {
// indicate success to the relay server
req, err := http.NewRequest("PUT", certURL+"/ok", nil)
if err != nil {
Log.WithError(err).Error("HTTP request error")
return
}
resp, err := client.Do(req)
if err != nil {
Log.WithError(err).Error("HTTP error")
return
}
resp.Body.Close()
if resp.StatusCode != 200 {
Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
}
}

View File

@ -0,0 +1,97 @@
// +build darwin
package main
import (
"fmt"
"os"
"gopkg.in/urfave/cli.v2"
)
const launchAgentIdentifier = "com.cloudflare.cloudflared"
func runApp(app *cli.App) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel launch agent",
Subcommands: []*cli.Command{
{
Name: "install",
Usage: "Install Argo Tunnel as an user launch agent",
Action: installLaunchd,
},
{
Name: "uninstall",
Usage: "Uninstall the Argo Tunnel launch agent",
Action: uninstallLaunchd,
},
},
})
app.Run(os.Args)
}
var launchdTemplate = ServiceTemplate{
Path: fmt.Sprintf("~/Library/LaunchAgents/%s.plist", launchAgentIdentifier),
Content: fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>%s</string>
<key>Program</key>
<string>{{ .Path }}</string>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/tmp/%s.out.log</string>
<key>StandardErrorPath</key>
<string>/tmp/%s.err.log</string>
<key>KeepAlive</key>
<dict>
<key>NetworkState</key>
<true/>
</dict>
<key>ThrottleInterval</key>
<integer>20</integer>
</dict>
</plist>`, launchAgentIdentifier, launchAgentIdentifier, launchAgentIdentifier),
}
func installLaunchd(c *cli.Context) error {
Log.Infof("Installing Argo Tunnel as an user launch agent")
etPath, err := os.Executable()
if err != nil {
Log.WithError(err).Infof("error determining executable path")
return fmt.Errorf("error determining executable path: %v", err)
}
templateArgs := ServiceTemplateArgs{Path: etPath}
err = launchdTemplate.Generate(&templateArgs)
if err != nil {
Log.WithError(err).Infof("error generating launchd template")
return err
}
plistPath, err := launchdTemplate.ResolvePath()
if err != nil {
Log.WithError(err).Infof("error resolving launchd template path")
return err
}
Log.Infof("Outputs are logged in %s and %s", fmt.Sprintf("/tmp/%s.out.log", launchAgentIdentifier), fmt.Sprintf("/tmp/%s.err.log", launchAgentIdentifier))
return runCommand("launchctl", "load", plistPath)
}
func uninstallLaunchd(c *cli.Context) error {
Log.Infof("Uninstalling Argo Tunnel as an user launch agent")
plistPath, err := launchdTemplate.ResolvePath()
if err != nil {
Log.WithError(err).Infof("error resolving launchd template path")
return err
}
err = runCommand("launchctl", "unload", plistPath)
if err != nil {
Log.WithError(err).Infof("error unloading")
return err
}
Log.Infof("Outputs are logged in %s and %s", fmt.Sprintf("/tmp/%s.out.log", launchAgentIdentifier), fmt.Sprintf("/tmp/%s.err.log", launchAgentIdentifier))
return launchdTemplate.Remove()
}

794
cmd/cloudflared/main.go Normal file
View File

@ -0,0 +1,794 @@
package main
import (
"crypto/tls"
"encoding/hex"
"fmt"
"io/ioutil"
"math/rand"
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"sync"
"syscall"
"time"
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation"
"github.com/facebookgo/grace/gracenet"
"github.com/getsentry/raven-go"
"github.com/mitchellh/go-homedir"
"github.com/rifflock/lfshook"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh/terminal"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
"github.com/coreos/go-systemd/daemon"
"github.com/pkg/errors"
)
const (
sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
credentialFile = "cert.pem"
quickStartUrl = "https://developers.cloudflare.com/argo-tunnel/quickstart/quickstart/"
noAutoupdateMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/"
licenseUrl = "https://developers.cloudflare.com/argo-tunnel/licence/"
)
var listeners = gracenet.Net{}
var Version = "DEV"
var BuildTime = "unknown"
var Log *logrus.Logger
var defaultConfigFiles = []string{"config.yml", "config.yaml"}
// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible
var defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp"}
// Shutdown channel used by the app. When closed, app must terminate.
// May be closed by the Windows service runner.
var shutdownC chan struct{}
type BuildAndRuntimeInfo struct {
GoOS string `json:"go_os"`
GoVersion string `json:"go_version"`
GoArch string `json:"go_arch"`
WarpVersion string `json:"warp_version"`
WarpFlags map[string]interface{} `json:"warp_flags"`
WarpEnvs map[string]string `json:"warp_envs"`
}
func main() {
metrics.RegisterBuildInfo(BuildTime, Version)
raven.SetDSN(sentryDSN)
raven.SetRelease(Version)
shutdownC = make(chan struct{})
app := &cli.App{}
app.Name = "cloudflared"
app.Copyright = fmt.Sprintf(`(c) %d Cloudflare Inc.
Use is subject to the license agreement at %s`, time.Now().Year(), licenseUrl)
app.Usage = "Cloudflare reverse tunnelling proxy agent"
app.ArgsUsage = "origin-url"
app.Version = fmt.Sprintf("%s (built %s)", Version, BuildTime)
app.Description = `A reverse tunnel proxy agent that connects to Cloudflare's infrastructure.
Upon connecting, you are assigned a unique subdomain on cftunnel.com.
You need to specify a hostname on a zone you control.
A DNS record will be created to CNAME your hostname to the unique subdomain on cftunnel.com.
Requests made to Cloudflare's servers for your hostname will be proxied
through the tunnel to your local webserver.`
app.Flags = []cli.Flag{
&cli.StringFlag{
Name: "config",
Usage: "Specifies a config file in YAML format.",
Value: findDefaultConfigPath(),
},
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "autoupdate-freq",
Usage: "Autoupdate frequency. Default is 24h.",
Value: time.Hour * 24,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "no-autoupdate",
Usage: "Disable periodic check for updates, restarting the server with the new version.",
Value: false,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "is-autoupdated",
Usage: "Signal the new process that Argo Tunnel client has been autoupdated",
Value: false,
Hidden: true,
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "edge",
Usage: "Address of the Cloudflare tunnel server.",
EnvVars: []string{"TUNNEL_EDGE"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "cacert",
Usage: "Certificate Authority authenticating the Cloudflare tunnel connection.",
EnvVars: []string{"TUNNEL_CACERT"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "origincert",
Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: findDefaultOriginCertPath(),
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "url",
Value: "https://localhost:8080",
Usage: "Connect to the local webserver at `URL`.",
EnvVars: []string{"TUNNEL_URL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "hostname",
Usage: "Set a hostname on a Cloudflare zone to route traffic through this tunnel.",
EnvVars: []string{"TUNNEL_HOSTNAME"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "origin-server-name",
Usage: "Hostname on the origin server certificate.",
EnvVars: []string{"TUNNEL_ORIGIN_SERVER_NAME"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "id",
Usage: "A unique identifier used to tie connections to this tunnel instance.",
EnvVars: []string{"TUNNEL_ID"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "lb-pool",
Usage: "The name of a (new/existing) load balancing pool to add this origin to.",
EnvVars: []string{"TUNNEL_LB_POOL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "api-key",
Usage: "This parameter has been deprecated since version 2017.10.1.",
EnvVars: []string{"TUNNEL_API_KEY"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "api-email",
Usage: "This parameter has been deprecated since version 2017.10.1.",
EnvVars: []string{"TUNNEL_API_EMAIL"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "api-ca-key",
Usage: "This parameter has been deprecated since version 2017.10.1.",
EnvVars: []string{"TUNNEL_API_CA_KEY"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "metrics",
Value: "localhost:",
Usage: "Listen address for metrics reporting.",
EnvVars: []string{"TUNNEL_METRICS"},
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "metrics-update-freq",
Usage: "Frequency to update tunnel metrics",
Value: time.Second * 5,
EnvVars: []string{"TUNNEL_METRICS_UPDATE_FREQ"},
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "tag",
Usage: "Custom tags used to identify this tunnel, in format `KEY=VALUE`. Multiple tags may be specified",
EnvVars: []string{"TUNNEL_TAG"},
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "heartbeat-interval",
Usage: "Minimum idle time before sending a heartbeat.",
Value: time.Second * 5,
Hidden: true,
}),
altsrc.NewUint64Flag(&cli.Uint64Flag{
Name: "heartbeat-count",
Usage: "Minimum number of unacked heartbeats to send before closing the connection.",
Value: 5,
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "loglevel",
Value: "info",
Usage: "Application logging level {panic, fatal, error, warn, info, debug}",
EnvVars: []string{"TUNNEL_LOGLEVEL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "proto-loglevel",
Value: "warn",
Usage: "Protocol logging level {panic, fatal, error, warn, info, debug}",
EnvVars: []string{"TUNNEL_PROTO_LOGLEVEL"},
}),
altsrc.NewUintFlag(&cli.UintFlag{
Name: "retries",
Value: 5,
Usage: "Maximum number of retries for connection/protocol errors.",
EnvVars: []string{"TUNNEL_RETRIES"},
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "hello-world",
Value: false,
Usage: "Run Hello World Server",
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "pidfile",
Usage: "Write the application's PID to this file after first successful connection.",
EnvVars: []string{"TUNNEL_PIDFILE"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "logfile",
Usage: "Save application log to this file for reporting issues.",
EnvVars: []string{"TUNNEL_LOGFILE"},
}),
altsrc.NewIntFlag(&cli.IntFlag{
Name: "ha-connections",
Value: 4,
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-connect-timeout",
Usage: "HTTP proxy timeout for establishing a new connection",
Value: time.Second * 30,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-tls-timeout",
Usage: "HTTP proxy timeout for completing a TLS handshake",
Value: time.Second * 10,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-tcp-keepalive",
Usage: "HTTP proxy TCP keepalive duration",
Value: time.Second * 30,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "proxy-no-happy-eyeballs",
Usage: "HTTP proxy should disable \"happy eyeballs\" for IPv4/v6 fallback",
}),
altsrc.NewIntFlag(&cli.IntFlag{
Name: "proxy-keepalive-connections",
Usage: "HTTP proxy maximum keepalive connection pool size",
Value: 100,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-keepalive-timeout",
Usage: "HTTP proxy timeout for closing an idle connection",
Value: time.Second * 90,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "proxy-dns",
Usage: "Run a DNS over HTTPS proxy server.",
EnvVars: []string{"TUNNEL_DNS"},
}),
altsrc.NewUintFlag(&cli.UintFlag{
Name: "proxy-dns-port",
Value: 53,
Usage: "Listen on given port for the DNS over HTTPS proxy server.",
EnvVars: []string{"TUNNEL_DNS_PORT"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "proxy-dns-address",
Usage: "Listen address for the DNS over HTTPS proxy server.",
Value: "localhost",
EnvVars: []string{"TUNNEL_DNS_ADDRESS"},
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "proxy-dns-upstream",
Usage: "Upstream endpoint URL, you can specify multiple endpoints for redundancy.",
Value: cli.NewStringSlice("https://dns.cloudflare.com/.well-known/dns-query"),
EnvVars: []string{"TUNNEL_DNS_UPSTREAM"},
}),
}
app.Action = func(c *cli.Context) error {
raven.CapturePanic(func() { startServer(c) }, nil)
return nil
}
app.Before = func(context *cli.Context) error {
Log = logrus.New()
inputSource, err := findInputSourceContext(context)
if err != nil {
Log.WithError(err).Infof("Cannot load configuration from %s", context.String("config"))
return err
} else if inputSource != nil {
err := altsrc.ApplyInputSourceValues(context, inputSource, app.Flags)
if err != nil {
Log.WithError(err).Infof("Cannot apply configuration from %s", context.String("config"))
return err
}
Log.Infof("Applied configuration from %s", context.String("config"))
}
return nil
}
app.Commands = []*cli.Command{
{
Name: "update",
Action: update,
Usage: "Update the agent if a new version exists",
ArgsUsage: " ",
Description: `Looks for a new version on the offical download server.
If a new version exists, updates the agent binary and quits.
Otherwise, does nothing.
To determine if an update happened in a script, check for error code 64.`,
},
{
Name: "login",
Action: login,
Usage: "Generate a configuration file with your login details",
ArgsUsage: " ",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "url",
Hidden: true,
},
},
},
{
Name: "hello",
Action: hello,
Usage: "Run a simple \"Hello World\" server for testing Argo Tunnel.",
Flags: []cli.Flag{
&cli.IntFlag{
Name: "port",
Usage: "Listen on the selected port.",
Value: 8080,
},
},
ArgsUsage: " ", // can't be the empty string or we get the default output
},
{
Name: "proxy-dns",
Action: tunneldns.Run,
Usage: "Run a DNS over HTTPS proxy server.",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "metrics",
Value: "localhost:",
Usage: "Listen address for metrics reporting.",
EnvVars: []string{"TUNNEL_METRICS"},
},
&cli.StringFlag{
Name: "address",
Usage: "Listen address for the DNS over HTTPS proxy server.",
Value: "localhost",
EnvVars: []string{"TUNNEL_DNS_ADDRESS"},
},
&cli.IntFlag{
Name: "port",
Usage: "Listen on given port for the DNS over HTTPS proxy server.",
Value: 53,
EnvVars: []string{"TUNNEL_DNS_PORT"},
},
&cli.StringSliceFlag{
Name: "upstream",
Usage: "Upstream endpoint URL, you can specify multiple endpoints for redundancy.",
Value: cli.NewStringSlice("https://dns.cloudflare.com/.well-known/dns-query"),
EnvVars: []string{"TUNNEL_DNS_UPSTREAM"},
},
},
ArgsUsage: " ", // can't be the empty string or we get the default output
},
}
runApp(app)
}
func startServer(c *cli.Context) {
var wg sync.WaitGroup
errC := make(chan error)
wg.Add(2)
// If the user choose to supply all options through env variables,
// c.NumFlags() == 0 && c.NArg() == 0. For cloudflared to work, the user needs to at
// least provide a hostname.
if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" {
Log.Infof("No arguments were provided. You need to at least specify the hostname for this tunnel. See %s", quickStartUrl)
cli.ShowAppHelp(c)
return
}
logLevel, err := logrus.ParseLevel(c.String("loglevel"))
if err != nil {
Log.WithError(err).Fatal("Unknown logging level specified")
}
Log.SetLevel(logLevel)
protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
if err != nil {
Log.WithError(err).Fatal("Unknown protocol logging level specified")
}
protoLogger := logrus.New()
protoLogger.Level = protoLogLevel
if c.String("logfile") != "" {
if err := initLogFile(c, protoLogger); err != nil {
Log.Error(err)
}
}
if isAutoupdateEnabled(c) {
if initUpdate() {
return
}
Log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
go autoupdate(c.Duration("autoupdate-freq"), shutdownC)
}
hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil {
Log.WithError(err).Fatal("Invalid hostname")
}
clientID := c.String("id")
if !c.IsSet("id") {
clientID = generateRandomClientID()
}
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
if err != nil {
Log.WithError(err).Fatal("Tag parse failure")
}
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
if c.IsSet("hello-world") {
wg.Add(1)
listener, err := createListener("127.0.0.1:")
if err != nil {
listener.Close()
Log.WithError(err).Fatal("Cannot start Hello World Server")
}
go func() {
startHelloWorldServer(listener, shutdownC)
wg.Done()
listener.Close()
}()
c.Set("url", "https://"+listener.Addr().String())
}
if c.IsSet("proxy-dns") {
wg.Add(1)
listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(c.Uint("proxy-dns-port")), c.StringSlice("proxy-dns-upstream"))
if err != nil {
listener.Stop()
Log.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server")
}
go func() {
listener.Start()
<-shutdownC
listener.Stop()
wg.Done()
}()
}
url, err := validateUrl(c)
if err != nil {
Log.WithError(err).Fatal("Error validating url")
}
Log.Infof("Proxying tunnel requests to %s", url)
// Fail if the user provided an old authentication method
if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflared login")
}
// Check that the user has acquired a certificate using the log in command
originCertPath, err := homedir.Expand(c.String("origincert"))
if err != nil {
Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
}
ok, err := fileExists(originCertPath)
if err != nil {
Log.Fatalf("Cannot check if origin cert exists at path %s", c.String("origincert"))
}
if !ok {
Log.Fatalf(`Cannot find a valid certificate for your origin at the path:
%s
If the path above is wrong, specify the path with the -origincert option.
If you don't have a certificate signed by Cloudflare, run the command:
%s login
`, originCertPath, os.Args[0])
}
// Easier to send the certificate as []byte via RPC than decoding it at this point
originCert, err := ioutil.ReadFile(originCertPath)
if err != nil {
Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
}
tunnelMetrics := origin.NewTunnelMetrics()
httpTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: c.Duration("proxy-connect-timeout"),
KeepAlive: c.Duration("proxy-tcp-keepalive"),
DualStack: !c.Bool("proxy-no-happy-eyeballs"),
}).DialContext,
MaxIdleConns: c.Int("proxy-keepalive-connections"),
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()},
}
if !c.IsSet("hello-world") && c.IsSet("origin-server-name") {
httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
}
tunnelConfig := &origin.TunnelConfig{
EdgeAddrs: c.StringSlice("edge"),
OriginUrl: url,
Hostname: hostname,
OriginCert: originCert,
TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
ClientTlsConfig: httpTransport.TLSClientConfig,
Retries: c.Uint("retries"),
HeartbeatInterval: c.Duration("heartbeat-interval"),
MaxHeartbeats: c.Uint64("heartbeat-count"),
ClientID: clientID,
ReportedVersion: Version,
LBPool: c.String("lb-pool"),
Tags: tags,
HAConnections: c.Int("ha-connections"),
HTTPTransport: httpTransport,
Metrics: tunnelMetrics,
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
ProtocolLogger: protoLogger,
Logger: Log,
IsAutoupdated: c.Bool("is-autoupdated"),
}
connectedSignal := make(chan struct{})
go writePidFile(connectedSignal, c.String("pidfile"))
go func() {
errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal)
wg.Done()
}()
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
if err != nil {
Log.WithError(err).Fatal("Error opening metrics server listener")
}
go func() {
errC <- metrics.ServeMetrics(metricsListener, shutdownC)
wg.Done()
}()
var errCode int
err = WaitForSignal(errC, shutdownC)
if err != nil {
Log.WithError(err).Error("Quitting due to error")
raven.CaptureErrorAndWait(err, nil)
errCode = 1
} else {
Log.Info("Quitting...")
}
// Wait for clean exit, discarding all errors
go func() {
for range errC {
}
}()
wg.Wait()
os.Exit(errCode)
}
func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signals)
select {
case err := <-errC:
close(shutdownC)
return err
case <-signals:
close(shutdownC)
case <-shutdownC:
}
return nil
}
func update(_ *cli.Context) error {
if updateApplied() {
os.Exit(64)
}
return nil
}
func initUpdate() bool {
if updateApplied() {
os.Args = append(os.Args, "--is-autoupdated=true")
if _, err := listeners.StartProcess(); err != nil {
Log.WithError(err).Error("Unable to restart server automatically")
return false
}
return true
}
return false
}
func autoupdate(freq time.Duration, shutdownC chan struct{}) {
for {
if updateApplied() {
os.Args = append(os.Args, "--is-autoupdated=true")
if _, err := listeners.StartProcess(); err != nil {
Log.WithError(err).Error("Unable to restart server automatically")
}
close(shutdownC)
return
}
time.Sleep(freq)
}
}
func updateApplied() bool {
releaseInfo := checkForUpdates()
if releaseInfo.Updated {
Log.Infof("Updated to version %s", releaseInfo.Version)
return true
}
if releaseInfo.Error != nil {
Log.WithError(releaseInfo.Error).Error("Update check failed")
}
return false
}
func fileExists(path string) (bool, error) {
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
// ignore missing files
return false, nil
}
return false, err
}
f.Close()
return true, nil
}
// returns the first path that contains a cert.pem file. If none of the defaultConfigDirs
// (differs by OS for legacy reasons) contains a cert.pem file, return empty string
func findDefaultOriginCertPath() string {
for _, defaultConfigDir := range defaultConfigDirs {
originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, credentialFile))
if ok, _ := fileExists(originCertPath); ok {
return originCertPath
}
}
return ""
}
// returns the firt path that contains a config file. If none of the combination of
// defaultConfigDirs (differs by OS for legacy reasons) and defaultConfigFiles
// contains a config file, return empty string
func findDefaultConfigPath() string {
for _, configDir := range defaultConfigDirs {
for _, configFile := range defaultConfigFiles {
dirPath, err := homedir.Expand(configDir)
if err != nil {
return ""
}
path := filepath.Join(dirPath, configFile)
if ok, _ := fileExists(path); ok {
return path
}
}
}
return ""
}
func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) {
if context.String("config") != "" {
return altsrc.NewYamlSourceFromFile(context.String("config"))
}
return nil, nil
}
func generateRandomClientID() string {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
id := make([]byte, 32)
r.Read(id)
return hex.EncodeToString(id)
}
func writePidFile(waitForSignal chan struct{}, pidFile string) {
<-waitForSignal
daemon.SdNotify(false, "READY=1")
if pidFile == "" {
return
}
file, err := os.Create(pidFile)
if err != nil {
Log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
}
defer file.Close()
fmt.Fprintf(file, "%d", os.Getpid())
}
// validate url. It can be either from --url or argument
func validateUrl(c *cli.Context) (string, error) {
var url = c.String("url")
if c.NArg() > 0 {
if c.IsSet("url") {
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
}
url = c.Args().Get(0)
}
validUrl, err := validation.ValidateUrl(url)
return validUrl, err
}
func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error {
filePath, err := homedir.Expand(c.String("logfile"))
if err != nil {
return errors.Wrap(err, "Cannot resolve logfile path")
}
fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC
// do not truncate log file if the client has been autoupdated
if c.Bool("is-autoupdated") {
fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE
}
f, err := os.OpenFile(filePath, fileMode, 0664)
if err != nil {
errors.Wrap(err, fmt.Sprintf("Cannot open file %s", filePath))
}
defer f.Close()
pathMap := lfshook.PathMap{
logrus.InfoLevel: filePath,
logrus.ErrorLevel: filePath,
logrus.FatalLevel: filePath,
logrus.PanicLevel: filePath,
}
Log.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
flags := make(map[string]interface{})
envs := make(map[string]string)
for _, flag := range c.LocalFlagNames() {
flags[flag] = c.Generic(flag)
}
// Find env variables for Argo Tunnel
for _, env := range os.Environ() {
// All Argo Tunnel env variables start with TUNNEL_
if strings.Contains(env, "TUNNEL_") {
vars := strings.Split(env, "=")
if len(vars) == 2 {
envs[vars[0]] = vars[1]
}
}
}
Log.Infof("Argo Tunnel build and runtime configuration: %+v", BuildAndRuntimeInfo{
GoOS: runtime.GOOS,
GoVersion: runtime.Version(),
GoArch: runtime.GOARCH,
WarpVersion: Version,
WarpFlags: flags,
WarpEnvs: envs,
})
return nil
}
func isAutoupdateEnabled(c *cli.Context) bool {
if terminal.IsTerminal(int(os.Stdout.Fd())) {
Log.Info(noAutoupdateMessage)
return false
}
return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0
}

View File

@ -0,0 +1,192 @@
package main
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"text/template"
"github.com/mitchellh/go-homedir"
)
type ServiceTemplate struct {
Path string
Content string
FileMode os.FileMode
}
type ServiceTemplateArgs struct {
Path string
}
func (st *ServiceTemplate) ResolvePath() (string, error) {
resolvedPath, err := homedir.Expand(st.Path)
if err != nil {
return "", fmt.Errorf("error resolving path %s: %v", st.Path, err)
}
return resolvedPath, nil
}
func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
tmpl, err := template.New(st.Path).Parse(st.Content)
if err != nil {
return fmt.Errorf("error generating %s template: %v", st.Path, err)
}
resolvedPath, err := st.ResolvePath()
if err != nil {
return err
}
var buffer bytes.Buffer
err = tmpl.Execute(&buffer, args)
if err != nil {
return fmt.Errorf("error generating %s: %v", st.Path, err)
}
fileMode := os.FileMode(0644)
if st.FileMode != 0 {
fileMode = st.FileMode
}
err = ioutil.WriteFile(resolvedPath, buffer.Bytes(), fileMode)
if err != nil {
return fmt.Errorf("error writing %s: %v", resolvedPath, err)
}
return nil
}
func (st *ServiceTemplate) Remove() error {
resolvedPath, err := st.ResolvePath()
if err != nil {
return err
}
err = os.Remove(resolvedPath)
if err != nil {
return fmt.Errorf("error deleting %s: %v", resolvedPath, err)
}
return nil
}
func runCommand(command string, args ...string) error {
cmd := exec.Command(command, args...)
stderr, err := cmd.StderrPipe()
if err != nil {
Log.WithError(err).Infof("error getting stderr pipe")
return fmt.Errorf("error getting stderr pipe: %v", err)
}
err = cmd.Start()
if err != nil {
Log.WithError(err).Infof("error starting %s", command)
return fmt.Errorf("error starting %s: %v", command, err)
}
commandErr, _ := ioutil.ReadAll(stderr)
if len(commandErr) > 0 {
Log.Errorf("%s: %s", command, commandErr)
}
err = cmd.Wait()
if err != nil {
Log.WithError(err).Infof("%s returned error", command)
return fmt.Errorf("%s returned with error: %v", command, err)
}
return nil
}
func ensureConfigDirExists(configDir string) error {
ok, err := fileExists(configDir)
if !ok && err == nil {
err = os.Mkdir(configDir, 0700)
}
return err
}
// openFile opens the file at path. If create is set and the file exists, returns nil, true, nil
func openFile(path string, create bool) (file *os.File, exists bool, err error) {
expandedPath, err := homedir.Expand(path)
if err != nil {
return nil, false, err
}
if create {
fileInfo, err := os.Stat(expandedPath)
if err == nil && fileInfo.Size() > 0 {
return nil, true, nil
}
file, err = os.OpenFile(expandedPath, os.O_RDWR|os.O_CREATE, 0600)
} else {
file, err = os.Open(expandedPath)
}
return file, false, err
}
func copyCertificate(srcConfigDir, destConfigDir string) error {
destCredentialPath := filepath.Join(destConfigDir, credentialFile)
destFile, exists, err := openFile(destCredentialPath, true)
if err != nil {
return err
} else if exists {
// credentials already exist, do nothing
return nil
}
defer destFile.Close()
srcCredentialPath := filepath.Join(srcConfigDir, credentialFile)
srcFile, _, err := openFile(srcCredentialPath, false)
if err != nil {
return err
}
defer srcFile.Close()
// Copy certificate
_, err = io.Copy(destFile, srcFile)
if err != nil {
return fmt.Errorf("unable to copy %s to %s: %v", srcCredentialPath, destCredentialPath, err)
}
return nil
}
func copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile string) error {
if err := ensureConfigDirExists(serviceConfigDir); err != nil {
return err
}
if err := copyCertificate(defaultConfigDir, serviceConfigDir); err != nil {
return err
}
// Copy or create config
destConfigPath := filepath.Join(serviceConfigDir, defaultConfigFile)
destFile, exists, err := openFile(destConfigPath, true)
if err != nil {
Log.WithError(err).Infof("cannot open %s", destConfigPath)
return err
} else if exists {
// config already exists, do nothing
return nil
}
defer destFile.Close()
srcConfigPath := filepath.Join(defaultConfigDir, defaultConfigFile)
srcFile, _, err := openFile(srcConfigPath, false)
if err != nil {
fmt.Println("Your service needs a config file that at least specifies the hostname option.")
fmt.Println("Type in a hostname now, or leave it blank and create the config file later.")
fmt.Print("Hostname: ")
reader := bufio.NewReader(os.Stdin)
input, _ := reader.ReadString('\n')
if input == "" {
return err
}
fmt.Fprintf(destFile, "hostname: %s\n", input)
} else {
defer srcFile.Close()
_, err = io.Copy(destFile, srcFile)
if err != nil {
return fmt.Errorf("unable to copy %s to %s: %v", srcConfigPath, destConfigPath, err)
}
Log.Infof("Copied %s to %s", srcConfigPath, destConfigPath)
}
return nil
}

32
cmd/cloudflared/tag.go Normal file
View File

@ -0,0 +1,32 @@
package main
import (
"fmt"
"regexp"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// Restrict key names to characters allowed in an HTTP header name.
// Restrict key values to printable characters (what is recognised as data in an HTTP header value).
var tagRegexp = regexp.MustCompile("^([a-zA-Z0-9!#$%&'*+\\-.^_`|~]+)=([[:print:]]+)$")
func NewTagFromCLI(compoundTag string) (tunnelpogs.Tag, bool) {
matches := tagRegexp.FindStringSubmatch(compoundTag)
if len(matches) == 0 {
return tunnelpogs.Tag{}, false
}
return tunnelpogs.Tag{Name: matches[1], Value: matches[2]}, true
}
func NewTagSliceFromCLI(tags []string) ([]tunnelpogs.Tag, error) {
var tagSlice []tunnelpogs.Tag
for _, compoundTag := range tags {
if tag, ok := NewTagFromCLI(compoundTag); ok {
tagSlice = append(tagSlice, tag)
} else {
return nil, fmt.Errorf("Cannot parse tag value %s", compoundTag)
}
}
return tagSlice, nil
}

View File

@ -0,0 +1,46 @@
package main
import (
"testing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/stretchr/testify/assert"
)
func TestSingleTag(t *testing.T) {
testCases := []struct {
Input string
Output tunnelpogs.Tag
Fail bool
}{
{Input: "x=y", Output: tunnelpogs.Tag{Name: "x", Value: "y"}},
{Input: "More-Complex=Tag Values", Output: tunnelpogs.Tag{Name: "More-Complex", Value: "Tag Values"}},
{Input: "First=Equals=Wins", Output: tunnelpogs.Tag{Name: "First", Value: "Equals=Wins"}},
{Input: "x=", Fail: true},
{Input: "=y", Fail: true},
{Input: "=", Fail: true},
{Input: "No spaces allowed=in key names", Fail: true},
{Input: "omg\nwtf=bbq", Fail: true},
}
for i, testCase := range testCases {
tag, ok := NewTagFromCLI(testCase.Input)
assert.Equalf(t, !testCase.Fail, ok, "mismatched success for test case %d", i)
assert.Equalf(t, testCase.Output, tag, "mismatched output for test case %d", i)
}
}
func TestTagSlice(t *testing.T) {
tagSlice, err := NewTagSliceFromCLI([]string{"a=b", "c=d", "e=f"})
assert.NoError(t, err)
assert.Len(t, tagSlice, 3)
assert.Equal(t, "a", tagSlice[0].Name)
assert.Equal(t, "b", tagSlice[0].Value)
assert.Equal(t, "c", tagSlice[1].Name)
assert.Equal(t, "d", tagSlice[1].Value)
assert.Equal(t, "e", tagSlice[2].Name)
assert.Equal(t, "f", tagSlice[2].Value)
tagSlice, err = NewTagSliceFromCLI([]string{"a=b", "=", "e=f"})
assert.Error(t, err)
}

41
cmd/cloudflared/update.go Normal file
View File

@ -0,0 +1,41 @@
package main
import "github.com/equinox-io/equinox"
const appID = "app_cwbQae3Tpea"
var publicKey = []byte(`
-----BEGIN ECDSA PUBLIC KEY-----
MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE4OWZocTVZ8Do/L6ScLdkV+9A0IYMHoOf
dsCmJ/QZ6aw0w9qkkwEpne1Lmo6+0pGexZzFZOH6w5amShn+RXt7qkSid9iWlzGq
EKx0BZogHSor9Wy5VztdFaAaVbsJiCbO
-----END ECDSA PUBLIC KEY-----
`)
type ReleaseInfo struct {
Updated bool
Version string
Error error
}
func checkForUpdates() ReleaseInfo {
var opts equinox.Options
if err := opts.SetPublicKeyPEM(publicKey); err != nil {
return ReleaseInfo{Error: err}
}
resp, err := equinox.Check(appID, opts)
switch {
case err == equinox.NotAvailableErr:
return ReleaseInfo{}
case err != nil:
return ReleaseInfo{Error: err}
}
err = resp.Apply()
if err != nil {
return ReleaseInfo{Error: err}
}
return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion}
}

View File

@ -0,0 +1,166 @@
// +build windows
package main
// Copypasta from the example files:
// https://github.com/golang/sys/blob/master/windows/svc/example
import (
"fmt"
"os"
cli "gopkg.in/urfave/cli.v2"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/eventlog"
"golang.org/x/sys/windows/svc/mgr"
)
const (
windowsServiceName = "Cloudflared"
windowsServiceDescription = "Argo Tunnel agent"
)
func runApp(app *cli.App) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel Windows service",
Subcommands: []*cli.Command{
&cli.Command{
Name: "install",
Usage: "Install Argo Tunnel as a Windows service",
Action: installWindowsService,
},
&cli.Command{
Name: "uninstall",
Usage: "Uninstall the Argo Tunnel service",
Action: uninstallWindowsService,
},
},
})
isIntSess, err := svc.IsAnInteractiveSession()
if err != nil {
Log.Fatalf("failed to determine if we are running in an interactive session: %v", err)
}
if isIntSess {
app.Run(os.Args)
return
}
elog, err := eventlog.Open(windowsServiceName)
if err != nil {
Log.WithError(err).Infof("Cannot open event log for %s", windowsServiceName)
return
}
defer elog.Close()
elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName))
// Run executes service name by calling windowsService which is a Handler
// interface that implements Execute method
err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog})
if err != nil {
elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
return
}
elog.Info(1, fmt.Sprintf("%s service stopped", windowsServiceName))
}
type windowsService struct {
app *cli.App
elog *eventlog.Log
}
// called by the package code at the start of the service
func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) {
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
changes <- svc.Status{State: svc.StartPending}
go s.app.Run(args)
changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
loop:
for {
select {
case c := <-r:
switch c.Cmd {
case svc.Interrogate:
s.elog.Info(1, fmt.Sprintf("control request 1 #%d", c))
changes <- c.CurrentStatus
case svc.Stop:
s.elog.Info(1, "received stop control request")
break loop
case svc.Shutdown:
s.elog.Info(1, "received shutdown control request")
break loop
default:
s.elog.Error(1, fmt.Sprintf("unexpected control request #%d", c))
}
}
}
close(shutdownC)
changes <- svc.Status{State: svc.StopPending}
return
}
func installWindowsService(c *cli.Context) error {
Log.Infof("Installing Argo Tunnel Windows service")
exepath, err := os.Executable()
if err != nil {
Log.Infof("Cannot find path name that start the process")
return err
}
m, err := mgr.Connect()
if err != nil {
Log.WithError(err).Infof("Cannot establish a connection to the service control manager")
return err
}
defer m.Disconnect()
s, err := m.OpenService(windowsServiceName)
if err == nil {
s.Close()
Log.Errorf("service %s already exists", windowsServiceName)
return fmt.Errorf("service %s already exists", windowsServiceName)
}
config := mgr.Config{StartType: mgr.StartAutomatic, DisplayName: windowsServiceDescription}
s, err = m.CreateService(windowsServiceName, exepath, config)
if err != nil {
Log.Infof("Cannot install service %s", windowsServiceName)
return err
}
defer s.Close()
err = eventlog.InstallAsEventCreate(windowsServiceName, eventlog.Error|eventlog.Warning|eventlog.Info)
if err != nil {
s.Delete()
Log.WithError(err).Infof("Cannot install event logger")
return fmt.Errorf("SetupEventLogSource() failed: %s", err)
}
return nil
}
func uninstallWindowsService(c *cli.Context) error {
Log.Infof("Uninstalling Argo Tunnel Windows Service")
m, err := mgr.Connect()
if err != nil {
Log.Infof("Cannot establish a connection to the service control manager")
return err
}
defer m.Disconnect()
s, err := m.OpenService(windowsServiceName)
if err != nil {
Log.Infof("service %s is not installed", windowsServiceName)
return fmt.Errorf("service %s is not installed", windowsServiceName)
}
defer s.Close()
err = s.Delete()
if err != nil {
Log.Errorf("Cannot delete service %s", windowsServiceName)
return err
}
err = eventlog.Remove(windowsServiceName)
if err != nil {
Log.Infof("Cannot remove event logger")
return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
}
return nil
}

View File

@ -24,12 +24,6 @@ type activeStreamMap struct {
ignoreNewStreams bool ignoreNewStreams bool
} }
type FlowControlMetrics struct {
AverageReceiveWindowSize, AverageSendWindowSize float64
MinReceiveWindowSize, MaxReceiveWindowSize uint32
MinSendWindowSize, MaxSendWindowSize uint32
}
func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap { func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap {
m := &activeStreamMap{ m := &activeStreamMap{
streams: make(map[uint32]*MuxedStream), streams: make(map[uint32]*MuxedStream),
@ -169,45 +163,3 @@ func (m *activeStreamMap) Abort() {
} }
m.ignoreNewStreams = true m.ignoreNewStreams = true
} }
func (m *activeStreamMap) Metrics() *FlowControlMetrics {
m.Lock()
defer m.Unlock()
var averageReceiveWindowSize, averageSendWindowSize float64
var minReceiveWindowSize, maxReceiveWindowSize, minSendWindowSize, maxSendWindowSize uint32
i := 0
// The first variable in the range expression for map is the key, not index.
for _, stream := range m.streams {
// iterative mean: a(t+1) = a(t) + (a(t)-x)/(t+1)
windows := stream.FlowControlWindow()
averageReceiveWindowSize += (float64(windows.receiveWindow) - averageReceiveWindowSize) / float64(i+1)
averageSendWindowSize += (float64(windows.sendWindow) - averageSendWindowSize) / float64(i+1)
if i == 0 {
maxReceiveWindowSize = windows.receiveWindow
minReceiveWindowSize = windows.receiveWindow
maxSendWindowSize = windows.sendWindow
minSendWindowSize = windows.sendWindow
} else {
if windows.receiveWindow > maxReceiveWindowSize {
maxReceiveWindowSize = windows.receiveWindow
} else if windows.receiveWindow < minReceiveWindowSize {
minReceiveWindowSize = windows.receiveWindow
}
if windows.sendWindow > maxSendWindowSize {
maxSendWindowSize = windows.sendWindow
} else if windows.sendWindow < minSendWindowSize {
minSendWindowSize = windows.sendWindow
}
}
i++
}
return &FlowControlMetrics{
MinReceiveWindowSize: minReceiveWindowSize,
MaxReceiveWindowSize: maxReceiveWindowSize,
AverageReceiveWindowSize: averageReceiveWindowSize,
MinSendWindowSize: minSendWindowSize,
MaxSendWindowSize: maxSendWindowSize,
AverageSendWindowSize: averageSendWindowSize,
}
}

18
h2mux/bytes_counter.go Normal file
View File

@ -0,0 +1,18 @@
package h2mux
import (
"sync/atomic"
)
type AtomicCounter struct {
count uint64
}
func (c *AtomicCounter) IncrementBy(number uint64) {
atomic.AddUint64(&c.count, number)
}
// Get returns the current value of counter and reset it to 0
func (c *AtomicCounter) Count() uint64 {
return atomic.SwapUint64(&c.count, 0)
}

View File

@ -0,0 +1,23 @@
package h2mux
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCounter(t *testing.T) {
var wg sync.WaitGroup
wg.Add(dataPoints)
c := AtomicCounter{}
for i := 0; i < dataPoints; i++ {
go func() {
defer wg.Done()
c.IncrementBy(uint64(1))
}()
}
wg.Wait()
assert.Equal(t, uint64(dataPoints), c.Count())
assert.Equal(t, uint64(0), c.Count())
}

View File

@ -59,6 +59,8 @@ type Muxer struct {
muxReader *MuxReader muxReader *MuxReader
// muxWriter is the write process. // muxWriter is the write process.
muxWriter *MuxWriter muxWriter *MuxWriter
// muxMetricsUpdater is the process to update metrics
muxMetricsUpdater *muxMetricsUpdater
// newStreamChan is used to create new streams on the writer thread. // newStreamChan is used to create new streams on the writer thread.
// The writer will assign the next available stream ID. // The writer will assign the next available stream ID.
newStreamChan chan MuxedStreamRequest newStreamChan chan MuxedStreamRequest
@ -133,6 +135,11 @@ func Handshake(
// set up reader/writer pair ready for serve // set up reader/writer pair ready for serve
streamErrors := NewStreamErrorMap() streamErrors := NewStreamErrorMap()
goAwayChan := make(chan http2.ErrCode, 1) goAwayChan := make(chan http2.ErrCode, 1)
updateRTTChan := make(chan *roundTripMeasurement, 1)
updateReceiveWindowChan := make(chan uint32, 1)
updateSendWindowChan := make(chan uint32, 1)
updateInBoundBytesChan := make(chan uint64)
updateOutBoundBytesChan := make(chan uint64)
pingTimestamp := NewPingTimestamp() pingTimestamp := NewPingTimestamp()
connActive := NewSignal() connActive := NewSignal()
idleDuration := config.HeartbeatInterval idleDuration := config.HeartbeatInterval
@ -149,34 +156,48 @@ func Handshake(
m.explicitShutdown = NewBooleanFuse() m.explicitShutdown = NewBooleanFuse()
m.muxReader = &MuxReader{ m.muxReader = &MuxReader{
f: m.f, f: m.f,
handler: m.config.Handler, handler: m.config.Handler,
streams: m.streams, streams: m.streams,
readyList: m.readyList, readyList: m.readyList,
streamErrors: streamErrors, streamErrors: streamErrors,
goAwayChan: goAwayChan, goAwayChan: goAwayChan,
abortChan: m.abortChan, abortChan: m.abortChan,
pingTimestamp: pingTimestamp, pingTimestamp: pingTimestamp,
connActive: connActive, connActive: connActive,
initialStreamWindow: defaultWindowSize, initialStreamWindow: defaultWindowSize,
streamWindowMax: maxWindowSize, streamWindowMax: maxWindowSize,
r: m.r, r: m.r,
updateRTTChan: updateRTTChan,
updateReceiveWindowChan: updateReceiveWindowChan,
updateSendWindowChan: updateSendWindowChan,
updateInBoundBytesChan: updateInBoundBytesChan,
} }
m.muxWriter = &MuxWriter{ m.muxWriter = &MuxWriter{
f: m.f, f: m.f,
streams: m.streams, streams: m.streams,
streamErrors: streamErrors, streamErrors: streamErrors,
readyStreamChan: m.readyList.ReadyChannel(), readyStreamChan: m.readyList.ReadyChannel(),
newStreamChan: m.newStreamChan, newStreamChan: m.newStreamChan,
goAwayChan: goAwayChan, goAwayChan: goAwayChan,
abortChan: m.abortChan, abortChan: m.abortChan,
pingTimestamp: pingTimestamp, pingTimestamp: pingTimestamp,
idleTimer: NewIdleTimer(idleDuration, maxRetries), idleTimer: NewIdleTimer(idleDuration, maxRetries),
connActiveChan: connActive.WaitChannel(), connActiveChan: connActive.WaitChannel(),
maxFrameSize: defaultFrameSize, maxFrameSize: defaultFrameSize,
updateReceiveWindowChan: updateReceiveWindowChan,
updateSendWindowChan: updateSendWindowChan,
updateOutBoundBytesChan: updateOutBoundBytesChan,
} }
m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer) m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
m.muxMetricsUpdater = newMuxMetricsUpdater(
updateRTTChan,
updateReceiveWindowChan,
updateSendWindowChan,
updateInBoundBytesChan,
updateOutBoundBytesChan,
m.abortChan,
)
return m, nil return m, nil
} }
@ -246,9 +267,13 @@ func (m *Muxer) Serve() error {
m.w.Close() m.w.Close()
m.abort() m.abort()
}() }()
go func() {
errChan <- m.muxMetricsUpdater.run(logger)
}()
err := <-errChan err := <-errChan
go func() { go func() {
// discard error as other handler closes // discard errors as other handler and muxMetricsUpdater close
<-errChan
<-errChan <-errChan
close(errChan) close(errChan)
}() }()
@ -318,14 +343,8 @@ func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, erro
} }
} }
// Return the estimated round-trip time. func (m *Muxer) Metrics() *MuxerMetrics {
func (m *Muxer) RTT() RTTMeasurement { return m.muxMetricsUpdater.Metrics()
return m.muxReader.RTT()
}
// Return min/max/average of send/receive window for all streams on this connection
func (m *Muxer) FlowControlMetrics() *FlowControlMetrics {
return m.muxReader.FlowControlMetrics()
} }
func (m *Muxer) abort() { func (m *Muxer) abort() {

View File

@ -14,6 +14,7 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -134,9 +135,12 @@ func TestSingleStream(t *testing.T) {
if stream.Headers[0].Value != "headerValue" { if stream.Headers[0].Value != "headerValue" {
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value) t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
} }
stream.WriteHeaders([]Header{ headers := []Header{
Header{Name: "response-header", Value: "responseValue"}, Header{Name: "response-header", Value: "responseValue"},
}) }
stream.WriteHeaders(headers)
assert.Equal(t, headers, stream.writeHeaders)
assert.False(t, stream.headersSent)
buf := []byte("Hello world") buf := []byte("Hello world")
stream.Write(buf) stream.Write(buf)
// after this receive, the edge closed the stream // after this receive, the edge closed the stream

View File

@ -19,6 +19,7 @@ type MuxedStream struct {
receiveWindowCurrentMax uint32 receiveWindowCurrentMax uint32
// limit set in http2 spec. 2^31-1 // limit set in http2 spec. 2^31-1
receiveWindowMax uint32 receiveWindowMax uint32
// nonzero if a WINDOW_UPDATE frame for a stream needs to be sent // nonzero if a WINDOW_UPDATE frame for a stream needs to be sent
windowUpdate uint32 windowUpdate uint32
@ -39,10 +40,6 @@ type MuxedStream struct {
receivedEOF bool receivedEOF bool
} }
type flowControlWindow struct {
receiveWindow, sendWindow uint32
}
func (s *MuxedStream) Read(p []byte) (n int, err error) { func (s *MuxedStream) Read(p []byte) (n int, err error) {
return s.readBuffer.Read(p) return s.readBuffer.Read(p)
} }
@ -101,17 +98,21 @@ func (s *MuxedStream) WriteHeaders(headers []Header) error {
return ErrStreamHeadersSent return ErrStreamHeadersSent
} }
s.writeHeaders = headers s.writeHeaders = headers
s.headersSent = false
s.writeNotify() s.writeNotify()
return nil return nil
} }
func (s *MuxedStream) FlowControlWindow() *flowControlWindow { func (s *MuxedStream) getReceiveWindow() uint32 {
s.writeLock.Lock() s.writeLock.Lock()
defer s.writeLock.Unlock() defer s.writeLock.Unlock()
return &flowControlWindow{ return s.receiveWindow
receiveWindow: s.receiveWindow, }
sendWindow: s.sendWindow,
} func (s *MuxedStream) getSendWindow() uint32 {
s.writeLock.Lock()
defer s.writeLock.Unlock()
return s.sendWindow
} }
// writeNotify must happen while holding writeLock. // writeNotify must happen while holding writeLock.
@ -209,9 +210,7 @@ func (s *MuxedStream) getChunk() *streamChunk {
} }
// Copies at most s.sendWindow bytes // Copies at most s.sendWindow bytes
//log.Infof("writeBuffer len %d stream %d", s.writeBuffer.Len(), s.streamID)
writeLen, _ := io.CopyN(&chunk.buffer, &s.writeBuffer, int64(s.sendWindow)) writeLen, _ := io.CopyN(&chunk.buffer, &s.writeBuffer, int64(s.sendWindow))
//log.Infof("writeLen %d stream %d", writeLen, s.streamID)
s.sendWindow -= uint32(writeLen) s.sendWindow -= uint32(writeLen)
s.receiveWindow += s.windowUpdate s.receiveWindow += s.windowUpdate
s.windowUpdate = 0 s.windowUpdate = 0

232
h2mux/muxmetrics.go Normal file
View File

@ -0,0 +1,232 @@
package h2mux
import (
"sync"
"time"
"github.com/golang-collections/collections/queue"
log "github.com/sirupsen/logrus"
)
// data points used to compute average receive window and send window size
const (
// data points used to compute average receive window and send window size
dataPoints = 100
// updateFreq is set to 1 sec so we can get inbound & outbound byes/sec
updateFreq = time.Second
)
type muxMetricsUpdater struct {
// rttData keeps record of rtt, rttMin, rttMax and last measured time
rttData *rttData
// receiveWindowData keeps record of receive window measurement
receiveWindowData *flowControlData
// sendWindowData keeps record of send window measurement
sendWindowData *flowControlData
// inBoundRate is incoming bytes/sec
inBoundRate *rate
// outBoundRate is outgoing bytes/sec
outBoundRate *rate
// updateRTTChan is the channel to receive new RTT measurement from muxReader
updateRTTChan <-chan *roundTripMeasurement
//updateReceiveWindowChan is the channel to receive updated receiveWindow size from muxReader and muxWriter
updateReceiveWindowChan <-chan uint32
//updateSendWindowChan is the channel to receive updated sendWindow size from muxReader and muxWriter
updateSendWindowChan <-chan uint32
// updateInBoundBytesChan us the channel to receive bytesRead from muxReader
updateInBoundBytesChan <-chan uint64
// updateOutBoundBytesChan us the channel to receive bytesWrote from muxWriter
updateOutBoundBytesChan <-chan uint64
// shutdownC is to signal the muxerMetricsUpdater to shutdown
abortChan <-chan struct{}
}
type MuxerMetrics struct {
RTT, RTTMin, RTTMax time.Duration
ReceiveWindowAve, SendWindowAve float64
ReceiveWindowMin, ReceiveWindowMax, SendWindowMin, SendWindowMax uint32
InBoundRateCurr, InBoundRateMin, InBoundRateMax uint64
OutBoundRateCurr, OutBoundRateMin, OutBoundRateMax uint64
}
type roundTripMeasurement struct {
receiveTime, sendTime time.Time
}
type rttData struct {
rtt, rttMin, rttMax time.Duration
lastMeasurementTime time.Time
lock sync.RWMutex
}
type flowControlData struct {
sum uint64
min, max uint32
queue *queue.Queue
lock sync.RWMutex
}
type rate struct {
curr uint64
min, max uint64
lock sync.RWMutex
}
func newMuxMetricsUpdater(
updateRTTChan <-chan *roundTripMeasurement,
updateReceiveWindowChan <-chan uint32,
updateSendWindowChan <-chan uint32,
updateInBoundBytesChan <-chan uint64,
updateOutBoundBytesChan <-chan uint64,
abortChan <-chan struct{},
) *muxMetricsUpdater {
return &muxMetricsUpdater{
rttData: newRTTData(),
receiveWindowData: newFlowControlData(),
sendWindowData: newFlowControlData(),
inBoundRate: newRate(),
outBoundRate: newRate(),
updateRTTChan: updateRTTChan,
updateReceiveWindowChan: updateReceiveWindowChan,
updateSendWindowChan: updateSendWindowChan,
updateInBoundBytesChan: updateInBoundBytesChan,
updateOutBoundBytesChan: updateOutBoundBytesChan,
abortChan: abortChan,
}
}
func (updater *muxMetricsUpdater) Metrics() *MuxerMetrics {
m := &MuxerMetrics{}
m.RTT, m.RTTMin, m.RTTMax = updater.rttData.metrics()
m.ReceiveWindowAve, m.ReceiveWindowMin, m.ReceiveWindowMax = updater.receiveWindowData.metrics()
m.SendWindowAve, m.SendWindowMin, m.SendWindowMax = updater.sendWindowData.metrics()
m.InBoundRateCurr, m.InBoundRateMin, m.InBoundRateMax = updater.inBoundRate.get()
m.OutBoundRateCurr, m.OutBoundRateMin, m.OutBoundRateMax = updater.outBoundRate.get()
return m
}
func (updater *muxMetricsUpdater) run(parentLogger *log.Entry) error {
logger := parentLogger.WithFields(log.Fields{
"subsystem": "mux",
"dir": "metrics",
})
defer logger.Debug("event loop finished")
for {
select {
case <-updater.abortChan:
logger.Infof("Stopping mux metrics updater")
return nil
case roundTripMeasurement := <-updater.updateRTTChan:
go updater.rttData.update(roundTripMeasurement)
logger.Debug("Update rtt")
case receiveWindow := <-updater.updateReceiveWindowChan:
go updater.receiveWindowData.update(receiveWindow)
logger.Debug("Update receive window")
case sendWindow := <-updater.updateSendWindowChan:
go updater.sendWindowData.update(sendWindow)
logger.Debug("Update send window")
case inBoundBytes := <-updater.updateInBoundBytesChan:
// inBoundBytes is bytes/sec because the update interval is 1 sec
go updater.inBoundRate.update(inBoundBytes)
logger.Debugf("Inbound bytes %d", inBoundBytes)
case outBoundBytes := <-updater.updateOutBoundBytesChan:
// outBoundBytes is bytes/sec because the update interval is 1 sec
go updater.outBoundRate.update(outBoundBytes)
logger.Debugf("Outbound bytes %d", outBoundBytes)
}
}
}
func newRTTData() *rttData {
return &rttData{}
}
func (r *rttData) update(measurement *roundTripMeasurement) {
r.lock.Lock()
defer r.lock.Unlock()
// discard pings before lastMeasurementTime
if r.lastMeasurementTime.After(measurement.sendTime) {
return
}
r.lastMeasurementTime = measurement.sendTime
r.rtt = measurement.receiveTime.Sub(measurement.sendTime)
if r.rttMax < r.rtt {
r.rttMax = r.rtt
}
if r.rttMin == 0 || r.rttMin > r.rtt {
r.rttMin = r.rtt
}
}
func (r *rttData) metrics() (rtt, rttMin, rttMax time.Duration) {
r.lock.RLock()
defer r.lock.RUnlock()
return r.rtt, r.rttMin, r.rttMax
}
func newFlowControlData() *flowControlData {
return &flowControlData{queue: queue.New()}
}
func (f *flowControlData) update(measurement uint32) {
f.lock.Lock()
defer f.lock.Unlock()
var firstItem uint32
// store new data into queue, remove oldest data if queue is full
f.queue.Enqueue(measurement)
if f.queue.Len() > dataPoints {
// data type should always be uint32
firstItem = f.queue.Dequeue().(uint32)
}
// if (measurement - firstItem) < 0, uint64(measurement - firstItem)
// will overflow and become a large positive number
f.sum += uint64(measurement)
f.sum -= uint64(firstItem)
if measurement > f.max {
f.max = measurement
}
if f.min == 0 || measurement < f.min {
f.min = measurement
}
}
// caller of ave() should acquire lock first
func (f *flowControlData) ave() float64 {
if f.queue.Len() == 0 {
return 0
}
return float64(f.sum) / float64(f.queue.Len())
}
func (f *flowControlData) metrics() (ave float64, min, max uint32) {
f.lock.RLock()
defer f.lock.RUnlock()
return f.ave(), f.min, f.max
}
func newRate() *rate {
return &rate{}
}
func (r *rate) update(measurement uint64) {
r.lock.Lock()
defer r.lock.Unlock()
r.curr = measurement
// if measurement is 0, then there is no incoming/outgoing connection, don't update min/max
if r.curr == 0 {
return
}
if measurement > r.max {
r.max = measurement
}
if r.min == 0 || measurement < r.min {
r.min = measurement
}
}
func (r *rate) get() (curr, min, max uint64) {
r.lock.RLock()
defer r.lock.RUnlock()
return r.curr, r.min, r.max
}

176
h2mux/muxmetrics_test.go Normal file
View File

@ -0,0 +1,176 @@
package h2mux
import (
"sync"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
func ave(sum uint64, len int) float64 {
return float64(sum) / float64(len)
}
func TestRTTUpdate(t *testing.T) {
r := newRTTData()
start := time.Now()
// send at 0 ms, receive at 2 ms, RTT = 2ms
m := &roundTripMeasurement{receiveTime: start.Add(2 * time.Millisecond), sendTime: start}
r.update(m)
assert.Equal(t, start, r.lastMeasurementTime)
assert.Equal(t, 2*time.Millisecond, r.rtt)
assert.Equal(t, 2*time.Millisecond, r.rttMin)
assert.Equal(t, 2*time.Millisecond, r.rttMax)
// send at 3 ms, receive at 6 ms, RTT = 3ms
m = &roundTripMeasurement{receiveTime: start.Add(6 * time.Millisecond), sendTime: start.Add(3 * time.Millisecond)}
r.update(m)
assert.Equal(t, start.Add(3*time.Millisecond), r.lastMeasurementTime)
assert.Equal(t, 3*time.Millisecond, r.rtt)
assert.Equal(t, 2*time.Millisecond, r.rttMin)
assert.Equal(t, 3*time.Millisecond, r.rttMax)
// send at 7 ms, receive at 8 ms, RTT = 1ms
m = &roundTripMeasurement{receiveTime: start.Add(8 * time.Millisecond), sendTime: start.Add(7 * time.Millisecond)}
r.update(m)
assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime)
assert.Equal(t, 1*time.Millisecond, r.rtt)
assert.Equal(t, 1*time.Millisecond, r.rttMin)
assert.Equal(t, 3*time.Millisecond, r.rttMax)
// send at -4 ms, receive at 0 ms, RTT = 4ms, but this ping is before last measurement
// so it will be discarded
m = &roundTripMeasurement{receiveTime: start, sendTime: start.Add(-2 * time.Millisecond)}
r.update(m)
assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime)
assert.Equal(t, 1*time.Millisecond, r.rtt)
assert.Equal(t, 1*time.Millisecond, r.rttMin)
assert.Equal(t, 3*time.Millisecond, r.rttMax)
}
func TestFlowControlDataUpdate(t *testing.T) {
f := newFlowControlData()
assert.Equal(t, 0, f.queue.Len())
assert.Equal(t, float64(0), f.ave())
var sum uint64
min := maxWindowSize - dataPoints
max := maxWindowSize
for i := 1; i <= dataPoints; i++ {
size := maxWindowSize - uint32(i)
f.update(size)
assert.Equal(t, max - uint32(1), f.max)
assert.Equal(t, size, f.min)
assert.Equal(t, i, f.queue.Len())
sum += uint64(size)
assert.Equal(t, sum, f.sum)
assert.Equal(t, ave(sum, f.queue.Len()), f.ave())
}
// queue is full, should start to dequeue first element
for i := 1; i <= dataPoints; i++ {
f.update(max)
assert.Equal(t, max, f.max)
assert.Equal(t, min, f.min)
assert.Equal(t, dataPoints, f.queue.Len())
sum += uint64(i)
assert.Equal(t, sum, f.sum)
assert.Equal(t, ave(sum, dataPoints), f.ave())
}
}
func TestMuxMetricsUpdater(t *testing.T) {
updateRTTChan := make(chan *roundTripMeasurement)
updateReceiveWindowChan := make(chan uint32)
updateSendWindowChan := make(chan uint32)
updateInBoundBytesChan := make(chan uint64)
updateOutBoundBytesChan := make(chan uint64)
abortChan := make(chan struct{})
errChan := make(chan error)
m := newMuxMetricsUpdater(updateRTTChan,
updateReceiveWindowChan,
updateSendWindowChan,
updateInBoundBytesChan,
updateOutBoundBytesChan,
abortChan,
)
logger := log.NewEntry(log.New())
go func() {
errChan <- m.run(logger)
}()
var wg sync.WaitGroup
wg.Add(2)
// mock muxReader
readerStart := time.Now()
rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart}
updateRTTChan <- rm
go func() {
defer wg.Done()
// Becareful if dataPoints is not divisibile by 4
readerSend := readerStart.Add(time.Millisecond)
for i := 1; i <= dataPoints/4; i++ {
readerReceive := readerSend.Add(time.Duration(i) * time.Millisecond)
rm := &roundTripMeasurement{receiveTime: readerReceive, sendTime: readerSend}
updateRTTChan <- rm
readerSend = readerReceive.Add(time.Millisecond)
updateReceiveWindowChan <- uint32(i)
updateSendWindowChan <- uint32(i)
updateInBoundBytesChan <- uint64(i)
}
}()
// mock muxWriter
go func() {
defer wg.Done()
for j := dataPoints/4 + 1; j <= dataPoints/2; j++ {
updateReceiveWindowChan <- uint32(j)
updateSendWindowChan <- uint32(j)
// should always be disgard since the send time is before readerSend
rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart.Add(-time.Duration(j*dataPoints) * time.Millisecond)}
updateRTTChan <- rm
updateOutBoundBytesChan <- uint64(j)
}
}()
wg.Wait()
metrics := m.Metrics()
points := dataPoints / 2
assert.Equal(t, time.Millisecond, metrics.RTTMin)
assert.Equal(t, time.Duration(dataPoints/4)*time.Millisecond, metrics.RTTMax)
// sum(1..i) = i*(i+1)/2, ave(1..i) = i*(i+1)/2/i = (i+1)/2
assert.Equal(t, float64(points+1)/float64(2), metrics.ReceiveWindowAve)
assert.Equal(t, uint32(1), metrics.ReceiveWindowMin)
assert.Equal(t, uint32(points), metrics.ReceiveWindowMax)
assert.Equal(t, float64(points+1)/float64(2), metrics.SendWindowAve)
assert.Equal(t, uint32(1), metrics.SendWindowMin)
assert.Equal(t, uint32(points), metrics.SendWindowMax)
assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateCurr)
assert.Equal(t, uint64(1), metrics.InBoundRateMin)
assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateMax)
assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateCurr)
assert.Equal(t, uint64(dataPoints/4+1), metrics.OutBoundRateMin)
assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateMax)
close(abortChan)
assert.Nil(t, <-errChan)
close(errChan)
}

View File

@ -3,7 +3,6 @@ package h2mux
import ( import (
"encoding/binary" "encoding/binary"
"io" "io"
"sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -34,14 +33,18 @@ type MuxReader struct {
initialStreamWindow uint32 initialStreamWindow uint32
// The max value for the send window of a stream. // The max value for the send window of a stream.
streamWindowMax uint32 streamWindowMax uint32
// windowMetrics keeps track of min/max/average of send/receive windows for all streams
flowControlMetrics *FlowControlMetrics
metricsMutex sync.Mutex
// r is a reference to the underlying connection used when shutting down. // r is a reference to the underlying connection used when shutting down.
r io.Closer r io.Closer
// rttMeasurement measures RTT based on ping timestamps. // updateRTTChan is the channel to send new RTT measurement to muxerMetricsUpdater
rttMeasurement RTTMeasurement updateRTTChan chan<- *roundTripMeasurement
rttMutex sync.Mutex // updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
updateReceiveWindowChan chan<- uint32
// updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
updateSendWindowChan chan<- uint32
// bytesRead is the amount of bytes read from data frame since the last time we send bytes read to metrics
bytesRead AtomicCounter
// updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
updateInBoundBytesChan chan<- uint64
} }
func (r *MuxReader) Shutdown() { func (r *MuxReader) Shutdown() {
@ -57,28 +60,26 @@ func (r *MuxReader) Shutdown() {
}() }()
} }
func (r *MuxReader) RTT() RTTMeasurement {
r.rttMutex.Lock()
defer r.rttMutex.Unlock()
return r.rttMeasurement
}
func (r *MuxReader) FlowControlMetrics() *FlowControlMetrics {
r.metricsMutex.Lock()
defer r.metricsMutex.Unlock()
if r.flowControlMetrics != nil {
return r.flowControlMetrics
}
// No metrics available yet
return &FlowControlMetrics{}
}
func (r *MuxReader) run(parentLogger *log.Entry) error { func (r *MuxReader) run(parentLogger *log.Entry) error {
logger := parentLogger.WithFields(log.Fields{ logger := parentLogger.WithFields(log.Fields{
"subsystem": "mux", "subsystem": "mux",
"dir": "read", "dir": "read",
}) })
defer logger.Debug("event loop finished") defer logger.Debug("event loop finished")
// routine to periodically update bytesRead
go func() {
tickC := time.Tick(updateFreq)
for {
select {
case <-r.abortChan:
return
case <-tickC:
r.updateInBoundBytesChan <- r.bytesRead.Count()
}
}
}()
for { for {
frame, err := r.f.ReadFrame() frame, err := r.f.ReadFrame()
if err != nil { if err != nil {
@ -120,6 +121,8 @@ func (r *MuxReader) run(parentLogger *log.Entry) error {
r.receivePingData(f) r.receivePingData(f)
case *http2.GoAwayFrame: case *http2.GoAwayFrame:
err = r.receiveGoAway(f) err = r.receiveGoAway(f)
// The receiver of a flow-controlled frame sends a WINDOW_UPDATE frame as it
// consumes data and frees up space in flow-control windows
case *http2.WindowUpdateFrame: case *http2.WindowUpdateFrame:
err = r.updateStreamWindow(f) err = r.updateStreamWindow(f)
default: default:
@ -236,10 +239,11 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E
} }
data := frame.Data() data := frame.Data()
if len(data) > 0 { if len(data) > 0 {
_, err = stream.readBuffer.Write(data) n, err := stream.readBuffer.Write(data)
if err != nil { if err != nil {
return r.streamError(stream.streamID, http2.ErrCodeInternal) return r.streamError(stream.streamID, http2.ErrCodeInternal)
} }
r.bytesRead.IncrementBy(uint64(n))
} }
if frame.Header().Flags.Has(http2.FlagDataEndStream) { if frame.Header().Flags.Has(http2.FlagDataEndStream) {
if stream.receiveEOF() { if stream.receiveEOF() {
@ -253,6 +257,7 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E
if !stream.consumeReceiveWindow(uint32(len(data))) { if !stream.consumeReceiveWindow(uint32(len(data))) {
return r.streamError(stream.streamID, http2.ErrCodeFlowControl) return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
} }
r.updateReceiveWindowChan <- stream.getReceiveWindow()
return nil return nil
} }
@ -263,10 +268,14 @@ func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
r.pingTimestamp.Set(ts) r.pingTimestamp.Set(ts)
return return
} }
r.rttMutex.Lock()
r.rttMeasurement.Update(time.Unix(0, ts)) // Update updates the computed values with a new measurement.
r.rttMutex.Unlock() // outgoingTime is the time that the probe was sent.
r.flowControlMetrics = r.streams.Metrics() // We assume that time.Now() is the time we received that probe.
r.updateRTTChan <- &roundTripMeasurement{
receiveTime: time.Now(),
sendTime: time.Unix(0, ts),
}
} }
// Receive a GOAWAY from the peer. Gracefully shut down our connection. // Receive a GOAWAY from the peer. Gracefully shut down our connection.
@ -293,6 +302,7 @@ func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
return nil return nil
} }
stream.replenishSendWindow(frame.Increment) stream.replenishSendWindow(frame.Increment)
r.updateSendWindowChan <- stream.getSendWindow()
return nil return nil
} }

View File

@ -40,6 +40,14 @@ type MuxWriter struct {
headerEncoder *hpack.Encoder headerEncoder *hpack.Encoder
// headerBuffer is the temporary buffer used by headerEncoder. // headerBuffer is the temporary buffer used by headerEncoder.
headerBuffer bytes.Buffer headerBuffer bytes.Buffer
// updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
updateReceiveWindowChan chan<- uint32
// updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
updateSendWindowChan chan<- uint32
// bytesWrote is the amount of bytes wrote to data frame since the last time we send bytes wrote to metrics
bytesWrote AtomicCounter
// updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
updateOutBoundBytesChan chan<- uint64
} }
type MuxedStreamRequest struct { type MuxedStreamRequest struct {
@ -64,6 +72,20 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error {
"dir": "write", "dir": "write",
}) })
defer logger.Debug("event loop finished") defer logger.Debug("event loop finished")
// routine to periodically communicate bytesWrote
go func() {
tickC := time.Tick(updateFreq)
for {
select {
case <-w.abortChan:
return
case <-tickC:
w.updateOutBoundBytesChan <- w.bytesWrote.Count()
}
}
}()
for { for {
select { select {
case <-w.abortChan: case <-w.abortChan:
@ -141,7 +163,8 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error {
func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error { func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error {
logger.Debug("writable") logger.Debug("writable")
chunk := stream.getChunk() chunk := stream.getChunk()
w.updateReceiveWindowChan <- stream.getReceiveWindow()
w.updateSendWindowChan <- stream.getSendWindow()
if chunk.sendHeadersFrame() { if chunk.sendHeadersFrame() {
err := w.writeHeaders(chunk.streamID, chunk.headers) err := w.writeHeaders(chunk.streamID, chunk.headers)
if err != nil { if err != nil {
@ -154,7 +177,9 @@ func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) erro
if chunk.sendWindowUpdateFrame() { if chunk.sendWindowUpdateFrame() {
// Send a WINDOW_UPDATE frame to update our receive window. // Send a WINDOW_UPDATE frame to update our receive window.
// If the Stream ID is zero, the window update applies to the connection as a whole // If the Stream ID is zero, the window update applies to the connection as a whole
// A WINDOW_UPDATE in a specific stream applies to the connection-level flow control as well. // RFC7540 section-6.9.1 "A receiver that receives a flow-controlled frame MUST
// always account for its contribution against the connection flow-control
// window, unless the receiver treats this as a connection error"
err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate) err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
if err != nil { if err != nil {
logger.WithError(err).Warn("error writing window update") logger.WithError(err).Warn("error writing window update")
@ -170,6 +195,8 @@ func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) erro
logger.WithError(err).Warn("error writing data") logger.WithError(err).Warn("error writing data")
return err return err
} }
// update the amount of data wrote
w.bytesWrote.IncrementBy(uint64(len(payload)))
logger.WithField("len", len(payload)).Debug("output data") logger.WithField("len", len(payload)).Debug("output data")
if sentEOF { if sentEOF {
@ -214,19 +241,16 @@ func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error {
return err return err
} }
blockSize := int(w.maxFrameSize) blockSize := int(w.maxFrameSize)
continuation := false
endHeaders := len(encodedHeaders) == 0 endHeaders := len(encodedHeaders) == 0
for !endHeaders && err == nil { for !endHeaders && err == nil {
blockFragment := encodedHeaders blockFragment := encodedHeaders
if len(encodedHeaders) > blockSize { if len(encodedHeaders) > blockSize {
blockFragment = blockFragment[:blockSize] blockFragment = blockFragment[:blockSize]
encodedHeaders = encodedHeaders[blockSize:] encodedHeaders = encodedHeaders[blockSize:]
} else { // Send CONTINUATION frame if the headers can't be fit into 1 frame
endHeaders = true
}
if continuation {
err = w.f.WriteContinuation(streamID, endHeaders, blockFragment) err = w.f.WriteContinuation(streamID, endHeaders, blockFragment)
} else { } else {
endHeaders = true
err = w.f.WriteHeaders(http2.HeadersFrameParam{ err = w.f.WriteHeaders(http2.HeadersFrameParam{
StreamID: streamID, StreamID: streamID,
EndHeaders: endHeaders, EndHeaders: endHeaders,

View File

@ -2,7 +2,6 @@ package h2mux
import ( import (
"sync/atomic" "sync/atomic"
"time"
) )
// PingTimestamp is an atomic interface around ping timestamping and signalling. // PingTimestamp is an atomic interface around ping timestamping and signalling.
@ -28,26 +27,3 @@ func (pt *PingTimestamp) Get() int64 {
func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} { func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} {
return pt.signal.WaitChannel() return pt.signal.WaitChannel()
} }
// RTTMeasurement encapsulates a continuous round trip time measurement.
type RTTMeasurement struct {
Current, Min, Max time.Duration
lastMeasurementTime time.Time
}
// Update updates the computed values with a new measurement.
// outgoingTime is the time that the probe was sent.
// We assume that time.Now() is the time we received that probe.
func (r *RTTMeasurement) Update(outgoingTime time.Time) {
if !r.lastMeasurementTime.Before(outgoingTime) {
return
}
r.lastMeasurementTime = outgoingTime
r.Current = time.Since(outgoingTime)
if r.Max < r.Current {
r.Max = r.Current
}
if r.Min > r.Current {
r.Min = r.Current
}
}

View File

@ -2,13 +2,32 @@ package origin
import ( import (
"sync" "sync"
"time"
"github.com/cloudflare/cloudflare-warp/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
type TunnelMetrics struct { type muxerMetrics struct {
rtt *prometheus.GaugeVec
rttMin *prometheus.GaugeVec
rttMax *prometheus.GaugeVec
receiveWindowAve *prometheus.GaugeVec
sendWindowAve *prometheus.GaugeVec
receiveWindowMin *prometheus.GaugeVec
receiveWindowMax *prometheus.GaugeVec
sendWindowMin *prometheus.GaugeVec
sendWindowMax *prometheus.GaugeVec
inBoundRateCurr *prometheus.GaugeVec
inBoundRateMin *prometheus.GaugeVec
inBoundRateMax *prometheus.GaugeVec
outBoundRateCurr *prometheus.GaugeVec
outBoundRateMin *prometheus.GaugeVec
outBoundRateMax *prometheus.GaugeVec
}
type tunnelMetrics struct {
haConnections prometheus.Gauge haConnections prometheus.Gauge
totalRequests prometheus.Counter totalRequests prometheus.Counter
requestsPerTunnel *prometheus.CounterVec requestsPerTunnel *prometheus.CounterVec
@ -20,16 +39,7 @@ type TunnelMetrics struct {
maxConcurrentRequestsPerTunnel *prometheus.GaugeVec maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
// concurrentRequests records max count of concurrent requests for each tunnel // concurrentRequests records max count of concurrent requests for each tunnel
maxConcurrentRequests map[string]uint64 maxConcurrentRequests map[string]uint64
rtt prometheus.Gauge
rttMin prometheus.Gauge
rttMax prometheus.Gauge
timerRetries prometheus.Gauge timerRetries prometheus.Gauge
receiveWindowSizeAve prometheus.Gauge
sendWindowSizeAve prometheus.Gauge
receiveWindowSizeMin prometheus.Gauge
receiveWindowSizeMax prometheus.Gauge
sendWindowSizeMin prometheus.Gauge
sendWindowSizeMax prometheus.Gauge
responseByCode *prometheus.CounterVec responseByCode *prometheus.CounterVec
responseCodePerTunnel *prometheus.CounterVec responseCodePerTunnel *prometheus.CounterVec
serverLocations *prometheus.GaugeVec serverLocations *prometheus.GaugeVec
@ -37,10 +47,189 @@ type TunnelMetrics struct {
locationLock sync.Mutex locationLock sync.Mutex
// oldServerLocations stores the last server the tunnel was connected to // oldServerLocations stores the last server the tunnel was connected to
oldServerLocations map[string]string oldServerLocations map[string]string
muxerMetrics *muxerMetrics
}
func newMuxerMetrics() *muxerMetrics {
rtt := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "rtt",
Help: "Round-trip time in millisecond",
},
[]string{"connection_id"},
)
prometheus.MustRegister(rtt)
rttMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "rtt_min",
Help: "Shortest round-trip time in millisecond",
},
[]string{"connection_id"},
)
prometheus.MustRegister(rttMin)
rttMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "rtt_max",
Help: "Longest round-trip time in millisecond",
},
[]string{"connection_id"},
)
prometheus.MustRegister(rttMax)
receiveWindowAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "receive_window_ave",
Help: "Average receive window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(receiveWindowAve)
sendWindowAve := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "send_window_ave",
Help: "Average send window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(sendWindowAve)
receiveWindowMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "receive_window_min",
Help: "Smallest receive window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(receiveWindowMin)
receiveWindowMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "receive_window_max",
Help: "Largest receive window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(receiveWindowMax)
sendWindowMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "send_window_min",
Help: "Smallest send window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(sendWindowMin)
sendWindowMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "send_window_max",
Help: "Largest send window size in bytes",
},
[]string{"connection_id"},
)
prometheus.MustRegister(sendWindowMax)
inBoundRateCurr := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "inbound_bytes_per_sec_curr",
Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
},
[]string{"connection_id"},
)
prometheus.MustRegister(inBoundRateCurr)
inBoundRateMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "inbound_bytes_per_sec_min",
Help: "Minimum non-zero inbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(inBoundRateMin)
inBoundRateMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "inbound_bytes_per_sec_max",
Help: "Maximum inbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(inBoundRateMax)
outBoundRateCurr := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "outbound_bytes_per_sec_curr",
Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
},
[]string{"connection_id"},
)
prometheus.MustRegister(outBoundRateCurr)
outBoundRateMin := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "outbound_bytes_per_sec_min",
Help: "Minimum non-zero outbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(outBoundRateMin)
outBoundRateMax := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "outbound_bytes_per_sec_max",
Help: "Maximum outbounding bytes per second",
},
[]string{"connection_id"},
)
prometheus.MustRegister(outBoundRateMax)
return &muxerMetrics{
rtt: rtt,
rttMin: rttMin,
rttMax: rttMax,
receiveWindowAve: receiveWindowAve,
sendWindowAve: sendWindowAve,
receiveWindowMin: receiveWindowMin,
receiveWindowMax: receiveWindowMax,
sendWindowMin: sendWindowMin,
sendWindowMax: sendWindowMax,
inBoundRateCurr: inBoundRateCurr,
inBoundRateMin: inBoundRateMin,
inBoundRateMax: inBoundRateMax,
outBoundRateCurr: outBoundRateCurr,
outBoundRateMin: outBoundRateMin,
outBoundRateMax: outBoundRateMax,
}
}
func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) {
m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT))
m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin))
m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax))
m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve)
m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve)
m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin))
m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax))
m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin))
m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax))
m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr))
m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin))
m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax))
m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr))
m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin))
m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax))
}
func convertRTTMilliSec(t time.Duration) float64 {
return float64(t / time.Millisecond)
} }
// Metrics that can be collected without asking the edge // Metrics that can be collected without asking the edge
func NewTunnelMetrics() *TunnelMetrics { func NewTunnelMetrics() *tunnelMetrics {
haConnections := prometheus.NewGauge( haConnections := prometheus.NewGauge(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Name: "ha_connections", Name: "ha_connections",
@ -82,27 +271,6 @@ func NewTunnelMetrics() *TunnelMetrics {
) )
prometheus.MustRegister(maxConcurrentRequestsPerTunnel) prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
rtt := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "rtt",
Help: "Round-trip time",
})
prometheus.MustRegister(rtt)
rttMin := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "rtt_min",
Help: "Shortest round-trip time",
})
prometheus.MustRegister(rttMin)
rttMax := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "rtt_max",
Help: "Longest round-trip time",
})
prometheus.MustRegister(rttMax)
timerRetries := prometheus.NewGauge( timerRetries := prometheus.NewGauge(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Name: "timer_retries", Name: "timer_retries",
@ -110,48 +278,6 @@ func NewTunnelMetrics() *TunnelMetrics {
}) })
prometheus.MustRegister(timerRetries) prometheus.MustRegister(timerRetries)
receiveWindowSizeAve := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "receive_window_ave",
Help: "Average receive window size",
})
prometheus.MustRegister(receiveWindowSizeAve)
sendWindowSizeAve := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "send_window_ave",
Help: "Average send window size",
})
prometheus.MustRegister(sendWindowSizeAve)
receiveWindowSizeMin := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "receive_window_min",
Help: "Smallest receive window size",
})
prometheus.MustRegister(receiveWindowSizeMin)
receiveWindowSizeMax := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "receive_window_max",
Help: "Largest receive window size",
})
prometheus.MustRegister(receiveWindowSizeMax)
sendWindowSizeMin := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "send_window_min",
Help: "Smallest send window size",
})
prometheus.MustRegister(sendWindowSizeMin)
sendWindowSizeMax := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "send_window_max",
Help: "Largest send window size",
})
prometheus.MustRegister(sendWindowSizeMax)
responseByCode := prometheus.NewCounterVec( responseByCode := prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{
Name: "response_by_code", Name: "response_by_code",
@ -179,7 +305,7 @@ func NewTunnelMetrics() *TunnelMetrics {
) )
prometheus.MustRegister(serverLocations) prometheus.MustRegister(serverLocations)
return &TunnelMetrics{ return &tunnelMetrics{
haConnections: haConnections, haConnections: haConnections,
totalRequests: totalRequests, totalRequests: totalRequests,
requestsPerTunnel: requestsPerTunnel, requestsPerTunnel: requestsPerTunnel,
@ -187,41 +313,28 @@ func NewTunnelMetrics() *TunnelMetrics {
concurrentRequests: make(map[string]uint64), concurrentRequests: make(map[string]uint64),
maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel, maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
maxConcurrentRequests: make(map[string]uint64), maxConcurrentRequests: make(map[string]uint64),
rtt: rtt, timerRetries: timerRetries,
rttMin: rttMin, responseByCode: responseByCode,
rttMax: rttMax, responseCodePerTunnel: responseCodePerTunnel,
timerRetries: timerRetries, serverLocations: serverLocations,
receiveWindowSizeAve: receiveWindowSizeAve, oldServerLocations: make(map[string]string),
sendWindowSizeAve: sendWindowSizeAve, muxerMetrics: newMuxerMetrics(),
receiveWindowSizeMin: receiveWindowSizeMin,
receiveWindowSizeMax: receiveWindowSizeMax,
sendWindowSizeMin: sendWindowSizeMin,
sendWindowSizeMax: sendWindowSizeMax,
responseByCode: responseByCode,
responseCodePerTunnel: responseCodePerTunnel,
serverLocations: serverLocations,
oldServerLocations: make(map[string]string),
} }
} }
func (t *TunnelMetrics) incrementHaConnections() { func (t *tunnelMetrics) incrementHaConnections() {
t.haConnections.Inc() t.haConnections.Inc()
} }
func (t *TunnelMetrics) decrementHaConnections() { func (t *tunnelMetrics) decrementHaConnections() {
t.haConnections.Dec() t.haConnections.Dec()
} }
func (t *TunnelMetrics) updateTunnelFlowControlMetrics(metrics *h2mux.FlowControlMetrics) { func (t *tunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) {
t.receiveWindowSizeAve.Set(float64(metrics.AverageReceiveWindowSize)) t.muxerMetrics.update(connectionID, metrics)
t.sendWindowSizeAve.Set(float64(metrics.AverageSendWindowSize))
t.receiveWindowSizeMin.Set(float64(metrics.MinReceiveWindowSize))
t.receiveWindowSizeMax.Set(float64(metrics.MaxReceiveWindowSize))
t.sendWindowSizeMin.Set(float64(metrics.MinSendWindowSize))
t.sendWindowSizeMax.Set(float64(metrics.MaxSendWindowSize))
} }
func (t *TunnelMetrics) incrementRequests(connectionID string) { func (t *tunnelMetrics) incrementRequests(connectionID string) {
t.concurrentRequestsLock.Lock() t.concurrentRequestsLock.Lock()
var concurrentRequests uint64 var concurrentRequests uint64
var ok bool var ok bool
@ -243,7 +356,7 @@ func (t *TunnelMetrics) incrementRequests(connectionID string) {
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc() t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
} }
func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) { func (t *tunnelMetrics) decrementConcurrentRequests(connectionID string) {
t.concurrentRequestsLock.Lock() t.concurrentRequestsLock.Lock()
if _, ok := t.concurrentRequests[connectionID]; ok { if _, ok := t.concurrentRequests[connectionID]; ok {
t.concurrentRequests[connectionID] -= 1 t.concurrentRequests[connectionID] -= 1
@ -255,13 +368,13 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec() t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
} }
func (t *TunnelMetrics) incrementResponses(connectionID, code string) { func (t *tunnelMetrics) incrementResponses(connectionID, code string) {
t.responseByCode.WithLabelValues(code).Inc() t.responseByCode.WithLabelValues(code).Inc()
t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc() t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
} }
func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) { func (t *tunnelMetrics) registerServerLocation(connectionID, loc string) {
t.locationLock.Lock() t.locationLock.Lock()
defer t.locationLock.Unlock() defer t.locationLock.Unlock()
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc { if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {

View File

@ -15,11 +15,11 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"github.com/cloudflare/cloudflare-warp/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflare-warp/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflare-warp/validation" "github.com/cloudflare/cloudflared/validation"
"github.com/cloudflare/cloudflare-warp/websocket" "github.com/cloudflare/cloudflared/websocket"
raven "github.com/getsentry/raven-go" raven "github.com/getsentry/raven-go"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -53,7 +53,7 @@ type TunnelConfig struct {
Tags []tunnelpogs.Tag Tags []tunnelpogs.Tag
HAConnections int HAConnections int
HTTPTransport http.RoundTripper HTTPTransport http.RoundTripper
Metrics *TunnelMetrics Metrics *tunnelMetrics
MetricsUpdateFreq time.Duration MetricsUpdateFreq time.Duration
ProtocolLogger *logrus.Logger ProtocolLogger *logrus.Logger
Logger *logrus.Logger Logger *logrus.Logger
@ -185,6 +185,7 @@ func ServeTunnel(
serveCtx, serveCancel := context.WithCancel(ctx) serveCtx, serveCancel := context.WithCancel(ctx)
registerErrC := make(chan error, 1) registerErrC := make(chan error, 1)
go func() { go func() {
defer wg.Done()
err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP) err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP)
if err == nil { if err == nil {
connectedFuse.Fuse(true) connectedFuse.Fuse(true)
@ -193,18 +194,18 @@ func ServeTunnel(
serveCancel() serveCancel()
} }
registerErrC <- err registerErrC <- err
wg.Done()
}() }()
updateMetricsTickC := time.Tick(config.MetricsUpdateFreq) updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
go func() { go func() {
defer wg.Done() defer wg.Done()
connectionTag := uint8ToString(connectionID)
for { for {
select { select {
case <-serveCtx.Done(): case <-serveCtx.Done():
handler.muxer.Shutdown() handler.muxer.Shutdown()
return return
case <-updateMetricsTickC: case <-updateMetricsTickC:
handler.UpdateMetrics() handler.UpdateMetrics(connectionTag)
} }
} }
}() }()
@ -303,7 +304,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
func LogServerInfo(logger *logrus.Entry, func LogServerInfo(logger *logrus.Entry,
promise tunnelrpc.ServerInfo_Promise, promise tunnelrpc.ServerInfo_Promise,
connectionID uint8, connectionID uint8,
metrics *TunnelMetrics, metrics *tunnelMetrics,
) { ) {
serverInfoMessage, err := promise.Struct() serverInfoMessage, err := promise.Struct()
if err != nil { if err != nil {
@ -356,13 +357,17 @@ func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
return return
} }
func FindCfRayHeader(h1 *http.Request) string {
return h1.Header.Get("Cf-Ray")
}
type TunnelHandler struct { type TunnelHandler struct {
originUrl string originUrl string
muxer *h2mux.Muxer muxer *h2mux.Muxer
httpClient http.RoundTripper httpClient http.RoundTripper
tlsConfig *tls.Config tlsConfig *tls.Config
tags []tunnelpogs.Tag tags []tunnelpogs.Tag
metrics *TunnelMetrics metrics *tunnelMetrics
// connectionID is only used by metrics, and prometheus requires labels to be string // connectionID is only used by metrics, and prometheus requires labels to be string
connectionID string connectionID string
} }
@ -435,7 +440,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
Log.WithError(err).Error("invalid request received") Log.WithError(err).Error("invalid request received")
} }
h.AppendTagHeaders(req) h.AppendTagHeaders(req)
cfRay := FindCfRayHeader(req)
h.logRequest(req, cfRay)
if websocket.IsWebSocketUpgrade(req) { if websocket.IsWebSocketUpgrade(req) {
conn, response, err := websocket.ClientConnect(req, h.tlsConfig) conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
if err != nil { if err != nil {
@ -444,6 +450,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
stream.WriteHeaders(H1ResponseToH2Response(response)) stream.WriteHeaders(H1ResponseToH2Response(response))
defer conn.Close() defer conn.Close()
websocket.Stream(conn.UnderlyingConn(), stream) websocket.Stream(conn.UnderlyingConn(), stream)
h.metrics.incrementResponses(h.connectionID, "200")
h.logResponse(response, cfRay)
} }
} else { } else {
response, err := h.httpClient.RoundTrip(req) response, err := h.httpClient.RoundTrip(req)
@ -454,6 +462,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
stream.WriteHeaders(H1ResponseToH2Response(response)) stream.WriteHeaders(H1ResponseToH2Response(response))
io.Copy(stream, response.Body) io.Copy(stream, response.Body)
h.metrics.incrementResponses(h.connectionID, "200") h.metrics.incrementResponses(h.connectionID, "200")
h.logResponse(response, cfRay)
} }
} }
h.metrics.decrementConcurrentRequests(h.connectionID) h.metrics.decrementConcurrentRequests(h.connectionID)
@ -467,9 +476,27 @@ func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
h.metrics.incrementResponses(h.connectionID, "502") h.metrics.incrementResponses(h.connectionID, "502")
} }
func (h *TunnelHandler) UpdateMetrics() { func (h *TunnelHandler) logRequest(req *http.Request, cfRay string) {
flowCtlMetrics := h.muxer.FlowControlMetrics() if cfRay != "" {
h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics) Log.WithField("CF-RAY", cfRay).Infof("%s %s %s", req.Method, req.URL, req.Proto)
} else {
Log.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto)
}
Log.Debugf("Request Headers %+v", req.Header)
}
func (h *TunnelHandler) logResponse(r *http.Response, cfRay string) {
if cfRay != "" {
Log.WithField("CF-RAY", cfRay).Infof("%s", r.Status)
} else {
Log.Infof("%s", r.Status)
}
Log.Debugf("Response Headers %+v", r.Header)
}
func (h *TunnelHandler) UpdateMetrics(connectionID string) {
h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics())
} }
func uint8ToString(input uint8) string { func uint8ToString(input uint8) string {

View File

@ -11,28 +11,28 @@ const (
BgUrgQQAIg== BgUrgQQAIg==
-----END EC PARAMETERS----- -----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY----- -----BEGIN EC PRIVATE KEY-----
MIGkAgEBBDAdyQBXfxTDCQSOT0HugmH9pVBtIw8t5dYvm6HxGlNq6P57v5GeN02Z MIGkAgEBBDBGGfwhIJdiUiJUVIItqJjEIMmlXxsMa8TQeer47+g+cIZ466rgg8EK
dH9FRl7+VSWgBwYFK4EEACKhZANiAATqpFzTxxV7D+/oqhKCTR6BEM9elTfKaRQE +Mdn6BY48GCgBwYFK4EEACKhZANiAASW//A9iDbPKg3OLkn7yJqLer32g9I5lBKR
FsLufcmaTMw/9tTwgpHKao/QsLKDTNbQhbSQLkcmpCQKlSGhl+pCrqNt/oYUAhav tPc/zBubQLLz9lAaYI6AOQiJXhGr5JkKmQfi1sYHK5rJITPFy4W8Et4hHLdazDZH
UIwpwGiLCqGH/R2AqWLKRPOa/Rufs/U= WnEd+TStQABFUjrhtqXPWmGKcly0pOE=
-----END EC PRIVATE KEY-----` -----END EC PRIVATE KEY-----`
helloCRT = ` helloCRT = `
-----BEGIN CERTIFICATE----- -----BEGIN CERTIFICATE-----
MIICkDCCAhigAwIBAgIJAPtKfUjc2lwGMAkGByqGSM49BAEwgYoxCzAJBgNVBAYT MIICiDCCAg6gAwIBAgIJAJ/FfkBTtbuIMAkGByqGSM49BAEwfzELMAkGA1UEBhMC
AlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQKDBBD VVMxDjAMBgNVBAgMBVRleGFzMQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENs
bG91ZGZsYXJlLCBJbmMuMT8wPQYDVQQDDDZDbG91ZGZsYXJlIEFyZ28gVHVubmVs b3VkZmxhcmUsIEluYy4xNDAyBgNVBAMMK0FyZ28gVHVubmVsIFNhbXBsZSBIZWxs
IFNhbXBsZSBIZWxsbyBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMjE1MjAxNjU5 byBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMzE5MjMwNTMyWhcNMjgwMzE2MjMw
WhcNMjgwMjEzMjAxNjU5WjCBijELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVRleGFz NTMyWjB/MQswCQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1
MQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENsb3VkZmxhcmUsIEluYy4xPzA9 c3RpbjEZMBcGA1UECgwQQ2xvdWRmbGFyZSwgSW5jLjE0MDIGA1UEAwwrQXJnbyBU
BgNVBAMMNkNsb3VkZmxhcmUgQXJnbyBUdW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZl dW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZlciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49
ciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49AgEGBSuBBAAiA2IABOqkXNPHFXsP7+iq AgEGBSuBBAAiA2IABJb/8D2INs8qDc4uSfvImot6vfaD0jmUEpG09z/MG5tAsvP2
EoJNHoEQz16VN8ppFAQWwu59yZpMzD/21PCCkcpqj9CwsoNM1tCFtJAuRyakJAqV UBpgjoA5CIleEavkmQqZB+LWxgcrmskhM8XLhbwS3iEct1rMNkdacR35NK1AAEVS
IaGX6kKuo23+hhQCFq9QjCnAaIsKoYf9HYCpYspE85r9G5+z9aNJMEcwRQYDVR0R OuG2pc9aYYpyXLSk4aNXMFUwUwYDVR0RBEwwSoIJbG9jYWxob3N0ghFjbG91ZGZs
BD4wPIIJbG9jYWxob3N0ggp3YXJwLWhlbGxvggt3YXJwMi1oZWxsb4cEfwAAAYcQ YXJlZC1oZWxsb4ISY2xvdWRmbGFyZWQyLWhlbGxvhwR/AAABhxAAAAAAAAAAAAAA
AAAAAAAAAAAAAAAAAAAAATAJBgcqhkjOPQQBA2cAMGQCMHyVPufXZ6vQo6XRWRa0 AAAAAAABMAkGByqGSM49BAEDaQAwZgIxAPxkdghH6y8xLMnY9Bom3Llf4NYM6yB9
dAwtfgesOdZVP2Wt+t5v8jOIQQh1IQXYk5GtyoZGSObjhQIwd1fRgAyKXaZt+1DV PD1YsaNUJTsxjTk3YY1Jsp+yzK0yUKtTZwIxAPcdvqCF2/iR9H288pCT1TgtO0a9
ZtHTdf8pMvESfJsSd8AB1eQ6q+pAiRUYyaxcE1Mlo2YY5o+g cJL9RY1lq7DIGN37v1ZXReWaD+3hNokY8NriVg==
-----END CERTIFICATE-----` -----END CERTIFICATE-----`
) )

38
tunneldns/https_proxy.go Normal file
View File

@ -0,0 +1,38 @@
package tunneldns
import (
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
"github.com/pkg/errors"
"golang.org/x/net/context"
)
// Upstream is a simplified interface for proxy destination
type Upstream interface {
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
}
// ProxyPlugin is a simplified DNS proxy using a generic upstream interface
type ProxyPlugin struct {
Upstreams []Upstream
Next plugin.Handler
}
// ServeDNS implements interface for CoreDNS plugin
func (p ProxyPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
var reply *dns.Msg
var backendErr error
for _, upstream := range p.Upstreams {
reply, backendErr = upstream.Exchange(ctx, r)
if backendErr == nil {
w.WriteMsg(reply)
return 0, nil
}
}
return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
}
// Name implements interface for CoreDNS plugin
func (p ProxyPlugin) Name() string { return "proxy" }

View File

@ -0,0 +1,97 @@
package tunneldns
import (
"bytes"
"crypto/tls"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
"github.com/miekg/dns"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/net/context"
)
const (
defaultTimeout = 5 * time.Second
)
// UpstreamHTTPS is the upstream implementation for DNS over HTTPS service
type UpstreamHTTPS struct {
client *http.Client
endpoint *url.URL
}
// NewUpstreamHTTPS creates a new DNS over HTTPS upstream from hostname
func NewUpstreamHTTPS(endpoint string) (Upstream, error) {
u, err := url.Parse(endpoint)
if err != nil {
return nil, err
}
// Update TLS and HTTP client configuration
tls := &tls.Config{ServerName: u.Hostname()}
client := &http.Client{
Timeout: time.Second * defaultTimeout,
Transport: &http.Transport{TLSClientConfig: tls},
}
return &UpstreamHTTPS{client: client, endpoint: u}, nil
}
// Exchange provides an implementation for the Upstream interface
func (u *UpstreamHTTPS) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
queryBuf, err := query.Pack()
if err != nil {
return nil, errors.Wrap(err, "failed to pack DNS query")
}
// No content negotiation for now, use DNS wire format
buf, backendErr := u.exchangeWireformat(queryBuf)
if backendErr == nil {
response := &dns.Msg{}
if err := response.Unpack(buf); err != nil {
return nil, errors.Wrap(err, "failed to unpack DNS response from body")
}
response.Id = query.Id
return response, nil
}
log.WithError(backendErr).Errorf("failed to connect to an HTTPS backend %q", u.endpoint)
return nil, backendErr
}
// Perform message exchange with the default UDP wireformat defined in current draft
// https://datatracker.ietf.org/doc/draft-ietf-doh-dns-over-https
func (u *UpstreamHTTPS) exchangeWireformat(msg []byte) ([]byte, error) {
req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
if err != nil {
return nil, errors.Wrap(err, "failed to create an HTTPS request")
}
req.Header.Add("Content-Type", "application/dns-udpwireformat")
req.Host = u.endpoint.Hostname()
resp, err := u.client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to perform an HTTPS request")
}
// Check response status code
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
}
// Read wireformat response from the body
buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read the response body")
}
return buf, nil
}

45
tunneldns/metrics.go Normal file
View File

@ -0,0 +1,45 @@
package tunneldns
import (
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics/vars"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/rcode"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/net/context"
)
// MetricsPlugin is an adapter for CoreDNS and built-in metrics
type MetricsPlugin struct {
Next plugin.Handler
}
// NewMetricsPlugin creates a plugin with configured metrics
func NewMetricsPlugin(next plugin.Handler) *MetricsPlugin {
prometheus.MustRegister(vars.RequestCount)
prometheus.MustRegister(vars.RequestDuration)
prometheus.MustRegister(vars.RequestSize)
prometheus.MustRegister(vars.RequestDo)
prometheus.MustRegister(vars.RequestType)
prometheus.MustRegister(vars.ResponseSize)
prometheus.MustRegister(vars.ResponseRcode)
return &MetricsPlugin{Next: next}
}
// ServeDNS implements the CoreDNS plugin interface
func (p MetricsPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
rw := dnstest.NewRecorder(w)
status, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, rw, r)
// Update built-in metrics
vars.Report(state, ".", rcode.ToString(rw.Rcode), rw.Len, rw.Start)
return status, err
}
// Name implements the CoreDNS plugin interface
func (p MetricsPlugin) Name() string { return "metrics" }

144
tunneldns/tunnel.go Normal file
View File

@ -0,0 +1,144 @@
package tunneldns
import (
"fmt"
"net"
"os"
"os/signal"
"strconv"
"sync"
"syscall"
"gopkg.in/urfave/cli.v2"
"github.com/cloudflare/cloudflared/metrics"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/cache"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
// Listener is an adapter between CoreDNS server and Warp runnable
type Listener struct {
server *dnsserver.Server
wg sync.WaitGroup
}
// Run implements a foreground runner
func Run(c *cli.Context) error {
metricsListener, err := net.Listen("tcp", c.String("metrics"))
if err != nil {
log.WithError(err).Fatal("Failed to open the metrics listener")
}
go metrics.ServeMetrics(metricsListener, nil)
listener, err := CreateListener(c.String("address"), uint16(c.Uint("port")), c.StringSlice("upstream"))
if err != nil {
log.WithError(err).Errorf("Failed to create the listeners")
return err
}
// Try to start the server
err = listener.Start()
if err != nil {
log.WithError(err).Errorf("Failed to start the listeners")
return listener.Stop()
}
// Wait for signal
signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signals)
<-signals
// Shut down server
err = listener.Stop()
if err != nil {
log.WithError(err).Errorf("failed to stop")
}
return err
}
// Create a CoreDNS server plugin from configuration
func createConfig(address string, port uint16, p plugin.Handler) *dnsserver.Config {
c := &dnsserver.Config{
Zone: ".",
Transport: "dns",
ListenHosts: []string{address},
Port: strconv.FormatUint(uint64(port), 10),
}
c.AddPlugin(func(next plugin.Handler) plugin.Handler { return p })
return c
}
// Start blocks for serving requests
func (l *Listener) Start() error {
log.WithField("addr", l.server.Address()).Infof("Starting DNS over HTTPS proxy server")
// Start UDP listener
if udp, err := l.server.ListenPacket(); err == nil {
l.wg.Add(1)
go func() {
l.server.ServePacket(udp)
l.wg.Done()
}()
} else {
return errors.Wrap(err, "failed to create a UDP listener")
}
// Start TCP listener
tcp, err := l.server.Listen()
if err == nil {
l.wg.Add(1)
go func() {
l.server.Serve(tcp)
l.wg.Done()
}()
}
return errors.Wrap(err, "failed to create a TCP listener")
}
// Stop signals server shutdown and blocks until completed
func (l *Listener) Stop() error {
if err := l.server.Stop(); err != nil {
return err
}
l.wg.Wait()
return nil
}
// CreateListener configures the server and bound sockets
func CreateListener(address string, port uint16, upstreams []string) (*Listener, error) {
// Build the list of upstreams
upstreamList := make([]Upstream, 0)
for _, url := range upstreams {
log.WithField("url", url).Infof("Adding DNS upstream")
upstream, err := NewUpstreamHTTPS(url)
if err != nil {
return nil, errors.Wrap(err, "failed to create HTTPS upstream")
}
upstreamList = append(upstreamList, upstream)
}
// Create a local cache with HTTPS proxy plugin
chain := cache.New()
chain.Next = ProxyPlugin{
Upstreams: upstreamList,
}
// Format an endpoint
endpoint := fmt.Sprintf("dns://%s:%d", address, port)
// Create the actual middleware server
server, err := dnsserver.NewServer(endpoint, []*dnsserver.Config{createConfig(address, port, NewMetricsPlugin(chain))})
if err != nil {
return nil, err
}
return &Listener{server: server}, nil
}

View File

@ -1,7 +1,7 @@
package pogs package pogs
import ( import (
"github.com/cloudflare/cloudflare-warp/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
"golang.org/x/net/context" "golang.org/x/net/context"
"zombiezen.com/go/capnproto2" "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/pogs" "zombiezen.com/go/capnproto2/pogs"

View File

@ -1,7 +1,7 @@
using Go = import "go.capnp"; using Go = import "go.capnp";
@0xdb8274f9144abc7e; @0xdb8274f9144abc7e;
$Go.package("tunnelrpc"); $Go.package("tunnelrpc");
$Go.import("github.com/cloudflare/cloudflare-warp/tunnelrpc"); $Go.import("github.com/cloudflare/cloudflared/tunnelrpc");
struct Authentication { struct Authentication {
key @0 :Text; key @0 :Text;
@ -35,7 +35,7 @@ struct RegistrationOptions {
connectionId @6 :UInt8; connectionId @6 :UInt8;
# origin LAN IP # origin LAN IP
originLocalIp @7 :Text; originLocalIp @7 :Text;
# whether Warp client has been autoupdated # whether Argo Tunnel client has been autoupdated
isAutoupdated @8 :Bool; isAutoupdated @8 :Bool;
} }

View File

@ -119,7 +119,7 @@ func validateScheme(scheme string) error {
return nil return nil
} }
} }
return fmt.Errorf("Currently Cloudflare-Warp does not support %s protocol.", scheme) return fmt.Errorf("Currently Argo Tunnel does not support %s protocol.", scheme)
} }
func validateIP(scheme, host, port string) (string, error) { func validateIP(scheme, host, port string) (string, error) {

View File

@ -126,7 +126,7 @@ func TestValidateUrl(t *testing.T) {
assert.Equal(t, "https://hello.example.com", validUrl) assert.Equal(t, "https://hello.example.com", validUrl)
validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt") validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
assert.Equal(t, "Currently Cloudflare-Warp does not support ftp protocol.", err.Error()) assert.Equal(t, "Currently Argo Tunnel does not support ftp protocol.", err.Error())
assert.Empty(t, validUrl) assert.Empty(t, validUrl)
validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080") validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")