Release Argo Tunnel Client 2018.3.1
This commit is contained in:
parent
9f5cec8dbc
commit
d0a6a2a829
|
@ -0,0 +1,13 @@
|
||||||
|
// +build !windows,!darwin,!linux
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func runApp(app *cli.App) {
|
||||||
|
app.Run(os.Args)
|
||||||
|
}
|
|
@ -0,0 +1,204 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"gopkg.in/urfave/cli.v2"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||||
|
)
|
||||||
|
|
||||||
|
type templateData struct {
|
||||||
|
ServerName string
|
||||||
|
Request *http.Request
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OriginUpTime struct {
|
||||||
|
StartTime time.Time `json:"startTime"`
|
||||||
|
UpTime string `json:"uptime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultServerName = "the Argo Tunnel test server"
|
||||||
|
const indexTemplate = `
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<meta http-equiv="X-UA-Compatible" content="IE=Edge">
|
||||||
|
<title>
|
||||||
|
Argo Tunnel Connection
|
||||||
|
</title>
|
||||||
|
<meta name="author" content="">
|
||||||
|
<meta name="description" content="Argo Tunnel Connection">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||||
|
<style>
|
||||||
|
html{line-height:1.15;-ms-text-size-adjust:100%;-webkit-text-size-adjust:100%}body{margin:0}section{display:block}h1{font-size:2em;margin:.67em 0}a{background-color:transparent;-webkit-text-decoration-skip:objects}/* 1 */::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}/* 1 */a,body,dd,div,dl,dt,h1,h4,html,p,section{box-sizing:border-box}.bt{border-top-style:solid;border-top-width:1px}.bl{border-left-style:solid;border-left-width:1px}.b--orange{border-color:#f38020}.br1{border-radius:.125rem}.bw2{border-width:.25rem}.dib{display:inline-block}.sans-serif{font-family:open sans,-apple-system,BlinkMacSystemFont,avenir next,avenir,helvetica neue,helvetica,ubuntu,roboto,noto,segoe ui,arial,sans-serif}.code{font-family:Consolas,monaco,monospace}.b{font-weight:700}.fw3{font-weight:300}.fw4{font-weight:400}.fw5{font-weight:500}.fw6{font-weight:600}.lh-copy{line-height:1.5}.link{text-decoration:none}.link,.link:active,.link:focus,.link:hover,.link:link,.link:visited{transition:color .15s ease-in}.link:focus{outline:1px dotted currentColor}.mw-100{max-width:100%}.mw4{max-width:8rem}.mw7{max-width:48rem}.bg-light-gray{background-color:#f7f7f7}.link-hover:hover{background-color:#1f679e}.white{color:#fff}.bg-white{background-color:#fff}.bg-blue{background-color:#408bc9}.pb2{padding-bottom:.5rem}.pb6{padding-bottom:8rem}.pt3{padding-top:1rem}.pt5{padding-top:4rem}.pv2{padding-top:.5rem;padding-bottom:.5rem}.ph3{padding-left:1rem;padding-right:1rem}.ph4{padding-left:2rem;padding-right:2rem}.ml0{margin-left:0}.mb1{margin-bottom:.25rem}.mb2{margin-bottom:.5rem}.mb3{margin-bottom:1rem}.mt5{margin-top:4rem}.ttu{text-transform:uppercase}.f4{font-size:1.25rem}.f5{font-size:1rem}.f6{font-size:.875rem}.f7{font-size:.75rem}.measure{max-width:30em}.center{margin-left:auto}.center{margin-right:auto}@media screen and (min-width:30em){.f2-ns{font-size:2.25rem}}@media screen and (min-width:30em) and (max-width:60em){.f5-m{font-size:1rem}}@media screen and (min-width:60em){.f4-l{font-size:1.25rem}}
|
||||||
|
.st0{fill:#FFF}.st1{fill:#f48120}.st2{fill:#faad3f}.st3{fill:#404041}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body class="sans-serif black">
|
||||||
|
<div class="bt bw2 b--orange bg-white pb6">
|
||||||
|
<div class="mw7 center ph4 pt3">
|
||||||
|
<svg id="Layer_2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 109 40.5" class="mw4">
|
||||||
|
<path class="st0" d="M98.6 14.2L93 12.9l-1-.4-25.7.2v12.4l32.3.1z"/>
|
||||||
|
<path class="st1" d="M88.1 24c.3-1 .2-2-.3-2.6-.5-.6-1.2-1-2.1-1.1l-17.4-.2c-.1 0-.2-.1-.3-.1-.1-.1-.1-.2 0-.3.1-.2.2-.3.4-.3l17.5-.2c2.1-.1 4.3-1.8 5.1-3.8l1-2.6c0-.1.1-.2 0-.3-1.1-5.1-5.7-8.9-11.1-8.9-5 0-9.3 3.2-10.8 7.7-1-.7-2.2-1.1-3.6-1-2.4.2-4.3 2.2-4.6 4.6-.1.6 0 1.2.1 1.8-3.9.1-7.1 3.3-7.1 7.3 0 .4 0 .7.1 1.1 0 .2.2.3.3.3h32.1c.2 0 .4-.1.4-.3l.3-1.1z"/>
|
||||||
|
<path class="st2" d="M93.6 12.8h-.5c-.1 0-.2.1-.3.2l-.7 2.4c-.3 1-.2 2 .3 2.6.5.6 1.2 1 2.1 1.1l3.7.2c.1 0 .2.1.3.1.1.1.1.2 0 .3-.1.2-.2.3-.4.3l-3.8.2c-2.1.1-4.3 1.8-5.1 3.8l-.2.9c-.1.1 0 .3.2.3h13.2c.2 0 .3-.1.3-.3.2-.8.4-1.7.4-2.6 0-5.2-4.3-9.5-9.5-9.5"/>
|
||||||
|
<path class="st3" d="M104.4 30.8c-.5 0-.9-.4-.9-.9s.4-.9.9-.9.9.4.9.9-.4.9-.9.9m0-1.6c-.4 0-.7.3-.7.7 0 .4.3.7.7.7.4 0 .7-.3.7-.7 0-.4-.3-.7-.7-.7m.4 1.2h-.2l-.2-.3h-.2v.3h-.2v-.9h.5c.2 0 .3.1.3.3 0 .1-.1.2-.2.3l.2.3zm-.3-.5c.1 0 .1 0 .1-.1s-.1-.1-.1-.1h-.3v.3h.3zM14.8 29H17v6h3.8v1.9h-6zM23.1 32.9c0-2.3 1.8-4.1 4.3-4.1s4.2 1.8 4.2 4.1-1.8 4.1-4.3 4.1c-2.4 0-4.2-1.8-4.2-4.1m6.3 0c0-1.2-.8-2.2-2-2.2s-2 1-2 2.1.8 2.1 2 2.1c1.2.2 2-.8 2-2M34.3 33.4V29h2.2v4.4c0 1.1.6 1.7 1.5 1.7s1.5-.5 1.5-1.6V29h2.2v4.4c0 2.6-1.5 3.7-3.7 3.7-2.3-.1-3.7-1.2-3.7-3.7M45 29h3.1c2.8 0 4.5 1.6 4.5 3.9s-1.7 4-4.5 4h-3V29zm3.1 5.9c1.3 0 2.2-.7 2.2-2s-.9-2-2.2-2h-.9v4h.9zM55.7 29H62v1.9h-4.1v1.3h3.7V34h-3.7v2.9h-2.2zM65.1 29h2.2v6h3.8v1.9h-6zM76.8 28.9H79l3.4 8H80l-.6-1.4h-3.1l-.6 1.4h-2.3l3.4-8zm2 4.9l-.9-2.2-.9 2.2h1.8zM85.2 29h3.7c1.2 0 2 .3 2.6.9.5.5.7 1.1.7 1.8 0 1.2-.6 2-1.6 2.4l1.9 2.8H90l-1.6-2.4h-1v2.4h-2.2V29zm3.6 3.8c.7 0 1.2-.4 1.2-.9 0-.6-.5-.9-1.2-.9h-1.4v1.9h1.4zM95.3 29h6.4v1.8h-4.2V32h3.8v1.8h-3.8V35h4.3v1.9h-6.5zM10 33.9c-.3.7-1 1.2-1.8 1.2-1.2 0-2-1-2-2.1s.8-2.1 2-2.1c.9 0 1.6.6 1.9 1.3h2.3c-.4-1.9-2-3.3-4.2-3.3-2.4 0-4.3 1.8-4.3 4.1s1.8 4.1 4.2 4.1c2.1 0 3.7-1.4 4.2-3.2H10z"/>
|
||||||
|
</svg>
|
||||||
|
<h1 class="f4 f2-ns mt5 fw5">Congrats! You created your first tunnel!</h1>
|
||||||
|
<p class="f6 f5-m f4-l measure lh-copy fw3">
|
||||||
|
Argo Tunnel exposes locally running applications to the internet by
|
||||||
|
running an encrypted, virtual tunnel from your laptop or server to
|
||||||
|
Cloudflare's edge network.
|
||||||
|
</p>
|
||||||
|
<p class="b f5 mt5 fw6">Ready for the next step?</p>
|
||||||
|
<a
|
||||||
|
class="fw6 link white bg-blue ph4 pv2 br1 dib f5 link-hover"
|
||||||
|
style="border-bottom: 1px solid #1f679e"
|
||||||
|
href="https://developers.cloudflare.com/argo-tunnel/">
|
||||||
|
Get started here
|
||||||
|
</a>
|
||||||
|
<section>
|
||||||
|
<h4 class="f6 fw4 pt5 mb2">Request</h4>
|
||||||
|
<dl class="bl bw2 b--orange ph3 pt3 pb2 bg-light-gray f7 code overflow-x-auto mw-100">
|
||||||
|
<dd class="ml0 mb3 f5">Method: {{.Request.Method}}</dd>
|
||||||
|
<dd class="ml0 mb3 f5">Protocol: {{.Request.Proto}}</dd>
|
||||||
|
<dd class="ml0 mb3 f5">Request URL: {{.Request.URL}}</dd>
|
||||||
|
<dd class="ml0 mb3 f5">Transfer encoding: {{.Request.TransferEncoding}}</dd>
|
||||||
|
<dd class="ml0 mb3 f5">Host: {{.Request.Host}}</dd>
|
||||||
|
<dd class="ml0 mb3 f5">Remote address: {{.Request.RemoteAddr}}</dd>
|
||||||
|
<dd class="ml0 mb3 f5">Request URI: {{.Request.RequestURI}}</dd>
|
||||||
|
{{range $key, $value := .Request.Header}}
|
||||||
|
<dd class="ml0 mb3 f5">Header: {{$key}}, Value: {{$value}}</dd>
|
||||||
|
{{end}}
|
||||||
|
<dd class="ml0 mb3 f5">Body: {{.Body}}</dd>
|
||||||
|
</dl>
|
||||||
|
</section>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`
|
||||||
|
|
||||||
|
func hello(c *cli.Context) error {
|
||||||
|
address := fmt.Sprintf(":%d", c.Int("port"))
|
||||||
|
listener, err := createListener(address)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
err = startHelloWorldServer(listener, nil)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
|
||||||
|
Log.Infof("Starting Hello World server at %s", listener.Addr())
|
||||||
|
serverName := defaultServerName
|
||||||
|
if hostname, err := os.Hostname(); err == nil {
|
||||||
|
serverName = hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
upgrader := websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
|
||||||
|
go func() {
|
||||||
|
<-shutdownC
|
||||||
|
httpServer.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
http.HandleFunc("/uptime", uptimeHandler(time.Now()))
|
||||||
|
http.HandleFunc("/ws", websocketHandler(upgrader))
|
||||||
|
http.HandleFunc("/", rootHandler(serverName))
|
||||||
|
err := httpServer.Serve(listener)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func createListener(address string) (net.Listener, error) {
|
||||||
|
certificate, err := tlsconfig.GetHelloCertificate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the port in address is empty, a port number is automatically chosen
|
||||||
|
listener, err := tls.Listen(
|
||||||
|
"tcp",
|
||||||
|
address,
|
||||||
|
&tls.Config{Certificates: []tls.Certificate{certificate}})
|
||||||
|
|
||||||
|
return listener, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func uptimeHandler(startTime time.Time) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Note that if autoupdate is enabled, the uptime is reset when a new client
|
||||||
|
// release is available
|
||||||
|
resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()}
|
||||||
|
respJson, err := json.Marshal(resp)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
} else {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write(respJson)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
mt, message, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.WriteMessage(mt, message); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rootHandler(serverName string) http.HandlerFunc {
|
||||||
|
responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
var body string
|
||||||
|
rawBody, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err == nil {
|
||||||
|
body = string(rawBody)
|
||||||
|
} else {
|
||||||
|
body = ""
|
||||||
|
}
|
||||||
|
err = responseTemplate.Execute(&buffer, &templateData{
|
||||||
|
ServerName: serverName,
|
||||||
|
Request: r,
|
||||||
|
Body: body,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(w, "error: %v", err)
|
||||||
|
} else {
|
||||||
|
buffer.WriteTo(w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateListenerHostAndPortSuccess(t *testing.T) {
|
||||||
|
listener, err := createListener("localhost:1234")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if listener.Addr().String() == "" {
|
||||||
|
t.Fatal("Fail to find available port")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateListenerOnlyHostSuccess(t *testing.T) {
|
||||||
|
listener, err := createListener("localhost:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if listener.Addr().String() == "" {
|
||||||
|
t.Fatal("Fail to find available port")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateListenerOnlyPortSuccess(t *testing.T) {
|
||||||
|
listener, err := createListener(":8888")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if listener.Addr().String() == "" {
|
||||||
|
t.Fatal("Fail to find available port")
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,292 @@
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func runApp(app *cli.App) {
|
||||||
|
app.Commands = append(app.Commands, &cli.Command{
|
||||||
|
Name: "service",
|
||||||
|
Usage: "Manages the Argo Tunnel system service",
|
||||||
|
Subcommands: []*cli.Command{
|
||||||
|
&cli.Command{
|
||||||
|
Name: "install",
|
||||||
|
Usage: "Install Argo Tunnel as a system service",
|
||||||
|
Action: installLinuxService,
|
||||||
|
},
|
||||||
|
&cli.Command{
|
||||||
|
Name: "uninstall",
|
||||||
|
Usage: "Uninstall the Argo Tunnel service",
|
||||||
|
Action: uninstallLinuxService,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
app.Run(os.Args)
|
||||||
|
}
|
||||||
|
|
||||||
|
const serviceConfigDir = "/etc/cloudflared"
|
||||||
|
|
||||||
|
var systemdTemplates = []ServiceTemplate{
|
||||||
|
{
|
||||||
|
Path: "/etc/systemd/system/cloudflared.service",
|
||||||
|
Content: `[Unit]
|
||||||
|
Description=Argo Tunnel
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
TimeoutStartSec=0
|
||||||
|
Type=notify
|
||||||
|
ExecStart={{ .Path }} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --no-autoupdate
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=5s
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/etc/systemd/system/cloudflared-update.service",
|
||||||
|
Content: `[Unit]
|
||||||
|
Description=Update Argo Tunnel
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 64 ]; then systemctl restart cloudflared; exit 0; fi; exit $code'
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/etc/systemd/system/cloudflared-update.timer",
|
||||||
|
Content: `[Unit]
|
||||||
|
Description=Update Argo Tunnel
|
||||||
|
|
||||||
|
[Timer]
|
||||||
|
OnUnitActiveSec=1d
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=timers.target
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var sysvTemplate = ServiceTemplate{
|
||||||
|
Path: "/etc/init.d/cloudflared",
|
||||||
|
FileMode: 0755,
|
||||||
|
Content: `# For RedHat and cousins:
|
||||||
|
# chkconfig: 2345 99 01
|
||||||
|
# description: Argo Tunnel agent
|
||||||
|
# processname: {{.Path}}
|
||||||
|
### BEGIN INIT INFO
|
||||||
|
# Provides: {{.Path}}
|
||||||
|
# Required-Start:
|
||||||
|
# Required-Stop:
|
||||||
|
# Default-Start: 2 3 4 5
|
||||||
|
# Default-Stop: 0 1 6
|
||||||
|
# Short-Description: Argo Tunnel
|
||||||
|
# Description: Argo Tunnel agent
|
||||||
|
### END INIT INFO
|
||||||
|
cmd="{{.Path}} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --pidfile /var/run/$name.pid --autoupdate-freq 24h0m0s"
|
||||||
|
name=$(basename $(readlink -f $0))
|
||||||
|
pid_file="/var/run/$name.pid"
|
||||||
|
stdout_log="/var/log/$name.log"
|
||||||
|
stderr_log="/var/log/$name.err"
|
||||||
|
[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name
|
||||||
|
get_pid() {
|
||||||
|
cat "$pid_file"
|
||||||
|
}
|
||||||
|
is_running() {
|
||||||
|
[ -f "$pid_file" ] && ps $(get_pid) > /dev/null 2>&1
|
||||||
|
}
|
||||||
|
case "$1" in
|
||||||
|
start)
|
||||||
|
if is_running; then
|
||||||
|
echo "Already started"
|
||||||
|
else
|
||||||
|
echo "Starting $name"
|
||||||
|
$cmd >> "$stdout_log" 2>> "$stderr_log" &
|
||||||
|
echo $! > "$pid_file"
|
||||||
|
if ! is_running; then
|
||||||
|
echo "Unable to start, see $stdout_log and $stderr_log"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
stop)
|
||||||
|
if is_running; then
|
||||||
|
echo -n "Stopping $name.."
|
||||||
|
kill $(get_pid)
|
||||||
|
for i in {1..10}
|
||||||
|
do
|
||||||
|
if ! is_running; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
echo -n "."
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo
|
||||||
|
if is_running; then
|
||||||
|
echo "Not stopped; may still be shutting down or shutdown may have failed"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "Stopped"
|
||||||
|
if [ -f "$pid_file" ]; then
|
||||||
|
rm "$pid_file"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "Not running"
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
restart)
|
||||||
|
$0 stop
|
||||||
|
if is_running; then
|
||||||
|
echo "Unable to stop, will not attempt to start"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
$0 start
|
||||||
|
;;
|
||||||
|
status)
|
||||||
|
if is_running; then
|
||||||
|
echo "Running"
|
||||||
|
else
|
||||||
|
echo "Stopped"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Usage: $0 {start|stop|restart|status}"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
exit 0
|
||||||
|
`,
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSystemd() bool {
|
||||||
|
if _, err := os.Stat("/run/systemd/system"); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func installLinuxService(c *cli.Context) error {
|
||||||
|
etPath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error determining executable path: %v", err)
|
||||||
|
}
|
||||||
|
templateArgs := ServiceTemplateArgs{Path: etPath}
|
||||||
|
|
||||||
|
defaultConfigDir := filepath.Dir(c.String("config"))
|
||||||
|
defaultConfigFile := filepath.Base(c.String("config"))
|
||||||
|
if err = copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile); err != nil {
|
||||||
|
Log.WithError(err).Infof("Failed to copy user configuration. Before running the service, ensure that %s contains two files, %s and %s",
|
||||||
|
serviceConfigDir, credentialFile, defaultConfigFiles[0])
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case isSystemd():
|
||||||
|
Log.Infof("Using Systemd")
|
||||||
|
return installSystemd(&templateArgs)
|
||||||
|
default:
|
||||||
|
Log.Infof("Using Sysv")
|
||||||
|
return installSysv(&templateArgs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func installSystemd(templateArgs *ServiceTemplateArgs) error {
|
||||||
|
for _, serviceTemplate := range systemdTemplates {
|
||||||
|
err := serviceTemplate.Generate(templateArgs)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("error generating service template")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := runCommand("systemctl", "enable", "cloudflared.service"); err != nil {
|
||||||
|
Log.WithError(err).Infof("systemctl enable cloudflared.service error")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := runCommand("systemctl", "start", "cloudflared-update.timer"); err != nil {
|
||||||
|
Log.WithError(err).Infof("systemctl start cloudflared-update.timer error")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
Log.Infof("systemctl daemon-reload")
|
||||||
|
return runCommand("systemctl", "daemon-reload")
|
||||||
|
}
|
||||||
|
|
||||||
|
func installSysv(templateArgs *ServiceTemplateArgs) error {
|
||||||
|
confPath, err := sysvTemplate.ResolvePath()
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("error resolving system path")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := sysvTemplate.Generate(templateArgs); err != nil {
|
||||||
|
Log.WithError(err).Infof("error generating system template")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, i := range [...]string{"2", "3", "4", "5"} {
|
||||||
|
if err := os.Symlink(confPath, "/etc/rc"+i+".d/S50et"); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, i := range [...]string{"0", "1", "6"} {
|
||||||
|
if err := os.Symlink(confPath, "/etc/rc"+i+".d/K02et"); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallLinuxService(c *cli.Context) error {
|
||||||
|
switch {
|
||||||
|
case isSystemd():
|
||||||
|
Log.Infof("Using Systemd")
|
||||||
|
return uninstallSystemd()
|
||||||
|
default:
|
||||||
|
Log.Infof("Using Sysv")
|
||||||
|
return uninstallSysv()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallSystemd() error {
|
||||||
|
if err := runCommand("systemctl", "disable", "cloudflared.service"); err != nil {
|
||||||
|
Log.WithError(err).Infof("systemctl disable cloudflared.service error")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := runCommand("systemctl", "stop", "cloudflared-update.timer"); err != nil {
|
||||||
|
Log.WithError(err).Infof("systemctl stop cloudflared-update.timer error")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, serviceTemplate := range systemdTemplates {
|
||||||
|
if err := serviceTemplate.Remove(); err != nil {
|
||||||
|
Log.WithError(err).Infof("error removing service template")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Log.Infof("Successfully uninstall cloudflared service")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallSysv() error {
|
||||||
|
if err := sysvTemplate.Remove(); err != nil {
|
||||||
|
Log.WithError(err).Infof("error removing service template")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, i := range [...]string{"2", "3", "4", "5"} {
|
||||||
|
if err := os.Remove("/etc/rc" + i + ".d/S50et"); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, i := range [...]string{"0", "1", "6"} {
|
||||||
|
if err := os.Remove("/etc/rc" + i + ".d/K02et"); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Log.Infof("Successfully uninstall cloudflared service")
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,194 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base32"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
homedir "github.com/mitchellh/go-homedir"
|
||||||
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const baseLoginURL = "https://www.cloudflare.com/a/warp"
|
||||||
|
const baseCertStoreURL = "https://login.cloudflarewarp.com"
|
||||||
|
const clientTimeout = time.Minute * 20
|
||||||
|
|
||||||
|
func login(c *cli.Context) error {
|
||||||
|
configPath, err := homedir.Expand(defaultConfigDirs[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ok, err := fileExists(configPath)
|
||||||
|
if !ok && err == nil {
|
||||||
|
// create config directory if doesn't already exist
|
||||||
|
err = os.Mkdir(configPath, 0700)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
path := filepath.Join(configPath, credentialFile)
|
||||||
|
fileInfo, err := os.Stat(path)
|
||||||
|
if err == nil && fileInfo.Size() > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, `You have an existing certificate at %s which login would overwrite.
|
||||||
|
If this is intentional, please move or delete that file then run this command again.
|
||||||
|
`, path)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil && err.(*os.PathError).Err != syscall.ENOENT {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// for local debugging
|
||||||
|
baseURL := baseCertStoreURL
|
||||||
|
if c.IsSet("url") {
|
||||||
|
baseURL = c.String("url")
|
||||||
|
}
|
||||||
|
// Generate a random post URL
|
||||||
|
certURL := baseURL + generateRandomPath()
|
||||||
|
loginURL, err := url.Parse(baseLoginURL)
|
||||||
|
if err != nil {
|
||||||
|
// shouldn't happen, URL is hardcoded
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
loginURL.RawQuery = "callback=" + url.QueryEscape(certURL)
|
||||||
|
|
||||||
|
err = open(loginURL.String())
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, `Please open the following URL and log in with your Cloudflare account:
|
||||||
|
|
||||||
|
%s
|
||||||
|
|
||||||
|
Leave cloudflared running to install the certificate automatically.
|
||||||
|
`, loginURL.String())
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, `A browser window should have opened at the following URL:
|
||||||
|
|
||||||
|
%s
|
||||||
|
|
||||||
|
If the browser failed to open, open it yourself and visit the URL above.
|
||||||
|
|
||||||
|
`, loginURL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if download(certURL, path) {
|
||||||
|
fmt.Fprintf(os.Stderr, `You have successfully logged in.
|
||||||
|
If you wish to copy your credentials to a server, they have been saved to:
|
||||||
|
%s
|
||||||
|
`, path)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, `Failed to write the certificate due to the following error:
|
||||||
|
%v
|
||||||
|
|
||||||
|
Your browser will download the certificate instead. You will have to manually
|
||||||
|
copy it to the following path:
|
||||||
|
|
||||||
|
%s
|
||||||
|
|
||||||
|
`, err, path)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRandomPath generates a random URL to associate with the certificate.
|
||||||
|
func generateRandomPath() string {
|
||||||
|
randomBytes := make([]byte, 40)
|
||||||
|
_, err := rand.Read(randomBytes)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return "/" + base32.StdEncoding.EncodeToString(randomBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// open opens the specified URL in the default browser of the user.
|
||||||
|
func open(url string) error {
|
||||||
|
var cmd string
|
||||||
|
var args []string
|
||||||
|
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
cmd = "cmd"
|
||||||
|
args = []string{"/c", "start"}
|
||||||
|
case "darwin":
|
||||||
|
cmd = "open"
|
||||||
|
default: // "linux", "freebsd", "openbsd", "netbsd"
|
||||||
|
cmd = "xdg-open"
|
||||||
|
}
|
||||||
|
args = append(args, url)
|
||||||
|
return exec.Command(cmd, args...).Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
func download(certURL, filePath string) bool {
|
||||||
|
client := &http.Client{Timeout: clientTimeout}
|
||||||
|
// attempt a (long-running) certificate get
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
ok, err := tryDownload(client, certURL, filePath)
|
||||||
|
if ok {
|
||||||
|
putSuccess(client, certURL)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Error("Error fetching certificate")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func tryDownload(client *http.Client, certURL, filePath string) (ok bool, err error) {
|
||||||
|
resp, err := client.Get(certURL)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode == 404 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
return false, fmt.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if resp.Header.Get("Content-Type") != "application/x-pem-file" {
|
||||||
|
return false, fmt.Errorf("Unexpected content type %s", resp.Header.Get("Content-Type"))
|
||||||
|
}
|
||||||
|
// write response
|
||||||
|
file, err := os.Create(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
written, err := io.Copy(file, resp.Body)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
return false, err
|
||||||
|
case resp.ContentLength != written && resp.ContentLength != -1:
|
||||||
|
return false, fmt.Errorf("Short read (%d bytes) from server while writing certificate", written)
|
||||||
|
default:
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func putSuccess(client *http.Client, certURL string) {
|
||||||
|
// indicate success to the relay server
|
||||||
|
req, err := http.NewRequest("PUT", certURL+"/ok", nil)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Error("HTTP request error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Error("HTTP error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,97 @@
|
||||||
|
// +build darwin
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"gopkg.in/urfave/cli.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const launchAgentIdentifier = "com.cloudflare.cloudflared"
|
||||||
|
|
||||||
|
func runApp(app *cli.App) {
|
||||||
|
app.Commands = append(app.Commands, &cli.Command{
|
||||||
|
Name: "service",
|
||||||
|
Usage: "Manages the Argo Tunnel launch agent",
|
||||||
|
Subcommands: []*cli.Command{
|
||||||
|
{
|
||||||
|
Name: "install",
|
||||||
|
Usage: "Install Argo Tunnel as an user launch agent",
|
||||||
|
Action: installLaunchd,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "uninstall",
|
||||||
|
Usage: "Uninstall the Argo Tunnel launch agent",
|
||||||
|
Action: uninstallLaunchd,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
app.Run(os.Args)
|
||||||
|
}
|
||||||
|
|
||||||
|
var launchdTemplate = ServiceTemplate{
|
||||||
|
Path: fmt.Sprintf("~/Library/LaunchAgents/%s.plist", launchAgentIdentifier),
|
||||||
|
Content: fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||||
|
<plist version="1.0">
|
||||||
|
<dict>
|
||||||
|
<key>Label</key>
|
||||||
|
<string>%s</string>
|
||||||
|
<key>Program</key>
|
||||||
|
<string>{{ .Path }}</string>
|
||||||
|
<key>RunAtLoad</key>
|
||||||
|
<true/>
|
||||||
|
<key>StandardOutPath</key>
|
||||||
|
<string>/tmp/%s.out.log</string>
|
||||||
|
<key>StandardErrorPath</key>
|
||||||
|
<string>/tmp/%s.err.log</string>
|
||||||
|
<key>KeepAlive</key>
|
||||||
|
<dict>
|
||||||
|
<key>NetworkState</key>
|
||||||
|
<true/>
|
||||||
|
</dict>
|
||||||
|
<key>ThrottleInterval</key>
|
||||||
|
<integer>20</integer>
|
||||||
|
</dict>
|
||||||
|
</plist>`, launchAgentIdentifier, launchAgentIdentifier, launchAgentIdentifier),
|
||||||
|
}
|
||||||
|
|
||||||
|
func installLaunchd(c *cli.Context) error {
|
||||||
|
Log.Infof("Installing Argo Tunnel as an user launch agent")
|
||||||
|
etPath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("error determining executable path")
|
||||||
|
return fmt.Errorf("error determining executable path: %v", err)
|
||||||
|
}
|
||||||
|
templateArgs := ServiceTemplateArgs{Path: etPath}
|
||||||
|
err = launchdTemplate.Generate(&templateArgs)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("error generating launchd template")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
plistPath, err := launchdTemplate.ResolvePath()
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("error resolving launchd template path")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
Log.Infof("Outputs are logged in %s and %s", fmt.Sprintf("/tmp/%s.out.log", launchAgentIdentifier), fmt.Sprintf("/tmp/%s.err.log", launchAgentIdentifier))
|
||||||
|
return runCommand("launchctl", "load", plistPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallLaunchd(c *cli.Context) error {
|
||||||
|
Log.Infof("Uninstalling Argo Tunnel as an user launch agent")
|
||||||
|
plistPath, err := launchdTemplate.ResolvePath()
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("error resolving launchd template path")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = runCommand("launchctl", "unload", plistPath)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("error unloading")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
Log.Infof("Outputs are logged in %s and %s", fmt.Sprintf("/tmp/%s.out.log", launchAgentIdentifier), fmt.Sprintf("/tmp/%s.err.log", launchAgentIdentifier))
|
||||||
|
return launchdTemplate.Remove()
|
||||||
|
}
|
|
@ -0,0 +1,794 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/metrics"
|
||||||
|
"github.com/cloudflare/cloudflared/origin"
|
||||||
|
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||||
|
"github.com/cloudflare/cloudflared/tunneldns"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
"github.com/cloudflare/cloudflared/validation"
|
||||||
|
|
||||||
|
"github.com/facebookgo/grace/gracenet"
|
||||||
|
"github.com/getsentry/raven-go"
|
||||||
|
"github.com/mitchellh/go-homedir"
|
||||||
|
"github.com/rifflock/lfshook"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/crypto/ssh/terminal"
|
||||||
|
"gopkg.in/urfave/cli.v2"
|
||||||
|
"gopkg.in/urfave/cli.v2/altsrc"
|
||||||
|
|
||||||
|
"github.com/coreos/go-systemd/daemon"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
|
||||||
|
credentialFile = "cert.pem"
|
||||||
|
quickStartUrl = "https://developers.cloudflare.com/argo-tunnel/quickstart/quickstart/"
|
||||||
|
noAutoupdateMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/"
|
||||||
|
licenseUrl = "https://developers.cloudflare.com/argo-tunnel/licence/"
|
||||||
|
)
|
||||||
|
|
||||||
|
var listeners = gracenet.Net{}
|
||||||
|
var Version = "DEV"
|
||||||
|
var BuildTime = "unknown"
|
||||||
|
var Log *logrus.Logger
|
||||||
|
var defaultConfigFiles = []string{"config.yml", "config.yaml"}
|
||||||
|
|
||||||
|
// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible
|
||||||
|
var defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp"}
|
||||||
|
|
||||||
|
// Shutdown channel used by the app. When closed, app must terminate.
|
||||||
|
// May be closed by the Windows service runner.
|
||||||
|
var shutdownC chan struct{}
|
||||||
|
|
||||||
|
type BuildAndRuntimeInfo struct {
|
||||||
|
GoOS string `json:"go_os"`
|
||||||
|
GoVersion string `json:"go_version"`
|
||||||
|
GoArch string `json:"go_arch"`
|
||||||
|
WarpVersion string `json:"warp_version"`
|
||||||
|
WarpFlags map[string]interface{} `json:"warp_flags"`
|
||||||
|
WarpEnvs map[string]string `json:"warp_envs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
metrics.RegisterBuildInfo(BuildTime, Version)
|
||||||
|
raven.SetDSN(sentryDSN)
|
||||||
|
raven.SetRelease(Version)
|
||||||
|
shutdownC = make(chan struct{})
|
||||||
|
app := &cli.App{}
|
||||||
|
app.Name = "cloudflared"
|
||||||
|
app.Copyright = fmt.Sprintf(`(c) %d Cloudflare Inc.
|
||||||
|
Use is subject to the license agreement at %s`, time.Now().Year(), licenseUrl)
|
||||||
|
app.Usage = "Cloudflare reverse tunnelling proxy agent"
|
||||||
|
app.ArgsUsage = "origin-url"
|
||||||
|
app.Version = fmt.Sprintf("%s (built %s)", Version, BuildTime)
|
||||||
|
app.Description = `A reverse tunnel proxy agent that connects to Cloudflare's infrastructure.
|
||||||
|
Upon connecting, you are assigned a unique subdomain on cftunnel.com.
|
||||||
|
You need to specify a hostname on a zone you control.
|
||||||
|
A DNS record will be created to CNAME your hostname to the unique subdomain on cftunnel.com.
|
||||||
|
|
||||||
|
Requests made to Cloudflare's servers for your hostname will be proxied
|
||||||
|
through the tunnel to your local webserver.`
|
||||||
|
app.Flags = []cli.Flag{
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "config",
|
||||||
|
Usage: "Specifies a config file in YAML format.",
|
||||||
|
Value: findDefaultConfigPath(),
|
||||||
|
},
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: "autoupdate-freq",
|
||||||
|
Usage: "Autoupdate frequency. Default is 24h.",
|
||||||
|
Value: time.Hour * 24,
|
||||||
|
}),
|
||||||
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
|
Name: "no-autoupdate",
|
||||||
|
Usage: "Disable periodic check for updates, restarting the server with the new version.",
|
||||||
|
Value: false,
|
||||||
|
}),
|
||||||
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
|
Name: "is-autoupdated",
|
||||||
|
Usage: "Signal the new process that Argo Tunnel client has been autoupdated",
|
||||||
|
Value: false,
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||||
|
Name: "edge",
|
||||||
|
Usage: "Address of the Cloudflare tunnel server.",
|
||||||
|
EnvVars: []string{"TUNNEL_EDGE"},
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "cacert",
|
||||||
|
Usage: "Certificate Authority authenticating the Cloudflare tunnel connection.",
|
||||||
|
EnvVars: []string{"TUNNEL_CACERT"},
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "origincert",
|
||||||
|
Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
|
||||||
|
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
|
||||||
|
Value: findDefaultOriginCertPath(),
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "url",
|
||||||
|
Value: "https://localhost:8080",
|
||||||
|
Usage: "Connect to the local webserver at `URL`.",
|
||||||
|
EnvVars: []string{"TUNNEL_URL"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "hostname",
|
||||||
|
Usage: "Set a hostname on a Cloudflare zone to route traffic through this tunnel.",
|
||||||
|
EnvVars: []string{"TUNNEL_HOSTNAME"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "origin-server-name",
|
||||||
|
Usage: "Hostname on the origin server certificate.",
|
||||||
|
EnvVars: []string{"TUNNEL_ORIGIN_SERVER_NAME"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "id",
|
||||||
|
Usage: "A unique identifier used to tie connections to this tunnel instance.",
|
||||||
|
EnvVars: []string{"TUNNEL_ID"},
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "lb-pool",
|
||||||
|
Usage: "The name of a (new/existing) load balancing pool to add this origin to.",
|
||||||
|
EnvVars: []string{"TUNNEL_LB_POOL"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "api-key",
|
||||||
|
Usage: "This parameter has been deprecated since version 2017.10.1.",
|
||||||
|
EnvVars: []string{"TUNNEL_API_KEY"},
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "api-email",
|
||||||
|
Usage: "This parameter has been deprecated since version 2017.10.1.",
|
||||||
|
EnvVars: []string{"TUNNEL_API_EMAIL"},
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "api-ca-key",
|
||||||
|
Usage: "This parameter has been deprecated since version 2017.10.1.",
|
||||||
|
EnvVars: []string{"TUNNEL_API_CA_KEY"},
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "metrics",
|
||||||
|
Value: "localhost:",
|
||||||
|
Usage: "Listen address for metrics reporting.",
|
||||||
|
EnvVars: []string{"TUNNEL_METRICS"},
|
||||||
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: "metrics-update-freq",
|
||||||
|
Usage: "Frequency to update tunnel metrics",
|
||||||
|
Value: time.Second * 5,
|
||||||
|
EnvVars: []string{"TUNNEL_METRICS_UPDATE_FREQ"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||||
|
Name: "tag",
|
||||||
|
Usage: "Custom tags used to identify this tunnel, in format `KEY=VALUE`. Multiple tags may be specified",
|
||||||
|
EnvVars: []string{"TUNNEL_TAG"},
|
||||||
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: "heartbeat-interval",
|
||||||
|
Usage: "Minimum idle time before sending a heartbeat.",
|
||||||
|
Value: time.Second * 5,
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewUint64Flag(&cli.Uint64Flag{
|
||||||
|
Name: "heartbeat-count",
|
||||||
|
Usage: "Minimum number of unacked heartbeats to send before closing the connection.",
|
||||||
|
Value: 5,
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "loglevel",
|
||||||
|
Value: "info",
|
||||||
|
Usage: "Application logging level {panic, fatal, error, warn, info, debug}",
|
||||||
|
EnvVars: []string{"TUNNEL_LOGLEVEL"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "proto-loglevel",
|
||||||
|
Value: "warn",
|
||||||
|
Usage: "Protocol logging level {panic, fatal, error, warn, info, debug}",
|
||||||
|
EnvVars: []string{"TUNNEL_PROTO_LOGLEVEL"},
|
||||||
|
}),
|
||||||
|
altsrc.NewUintFlag(&cli.UintFlag{
|
||||||
|
Name: "retries",
|
||||||
|
Value: 5,
|
||||||
|
Usage: "Maximum number of retries for connection/protocol errors.",
|
||||||
|
EnvVars: []string{"TUNNEL_RETRIES"},
|
||||||
|
}),
|
||||||
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
|
Name: "hello-world",
|
||||||
|
Value: false,
|
||||||
|
Usage: "Run Hello World Server",
|
||||||
|
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "pidfile",
|
||||||
|
Usage: "Write the application's PID to this file after first successful connection.",
|
||||||
|
EnvVars: []string{"TUNNEL_PIDFILE"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "logfile",
|
||||||
|
Usage: "Save application log to this file for reporting issues.",
|
||||||
|
EnvVars: []string{"TUNNEL_LOGFILE"},
|
||||||
|
}),
|
||||||
|
altsrc.NewIntFlag(&cli.IntFlag{
|
||||||
|
Name: "ha-connections",
|
||||||
|
Value: 4,
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: "proxy-connect-timeout",
|
||||||
|
Usage: "HTTP proxy timeout for establishing a new connection",
|
||||||
|
Value: time.Second * 30,
|
||||||
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: "proxy-tls-timeout",
|
||||||
|
Usage: "HTTP proxy timeout for completing a TLS handshake",
|
||||||
|
Value: time.Second * 10,
|
||||||
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: "proxy-tcp-keepalive",
|
||||||
|
Usage: "HTTP proxy TCP keepalive duration",
|
||||||
|
Value: time.Second * 30,
|
||||||
|
}),
|
||||||
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
|
Name: "proxy-no-happy-eyeballs",
|
||||||
|
Usage: "HTTP proxy should disable \"happy eyeballs\" for IPv4/v6 fallback",
|
||||||
|
}),
|
||||||
|
altsrc.NewIntFlag(&cli.IntFlag{
|
||||||
|
Name: "proxy-keepalive-connections",
|
||||||
|
Usage: "HTTP proxy maximum keepalive connection pool size",
|
||||||
|
Value: 100,
|
||||||
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: "proxy-keepalive-timeout",
|
||||||
|
Usage: "HTTP proxy timeout for closing an idle connection",
|
||||||
|
Value: time.Second * 90,
|
||||||
|
}),
|
||||||
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
|
Name: "proxy-dns",
|
||||||
|
Usage: "Run a DNS over HTTPS proxy server.",
|
||||||
|
EnvVars: []string{"TUNNEL_DNS"},
|
||||||
|
}),
|
||||||
|
altsrc.NewUintFlag(&cli.UintFlag{
|
||||||
|
Name: "proxy-dns-port",
|
||||||
|
Value: 53,
|
||||||
|
Usage: "Listen on given port for the DNS over HTTPS proxy server.",
|
||||||
|
EnvVars: []string{"TUNNEL_DNS_PORT"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "proxy-dns-address",
|
||||||
|
Usage: "Listen address for the DNS over HTTPS proxy server.",
|
||||||
|
Value: "localhost",
|
||||||
|
EnvVars: []string{"TUNNEL_DNS_ADDRESS"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||||
|
Name: "proxy-dns-upstream",
|
||||||
|
Usage: "Upstream endpoint URL, you can specify multiple endpoints for redundancy.",
|
||||||
|
Value: cli.NewStringSlice("https://dns.cloudflare.com/.well-known/dns-query"),
|
||||||
|
EnvVars: []string{"TUNNEL_DNS_UPSTREAM"},
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
app.Action = func(c *cli.Context) error {
|
||||||
|
raven.CapturePanic(func() { startServer(c) }, nil)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
app.Before = func(context *cli.Context) error {
|
||||||
|
Log = logrus.New()
|
||||||
|
inputSource, err := findInputSourceContext(context)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("Cannot load configuration from %s", context.String("config"))
|
||||||
|
return err
|
||||||
|
} else if inputSource != nil {
|
||||||
|
err := altsrc.ApplyInputSourceValues(context, inputSource, app.Flags)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("Cannot apply configuration from %s", context.String("config"))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
Log.Infof("Applied configuration from %s", context.String("config"))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
app.Commands = []*cli.Command{
|
||||||
|
{
|
||||||
|
Name: "update",
|
||||||
|
Action: update,
|
||||||
|
Usage: "Update the agent if a new version exists",
|
||||||
|
ArgsUsage: " ",
|
||||||
|
Description: `Looks for a new version on the offical download server.
|
||||||
|
If a new version exists, updates the agent binary and quits.
|
||||||
|
Otherwise, does nothing.
|
||||||
|
|
||||||
|
To determine if an update happened in a script, check for error code 64.`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "login",
|
||||||
|
Action: login,
|
||||||
|
Usage: "Generate a configuration file with your login details",
|
||||||
|
ArgsUsage: " ",
|
||||||
|
Flags: []cli.Flag{
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "url",
|
||||||
|
Hidden: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "hello",
|
||||||
|
Action: hello,
|
||||||
|
Usage: "Run a simple \"Hello World\" server for testing Argo Tunnel.",
|
||||||
|
Flags: []cli.Flag{
|
||||||
|
&cli.IntFlag{
|
||||||
|
Name: "port",
|
||||||
|
Usage: "Listen on the selected port.",
|
||||||
|
Value: 8080,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ArgsUsage: " ", // can't be the empty string or we get the default output
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "proxy-dns",
|
||||||
|
Action: tunneldns.Run,
|
||||||
|
Usage: "Run a DNS over HTTPS proxy server.",
|
||||||
|
Flags: []cli.Flag{
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "metrics",
|
||||||
|
Value: "localhost:",
|
||||||
|
Usage: "Listen address for metrics reporting.",
|
||||||
|
EnvVars: []string{"TUNNEL_METRICS"},
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "address",
|
||||||
|
Usage: "Listen address for the DNS over HTTPS proxy server.",
|
||||||
|
Value: "localhost",
|
||||||
|
EnvVars: []string{"TUNNEL_DNS_ADDRESS"},
|
||||||
|
},
|
||||||
|
&cli.IntFlag{
|
||||||
|
Name: "port",
|
||||||
|
Usage: "Listen on given port for the DNS over HTTPS proxy server.",
|
||||||
|
Value: 53,
|
||||||
|
EnvVars: []string{"TUNNEL_DNS_PORT"},
|
||||||
|
},
|
||||||
|
&cli.StringSliceFlag{
|
||||||
|
Name: "upstream",
|
||||||
|
Usage: "Upstream endpoint URL, you can specify multiple endpoints for redundancy.",
|
||||||
|
Value: cli.NewStringSlice("https://dns.cloudflare.com/.well-known/dns-query"),
|
||||||
|
EnvVars: []string{"TUNNEL_DNS_UPSTREAM"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ArgsUsage: " ", // can't be the empty string or we get the default output
|
||||||
|
},
|
||||||
|
}
|
||||||
|
runApp(app)
|
||||||
|
}
|
||||||
|
|
||||||
|
func startServer(c *cli.Context) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errC := make(chan error)
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
// If the user choose to supply all options through env variables,
|
||||||
|
// c.NumFlags() == 0 && c.NArg() == 0. For cloudflared to work, the user needs to at
|
||||||
|
// least provide a hostname.
|
||||||
|
if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" {
|
||||||
|
Log.Infof("No arguments were provided. You need to at least specify the hostname for this tunnel. See %s", quickStartUrl)
|
||||||
|
cli.ShowAppHelp(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logLevel, err := logrus.ParseLevel(c.String("loglevel"))
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Fatal("Unknown logging level specified")
|
||||||
|
}
|
||||||
|
Log.SetLevel(logLevel)
|
||||||
|
|
||||||
|
protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Fatal("Unknown protocol logging level specified")
|
||||||
|
}
|
||||||
|
protoLogger := logrus.New()
|
||||||
|
protoLogger.Level = protoLogLevel
|
||||||
|
|
||||||
|
if c.String("logfile") != "" {
|
||||||
|
if err := initLogFile(c, protoLogger); err != nil {
|
||||||
|
Log.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isAutoupdateEnabled(c) {
|
||||||
|
if initUpdate() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
|
||||||
|
go autoupdate(c.Duration("autoupdate-freq"), shutdownC)
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname, err := validation.ValidateHostname(c.String("hostname"))
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Fatal("Invalid hostname")
|
||||||
|
}
|
||||||
|
clientID := c.String("id")
|
||||||
|
if !c.IsSet("id") {
|
||||||
|
clientID = generateRandomClientID()
|
||||||
|
}
|
||||||
|
|
||||||
|
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Fatal("Tag parse failure")
|
||||||
|
}
|
||||||
|
|
||||||
|
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
|
||||||
|
if c.IsSet("hello-world") {
|
||||||
|
wg.Add(1)
|
||||||
|
listener, err := createListener("127.0.0.1:")
|
||||||
|
if err != nil {
|
||||||
|
listener.Close()
|
||||||
|
Log.WithError(err).Fatal("Cannot start Hello World Server")
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
startHelloWorldServer(listener, shutdownC)
|
||||||
|
wg.Done()
|
||||||
|
listener.Close()
|
||||||
|
}()
|
||||||
|
c.Set("url", "https://"+listener.Addr().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.IsSet("proxy-dns") {
|
||||||
|
wg.Add(1)
|
||||||
|
listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(c.Uint("proxy-dns-port")), c.StringSlice("proxy-dns-upstream"))
|
||||||
|
if err != nil {
|
||||||
|
listener.Stop()
|
||||||
|
Log.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server")
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
listener.Start()
|
||||||
|
<-shutdownC
|
||||||
|
listener.Stop()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
url, err := validateUrl(c)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Fatal("Error validating url")
|
||||||
|
}
|
||||||
|
Log.Infof("Proxying tunnel requests to %s", url)
|
||||||
|
|
||||||
|
// Fail if the user provided an old authentication method
|
||||||
|
if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
|
||||||
|
Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflared login")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the user has acquired a certificate using the log in command
|
||||||
|
originCertPath, err := homedir.Expand(c.String("origincert"))
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
|
||||||
|
}
|
||||||
|
ok, err := fileExists(originCertPath)
|
||||||
|
if err != nil {
|
||||||
|
Log.Fatalf("Cannot check if origin cert exists at path %s", c.String("origincert"))
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
Log.Fatalf(`Cannot find a valid certificate for your origin at the path:
|
||||||
|
|
||||||
|
%s
|
||||||
|
|
||||||
|
If the path above is wrong, specify the path with the -origincert option.
|
||||||
|
If you don't have a certificate signed by Cloudflare, run the command:
|
||||||
|
|
||||||
|
%s login
|
||||||
|
`, originCertPath, os.Args[0])
|
||||||
|
}
|
||||||
|
// Easier to send the certificate as []byte via RPC than decoding it at this point
|
||||||
|
originCert, err := ioutil.ReadFile(originCertPath)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnelMetrics := origin.NewTunnelMetrics()
|
||||||
|
httpTransport := &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: c.Duration("proxy-connect-timeout"),
|
||||||
|
KeepAlive: c.Duration("proxy-tcp-keepalive"),
|
||||||
|
DualStack: !c.Bool("proxy-no-happy-eyeballs"),
|
||||||
|
}).DialContext,
|
||||||
|
MaxIdleConns: c.Int("proxy-keepalive-connections"),
|
||||||
|
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
|
||||||
|
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
TLSClientConfig: &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.IsSet("hello-world") && c.IsSet("origin-server-name") {
|
||||||
|
httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnelConfig := &origin.TunnelConfig{
|
||||||
|
EdgeAddrs: c.StringSlice("edge"),
|
||||||
|
OriginUrl: url,
|
||||||
|
Hostname: hostname,
|
||||||
|
OriginCert: originCert,
|
||||||
|
TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
|
||||||
|
ClientTlsConfig: httpTransport.TLSClientConfig,
|
||||||
|
Retries: c.Uint("retries"),
|
||||||
|
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
||||||
|
MaxHeartbeats: c.Uint64("heartbeat-count"),
|
||||||
|
ClientID: clientID,
|
||||||
|
ReportedVersion: Version,
|
||||||
|
LBPool: c.String("lb-pool"),
|
||||||
|
Tags: tags,
|
||||||
|
HAConnections: c.Int("ha-connections"),
|
||||||
|
HTTPTransport: httpTransport,
|
||||||
|
Metrics: tunnelMetrics,
|
||||||
|
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
|
||||||
|
ProtocolLogger: protoLogger,
|
||||||
|
Logger: Log,
|
||||||
|
IsAutoupdated: c.Bool("is-autoupdated"),
|
||||||
|
}
|
||||||
|
connectedSignal := make(chan struct{})
|
||||||
|
|
||||||
|
go writePidFile(connectedSignal, c.String("pidfile"))
|
||||||
|
go func() {
|
||||||
|
errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Fatal("Error opening metrics server listener")
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
errC <- metrics.ServeMetrics(metricsListener, shutdownC)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var errCode int
|
||||||
|
err = WaitForSignal(errC, shutdownC)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Error("Quitting due to error")
|
||||||
|
raven.CaptureErrorAndWait(err, nil)
|
||||||
|
errCode = 1
|
||||||
|
} else {
|
||||||
|
Log.Info("Quitting...")
|
||||||
|
}
|
||||||
|
// Wait for clean exit, discarding all errors
|
||||||
|
go func() {
|
||||||
|
for range errC {
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
os.Exit(errCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
|
||||||
|
signals := make(chan os.Signal, 10)
|
||||||
|
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
defer signal.Stop(signals)
|
||||||
|
select {
|
||||||
|
case err := <-errC:
|
||||||
|
close(shutdownC)
|
||||||
|
return err
|
||||||
|
case <-signals:
|
||||||
|
close(shutdownC)
|
||||||
|
case <-shutdownC:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func update(_ *cli.Context) error {
|
||||||
|
if updateApplied() {
|
||||||
|
os.Exit(64)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func initUpdate() bool {
|
||||||
|
if updateApplied() {
|
||||||
|
os.Args = append(os.Args, "--is-autoupdated=true")
|
||||||
|
if _, err := listeners.StartProcess(); err != nil {
|
||||||
|
Log.WithError(err).Error("Unable to restart server automatically")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func autoupdate(freq time.Duration, shutdownC chan struct{}) {
|
||||||
|
for {
|
||||||
|
if updateApplied() {
|
||||||
|
os.Args = append(os.Args, "--is-autoupdated=true")
|
||||||
|
if _, err := listeners.StartProcess(); err != nil {
|
||||||
|
Log.WithError(err).Error("Unable to restart server automatically")
|
||||||
|
}
|
||||||
|
close(shutdownC)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(freq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateApplied() bool {
|
||||||
|
releaseInfo := checkForUpdates()
|
||||||
|
if releaseInfo.Updated {
|
||||||
|
Log.Infof("Updated to version %s", releaseInfo.Version)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if releaseInfo.Error != nil {
|
||||||
|
Log.WithError(releaseInfo.Error).Error("Update check failed")
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileExists(path string) (bool, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
// ignore missing files
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
f.Close()
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns the first path that contains a cert.pem file. If none of the defaultConfigDirs
|
||||||
|
// (differs by OS for legacy reasons) contains a cert.pem file, return empty string
|
||||||
|
func findDefaultOriginCertPath() string {
|
||||||
|
for _, defaultConfigDir := range defaultConfigDirs {
|
||||||
|
originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, credentialFile))
|
||||||
|
if ok, _ := fileExists(originCertPath); ok {
|
||||||
|
return originCertPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns the firt path that contains a config file. If none of the combination of
|
||||||
|
// defaultConfigDirs (differs by OS for legacy reasons) and defaultConfigFiles
|
||||||
|
// contains a config file, return empty string
|
||||||
|
func findDefaultConfigPath() string {
|
||||||
|
for _, configDir := range defaultConfigDirs {
|
||||||
|
for _, configFile := range defaultConfigFiles {
|
||||||
|
dirPath, err := homedir.Expand(configDir)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
path := filepath.Join(dirPath, configFile)
|
||||||
|
if ok, _ := fileExists(path); ok {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) {
|
||||||
|
if context.String("config") != "" {
|
||||||
|
return altsrc.NewYamlSourceFromFile(context.String("config"))
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateRandomClientID() string {
|
||||||
|
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
id := make([]byte, 32)
|
||||||
|
r.Read(id)
|
||||||
|
return hex.EncodeToString(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writePidFile(waitForSignal chan struct{}, pidFile string) {
|
||||||
|
<-waitForSignal
|
||||||
|
daemon.SdNotify(false, "READY=1")
|
||||||
|
if pidFile == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file, err := os.Create(pidFile)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
fmt.Fprintf(file, "%d", os.Getpid())
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate url. It can be either from --url or argument
|
||||||
|
func validateUrl(c *cli.Context) (string, error) {
|
||||||
|
var url = c.String("url")
|
||||||
|
if c.NArg() > 0 {
|
||||||
|
if c.IsSet("url") {
|
||||||
|
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
|
||||||
|
}
|
||||||
|
url = c.Args().Get(0)
|
||||||
|
}
|
||||||
|
validUrl, err := validation.ValidateUrl(url)
|
||||||
|
return validUrl, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error {
|
||||||
|
filePath, err := homedir.Expand(c.String("logfile"))
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Cannot resolve logfile path")
|
||||||
|
}
|
||||||
|
|
||||||
|
fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC
|
||||||
|
// do not truncate log file if the client has been autoupdated
|
||||||
|
if c.Bool("is-autoupdated") {
|
||||||
|
fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE
|
||||||
|
}
|
||||||
|
f, err := os.OpenFile(filePath, fileMode, 0664)
|
||||||
|
if err != nil {
|
||||||
|
errors.Wrap(err, fmt.Sprintf("Cannot open file %s", filePath))
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
pathMap := lfshook.PathMap{
|
||||||
|
logrus.InfoLevel: filePath,
|
||||||
|
logrus.ErrorLevel: filePath,
|
||||||
|
logrus.FatalLevel: filePath,
|
||||||
|
logrus.PanicLevel: filePath,
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
|
||||||
|
protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
|
||||||
|
|
||||||
|
flags := make(map[string]interface{})
|
||||||
|
envs := make(map[string]string)
|
||||||
|
|
||||||
|
for _, flag := range c.LocalFlagNames() {
|
||||||
|
flags[flag] = c.Generic(flag)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find env variables for Argo Tunnel
|
||||||
|
for _, env := range os.Environ() {
|
||||||
|
// All Argo Tunnel env variables start with TUNNEL_
|
||||||
|
if strings.Contains(env, "TUNNEL_") {
|
||||||
|
vars := strings.Split(env, "=")
|
||||||
|
if len(vars) == 2 {
|
||||||
|
envs[vars[0]] = vars[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.Infof("Argo Tunnel build and runtime configuration: %+v", BuildAndRuntimeInfo{
|
||||||
|
GoOS: runtime.GOOS,
|
||||||
|
GoVersion: runtime.Version(),
|
||||||
|
GoArch: runtime.GOARCH,
|
||||||
|
WarpVersion: Version,
|
||||||
|
WarpFlags: flags,
|
||||||
|
WarpEnvs: envs,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isAutoupdateEnabled(c *cli.Context) bool {
|
||||||
|
if terminal.IsTerminal(int(os.Stdout.Fd())) {
|
||||||
|
Log.Info(noAutoupdateMessage)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0
|
||||||
|
}
|
|
@ -0,0 +1,192 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/mitchellh/go-homedir"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServiceTemplate struct {
|
||||||
|
Path string
|
||||||
|
Content string
|
||||||
|
FileMode os.FileMode
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServiceTemplateArgs struct {
|
||||||
|
Path string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (st *ServiceTemplate) ResolvePath() (string, error) {
|
||||||
|
resolvedPath, err := homedir.Expand(st.Path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error resolving path %s: %v", st.Path, err)
|
||||||
|
}
|
||||||
|
return resolvedPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
|
||||||
|
tmpl, err := template.New(st.Path).Parse(st.Content)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error generating %s template: %v", st.Path, err)
|
||||||
|
}
|
||||||
|
resolvedPath, err := st.ResolvePath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
err = tmpl.Execute(&buffer, args)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error generating %s: %v", st.Path, err)
|
||||||
|
}
|
||||||
|
fileMode := os.FileMode(0644)
|
||||||
|
if st.FileMode != 0 {
|
||||||
|
fileMode = st.FileMode
|
||||||
|
}
|
||||||
|
err = ioutil.WriteFile(resolvedPath, buffer.Bytes(), fileMode)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error writing %s: %v", resolvedPath, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (st *ServiceTemplate) Remove() error {
|
||||||
|
resolvedPath, err := st.ResolvePath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = os.Remove(resolvedPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error deleting %s: %v", resolvedPath, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runCommand(command string, args ...string) error {
|
||||||
|
cmd := exec.Command(command, args...)
|
||||||
|
stderr, err := cmd.StderrPipe()
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("error getting stderr pipe")
|
||||||
|
return fmt.Errorf("error getting stderr pipe: %v", err)
|
||||||
|
}
|
||||||
|
err = cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("error starting %s", command)
|
||||||
|
return fmt.Errorf("error starting %s: %v", command, err)
|
||||||
|
}
|
||||||
|
commandErr, _ := ioutil.ReadAll(stderr)
|
||||||
|
if len(commandErr) > 0 {
|
||||||
|
Log.Errorf("%s: %s", command, commandErr)
|
||||||
|
}
|
||||||
|
err = cmd.Wait()
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("%s returned error", command)
|
||||||
|
return fmt.Errorf("%s returned with error: %v", command, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureConfigDirExists(configDir string) error {
|
||||||
|
ok, err := fileExists(configDir)
|
||||||
|
if !ok && err == nil {
|
||||||
|
err = os.Mkdir(configDir, 0700)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// openFile opens the file at path. If create is set and the file exists, returns nil, true, nil
|
||||||
|
func openFile(path string, create bool) (file *os.File, exists bool, err error) {
|
||||||
|
expandedPath, err := homedir.Expand(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
if create {
|
||||||
|
fileInfo, err := os.Stat(expandedPath)
|
||||||
|
if err == nil && fileInfo.Size() > 0 {
|
||||||
|
return nil, true, nil
|
||||||
|
}
|
||||||
|
file, err = os.OpenFile(expandedPath, os.O_RDWR|os.O_CREATE, 0600)
|
||||||
|
} else {
|
||||||
|
file, err = os.Open(expandedPath)
|
||||||
|
}
|
||||||
|
return file, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyCertificate(srcConfigDir, destConfigDir string) error {
|
||||||
|
destCredentialPath := filepath.Join(destConfigDir, credentialFile)
|
||||||
|
destFile, exists, err := openFile(destCredentialPath, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if exists {
|
||||||
|
// credentials already exist, do nothing
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer destFile.Close()
|
||||||
|
|
||||||
|
srcCredentialPath := filepath.Join(srcConfigDir, credentialFile)
|
||||||
|
srcFile, _, err := openFile(srcCredentialPath, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer srcFile.Close()
|
||||||
|
|
||||||
|
// Copy certificate
|
||||||
|
_, err = io.Copy(destFile, srcFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to copy %s to %s: %v", srcCredentialPath, destCredentialPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile string) error {
|
||||||
|
if err := ensureConfigDirExists(serviceConfigDir); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := copyCertificate(defaultConfigDir, serviceConfigDir); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy or create config
|
||||||
|
destConfigPath := filepath.Join(serviceConfigDir, defaultConfigFile)
|
||||||
|
destFile, exists, err := openFile(destConfigPath, true)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("cannot open %s", destConfigPath)
|
||||||
|
return err
|
||||||
|
} else if exists {
|
||||||
|
// config already exists, do nothing
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer destFile.Close()
|
||||||
|
|
||||||
|
srcConfigPath := filepath.Join(defaultConfigDir, defaultConfigFile)
|
||||||
|
srcFile, _, err := openFile(srcConfigPath, false)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Your service needs a config file that at least specifies the hostname option.")
|
||||||
|
fmt.Println("Type in a hostname now, or leave it blank and create the config file later.")
|
||||||
|
fmt.Print("Hostname: ")
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
input, _ := reader.ReadString('\n')
|
||||||
|
if input == "" {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Fprintf(destFile, "hostname: %s\n", input)
|
||||||
|
} else {
|
||||||
|
defer srcFile.Close()
|
||||||
|
_, err = io.Copy(destFile, srcFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to copy %s to %s: %v", srcConfigPath, destConfigPath, err)
|
||||||
|
}
|
||||||
|
Log.Infof("Copied %s to %s", srcConfigPath, destConfigPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Restrict key names to characters allowed in an HTTP header name.
|
||||||
|
// Restrict key values to printable characters (what is recognised as data in an HTTP header value).
|
||||||
|
var tagRegexp = regexp.MustCompile("^([a-zA-Z0-9!#$%&'*+\\-.^_`|~]+)=([[:print:]]+)$")
|
||||||
|
|
||||||
|
func NewTagFromCLI(compoundTag string) (tunnelpogs.Tag, bool) {
|
||||||
|
matches := tagRegexp.FindStringSubmatch(compoundTag)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return tunnelpogs.Tag{}, false
|
||||||
|
}
|
||||||
|
return tunnelpogs.Tag{Name: matches[1], Value: matches[2]}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTagSliceFromCLI(tags []string) ([]tunnelpogs.Tag, error) {
|
||||||
|
var tagSlice []tunnelpogs.Tag
|
||||||
|
for _, compoundTag := range tags {
|
||||||
|
if tag, ok := NewTagFromCLI(compoundTag); ok {
|
||||||
|
tagSlice = append(tagSlice, tag)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("Cannot parse tag value %s", compoundTag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tagSlice, nil
|
||||||
|
}
|
|
@ -0,0 +1,46 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSingleTag(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
Input string
|
||||||
|
Output tunnelpogs.Tag
|
||||||
|
Fail bool
|
||||||
|
}{
|
||||||
|
{Input: "x=y", Output: tunnelpogs.Tag{Name: "x", Value: "y"}},
|
||||||
|
{Input: "More-Complex=Tag Values", Output: tunnelpogs.Tag{Name: "More-Complex", Value: "Tag Values"}},
|
||||||
|
{Input: "First=Equals=Wins", Output: tunnelpogs.Tag{Name: "First", Value: "Equals=Wins"}},
|
||||||
|
{Input: "x=", Fail: true},
|
||||||
|
{Input: "=y", Fail: true},
|
||||||
|
{Input: "=", Fail: true},
|
||||||
|
{Input: "No spaces allowed=in key names", Fail: true},
|
||||||
|
{Input: "omg\nwtf=bbq", Fail: true},
|
||||||
|
}
|
||||||
|
for i, testCase := range testCases {
|
||||||
|
tag, ok := NewTagFromCLI(testCase.Input)
|
||||||
|
assert.Equalf(t, !testCase.Fail, ok, "mismatched success for test case %d", i)
|
||||||
|
assert.Equalf(t, testCase.Output, tag, "mismatched output for test case %d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTagSlice(t *testing.T) {
|
||||||
|
tagSlice, err := NewTagSliceFromCLI([]string{"a=b", "c=d", "e=f"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, tagSlice, 3)
|
||||||
|
assert.Equal(t, "a", tagSlice[0].Name)
|
||||||
|
assert.Equal(t, "b", tagSlice[0].Value)
|
||||||
|
assert.Equal(t, "c", tagSlice[1].Name)
|
||||||
|
assert.Equal(t, "d", tagSlice[1].Value)
|
||||||
|
assert.Equal(t, "e", tagSlice[2].Name)
|
||||||
|
assert.Equal(t, "f", tagSlice[2].Value)
|
||||||
|
|
||||||
|
tagSlice, err = NewTagSliceFromCLI([]string{"a=b", "=", "e=f"})
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "github.com/equinox-io/equinox"
|
||||||
|
|
||||||
|
const appID = "app_cwbQae3Tpea"
|
||||||
|
|
||||||
|
var publicKey = []byte(`
|
||||||
|
-----BEGIN ECDSA PUBLIC KEY-----
|
||||||
|
MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE4OWZocTVZ8Do/L6ScLdkV+9A0IYMHoOf
|
||||||
|
dsCmJ/QZ6aw0w9qkkwEpne1Lmo6+0pGexZzFZOH6w5amShn+RXt7qkSid9iWlzGq
|
||||||
|
EKx0BZogHSor9Wy5VztdFaAaVbsJiCbO
|
||||||
|
-----END ECDSA PUBLIC KEY-----
|
||||||
|
`)
|
||||||
|
|
||||||
|
type ReleaseInfo struct {
|
||||||
|
Updated bool
|
||||||
|
Version string
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkForUpdates() ReleaseInfo {
|
||||||
|
var opts equinox.Options
|
||||||
|
if err := opts.SetPublicKeyPEM(publicKey); err != nil {
|
||||||
|
return ReleaseInfo{Error: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := equinox.Check(appID, opts)
|
||||||
|
switch {
|
||||||
|
case err == equinox.NotAvailableErr:
|
||||||
|
return ReleaseInfo{}
|
||||||
|
case err != nil:
|
||||||
|
return ReleaseInfo{Error: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = resp.Apply()
|
||||||
|
if err != nil {
|
||||||
|
return ReleaseInfo{Error: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion}
|
||||||
|
}
|
|
@ -0,0 +1,166 @@
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
// Copypasta from the example files:
|
||||||
|
// https://github.com/golang/sys/blob/master/windows/svc/example
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows/svc"
|
||||||
|
"golang.org/x/sys/windows/svc/eventlog"
|
||||||
|
"golang.org/x/sys/windows/svc/mgr"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
windowsServiceName = "Cloudflared"
|
||||||
|
windowsServiceDescription = "Argo Tunnel agent"
|
||||||
|
)
|
||||||
|
|
||||||
|
func runApp(app *cli.App) {
|
||||||
|
app.Commands = append(app.Commands, &cli.Command{
|
||||||
|
Name: "service",
|
||||||
|
Usage: "Manages the Argo Tunnel Windows service",
|
||||||
|
Subcommands: []*cli.Command{
|
||||||
|
&cli.Command{
|
||||||
|
Name: "install",
|
||||||
|
Usage: "Install Argo Tunnel as a Windows service",
|
||||||
|
Action: installWindowsService,
|
||||||
|
},
|
||||||
|
&cli.Command{
|
||||||
|
Name: "uninstall",
|
||||||
|
Usage: "Uninstall the Argo Tunnel service",
|
||||||
|
Action: uninstallWindowsService,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
isIntSess, err := svc.IsAnInteractiveSession()
|
||||||
|
if err != nil {
|
||||||
|
Log.Fatalf("failed to determine if we are running in an interactive session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isIntSess {
|
||||||
|
app.Run(os.Args)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
elog, err := eventlog.Open(windowsServiceName)
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("Cannot open event log for %s", windowsServiceName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer elog.Close()
|
||||||
|
|
||||||
|
elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName))
|
||||||
|
// Run executes service name by calling windowsService which is a Handler
|
||||||
|
// interface that implements Execute method
|
||||||
|
err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog})
|
||||||
|
if err != nil {
|
||||||
|
elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
elog.Info(1, fmt.Sprintf("%s service stopped", windowsServiceName))
|
||||||
|
}
|
||||||
|
|
||||||
|
type windowsService struct {
|
||||||
|
app *cli.App
|
||||||
|
elog *eventlog.Log
|
||||||
|
}
|
||||||
|
|
||||||
|
// called by the package code at the start of the service
|
||||||
|
func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) {
|
||||||
|
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
|
||||||
|
changes <- svc.Status{State: svc.StartPending}
|
||||||
|
go s.app.Run(args)
|
||||||
|
|
||||||
|
changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case c := <-r:
|
||||||
|
switch c.Cmd {
|
||||||
|
case svc.Interrogate:
|
||||||
|
s.elog.Info(1, fmt.Sprintf("control request 1 #%d", c))
|
||||||
|
changes <- c.CurrentStatus
|
||||||
|
case svc.Stop:
|
||||||
|
s.elog.Info(1, "received stop control request")
|
||||||
|
break loop
|
||||||
|
case svc.Shutdown:
|
||||||
|
s.elog.Info(1, "received shutdown control request")
|
||||||
|
break loop
|
||||||
|
default:
|
||||||
|
s.elog.Error(1, fmt.Sprintf("unexpected control request #%d", c))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(shutdownC)
|
||||||
|
changes <- svc.Status{State: svc.StopPending}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func installWindowsService(c *cli.Context) error {
|
||||||
|
Log.Infof("Installing Argo Tunnel Windows service")
|
||||||
|
exepath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
Log.Infof("Cannot find path name that start the process")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
Log.WithError(err).Infof("Cannot establish a connection to the service control manager")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
s, err := m.OpenService(windowsServiceName)
|
||||||
|
if err == nil {
|
||||||
|
s.Close()
|
||||||
|
Log.Errorf("service %s already exists", windowsServiceName)
|
||||||
|
return fmt.Errorf("service %s already exists", windowsServiceName)
|
||||||
|
}
|
||||||
|
config := mgr.Config{StartType: mgr.StartAutomatic, DisplayName: windowsServiceDescription}
|
||||||
|
s, err = m.CreateService(windowsServiceName, exepath, config)
|
||||||
|
if err != nil {
|
||||||
|
Log.Infof("Cannot install service %s", windowsServiceName)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
err = eventlog.InstallAsEventCreate(windowsServiceName, eventlog.Error|eventlog.Warning|eventlog.Info)
|
||||||
|
if err != nil {
|
||||||
|
s.Delete()
|
||||||
|
Log.WithError(err).Infof("Cannot install event logger")
|
||||||
|
return fmt.Errorf("SetupEventLogSource() failed: %s", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallWindowsService(c *cli.Context) error {
|
||||||
|
Log.Infof("Uninstalling Argo Tunnel Windows Service")
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
Log.Infof("Cannot establish a connection to the service control manager")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
s, err := m.OpenService(windowsServiceName)
|
||||||
|
if err != nil {
|
||||||
|
Log.Infof("service %s is not installed", windowsServiceName)
|
||||||
|
return fmt.Errorf("service %s is not installed", windowsServiceName)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
err = s.Delete()
|
||||||
|
if err != nil {
|
||||||
|
Log.Errorf("Cannot delete service %s", windowsServiceName)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = eventlog.Remove(windowsServiceName)
|
||||||
|
if err != nil {
|
||||||
|
Log.Infof("Cannot remove event logger")
|
||||||
|
return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -24,12 +24,6 @@ type activeStreamMap struct {
|
||||||
ignoreNewStreams bool
|
ignoreNewStreams bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type FlowControlMetrics struct {
|
|
||||||
AverageReceiveWindowSize, AverageSendWindowSize float64
|
|
||||||
MinReceiveWindowSize, MaxReceiveWindowSize uint32
|
|
||||||
MinSendWindowSize, MaxSendWindowSize uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap {
|
func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap {
|
||||||
m := &activeStreamMap{
|
m := &activeStreamMap{
|
||||||
streams: make(map[uint32]*MuxedStream),
|
streams: make(map[uint32]*MuxedStream),
|
||||||
|
@ -169,45 +163,3 @@ func (m *activeStreamMap) Abort() {
|
||||||
}
|
}
|
||||||
m.ignoreNewStreams = true
|
m.ignoreNewStreams = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *activeStreamMap) Metrics() *FlowControlMetrics {
|
|
||||||
m.Lock()
|
|
||||||
defer m.Unlock()
|
|
||||||
var averageReceiveWindowSize, averageSendWindowSize float64
|
|
||||||
var minReceiveWindowSize, maxReceiveWindowSize, minSendWindowSize, maxSendWindowSize uint32
|
|
||||||
i := 0
|
|
||||||
// The first variable in the range expression for map is the key, not index.
|
|
||||||
for _, stream := range m.streams {
|
|
||||||
// iterative mean: a(t+1) = a(t) + (a(t)-x)/(t+1)
|
|
||||||
windows := stream.FlowControlWindow()
|
|
||||||
averageReceiveWindowSize += (float64(windows.receiveWindow) - averageReceiveWindowSize) / float64(i+1)
|
|
||||||
averageSendWindowSize += (float64(windows.sendWindow) - averageSendWindowSize) / float64(i+1)
|
|
||||||
if i == 0 {
|
|
||||||
maxReceiveWindowSize = windows.receiveWindow
|
|
||||||
minReceiveWindowSize = windows.receiveWindow
|
|
||||||
maxSendWindowSize = windows.sendWindow
|
|
||||||
minSendWindowSize = windows.sendWindow
|
|
||||||
} else {
|
|
||||||
if windows.receiveWindow > maxReceiveWindowSize {
|
|
||||||
maxReceiveWindowSize = windows.receiveWindow
|
|
||||||
} else if windows.receiveWindow < minReceiveWindowSize {
|
|
||||||
minReceiveWindowSize = windows.receiveWindow
|
|
||||||
}
|
|
||||||
|
|
||||||
if windows.sendWindow > maxSendWindowSize {
|
|
||||||
maxSendWindowSize = windows.sendWindow
|
|
||||||
} else if windows.sendWindow < minSendWindowSize {
|
|
||||||
minSendWindowSize = windows.sendWindow
|
|
||||||
}
|
|
||||||
}
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
return &FlowControlMetrics{
|
|
||||||
MinReceiveWindowSize: minReceiveWindowSize,
|
|
||||||
MaxReceiveWindowSize: maxReceiveWindowSize,
|
|
||||||
AverageReceiveWindowSize: averageReceiveWindowSize,
|
|
||||||
MinSendWindowSize: minSendWindowSize,
|
|
||||||
MaxSendWindowSize: maxSendWindowSize,
|
|
||||||
AverageSendWindowSize: averageSendWindowSize,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AtomicCounter struct {
|
||||||
|
count uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AtomicCounter) IncrementBy(number uint64) {
|
||||||
|
atomic.AddUint64(&c.count, number)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the current value of counter and reset it to 0
|
||||||
|
func (c *AtomicCounter) Count() uint64 {
|
||||||
|
return atomic.SwapUint64(&c.count, 0)
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCounter(t *testing.T) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(dataPoints)
|
||||||
|
c := AtomicCounter{}
|
||||||
|
for i := 0; i < dataPoints; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
c.IncrementBy(uint64(1))
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, uint64(dataPoints), c.Count())
|
||||||
|
assert.Equal(t, uint64(0), c.Count())
|
||||||
|
}
|
|
@ -59,6 +59,8 @@ type Muxer struct {
|
||||||
muxReader *MuxReader
|
muxReader *MuxReader
|
||||||
// muxWriter is the write process.
|
// muxWriter is the write process.
|
||||||
muxWriter *MuxWriter
|
muxWriter *MuxWriter
|
||||||
|
// muxMetricsUpdater is the process to update metrics
|
||||||
|
muxMetricsUpdater *muxMetricsUpdater
|
||||||
// newStreamChan is used to create new streams on the writer thread.
|
// newStreamChan is used to create new streams on the writer thread.
|
||||||
// The writer will assign the next available stream ID.
|
// The writer will assign the next available stream ID.
|
||||||
newStreamChan chan MuxedStreamRequest
|
newStreamChan chan MuxedStreamRequest
|
||||||
|
@ -133,6 +135,11 @@ func Handshake(
|
||||||
// set up reader/writer pair ready for serve
|
// set up reader/writer pair ready for serve
|
||||||
streamErrors := NewStreamErrorMap()
|
streamErrors := NewStreamErrorMap()
|
||||||
goAwayChan := make(chan http2.ErrCode, 1)
|
goAwayChan := make(chan http2.ErrCode, 1)
|
||||||
|
updateRTTChan := make(chan *roundTripMeasurement, 1)
|
||||||
|
updateReceiveWindowChan := make(chan uint32, 1)
|
||||||
|
updateSendWindowChan := make(chan uint32, 1)
|
||||||
|
updateInBoundBytesChan := make(chan uint64)
|
||||||
|
updateOutBoundBytesChan := make(chan uint64)
|
||||||
pingTimestamp := NewPingTimestamp()
|
pingTimestamp := NewPingTimestamp()
|
||||||
connActive := NewSignal()
|
connActive := NewSignal()
|
||||||
idleDuration := config.HeartbeatInterval
|
idleDuration := config.HeartbeatInterval
|
||||||
|
@ -149,34 +156,48 @@ func Handshake(
|
||||||
|
|
||||||
m.explicitShutdown = NewBooleanFuse()
|
m.explicitShutdown = NewBooleanFuse()
|
||||||
m.muxReader = &MuxReader{
|
m.muxReader = &MuxReader{
|
||||||
f: m.f,
|
f: m.f,
|
||||||
handler: m.config.Handler,
|
handler: m.config.Handler,
|
||||||
streams: m.streams,
|
streams: m.streams,
|
||||||
readyList: m.readyList,
|
readyList: m.readyList,
|
||||||
streamErrors: streamErrors,
|
streamErrors: streamErrors,
|
||||||
goAwayChan: goAwayChan,
|
goAwayChan: goAwayChan,
|
||||||
abortChan: m.abortChan,
|
abortChan: m.abortChan,
|
||||||
pingTimestamp: pingTimestamp,
|
pingTimestamp: pingTimestamp,
|
||||||
connActive: connActive,
|
connActive: connActive,
|
||||||
initialStreamWindow: defaultWindowSize,
|
initialStreamWindow: defaultWindowSize,
|
||||||
streamWindowMax: maxWindowSize,
|
streamWindowMax: maxWindowSize,
|
||||||
r: m.r,
|
r: m.r,
|
||||||
|
updateRTTChan: updateRTTChan,
|
||||||
|
updateReceiveWindowChan: updateReceiveWindowChan,
|
||||||
|
updateSendWindowChan: updateSendWindowChan,
|
||||||
|
updateInBoundBytesChan: updateInBoundBytesChan,
|
||||||
}
|
}
|
||||||
m.muxWriter = &MuxWriter{
|
m.muxWriter = &MuxWriter{
|
||||||
f: m.f,
|
f: m.f,
|
||||||
streams: m.streams,
|
streams: m.streams,
|
||||||
streamErrors: streamErrors,
|
streamErrors: streamErrors,
|
||||||
readyStreamChan: m.readyList.ReadyChannel(),
|
readyStreamChan: m.readyList.ReadyChannel(),
|
||||||
newStreamChan: m.newStreamChan,
|
newStreamChan: m.newStreamChan,
|
||||||
goAwayChan: goAwayChan,
|
goAwayChan: goAwayChan,
|
||||||
abortChan: m.abortChan,
|
abortChan: m.abortChan,
|
||||||
pingTimestamp: pingTimestamp,
|
pingTimestamp: pingTimestamp,
|
||||||
idleTimer: NewIdleTimer(idleDuration, maxRetries),
|
idleTimer: NewIdleTimer(idleDuration, maxRetries),
|
||||||
connActiveChan: connActive.WaitChannel(),
|
connActiveChan: connActive.WaitChannel(),
|
||||||
maxFrameSize: defaultFrameSize,
|
maxFrameSize: defaultFrameSize,
|
||||||
|
updateReceiveWindowChan: updateReceiveWindowChan,
|
||||||
|
updateSendWindowChan: updateSendWindowChan,
|
||||||
|
updateOutBoundBytesChan: updateOutBoundBytesChan,
|
||||||
}
|
}
|
||||||
m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
|
m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
|
||||||
|
m.muxMetricsUpdater = newMuxMetricsUpdater(
|
||||||
|
updateRTTChan,
|
||||||
|
updateReceiveWindowChan,
|
||||||
|
updateSendWindowChan,
|
||||||
|
updateInBoundBytesChan,
|
||||||
|
updateOutBoundBytesChan,
|
||||||
|
m.abortChan,
|
||||||
|
)
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -246,9 +267,13 @@ func (m *Muxer) Serve() error {
|
||||||
m.w.Close()
|
m.w.Close()
|
||||||
m.abort()
|
m.abort()
|
||||||
}()
|
}()
|
||||||
|
go func() {
|
||||||
|
errChan <- m.muxMetricsUpdater.run(logger)
|
||||||
|
}()
|
||||||
err := <-errChan
|
err := <-errChan
|
||||||
go func() {
|
go func() {
|
||||||
// discard error as other handler closes
|
// discard errors as other handler and muxMetricsUpdater close
|
||||||
|
<-errChan
|
||||||
<-errChan
|
<-errChan
|
||||||
close(errChan)
|
close(errChan)
|
||||||
}()
|
}()
|
||||||
|
@ -318,14 +343,8 @@ func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, erro
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the estimated round-trip time.
|
func (m *Muxer) Metrics() *MuxerMetrics {
|
||||||
func (m *Muxer) RTT() RTTMeasurement {
|
return m.muxMetricsUpdater.Metrics()
|
||||||
return m.muxReader.RTT()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return min/max/average of send/receive window for all streams on this connection
|
|
||||||
func (m *Muxer) FlowControlMetrics() *FlowControlMetrics {
|
|
||||||
return m.muxReader.FlowControlMetrics()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Muxer) abort() {
|
func (m *Muxer) abort() {
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
|
@ -134,9 +135,12 @@ func TestSingleStream(t *testing.T) {
|
||||||
if stream.Headers[0].Value != "headerValue" {
|
if stream.Headers[0].Value != "headerValue" {
|
||||||
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
|
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
|
||||||
}
|
}
|
||||||
stream.WriteHeaders([]Header{
|
headers := []Header{
|
||||||
Header{Name: "response-header", Value: "responseValue"},
|
Header{Name: "response-header", Value: "responseValue"},
|
||||||
})
|
}
|
||||||
|
stream.WriteHeaders(headers)
|
||||||
|
assert.Equal(t, headers, stream.writeHeaders)
|
||||||
|
assert.False(t, stream.headersSent)
|
||||||
buf := []byte("Hello world")
|
buf := []byte("Hello world")
|
||||||
stream.Write(buf)
|
stream.Write(buf)
|
||||||
// after this receive, the edge closed the stream
|
// after this receive, the edge closed the stream
|
||||||
|
|
|
@ -19,6 +19,7 @@ type MuxedStream struct {
|
||||||
receiveWindowCurrentMax uint32
|
receiveWindowCurrentMax uint32
|
||||||
// limit set in http2 spec. 2^31-1
|
// limit set in http2 spec. 2^31-1
|
||||||
receiveWindowMax uint32
|
receiveWindowMax uint32
|
||||||
|
|
||||||
// nonzero if a WINDOW_UPDATE frame for a stream needs to be sent
|
// nonzero if a WINDOW_UPDATE frame for a stream needs to be sent
|
||||||
windowUpdate uint32
|
windowUpdate uint32
|
||||||
|
|
||||||
|
@ -39,10 +40,6 @@ type MuxedStream struct {
|
||||||
receivedEOF bool
|
receivedEOF bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type flowControlWindow struct {
|
|
||||||
receiveWindow, sendWindow uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *MuxedStream) Read(p []byte) (n int, err error) {
|
func (s *MuxedStream) Read(p []byte) (n int, err error) {
|
||||||
return s.readBuffer.Read(p)
|
return s.readBuffer.Read(p)
|
||||||
}
|
}
|
||||||
|
@ -101,17 +98,21 @@ func (s *MuxedStream) WriteHeaders(headers []Header) error {
|
||||||
return ErrStreamHeadersSent
|
return ErrStreamHeadersSent
|
||||||
}
|
}
|
||||||
s.writeHeaders = headers
|
s.writeHeaders = headers
|
||||||
|
s.headersSent = false
|
||||||
s.writeNotify()
|
s.writeNotify()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *MuxedStream) FlowControlWindow() *flowControlWindow {
|
func (s *MuxedStream) getReceiveWindow() uint32 {
|
||||||
s.writeLock.Lock()
|
s.writeLock.Lock()
|
||||||
defer s.writeLock.Unlock()
|
defer s.writeLock.Unlock()
|
||||||
return &flowControlWindow{
|
return s.receiveWindow
|
||||||
receiveWindow: s.receiveWindow,
|
}
|
||||||
sendWindow: s.sendWindow,
|
|
||||||
}
|
func (s *MuxedStream) getSendWindow() uint32 {
|
||||||
|
s.writeLock.Lock()
|
||||||
|
defer s.writeLock.Unlock()
|
||||||
|
return s.sendWindow
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeNotify must happen while holding writeLock.
|
// writeNotify must happen while holding writeLock.
|
||||||
|
@ -209,9 +210,7 @@ func (s *MuxedStream) getChunk() *streamChunk {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copies at most s.sendWindow bytes
|
// Copies at most s.sendWindow bytes
|
||||||
//log.Infof("writeBuffer len %d stream %d", s.writeBuffer.Len(), s.streamID)
|
|
||||||
writeLen, _ := io.CopyN(&chunk.buffer, &s.writeBuffer, int64(s.sendWindow))
|
writeLen, _ := io.CopyN(&chunk.buffer, &s.writeBuffer, int64(s.sendWindow))
|
||||||
//log.Infof("writeLen %d stream %d", writeLen, s.streamID)
|
|
||||||
s.sendWindow -= uint32(writeLen)
|
s.sendWindow -= uint32(writeLen)
|
||||||
s.receiveWindow += s.windowUpdate
|
s.receiveWindow += s.windowUpdate
|
||||||
s.windowUpdate = 0
|
s.windowUpdate = 0
|
||||||
|
|
|
@ -0,0 +1,232 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-collections/collections/queue"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// data points used to compute average receive window and send window size
|
||||||
|
const (
|
||||||
|
// data points used to compute average receive window and send window size
|
||||||
|
dataPoints = 100
|
||||||
|
// updateFreq is set to 1 sec so we can get inbound & outbound byes/sec
|
||||||
|
updateFreq = time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type muxMetricsUpdater struct {
|
||||||
|
// rttData keeps record of rtt, rttMin, rttMax and last measured time
|
||||||
|
rttData *rttData
|
||||||
|
// receiveWindowData keeps record of receive window measurement
|
||||||
|
receiveWindowData *flowControlData
|
||||||
|
// sendWindowData keeps record of send window measurement
|
||||||
|
sendWindowData *flowControlData
|
||||||
|
// inBoundRate is incoming bytes/sec
|
||||||
|
inBoundRate *rate
|
||||||
|
// outBoundRate is outgoing bytes/sec
|
||||||
|
outBoundRate *rate
|
||||||
|
// updateRTTChan is the channel to receive new RTT measurement from muxReader
|
||||||
|
updateRTTChan <-chan *roundTripMeasurement
|
||||||
|
//updateReceiveWindowChan is the channel to receive updated receiveWindow size from muxReader and muxWriter
|
||||||
|
updateReceiveWindowChan <-chan uint32
|
||||||
|
//updateSendWindowChan is the channel to receive updated sendWindow size from muxReader and muxWriter
|
||||||
|
updateSendWindowChan <-chan uint32
|
||||||
|
// updateInBoundBytesChan us the channel to receive bytesRead from muxReader
|
||||||
|
updateInBoundBytesChan <-chan uint64
|
||||||
|
// updateOutBoundBytesChan us the channel to receive bytesWrote from muxWriter
|
||||||
|
updateOutBoundBytesChan <-chan uint64
|
||||||
|
// shutdownC is to signal the muxerMetricsUpdater to shutdown
|
||||||
|
abortChan <-chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type MuxerMetrics struct {
|
||||||
|
RTT, RTTMin, RTTMax time.Duration
|
||||||
|
ReceiveWindowAve, SendWindowAve float64
|
||||||
|
ReceiveWindowMin, ReceiveWindowMax, SendWindowMin, SendWindowMax uint32
|
||||||
|
InBoundRateCurr, InBoundRateMin, InBoundRateMax uint64
|
||||||
|
OutBoundRateCurr, OutBoundRateMin, OutBoundRateMax uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type roundTripMeasurement struct {
|
||||||
|
receiveTime, sendTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type rttData struct {
|
||||||
|
rtt, rttMin, rttMax time.Duration
|
||||||
|
lastMeasurementTime time.Time
|
||||||
|
lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type flowControlData struct {
|
||||||
|
sum uint64
|
||||||
|
min, max uint32
|
||||||
|
queue *queue.Queue
|
||||||
|
lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type rate struct {
|
||||||
|
curr uint64
|
||||||
|
min, max uint64
|
||||||
|
lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMuxMetricsUpdater(
|
||||||
|
updateRTTChan <-chan *roundTripMeasurement,
|
||||||
|
updateReceiveWindowChan <-chan uint32,
|
||||||
|
updateSendWindowChan <-chan uint32,
|
||||||
|
updateInBoundBytesChan <-chan uint64,
|
||||||
|
updateOutBoundBytesChan <-chan uint64,
|
||||||
|
abortChan <-chan struct{},
|
||||||
|
) *muxMetricsUpdater {
|
||||||
|
return &muxMetricsUpdater{
|
||||||
|
rttData: newRTTData(),
|
||||||
|
receiveWindowData: newFlowControlData(),
|
||||||
|
sendWindowData: newFlowControlData(),
|
||||||
|
inBoundRate: newRate(),
|
||||||
|
outBoundRate: newRate(),
|
||||||
|
updateRTTChan: updateRTTChan,
|
||||||
|
updateReceiveWindowChan: updateReceiveWindowChan,
|
||||||
|
updateSendWindowChan: updateSendWindowChan,
|
||||||
|
updateInBoundBytesChan: updateInBoundBytesChan,
|
||||||
|
updateOutBoundBytesChan: updateOutBoundBytesChan,
|
||||||
|
abortChan: abortChan,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (updater *muxMetricsUpdater) Metrics() *MuxerMetrics {
|
||||||
|
m := &MuxerMetrics{}
|
||||||
|
m.RTT, m.RTTMin, m.RTTMax = updater.rttData.metrics()
|
||||||
|
m.ReceiveWindowAve, m.ReceiveWindowMin, m.ReceiveWindowMax = updater.receiveWindowData.metrics()
|
||||||
|
m.SendWindowAve, m.SendWindowMin, m.SendWindowMax = updater.sendWindowData.metrics()
|
||||||
|
m.InBoundRateCurr, m.InBoundRateMin, m.InBoundRateMax = updater.inBoundRate.get()
|
||||||
|
m.OutBoundRateCurr, m.OutBoundRateMin, m.OutBoundRateMax = updater.outBoundRate.get()
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (updater *muxMetricsUpdater) run(parentLogger *log.Entry) error {
|
||||||
|
logger := parentLogger.WithFields(log.Fields{
|
||||||
|
"subsystem": "mux",
|
||||||
|
"dir": "metrics",
|
||||||
|
})
|
||||||
|
defer logger.Debug("event loop finished")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-updater.abortChan:
|
||||||
|
logger.Infof("Stopping mux metrics updater")
|
||||||
|
return nil
|
||||||
|
case roundTripMeasurement := <-updater.updateRTTChan:
|
||||||
|
go updater.rttData.update(roundTripMeasurement)
|
||||||
|
logger.Debug("Update rtt")
|
||||||
|
case receiveWindow := <-updater.updateReceiveWindowChan:
|
||||||
|
go updater.receiveWindowData.update(receiveWindow)
|
||||||
|
logger.Debug("Update receive window")
|
||||||
|
case sendWindow := <-updater.updateSendWindowChan:
|
||||||
|
go updater.sendWindowData.update(sendWindow)
|
||||||
|
logger.Debug("Update send window")
|
||||||
|
case inBoundBytes := <-updater.updateInBoundBytesChan:
|
||||||
|
// inBoundBytes is bytes/sec because the update interval is 1 sec
|
||||||
|
go updater.inBoundRate.update(inBoundBytes)
|
||||||
|
logger.Debugf("Inbound bytes %d", inBoundBytes)
|
||||||
|
case outBoundBytes := <-updater.updateOutBoundBytesChan:
|
||||||
|
// outBoundBytes is bytes/sec because the update interval is 1 sec
|
||||||
|
go updater.outBoundRate.update(outBoundBytes)
|
||||||
|
logger.Debugf("Outbound bytes %d", outBoundBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRTTData() *rttData {
|
||||||
|
return &rttData{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rttData) update(measurement *roundTripMeasurement) {
|
||||||
|
r.lock.Lock()
|
||||||
|
defer r.lock.Unlock()
|
||||||
|
// discard pings before lastMeasurementTime
|
||||||
|
if r.lastMeasurementTime.After(measurement.sendTime) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.lastMeasurementTime = measurement.sendTime
|
||||||
|
r.rtt = measurement.receiveTime.Sub(measurement.sendTime)
|
||||||
|
if r.rttMax < r.rtt {
|
||||||
|
r.rttMax = r.rtt
|
||||||
|
}
|
||||||
|
if r.rttMin == 0 || r.rttMin > r.rtt {
|
||||||
|
r.rttMin = r.rtt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rttData) metrics() (rtt, rttMin, rttMax time.Duration) {
|
||||||
|
r.lock.RLock()
|
||||||
|
defer r.lock.RUnlock()
|
||||||
|
return r.rtt, r.rttMin, r.rttMax
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFlowControlData() *flowControlData {
|
||||||
|
return &flowControlData{queue: queue.New()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *flowControlData) update(measurement uint32) {
|
||||||
|
f.lock.Lock()
|
||||||
|
defer f.lock.Unlock()
|
||||||
|
var firstItem uint32
|
||||||
|
// store new data into queue, remove oldest data if queue is full
|
||||||
|
f.queue.Enqueue(measurement)
|
||||||
|
if f.queue.Len() > dataPoints {
|
||||||
|
// data type should always be uint32
|
||||||
|
firstItem = f.queue.Dequeue().(uint32)
|
||||||
|
}
|
||||||
|
// if (measurement - firstItem) < 0, uint64(measurement - firstItem)
|
||||||
|
// will overflow and become a large positive number
|
||||||
|
f.sum += uint64(measurement)
|
||||||
|
f.sum -= uint64(firstItem)
|
||||||
|
if measurement > f.max {
|
||||||
|
f.max = measurement
|
||||||
|
}
|
||||||
|
if f.min == 0 || measurement < f.min {
|
||||||
|
f.min = measurement
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// caller of ave() should acquire lock first
|
||||||
|
func (f *flowControlData) ave() float64 {
|
||||||
|
if f.queue.Len() == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return float64(f.sum) / float64(f.queue.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *flowControlData) metrics() (ave float64, min, max uint32) {
|
||||||
|
f.lock.RLock()
|
||||||
|
defer f.lock.RUnlock()
|
||||||
|
return f.ave(), f.min, f.max
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRate() *rate {
|
||||||
|
return &rate{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rate) update(measurement uint64) {
|
||||||
|
r.lock.Lock()
|
||||||
|
defer r.lock.Unlock()
|
||||||
|
r.curr = measurement
|
||||||
|
// if measurement is 0, then there is no incoming/outgoing connection, don't update min/max
|
||||||
|
if r.curr == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if measurement > r.max {
|
||||||
|
r.max = measurement
|
||||||
|
}
|
||||||
|
if r.min == 0 || measurement < r.min {
|
||||||
|
r.min = measurement
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rate) get() (curr, min, max uint64) {
|
||||||
|
r.lock.RLock()
|
||||||
|
defer r.lock.RUnlock()
|
||||||
|
return r.curr, r.min, r.max
|
||||||
|
}
|
|
@ -0,0 +1,176 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ave(sum uint64, len int) float64 {
|
||||||
|
return float64(sum) / float64(len)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRTTUpdate(t *testing.T) {
|
||||||
|
r := newRTTData()
|
||||||
|
start := time.Now()
|
||||||
|
// send at 0 ms, receive at 2 ms, RTT = 2ms
|
||||||
|
m := &roundTripMeasurement{receiveTime: start.Add(2 * time.Millisecond), sendTime: start}
|
||||||
|
r.update(m)
|
||||||
|
assert.Equal(t, start, r.lastMeasurementTime)
|
||||||
|
assert.Equal(t, 2*time.Millisecond, r.rtt)
|
||||||
|
assert.Equal(t, 2*time.Millisecond, r.rttMin)
|
||||||
|
assert.Equal(t, 2*time.Millisecond, r.rttMax)
|
||||||
|
|
||||||
|
// send at 3 ms, receive at 6 ms, RTT = 3ms
|
||||||
|
m = &roundTripMeasurement{receiveTime: start.Add(6 * time.Millisecond), sendTime: start.Add(3 * time.Millisecond)}
|
||||||
|
r.update(m)
|
||||||
|
assert.Equal(t, start.Add(3*time.Millisecond), r.lastMeasurementTime)
|
||||||
|
assert.Equal(t, 3*time.Millisecond, r.rtt)
|
||||||
|
assert.Equal(t, 2*time.Millisecond, r.rttMin)
|
||||||
|
assert.Equal(t, 3*time.Millisecond, r.rttMax)
|
||||||
|
|
||||||
|
// send at 7 ms, receive at 8 ms, RTT = 1ms
|
||||||
|
m = &roundTripMeasurement{receiveTime: start.Add(8 * time.Millisecond), sendTime: start.Add(7 * time.Millisecond)}
|
||||||
|
r.update(m)
|
||||||
|
assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime)
|
||||||
|
assert.Equal(t, 1*time.Millisecond, r.rtt)
|
||||||
|
assert.Equal(t, 1*time.Millisecond, r.rttMin)
|
||||||
|
assert.Equal(t, 3*time.Millisecond, r.rttMax)
|
||||||
|
|
||||||
|
// send at -4 ms, receive at 0 ms, RTT = 4ms, but this ping is before last measurement
|
||||||
|
// so it will be discarded
|
||||||
|
m = &roundTripMeasurement{receiveTime: start, sendTime: start.Add(-2 * time.Millisecond)}
|
||||||
|
r.update(m)
|
||||||
|
assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime)
|
||||||
|
assert.Equal(t, 1*time.Millisecond, r.rtt)
|
||||||
|
assert.Equal(t, 1*time.Millisecond, r.rttMin)
|
||||||
|
assert.Equal(t, 3*time.Millisecond, r.rttMax)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlowControlDataUpdate(t *testing.T) {
|
||||||
|
f := newFlowControlData()
|
||||||
|
assert.Equal(t, 0, f.queue.Len())
|
||||||
|
assert.Equal(t, float64(0), f.ave())
|
||||||
|
|
||||||
|
var sum uint64
|
||||||
|
min := maxWindowSize - dataPoints
|
||||||
|
max := maxWindowSize
|
||||||
|
for i := 1; i <= dataPoints; i++ {
|
||||||
|
size := maxWindowSize - uint32(i)
|
||||||
|
f.update(size)
|
||||||
|
assert.Equal(t, max - uint32(1), f.max)
|
||||||
|
assert.Equal(t, size, f.min)
|
||||||
|
|
||||||
|
assert.Equal(t, i, f.queue.Len())
|
||||||
|
|
||||||
|
sum += uint64(size)
|
||||||
|
assert.Equal(t, sum, f.sum)
|
||||||
|
assert.Equal(t, ave(sum, f.queue.Len()), f.ave())
|
||||||
|
}
|
||||||
|
|
||||||
|
// queue is full, should start to dequeue first element
|
||||||
|
for i := 1; i <= dataPoints; i++ {
|
||||||
|
f.update(max)
|
||||||
|
assert.Equal(t, max, f.max)
|
||||||
|
assert.Equal(t, min, f.min)
|
||||||
|
|
||||||
|
assert.Equal(t, dataPoints, f.queue.Len())
|
||||||
|
|
||||||
|
sum += uint64(i)
|
||||||
|
assert.Equal(t, sum, f.sum)
|
||||||
|
assert.Equal(t, ave(sum, dataPoints), f.ave())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMuxMetricsUpdater(t *testing.T) {
|
||||||
|
updateRTTChan := make(chan *roundTripMeasurement)
|
||||||
|
updateReceiveWindowChan := make(chan uint32)
|
||||||
|
updateSendWindowChan := make(chan uint32)
|
||||||
|
updateInBoundBytesChan := make(chan uint64)
|
||||||
|
updateOutBoundBytesChan := make(chan uint64)
|
||||||
|
abortChan := make(chan struct{})
|
||||||
|
errChan := make(chan error)
|
||||||
|
m := newMuxMetricsUpdater(updateRTTChan,
|
||||||
|
updateReceiveWindowChan,
|
||||||
|
updateSendWindowChan,
|
||||||
|
updateInBoundBytesChan,
|
||||||
|
updateOutBoundBytesChan,
|
||||||
|
abortChan,
|
||||||
|
)
|
||||||
|
logger := log.NewEntry(log.New())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errChan <- m.run(logger)
|
||||||
|
}()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
// mock muxReader
|
||||||
|
readerStart := time.Now()
|
||||||
|
rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart}
|
||||||
|
updateRTTChan <- rm
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
// Becareful if dataPoints is not divisibile by 4
|
||||||
|
readerSend := readerStart.Add(time.Millisecond)
|
||||||
|
for i := 1; i <= dataPoints/4; i++ {
|
||||||
|
readerReceive := readerSend.Add(time.Duration(i) * time.Millisecond)
|
||||||
|
rm := &roundTripMeasurement{receiveTime: readerReceive, sendTime: readerSend}
|
||||||
|
updateRTTChan <- rm
|
||||||
|
readerSend = readerReceive.Add(time.Millisecond)
|
||||||
|
|
||||||
|
updateReceiveWindowChan <- uint32(i)
|
||||||
|
updateSendWindowChan <- uint32(i)
|
||||||
|
|
||||||
|
updateInBoundBytesChan <- uint64(i)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// mock muxWriter
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := dataPoints/4 + 1; j <= dataPoints/2; j++ {
|
||||||
|
updateReceiveWindowChan <- uint32(j)
|
||||||
|
updateSendWindowChan <- uint32(j)
|
||||||
|
|
||||||
|
// should always be disgard since the send time is before readerSend
|
||||||
|
rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart.Add(-time.Duration(j*dataPoints) * time.Millisecond)}
|
||||||
|
updateRTTChan <- rm
|
||||||
|
|
||||||
|
updateOutBoundBytesChan <- uint64(j)
|
||||||
|
}
|
||||||
|
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
metrics := m.Metrics()
|
||||||
|
points := dataPoints / 2
|
||||||
|
assert.Equal(t, time.Millisecond, metrics.RTTMin)
|
||||||
|
assert.Equal(t, time.Duration(dataPoints/4)*time.Millisecond, metrics.RTTMax)
|
||||||
|
|
||||||
|
// sum(1..i) = i*(i+1)/2, ave(1..i) = i*(i+1)/2/i = (i+1)/2
|
||||||
|
assert.Equal(t, float64(points+1)/float64(2), metrics.ReceiveWindowAve)
|
||||||
|
assert.Equal(t, uint32(1), metrics.ReceiveWindowMin)
|
||||||
|
assert.Equal(t, uint32(points), metrics.ReceiveWindowMax)
|
||||||
|
|
||||||
|
assert.Equal(t, float64(points+1)/float64(2), metrics.SendWindowAve)
|
||||||
|
assert.Equal(t, uint32(1), metrics.SendWindowMin)
|
||||||
|
assert.Equal(t, uint32(points), metrics.SendWindowMax)
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateCurr)
|
||||||
|
assert.Equal(t, uint64(1), metrics.InBoundRateMin)
|
||||||
|
assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateMax)
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateCurr)
|
||||||
|
assert.Equal(t, uint64(dataPoints/4+1), metrics.OutBoundRateMin)
|
||||||
|
assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateMax)
|
||||||
|
|
||||||
|
close(abortChan)
|
||||||
|
assert.Nil(t, <-errChan)
|
||||||
|
close(errChan)
|
||||||
|
|
||||||
|
}
|
|
@ -3,7 +3,6 @@ package h2mux
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
@ -34,14 +33,18 @@ type MuxReader struct {
|
||||||
initialStreamWindow uint32
|
initialStreamWindow uint32
|
||||||
// The max value for the send window of a stream.
|
// The max value for the send window of a stream.
|
||||||
streamWindowMax uint32
|
streamWindowMax uint32
|
||||||
// windowMetrics keeps track of min/max/average of send/receive windows for all streams
|
|
||||||
flowControlMetrics *FlowControlMetrics
|
|
||||||
metricsMutex sync.Mutex
|
|
||||||
// r is a reference to the underlying connection used when shutting down.
|
// r is a reference to the underlying connection used when shutting down.
|
||||||
r io.Closer
|
r io.Closer
|
||||||
// rttMeasurement measures RTT based on ping timestamps.
|
// updateRTTChan is the channel to send new RTT measurement to muxerMetricsUpdater
|
||||||
rttMeasurement RTTMeasurement
|
updateRTTChan chan<- *roundTripMeasurement
|
||||||
rttMutex sync.Mutex
|
// updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
|
||||||
|
updateReceiveWindowChan chan<- uint32
|
||||||
|
// updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
|
||||||
|
updateSendWindowChan chan<- uint32
|
||||||
|
// bytesRead is the amount of bytes read from data frame since the last time we send bytes read to metrics
|
||||||
|
bytesRead AtomicCounter
|
||||||
|
// updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
|
||||||
|
updateInBoundBytesChan chan<- uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MuxReader) Shutdown() {
|
func (r *MuxReader) Shutdown() {
|
||||||
|
@ -57,28 +60,26 @@ func (r *MuxReader) Shutdown() {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MuxReader) RTT() RTTMeasurement {
|
|
||||||
r.rttMutex.Lock()
|
|
||||||
defer r.rttMutex.Unlock()
|
|
||||||
return r.rttMeasurement
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *MuxReader) FlowControlMetrics() *FlowControlMetrics {
|
|
||||||
r.metricsMutex.Lock()
|
|
||||||
defer r.metricsMutex.Unlock()
|
|
||||||
if r.flowControlMetrics != nil {
|
|
||||||
return r.flowControlMetrics
|
|
||||||
}
|
|
||||||
// No metrics available yet
|
|
||||||
return &FlowControlMetrics{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *MuxReader) run(parentLogger *log.Entry) error {
|
func (r *MuxReader) run(parentLogger *log.Entry) error {
|
||||||
logger := parentLogger.WithFields(log.Fields{
|
logger := parentLogger.WithFields(log.Fields{
|
||||||
"subsystem": "mux",
|
"subsystem": "mux",
|
||||||
"dir": "read",
|
"dir": "read",
|
||||||
})
|
})
|
||||||
defer logger.Debug("event loop finished")
|
defer logger.Debug("event loop finished")
|
||||||
|
|
||||||
|
// routine to periodically update bytesRead
|
||||||
|
go func() {
|
||||||
|
tickC := time.Tick(updateFreq)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-r.abortChan:
|
||||||
|
return
|
||||||
|
case <-tickC:
|
||||||
|
r.updateInBoundBytesChan <- r.bytesRead.Count()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
frame, err := r.f.ReadFrame()
|
frame, err := r.f.ReadFrame()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -120,6 +121,8 @@ func (r *MuxReader) run(parentLogger *log.Entry) error {
|
||||||
r.receivePingData(f)
|
r.receivePingData(f)
|
||||||
case *http2.GoAwayFrame:
|
case *http2.GoAwayFrame:
|
||||||
err = r.receiveGoAway(f)
|
err = r.receiveGoAway(f)
|
||||||
|
// The receiver of a flow-controlled frame sends a WINDOW_UPDATE frame as it
|
||||||
|
// consumes data and frees up space in flow-control windows
|
||||||
case *http2.WindowUpdateFrame:
|
case *http2.WindowUpdateFrame:
|
||||||
err = r.updateStreamWindow(f)
|
err = r.updateStreamWindow(f)
|
||||||
default:
|
default:
|
||||||
|
@ -236,10 +239,11 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E
|
||||||
}
|
}
|
||||||
data := frame.Data()
|
data := frame.Data()
|
||||||
if len(data) > 0 {
|
if len(data) > 0 {
|
||||||
_, err = stream.readBuffer.Write(data)
|
n, err := stream.readBuffer.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r.streamError(stream.streamID, http2.ErrCodeInternal)
|
return r.streamError(stream.streamID, http2.ErrCodeInternal)
|
||||||
}
|
}
|
||||||
|
r.bytesRead.IncrementBy(uint64(n))
|
||||||
}
|
}
|
||||||
if frame.Header().Flags.Has(http2.FlagDataEndStream) {
|
if frame.Header().Flags.Has(http2.FlagDataEndStream) {
|
||||||
if stream.receiveEOF() {
|
if stream.receiveEOF() {
|
||||||
|
@ -253,6 +257,7 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E
|
||||||
if !stream.consumeReceiveWindow(uint32(len(data))) {
|
if !stream.consumeReceiveWindow(uint32(len(data))) {
|
||||||
return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
|
return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
|
||||||
}
|
}
|
||||||
|
r.updateReceiveWindowChan <- stream.getReceiveWindow()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -263,10 +268,14 @@ func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
|
||||||
r.pingTimestamp.Set(ts)
|
r.pingTimestamp.Set(ts)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.rttMutex.Lock()
|
|
||||||
r.rttMeasurement.Update(time.Unix(0, ts))
|
// Update updates the computed values with a new measurement.
|
||||||
r.rttMutex.Unlock()
|
// outgoingTime is the time that the probe was sent.
|
||||||
r.flowControlMetrics = r.streams.Metrics()
|
// We assume that time.Now() is the time we received that probe.
|
||||||
|
r.updateRTTChan <- &roundTripMeasurement{
|
||||||
|
receiveTime: time.Now(),
|
||||||
|
sendTime: time.Unix(0, ts),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Receive a GOAWAY from the peer. Gracefully shut down our connection.
|
// Receive a GOAWAY from the peer. Gracefully shut down our connection.
|
||||||
|
@ -293,6 +302,7 @@ func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
stream.replenishSendWindow(frame.Increment)
|
stream.replenishSendWindow(frame.Increment)
|
||||||
|
r.updateSendWindowChan <- stream.getSendWindow()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,14 @@ type MuxWriter struct {
|
||||||
headerEncoder *hpack.Encoder
|
headerEncoder *hpack.Encoder
|
||||||
// headerBuffer is the temporary buffer used by headerEncoder.
|
// headerBuffer is the temporary buffer used by headerEncoder.
|
||||||
headerBuffer bytes.Buffer
|
headerBuffer bytes.Buffer
|
||||||
|
// updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
|
||||||
|
updateReceiveWindowChan chan<- uint32
|
||||||
|
// updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
|
||||||
|
updateSendWindowChan chan<- uint32
|
||||||
|
// bytesWrote is the amount of bytes wrote to data frame since the last time we send bytes wrote to metrics
|
||||||
|
bytesWrote AtomicCounter
|
||||||
|
// updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
|
||||||
|
updateOutBoundBytesChan chan<- uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
type MuxedStreamRequest struct {
|
type MuxedStreamRequest struct {
|
||||||
|
@ -64,6 +72,20 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error {
|
||||||
"dir": "write",
|
"dir": "write",
|
||||||
})
|
})
|
||||||
defer logger.Debug("event loop finished")
|
defer logger.Debug("event loop finished")
|
||||||
|
|
||||||
|
// routine to periodically communicate bytesWrote
|
||||||
|
go func() {
|
||||||
|
tickC := time.Tick(updateFreq)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-w.abortChan:
|
||||||
|
return
|
||||||
|
case <-tickC:
|
||||||
|
w.updateOutBoundBytesChan <- w.bytesWrote.Count()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-w.abortChan:
|
case <-w.abortChan:
|
||||||
|
@ -141,7 +163,8 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error {
|
||||||
func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error {
|
func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error {
|
||||||
logger.Debug("writable")
|
logger.Debug("writable")
|
||||||
chunk := stream.getChunk()
|
chunk := stream.getChunk()
|
||||||
|
w.updateReceiveWindowChan <- stream.getReceiveWindow()
|
||||||
|
w.updateSendWindowChan <- stream.getSendWindow()
|
||||||
if chunk.sendHeadersFrame() {
|
if chunk.sendHeadersFrame() {
|
||||||
err := w.writeHeaders(chunk.streamID, chunk.headers)
|
err := w.writeHeaders(chunk.streamID, chunk.headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -154,7 +177,9 @@ func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) erro
|
||||||
if chunk.sendWindowUpdateFrame() {
|
if chunk.sendWindowUpdateFrame() {
|
||||||
// Send a WINDOW_UPDATE frame to update our receive window.
|
// Send a WINDOW_UPDATE frame to update our receive window.
|
||||||
// If the Stream ID is zero, the window update applies to the connection as a whole
|
// If the Stream ID is zero, the window update applies to the connection as a whole
|
||||||
// A WINDOW_UPDATE in a specific stream applies to the connection-level flow control as well.
|
// RFC7540 section-6.9.1 "A receiver that receives a flow-controlled frame MUST
|
||||||
|
// always account for its contribution against the connection flow-control
|
||||||
|
// window, unless the receiver treats this as a connection error"
|
||||||
err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
|
err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Warn("error writing window update")
|
logger.WithError(err).Warn("error writing window update")
|
||||||
|
@ -170,6 +195,8 @@ func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) erro
|
||||||
logger.WithError(err).Warn("error writing data")
|
logger.WithError(err).Warn("error writing data")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
// update the amount of data wrote
|
||||||
|
w.bytesWrote.IncrementBy(uint64(len(payload)))
|
||||||
logger.WithField("len", len(payload)).Debug("output data")
|
logger.WithField("len", len(payload)).Debug("output data")
|
||||||
|
|
||||||
if sentEOF {
|
if sentEOF {
|
||||||
|
@ -214,19 +241,16 @@ func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
blockSize := int(w.maxFrameSize)
|
blockSize := int(w.maxFrameSize)
|
||||||
continuation := false
|
|
||||||
endHeaders := len(encodedHeaders) == 0
|
endHeaders := len(encodedHeaders) == 0
|
||||||
for !endHeaders && err == nil {
|
for !endHeaders && err == nil {
|
||||||
blockFragment := encodedHeaders
|
blockFragment := encodedHeaders
|
||||||
if len(encodedHeaders) > blockSize {
|
if len(encodedHeaders) > blockSize {
|
||||||
blockFragment = blockFragment[:blockSize]
|
blockFragment = blockFragment[:blockSize]
|
||||||
encodedHeaders = encodedHeaders[blockSize:]
|
encodedHeaders = encodedHeaders[blockSize:]
|
||||||
} else {
|
// Send CONTINUATION frame if the headers can't be fit into 1 frame
|
||||||
endHeaders = true
|
|
||||||
}
|
|
||||||
if continuation {
|
|
||||||
err = w.f.WriteContinuation(streamID, endHeaders, blockFragment)
|
err = w.f.WriteContinuation(streamID, endHeaders, blockFragment)
|
||||||
} else {
|
} else {
|
||||||
|
endHeaders = true
|
||||||
err = w.f.WriteHeaders(http2.HeadersFrameParam{
|
err = w.f.WriteHeaders(http2.HeadersFrameParam{
|
||||||
StreamID: streamID,
|
StreamID: streamID,
|
||||||
EndHeaders: endHeaders,
|
EndHeaders: endHeaders,
|
||||||
|
|
24
h2mux/rtt.go
24
h2mux/rtt.go
|
@ -2,7 +2,6 @@ package h2mux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// PingTimestamp is an atomic interface around ping timestamping and signalling.
|
// PingTimestamp is an atomic interface around ping timestamping and signalling.
|
||||||
|
@ -28,26 +27,3 @@ func (pt *PingTimestamp) Get() int64 {
|
||||||
func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} {
|
func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} {
|
||||||
return pt.signal.WaitChannel()
|
return pt.signal.WaitChannel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RTTMeasurement encapsulates a continuous round trip time measurement.
|
|
||||||
type RTTMeasurement struct {
|
|
||||||
Current, Min, Max time.Duration
|
|
||||||
lastMeasurementTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update updates the computed values with a new measurement.
|
|
||||||
// outgoingTime is the time that the probe was sent.
|
|
||||||
// We assume that time.Now() is the time we received that probe.
|
|
||||||
func (r *RTTMeasurement) Update(outgoingTime time.Time) {
|
|
||||||
if !r.lastMeasurementTime.Before(outgoingTime) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.lastMeasurementTime = outgoingTime
|
|
||||||
r.Current = time.Since(outgoingTime)
|
|
||||||
if r.Max < r.Current {
|
|
||||||
r.Max = r.Current
|
|
||||||
}
|
|
||||||
if r.Min > r.Current {
|
|
||||||
r.Min = r.Current
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -2,13 +2,32 @@ package origin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflare-warp/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunnelMetrics struct {
|
type muxerMetrics struct {
|
||||||
|
rtt *prometheus.GaugeVec
|
||||||
|
rttMin *prometheus.GaugeVec
|
||||||
|
rttMax *prometheus.GaugeVec
|
||||||
|
receiveWindowAve *prometheus.GaugeVec
|
||||||
|
sendWindowAve *prometheus.GaugeVec
|
||||||
|
receiveWindowMin *prometheus.GaugeVec
|
||||||
|
receiveWindowMax *prometheus.GaugeVec
|
||||||
|
sendWindowMin *prometheus.GaugeVec
|
||||||
|
sendWindowMax *prometheus.GaugeVec
|
||||||
|
inBoundRateCurr *prometheus.GaugeVec
|
||||||
|
inBoundRateMin *prometheus.GaugeVec
|
||||||
|
inBoundRateMax *prometheus.GaugeVec
|
||||||
|
outBoundRateCurr *prometheus.GaugeVec
|
||||||
|
outBoundRateMin *prometheus.GaugeVec
|
||||||
|
outBoundRateMax *prometheus.GaugeVec
|
||||||
|
}
|
||||||
|
|
||||||
|
type tunnelMetrics struct {
|
||||||
haConnections prometheus.Gauge
|
haConnections prometheus.Gauge
|
||||||
totalRequests prometheus.Counter
|
totalRequests prometheus.Counter
|
||||||
requestsPerTunnel *prometheus.CounterVec
|
requestsPerTunnel *prometheus.CounterVec
|
||||||
|
@ -20,16 +39,7 @@ type TunnelMetrics struct {
|
||||||
maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
|
maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
|
||||||
// concurrentRequests records max count of concurrent requests for each tunnel
|
// concurrentRequests records max count of concurrent requests for each tunnel
|
||||||
maxConcurrentRequests map[string]uint64
|
maxConcurrentRequests map[string]uint64
|
||||||
rtt prometheus.Gauge
|
|
||||||
rttMin prometheus.Gauge
|
|
||||||
rttMax prometheus.Gauge
|
|
||||||
timerRetries prometheus.Gauge
|
timerRetries prometheus.Gauge
|
||||||
receiveWindowSizeAve prometheus.Gauge
|
|
||||||
sendWindowSizeAve prometheus.Gauge
|
|
||||||
receiveWindowSizeMin prometheus.Gauge
|
|
||||||
receiveWindowSizeMax prometheus.Gauge
|
|
||||||
sendWindowSizeMin prometheus.Gauge
|
|
||||||
sendWindowSizeMax prometheus.Gauge
|
|
||||||
responseByCode *prometheus.CounterVec
|
responseByCode *prometheus.CounterVec
|
||||||
responseCodePerTunnel *prometheus.CounterVec
|
responseCodePerTunnel *prometheus.CounterVec
|
||||||
serverLocations *prometheus.GaugeVec
|
serverLocations *prometheus.GaugeVec
|
||||||
|
@ -37,10 +47,189 @@ type TunnelMetrics struct {
|
||||||
locationLock sync.Mutex
|
locationLock sync.Mutex
|
||||||
// oldServerLocations stores the last server the tunnel was connected to
|
// oldServerLocations stores the last server the tunnel was connected to
|
||||||
oldServerLocations map[string]string
|
oldServerLocations map[string]string
|
||||||
|
|
||||||
|
muxerMetrics *muxerMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMuxerMetrics() *muxerMetrics {
|
||||||
|
rtt := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "rtt",
|
||||||
|
Help: "Round-trip time in millisecond",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(rtt)
|
||||||
|
|
||||||
|
rttMin := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "rtt_min",
|
||||||
|
Help: "Shortest round-trip time in millisecond",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(rttMin)
|
||||||
|
|
||||||
|
rttMax := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "rtt_max",
|
||||||
|
Help: "Longest round-trip time in millisecond",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(rttMax)
|
||||||
|
|
||||||
|
receiveWindowAve := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "receive_window_ave",
|
||||||
|
Help: "Average receive window size in bytes",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(receiveWindowAve)
|
||||||
|
|
||||||
|
sendWindowAve := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "send_window_ave",
|
||||||
|
Help: "Average send window size in bytes",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(sendWindowAve)
|
||||||
|
|
||||||
|
receiveWindowMin := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "receive_window_min",
|
||||||
|
Help: "Smallest receive window size in bytes",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(receiveWindowMin)
|
||||||
|
|
||||||
|
receiveWindowMax := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "receive_window_max",
|
||||||
|
Help: "Largest receive window size in bytes",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(receiveWindowMax)
|
||||||
|
|
||||||
|
sendWindowMin := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "send_window_min",
|
||||||
|
Help: "Smallest send window size in bytes",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(sendWindowMin)
|
||||||
|
|
||||||
|
sendWindowMax := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "send_window_max",
|
||||||
|
Help: "Largest send window size in bytes",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(sendWindowMax)
|
||||||
|
|
||||||
|
inBoundRateCurr := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "inbound_bytes_per_sec_curr",
|
||||||
|
Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(inBoundRateCurr)
|
||||||
|
|
||||||
|
inBoundRateMin := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "inbound_bytes_per_sec_min",
|
||||||
|
Help: "Minimum non-zero inbounding bytes per second",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(inBoundRateMin)
|
||||||
|
|
||||||
|
inBoundRateMax := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "inbound_bytes_per_sec_max",
|
||||||
|
Help: "Maximum inbounding bytes per second",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(inBoundRateMax)
|
||||||
|
|
||||||
|
outBoundRateCurr := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "outbound_bytes_per_sec_curr",
|
||||||
|
Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(outBoundRateCurr)
|
||||||
|
|
||||||
|
outBoundRateMin := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "outbound_bytes_per_sec_min",
|
||||||
|
Help: "Minimum non-zero outbounding bytes per second",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(outBoundRateMin)
|
||||||
|
|
||||||
|
outBoundRateMax := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "outbound_bytes_per_sec_max",
|
||||||
|
Help: "Maximum outbounding bytes per second",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(outBoundRateMax)
|
||||||
|
|
||||||
|
return &muxerMetrics{
|
||||||
|
rtt: rtt,
|
||||||
|
rttMin: rttMin,
|
||||||
|
rttMax: rttMax,
|
||||||
|
receiveWindowAve: receiveWindowAve,
|
||||||
|
sendWindowAve: sendWindowAve,
|
||||||
|
receiveWindowMin: receiveWindowMin,
|
||||||
|
receiveWindowMax: receiveWindowMax,
|
||||||
|
sendWindowMin: sendWindowMin,
|
||||||
|
sendWindowMax: sendWindowMax,
|
||||||
|
inBoundRateCurr: inBoundRateCurr,
|
||||||
|
inBoundRateMin: inBoundRateMin,
|
||||||
|
inBoundRateMax: inBoundRateMax,
|
||||||
|
outBoundRateCurr: outBoundRateCurr,
|
||||||
|
outBoundRateMin: outBoundRateMin,
|
||||||
|
outBoundRateMax: outBoundRateMax,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) {
|
||||||
|
m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT))
|
||||||
|
m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin))
|
||||||
|
m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax))
|
||||||
|
m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve)
|
||||||
|
m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve)
|
||||||
|
m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin))
|
||||||
|
m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax))
|
||||||
|
m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin))
|
||||||
|
m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax))
|
||||||
|
m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr))
|
||||||
|
m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin))
|
||||||
|
m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax))
|
||||||
|
m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr))
|
||||||
|
m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin))
|
||||||
|
m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax))
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertRTTMilliSec(t time.Duration) float64 {
|
||||||
|
return float64(t / time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Metrics that can be collected without asking the edge
|
// Metrics that can be collected without asking the edge
|
||||||
func NewTunnelMetrics() *TunnelMetrics {
|
func NewTunnelMetrics() *tunnelMetrics {
|
||||||
haConnections := prometheus.NewGauge(
|
haConnections := prometheus.NewGauge(
|
||||||
prometheus.GaugeOpts{
|
prometheus.GaugeOpts{
|
||||||
Name: "ha_connections",
|
Name: "ha_connections",
|
||||||
|
@ -82,27 +271,6 @@ func NewTunnelMetrics() *TunnelMetrics {
|
||||||
)
|
)
|
||||||
prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
|
prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
|
||||||
|
|
||||||
rtt := prometheus.NewGauge(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Name: "rtt",
|
|
||||||
Help: "Round-trip time",
|
|
||||||
})
|
|
||||||
prometheus.MustRegister(rtt)
|
|
||||||
|
|
||||||
rttMin := prometheus.NewGauge(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Name: "rtt_min",
|
|
||||||
Help: "Shortest round-trip time",
|
|
||||||
})
|
|
||||||
prometheus.MustRegister(rttMin)
|
|
||||||
|
|
||||||
rttMax := prometheus.NewGauge(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Name: "rtt_max",
|
|
||||||
Help: "Longest round-trip time",
|
|
||||||
})
|
|
||||||
prometheus.MustRegister(rttMax)
|
|
||||||
|
|
||||||
timerRetries := prometheus.NewGauge(
|
timerRetries := prometheus.NewGauge(
|
||||||
prometheus.GaugeOpts{
|
prometheus.GaugeOpts{
|
||||||
Name: "timer_retries",
|
Name: "timer_retries",
|
||||||
|
@ -110,48 +278,6 @@ func NewTunnelMetrics() *TunnelMetrics {
|
||||||
})
|
})
|
||||||
prometheus.MustRegister(timerRetries)
|
prometheus.MustRegister(timerRetries)
|
||||||
|
|
||||||
receiveWindowSizeAve := prometheus.NewGauge(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Name: "receive_window_ave",
|
|
||||||
Help: "Average receive window size",
|
|
||||||
})
|
|
||||||
prometheus.MustRegister(receiveWindowSizeAve)
|
|
||||||
|
|
||||||
sendWindowSizeAve := prometheus.NewGauge(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Name: "send_window_ave",
|
|
||||||
Help: "Average send window size",
|
|
||||||
})
|
|
||||||
prometheus.MustRegister(sendWindowSizeAve)
|
|
||||||
|
|
||||||
receiveWindowSizeMin := prometheus.NewGauge(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Name: "receive_window_min",
|
|
||||||
Help: "Smallest receive window size",
|
|
||||||
})
|
|
||||||
prometheus.MustRegister(receiveWindowSizeMin)
|
|
||||||
|
|
||||||
receiveWindowSizeMax := prometheus.NewGauge(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Name: "receive_window_max",
|
|
||||||
Help: "Largest receive window size",
|
|
||||||
})
|
|
||||||
prometheus.MustRegister(receiveWindowSizeMax)
|
|
||||||
|
|
||||||
sendWindowSizeMin := prometheus.NewGauge(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Name: "send_window_min",
|
|
||||||
Help: "Smallest send window size",
|
|
||||||
})
|
|
||||||
prometheus.MustRegister(sendWindowSizeMin)
|
|
||||||
|
|
||||||
sendWindowSizeMax := prometheus.NewGauge(
|
|
||||||
prometheus.GaugeOpts{
|
|
||||||
Name: "send_window_max",
|
|
||||||
Help: "Largest send window size",
|
|
||||||
})
|
|
||||||
prometheus.MustRegister(sendWindowSizeMax)
|
|
||||||
|
|
||||||
responseByCode := prometheus.NewCounterVec(
|
responseByCode := prometheus.NewCounterVec(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
Name: "response_by_code",
|
Name: "response_by_code",
|
||||||
|
@ -179,7 +305,7 @@ func NewTunnelMetrics() *TunnelMetrics {
|
||||||
)
|
)
|
||||||
prometheus.MustRegister(serverLocations)
|
prometheus.MustRegister(serverLocations)
|
||||||
|
|
||||||
return &TunnelMetrics{
|
return &tunnelMetrics{
|
||||||
haConnections: haConnections,
|
haConnections: haConnections,
|
||||||
totalRequests: totalRequests,
|
totalRequests: totalRequests,
|
||||||
requestsPerTunnel: requestsPerTunnel,
|
requestsPerTunnel: requestsPerTunnel,
|
||||||
|
@ -187,41 +313,28 @@ func NewTunnelMetrics() *TunnelMetrics {
|
||||||
concurrentRequests: make(map[string]uint64),
|
concurrentRequests: make(map[string]uint64),
|
||||||
maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
|
maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
|
||||||
maxConcurrentRequests: make(map[string]uint64),
|
maxConcurrentRequests: make(map[string]uint64),
|
||||||
rtt: rtt,
|
timerRetries: timerRetries,
|
||||||
rttMin: rttMin,
|
responseByCode: responseByCode,
|
||||||
rttMax: rttMax,
|
responseCodePerTunnel: responseCodePerTunnel,
|
||||||
timerRetries: timerRetries,
|
serverLocations: serverLocations,
|
||||||
receiveWindowSizeAve: receiveWindowSizeAve,
|
oldServerLocations: make(map[string]string),
|
||||||
sendWindowSizeAve: sendWindowSizeAve,
|
muxerMetrics: newMuxerMetrics(),
|
||||||
receiveWindowSizeMin: receiveWindowSizeMin,
|
|
||||||
receiveWindowSizeMax: receiveWindowSizeMax,
|
|
||||||
sendWindowSizeMin: sendWindowSizeMin,
|
|
||||||
sendWindowSizeMax: sendWindowSizeMax,
|
|
||||||
responseByCode: responseByCode,
|
|
||||||
responseCodePerTunnel: responseCodePerTunnel,
|
|
||||||
serverLocations: serverLocations,
|
|
||||||
oldServerLocations: make(map[string]string),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunnelMetrics) incrementHaConnections() {
|
func (t *tunnelMetrics) incrementHaConnections() {
|
||||||
t.haConnections.Inc()
|
t.haConnections.Inc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunnelMetrics) decrementHaConnections() {
|
func (t *tunnelMetrics) decrementHaConnections() {
|
||||||
t.haConnections.Dec()
|
t.haConnections.Dec()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunnelMetrics) updateTunnelFlowControlMetrics(metrics *h2mux.FlowControlMetrics) {
|
func (t *tunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) {
|
||||||
t.receiveWindowSizeAve.Set(float64(metrics.AverageReceiveWindowSize))
|
t.muxerMetrics.update(connectionID, metrics)
|
||||||
t.sendWindowSizeAve.Set(float64(metrics.AverageSendWindowSize))
|
|
||||||
t.receiveWindowSizeMin.Set(float64(metrics.MinReceiveWindowSize))
|
|
||||||
t.receiveWindowSizeMax.Set(float64(metrics.MaxReceiveWindowSize))
|
|
||||||
t.sendWindowSizeMin.Set(float64(metrics.MinSendWindowSize))
|
|
||||||
t.sendWindowSizeMax.Set(float64(metrics.MaxSendWindowSize))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunnelMetrics) incrementRequests(connectionID string) {
|
func (t *tunnelMetrics) incrementRequests(connectionID string) {
|
||||||
t.concurrentRequestsLock.Lock()
|
t.concurrentRequestsLock.Lock()
|
||||||
var concurrentRequests uint64
|
var concurrentRequests uint64
|
||||||
var ok bool
|
var ok bool
|
||||||
|
@ -243,7 +356,7 @@ func (t *TunnelMetrics) incrementRequests(connectionID string) {
|
||||||
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
|
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
|
func (t *tunnelMetrics) decrementConcurrentRequests(connectionID string) {
|
||||||
t.concurrentRequestsLock.Lock()
|
t.concurrentRequestsLock.Lock()
|
||||||
if _, ok := t.concurrentRequests[connectionID]; ok {
|
if _, ok := t.concurrentRequests[connectionID]; ok {
|
||||||
t.concurrentRequests[connectionID] -= 1
|
t.concurrentRequests[connectionID] -= 1
|
||||||
|
@ -255,13 +368,13 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
|
||||||
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
|
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunnelMetrics) incrementResponses(connectionID, code string) {
|
func (t *tunnelMetrics) incrementResponses(connectionID, code string) {
|
||||||
t.responseByCode.WithLabelValues(code).Inc()
|
t.responseByCode.WithLabelValues(code).Inc()
|
||||||
t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
|
t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) {
|
func (t *tunnelMetrics) registerServerLocation(connectionID, loc string) {
|
||||||
t.locationLock.Lock()
|
t.locationLock.Lock()
|
||||||
defer t.locationLock.Unlock()
|
defer t.locationLock.Unlock()
|
||||||
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
|
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
|
||||||
|
|
|
@ -15,11 +15,11 @@ import (
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflare-warp/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
"github.com/cloudflare/cloudflare-warp/tunnelrpc"
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
"github.com/cloudflare/cloudflare-warp/validation"
|
"github.com/cloudflare/cloudflared/validation"
|
||||||
"github.com/cloudflare/cloudflare-warp/websocket"
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
|
|
||||||
raven "github.com/getsentry/raven-go"
|
raven "github.com/getsentry/raven-go"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
@ -53,7 +53,7 @@ type TunnelConfig struct {
|
||||||
Tags []tunnelpogs.Tag
|
Tags []tunnelpogs.Tag
|
||||||
HAConnections int
|
HAConnections int
|
||||||
HTTPTransport http.RoundTripper
|
HTTPTransport http.RoundTripper
|
||||||
Metrics *TunnelMetrics
|
Metrics *tunnelMetrics
|
||||||
MetricsUpdateFreq time.Duration
|
MetricsUpdateFreq time.Duration
|
||||||
ProtocolLogger *logrus.Logger
|
ProtocolLogger *logrus.Logger
|
||||||
Logger *logrus.Logger
|
Logger *logrus.Logger
|
||||||
|
@ -185,6 +185,7 @@ func ServeTunnel(
|
||||||
serveCtx, serveCancel := context.WithCancel(ctx)
|
serveCtx, serveCancel := context.WithCancel(ctx)
|
||||||
registerErrC := make(chan error, 1)
|
registerErrC := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP)
|
err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
connectedFuse.Fuse(true)
|
connectedFuse.Fuse(true)
|
||||||
|
@ -193,18 +194,18 @@ func ServeTunnel(
|
||||||
serveCancel()
|
serveCancel()
|
||||||
}
|
}
|
||||||
registerErrC <- err
|
registerErrC <- err
|
||||||
wg.Done()
|
|
||||||
}()
|
}()
|
||||||
updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
|
updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
connectionTag := uint8ToString(connectionID)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-serveCtx.Done():
|
case <-serveCtx.Done():
|
||||||
handler.muxer.Shutdown()
|
handler.muxer.Shutdown()
|
||||||
return
|
return
|
||||||
case <-updateMetricsTickC:
|
case <-updateMetricsTickC:
|
||||||
handler.UpdateMetrics()
|
handler.UpdateMetrics(connectionTag)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -303,7 +304,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
|
||||||
func LogServerInfo(logger *logrus.Entry,
|
func LogServerInfo(logger *logrus.Entry,
|
||||||
promise tunnelrpc.ServerInfo_Promise,
|
promise tunnelrpc.ServerInfo_Promise,
|
||||||
connectionID uint8,
|
connectionID uint8,
|
||||||
metrics *TunnelMetrics,
|
metrics *tunnelMetrics,
|
||||||
) {
|
) {
|
||||||
serverInfoMessage, err := promise.Struct()
|
serverInfoMessage, err := promise.Struct()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -356,13 +357,17 @@ func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func FindCfRayHeader(h1 *http.Request) string {
|
||||||
|
return h1.Header.Get("Cf-Ray")
|
||||||
|
}
|
||||||
|
|
||||||
type TunnelHandler struct {
|
type TunnelHandler struct {
|
||||||
originUrl string
|
originUrl string
|
||||||
muxer *h2mux.Muxer
|
muxer *h2mux.Muxer
|
||||||
httpClient http.RoundTripper
|
httpClient http.RoundTripper
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
tags []tunnelpogs.Tag
|
tags []tunnelpogs.Tag
|
||||||
metrics *TunnelMetrics
|
metrics *tunnelMetrics
|
||||||
// connectionID is only used by metrics, and prometheus requires labels to be string
|
// connectionID is only used by metrics, and prometheus requires labels to be string
|
||||||
connectionID string
|
connectionID string
|
||||||
}
|
}
|
||||||
|
@ -435,7 +440,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
Log.WithError(err).Error("invalid request received")
|
Log.WithError(err).Error("invalid request received")
|
||||||
}
|
}
|
||||||
h.AppendTagHeaders(req)
|
h.AppendTagHeaders(req)
|
||||||
|
cfRay := FindCfRayHeader(req)
|
||||||
|
h.logRequest(req, cfRay)
|
||||||
if websocket.IsWebSocketUpgrade(req) {
|
if websocket.IsWebSocketUpgrade(req) {
|
||||||
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
|
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -444,6 +450,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
stream.WriteHeaders(H1ResponseToH2Response(response))
|
stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
websocket.Stream(conn.UnderlyingConn(), stream)
|
websocket.Stream(conn.UnderlyingConn(), stream)
|
||||||
|
h.metrics.incrementResponses(h.connectionID, "200")
|
||||||
|
h.logResponse(response, cfRay)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
response, err := h.httpClient.RoundTrip(req)
|
response, err := h.httpClient.RoundTrip(req)
|
||||||
|
@ -454,6 +462,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
stream.WriteHeaders(H1ResponseToH2Response(response))
|
stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||||
io.Copy(stream, response.Body)
|
io.Copy(stream, response.Body)
|
||||||
h.metrics.incrementResponses(h.connectionID, "200")
|
h.metrics.incrementResponses(h.connectionID, "200")
|
||||||
|
h.logResponse(response, cfRay)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
h.metrics.decrementConcurrentRequests(h.connectionID)
|
h.metrics.decrementConcurrentRequests(h.connectionID)
|
||||||
|
@ -467,9 +476,27 @@ func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
|
||||||
h.metrics.incrementResponses(h.connectionID, "502")
|
h.metrics.incrementResponses(h.connectionID, "502")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *TunnelHandler) UpdateMetrics() {
|
func (h *TunnelHandler) logRequest(req *http.Request, cfRay string) {
|
||||||
flowCtlMetrics := h.muxer.FlowControlMetrics()
|
if cfRay != "" {
|
||||||
h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics)
|
Log.WithField("CF-RAY", cfRay).Infof("%s %s %s", req.Method, req.URL, req.Proto)
|
||||||
|
} else {
|
||||||
|
Log.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto)
|
||||||
|
}
|
||||||
|
Log.Debugf("Request Headers %+v", req.Header)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TunnelHandler) logResponse(r *http.Response, cfRay string) {
|
||||||
|
if cfRay != "" {
|
||||||
|
Log.WithField("CF-RAY", cfRay).Infof("%s", r.Status)
|
||||||
|
} else {
|
||||||
|
Log.Infof("%s", r.Status)
|
||||||
|
}
|
||||||
|
Log.Debugf("Response Headers %+v", r.Header)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func (h *TunnelHandler) UpdateMetrics(connectionID string) {
|
||||||
|
h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics())
|
||||||
}
|
}
|
||||||
|
|
||||||
func uint8ToString(input uint8) string {
|
func uint8ToString(input uint8) string {
|
||||||
|
|
|
@ -11,28 +11,28 @@ const (
|
||||||
BgUrgQQAIg==
|
BgUrgQQAIg==
|
||||||
-----END EC PARAMETERS-----
|
-----END EC PARAMETERS-----
|
||||||
-----BEGIN EC PRIVATE KEY-----
|
-----BEGIN EC PRIVATE KEY-----
|
||||||
MIGkAgEBBDAdyQBXfxTDCQSOT0HugmH9pVBtIw8t5dYvm6HxGlNq6P57v5GeN02Z
|
MIGkAgEBBDBGGfwhIJdiUiJUVIItqJjEIMmlXxsMa8TQeer47+g+cIZ466rgg8EK
|
||||||
dH9FRl7+VSWgBwYFK4EEACKhZANiAATqpFzTxxV7D+/oqhKCTR6BEM9elTfKaRQE
|
+Mdn6BY48GCgBwYFK4EEACKhZANiAASW//A9iDbPKg3OLkn7yJqLer32g9I5lBKR
|
||||||
FsLufcmaTMw/9tTwgpHKao/QsLKDTNbQhbSQLkcmpCQKlSGhl+pCrqNt/oYUAhav
|
tPc/zBubQLLz9lAaYI6AOQiJXhGr5JkKmQfi1sYHK5rJITPFy4W8Et4hHLdazDZH
|
||||||
UIwpwGiLCqGH/R2AqWLKRPOa/Rufs/U=
|
WnEd+TStQABFUjrhtqXPWmGKcly0pOE=
|
||||||
-----END EC PRIVATE KEY-----`
|
-----END EC PRIVATE KEY-----`
|
||||||
|
|
||||||
helloCRT = `
|
helloCRT = `
|
||||||
-----BEGIN CERTIFICATE-----
|
-----BEGIN CERTIFICATE-----
|
||||||
MIICkDCCAhigAwIBAgIJAPtKfUjc2lwGMAkGByqGSM49BAEwgYoxCzAJBgNVBAYT
|
MIICiDCCAg6gAwIBAgIJAJ/FfkBTtbuIMAkGByqGSM49BAEwfzELMAkGA1UEBhMC
|
||||||
AlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQKDBBD
|
VVMxDjAMBgNVBAgMBVRleGFzMQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENs
|
||||||
bG91ZGZsYXJlLCBJbmMuMT8wPQYDVQQDDDZDbG91ZGZsYXJlIEFyZ28gVHVubmVs
|
b3VkZmxhcmUsIEluYy4xNDAyBgNVBAMMK0FyZ28gVHVubmVsIFNhbXBsZSBIZWxs
|
||||||
IFNhbXBsZSBIZWxsbyBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMjE1MjAxNjU5
|
byBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMzE5MjMwNTMyWhcNMjgwMzE2MjMw
|
||||||
WhcNMjgwMjEzMjAxNjU5WjCBijELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVRleGFz
|
NTMyWjB/MQswCQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1
|
||||||
MQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENsb3VkZmxhcmUsIEluYy4xPzA9
|
c3RpbjEZMBcGA1UECgwQQ2xvdWRmbGFyZSwgSW5jLjE0MDIGA1UEAwwrQXJnbyBU
|
||||||
BgNVBAMMNkNsb3VkZmxhcmUgQXJnbyBUdW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZl
|
dW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZlciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49
|
||||||
ciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49AgEGBSuBBAAiA2IABOqkXNPHFXsP7+iq
|
AgEGBSuBBAAiA2IABJb/8D2INs8qDc4uSfvImot6vfaD0jmUEpG09z/MG5tAsvP2
|
||||||
EoJNHoEQz16VN8ppFAQWwu59yZpMzD/21PCCkcpqj9CwsoNM1tCFtJAuRyakJAqV
|
UBpgjoA5CIleEavkmQqZB+LWxgcrmskhM8XLhbwS3iEct1rMNkdacR35NK1AAEVS
|
||||||
IaGX6kKuo23+hhQCFq9QjCnAaIsKoYf9HYCpYspE85r9G5+z9aNJMEcwRQYDVR0R
|
OuG2pc9aYYpyXLSk4aNXMFUwUwYDVR0RBEwwSoIJbG9jYWxob3N0ghFjbG91ZGZs
|
||||||
BD4wPIIJbG9jYWxob3N0ggp3YXJwLWhlbGxvggt3YXJwMi1oZWxsb4cEfwAAAYcQ
|
YXJlZC1oZWxsb4ISY2xvdWRmbGFyZWQyLWhlbGxvhwR/AAABhxAAAAAAAAAAAAAA
|
||||||
AAAAAAAAAAAAAAAAAAAAATAJBgcqhkjOPQQBA2cAMGQCMHyVPufXZ6vQo6XRWRa0
|
AAAAAAABMAkGByqGSM49BAEDaQAwZgIxAPxkdghH6y8xLMnY9Bom3Llf4NYM6yB9
|
||||||
dAwtfgesOdZVP2Wt+t5v8jOIQQh1IQXYk5GtyoZGSObjhQIwd1fRgAyKXaZt+1DV
|
PD1YsaNUJTsxjTk3YY1Jsp+yzK0yUKtTZwIxAPcdvqCF2/iR9H288pCT1TgtO0a9
|
||||||
ZtHTdf8pMvESfJsSd8AB1eQ6q+pAiRUYyaxcE1Mlo2YY5o+g
|
cJL9RY1lq7DIGN37v1ZXReWaD+3hNokY8NriVg==
|
||||||
-----END CERTIFICATE-----`
|
-----END CERTIFICATE-----`
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
package tunneldns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/coredns/coredns/plugin"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Upstream is a simplified interface for proxy destination
|
||||||
|
type Upstream interface {
|
||||||
|
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyPlugin is a simplified DNS proxy using a generic upstream interface
|
||||||
|
type ProxyPlugin struct {
|
||||||
|
Upstreams []Upstream
|
||||||
|
Next plugin.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeDNS implements interface for CoreDNS plugin
|
||||||
|
func (p ProxyPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
var reply *dns.Msg
|
||||||
|
var backendErr error
|
||||||
|
|
||||||
|
for _, upstream := range p.Upstreams {
|
||||||
|
reply, backendErr = upstream.Exchange(ctx, r)
|
||||||
|
if backendErr == nil {
|
||||||
|
w.WriteMsg(reply)
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name implements interface for CoreDNS plugin
|
||||||
|
func (p ProxyPlugin) Name() string { return "proxy" }
|
|
@ -0,0 +1,97 @@
|
||||||
|
package tunneldns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpstreamHTTPS is the upstream implementation for DNS over HTTPS service
|
||||||
|
type UpstreamHTTPS struct {
|
||||||
|
client *http.Client
|
||||||
|
endpoint *url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUpstreamHTTPS creates a new DNS over HTTPS upstream from hostname
|
||||||
|
func NewUpstreamHTTPS(endpoint string) (Upstream, error) {
|
||||||
|
u, err := url.Parse(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update TLS and HTTP client configuration
|
||||||
|
tls := &tls.Config{ServerName: u.Hostname()}
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: time.Second * defaultTimeout,
|
||||||
|
Transport: &http.Transport{TLSClientConfig: tls},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UpstreamHTTPS{client: client, endpoint: u}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange provides an implementation for the Upstream interface
|
||||||
|
func (u *UpstreamHTTPS) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||||
|
queryBuf, err := query.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to pack DNS query")
|
||||||
|
}
|
||||||
|
|
||||||
|
// No content negotiation for now, use DNS wire format
|
||||||
|
buf, backendErr := u.exchangeWireformat(queryBuf)
|
||||||
|
if backendErr == nil {
|
||||||
|
response := &dns.Msg{}
|
||||||
|
if err := response.Unpack(buf); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to unpack DNS response from body")
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Id = query.Id
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithError(backendErr).Errorf("failed to connect to an HTTPS backend %q", u.endpoint)
|
||||||
|
return nil, backendErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform message exchange with the default UDP wireformat defined in current draft
|
||||||
|
// https://datatracker.ietf.org/doc/draft-ietf-doh-dns-over-https
|
||||||
|
func (u *UpstreamHTTPS) exchangeWireformat(msg []byte) ([]byte, error) {
|
||||||
|
req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to create an HTTPS request")
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Add("Content-Type", "application/dns-udpwireformat")
|
||||||
|
req.Host = u.endpoint.Hostname()
|
||||||
|
|
||||||
|
resp, err := u.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to perform an HTTPS request")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check response status code
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read wireformat response from the body
|
||||||
|
buf, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to read the response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
package tunneldns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/coredns/coredns/plugin"
|
||||||
|
"github.com/coredns/coredns/plugin/metrics/vars"
|
||||||
|
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||||
|
"github.com/coredns/coredns/plugin/pkg/rcode"
|
||||||
|
"github.com/coredns/coredns/request"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MetricsPlugin is an adapter for CoreDNS and built-in metrics
|
||||||
|
type MetricsPlugin struct {
|
||||||
|
Next plugin.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMetricsPlugin creates a plugin with configured metrics
|
||||||
|
func NewMetricsPlugin(next plugin.Handler) *MetricsPlugin {
|
||||||
|
prometheus.MustRegister(vars.RequestCount)
|
||||||
|
prometheus.MustRegister(vars.RequestDuration)
|
||||||
|
prometheus.MustRegister(vars.RequestSize)
|
||||||
|
prometheus.MustRegister(vars.RequestDo)
|
||||||
|
prometheus.MustRegister(vars.RequestType)
|
||||||
|
prometheus.MustRegister(vars.ResponseSize)
|
||||||
|
prometheus.MustRegister(vars.ResponseRcode)
|
||||||
|
return &MetricsPlugin{Next: next}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeDNS implements the CoreDNS plugin interface
|
||||||
|
func (p MetricsPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
state := request.Request{W: w, Req: r}
|
||||||
|
|
||||||
|
rw := dnstest.NewRecorder(w)
|
||||||
|
status, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, rw, r)
|
||||||
|
|
||||||
|
// Update built-in metrics
|
||||||
|
vars.Report(state, ".", rcode.ToString(rw.Rcode), rw.Len, rw.Start)
|
||||||
|
|
||||||
|
return status, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name implements the CoreDNS plugin interface
|
||||||
|
func (p MetricsPlugin) Name() string { return "metrics" }
|
|
@ -0,0 +1,144 @@
|
||||||
|
package tunneldns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"gopkg.in/urfave/cli.v2"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/metrics"
|
||||||
|
"github.com/coredns/coredns/core/dnsserver"
|
||||||
|
"github.com/coredns/coredns/plugin"
|
||||||
|
"github.com/coredns/coredns/plugin/cache"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Listener is an adapter between CoreDNS server and Warp runnable
|
||||||
|
type Listener struct {
|
||||||
|
server *dnsserver.Server
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run implements a foreground runner
|
||||||
|
func Run(c *cli.Context) error {
|
||||||
|
metricsListener, err := net.Listen("tcp", c.String("metrics"))
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Fatal("Failed to open the metrics listener")
|
||||||
|
}
|
||||||
|
|
||||||
|
go metrics.ServeMetrics(metricsListener, nil)
|
||||||
|
|
||||||
|
listener, err := CreateListener(c.String("address"), uint16(c.Uint("port")), c.StringSlice("upstream"))
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Errorf("Failed to create the listeners")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to start the server
|
||||||
|
err = listener.Start()
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Errorf("Failed to start the listeners")
|
||||||
|
return listener.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for signal
|
||||||
|
signals := make(chan os.Signal, 10)
|
||||||
|
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
defer signal.Stop(signals)
|
||||||
|
<-signals
|
||||||
|
|
||||||
|
// Shut down server
|
||||||
|
err = listener.Stop()
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Errorf("failed to stop")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a CoreDNS server plugin from configuration
|
||||||
|
func createConfig(address string, port uint16, p plugin.Handler) *dnsserver.Config {
|
||||||
|
c := &dnsserver.Config{
|
||||||
|
Zone: ".",
|
||||||
|
Transport: "dns",
|
||||||
|
ListenHosts: []string{address},
|
||||||
|
Port: strconv.FormatUint(uint64(port), 10),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.AddPlugin(func(next plugin.Handler) plugin.Handler { return p })
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start blocks for serving requests
|
||||||
|
func (l *Listener) Start() error {
|
||||||
|
log.WithField("addr", l.server.Address()).Infof("Starting DNS over HTTPS proxy server")
|
||||||
|
|
||||||
|
// Start UDP listener
|
||||||
|
if udp, err := l.server.ListenPacket(); err == nil {
|
||||||
|
l.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
l.server.ServePacket(udp)
|
||||||
|
l.wg.Done()
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
return errors.Wrap(err, "failed to create a UDP listener")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start TCP listener
|
||||||
|
tcp, err := l.server.Listen()
|
||||||
|
if err == nil {
|
||||||
|
l.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
l.server.Serve(tcp)
|
||||||
|
l.wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.Wrap(err, "failed to create a TCP listener")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop signals server shutdown and blocks until completed
|
||||||
|
func (l *Listener) Stop() error {
|
||||||
|
if err := l.server.Stop(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
l.wg.Wait()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateListener configures the server and bound sockets
|
||||||
|
func CreateListener(address string, port uint16, upstreams []string) (*Listener, error) {
|
||||||
|
// Build the list of upstreams
|
||||||
|
upstreamList := make([]Upstream, 0)
|
||||||
|
for _, url := range upstreams {
|
||||||
|
log.WithField("url", url).Infof("Adding DNS upstream")
|
||||||
|
upstream, err := NewUpstreamHTTPS(url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to create HTTPS upstream")
|
||||||
|
}
|
||||||
|
upstreamList = append(upstreamList, upstream)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a local cache with HTTPS proxy plugin
|
||||||
|
chain := cache.New()
|
||||||
|
chain.Next = ProxyPlugin{
|
||||||
|
Upstreams: upstreamList,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format an endpoint
|
||||||
|
endpoint := fmt.Sprintf("dns://%s:%d", address, port)
|
||||||
|
|
||||||
|
// Create the actual middleware server
|
||||||
|
server, err := dnsserver.NewServer(endpoint, []*dnsserver.Config{createConfig(address, port, NewMetricsPlugin(chain))})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Listener{server: server}, nil
|
||||||
|
}
|
|
@ -1,7 +1,7 @@
|
||||||
package pogs
|
package pogs
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/cloudflare/cloudflare-warp/tunnelrpc"
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
"golang.org/x/net/context"
|
"golang.org/x/net/context"
|
||||||
"zombiezen.com/go/capnproto2"
|
"zombiezen.com/go/capnproto2"
|
||||||
"zombiezen.com/go/capnproto2/pogs"
|
"zombiezen.com/go/capnproto2/pogs"
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
using Go = import "go.capnp";
|
using Go = import "go.capnp";
|
||||||
@0xdb8274f9144abc7e;
|
@0xdb8274f9144abc7e;
|
||||||
$Go.package("tunnelrpc");
|
$Go.package("tunnelrpc");
|
||||||
$Go.import("github.com/cloudflare/cloudflare-warp/tunnelrpc");
|
$Go.import("github.com/cloudflare/cloudflared/tunnelrpc");
|
||||||
|
|
||||||
struct Authentication {
|
struct Authentication {
|
||||||
key @0 :Text;
|
key @0 :Text;
|
||||||
|
@ -35,7 +35,7 @@ struct RegistrationOptions {
|
||||||
connectionId @6 :UInt8;
|
connectionId @6 :UInt8;
|
||||||
# origin LAN IP
|
# origin LAN IP
|
||||||
originLocalIp @7 :Text;
|
originLocalIp @7 :Text;
|
||||||
# whether Warp client has been autoupdated
|
# whether Argo Tunnel client has been autoupdated
|
||||||
isAutoupdated @8 :Bool;
|
isAutoupdated @8 :Bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -119,7 +119,7 @@ func validateScheme(scheme string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Errorf("Currently Cloudflare-Warp does not support %s protocol.", scheme)
|
return fmt.Errorf("Currently Argo Tunnel does not support %s protocol.", scheme)
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateIP(scheme, host, port string) (string, error) {
|
func validateIP(scheme, host, port string) (string, error) {
|
||||||
|
|
|
@ -126,7 +126,7 @@ func TestValidateUrl(t *testing.T) {
|
||||||
assert.Equal(t, "https://hello.example.com", validUrl)
|
assert.Equal(t, "https://hello.example.com", validUrl)
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
|
validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
|
||||||
assert.Equal(t, "Currently Cloudflare-Warp does not support ftp protocol.", err.Error())
|
assert.Equal(t, "Currently Argo Tunnel does not support ftp protocol.", err.Error())
|
||||||
assert.Empty(t, validUrl)
|
assert.Empty(t, validUrl)
|
||||||
|
|
||||||
validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")
|
validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")
|
||||||
|
|
Loading…
Reference in New Issue