Release Argo Tunnel Client 2018.3.1
This commit is contained in:
		
							parent
							
								
									9f5cec8dbc
								
							
						
					
					
						commit
						d0a6a2a829
					
				| 
						 | 
				
			
			@ -0,0 +1,13 @@
 | 
			
		|||
// +build !windows,!darwin,!linux
 | 
			
		||||
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"os"
 | 
			
		||||
 | 
			
		||||
	cli "gopkg.in/urfave/cli.v2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func runApp(app *cli.App) {
 | 
			
		||||
	app.Run(os.Args)
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,204 @@
 | 
			
		|||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gopkg.in/urfave/cli.v2"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/tlsconfig"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type templateData struct {
 | 
			
		||||
	ServerName string
 | 
			
		||||
	Request    *http.Request
 | 
			
		||||
	Body       string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OriginUpTime struct {
 | 
			
		||||
	StartTime time.Time `json:"startTime"`
 | 
			
		||||
	UpTime    string    `json:"uptime"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const defaultServerName = "the Argo Tunnel test server"
 | 
			
		||||
const indexTemplate = `
 | 
			
		||||
<!DOCTYPE html>
 | 
			
		||||
<html lang="en">
 | 
			
		||||
  <head>
 | 
			
		||||
    <meta charset="utf-8">
 | 
			
		||||
    <meta http-equiv="X-UA-Compatible" content="IE=Edge">
 | 
			
		||||
    <title>
 | 
			
		||||
      Argo Tunnel Connection
 | 
			
		||||
    </title>
 | 
			
		||||
    <meta name="author" content="">
 | 
			
		||||
    <meta name="description" content="Argo Tunnel Connection">
 | 
			
		||||
    <meta name="viewport" content="width=device-width, initial-scale=1">
 | 
			
		||||
    <style>
 | 
			
		||||
      html{line-height:1.15;-ms-text-size-adjust:100%;-webkit-text-size-adjust:100%}body{margin:0}section{display:block}h1{font-size:2em;margin:.67em 0}a{background-color:transparent;-webkit-text-decoration-skip:objects}/* 1 */::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}/* 1 */a,body,dd,div,dl,dt,h1,h4,html,p,section{box-sizing:border-box}.bt{border-top-style:solid;border-top-width:1px}.bl{border-left-style:solid;border-left-width:1px}.b--orange{border-color:#f38020}.br1{border-radius:.125rem}.bw2{border-width:.25rem}.dib{display:inline-block}.sans-serif{font-family:open sans,-apple-system,BlinkMacSystemFont,avenir next,avenir,helvetica neue,helvetica,ubuntu,roboto,noto,segoe ui,arial,sans-serif}.code{font-family:Consolas,monaco,monospace}.b{font-weight:700}.fw3{font-weight:300}.fw4{font-weight:400}.fw5{font-weight:500}.fw6{font-weight:600}.lh-copy{line-height:1.5}.link{text-decoration:none}.link,.link:active,.link:focus,.link:hover,.link:link,.link:visited{transition:color .15s ease-in}.link:focus{outline:1px dotted currentColor}.mw-100{max-width:100%}.mw4{max-width:8rem}.mw7{max-width:48rem}.bg-light-gray{background-color:#f7f7f7}.link-hover:hover{background-color:#1f679e}.white{color:#fff}.bg-white{background-color:#fff}.bg-blue{background-color:#408bc9}.pb2{padding-bottom:.5rem}.pb6{padding-bottom:8rem}.pt3{padding-top:1rem}.pt5{padding-top:4rem}.pv2{padding-top:.5rem;padding-bottom:.5rem}.ph3{padding-left:1rem;padding-right:1rem}.ph4{padding-left:2rem;padding-right:2rem}.ml0{margin-left:0}.mb1{margin-bottom:.25rem}.mb2{margin-bottom:.5rem}.mb3{margin-bottom:1rem}.mt5{margin-top:4rem}.ttu{text-transform:uppercase}.f4{font-size:1.25rem}.f5{font-size:1rem}.f6{font-size:.875rem}.f7{font-size:.75rem}.measure{max-width:30em}.center{margin-left:auto}.center{margin-right:auto}@media screen and (min-width:30em){.f2-ns{font-size:2.25rem}}@media screen and (min-width:30em) and (max-width:60em){.f5-m{font-size:1rem}}@media screen and (min-width:60em){.f4-l{font-size:1.25rem}}
 | 
			
		||||
    .st0{fill:#FFF}.st1{fill:#f48120}.st2{fill:#faad3f}.st3{fill:#404041}
 | 
			
		||||
    </style>
 | 
			
		||||
  </head>
 | 
			
		||||
  <body class="sans-serif black">
 | 
			
		||||
    <div class="bt bw2 b--orange bg-white pb6">
 | 
			
		||||
      <div class="mw7 center ph4 pt3">
 | 
			
		||||
        <svg id="Layer_2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 109 40.5" class="mw4">
 | 
			
		||||
          <path class="st0" d="M98.6 14.2L93 12.9l-1-.4-25.7.2v12.4l32.3.1z"/>
 | 
			
		||||
          <path class="st1" d="M88.1 24c.3-1 .2-2-.3-2.6-.5-.6-1.2-1-2.1-1.1l-17.4-.2c-.1 0-.2-.1-.3-.1-.1-.1-.1-.2 0-.3.1-.2.2-.3.4-.3l17.5-.2c2.1-.1 4.3-1.8 5.1-3.8l1-2.6c0-.1.1-.2 0-.3-1.1-5.1-5.7-8.9-11.1-8.9-5 0-9.3 3.2-10.8 7.7-1-.7-2.2-1.1-3.6-1-2.4.2-4.3 2.2-4.6 4.6-.1.6 0 1.2.1 1.8-3.9.1-7.1 3.3-7.1 7.3 0 .4 0 .7.1 1.1 0 .2.2.3.3.3h32.1c.2 0 .4-.1.4-.3l.3-1.1z"/>
 | 
			
		||||
          <path class="st2" d="M93.6 12.8h-.5c-.1 0-.2.1-.3.2l-.7 2.4c-.3 1-.2 2 .3 2.6.5.6 1.2 1 2.1 1.1l3.7.2c.1 0 .2.1.3.1.1.1.1.2 0 .3-.1.2-.2.3-.4.3l-3.8.2c-2.1.1-4.3 1.8-5.1 3.8l-.2.9c-.1.1 0 .3.2.3h13.2c.2 0 .3-.1.3-.3.2-.8.4-1.7.4-2.6 0-5.2-4.3-9.5-9.5-9.5"/>
 | 
			
		||||
          <path class="st3" d="M104.4 30.8c-.5 0-.9-.4-.9-.9s.4-.9.9-.9.9.4.9.9-.4.9-.9.9m0-1.6c-.4 0-.7.3-.7.7 0 .4.3.7.7.7.4 0 .7-.3.7-.7 0-.4-.3-.7-.7-.7m.4 1.2h-.2l-.2-.3h-.2v.3h-.2v-.9h.5c.2 0 .3.1.3.3 0 .1-.1.2-.2.3l.2.3zm-.3-.5c.1 0 .1 0 .1-.1s-.1-.1-.1-.1h-.3v.3h.3zM14.8 29H17v6h3.8v1.9h-6zM23.1 32.9c0-2.3 1.8-4.1 4.3-4.1s4.2 1.8 4.2 4.1-1.8 4.1-4.3 4.1c-2.4 0-4.2-1.8-4.2-4.1m6.3 0c0-1.2-.8-2.2-2-2.2s-2 1-2 2.1.8 2.1 2 2.1c1.2.2 2-.8 2-2M34.3 33.4V29h2.2v4.4c0 1.1.6 1.7 1.5 1.7s1.5-.5 1.5-1.6V29h2.2v4.4c0 2.6-1.5 3.7-3.7 3.7-2.3-.1-3.7-1.2-3.7-3.7M45 29h3.1c2.8 0 4.5 1.6 4.5 3.9s-1.7 4-4.5 4h-3V29zm3.1 5.9c1.3 0 2.2-.7 2.2-2s-.9-2-2.2-2h-.9v4h.9zM55.7 29H62v1.9h-4.1v1.3h3.7V34h-3.7v2.9h-2.2zM65.1 29h2.2v6h3.8v1.9h-6zM76.8 28.9H79l3.4 8H80l-.6-1.4h-3.1l-.6 1.4h-2.3l3.4-8zm2 4.9l-.9-2.2-.9 2.2h1.8zM85.2 29h3.7c1.2 0 2 .3 2.6.9.5.5.7 1.1.7 1.8 0 1.2-.6 2-1.6 2.4l1.9 2.8H90l-1.6-2.4h-1v2.4h-2.2V29zm3.6 3.8c.7 0 1.2-.4 1.2-.9 0-.6-.5-.9-1.2-.9h-1.4v1.9h1.4zM95.3 29h6.4v1.8h-4.2V32h3.8v1.8h-3.8V35h4.3v1.9h-6.5zM10 33.9c-.3.7-1 1.2-1.8 1.2-1.2 0-2-1-2-2.1s.8-2.1 2-2.1c.9 0 1.6.6 1.9 1.3h2.3c-.4-1.9-2-3.3-4.2-3.3-2.4 0-4.3 1.8-4.3 4.1s1.8 4.1 4.2 4.1c2.1 0 3.7-1.4 4.2-3.2H10z"/>
 | 
			
		||||
        </svg>
 | 
			
		||||
        <h1 class="f4 f2-ns mt5 fw5">Congrats! You created your first tunnel!</h1>
 | 
			
		||||
        <p class="f6 f5-m f4-l measure lh-copy fw3">
 | 
			
		||||
          Argo Tunnel exposes locally running applications to the internet by
 | 
			
		||||
          running an encrypted, virtual tunnel from your laptop or server to
 | 
			
		||||
          Cloudflare's edge network.
 | 
			
		||||
        </p>
 | 
			
		||||
        <p class="b f5 mt5 fw6">Ready for the next step?</p>
 | 
			
		||||
        <a
 | 
			
		||||
          class="fw6 link white bg-blue ph4 pv2 br1 dib f5 link-hover"
 | 
			
		||||
          style="border-bottom: 1px solid #1f679e"
 | 
			
		||||
          href="https://developers.cloudflare.com/argo-tunnel/">
 | 
			
		||||
          Get started here
 | 
			
		||||
        </a>
 | 
			
		||||
       <section>
 | 
			
		||||
          <h4 class="f6 fw4 pt5 mb2">Request</h4>
 | 
			
		||||
          <dl class="bl bw2 b--orange ph3 pt3 pb2 bg-light-gray f7 code overflow-x-auto mw-100">
 | 
			
		||||
						<dd class="ml0 mb3 f5">Method: {{.Request.Method}}</dd>
 | 
			
		||||
						<dd class="ml0 mb3 f5">Protocol: {{.Request.Proto}}</dd>
 | 
			
		||||
						<dd class="ml0 mb3 f5">Request URL: {{.Request.URL}}</dd>
 | 
			
		||||
						<dd class="ml0 mb3 f5">Transfer encoding: {{.Request.TransferEncoding}}</dd>
 | 
			
		||||
						<dd class="ml0 mb3 f5">Host: {{.Request.Host}}</dd>
 | 
			
		||||
						<dd class="ml0 mb3 f5">Remote address: {{.Request.RemoteAddr}}</dd>
 | 
			
		||||
						<dd class="ml0 mb3 f5">Request URI: {{.Request.RequestURI}}</dd>
 | 
			
		||||
{{range $key, $value := .Request.Header}}
 | 
			
		||||
						<dd class="ml0 mb3 f5">Header: {{$key}}, Value: {{$value}}</dd>
 | 
			
		||||
{{end}}
 | 
			
		||||
						<dd class="ml0 mb3 f5">Body: {{.Body}}</dd>
 | 
			
		||||
					</dl>
 | 
			
		||||
        </section>
 | 
			
		||||
     </div>
 | 
			
		||||
    </div>
 | 
			
		||||
  </body>
 | 
			
		||||
</html>
 | 
			
		||||
`
 | 
			
		||||
 | 
			
		||||
func hello(c *cli.Context) error {
 | 
			
		||||
	address := fmt.Sprintf(":%d", c.Int("port"))
 | 
			
		||||
	listener, err := createListener(address)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer listener.Close()
 | 
			
		||||
	err = startHelloWorldServer(listener, nil)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
 | 
			
		||||
	Log.Infof("Starting Hello World server at %s", listener.Addr())
 | 
			
		||||
	serverName := defaultServerName
 | 
			
		||||
	if hostname, err := os.Hostname(); err == nil {
 | 
			
		||||
		serverName = hostname
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	upgrader := websocket.Upgrader{
 | 
			
		||||
		ReadBufferSize:  1024,
 | 
			
		||||
		WriteBufferSize: 1024,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
 | 
			
		||||
	go func() {
 | 
			
		||||
		<-shutdownC
 | 
			
		||||
		httpServer.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	http.HandleFunc("/uptime", uptimeHandler(time.Now()))
 | 
			
		||||
	http.HandleFunc("/ws", websocketHandler(upgrader))
 | 
			
		||||
	http.HandleFunc("/", rootHandler(serverName))
 | 
			
		||||
	err := httpServer.Serve(listener)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func createListener(address string) (net.Listener, error) {
 | 
			
		||||
	certificate, err := tlsconfig.GetHelloCertificate()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If the port in address is empty, a port number is automatically chosen
 | 
			
		||||
	listener, err := tls.Listen(
 | 
			
		||||
		"tcp",
 | 
			
		||||
		address,
 | 
			
		||||
		&tls.Config{Certificates: []tls.Certificate{certificate}})
 | 
			
		||||
 | 
			
		||||
	return listener, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func uptimeHandler(startTime time.Time) http.HandlerFunc {
 | 
			
		||||
	return func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		// Note that if autoupdate is enabled, the uptime is reset when a new client
 | 
			
		||||
		// release is available
 | 
			
		||||
		resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()}
 | 
			
		||||
		respJson, err := json.Marshal(resp)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			w.WriteHeader(http.StatusInternalServerError)
 | 
			
		||||
		} else {
 | 
			
		||||
			w.Header().Set("Content-Type", "application/json")
 | 
			
		||||
			w.Write(respJson)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc {
 | 
			
		||||
	return func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		conn, err := upgrader.Upgrade(w, r, nil)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		defer conn.Close()
 | 
			
		||||
 | 
			
		||||
		for {
 | 
			
		||||
			mt, message, err := conn.ReadMessage()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err := conn.WriteMessage(mt, message); err != nil {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func rootHandler(serverName string) http.HandlerFunc {
 | 
			
		||||
	responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
 | 
			
		||||
	return func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		var buffer bytes.Buffer
 | 
			
		||||
		var body string
 | 
			
		||||
		rawBody, err := ioutil.ReadAll(r.Body)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			body = string(rawBody)
 | 
			
		||||
		} else {
 | 
			
		||||
			body = ""
 | 
			
		||||
		}
 | 
			
		||||
		err = responseTemplate.Execute(&buffer, &templateData{
 | 
			
		||||
			ServerName: serverName,
 | 
			
		||||
			Request:    r,
 | 
			
		||||
			Body:       body,
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			w.WriteHeader(http.StatusInternalServerError)
 | 
			
		||||
			fmt.Fprintf(w, "error: %v", err)
 | 
			
		||||
		} else {
 | 
			
		||||
			buffer.WriteTo(w)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,35 @@
 | 
			
		|||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestCreateListenerHostAndPortSuccess(t *testing.T) {
 | 
			
		||||
	listener, err := createListener("localhost:1234")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	if listener.Addr().String() == "" {
 | 
			
		||||
		t.Fatal("Fail to find available port")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCreateListenerOnlyHostSuccess(t *testing.T) {
 | 
			
		||||
	listener, err := createListener("localhost:")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	if listener.Addr().String() == "" {
 | 
			
		||||
		t.Fatal("Fail to find available port")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCreateListenerOnlyPortSuccess(t *testing.T) {
 | 
			
		||||
	listener, err := createListener(":8888")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	if listener.Addr().String() == "" {
 | 
			
		||||
		t.Fatal("Fail to find available port")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,292 @@
 | 
			
		|||
// +build linux
 | 
			
		||||
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
 | 
			
		||||
	cli "gopkg.in/urfave/cli.v2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func runApp(app *cli.App) {
 | 
			
		||||
	app.Commands = append(app.Commands, &cli.Command{
 | 
			
		||||
		Name:  "service",
 | 
			
		||||
		Usage: "Manages the Argo Tunnel system service",
 | 
			
		||||
		Subcommands: []*cli.Command{
 | 
			
		||||
			&cli.Command{
 | 
			
		||||
				Name:   "install",
 | 
			
		||||
				Usage:  "Install Argo Tunnel as a system service",
 | 
			
		||||
				Action: installLinuxService,
 | 
			
		||||
			},
 | 
			
		||||
			&cli.Command{
 | 
			
		||||
				Name:   "uninstall",
 | 
			
		||||
				Usage:  "Uninstall the Argo Tunnel service",
 | 
			
		||||
				Action: uninstallLinuxService,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
	app.Run(os.Args)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const serviceConfigDir = "/etc/cloudflared"
 | 
			
		||||
 | 
			
		||||
var systemdTemplates = []ServiceTemplate{
 | 
			
		||||
	{
 | 
			
		||||
		Path: "/etc/systemd/system/cloudflared.service",
 | 
			
		||||
		Content: `[Unit]
 | 
			
		||||
Description=Argo Tunnel
 | 
			
		||||
After=network.target
 | 
			
		||||
 | 
			
		||||
[Service]
 | 
			
		||||
TimeoutStartSec=0
 | 
			
		||||
Type=notify
 | 
			
		||||
ExecStart={{ .Path }} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --no-autoupdate
 | 
			
		||||
Restart=on-failure
 | 
			
		||||
RestartSec=5s
 | 
			
		||||
 | 
			
		||||
[Install]
 | 
			
		||||
WantedBy=multi-user.target
 | 
			
		||||
`,
 | 
			
		||||
	},
 | 
			
		||||
	{
 | 
			
		||||
		Path: "/etc/systemd/system/cloudflared-update.service",
 | 
			
		||||
		Content: `[Unit]
 | 
			
		||||
Description=Update Argo Tunnel
 | 
			
		||||
After=network.target
 | 
			
		||||
 | 
			
		||||
[Service]
 | 
			
		||||
ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 64 ]; then systemctl restart cloudflared; exit 0; fi; exit $code'
 | 
			
		||||
`,
 | 
			
		||||
	},
 | 
			
		||||
	{
 | 
			
		||||
		Path: "/etc/systemd/system/cloudflared-update.timer",
 | 
			
		||||
		Content: `[Unit]
 | 
			
		||||
Description=Update Argo Tunnel
 | 
			
		||||
 | 
			
		||||
[Timer]
 | 
			
		||||
OnUnitActiveSec=1d
 | 
			
		||||
 | 
			
		||||
[Install]
 | 
			
		||||
WantedBy=timers.target
 | 
			
		||||
`,
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var sysvTemplate = ServiceTemplate{
 | 
			
		||||
	Path:     "/etc/init.d/cloudflared",
 | 
			
		||||
	FileMode: 0755,
 | 
			
		||||
	Content: `# For RedHat and cousins:
 | 
			
		||||
# chkconfig: 2345 99 01
 | 
			
		||||
# description: Argo Tunnel agent
 | 
			
		||||
# processname: {{.Path}}
 | 
			
		||||
### BEGIN INIT INFO
 | 
			
		||||
# Provides:          {{.Path}}
 | 
			
		||||
# Required-Start:
 | 
			
		||||
# Required-Stop:
 | 
			
		||||
# Default-Start:     2 3 4 5
 | 
			
		||||
# Default-Stop:      0 1 6
 | 
			
		||||
# Short-Description: Argo Tunnel
 | 
			
		||||
# Description:       Argo Tunnel agent
 | 
			
		||||
### END INIT INFO
 | 
			
		||||
cmd="{{.Path}} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --pidfile /var/run/$name.pid --autoupdate-freq 24h0m0s"
 | 
			
		||||
name=$(basename $(readlink -f $0))
 | 
			
		||||
pid_file="/var/run/$name.pid"
 | 
			
		||||
stdout_log="/var/log/$name.log"
 | 
			
		||||
stderr_log="/var/log/$name.err"
 | 
			
		||||
[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name
 | 
			
		||||
get_pid() {
 | 
			
		||||
    cat "$pid_file"
 | 
			
		||||
}
 | 
			
		||||
is_running() {
 | 
			
		||||
    [ -f "$pid_file" ] && ps $(get_pid) > /dev/null 2>&1
 | 
			
		||||
}
 | 
			
		||||
case "$1" in
 | 
			
		||||
    start)
 | 
			
		||||
        if is_running; then
 | 
			
		||||
            echo "Already started"
 | 
			
		||||
        else
 | 
			
		||||
            echo "Starting $name"
 | 
			
		||||
            $cmd >> "$stdout_log" 2>> "$stderr_log" &
 | 
			
		||||
            echo $! > "$pid_file"
 | 
			
		||||
            if ! is_running; then
 | 
			
		||||
                echo "Unable to start, see $stdout_log and $stderr_log"
 | 
			
		||||
                exit 1
 | 
			
		||||
            fi
 | 
			
		||||
        fi
 | 
			
		||||
    ;;
 | 
			
		||||
    stop)
 | 
			
		||||
        if is_running; then
 | 
			
		||||
            echo -n "Stopping $name.."
 | 
			
		||||
            kill $(get_pid)
 | 
			
		||||
            for i in {1..10}
 | 
			
		||||
            do
 | 
			
		||||
                if ! is_running; then
 | 
			
		||||
                    break
 | 
			
		||||
                fi
 | 
			
		||||
                echo -n "."
 | 
			
		||||
                sleep 1
 | 
			
		||||
            done
 | 
			
		||||
            echo
 | 
			
		||||
            if is_running; then
 | 
			
		||||
                echo "Not stopped; may still be shutting down or shutdown may have failed"
 | 
			
		||||
                exit 1
 | 
			
		||||
            else
 | 
			
		||||
                echo "Stopped"
 | 
			
		||||
                if [ -f "$pid_file" ]; then
 | 
			
		||||
                    rm "$pid_file"
 | 
			
		||||
                fi
 | 
			
		||||
            fi
 | 
			
		||||
        else
 | 
			
		||||
            echo "Not running"
 | 
			
		||||
        fi
 | 
			
		||||
    ;;
 | 
			
		||||
    restart)
 | 
			
		||||
        $0 stop
 | 
			
		||||
        if is_running; then
 | 
			
		||||
            echo "Unable to stop, will not attempt to start"
 | 
			
		||||
            exit 1
 | 
			
		||||
        fi
 | 
			
		||||
        $0 start
 | 
			
		||||
    ;;
 | 
			
		||||
    status)
 | 
			
		||||
        if is_running; then
 | 
			
		||||
            echo "Running"
 | 
			
		||||
        else
 | 
			
		||||
            echo "Stopped"
 | 
			
		||||
            exit 1
 | 
			
		||||
        fi
 | 
			
		||||
    ;;
 | 
			
		||||
    *)
 | 
			
		||||
    echo "Usage: $0 {start|stop|restart|status}"
 | 
			
		||||
    exit 1
 | 
			
		||||
    ;;
 | 
			
		||||
esac
 | 
			
		||||
exit 0
 | 
			
		||||
`,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isSystemd() bool {
 | 
			
		||||
	if _, err := os.Stat("/run/systemd/system"); err == nil {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func installLinuxService(c *cli.Context) error {
 | 
			
		||||
	etPath, err := os.Executable()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error determining executable path: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	templateArgs := ServiceTemplateArgs{Path: etPath}
 | 
			
		||||
 | 
			
		||||
	defaultConfigDir := filepath.Dir(c.String("config"))
 | 
			
		||||
	defaultConfigFile := filepath.Base(c.String("config"))
 | 
			
		||||
	if err = copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile); err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("Failed to copy user configuration. Before running the service, ensure that %s contains two files, %s and %s",
 | 
			
		||||
			serviceConfigDir, credentialFile, defaultConfigFiles[0])
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch {
 | 
			
		||||
	case isSystemd():
 | 
			
		||||
		Log.Infof("Using Systemd")
 | 
			
		||||
		return installSystemd(&templateArgs)
 | 
			
		||||
	default:
 | 
			
		||||
		Log.Infof("Using Sysv")
 | 
			
		||||
		return installSysv(&templateArgs)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func installSystemd(templateArgs *ServiceTemplateArgs) error {
 | 
			
		||||
	for _, serviceTemplate := range systemdTemplates {
 | 
			
		||||
		err := serviceTemplate.Generate(templateArgs)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			Log.WithError(err).Infof("error generating service template")
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if err := runCommand("systemctl", "enable", "cloudflared.service"); err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("systemctl enable cloudflared.service error")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err := runCommand("systemctl", "start", "cloudflared-update.timer"); err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("systemctl start cloudflared-update.timer error")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	Log.Infof("systemctl daemon-reload")
 | 
			
		||||
	return runCommand("systemctl", "daemon-reload")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func installSysv(templateArgs *ServiceTemplateArgs) error {
 | 
			
		||||
	confPath, err := sysvTemplate.ResolvePath()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("error resolving system path")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err := sysvTemplate.Generate(templateArgs); err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("error generating system template")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	for _, i := range [...]string{"2", "3", "4", "5"} {
 | 
			
		||||
		if err := os.Symlink(confPath, "/etc/rc"+i+".d/S50et"); err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	for _, i := range [...]string{"0", "1", "6"} {
 | 
			
		||||
		if err := os.Symlink(confPath, "/etc/rc"+i+".d/K02et"); err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func uninstallLinuxService(c *cli.Context) error {
 | 
			
		||||
	switch {
 | 
			
		||||
	case isSystemd():
 | 
			
		||||
		Log.Infof("Using Systemd")
 | 
			
		||||
		return uninstallSystemd()
 | 
			
		||||
	default:
 | 
			
		||||
		Log.Infof("Using Sysv")
 | 
			
		||||
		return uninstallSysv()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func uninstallSystemd() error {
 | 
			
		||||
	if err := runCommand("systemctl", "disable", "cloudflared.service"); err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("systemctl disable cloudflared.service error")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err := runCommand("systemctl", "stop", "cloudflared-update.timer"); err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("systemctl stop cloudflared-update.timer error")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	for _, serviceTemplate := range systemdTemplates {
 | 
			
		||||
		if err := serviceTemplate.Remove(); err != nil {
 | 
			
		||||
			Log.WithError(err).Infof("error removing service template")
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	Log.Infof("Successfully uninstall cloudflared service")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func uninstallSysv() error {
 | 
			
		||||
	if err := sysvTemplate.Remove(); err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("error removing service template")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	for _, i := range [...]string{"2", "3", "4", "5"} {
 | 
			
		||||
		if err := os.Remove("/etc/rc" + i + ".d/S50et"); err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	for _, i := range [...]string{"0", "1", "6"} {
 | 
			
		||||
		if err := os.Remove("/etc/rc" + i + ".d/K02et"); err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	Log.Infof("Successfully uninstall cloudflared service")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,194 @@
 | 
			
		|||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"encoding/base32"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"syscall"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	homedir "github.com/mitchellh/go-homedir"
 | 
			
		||||
	cli "gopkg.in/urfave/cli.v2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const baseLoginURL = "https://www.cloudflare.com/a/warp"
 | 
			
		||||
const baseCertStoreURL = "https://login.cloudflarewarp.com"
 | 
			
		||||
const clientTimeout = time.Minute * 20
 | 
			
		||||
 | 
			
		||||
func login(c *cli.Context) error {
 | 
			
		||||
	configPath, err := homedir.Expand(defaultConfigDirs[0])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	ok, err := fileExists(configPath)
 | 
			
		||||
	if !ok && err == nil {
 | 
			
		||||
		// create config directory if doesn't already exist
 | 
			
		||||
		err = os.Mkdir(configPath, 0700)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	path := filepath.Join(configPath, credentialFile)
 | 
			
		||||
	fileInfo, err := os.Stat(path)
 | 
			
		||||
	if err == nil && fileInfo.Size() > 0 {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, `You have an existing certificate at %s which login would overwrite.
 | 
			
		||||
If this is intentional, please move or delete that file then run this command again.
 | 
			
		||||
`, path)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil && err.(*os.PathError).Err != syscall.ENOENT {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// for local debugging
 | 
			
		||||
	baseURL := baseCertStoreURL
 | 
			
		||||
	if c.IsSet("url") {
 | 
			
		||||
		baseURL = c.String("url")
 | 
			
		||||
	}
 | 
			
		||||
	// Generate a random post URL
 | 
			
		||||
	certURL := baseURL + generateRandomPath()
 | 
			
		||||
	loginURL, err := url.Parse(baseLoginURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		// shouldn't happen, URL is hardcoded
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	loginURL.RawQuery = "callback=" + url.QueryEscape(certURL)
 | 
			
		||||
 | 
			
		||||
	err = open(loginURL.String())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, `Please open the following URL and log in with your Cloudflare account:
 | 
			
		||||
 | 
			
		||||
%s
 | 
			
		||||
 | 
			
		||||
Leave cloudflared running to install the certificate automatically.
 | 
			
		||||
`, loginURL.String())
 | 
			
		||||
	} else {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, `A browser window should have opened at the following URL:
 | 
			
		||||
 | 
			
		||||
%s
 | 
			
		||||
 | 
			
		||||
If the browser failed to open, open it yourself and visit the URL above.
 | 
			
		||||
 | 
			
		||||
`, loginURL.String())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if download(certURL, path) {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, `You have successfully logged in.
 | 
			
		||||
If you wish to copy your credentials to a server, they have been saved to:
 | 
			
		||||
%s
 | 
			
		||||
`, path)
 | 
			
		||||
	} else {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, `Failed to write the certificate due to the following error:
 | 
			
		||||
%v
 | 
			
		||||
 | 
			
		||||
Your browser will download the certificate instead. You will have to manually
 | 
			
		||||
copy it to the following path:
 | 
			
		||||
 | 
			
		||||
%s
 | 
			
		||||
 | 
			
		||||
`, err, path)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// generateRandomPath generates a random URL to associate with the certificate.
 | 
			
		||||
func generateRandomPath() string {
 | 
			
		||||
	randomBytes := make([]byte, 40)
 | 
			
		||||
	_, err := rand.Read(randomBytes)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	return "/" + base32.StdEncoding.EncodeToString(randomBytes)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// open opens the specified URL in the default browser of the user.
 | 
			
		||||
func open(url string) error {
 | 
			
		||||
	var cmd string
 | 
			
		||||
	var args []string
 | 
			
		||||
 | 
			
		||||
	switch runtime.GOOS {
 | 
			
		||||
	case "windows":
 | 
			
		||||
		cmd = "cmd"
 | 
			
		||||
		args = []string{"/c", "start"}
 | 
			
		||||
	case "darwin":
 | 
			
		||||
		cmd = "open"
 | 
			
		||||
	default: // "linux", "freebsd", "openbsd", "netbsd"
 | 
			
		||||
		cmd = "xdg-open"
 | 
			
		||||
	}
 | 
			
		||||
	args = append(args, url)
 | 
			
		||||
	return exec.Command(cmd, args...).Start()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func download(certURL, filePath string) bool {
 | 
			
		||||
	client := &http.Client{Timeout: clientTimeout}
 | 
			
		||||
	// attempt a (long-running) certificate get
 | 
			
		||||
	for i := 0; i < 20; i++ {
 | 
			
		||||
		ok, err := tryDownload(client, certURL, filePath)
 | 
			
		||||
		if ok {
 | 
			
		||||
			putSuccess(client, certURL)
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			Log.WithError(err).Error("Error fetching certificate")
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func tryDownload(client *http.Client, certURL, filePath string) (ok bool, err error) {
 | 
			
		||||
	resp, err := client.Get(certURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
	if resp.StatusCode == 404 {
 | 
			
		||||
		return false, nil
 | 
			
		||||
	}
 | 
			
		||||
	if resp.StatusCode != 200 {
 | 
			
		||||
		return false, fmt.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
 | 
			
		||||
	}
 | 
			
		||||
	if resp.Header.Get("Content-Type") != "application/x-pem-file" {
 | 
			
		||||
		return false, fmt.Errorf("Unexpected content type %s", resp.Header.Get("Content-Type"))
 | 
			
		||||
	}
 | 
			
		||||
	// write response
 | 
			
		||||
	file, err := os.Create(filePath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	defer file.Close()
 | 
			
		||||
	written, err := io.Copy(file, resp.Body)
 | 
			
		||||
	switch {
 | 
			
		||||
	case err != nil:
 | 
			
		||||
		return false, err
 | 
			
		||||
	case resp.ContentLength != written && resp.ContentLength != -1:
 | 
			
		||||
		return false, fmt.Errorf("Short read (%d bytes) from server while writing certificate", written)
 | 
			
		||||
	default:
 | 
			
		||||
		return true, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func putSuccess(client *http.Client, certURL string) {
 | 
			
		||||
	// indicate success to the relay server
 | 
			
		||||
	req, err := http.NewRequest("PUT", certURL+"/ok", nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Error("HTTP request error")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp, err := client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Error("HTTP error")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.Body.Close()
 | 
			
		||||
	if resp.StatusCode != 200 {
 | 
			
		||||
		Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,97 @@
 | 
			
		|||
// +build darwin
 | 
			
		||||
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
 | 
			
		||||
	"gopkg.in/urfave/cli.v2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const launchAgentIdentifier = "com.cloudflare.cloudflared"
 | 
			
		||||
 | 
			
		||||
func runApp(app *cli.App) {
 | 
			
		||||
	app.Commands = append(app.Commands, &cli.Command{
 | 
			
		||||
		Name:  "service",
 | 
			
		||||
		Usage: "Manages the Argo Tunnel launch agent",
 | 
			
		||||
		Subcommands: []*cli.Command{
 | 
			
		||||
			{
 | 
			
		||||
				Name:   "install",
 | 
			
		||||
				Usage:  "Install Argo Tunnel as an user launch agent",
 | 
			
		||||
				Action: installLaunchd,
 | 
			
		||||
			},
 | 
			
		||||
			{
 | 
			
		||||
				Name:   "uninstall",
 | 
			
		||||
				Usage:  "Uninstall the Argo Tunnel launch agent",
 | 
			
		||||
				Action: uninstallLaunchd,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
	app.Run(os.Args)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var launchdTemplate = ServiceTemplate{
 | 
			
		||||
	Path: fmt.Sprintf("~/Library/LaunchAgents/%s.plist", launchAgentIdentifier),
 | 
			
		||||
	Content: fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
 | 
			
		||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
 | 
			
		||||
<plist version="1.0">
 | 
			
		||||
	<dict>
 | 
			
		||||
		<key>Label</key>
 | 
			
		||||
		<string>%s</string>
 | 
			
		||||
		<key>Program</key>
 | 
			
		||||
		<string>{{ .Path }}</string>
 | 
			
		||||
		<key>RunAtLoad</key>
 | 
			
		||||
		<true/>
 | 
			
		||||
		<key>StandardOutPath</key>
 | 
			
		||||
		<string>/tmp/%s.out.log</string>
 | 
			
		||||
    <key>StandardErrorPath</key>
 | 
			
		||||
		<string>/tmp/%s.err.log</string>
 | 
			
		||||
		<key>KeepAlive</key>
 | 
			
		||||
		<dict>
 | 
			
		||||
			<key>NetworkState</key>
 | 
			
		||||
			<true/>
 | 
			
		||||
		</dict>
 | 
			
		||||
		<key>ThrottleInterval</key>
 | 
			
		||||
		<integer>20</integer>
 | 
			
		||||
	</dict>
 | 
			
		||||
</plist>`, launchAgentIdentifier, launchAgentIdentifier, launchAgentIdentifier),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func installLaunchd(c *cli.Context) error {
 | 
			
		||||
	Log.Infof("Installing Argo Tunnel as an user launch agent")
 | 
			
		||||
	etPath, err := os.Executable()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("error determining executable path")
 | 
			
		||||
		return fmt.Errorf("error determining executable path: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	templateArgs := ServiceTemplateArgs{Path: etPath}
 | 
			
		||||
	err = launchdTemplate.Generate(&templateArgs)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("error generating launchd template")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	plistPath, err := launchdTemplate.ResolvePath()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("error resolving launchd template path")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	Log.Infof("Outputs are logged in %s and %s", fmt.Sprintf("/tmp/%s.out.log", launchAgentIdentifier), fmt.Sprintf("/tmp/%s.err.log", launchAgentIdentifier))
 | 
			
		||||
	return runCommand("launchctl", "load", plistPath)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func uninstallLaunchd(c *cli.Context) error {
 | 
			
		||||
	Log.Infof("Uninstalling Argo Tunnel as an user launch agent")
 | 
			
		||||
	plistPath, err := launchdTemplate.ResolvePath()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("error resolving launchd template path")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	err = runCommand("launchctl", "unload", plistPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("error unloading")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	Log.Infof("Outputs are logged in %s and %s", fmt.Sprintf("/tmp/%s.out.log", launchAgentIdentifier), fmt.Sprintf("/tmp/%s.err.log", launchAgentIdentifier))
 | 
			
		||||
	return launchdTemplate.Remove()
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,794 @@
 | 
			
		|||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/signal"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"syscall"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/metrics"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/origin"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/tlsconfig"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/tunneldns"
 | 
			
		||||
	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/validation"
 | 
			
		||||
 | 
			
		||||
	"github.com/facebookgo/grace/gracenet"
 | 
			
		||||
	"github.com/getsentry/raven-go"
 | 
			
		||||
	"github.com/mitchellh/go-homedir"
 | 
			
		||||
	"github.com/rifflock/lfshook"
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	"golang.org/x/crypto/ssh/terminal"
 | 
			
		||||
	"gopkg.in/urfave/cli.v2"
 | 
			
		||||
	"gopkg.in/urfave/cli.v2/altsrc"
 | 
			
		||||
 | 
			
		||||
	"github.com/coreos/go-systemd/daemon"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	sentryDSN           = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
 | 
			
		||||
	credentialFile      = "cert.pem"
 | 
			
		||||
	quickStartUrl       = "https://developers.cloudflare.com/argo-tunnel/quickstart/quickstart/"
 | 
			
		||||
	noAutoupdateMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/"
 | 
			
		||||
	licenseUrl          = "https://developers.cloudflare.com/argo-tunnel/licence/"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var listeners = gracenet.Net{}
 | 
			
		||||
var Version = "DEV"
 | 
			
		||||
var BuildTime = "unknown"
 | 
			
		||||
var Log *logrus.Logger
 | 
			
		||||
var defaultConfigFiles = []string{"config.yml", "config.yaml"}
 | 
			
		||||
 | 
			
		||||
// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible
 | 
			
		||||
var defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp"}
 | 
			
		||||
 | 
			
		||||
// Shutdown channel used by the app. When closed, app must terminate.
 | 
			
		||||
// May be closed by the Windows service runner.
 | 
			
		||||
var shutdownC chan struct{}
 | 
			
		||||
 | 
			
		||||
type BuildAndRuntimeInfo struct {
 | 
			
		||||
	GoOS        string                 `json:"go_os"`
 | 
			
		||||
	GoVersion   string                 `json:"go_version"`
 | 
			
		||||
	GoArch      string                 `json:"go_arch"`
 | 
			
		||||
	WarpVersion string                 `json:"warp_version"`
 | 
			
		||||
	WarpFlags   map[string]interface{} `json:"warp_flags"`
 | 
			
		||||
	WarpEnvs    map[string]string      `json:"warp_envs"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	metrics.RegisterBuildInfo(BuildTime, Version)
 | 
			
		||||
	raven.SetDSN(sentryDSN)
 | 
			
		||||
	raven.SetRelease(Version)
 | 
			
		||||
	shutdownC = make(chan struct{})
 | 
			
		||||
	app := &cli.App{}
 | 
			
		||||
	app.Name = "cloudflared"
 | 
			
		||||
	app.Copyright = fmt.Sprintf(`(c) %d Cloudflare Inc.
 | 
			
		||||
   Use is subject to the license agreement at %s`, time.Now().Year(), licenseUrl)
 | 
			
		||||
	app.Usage = "Cloudflare reverse tunnelling proxy agent"
 | 
			
		||||
	app.ArgsUsage = "origin-url"
 | 
			
		||||
	app.Version = fmt.Sprintf("%s (built %s)", Version, BuildTime)
 | 
			
		||||
	app.Description = `A reverse tunnel proxy agent that connects to Cloudflare's infrastructure.
 | 
			
		||||
   Upon connecting, you are assigned a unique subdomain on cftunnel.com.
 | 
			
		||||
   You need to specify a hostname on a zone you control.
 | 
			
		||||
   A DNS record will be created to CNAME your hostname to the unique subdomain on cftunnel.com.
 | 
			
		||||
 | 
			
		||||
   Requests made to Cloudflare's servers for your hostname will be proxied
 | 
			
		||||
   through the tunnel to your local webserver.`
 | 
			
		||||
	app.Flags = []cli.Flag{
 | 
			
		||||
		&cli.StringFlag{
 | 
			
		||||
			Name:  "config",
 | 
			
		||||
			Usage: "Specifies a config file in YAML format.",
 | 
			
		||||
			Value: findDefaultConfigPath(),
 | 
			
		||||
		},
 | 
			
		||||
		altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
			
		||||
			Name:  "autoupdate-freq",
 | 
			
		||||
			Usage: "Autoupdate frequency. Default is 24h.",
 | 
			
		||||
			Value: time.Hour * 24,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewBoolFlag(&cli.BoolFlag{
 | 
			
		||||
			Name:  "no-autoupdate",
 | 
			
		||||
			Usage: "Disable periodic check for updates, restarting the server with the new version.",
 | 
			
		||||
			Value: false,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewBoolFlag(&cli.BoolFlag{
 | 
			
		||||
			Name:   "is-autoupdated",
 | 
			
		||||
			Usage:  "Signal the new process that Argo Tunnel client has been autoupdated",
 | 
			
		||||
			Value:  false,
 | 
			
		||||
			Hidden: true,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
 | 
			
		||||
			Name:    "edge",
 | 
			
		||||
			Usage:   "Address of the Cloudflare tunnel server.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_EDGE"},
 | 
			
		||||
			Hidden:  true,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "cacert",
 | 
			
		||||
			Usage:   "Certificate Authority authenticating the Cloudflare tunnel connection.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_CACERT"},
 | 
			
		||||
			Hidden:  true,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "origincert",
 | 
			
		||||
			Usage:   "Path to the certificate generated for your origin when you run cloudflared login.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
 | 
			
		||||
			Value:   findDefaultOriginCertPath(),
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "url",
 | 
			
		||||
			Value:   "https://localhost:8080",
 | 
			
		||||
			Usage:   "Connect to the local webserver at `URL`.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_URL"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "hostname",
 | 
			
		||||
			Usage:   "Set a hostname on a Cloudflare zone to route traffic through this tunnel.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_HOSTNAME"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "origin-server-name",
 | 
			
		||||
			Usage:   "Hostname on the origin server certificate.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_ORIGIN_SERVER_NAME"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "id",
 | 
			
		||||
			Usage:   "A unique identifier used to tie connections to this tunnel instance.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_ID"},
 | 
			
		||||
			Hidden:  true,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "lb-pool",
 | 
			
		||||
			Usage:   "The name of a (new/existing) load balancing pool to add this origin to.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_LB_POOL"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "api-key",
 | 
			
		||||
			Usage:   "This parameter has been deprecated since version 2017.10.1.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_API_KEY"},
 | 
			
		||||
			Hidden:  true,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "api-email",
 | 
			
		||||
			Usage:   "This parameter has been deprecated since version 2017.10.1.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_API_EMAIL"},
 | 
			
		||||
			Hidden:  true,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "api-ca-key",
 | 
			
		||||
			Usage:   "This parameter has been deprecated since version 2017.10.1.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_API_CA_KEY"},
 | 
			
		||||
			Hidden:  true,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "metrics",
 | 
			
		||||
			Value:   "localhost:",
 | 
			
		||||
			Usage:   "Listen address for metrics reporting.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_METRICS"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
			
		||||
			Name:    "metrics-update-freq",
 | 
			
		||||
			Usage:   "Frequency to update tunnel metrics",
 | 
			
		||||
			Value:   time.Second * 5,
 | 
			
		||||
			EnvVars: []string{"TUNNEL_METRICS_UPDATE_FREQ"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
 | 
			
		||||
			Name:    "tag",
 | 
			
		||||
			Usage:   "Custom tags used to identify this tunnel, in format `KEY=VALUE`. Multiple tags may be specified",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_TAG"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
			
		||||
			Name:   "heartbeat-interval",
 | 
			
		||||
			Usage:  "Minimum idle time before sending a heartbeat.",
 | 
			
		||||
			Value:  time.Second * 5,
 | 
			
		||||
			Hidden: true,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewUint64Flag(&cli.Uint64Flag{
 | 
			
		||||
			Name:   "heartbeat-count",
 | 
			
		||||
			Usage:  "Minimum number of unacked heartbeats to send before closing the connection.",
 | 
			
		||||
			Value:  5,
 | 
			
		||||
			Hidden: true,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "loglevel",
 | 
			
		||||
			Value:   "info",
 | 
			
		||||
			Usage:   "Application logging level {panic, fatal, error, warn, info, debug}",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_LOGLEVEL"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "proto-loglevel",
 | 
			
		||||
			Value:   "warn",
 | 
			
		||||
			Usage:   "Protocol logging level {panic, fatal, error, warn, info, debug}",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_PROTO_LOGLEVEL"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewUintFlag(&cli.UintFlag{
 | 
			
		||||
			Name:    "retries",
 | 
			
		||||
			Value:   5,
 | 
			
		||||
			Usage:   "Maximum number of retries for connection/protocol errors.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_RETRIES"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewBoolFlag(&cli.BoolFlag{
 | 
			
		||||
			Name:    "hello-world",
 | 
			
		||||
			Value:   false,
 | 
			
		||||
			Usage:   "Run Hello World Server",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_HELLO_WORLD"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "pidfile",
 | 
			
		||||
			Usage:   "Write the application's PID to this file after first successful connection.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_PIDFILE"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "logfile",
 | 
			
		||||
			Usage:   "Save application log to this file for reporting issues.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_LOGFILE"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewIntFlag(&cli.IntFlag{
 | 
			
		||||
			Name:   "ha-connections",
 | 
			
		||||
			Value:  4,
 | 
			
		||||
			Hidden: true,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
			
		||||
			Name:  "proxy-connect-timeout",
 | 
			
		||||
			Usage: "HTTP proxy timeout for establishing a new connection",
 | 
			
		||||
			Value: time.Second * 30,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
			
		||||
			Name:  "proxy-tls-timeout",
 | 
			
		||||
			Usage: "HTTP proxy timeout for completing a TLS handshake",
 | 
			
		||||
			Value: time.Second * 10,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
			
		||||
			Name:  "proxy-tcp-keepalive",
 | 
			
		||||
			Usage: "HTTP proxy TCP keepalive duration",
 | 
			
		||||
			Value: time.Second * 30,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewBoolFlag(&cli.BoolFlag{
 | 
			
		||||
			Name:  "proxy-no-happy-eyeballs",
 | 
			
		||||
			Usage: "HTTP proxy should disable \"happy eyeballs\" for IPv4/v6 fallback",
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewIntFlag(&cli.IntFlag{
 | 
			
		||||
			Name:  "proxy-keepalive-connections",
 | 
			
		||||
			Usage: "HTTP proxy maximum keepalive connection pool size",
 | 
			
		||||
			Value: 100,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
			
		||||
			Name:  "proxy-keepalive-timeout",
 | 
			
		||||
			Usage: "HTTP proxy timeout for closing an idle connection",
 | 
			
		||||
			Value: time.Second * 90,
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewBoolFlag(&cli.BoolFlag{
 | 
			
		||||
			Name:    "proxy-dns",
 | 
			
		||||
			Usage:   "Run a DNS over HTTPS proxy server.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_DNS"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewUintFlag(&cli.UintFlag{
 | 
			
		||||
			Name:    "proxy-dns-port",
 | 
			
		||||
			Value:   53,
 | 
			
		||||
			Usage:   "Listen on given port for the DNS over HTTPS proxy server.",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_DNS_PORT"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringFlag(&cli.StringFlag{
 | 
			
		||||
			Name:    "proxy-dns-address",
 | 
			
		||||
			Usage:   "Listen address for the DNS over HTTPS proxy server.",
 | 
			
		||||
			Value:   "localhost",
 | 
			
		||||
			EnvVars: []string{"TUNNEL_DNS_ADDRESS"},
 | 
			
		||||
		}),
 | 
			
		||||
		altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
 | 
			
		||||
			Name:    "proxy-dns-upstream",
 | 
			
		||||
			Usage:   "Upstream endpoint URL, you can specify multiple endpoints for redundancy.",
 | 
			
		||||
			Value:   cli.NewStringSlice("https://dns.cloudflare.com/.well-known/dns-query"),
 | 
			
		||||
			EnvVars: []string{"TUNNEL_DNS_UPSTREAM"},
 | 
			
		||||
		}),
 | 
			
		||||
	}
 | 
			
		||||
	app.Action = func(c *cli.Context) error {
 | 
			
		||||
		raven.CapturePanic(func() { startServer(c) }, nil)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	app.Before = func(context *cli.Context) error {
 | 
			
		||||
		Log = logrus.New()
 | 
			
		||||
		inputSource, err := findInputSourceContext(context)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			Log.WithError(err).Infof("Cannot load configuration from %s", context.String("config"))
 | 
			
		||||
			return err
 | 
			
		||||
		} else if inputSource != nil {
 | 
			
		||||
			err := altsrc.ApplyInputSourceValues(context, inputSource, app.Flags)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				Log.WithError(err).Infof("Cannot apply configuration from %s", context.String("config"))
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			Log.Infof("Applied configuration from %s", context.String("config"))
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	app.Commands = []*cli.Command{
 | 
			
		||||
		{
 | 
			
		||||
			Name:      "update",
 | 
			
		||||
			Action:    update,
 | 
			
		||||
			Usage:     "Update the agent if a new version exists",
 | 
			
		||||
			ArgsUsage: " ",
 | 
			
		||||
			Description: `Looks for a new version on the offical download server.
 | 
			
		||||
   If a new version exists, updates the agent binary and quits.
 | 
			
		||||
   Otherwise, does nothing.
 | 
			
		||||
 | 
			
		||||
   To determine if an update happened in a script, check for error code 64.`,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Name:      "login",
 | 
			
		||||
			Action:    login,
 | 
			
		||||
			Usage:     "Generate a configuration file with your login details",
 | 
			
		||||
			ArgsUsage: " ",
 | 
			
		||||
			Flags: []cli.Flag{
 | 
			
		||||
				&cli.StringFlag{
 | 
			
		||||
					Name:   "url",
 | 
			
		||||
					Hidden: true,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Name:   "hello",
 | 
			
		||||
			Action: hello,
 | 
			
		||||
			Usage:  "Run a simple \"Hello World\" server for testing Argo Tunnel.",
 | 
			
		||||
			Flags: []cli.Flag{
 | 
			
		||||
				&cli.IntFlag{
 | 
			
		||||
					Name:  "port",
 | 
			
		||||
					Usage: "Listen on the selected port.",
 | 
			
		||||
					Value: 8080,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			ArgsUsage: " ", // can't be the empty string or we get the default output
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Name:   "proxy-dns",
 | 
			
		||||
			Action: tunneldns.Run,
 | 
			
		||||
			Usage:  "Run a DNS over HTTPS proxy server.",
 | 
			
		||||
			Flags: []cli.Flag{
 | 
			
		||||
				&cli.StringFlag{
 | 
			
		||||
					Name:    "metrics",
 | 
			
		||||
					Value:   "localhost:",
 | 
			
		||||
					Usage:   "Listen address for metrics reporting.",
 | 
			
		||||
					EnvVars: []string{"TUNNEL_METRICS"},
 | 
			
		||||
				},
 | 
			
		||||
				&cli.StringFlag{
 | 
			
		||||
					Name:    "address",
 | 
			
		||||
					Usage:   "Listen address for the DNS over HTTPS proxy server.",
 | 
			
		||||
					Value:   "localhost",
 | 
			
		||||
					EnvVars: []string{"TUNNEL_DNS_ADDRESS"},
 | 
			
		||||
				},
 | 
			
		||||
				&cli.IntFlag{
 | 
			
		||||
					Name:    "port",
 | 
			
		||||
					Usage:   "Listen on given port for the DNS over HTTPS proxy server.",
 | 
			
		||||
					Value:   53,
 | 
			
		||||
					EnvVars: []string{"TUNNEL_DNS_PORT"},
 | 
			
		||||
				},
 | 
			
		||||
				&cli.StringSliceFlag{
 | 
			
		||||
					Name:    "upstream",
 | 
			
		||||
					Usage:   "Upstream endpoint URL, you can specify multiple endpoints for redundancy.",
 | 
			
		||||
					Value:   cli.NewStringSlice("https://dns.cloudflare.com/.well-known/dns-query"),
 | 
			
		||||
					EnvVars: []string{"TUNNEL_DNS_UPSTREAM"},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			ArgsUsage: " ", // can't be the empty string or we get the default output
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	runApp(app)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func startServer(c *cli.Context) {
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	errC := make(chan error)
 | 
			
		||||
	wg.Add(2)
 | 
			
		||||
 | 
			
		||||
	// If the user choose to supply all options through env variables,
 | 
			
		||||
	// c.NumFlags() == 0 && c.NArg() == 0. For cloudflared to work, the user needs to at
 | 
			
		||||
	// least provide a hostname.
 | 
			
		||||
	if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" {
 | 
			
		||||
		Log.Infof("No arguments were provided. You need to at least specify the hostname for this tunnel. See %s", quickStartUrl)
 | 
			
		||||
		cli.ShowAppHelp(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	logLevel, err := logrus.ParseLevel(c.String("loglevel"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Fatal("Unknown logging level specified")
 | 
			
		||||
	}
 | 
			
		||||
	Log.SetLevel(logLevel)
 | 
			
		||||
 | 
			
		||||
	protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Fatal("Unknown protocol logging level specified")
 | 
			
		||||
	}
 | 
			
		||||
	protoLogger := logrus.New()
 | 
			
		||||
	protoLogger.Level = protoLogLevel
 | 
			
		||||
 | 
			
		||||
	if c.String("logfile") != "" {
 | 
			
		||||
		if err := initLogFile(c, protoLogger); err != nil {
 | 
			
		||||
			Log.Error(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if isAutoupdateEnabled(c) {
 | 
			
		||||
		if initUpdate() {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		Log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
 | 
			
		||||
		go autoupdate(c.Duration("autoupdate-freq"), shutdownC)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	hostname, err := validation.ValidateHostname(c.String("hostname"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Fatal("Invalid hostname")
 | 
			
		||||
	}
 | 
			
		||||
	clientID := c.String("id")
 | 
			
		||||
	if !c.IsSet("id") {
 | 
			
		||||
		clientID = generateRandomClientID()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Fatal("Tag parse failure")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
 | 
			
		||||
	if c.IsSet("hello-world") {
 | 
			
		||||
		wg.Add(1)
 | 
			
		||||
		listener, err := createListener("127.0.0.1:")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			listener.Close()
 | 
			
		||||
			Log.WithError(err).Fatal("Cannot start Hello World Server")
 | 
			
		||||
		}
 | 
			
		||||
		go func() {
 | 
			
		||||
			startHelloWorldServer(listener, shutdownC)
 | 
			
		||||
			wg.Done()
 | 
			
		||||
			listener.Close()
 | 
			
		||||
		}()
 | 
			
		||||
		c.Set("url", "https://"+listener.Addr().String())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if c.IsSet("proxy-dns") {
 | 
			
		||||
		wg.Add(1)
 | 
			
		||||
		listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(c.Uint("proxy-dns-port")), c.StringSlice("proxy-dns-upstream"))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			listener.Stop()
 | 
			
		||||
			Log.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server")
 | 
			
		||||
		}
 | 
			
		||||
		go func() {
 | 
			
		||||
			listener.Start()
 | 
			
		||||
			<-shutdownC
 | 
			
		||||
			listener.Stop()
 | 
			
		||||
			wg.Done()
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url, err := validateUrl(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Fatal("Error validating url")
 | 
			
		||||
	}
 | 
			
		||||
	Log.Infof("Proxying tunnel requests to %s", url)
 | 
			
		||||
 | 
			
		||||
	// Fail if the user provided an old authentication method
 | 
			
		||||
	if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
 | 
			
		||||
		Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflared login")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check that the user has acquired a certificate using the log in command
 | 
			
		||||
	originCertPath, err := homedir.Expand(c.String("origincert"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
 | 
			
		||||
	}
 | 
			
		||||
	ok, err := fileExists(originCertPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.Fatalf("Cannot check if origin cert exists at path %s", c.String("origincert"))
 | 
			
		||||
	}
 | 
			
		||||
	if !ok {
 | 
			
		||||
		Log.Fatalf(`Cannot find a valid certificate for your origin at the path:
 | 
			
		||||
 | 
			
		||||
    %s
 | 
			
		||||
 | 
			
		||||
If the path above is wrong, specify the path with the -origincert option.
 | 
			
		||||
If you don't have a certificate signed by Cloudflare, run the command:
 | 
			
		||||
 | 
			
		||||
    %s login
 | 
			
		||||
`, originCertPath, os.Args[0])
 | 
			
		||||
	}
 | 
			
		||||
	// Easier to send the certificate as []byte via RPC than decoding it at this point
 | 
			
		||||
	originCert, err := ioutil.ReadFile(originCertPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tunnelMetrics := origin.NewTunnelMetrics()
 | 
			
		||||
	httpTransport := &http.Transport{
 | 
			
		||||
		Proxy: http.ProxyFromEnvironment,
 | 
			
		||||
		DialContext: (&net.Dialer{
 | 
			
		||||
			Timeout:   c.Duration("proxy-connect-timeout"),
 | 
			
		||||
			KeepAlive: c.Duration("proxy-tcp-keepalive"),
 | 
			
		||||
			DualStack: !c.Bool("proxy-no-happy-eyeballs"),
 | 
			
		||||
		}).DialContext,
 | 
			
		||||
		MaxIdleConns:          c.Int("proxy-keepalive-connections"),
 | 
			
		||||
		IdleConnTimeout:       c.Duration("proxy-keepalive-timeout"),
 | 
			
		||||
		TLSHandshakeTimeout:   c.Duration("proxy-tls-timeout"),
 | 
			
		||||
		ExpectContinueTimeout: 1 * time.Second,
 | 
			
		||||
		TLSClientConfig:       &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !c.IsSet("hello-world") && c.IsSet("origin-server-name") {
 | 
			
		||||
		httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tunnelConfig := &origin.TunnelConfig{
 | 
			
		||||
		EdgeAddrs:         c.StringSlice("edge"),
 | 
			
		||||
		OriginUrl:         url,
 | 
			
		||||
		Hostname:          hostname,
 | 
			
		||||
		OriginCert:        originCert,
 | 
			
		||||
		TlsConfig:         tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
 | 
			
		||||
		ClientTlsConfig:   httpTransport.TLSClientConfig,
 | 
			
		||||
		Retries:           c.Uint("retries"),
 | 
			
		||||
		HeartbeatInterval: c.Duration("heartbeat-interval"),
 | 
			
		||||
		MaxHeartbeats:     c.Uint64("heartbeat-count"),
 | 
			
		||||
		ClientID:          clientID,
 | 
			
		||||
		ReportedVersion:   Version,
 | 
			
		||||
		LBPool:            c.String("lb-pool"),
 | 
			
		||||
		Tags:              tags,
 | 
			
		||||
		HAConnections:     c.Int("ha-connections"),
 | 
			
		||||
		HTTPTransport:     httpTransport,
 | 
			
		||||
		Metrics:           tunnelMetrics,
 | 
			
		||||
		MetricsUpdateFreq: c.Duration("metrics-update-freq"),
 | 
			
		||||
		ProtocolLogger:    protoLogger,
 | 
			
		||||
		Logger:            Log,
 | 
			
		||||
		IsAutoupdated:     c.Bool("is-autoupdated"),
 | 
			
		||||
	}
 | 
			
		||||
	connectedSignal := make(chan struct{})
 | 
			
		||||
 | 
			
		||||
	go writePidFile(connectedSignal, c.String("pidfile"))
 | 
			
		||||
	go func() {
 | 
			
		||||
		errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal)
 | 
			
		||||
		wg.Done()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Fatal("Error opening metrics server listener")
 | 
			
		||||
	}
 | 
			
		||||
	go func() {
 | 
			
		||||
		errC <- metrics.ServeMetrics(metricsListener, shutdownC)
 | 
			
		||||
		wg.Done()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	var errCode int
 | 
			
		||||
	err = WaitForSignal(errC, shutdownC)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Error("Quitting due to error")
 | 
			
		||||
		raven.CaptureErrorAndWait(err, nil)
 | 
			
		||||
		errCode = 1
 | 
			
		||||
	} else {
 | 
			
		||||
		Log.Info("Quitting...")
 | 
			
		||||
	}
 | 
			
		||||
	// Wait for clean exit, discarding all errors
 | 
			
		||||
	go func() {
 | 
			
		||||
		for range errC {
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
	os.Exit(errCode)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
 | 
			
		||||
	signals := make(chan os.Signal, 10)
 | 
			
		||||
	signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
 | 
			
		||||
	defer signal.Stop(signals)
 | 
			
		||||
	select {
 | 
			
		||||
	case err := <-errC:
 | 
			
		||||
		close(shutdownC)
 | 
			
		||||
		return err
 | 
			
		||||
	case <-signals:
 | 
			
		||||
		close(shutdownC)
 | 
			
		||||
	case <-shutdownC:
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func update(_ *cli.Context) error {
 | 
			
		||||
	if updateApplied() {
 | 
			
		||||
		os.Exit(64)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func initUpdate() bool {
 | 
			
		||||
	if updateApplied() {
 | 
			
		||||
		os.Args = append(os.Args, "--is-autoupdated=true")
 | 
			
		||||
		if _, err := listeners.StartProcess(); err != nil {
 | 
			
		||||
			Log.WithError(err).Error("Unable to restart server automatically")
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func autoupdate(freq time.Duration, shutdownC chan struct{}) {
 | 
			
		||||
	for {
 | 
			
		||||
		if updateApplied() {
 | 
			
		||||
			os.Args = append(os.Args, "--is-autoupdated=true")
 | 
			
		||||
			if _, err := listeners.StartProcess(); err != nil {
 | 
			
		||||
				Log.WithError(err).Error("Unable to restart server automatically")
 | 
			
		||||
			}
 | 
			
		||||
			close(shutdownC)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		time.Sleep(freq)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateApplied() bool {
 | 
			
		||||
	releaseInfo := checkForUpdates()
 | 
			
		||||
	if releaseInfo.Updated {
 | 
			
		||||
		Log.Infof("Updated to version %s", releaseInfo.Version)
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if releaseInfo.Error != nil {
 | 
			
		||||
		Log.WithError(releaseInfo.Error).Error("Update check failed")
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func fileExists(path string) (bool, error) {
 | 
			
		||||
	f, err := os.Open(path)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if os.IsNotExist(err) {
 | 
			
		||||
			// ignore missing files
 | 
			
		||||
			return false, nil
 | 
			
		||||
		}
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	f.Close()
 | 
			
		||||
	return true, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// returns the first path that contains a cert.pem file. If none of the defaultConfigDirs
 | 
			
		||||
// (differs by OS for legacy reasons) contains a cert.pem file, return empty string
 | 
			
		||||
func findDefaultOriginCertPath() string {
 | 
			
		||||
	for _, defaultConfigDir := range defaultConfigDirs {
 | 
			
		||||
		originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, credentialFile))
 | 
			
		||||
		if ok, _ := fileExists(originCertPath); ok {
 | 
			
		||||
			return originCertPath
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// returns the firt path that contains a config file. If none of the combination of
 | 
			
		||||
// defaultConfigDirs (differs by OS for legacy reasons) and defaultConfigFiles
 | 
			
		||||
// contains a config file, return empty string
 | 
			
		||||
func findDefaultConfigPath() string {
 | 
			
		||||
	for _, configDir := range defaultConfigDirs {
 | 
			
		||||
		for _, configFile := range defaultConfigFiles {
 | 
			
		||||
			dirPath, err := homedir.Expand(configDir)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return ""
 | 
			
		||||
			}
 | 
			
		||||
			path := filepath.Join(dirPath, configFile)
 | 
			
		||||
			if ok, _ := fileExists(path); ok {
 | 
			
		||||
				return path
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) {
 | 
			
		||||
	if context.String("config") != "" {
 | 
			
		||||
		return altsrc.NewYamlSourceFromFile(context.String("config"))
 | 
			
		||||
	}
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func generateRandomClientID() string {
 | 
			
		||||
	r := rand.New(rand.NewSource(time.Now().UnixNano()))
 | 
			
		||||
	id := make([]byte, 32)
 | 
			
		||||
	r.Read(id)
 | 
			
		||||
	return hex.EncodeToString(id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func writePidFile(waitForSignal chan struct{}, pidFile string) {
 | 
			
		||||
	<-waitForSignal
 | 
			
		||||
	daemon.SdNotify(false, "READY=1")
 | 
			
		||||
	if pidFile == "" {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	file, err := os.Create(pidFile)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
 | 
			
		||||
	}
 | 
			
		||||
	defer file.Close()
 | 
			
		||||
	fmt.Fprintf(file, "%d", os.Getpid())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// validate url. It can be either from --url or argument
 | 
			
		||||
func validateUrl(c *cli.Context) (string, error) {
 | 
			
		||||
	var url = c.String("url")
 | 
			
		||||
	if c.NArg() > 0 {
 | 
			
		||||
		if c.IsSet("url") {
 | 
			
		||||
			return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
 | 
			
		||||
		}
 | 
			
		||||
		url = c.Args().Get(0)
 | 
			
		||||
	}
 | 
			
		||||
	validUrl, err := validation.ValidateUrl(url)
 | 
			
		||||
	return validUrl, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error {
 | 
			
		||||
	filePath, err := homedir.Expand(c.String("logfile"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.Wrap(err, "Cannot resolve logfile path")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC
 | 
			
		||||
	// do not truncate log file if the client has been autoupdated
 | 
			
		||||
	if c.Bool("is-autoupdated") {
 | 
			
		||||
		fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE
 | 
			
		||||
	}
 | 
			
		||||
	f, err := os.OpenFile(filePath, fileMode, 0664)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		errors.Wrap(err, fmt.Sprintf("Cannot open file %s", filePath))
 | 
			
		||||
	}
 | 
			
		||||
	defer f.Close()
 | 
			
		||||
	pathMap := lfshook.PathMap{
 | 
			
		||||
		logrus.InfoLevel:  filePath,
 | 
			
		||||
		logrus.ErrorLevel: filePath,
 | 
			
		||||
		logrus.FatalLevel: filePath,
 | 
			
		||||
		logrus.PanicLevel: filePath,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	Log.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
 | 
			
		||||
	protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
 | 
			
		||||
 | 
			
		||||
	flags := make(map[string]interface{})
 | 
			
		||||
	envs := make(map[string]string)
 | 
			
		||||
 | 
			
		||||
	for _, flag := range c.LocalFlagNames() {
 | 
			
		||||
		flags[flag] = c.Generic(flag)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Find env variables for Argo Tunnel
 | 
			
		||||
	for _, env := range os.Environ() {
 | 
			
		||||
		// All Argo Tunnel env variables start with TUNNEL_
 | 
			
		||||
		if strings.Contains(env, "TUNNEL_") {
 | 
			
		||||
			vars := strings.Split(env, "=")
 | 
			
		||||
			if len(vars) == 2 {
 | 
			
		||||
				envs[vars[0]] = vars[1]
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	Log.Infof("Argo Tunnel build and runtime configuration: %+v", BuildAndRuntimeInfo{
 | 
			
		||||
		GoOS:        runtime.GOOS,
 | 
			
		||||
		GoVersion:   runtime.Version(),
 | 
			
		||||
		GoArch:      runtime.GOARCH,
 | 
			
		||||
		WarpVersion: Version,
 | 
			
		||||
		WarpFlags:   flags,
 | 
			
		||||
		WarpEnvs:    envs,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isAutoupdateEnabled(c *cli.Context) bool {
 | 
			
		||||
	if terminal.IsTerminal(int(os.Stdout.Fd())) {
 | 
			
		||||
		Log.Info(noAutoupdateMessage)
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,192 @@
 | 
			
		|||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"text/template"
 | 
			
		||||
 | 
			
		||||
	"github.com/mitchellh/go-homedir"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ServiceTemplate struct {
 | 
			
		||||
	Path     string
 | 
			
		||||
	Content  string
 | 
			
		||||
	FileMode os.FileMode
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ServiceTemplateArgs struct {
 | 
			
		||||
	Path string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (st *ServiceTemplate) ResolvePath() (string, error) {
 | 
			
		||||
	resolvedPath, err := homedir.Expand(st.Path)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error resolving path %s: %v", st.Path, err)
 | 
			
		||||
	}
 | 
			
		||||
	return resolvedPath, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
 | 
			
		||||
	tmpl, err := template.New(st.Path).Parse(st.Content)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error generating %s template: %v", st.Path, err)
 | 
			
		||||
	}
 | 
			
		||||
	resolvedPath, err := st.ResolvePath()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	var buffer bytes.Buffer
 | 
			
		||||
	err = tmpl.Execute(&buffer, args)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error generating %s: %v", st.Path, err)
 | 
			
		||||
	}
 | 
			
		||||
	fileMode := os.FileMode(0644)
 | 
			
		||||
	if st.FileMode != 0 {
 | 
			
		||||
		fileMode = st.FileMode
 | 
			
		||||
	}
 | 
			
		||||
	err = ioutil.WriteFile(resolvedPath, buffer.Bytes(), fileMode)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error writing %s: %v", resolvedPath, err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (st *ServiceTemplate) Remove() error {
 | 
			
		||||
	resolvedPath, err := st.ResolvePath()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	err = os.Remove(resolvedPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error deleting %s: %v", resolvedPath, err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func runCommand(command string, args ...string) error {
 | 
			
		||||
	cmd := exec.Command(command, args...)
 | 
			
		||||
	stderr, err := cmd.StderrPipe()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("error getting stderr pipe")
 | 
			
		||||
		return fmt.Errorf("error getting stderr pipe: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	err = cmd.Start()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("error starting %s", command)
 | 
			
		||||
		return fmt.Errorf("error starting %s: %v", command, err)
 | 
			
		||||
	}
 | 
			
		||||
	commandErr, _ := ioutil.ReadAll(stderr)
 | 
			
		||||
	if len(commandErr) > 0 {
 | 
			
		||||
		Log.Errorf("%s: %s", command, commandErr)
 | 
			
		||||
	}
 | 
			
		||||
	err = cmd.Wait()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("%s returned error", command)
 | 
			
		||||
		return fmt.Errorf("%s returned with error: %v", command, err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ensureConfigDirExists(configDir string) error {
 | 
			
		||||
	ok, err := fileExists(configDir)
 | 
			
		||||
	if !ok && err == nil {
 | 
			
		||||
		err = os.Mkdir(configDir, 0700)
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// openFile opens the file at path. If create is set and the file exists, returns nil, true, nil
 | 
			
		||||
func openFile(path string, create bool) (file *os.File, exists bool, err error) {
 | 
			
		||||
	expandedPath, err := homedir.Expand(path)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, false, err
 | 
			
		||||
	}
 | 
			
		||||
	if create {
 | 
			
		||||
		fileInfo, err := os.Stat(expandedPath)
 | 
			
		||||
		if err == nil && fileInfo.Size() > 0 {
 | 
			
		||||
			return nil, true, nil
 | 
			
		||||
		}
 | 
			
		||||
		file, err = os.OpenFile(expandedPath, os.O_RDWR|os.O_CREATE, 0600)
 | 
			
		||||
	} else {
 | 
			
		||||
		file, err = os.Open(expandedPath)
 | 
			
		||||
	}
 | 
			
		||||
	return file, false, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func copyCertificate(srcConfigDir, destConfigDir string) error {
 | 
			
		||||
	destCredentialPath := filepath.Join(destConfigDir, credentialFile)
 | 
			
		||||
	destFile, exists, err := openFile(destCredentialPath, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	} else if exists {
 | 
			
		||||
		// credentials already exist, do nothing
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	defer destFile.Close()
 | 
			
		||||
 | 
			
		||||
	srcCredentialPath := filepath.Join(srcConfigDir, credentialFile)
 | 
			
		||||
	srcFile, _, err := openFile(srcCredentialPath, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer srcFile.Close()
 | 
			
		||||
 | 
			
		||||
	// Copy certificate
 | 
			
		||||
	_, err = io.Copy(destFile, srcFile)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("unable to copy %s to %s: %v", srcCredentialPath, destCredentialPath, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile string) error {
 | 
			
		||||
	if err := ensureConfigDirExists(serviceConfigDir); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := copyCertificate(defaultConfigDir, serviceConfigDir); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Copy or create config
 | 
			
		||||
	destConfigPath := filepath.Join(serviceConfigDir, defaultConfigFile)
 | 
			
		||||
	destFile, exists, err := openFile(destConfigPath, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("cannot open %s", destConfigPath)
 | 
			
		||||
		return err
 | 
			
		||||
	} else if exists {
 | 
			
		||||
		// config already exists, do nothing
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	defer destFile.Close()
 | 
			
		||||
 | 
			
		||||
	srcConfigPath := filepath.Join(defaultConfigDir, defaultConfigFile)
 | 
			
		||||
	srcFile, _, err := openFile(srcConfigPath, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		fmt.Println("Your service needs a config file that at least specifies the hostname option.")
 | 
			
		||||
		fmt.Println("Type in a hostname now, or leave it blank and create the config file later.")
 | 
			
		||||
		fmt.Print("Hostname: ")
 | 
			
		||||
		reader := bufio.NewReader(os.Stdin)
 | 
			
		||||
		input, _ := reader.ReadString('\n')
 | 
			
		||||
		if input == "" {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		fmt.Fprintf(destFile, "hostname: %s\n", input)
 | 
			
		||||
	} else {
 | 
			
		||||
		defer srcFile.Close()
 | 
			
		||||
		_, err = io.Copy(destFile, srcFile)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("unable to copy %s to %s: %v", srcConfigPath, destConfigPath, err)
 | 
			
		||||
		}
 | 
			
		||||
		Log.Infof("Copied %s to %s", srcConfigPath, destConfigPath)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,32 @@
 | 
			
		|||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"regexp"
 | 
			
		||||
 | 
			
		||||
	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Restrict key names to characters allowed in an HTTP header name.
 | 
			
		||||
// Restrict key values to printable characters (what is recognised as data in an HTTP header value).
 | 
			
		||||
var tagRegexp = regexp.MustCompile("^([a-zA-Z0-9!#$%&'*+\\-.^_`|~]+)=([[:print:]]+)$")
 | 
			
		||||
 | 
			
		||||
func NewTagFromCLI(compoundTag string) (tunnelpogs.Tag, bool) {
 | 
			
		||||
	matches := tagRegexp.FindStringSubmatch(compoundTag)
 | 
			
		||||
	if len(matches) == 0 {
 | 
			
		||||
		return tunnelpogs.Tag{}, false
 | 
			
		||||
	}
 | 
			
		||||
	return tunnelpogs.Tag{Name: matches[1], Value: matches[2]}, true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewTagSliceFromCLI(tags []string) ([]tunnelpogs.Tag, error) {
 | 
			
		||||
	var tagSlice []tunnelpogs.Tag
 | 
			
		||||
	for _, compoundTag := range tags {
 | 
			
		||||
		if tag, ok := NewTagFromCLI(compoundTag); ok {
 | 
			
		||||
			tagSlice = append(tagSlice, tag)
 | 
			
		||||
		} else {
 | 
			
		||||
			return nil, fmt.Errorf("Cannot parse tag value %s", compoundTag)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return tagSlice, nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,46 @@
 | 
			
		|||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestSingleTag(t *testing.T) {
 | 
			
		||||
	testCases := []struct {
 | 
			
		||||
		Input  string
 | 
			
		||||
		Output tunnelpogs.Tag
 | 
			
		||||
		Fail   bool
 | 
			
		||||
	}{
 | 
			
		||||
		{Input: "x=y", Output: tunnelpogs.Tag{Name: "x", Value: "y"}},
 | 
			
		||||
		{Input: "More-Complex=Tag Values", Output: tunnelpogs.Tag{Name: "More-Complex", Value: "Tag Values"}},
 | 
			
		||||
		{Input: "First=Equals=Wins", Output: tunnelpogs.Tag{Name: "First", Value: "Equals=Wins"}},
 | 
			
		||||
		{Input: "x=", Fail: true},
 | 
			
		||||
		{Input: "=y", Fail: true},
 | 
			
		||||
		{Input: "=", Fail: true},
 | 
			
		||||
		{Input: "No spaces allowed=in key names", Fail: true},
 | 
			
		||||
		{Input: "omg\nwtf=bbq", Fail: true},
 | 
			
		||||
	}
 | 
			
		||||
	for i, testCase := range testCases {
 | 
			
		||||
		tag, ok := NewTagFromCLI(testCase.Input)
 | 
			
		||||
		assert.Equalf(t, !testCase.Fail, ok, "mismatched success for test case %d", i)
 | 
			
		||||
		assert.Equalf(t, testCase.Output, tag, "mismatched output for test case %d", i)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTagSlice(t *testing.T) {
 | 
			
		||||
	tagSlice, err := NewTagSliceFromCLI([]string{"a=b", "c=d", "e=f"})
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Len(t, tagSlice, 3)
 | 
			
		||||
	assert.Equal(t, "a", tagSlice[0].Name)
 | 
			
		||||
	assert.Equal(t, "b", tagSlice[0].Value)
 | 
			
		||||
	assert.Equal(t, "c", tagSlice[1].Name)
 | 
			
		||||
	assert.Equal(t, "d", tagSlice[1].Value)
 | 
			
		||||
	assert.Equal(t, "e", tagSlice[2].Name)
 | 
			
		||||
	assert.Equal(t, "f", tagSlice[2].Value)
 | 
			
		||||
 | 
			
		||||
	tagSlice, err = NewTagSliceFromCLI([]string{"a=b", "=", "e=f"})
 | 
			
		||||
	assert.Error(t, err)
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,41 @@
 | 
			
		|||
package main
 | 
			
		||||
 | 
			
		||||
import "github.com/equinox-io/equinox"
 | 
			
		||||
 | 
			
		||||
const appID = "app_cwbQae3Tpea"
 | 
			
		||||
 | 
			
		||||
var publicKey = []byte(`
 | 
			
		||||
-----BEGIN ECDSA PUBLIC KEY-----
 | 
			
		||||
MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE4OWZocTVZ8Do/L6ScLdkV+9A0IYMHoOf
 | 
			
		||||
dsCmJ/QZ6aw0w9qkkwEpne1Lmo6+0pGexZzFZOH6w5amShn+RXt7qkSid9iWlzGq
 | 
			
		||||
EKx0BZogHSor9Wy5VztdFaAaVbsJiCbO
 | 
			
		||||
-----END ECDSA PUBLIC KEY-----
 | 
			
		||||
`)
 | 
			
		||||
 | 
			
		||||
type ReleaseInfo struct {
 | 
			
		||||
	Updated bool
 | 
			
		||||
	Version string
 | 
			
		||||
	Error   error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func checkForUpdates() ReleaseInfo {
 | 
			
		||||
	var opts equinox.Options
 | 
			
		||||
	if err := opts.SetPublicKeyPEM(publicKey); err != nil {
 | 
			
		||||
		return ReleaseInfo{Error: err}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp, err := equinox.Check(appID, opts)
 | 
			
		||||
	switch {
 | 
			
		||||
	case err == equinox.NotAvailableErr:
 | 
			
		||||
		return ReleaseInfo{}
 | 
			
		||||
	case err != nil:
 | 
			
		||||
		return ReleaseInfo{Error: err}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = resp.Apply()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ReleaseInfo{Error: err}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,166 @@
 | 
			
		|||
// +build windows
 | 
			
		||||
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
// Copypasta from the example files:
 | 
			
		||||
// https://github.com/golang/sys/blob/master/windows/svc/example
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
 | 
			
		||||
	cli "gopkg.in/urfave/cli.v2"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/sys/windows/svc"
 | 
			
		||||
	"golang.org/x/sys/windows/svc/eventlog"
 | 
			
		||||
	"golang.org/x/sys/windows/svc/mgr"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	windowsServiceName        = "Cloudflared"
 | 
			
		||||
	windowsServiceDescription = "Argo Tunnel agent"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func runApp(app *cli.App) {
 | 
			
		||||
	app.Commands = append(app.Commands, &cli.Command{
 | 
			
		||||
		Name:  "service",
 | 
			
		||||
		Usage: "Manages the Argo Tunnel Windows service",
 | 
			
		||||
		Subcommands: []*cli.Command{
 | 
			
		||||
			&cli.Command{
 | 
			
		||||
				Name:   "install",
 | 
			
		||||
				Usage:  "Install Argo Tunnel as a Windows service",
 | 
			
		||||
				Action: installWindowsService,
 | 
			
		||||
			},
 | 
			
		||||
			&cli.Command{
 | 
			
		||||
				Name:   "uninstall",
 | 
			
		||||
				Usage:  "Uninstall the Argo Tunnel service",
 | 
			
		||||
				Action: uninstallWindowsService,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	isIntSess, err := svc.IsAnInteractiveSession()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.Fatalf("failed to determine if we are running in an interactive session: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if isIntSess {
 | 
			
		||||
		app.Run(os.Args)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	elog, err := eventlog.Open(windowsServiceName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("Cannot open event log for %s", windowsServiceName)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer elog.Close()
 | 
			
		||||
 | 
			
		||||
	elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName))
 | 
			
		||||
	// Run executes service name by calling windowsService which is a Handler
 | 
			
		||||
	// interface that implements Execute method
 | 
			
		||||
	err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	elog.Info(1, fmt.Sprintf("%s service stopped", windowsServiceName))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type windowsService struct {
 | 
			
		||||
	app  *cli.App
 | 
			
		||||
	elog *eventlog.Log
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// called by the package code at the start of the service
 | 
			
		||||
func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) {
 | 
			
		||||
	const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
 | 
			
		||||
	changes <- svc.Status{State: svc.StartPending}
 | 
			
		||||
	go s.app.Run(args)
 | 
			
		||||
 | 
			
		||||
	changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
 | 
			
		||||
loop:
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case c := <-r:
 | 
			
		||||
			switch c.Cmd {
 | 
			
		||||
			case svc.Interrogate:
 | 
			
		||||
				s.elog.Info(1, fmt.Sprintf("control request 1 #%d", c))
 | 
			
		||||
				changes <- c.CurrentStatus
 | 
			
		||||
			case svc.Stop:
 | 
			
		||||
				s.elog.Info(1, "received stop control request")
 | 
			
		||||
				break loop
 | 
			
		||||
			case svc.Shutdown:
 | 
			
		||||
				s.elog.Info(1, "received shutdown control request")
 | 
			
		||||
				break loop
 | 
			
		||||
			default:
 | 
			
		||||
				s.elog.Error(1, fmt.Sprintf("unexpected control request #%d", c))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	close(shutdownC)
 | 
			
		||||
	changes <- svc.Status{State: svc.StopPending}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func installWindowsService(c *cli.Context) error {
 | 
			
		||||
	Log.Infof("Installing Argo Tunnel Windows service")
 | 
			
		||||
	exepath, err := os.Executable()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.Infof("Cannot find path name that start the process")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	m, err := mgr.Connect()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.WithError(err).Infof("Cannot establish a connection to the service control manager")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer m.Disconnect()
 | 
			
		||||
	s, err := m.OpenService(windowsServiceName)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		s.Close()
 | 
			
		||||
		Log.Errorf("service %s already exists", windowsServiceName)
 | 
			
		||||
		return fmt.Errorf("service %s already exists", windowsServiceName)
 | 
			
		||||
	}
 | 
			
		||||
	config := mgr.Config{StartType: mgr.StartAutomatic, DisplayName: windowsServiceDescription}
 | 
			
		||||
	s, err = m.CreateService(windowsServiceName, exepath, config)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.Infof("Cannot install service %s", windowsServiceName)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer s.Close()
 | 
			
		||||
	err = eventlog.InstallAsEventCreate(windowsServiceName, eventlog.Error|eventlog.Warning|eventlog.Info)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.Delete()
 | 
			
		||||
		Log.WithError(err).Infof("Cannot install event logger")
 | 
			
		||||
		return fmt.Errorf("SetupEventLogSource() failed: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func uninstallWindowsService(c *cli.Context) error {
 | 
			
		||||
	Log.Infof("Uninstalling Argo Tunnel Windows Service")
 | 
			
		||||
	m, err := mgr.Connect()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.Infof("Cannot establish a connection to the service control manager")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer m.Disconnect()
 | 
			
		||||
	s, err := m.OpenService(windowsServiceName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.Infof("service %s is not installed", windowsServiceName)
 | 
			
		||||
		return fmt.Errorf("service %s is not installed", windowsServiceName)
 | 
			
		||||
	}
 | 
			
		||||
	defer s.Close()
 | 
			
		||||
	err = s.Delete()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.Errorf("Cannot delete service %s", windowsServiceName)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	err = eventlog.Remove(windowsServiceName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		Log.Infof("Cannot remove event logger")
 | 
			
		||||
		return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -24,12 +24,6 @@ type activeStreamMap struct {
 | 
			
		|||
	ignoreNewStreams bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type FlowControlMetrics struct {
 | 
			
		||||
	AverageReceiveWindowSize, AverageSendWindowSize float64
 | 
			
		||||
	MinReceiveWindowSize, MaxReceiveWindowSize      uint32
 | 
			
		||||
	MinSendWindowSize, MaxSendWindowSize            uint32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap {
 | 
			
		||||
	m := &activeStreamMap{
 | 
			
		||||
		streams:      make(map[uint32]*MuxedStream),
 | 
			
		||||
| 
						 | 
				
			
			@ -169,45 +163,3 @@ func (m *activeStreamMap) Abort() {
 | 
			
		|||
	}
 | 
			
		||||
	m.ignoreNewStreams = true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *activeStreamMap) Metrics() *FlowControlMetrics {
 | 
			
		||||
	m.Lock()
 | 
			
		||||
	defer m.Unlock()
 | 
			
		||||
	var averageReceiveWindowSize, averageSendWindowSize float64
 | 
			
		||||
	var minReceiveWindowSize, maxReceiveWindowSize, minSendWindowSize, maxSendWindowSize uint32
 | 
			
		||||
	i := 0
 | 
			
		||||
	// The first variable in the range expression for map is the key, not index.
 | 
			
		||||
	for _, stream := range m.streams {
 | 
			
		||||
		// iterative mean: a(t+1) = a(t) + (a(t)-x)/(t+1)
 | 
			
		||||
		windows := stream.FlowControlWindow()
 | 
			
		||||
		averageReceiveWindowSize += (float64(windows.receiveWindow) - averageReceiveWindowSize) / float64(i+1)
 | 
			
		||||
		averageSendWindowSize += (float64(windows.sendWindow) - averageSendWindowSize) / float64(i+1)
 | 
			
		||||
		if i == 0 {
 | 
			
		||||
			maxReceiveWindowSize = windows.receiveWindow
 | 
			
		||||
			minReceiveWindowSize = windows.receiveWindow
 | 
			
		||||
			maxSendWindowSize = windows.sendWindow
 | 
			
		||||
			minSendWindowSize = windows.sendWindow
 | 
			
		||||
		} else {
 | 
			
		||||
			if windows.receiveWindow > maxReceiveWindowSize {
 | 
			
		||||
				maxReceiveWindowSize = windows.receiveWindow
 | 
			
		||||
			} else if windows.receiveWindow < minReceiveWindowSize {
 | 
			
		||||
				minReceiveWindowSize = windows.receiveWindow
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if windows.sendWindow > maxSendWindowSize {
 | 
			
		||||
				maxSendWindowSize = windows.sendWindow
 | 
			
		||||
			} else if windows.sendWindow < minSendWindowSize {
 | 
			
		||||
				minSendWindowSize = windows.sendWindow
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		i++
 | 
			
		||||
	}
 | 
			
		||||
	return &FlowControlMetrics{
 | 
			
		||||
		MinReceiveWindowSize:     minReceiveWindowSize,
 | 
			
		||||
		MaxReceiveWindowSize:     maxReceiveWindowSize,
 | 
			
		||||
		AverageReceiveWindowSize: averageReceiveWindowSize,
 | 
			
		||||
		MinSendWindowSize:        minSendWindowSize,
 | 
			
		||||
		MaxSendWindowSize:        maxSendWindowSize,
 | 
			
		||||
		AverageSendWindowSize:    averageSendWindowSize,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,18 @@
 | 
			
		|||
package h2mux
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type AtomicCounter struct {
 | 
			
		||||
	count uint64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *AtomicCounter) IncrementBy(number uint64) {
 | 
			
		||||
	atomic.AddUint64(&c.count, number)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get returns the current value of counter and reset it to 0
 | 
			
		||||
func (c *AtomicCounter) Count() uint64 {
 | 
			
		||||
	return atomic.SwapUint64(&c.count, 0)
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,23 @@
 | 
			
		|||
package h2mux
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"sync"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestCounter(t *testing.T) {
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	wg.Add(dataPoints)
 | 
			
		||||
	c := AtomicCounter{}
 | 
			
		||||
	for i := 0; i < dataPoints; i++ {
 | 
			
		||||
		go func() {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
			c.IncrementBy(uint64(1))
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
	assert.Equal(t, uint64(dataPoints), c.Count())
 | 
			
		||||
	assert.Equal(t, uint64(0), c.Count())
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -59,6 +59,8 @@ type Muxer struct {
 | 
			
		|||
	muxReader *MuxReader
 | 
			
		||||
	// muxWriter is the write process.
 | 
			
		||||
	muxWriter *MuxWriter
 | 
			
		||||
	// muxMetricsUpdater is the process to update metrics
 | 
			
		||||
	muxMetricsUpdater *muxMetricsUpdater
 | 
			
		||||
	// newStreamChan is used to create new streams on the writer thread.
 | 
			
		||||
	// The writer will assign the next available stream ID.
 | 
			
		||||
	newStreamChan chan MuxedStreamRequest
 | 
			
		||||
| 
						 | 
				
			
			@ -133,6 +135,11 @@ func Handshake(
 | 
			
		|||
	// set up reader/writer pair ready for serve
 | 
			
		||||
	streamErrors := NewStreamErrorMap()
 | 
			
		||||
	goAwayChan := make(chan http2.ErrCode, 1)
 | 
			
		||||
	updateRTTChan := make(chan *roundTripMeasurement, 1)
 | 
			
		||||
	updateReceiveWindowChan := make(chan uint32, 1)
 | 
			
		||||
	updateSendWindowChan := make(chan uint32, 1)
 | 
			
		||||
	updateInBoundBytesChan := make(chan uint64)
 | 
			
		||||
	updateOutBoundBytesChan := make(chan uint64)
 | 
			
		||||
	pingTimestamp := NewPingTimestamp()
 | 
			
		||||
	connActive := NewSignal()
 | 
			
		||||
	idleDuration := config.HeartbeatInterval
 | 
			
		||||
| 
						 | 
				
			
			@ -149,34 +156,48 @@ func Handshake(
 | 
			
		|||
 | 
			
		||||
	m.explicitShutdown = NewBooleanFuse()
 | 
			
		||||
	m.muxReader = &MuxReader{
 | 
			
		||||
		f:                   m.f,
 | 
			
		||||
		handler:             m.config.Handler,
 | 
			
		||||
		streams:             m.streams,
 | 
			
		||||
		readyList:           m.readyList,
 | 
			
		||||
		streamErrors:        streamErrors,
 | 
			
		||||
		goAwayChan:          goAwayChan,
 | 
			
		||||
		abortChan:           m.abortChan,
 | 
			
		||||
		pingTimestamp:       pingTimestamp,
 | 
			
		||||
		connActive:          connActive,
 | 
			
		||||
		initialStreamWindow: defaultWindowSize,
 | 
			
		||||
		streamWindowMax:     maxWindowSize,
 | 
			
		||||
		r:                   m.r,
 | 
			
		||||
		f:                       m.f,
 | 
			
		||||
		handler:                 m.config.Handler,
 | 
			
		||||
		streams:                 m.streams,
 | 
			
		||||
		readyList:               m.readyList,
 | 
			
		||||
		streamErrors:            streamErrors,
 | 
			
		||||
		goAwayChan:              goAwayChan,
 | 
			
		||||
		abortChan:               m.abortChan,
 | 
			
		||||
		pingTimestamp:           pingTimestamp,
 | 
			
		||||
		connActive:              connActive,
 | 
			
		||||
		initialStreamWindow:     defaultWindowSize,
 | 
			
		||||
		streamWindowMax:         maxWindowSize,
 | 
			
		||||
		r:                       m.r,
 | 
			
		||||
		updateRTTChan:           updateRTTChan,
 | 
			
		||||
		updateReceiveWindowChan: updateReceiveWindowChan,
 | 
			
		||||
		updateSendWindowChan:    updateSendWindowChan,
 | 
			
		||||
		updateInBoundBytesChan:  updateInBoundBytesChan,
 | 
			
		||||
	}
 | 
			
		||||
	m.muxWriter = &MuxWriter{
 | 
			
		||||
		f:               m.f,
 | 
			
		||||
		streams:         m.streams,
 | 
			
		||||
		streamErrors:    streamErrors,
 | 
			
		||||
		readyStreamChan: m.readyList.ReadyChannel(),
 | 
			
		||||
		newStreamChan:   m.newStreamChan,
 | 
			
		||||
		goAwayChan:      goAwayChan,
 | 
			
		||||
		abortChan:       m.abortChan,
 | 
			
		||||
		pingTimestamp:   pingTimestamp,
 | 
			
		||||
		idleTimer:       NewIdleTimer(idleDuration, maxRetries),
 | 
			
		||||
		connActiveChan:  connActive.WaitChannel(),
 | 
			
		||||
		maxFrameSize:    defaultFrameSize,
 | 
			
		||||
		f:                       m.f,
 | 
			
		||||
		streams:                 m.streams,
 | 
			
		||||
		streamErrors:            streamErrors,
 | 
			
		||||
		readyStreamChan:         m.readyList.ReadyChannel(),
 | 
			
		||||
		newStreamChan:           m.newStreamChan,
 | 
			
		||||
		goAwayChan:              goAwayChan,
 | 
			
		||||
		abortChan:               m.abortChan,
 | 
			
		||||
		pingTimestamp:           pingTimestamp,
 | 
			
		||||
		idleTimer:               NewIdleTimer(idleDuration, maxRetries),
 | 
			
		||||
		connActiveChan:          connActive.WaitChannel(),
 | 
			
		||||
		maxFrameSize:            defaultFrameSize,
 | 
			
		||||
		updateReceiveWindowChan: updateReceiveWindowChan,
 | 
			
		||||
		updateSendWindowChan:    updateSendWindowChan,
 | 
			
		||||
		updateOutBoundBytesChan: updateOutBoundBytesChan,
 | 
			
		||||
	}
 | 
			
		||||
	m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
 | 
			
		||||
 | 
			
		||||
	m.muxMetricsUpdater = newMuxMetricsUpdater(
 | 
			
		||||
		updateRTTChan,
 | 
			
		||||
		updateReceiveWindowChan,
 | 
			
		||||
		updateSendWindowChan,
 | 
			
		||||
		updateInBoundBytesChan,
 | 
			
		||||
		updateOutBoundBytesChan,
 | 
			
		||||
		m.abortChan,
 | 
			
		||||
	)
 | 
			
		||||
	return m, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -246,9 +267,13 @@ func (m *Muxer) Serve() error {
 | 
			
		|||
		m.w.Close()
 | 
			
		||||
		m.abort()
 | 
			
		||||
	}()
 | 
			
		||||
	go func() {
 | 
			
		||||
		errChan <- m.muxMetricsUpdater.run(logger)
 | 
			
		||||
	}()
 | 
			
		||||
	err := <-errChan
 | 
			
		||||
	go func() {
 | 
			
		||||
		// discard error as other handler closes
 | 
			
		||||
		// discard errors as other handler and muxMetricsUpdater close
 | 
			
		||||
		<-errChan
 | 
			
		||||
		<-errChan
 | 
			
		||||
		close(errChan)
 | 
			
		||||
	}()
 | 
			
		||||
| 
						 | 
				
			
			@ -318,14 +343,8 @@ func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, erro
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Return the estimated round-trip time.
 | 
			
		||||
func (m *Muxer) RTT() RTTMeasurement {
 | 
			
		||||
	return m.muxReader.RTT()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Return min/max/average of send/receive window for all streams on this connection
 | 
			
		||||
func (m *Muxer) FlowControlMetrics() *FlowControlMetrics {
 | 
			
		||||
	return m.muxReader.FlowControlMetrics()
 | 
			
		||||
func (m *Muxer) Metrics() *MuxerMetrics {
 | 
			
		||||
	return m.muxMetricsUpdater.Metrics()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *Muxer) abort() {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,6 +14,7 @@ import (
 | 
			
		|||
	"time"
 | 
			
		||||
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestMain(m *testing.M) {
 | 
			
		||||
| 
						 | 
				
			
			@ -134,9 +135,12 @@ func TestSingleStream(t *testing.T) {
 | 
			
		|||
		if stream.Headers[0].Value != "headerValue" {
 | 
			
		||||
			t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
 | 
			
		||||
		}
 | 
			
		||||
		stream.WriteHeaders([]Header{
 | 
			
		||||
		headers := []Header{
 | 
			
		||||
			Header{Name: "response-header", Value: "responseValue"},
 | 
			
		||||
		})
 | 
			
		||||
		}
 | 
			
		||||
		stream.WriteHeaders(headers)
 | 
			
		||||
		assert.Equal(t, headers, stream.writeHeaders)
 | 
			
		||||
		assert.False(t, stream.headersSent)
 | 
			
		||||
		buf := []byte("Hello world")
 | 
			
		||||
		stream.Write(buf)
 | 
			
		||||
		// after this receive, the edge closed the stream
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,6 +19,7 @@ type MuxedStream struct {
 | 
			
		|||
	receiveWindowCurrentMax uint32
 | 
			
		||||
	// limit set in http2 spec. 2^31-1
 | 
			
		||||
	receiveWindowMax uint32
 | 
			
		||||
 | 
			
		||||
	// nonzero if a WINDOW_UPDATE frame for a stream needs to be sent
 | 
			
		||||
	windowUpdate uint32
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -39,10 +40,6 @@ type MuxedStream struct {
 | 
			
		|||
	receivedEOF bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type flowControlWindow struct {
 | 
			
		||||
	receiveWindow, sendWindow uint32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *MuxedStream) Read(p []byte) (n int, err error) {
 | 
			
		||||
	return s.readBuffer.Read(p)
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -101,17 +98,21 @@ func (s *MuxedStream) WriteHeaders(headers []Header) error {
 | 
			
		|||
		return ErrStreamHeadersSent
 | 
			
		||||
	}
 | 
			
		||||
	s.writeHeaders = headers
 | 
			
		||||
	s.headersSent = false
 | 
			
		||||
	s.writeNotify()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *MuxedStream) FlowControlWindow() *flowControlWindow {
 | 
			
		||||
func (s *MuxedStream) getReceiveWindow() uint32 {
 | 
			
		||||
	s.writeLock.Lock()
 | 
			
		||||
	defer s.writeLock.Unlock()
 | 
			
		||||
	return &flowControlWindow{
 | 
			
		||||
		receiveWindow: s.receiveWindow,
 | 
			
		||||
		sendWindow:    s.sendWindow,
 | 
			
		||||
	}
 | 
			
		||||
	return s.receiveWindow
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *MuxedStream) getSendWindow() uint32 {
 | 
			
		||||
	s.writeLock.Lock()
 | 
			
		||||
	defer s.writeLock.Unlock()
 | 
			
		||||
	return s.sendWindow
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// writeNotify must happen while holding writeLock.
 | 
			
		||||
| 
						 | 
				
			
			@ -209,9 +210,7 @@ func (s *MuxedStream) getChunk() *streamChunk {
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	// Copies at most s.sendWindow bytes
 | 
			
		||||
	//log.Infof("writeBuffer len %d stream %d", s.writeBuffer.Len(), s.streamID)
 | 
			
		||||
	writeLen, _ := io.CopyN(&chunk.buffer, &s.writeBuffer, int64(s.sendWindow))
 | 
			
		||||
	//log.Infof("writeLen %d stream %d", writeLen, s.streamID)
 | 
			
		||||
	s.sendWindow -= uint32(writeLen)
 | 
			
		||||
	s.receiveWindow += s.windowUpdate
 | 
			
		||||
	s.windowUpdate = 0
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,232 @@
 | 
			
		|||
package h2mux
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/golang-collections/collections/queue"
 | 
			
		||||
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// data points used to compute average receive window and send window size
 | 
			
		||||
const (
 | 
			
		||||
	// data points used to compute average receive window and send window size
 | 
			
		||||
	dataPoints = 100
 | 
			
		||||
	// updateFreq is set to 1 sec so we can get inbound & outbound byes/sec
 | 
			
		||||
	updateFreq = time.Second
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type muxMetricsUpdater struct {
 | 
			
		||||
	// rttData keeps record of rtt, rttMin, rttMax and last measured time
 | 
			
		||||
	rttData *rttData
 | 
			
		||||
	// receiveWindowData keeps record of receive window measurement
 | 
			
		||||
	receiveWindowData *flowControlData
 | 
			
		||||
	// sendWindowData keeps record of send window measurement
 | 
			
		||||
	sendWindowData *flowControlData
 | 
			
		||||
	// inBoundRate is incoming bytes/sec
 | 
			
		||||
	inBoundRate *rate
 | 
			
		||||
	// outBoundRate is outgoing bytes/sec
 | 
			
		||||
	outBoundRate *rate
 | 
			
		||||
	// updateRTTChan is the channel to receive new RTT measurement from muxReader
 | 
			
		||||
	updateRTTChan <-chan *roundTripMeasurement
 | 
			
		||||
	//updateReceiveWindowChan is the channel to receive updated receiveWindow size from muxReader and muxWriter
 | 
			
		||||
	updateReceiveWindowChan <-chan uint32
 | 
			
		||||
	//updateSendWindowChan is the channel to receive updated sendWindow size from muxReader and muxWriter
 | 
			
		||||
	updateSendWindowChan <-chan uint32
 | 
			
		||||
	// updateInBoundBytesChan us the channel to receive bytesRead from muxReader
 | 
			
		||||
	updateInBoundBytesChan <-chan uint64
 | 
			
		||||
	// updateOutBoundBytesChan us the channel to receive bytesWrote from muxWriter
 | 
			
		||||
	updateOutBoundBytesChan <-chan uint64
 | 
			
		||||
	// shutdownC is to signal the muxerMetricsUpdater to shutdown
 | 
			
		||||
	abortChan <-chan struct{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MuxerMetrics struct {
 | 
			
		||||
	RTT, RTTMin, RTTMax                                              time.Duration
 | 
			
		||||
	ReceiveWindowAve, SendWindowAve                                  float64
 | 
			
		||||
	ReceiveWindowMin, ReceiveWindowMax, SendWindowMin, SendWindowMax uint32
 | 
			
		||||
	InBoundRateCurr, InBoundRateMin, InBoundRateMax                  uint64
 | 
			
		||||
	OutBoundRateCurr, OutBoundRateMin, OutBoundRateMax               uint64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type roundTripMeasurement struct {
 | 
			
		||||
	receiveTime, sendTime time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type rttData struct {
 | 
			
		||||
	rtt, rttMin, rttMax time.Duration
 | 
			
		||||
	lastMeasurementTime time.Time
 | 
			
		||||
	lock                sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type flowControlData struct {
 | 
			
		||||
	sum      uint64
 | 
			
		||||
	min, max uint32
 | 
			
		||||
	queue    *queue.Queue
 | 
			
		||||
	lock     sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type rate struct {
 | 
			
		||||
	curr uint64
 | 
			
		||||
	min, max uint64
 | 
			
		||||
	lock sync.RWMutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newMuxMetricsUpdater(
 | 
			
		||||
	updateRTTChan <-chan *roundTripMeasurement,
 | 
			
		||||
	updateReceiveWindowChan <-chan uint32,
 | 
			
		||||
	updateSendWindowChan <-chan uint32,
 | 
			
		||||
	updateInBoundBytesChan <-chan uint64,
 | 
			
		||||
	updateOutBoundBytesChan <-chan uint64,
 | 
			
		||||
	abortChan <-chan struct{},
 | 
			
		||||
) *muxMetricsUpdater {
 | 
			
		||||
	return &muxMetricsUpdater{
 | 
			
		||||
		rttData:                 newRTTData(),
 | 
			
		||||
		receiveWindowData:       newFlowControlData(),
 | 
			
		||||
		sendWindowData:          newFlowControlData(),
 | 
			
		||||
		inBoundRate:             newRate(),
 | 
			
		||||
		outBoundRate:            newRate(),
 | 
			
		||||
		updateRTTChan:           updateRTTChan,
 | 
			
		||||
		updateReceiveWindowChan: updateReceiveWindowChan,
 | 
			
		||||
		updateSendWindowChan:    updateSendWindowChan,
 | 
			
		||||
		updateInBoundBytesChan:  updateInBoundBytesChan,
 | 
			
		||||
		updateOutBoundBytesChan: updateOutBoundBytesChan,
 | 
			
		||||
		abortChan:               abortChan,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (updater *muxMetricsUpdater) Metrics() *MuxerMetrics {
 | 
			
		||||
	m := &MuxerMetrics{}
 | 
			
		||||
	m.RTT, m.RTTMin, m.RTTMax = updater.rttData.metrics()
 | 
			
		||||
	m.ReceiveWindowAve, m.ReceiveWindowMin, m.ReceiveWindowMax = updater.receiveWindowData.metrics()
 | 
			
		||||
	m.SendWindowAve, m.SendWindowMin, m.SendWindowMax = updater.sendWindowData.metrics()
 | 
			
		||||
	m.InBoundRateCurr, m.InBoundRateMin, m.InBoundRateMax = updater.inBoundRate.get()
 | 
			
		||||
	m.OutBoundRateCurr, m.OutBoundRateMin, m.OutBoundRateMax = updater.outBoundRate.get()
 | 
			
		||||
	return m
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (updater *muxMetricsUpdater) run(parentLogger *log.Entry) error {
 | 
			
		||||
	logger := parentLogger.WithFields(log.Fields{
 | 
			
		||||
		"subsystem": "mux",
 | 
			
		||||
		"dir":       "metrics",
 | 
			
		||||
	})
 | 
			
		||||
	defer logger.Debug("event loop finished")
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-updater.abortChan:
 | 
			
		||||
			logger.Infof("Stopping mux metrics updater")
 | 
			
		||||
			return nil
 | 
			
		||||
		case roundTripMeasurement := <-updater.updateRTTChan:
 | 
			
		||||
			go updater.rttData.update(roundTripMeasurement)
 | 
			
		||||
			logger.Debug("Update rtt")
 | 
			
		||||
		case receiveWindow := <-updater.updateReceiveWindowChan:
 | 
			
		||||
			go updater.receiveWindowData.update(receiveWindow)
 | 
			
		||||
			logger.Debug("Update receive window")
 | 
			
		||||
		case sendWindow := <-updater.updateSendWindowChan:
 | 
			
		||||
			go updater.sendWindowData.update(sendWindow)
 | 
			
		||||
			logger.Debug("Update send window")
 | 
			
		||||
		case inBoundBytes := <-updater.updateInBoundBytesChan:
 | 
			
		||||
			// inBoundBytes is bytes/sec because the update interval is 1 sec
 | 
			
		||||
			go updater.inBoundRate.update(inBoundBytes)
 | 
			
		||||
			logger.Debugf("Inbound bytes %d", inBoundBytes)
 | 
			
		||||
		case outBoundBytes := <-updater.updateOutBoundBytesChan:
 | 
			
		||||
			// outBoundBytes is bytes/sec because the update interval is 1 sec
 | 
			
		||||
			go updater.outBoundRate.update(outBoundBytes)
 | 
			
		||||
			logger.Debugf("Outbound bytes %d", outBoundBytes)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newRTTData() *rttData {
 | 
			
		||||
	return &rttData{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *rttData) update(measurement *roundTripMeasurement) {
 | 
			
		||||
	r.lock.Lock()
 | 
			
		||||
	defer r.lock.Unlock()
 | 
			
		||||
	// discard pings before lastMeasurementTime
 | 
			
		||||
	if r.lastMeasurementTime.After(measurement.sendTime) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	r.lastMeasurementTime = measurement.sendTime
 | 
			
		||||
	r.rtt = measurement.receiveTime.Sub(measurement.sendTime)
 | 
			
		||||
	if r.rttMax < r.rtt {
 | 
			
		||||
		r.rttMax = r.rtt
 | 
			
		||||
	}
 | 
			
		||||
	if r.rttMin == 0 || r.rttMin > r.rtt {
 | 
			
		||||
		r.rttMin = r.rtt
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *rttData) metrics() (rtt, rttMin, rttMax time.Duration) {
 | 
			
		||||
	r.lock.RLock()
 | 
			
		||||
	defer r.lock.RUnlock()
 | 
			
		||||
	return r.rtt, r.rttMin, r.rttMax
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newFlowControlData() *flowControlData {
 | 
			
		||||
	return &flowControlData{queue: queue.New()}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *flowControlData) update(measurement uint32) {
 | 
			
		||||
	f.lock.Lock()
 | 
			
		||||
	defer f.lock.Unlock()
 | 
			
		||||
	var firstItem uint32
 | 
			
		||||
	// store new data into queue, remove oldest data if queue is full
 | 
			
		||||
	f.queue.Enqueue(measurement)
 | 
			
		||||
	if f.queue.Len() > dataPoints {
 | 
			
		||||
		// data type should always be uint32
 | 
			
		||||
		firstItem = f.queue.Dequeue().(uint32)
 | 
			
		||||
	}
 | 
			
		||||
	// if (measurement - firstItem) < 0, uint64(measurement - firstItem)
 | 
			
		||||
	// will overflow and become a large positive number
 | 
			
		||||
	f.sum += uint64(measurement)
 | 
			
		||||
	f.sum -= uint64(firstItem)
 | 
			
		||||
	if measurement > f.max {
 | 
			
		||||
		f.max = measurement
 | 
			
		||||
	}
 | 
			
		||||
	if f.min == 0 || measurement < f.min {
 | 
			
		||||
		f.min = measurement
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// caller of ave() should acquire lock first
 | 
			
		||||
func (f *flowControlData) ave() float64 {
 | 
			
		||||
	if f.queue.Len() == 0 {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	return float64(f.sum) / float64(f.queue.Len())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *flowControlData) metrics() (ave float64, min, max uint32) {
 | 
			
		||||
	f.lock.RLock()
 | 
			
		||||
	defer f.lock.RUnlock()
 | 
			
		||||
	return f.ave(), f.min, f.max
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newRate() *rate {
 | 
			
		||||
	return &rate{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *rate) update(measurement uint64) {
 | 
			
		||||
	r.lock.Lock()
 | 
			
		||||
	defer r.lock.Unlock()
 | 
			
		||||
	r.curr = measurement
 | 
			
		||||
	// if measurement is 0, then there is no incoming/outgoing connection, don't update min/max
 | 
			
		||||
	if r.curr == 0 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if measurement > r.max {
 | 
			
		||||
		r.max = measurement
 | 
			
		||||
	}
 | 
			
		||||
	if r.min == 0 || measurement < r.min {
 | 
			
		||||
		r.min = measurement
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *rate) get() (curr, min, max uint64) {
 | 
			
		||||
	r.lock.RLock()
 | 
			
		||||
	defer r.lock.RUnlock()
 | 
			
		||||
	return r.curr, r.min, r.max
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,176 @@
 | 
			
		|||
package h2mux
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"sync"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ave(sum uint64, len int) float64 {
 | 
			
		||||
	return float64(sum) / float64(len)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRTTUpdate(t *testing.T) {
 | 
			
		||||
	r := newRTTData()
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	// send at 0 ms, receive at 2 ms, RTT = 2ms
 | 
			
		||||
	m := &roundTripMeasurement{receiveTime: start.Add(2 * time.Millisecond), sendTime: start}
 | 
			
		||||
	r.update(m)
 | 
			
		||||
	assert.Equal(t, start, r.lastMeasurementTime)
 | 
			
		||||
	assert.Equal(t, 2*time.Millisecond, r.rtt)
 | 
			
		||||
	assert.Equal(t, 2*time.Millisecond, r.rttMin)
 | 
			
		||||
	assert.Equal(t, 2*time.Millisecond, r.rttMax)
 | 
			
		||||
 | 
			
		||||
	// send at 3 ms, receive at 6 ms, RTT = 3ms
 | 
			
		||||
	m = &roundTripMeasurement{receiveTime: start.Add(6 * time.Millisecond), sendTime: start.Add(3 * time.Millisecond)}
 | 
			
		||||
	r.update(m)
 | 
			
		||||
	assert.Equal(t, start.Add(3*time.Millisecond), r.lastMeasurementTime)
 | 
			
		||||
	assert.Equal(t, 3*time.Millisecond, r.rtt)
 | 
			
		||||
	assert.Equal(t, 2*time.Millisecond, r.rttMin)
 | 
			
		||||
	assert.Equal(t, 3*time.Millisecond, r.rttMax)
 | 
			
		||||
 | 
			
		||||
	// send at 7 ms, receive at 8 ms, RTT = 1ms
 | 
			
		||||
	m = &roundTripMeasurement{receiveTime: start.Add(8 * time.Millisecond), sendTime: start.Add(7 * time.Millisecond)}
 | 
			
		||||
	r.update(m)
 | 
			
		||||
	assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime)
 | 
			
		||||
	assert.Equal(t, 1*time.Millisecond, r.rtt)
 | 
			
		||||
	assert.Equal(t, 1*time.Millisecond, r.rttMin)
 | 
			
		||||
	assert.Equal(t, 3*time.Millisecond, r.rttMax)
 | 
			
		||||
 | 
			
		||||
	// send at -4 ms, receive at 0 ms, RTT = 4ms, but this ping is before last measurement
 | 
			
		||||
	// so it will be discarded
 | 
			
		||||
	m = &roundTripMeasurement{receiveTime: start, sendTime: start.Add(-2 * time.Millisecond)}
 | 
			
		||||
	r.update(m)
 | 
			
		||||
	assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime)
 | 
			
		||||
	assert.Equal(t, 1*time.Millisecond, r.rtt)
 | 
			
		||||
	assert.Equal(t, 1*time.Millisecond, r.rttMin)
 | 
			
		||||
	assert.Equal(t, 3*time.Millisecond, r.rttMax)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestFlowControlDataUpdate(t *testing.T) {
 | 
			
		||||
	f := newFlowControlData()
 | 
			
		||||
	assert.Equal(t, 0, f.queue.Len())
 | 
			
		||||
	assert.Equal(t, float64(0), f.ave())
 | 
			
		||||
 | 
			
		||||
	var sum uint64
 | 
			
		||||
	min := maxWindowSize - dataPoints
 | 
			
		||||
	max := maxWindowSize
 | 
			
		||||
	for i := 1; i <= dataPoints; i++ {
 | 
			
		||||
		size := maxWindowSize - uint32(i)
 | 
			
		||||
		f.update(size)
 | 
			
		||||
		assert.Equal(t, max - uint32(1), f.max)
 | 
			
		||||
		assert.Equal(t, size, f.min)
 | 
			
		||||
 | 
			
		||||
		assert.Equal(t, i, f.queue.Len())
 | 
			
		||||
 | 
			
		||||
		sum += uint64(size)
 | 
			
		||||
		assert.Equal(t, sum, f.sum)
 | 
			
		||||
		assert.Equal(t, ave(sum, f.queue.Len()), f.ave())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// queue is full, should start to dequeue first element
 | 
			
		||||
	for i := 1; i <= dataPoints; i++ {
 | 
			
		||||
		f.update(max)
 | 
			
		||||
		assert.Equal(t, max, f.max)
 | 
			
		||||
		assert.Equal(t, min, f.min)
 | 
			
		||||
 | 
			
		||||
		assert.Equal(t, dataPoints, f.queue.Len())
 | 
			
		||||
 | 
			
		||||
		sum += uint64(i)
 | 
			
		||||
		assert.Equal(t, sum, f.sum)
 | 
			
		||||
		assert.Equal(t, ave(sum, dataPoints), f.ave())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMuxMetricsUpdater(t *testing.T) {
 | 
			
		||||
	updateRTTChan := make(chan *roundTripMeasurement)
 | 
			
		||||
	updateReceiveWindowChan := make(chan uint32)
 | 
			
		||||
	updateSendWindowChan := make(chan uint32)
 | 
			
		||||
	updateInBoundBytesChan := make(chan uint64)
 | 
			
		||||
	updateOutBoundBytesChan := make(chan uint64)
 | 
			
		||||
	abortChan := make(chan struct{})
 | 
			
		||||
	errChan := make(chan error)
 | 
			
		||||
	m := newMuxMetricsUpdater(updateRTTChan,
 | 
			
		||||
		updateReceiveWindowChan,
 | 
			
		||||
		updateSendWindowChan,
 | 
			
		||||
		updateInBoundBytesChan,
 | 
			
		||||
		updateOutBoundBytesChan,
 | 
			
		||||
		abortChan,
 | 
			
		||||
	)
 | 
			
		||||
	logger := log.NewEntry(log.New())
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		errChan <- m.run(logger)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	wg.Add(2)
 | 
			
		||||
 | 
			
		||||
	// mock muxReader
 | 
			
		||||
	readerStart := time.Now()
 | 
			
		||||
	rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart}
 | 
			
		||||
	updateRTTChan <- rm
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer wg.Done()
 | 
			
		||||
		// Becareful if dataPoints is not divisibile by 4
 | 
			
		||||
		readerSend := readerStart.Add(time.Millisecond)
 | 
			
		||||
		for i := 1; i <= dataPoints/4; i++ {
 | 
			
		||||
			readerReceive := readerSend.Add(time.Duration(i) * time.Millisecond)
 | 
			
		||||
			rm := &roundTripMeasurement{receiveTime: readerReceive, sendTime: readerSend}
 | 
			
		||||
			updateRTTChan <- rm
 | 
			
		||||
			readerSend = readerReceive.Add(time.Millisecond)
 | 
			
		||||
 | 
			
		||||
			updateReceiveWindowChan <- uint32(i)
 | 
			
		||||
			updateSendWindowChan <- uint32(i)
 | 
			
		||||
 | 
			
		||||
			updateInBoundBytesChan <- uint64(i)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// mock muxWriter
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer wg.Done()
 | 
			
		||||
		for j := dataPoints/4 + 1; j <= dataPoints/2; j++ {
 | 
			
		||||
			updateReceiveWindowChan <- uint32(j)
 | 
			
		||||
			updateSendWindowChan <- uint32(j)
 | 
			
		||||
 | 
			
		||||
			// should always be disgard since the send time is before readerSend
 | 
			
		||||
			rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart.Add(-time.Duration(j*dataPoints) * time.Millisecond)}
 | 
			
		||||
			updateRTTChan <- rm
 | 
			
		||||
 | 
			
		||||
			updateOutBoundBytesChan <- uint64(j)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}()
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
 | 
			
		||||
	metrics := m.Metrics()
 | 
			
		||||
	points := dataPoints / 2
 | 
			
		||||
	assert.Equal(t, time.Millisecond, metrics.RTTMin)
 | 
			
		||||
	assert.Equal(t, time.Duration(dataPoints/4)*time.Millisecond, metrics.RTTMax)
 | 
			
		||||
 | 
			
		||||
	// sum(1..i) = i*(i+1)/2, ave(1..i) = i*(i+1)/2/i = (i+1)/2
 | 
			
		||||
	assert.Equal(t, float64(points+1)/float64(2), metrics.ReceiveWindowAve)
 | 
			
		||||
	assert.Equal(t, uint32(1), metrics.ReceiveWindowMin)
 | 
			
		||||
	assert.Equal(t, uint32(points), metrics.ReceiveWindowMax)
 | 
			
		||||
 | 
			
		||||
	assert.Equal(t, float64(points+1)/float64(2), metrics.SendWindowAve)
 | 
			
		||||
	assert.Equal(t, uint32(1), metrics.SendWindowMin)
 | 
			
		||||
	assert.Equal(t, uint32(points), metrics.SendWindowMax)
 | 
			
		||||
 | 
			
		||||
	assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateCurr)
 | 
			
		||||
	assert.Equal(t, uint64(1), metrics.InBoundRateMin)
 | 
			
		||||
	assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateMax)
 | 
			
		||||
 | 
			
		||||
	assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateCurr)
 | 
			
		||||
	assert.Equal(t, uint64(dataPoints/4+1), metrics.OutBoundRateMin)
 | 
			
		||||
	assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateMax)
 | 
			
		||||
 | 
			
		||||
	close(abortChan)
 | 
			
		||||
	assert.Nil(t, <-errChan)
 | 
			
		||||
	close(errChan)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -3,7 +3,6 @@ package h2mux
 | 
			
		|||
import (
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"io"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
| 
						 | 
				
			
			@ -34,14 +33,18 @@ type MuxReader struct {
 | 
			
		|||
	initialStreamWindow uint32
 | 
			
		||||
	// The max value for the send window of a stream.
 | 
			
		||||
	streamWindowMax uint32
 | 
			
		||||
	// windowMetrics keeps track of min/max/average of send/receive windows for all streams
 | 
			
		||||
	flowControlMetrics *FlowControlMetrics
 | 
			
		||||
	metricsMutex       sync.Mutex
 | 
			
		||||
	// r is a reference to the underlying connection used when shutting down.
 | 
			
		||||
	r io.Closer
 | 
			
		||||
	// rttMeasurement measures RTT based on ping timestamps.
 | 
			
		||||
	rttMeasurement RTTMeasurement
 | 
			
		||||
	rttMutex       sync.Mutex
 | 
			
		||||
	// updateRTTChan is the channel to send new RTT measurement to muxerMetricsUpdater
 | 
			
		||||
	updateRTTChan chan<- *roundTripMeasurement
 | 
			
		||||
	// updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
 | 
			
		||||
	updateReceiveWindowChan chan<- uint32
 | 
			
		||||
	// updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
 | 
			
		||||
	updateSendWindowChan chan<- uint32
 | 
			
		||||
	// bytesRead is the amount of bytes read from data frame since the last time we send bytes read to metrics
 | 
			
		||||
	bytesRead AtomicCounter
 | 
			
		||||
	// updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
 | 
			
		||||
	updateInBoundBytesChan chan<- uint64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *MuxReader) Shutdown() {
 | 
			
		||||
| 
						 | 
				
			
			@ -57,28 +60,26 @@ func (r *MuxReader) Shutdown() {
 | 
			
		|||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *MuxReader) RTT() RTTMeasurement {
 | 
			
		||||
	r.rttMutex.Lock()
 | 
			
		||||
	defer r.rttMutex.Unlock()
 | 
			
		||||
	return r.rttMeasurement
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *MuxReader) FlowControlMetrics() *FlowControlMetrics {
 | 
			
		||||
	r.metricsMutex.Lock()
 | 
			
		||||
	defer r.metricsMutex.Unlock()
 | 
			
		||||
	if r.flowControlMetrics != nil {
 | 
			
		||||
		return r.flowControlMetrics
 | 
			
		||||
	}
 | 
			
		||||
	// No metrics available yet
 | 
			
		||||
	return &FlowControlMetrics{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *MuxReader) run(parentLogger *log.Entry) error {
 | 
			
		||||
	logger := parentLogger.WithFields(log.Fields{
 | 
			
		||||
		"subsystem": "mux",
 | 
			
		||||
		"dir":       "read",
 | 
			
		||||
	})
 | 
			
		||||
	defer logger.Debug("event loop finished")
 | 
			
		||||
 | 
			
		||||
	// routine to periodically update bytesRead
 | 
			
		||||
	go func() {
 | 
			
		||||
		tickC := time.Tick(updateFreq)
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case <-r.abortChan:
 | 
			
		||||
				return
 | 
			
		||||
			case <-tickC:
 | 
			
		||||
				r.updateInBoundBytesChan <- r.bytesRead.Count()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		frame, err := r.f.ReadFrame()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -120,6 +121,8 @@ func (r *MuxReader) run(parentLogger *log.Entry) error {
 | 
			
		|||
			r.receivePingData(f)
 | 
			
		||||
		case *http2.GoAwayFrame:
 | 
			
		||||
			err = r.receiveGoAway(f)
 | 
			
		||||
		// The receiver of a flow-controlled frame sends a WINDOW_UPDATE frame as it
 | 
			
		||||
		// consumes data and frees up space in flow-control windows
 | 
			
		||||
		case *http2.WindowUpdateFrame:
 | 
			
		||||
			err = r.updateStreamWindow(f)
 | 
			
		||||
		default:
 | 
			
		||||
| 
						 | 
				
			
			@ -236,10 +239,11 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E
 | 
			
		|||
	}
 | 
			
		||||
	data := frame.Data()
 | 
			
		||||
	if len(data) > 0 {
 | 
			
		||||
		_, err = stream.readBuffer.Write(data)
 | 
			
		||||
		n, err := stream.readBuffer.Write(data)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return r.streamError(stream.streamID, http2.ErrCodeInternal)
 | 
			
		||||
		}
 | 
			
		||||
		r.bytesRead.IncrementBy(uint64(n))
 | 
			
		||||
	}
 | 
			
		||||
	if frame.Header().Flags.Has(http2.FlagDataEndStream) {
 | 
			
		||||
		if stream.receiveEOF() {
 | 
			
		||||
| 
						 | 
				
			
			@ -253,6 +257,7 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E
 | 
			
		|||
	if !stream.consumeReceiveWindow(uint32(len(data))) {
 | 
			
		||||
		return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
 | 
			
		||||
	}
 | 
			
		||||
	r.updateReceiveWindowChan <- stream.getReceiveWindow()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -263,10 +268,14 @@ func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
 | 
			
		|||
		r.pingTimestamp.Set(ts)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	r.rttMutex.Lock()
 | 
			
		||||
	r.rttMeasurement.Update(time.Unix(0, ts))
 | 
			
		||||
	r.rttMutex.Unlock()
 | 
			
		||||
	r.flowControlMetrics = r.streams.Metrics()
 | 
			
		||||
 | 
			
		||||
	// Update updates the computed values with a new measurement.
 | 
			
		||||
	// outgoingTime is the time that the probe was sent.
 | 
			
		||||
	// We assume that time.Now() is the time we received that probe.
 | 
			
		||||
	r.updateRTTChan <- &roundTripMeasurement{
 | 
			
		||||
		receiveTime: time.Now(),
 | 
			
		||||
		sendTime:    time.Unix(0, ts),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Receive a GOAWAY from the peer. Gracefully shut down our connection.
 | 
			
		||||
| 
						 | 
				
			
			@ -293,6 +302,7 @@ func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
 | 
			
		|||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	stream.replenishSendWindow(frame.Increment)
 | 
			
		||||
	r.updateSendWindowChan <- stream.getSendWindow()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -40,6 +40,14 @@ type MuxWriter struct {
 | 
			
		|||
	headerEncoder *hpack.Encoder
 | 
			
		||||
	// headerBuffer is the temporary buffer used by headerEncoder.
 | 
			
		||||
	headerBuffer bytes.Buffer
 | 
			
		||||
	// updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
 | 
			
		||||
	updateReceiveWindowChan chan<- uint32
 | 
			
		||||
	// updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
 | 
			
		||||
	updateSendWindowChan chan<- uint32
 | 
			
		||||
	// bytesWrote is the amount of bytes wrote to data frame since the last time we send bytes wrote to metrics
 | 
			
		||||
	bytesWrote AtomicCounter
 | 
			
		||||
	// updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
 | 
			
		||||
	updateOutBoundBytesChan chan<- uint64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MuxedStreamRequest struct {
 | 
			
		||||
| 
						 | 
				
			
			@ -64,6 +72,20 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error {
 | 
			
		|||
		"dir":       "write",
 | 
			
		||||
	})
 | 
			
		||||
	defer logger.Debug("event loop finished")
 | 
			
		||||
 | 
			
		||||
	// routine to periodically communicate bytesWrote
 | 
			
		||||
	go func() {
 | 
			
		||||
		tickC := time.Tick(updateFreq)
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case <-w.abortChan:
 | 
			
		||||
				return
 | 
			
		||||
			case <-tickC:
 | 
			
		||||
				w.updateOutBoundBytesChan <- w.bytesWrote.Count()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-w.abortChan:
 | 
			
		||||
| 
						 | 
				
			
			@ -141,7 +163,8 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error {
 | 
			
		|||
func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error {
 | 
			
		||||
	logger.Debug("writable")
 | 
			
		||||
	chunk := stream.getChunk()
 | 
			
		||||
 | 
			
		||||
	w.updateReceiveWindowChan <- stream.getReceiveWindow()
 | 
			
		||||
	w.updateSendWindowChan <- stream.getSendWindow()
 | 
			
		||||
	if chunk.sendHeadersFrame() {
 | 
			
		||||
		err := w.writeHeaders(chunk.streamID, chunk.headers)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -154,7 +177,9 @@ func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) erro
 | 
			
		|||
	if chunk.sendWindowUpdateFrame() {
 | 
			
		||||
		// Send a WINDOW_UPDATE frame to update our receive window.
 | 
			
		||||
		// If the Stream ID is zero, the window update applies to the connection as a whole
 | 
			
		||||
		// A WINDOW_UPDATE in a specific stream applies to the connection-level flow control as well.
 | 
			
		||||
		// RFC7540 section-6.9.1 "A receiver that receives a flow-controlled frame MUST
 | 
			
		||||
		// always account for  its contribution against the connection flow-control
 | 
			
		||||
		// window, unless the receiver treats this as a connection error"
 | 
			
		||||
		err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.WithError(err).Warn("error writing window update")
 | 
			
		||||
| 
						 | 
				
			
			@ -170,6 +195,8 @@ func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) erro
 | 
			
		|||
			logger.WithError(err).Warn("error writing data")
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		// update the amount of data wrote
 | 
			
		||||
		w.bytesWrote.IncrementBy(uint64(len(payload)))
 | 
			
		||||
		logger.WithField("len", len(payload)).Debug("output data")
 | 
			
		||||
 | 
			
		||||
		if sentEOF {
 | 
			
		||||
| 
						 | 
				
			
			@ -214,19 +241,16 @@ func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error {
 | 
			
		|||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	blockSize := int(w.maxFrameSize)
 | 
			
		||||
	continuation := false
 | 
			
		||||
	endHeaders := len(encodedHeaders) == 0
 | 
			
		||||
	for !endHeaders && err == nil {
 | 
			
		||||
		blockFragment := encodedHeaders
 | 
			
		||||
		if len(encodedHeaders) > blockSize {
 | 
			
		||||
			blockFragment = blockFragment[:blockSize]
 | 
			
		||||
			encodedHeaders = encodedHeaders[blockSize:]
 | 
			
		||||
		} else {
 | 
			
		||||
			endHeaders = true
 | 
			
		||||
		}
 | 
			
		||||
		if continuation {
 | 
			
		||||
			// Send CONTINUATION frame if the headers can't be fit into 1 frame
 | 
			
		||||
			err = w.f.WriteContinuation(streamID, endHeaders, blockFragment)
 | 
			
		||||
		} else {
 | 
			
		||||
			endHeaders = true
 | 
			
		||||
			err = w.f.WriteHeaders(http2.HeadersFrameParam{
 | 
			
		||||
				StreamID:      streamID,
 | 
			
		||||
				EndHeaders:    endHeaders,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										24
									
								
								h2mux/rtt.go
								
								
								
								
							
							
						
						
									
										24
									
								
								h2mux/rtt.go
								
								
								
								
							| 
						 | 
				
			
			@ -2,7 +2,6 @@ package h2mux
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// PingTimestamp is an atomic interface around ping timestamping and signalling.
 | 
			
		||||
| 
						 | 
				
			
			@ -28,26 +27,3 @@ func (pt *PingTimestamp) Get() int64 {
 | 
			
		|||
func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} {
 | 
			
		||||
	return pt.signal.WaitChannel()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RTTMeasurement encapsulates a continuous round trip time measurement.
 | 
			
		||||
type RTTMeasurement struct {
 | 
			
		||||
	Current, Min, Max   time.Duration
 | 
			
		||||
	lastMeasurementTime time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Update updates the computed values with a new measurement.
 | 
			
		||||
// outgoingTime is the time that the probe was sent.
 | 
			
		||||
// We assume that time.Now() is the time we received that probe.
 | 
			
		||||
func (r *RTTMeasurement) Update(outgoingTime time.Time) {
 | 
			
		||||
	if !r.lastMeasurementTime.Before(outgoingTime) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	r.lastMeasurementTime = outgoingTime
 | 
			
		||||
	r.Current = time.Since(outgoingTime)
 | 
			
		||||
	if r.Max < r.Current {
 | 
			
		||||
		r.Max = r.Current
 | 
			
		||||
	}
 | 
			
		||||
	if r.Min > r.Current {
 | 
			
		||||
		r.Min = r.Current
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,13 +2,32 @@ package origin
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflare-warp/h2mux"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/h2mux"
 | 
			
		||||
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TunnelMetrics struct {
 | 
			
		||||
type muxerMetrics struct {
 | 
			
		||||
	rtt              *prometheus.GaugeVec
 | 
			
		||||
	rttMin           *prometheus.GaugeVec
 | 
			
		||||
	rttMax           *prometheus.GaugeVec
 | 
			
		||||
	receiveWindowAve *prometheus.GaugeVec
 | 
			
		||||
	sendWindowAve    *prometheus.GaugeVec
 | 
			
		||||
	receiveWindowMin *prometheus.GaugeVec
 | 
			
		||||
	receiveWindowMax *prometheus.GaugeVec
 | 
			
		||||
	sendWindowMin    *prometheus.GaugeVec
 | 
			
		||||
	sendWindowMax    *prometheus.GaugeVec
 | 
			
		||||
	inBoundRateCurr  *prometheus.GaugeVec
 | 
			
		||||
	inBoundRateMin   *prometheus.GaugeVec
 | 
			
		||||
	inBoundRateMax   *prometheus.GaugeVec
 | 
			
		||||
	outBoundRateCurr *prometheus.GaugeVec
 | 
			
		||||
	outBoundRateMin  *prometheus.GaugeVec
 | 
			
		||||
	outBoundRateMax  *prometheus.GaugeVec
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type tunnelMetrics struct {
 | 
			
		||||
	haConnections     prometheus.Gauge
 | 
			
		||||
	totalRequests     prometheus.Counter
 | 
			
		||||
	requestsPerTunnel *prometheus.CounterVec
 | 
			
		||||
| 
						 | 
				
			
			@ -20,16 +39,7 @@ type TunnelMetrics struct {
 | 
			
		|||
	maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
 | 
			
		||||
	// concurrentRequests records max count of concurrent requests for each tunnel
 | 
			
		||||
	maxConcurrentRequests map[string]uint64
 | 
			
		||||
	rtt                   prometheus.Gauge
 | 
			
		||||
	rttMin                prometheus.Gauge
 | 
			
		||||
	rttMax                prometheus.Gauge
 | 
			
		||||
	timerRetries          prometheus.Gauge
 | 
			
		||||
	receiveWindowSizeAve  prometheus.Gauge
 | 
			
		||||
	sendWindowSizeAve     prometheus.Gauge
 | 
			
		||||
	receiveWindowSizeMin  prometheus.Gauge
 | 
			
		||||
	receiveWindowSizeMax  prometheus.Gauge
 | 
			
		||||
	sendWindowSizeMin     prometheus.Gauge
 | 
			
		||||
	sendWindowSizeMax     prometheus.Gauge
 | 
			
		||||
	responseByCode        *prometheus.CounterVec
 | 
			
		||||
	responseCodePerTunnel *prometheus.CounterVec
 | 
			
		||||
	serverLocations       *prometheus.GaugeVec
 | 
			
		||||
| 
						 | 
				
			
			@ -37,10 +47,189 @@ type TunnelMetrics struct {
 | 
			
		|||
	locationLock sync.Mutex
 | 
			
		||||
	// oldServerLocations stores the last server the tunnel was connected to
 | 
			
		||||
	oldServerLocations map[string]string
 | 
			
		||||
 | 
			
		||||
	muxerMetrics *muxerMetrics
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newMuxerMetrics() *muxerMetrics {
 | 
			
		||||
	rtt := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "rtt",
 | 
			
		||||
			Help: "Round-trip time in millisecond",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(rtt)
 | 
			
		||||
 | 
			
		||||
	rttMin := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "rtt_min",
 | 
			
		||||
			Help: "Shortest round-trip time in millisecond",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(rttMin)
 | 
			
		||||
 | 
			
		||||
	rttMax := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "rtt_max",
 | 
			
		||||
			Help: "Longest round-trip time in millisecond",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(rttMax)
 | 
			
		||||
 | 
			
		||||
	receiveWindowAve := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "receive_window_ave",
 | 
			
		||||
			Help: "Average receive window size in bytes",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(receiveWindowAve)
 | 
			
		||||
 | 
			
		||||
	sendWindowAve := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "send_window_ave",
 | 
			
		||||
			Help: "Average send window size in bytes",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(sendWindowAve)
 | 
			
		||||
 | 
			
		||||
	receiveWindowMin := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "receive_window_min",
 | 
			
		||||
			Help: "Smallest receive window size in bytes",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(receiveWindowMin)
 | 
			
		||||
 | 
			
		||||
	receiveWindowMax := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "receive_window_max",
 | 
			
		||||
			Help: "Largest receive window size in bytes",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(receiveWindowMax)
 | 
			
		||||
 | 
			
		||||
	sendWindowMin := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "send_window_min",
 | 
			
		||||
			Help: "Smallest send window size in bytes",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(sendWindowMin)
 | 
			
		||||
 | 
			
		||||
	sendWindowMax := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "send_window_max",
 | 
			
		||||
			Help: "Largest send window size in bytes",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(sendWindowMax)
 | 
			
		||||
 | 
			
		||||
	inBoundRateCurr := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "inbound_bytes_per_sec_curr",
 | 
			
		||||
			Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(inBoundRateCurr)
 | 
			
		||||
 | 
			
		||||
	inBoundRateMin := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "inbound_bytes_per_sec_min",
 | 
			
		||||
			Help: "Minimum non-zero inbounding bytes per second",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(inBoundRateMin)
 | 
			
		||||
 | 
			
		||||
	inBoundRateMax := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "inbound_bytes_per_sec_max",
 | 
			
		||||
			Help: "Maximum inbounding bytes per second",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(inBoundRateMax)
 | 
			
		||||
 | 
			
		||||
	outBoundRateCurr := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "outbound_bytes_per_sec_curr",
 | 
			
		||||
			Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(outBoundRateCurr)
 | 
			
		||||
 | 
			
		||||
	outBoundRateMin := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "outbound_bytes_per_sec_min",
 | 
			
		||||
			Help: "Minimum non-zero outbounding bytes per second",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(outBoundRateMin)
 | 
			
		||||
 | 
			
		||||
	outBoundRateMax := prometheus.NewGaugeVec(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "outbound_bytes_per_sec_max",
 | 
			
		||||
			Help: "Maximum outbounding bytes per second",
 | 
			
		||||
		},
 | 
			
		||||
		[]string{"connection_id"},
 | 
			
		||||
	)
 | 
			
		||||
	prometheus.MustRegister(outBoundRateMax)
 | 
			
		||||
 | 
			
		||||
	return &muxerMetrics{
 | 
			
		||||
		rtt:              rtt,
 | 
			
		||||
		rttMin:           rttMin,
 | 
			
		||||
		rttMax:           rttMax,
 | 
			
		||||
		receiveWindowAve: receiveWindowAve,
 | 
			
		||||
		sendWindowAve:    sendWindowAve,
 | 
			
		||||
		receiveWindowMin: receiveWindowMin,
 | 
			
		||||
		receiveWindowMax: receiveWindowMax,
 | 
			
		||||
		sendWindowMin:    sendWindowMin,
 | 
			
		||||
		sendWindowMax:    sendWindowMax,
 | 
			
		||||
		inBoundRateCurr:  inBoundRateCurr,
 | 
			
		||||
		inBoundRateMin:   inBoundRateMin,
 | 
			
		||||
		inBoundRateMax:   inBoundRateMax,
 | 
			
		||||
		outBoundRateCurr: outBoundRateCurr,
 | 
			
		||||
		outBoundRateMin:  outBoundRateMin,
 | 
			
		||||
		outBoundRateMax:  outBoundRateMax,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) {
 | 
			
		||||
	m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT))
 | 
			
		||||
	m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin))
 | 
			
		||||
	m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax))
 | 
			
		||||
	m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve)
 | 
			
		||||
	m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve)
 | 
			
		||||
	m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin))
 | 
			
		||||
	m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax))
 | 
			
		||||
	m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin))
 | 
			
		||||
	m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax))
 | 
			
		||||
	m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr))
 | 
			
		||||
	m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin))
 | 
			
		||||
	m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax))
 | 
			
		||||
	m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr))
 | 
			
		||||
	m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin))
 | 
			
		||||
	m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func convertRTTMilliSec(t time.Duration) float64 {
 | 
			
		||||
	return float64(t / time.Millisecond)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Metrics that can be collected without asking the edge
 | 
			
		||||
func NewTunnelMetrics() *TunnelMetrics {
 | 
			
		||||
func NewTunnelMetrics() *tunnelMetrics {
 | 
			
		||||
	haConnections := prometheus.NewGauge(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "ha_connections",
 | 
			
		||||
| 
						 | 
				
			
			@ -82,27 +271,6 @@ func NewTunnelMetrics() *TunnelMetrics {
 | 
			
		|||
	)
 | 
			
		||||
	prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
 | 
			
		||||
 | 
			
		||||
	rtt := prometheus.NewGauge(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "rtt",
 | 
			
		||||
			Help: "Round-trip time",
 | 
			
		||||
		})
 | 
			
		||||
	prometheus.MustRegister(rtt)
 | 
			
		||||
 | 
			
		||||
	rttMin := prometheus.NewGauge(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "rtt_min",
 | 
			
		||||
			Help: "Shortest round-trip time",
 | 
			
		||||
		})
 | 
			
		||||
	prometheus.MustRegister(rttMin)
 | 
			
		||||
 | 
			
		||||
	rttMax := prometheus.NewGauge(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "rtt_max",
 | 
			
		||||
			Help: "Longest round-trip time",
 | 
			
		||||
		})
 | 
			
		||||
	prometheus.MustRegister(rttMax)
 | 
			
		||||
 | 
			
		||||
	timerRetries := prometheus.NewGauge(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "timer_retries",
 | 
			
		||||
| 
						 | 
				
			
			@ -110,48 +278,6 @@ func NewTunnelMetrics() *TunnelMetrics {
 | 
			
		|||
		})
 | 
			
		||||
	prometheus.MustRegister(timerRetries)
 | 
			
		||||
 | 
			
		||||
	receiveWindowSizeAve := prometheus.NewGauge(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "receive_window_ave",
 | 
			
		||||
			Help: "Average receive window size",
 | 
			
		||||
		})
 | 
			
		||||
	prometheus.MustRegister(receiveWindowSizeAve)
 | 
			
		||||
 | 
			
		||||
	sendWindowSizeAve := prometheus.NewGauge(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "send_window_ave",
 | 
			
		||||
			Help: "Average send window size",
 | 
			
		||||
		})
 | 
			
		||||
	prometheus.MustRegister(sendWindowSizeAve)
 | 
			
		||||
 | 
			
		||||
	receiveWindowSizeMin := prometheus.NewGauge(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "receive_window_min",
 | 
			
		||||
			Help: "Smallest receive window size",
 | 
			
		||||
		})
 | 
			
		||||
	prometheus.MustRegister(receiveWindowSizeMin)
 | 
			
		||||
 | 
			
		||||
	receiveWindowSizeMax := prometheus.NewGauge(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "receive_window_max",
 | 
			
		||||
			Help: "Largest receive window size",
 | 
			
		||||
		})
 | 
			
		||||
	prometheus.MustRegister(receiveWindowSizeMax)
 | 
			
		||||
 | 
			
		||||
	sendWindowSizeMin := prometheus.NewGauge(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "send_window_min",
 | 
			
		||||
			Help: "Smallest send window size",
 | 
			
		||||
		})
 | 
			
		||||
	prometheus.MustRegister(sendWindowSizeMin)
 | 
			
		||||
 | 
			
		||||
	sendWindowSizeMax := prometheus.NewGauge(
 | 
			
		||||
		prometheus.GaugeOpts{
 | 
			
		||||
			Name: "send_window_max",
 | 
			
		||||
			Help: "Largest send window size",
 | 
			
		||||
		})
 | 
			
		||||
	prometheus.MustRegister(sendWindowSizeMax)
 | 
			
		||||
 | 
			
		||||
	responseByCode := prometheus.NewCounterVec(
 | 
			
		||||
		prometheus.CounterOpts{
 | 
			
		||||
			Name: "response_by_code",
 | 
			
		||||
| 
						 | 
				
			
			@ -179,7 +305,7 @@ func NewTunnelMetrics() *TunnelMetrics {
 | 
			
		|||
	)
 | 
			
		||||
	prometheus.MustRegister(serverLocations)
 | 
			
		||||
 | 
			
		||||
	return &TunnelMetrics{
 | 
			
		||||
	return &tunnelMetrics{
 | 
			
		||||
		haConnections:                  haConnections,
 | 
			
		||||
		totalRequests:                  totalRequests,
 | 
			
		||||
		requestsPerTunnel:              requestsPerTunnel,
 | 
			
		||||
| 
						 | 
				
			
			@ -187,41 +313,28 @@ func NewTunnelMetrics() *TunnelMetrics {
 | 
			
		|||
		concurrentRequests:             make(map[string]uint64),
 | 
			
		||||
		maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
 | 
			
		||||
		maxConcurrentRequests:          make(map[string]uint64),
 | 
			
		||||
		rtt:                   rtt,
 | 
			
		||||
		rttMin:                rttMin,
 | 
			
		||||
		rttMax:                rttMax,
 | 
			
		||||
		timerRetries:          timerRetries,
 | 
			
		||||
		receiveWindowSizeAve:  receiveWindowSizeAve,
 | 
			
		||||
		sendWindowSizeAve:     sendWindowSizeAve,
 | 
			
		||||
		receiveWindowSizeMin:  receiveWindowSizeMin,
 | 
			
		||||
		receiveWindowSizeMax:  receiveWindowSizeMax,
 | 
			
		||||
		sendWindowSizeMin:     sendWindowSizeMin,
 | 
			
		||||
		sendWindowSizeMax:     sendWindowSizeMax,
 | 
			
		||||
		responseByCode:        responseByCode,
 | 
			
		||||
		responseCodePerTunnel: responseCodePerTunnel,
 | 
			
		||||
		serverLocations:       serverLocations,
 | 
			
		||||
		oldServerLocations:    make(map[string]string),
 | 
			
		||||
		timerRetries:                   timerRetries,
 | 
			
		||||
		responseByCode:                 responseByCode,
 | 
			
		||||
		responseCodePerTunnel:          responseCodePerTunnel,
 | 
			
		||||
		serverLocations:                serverLocations,
 | 
			
		||||
		oldServerLocations:             make(map[string]string),
 | 
			
		||||
		muxerMetrics:                   newMuxerMetrics(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *TunnelMetrics) incrementHaConnections() {
 | 
			
		||||
func (t *tunnelMetrics) incrementHaConnections() {
 | 
			
		||||
	t.haConnections.Inc()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *TunnelMetrics) decrementHaConnections() {
 | 
			
		||||
func (t *tunnelMetrics) decrementHaConnections() {
 | 
			
		||||
	t.haConnections.Dec()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *TunnelMetrics) updateTunnelFlowControlMetrics(metrics *h2mux.FlowControlMetrics) {
 | 
			
		||||
	t.receiveWindowSizeAve.Set(float64(metrics.AverageReceiveWindowSize))
 | 
			
		||||
	t.sendWindowSizeAve.Set(float64(metrics.AverageSendWindowSize))
 | 
			
		||||
	t.receiveWindowSizeMin.Set(float64(metrics.MinReceiveWindowSize))
 | 
			
		||||
	t.receiveWindowSizeMax.Set(float64(metrics.MaxReceiveWindowSize))
 | 
			
		||||
	t.sendWindowSizeMin.Set(float64(metrics.MinSendWindowSize))
 | 
			
		||||
	t.sendWindowSizeMax.Set(float64(metrics.MaxSendWindowSize))
 | 
			
		||||
func (t *tunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) {
 | 
			
		||||
	t.muxerMetrics.update(connectionID, metrics)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *TunnelMetrics) incrementRequests(connectionID string) {
 | 
			
		||||
func (t *tunnelMetrics) incrementRequests(connectionID string) {
 | 
			
		||||
	t.concurrentRequestsLock.Lock()
 | 
			
		||||
	var concurrentRequests uint64
 | 
			
		||||
	var ok bool
 | 
			
		||||
| 
						 | 
				
			
			@ -243,7 +356,7 @@ func (t *TunnelMetrics) incrementRequests(connectionID string) {
 | 
			
		|||
	t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
 | 
			
		||||
func (t *tunnelMetrics) decrementConcurrentRequests(connectionID string) {
 | 
			
		||||
	t.concurrentRequestsLock.Lock()
 | 
			
		||||
	if _, ok := t.concurrentRequests[connectionID]; ok {
 | 
			
		||||
		t.concurrentRequests[connectionID] -= 1
 | 
			
		||||
| 
						 | 
				
			
			@ -255,13 +368,13 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
 | 
			
		|||
	t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *TunnelMetrics) incrementResponses(connectionID, code string) {
 | 
			
		||||
func (t *tunnelMetrics) incrementResponses(connectionID, code string) {
 | 
			
		||||
	t.responseByCode.WithLabelValues(code).Inc()
 | 
			
		||||
	t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) {
 | 
			
		||||
func (t *tunnelMetrics) registerServerLocation(connectionID, loc string) {
 | 
			
		||||
	t.locationLock.Lock()
 | 
			
		||||
	defer t.locationLock.Unlock()
 | 
			
		||||
	if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,11 +15,11 @@ import (
 | 
			
		|||
 | 
			
		||||
	"golang.org/x/net/context"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflare-warp/h2mux"
 | 
			
		||||
	"github.com/cloudflare/cloudflare-warp/tunnelrpc"
 | 
			
		||||
	tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
 | 
			
		||||
	"github.com/cloudflare/cloudflare-warp/validation"
 | 
			
		||||
	"github.com/cloudflare/cloudflare-warp/websocket"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/h2mux"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/tunnelrpc"
 | 
			
		||||
	tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/validation"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/websocket"
 | 
			
		||||
 | 
			
		||||
	raven "github.com/getsentry/raven-go"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
| 
						 | 
				
			
			@ -53,7 +53,7 @@ type TunnelConfig struct {
 | 
			
		|||
	Tags              []tunnelpogs.Tag
 | 
			
		||||
	HAConnections     int
 | 
			
		||||
	HTTPTransport     http.RoundTripper
 | 
			
		||||
	Metrics           *TunnelMetrics
 | 
			
		||||
	Metrics           *tunnelMetrics
 | 
			
		||||
	MetricsUpdateFreq time.Duration
 | 
			
		||||
	ProtocolLogger    *logrus.Logger
 | 
			
		||||
	Logger            *logrus.Logger
 | 
			
		||||
| 
						 | 
				
			
			@ -185,6 +185,7 @@ func ServeTunnel(
 | 
			
		|||
	serveCtx, serveCancel := context.WithCancel(ctx)
 | 
			
		||||
	registerErrC := make(chan error, 1)
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer wg.Done()
 | 
			
		||||
		err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			connectedFuse.Fuse(true)
 | 
			
		||||
| 
						 | 
				
			
			@ -193,18 +194,18 @@ func ServeTunnel(
 | 
			
		|||
			serveCancel()
 | 
			
		||||
		}
 | 
			
		||||
		registerErrC <- err
 | 
			
		||||
		wg.Done()
 | 
			
		||||
	}()
 | 
			
		||||
	updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer wg.Done()
 | 
			
		||||
		connectionTag := uint8ToString(connectionID)
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case <-serveCtx.Done():
 | 
			
		||||
				handler.muxer.Shutdown()
 | 
			
		||||
				return
 | 
			
		||||
			case <-updateMetricsTickC:
 | 
			
		||||
				handler.UpdateMetrics()
 | 
			
		||||
				handler.UpdateMetrics(connectionTag)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
| 
						 | 
				
			
			@ -303,7 +304,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
 | 
			
		|||
func LogServerInfo(logger *logrus.Entry,
 | 
			
		||||
	promise tunnelrpc.ServerInfo_Promise,
 | 
			
		||||
	connectionID uint8,
 | 
			
		||||
	metrics *TunnelMetrics,
 | 
			
		||||
	metrics *tunnelMetrics,
 | 
			
		||||
) {
 | 
			
		||||
	serverInfoMessage, err := promise.Struct()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -356,13 +357,17 @@ func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
 | 
			
		|||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FindCfRayHeader(h1 *http.Request) string {
 | 
			
		||||
	return h1.Header.Get("Cf-Ray")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TunnelHandler struct {
 | 
			
		||||
	originUrl  string
 | 
			
		||||
	muxer      *h2mux.Muxer
 | 
			
		||||
	httpClient http.RoundTripper
 | 
			
		||||
	tlsConfig  *tls.Config
 | 
			
		||||
	tags       []tunnelpogs.Tag
 | 
			
		||||
	metrics    *TunnelMetrics
 | 
			
		||||
	metrics    *tunnelMetrics
 | 
			
		||||
	// connectionID is only used by metrics, and prometheus requires labels to be string
 | 
			
		||||
	connectionID string
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -435,7 +440,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
 | 
			
		|||
		Log.WithError(err).Error("invalid request received")
 | 
			
		||||
	}
 | 
			
		||||
	h.AppendTagHeaders(req)
 | 
			
		||||
 | 
			
		||||
	cfRay := FindCfRayHeader(req)
 | 
			
		||||
	h.logRequest(req, cfRay)
 | 
			
		||||
	if websocket.IsWebSocketUpgrade(req) {
 | 
			
		||||
		conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -444,6 +450,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
 | 
			
		|||
			stream.WriteHeaders(H1ResponseToH2Response(response))
 | 
			
		||||
			defer conn.Close()
 | 
			
		||||
			websocket.Stream(conn.UnderlyingConn(), stream)
 | 
			
		||||
			h.metrics.incrementResponses(h.connectionID, "200")
 | 
			
		||||
			h.logResponse(response, cfRay)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		response, err := h.httpClient.RoundTrip(req)
 | 
			
		||||
| 
						 | 
				
			
			@ -454,6 +462,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
 | 
			
		|||
			stream.WriteHeaders(H1ResponseToH2Response(response))
 | 
			
		||||
			io.Copy(stream, response.Body)
 | 
			
		||||
			h.metrics.incrementResponses(h.connectionID, "200")
 | 
			
		||||
			h.logResponse(response, cfRay)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	h.metrics.decrementConcurrentRequests(h.connectionID)
 | 
			
		||||
| 
						 | 
				
			
			@ -467,9 +476,27 @@ func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
 | 
			
		|||
	h.metrics.incrementResponses(h.connectionID, "502")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *TunnelHandler) UpdateMetrics() {
 | 
			
		||||
	flowCtlMetrics := h.muxer.FlowControlMetrics()
 | 
			
		||||
	h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics)
 | 
			
		||||
func (h *TunnelHandler) logRequest(req *http.Request, cfRay string) {
 | 
			
		||||
	if cfRay != "" {
 | 
			
		||||
		Log.WithField("CF-RAY", cfRay).Infof("%s %s %s", req.Method, req.URL, req.Proto)
 | 
			
		||||
	} else {
 | 
			
		||||
		Log.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto)
 | 
			
		||||
	}
 | 
			
		||||
	Log.Debugf("Request Headers %+v", req.Header)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *TunnelHandler) logResponse(r *http.Response, cfRay string) {
 | 
			
		||||
	if cfRay != "" {
 | 
			
		||||
		Log.WithField("CF-RAY", cfRay).Infof("%s", r.Status)
 | 
			
		||||
	} else {
 | 
			
		||||
		Log.Infof("%s", r.Status)
 | 
			
		||||
	}
 | 
			
		||||
	Log.Debugf("Response Headers %+v", r.Header)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
func (h *TunnelHandler) UpdateMetrics(connectionID string) {
 | 
			
		||||
	h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func uint8ToString(input uint8) string {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,28 +11,28 @@ const (
 | 
			
		|||
BgUrgQQAIg==
 | 
			
		||||
-----END EC PARAMETERS-----
 | 
			
		||||
-----BEGIN EC PRIVATE KEY-----
 | 
			
		||||
MIGkAgEBBDAdyQBXfxTDCQSOT0HugmH9pVBtIw8t5dYvm6HxGlNq6P57v5GeN02Z
 | 
			
		||||
dH9FRl7+VSWgBwYFK4EEACKhZANiAATqpFzTxxV7D+/oqhKCTR6BEM9elTfKaRQE
 | 
			
		||||
FsLufcmaTMw/9tTwgpHKao/QsLKDTNbQhbSQLkcmpCQKlSGhl+pCrqNt/oYUAhav
 | 
			
		||||
UIwpwGiLCqGH/R2AqWLKRPOa/Rufs/U=
 | 
			
		||||
MIGkAgEBBDBGGfwhIJdiUiJUVIItqJjEIMmlXxsMa8TQeer47+g+cIZ466rgg8EK
 | 
			
		||||
+Mdn6BY48GCgBwYFK4EEACKhZANiAASW//A9iDbPKg3OLkn7yJqLer32g9I5lBKR
 | 
			
		||||
tPc/zBubQLLz9lAaYI6AOQiJXhGr5JkKmQfi1sYHK5rJITPFy4W8Et4hHLdazDZH
 | 
			
		||||
WnEd+TStQABFUjrhtqXPWmGKcly0pOE=
 | 
			
		||||
-----END EC PRIVATE KEY-----`
 | 
			
		||||
 | 
			
		||||
	helloCRT = `
 | 
			
		||||
-----BEGIN CERTIFICATE-----
 | 
			
		||||
MIICkDCCAhigAwIBAgIJAPtKfUjc2lwGMAkGByqGSM49BAEwgYoxCzAJBgNVBAYT
 | 
			
		||||
AlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQKDBBD
 | 
			
		||||
bG91ZGZsYXJlLCBJbmMuMT8wPQYDVQQDDDZDbG91ZGZsYXJlIEFyZ28gVHVubmVs
 | 
			
		||||
IFNhbXBsZSBIZWxsbyBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMjE1MjAxNjU5
 | 
			
		||||
WhcNMjgwMjEzMjAxNjU5WjCBijELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVRleGFz
 | 
			
		||||
MQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENsb3VkZmxhcmUsIEluYy4xPzA9
 | 
			
		||||
BgNVBAMMNkNsb3VkZmxhcmUgQXJnbyBUdW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZl
 | 
			
		||||
ciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49AgEGBSuBBAAiA2IABOqkXNPHFXsP7+iq
 | 
			
		||||
EoJNHoEQz16VN8ppFAQWwu59yZpMzD/21PCCkcpqj9CwsoNM1tCFtJAuRyakJAqV
 | 
			
		||||
IaGX6kKuo23+hhQCFq9QjCnAaIsKoYf9HYCpYspE85r9G5+z9aNJMEcwRQYDVR0R
 | 
			
		||||
BD4wPIIJbG9jYWxob3N0ggp3YXJwLWhlbGxvggt3YXJwMi1oZWxsb4cEfwAAAYcQ
 | 
			
		||||
AAAAAAAAAAAAAAAAAAAAATAJBgcqhkjOPQQBA2cAMGQCMHyVPufXZ6vQo6XRWRa0
 | 
			
		||||
dAwtfgesOdZVP2Wt+t5v8jOIQQh1IQXYk5GtyoZGSObjhQIwd1fRgAyKXaZt+1DV
 | 
			
		||||
ZtHTdf8pMvESfJsSd8AB1eQ6q+pAiRUYyaxcE1Mlo2YY5o+g
 | 
			
		||||
MIICiDCCAg6gAwIBAgIJAJ/FfkBTtbuIMAkGByqGSM49BAEwfzELMAkGA1UEBhMC
 | 
			
		||||
VVMxDjAMBgNVBAgMBVRleGFzMQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENs
 | 
			
		||||
b3VkZmxhcmUsIEluYy4xNDAyBgNVBAMMK0FyZ28gVHVubmVsIFNhbXBsZSBIZWxs
 | 
			
		||||
byBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMzE5MjMwNTMyWhcNMjgwMzE2MjMw
 | 
			
		||||
NTMyWjB/MQswCQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1
 | 
			
		||||
c3RpbjEZMBcGA1UECgwQQ2xvdWRmbGFyZSwgSW5jLjE0MDIGA1UEAwwrQXJnbyBU
 | 
			
		||||
dW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZlciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49
 | 
			
		||||
AgEGBSuBBAAiA2IABJb/8D2INs8qDc4uSfvImot6vfaD0jmUEpG09z/MG5tAsvP2
 | 
			
		||||
UBpgjoA5CIleEavkmQqZB+LWxgcrmskhM8XLhbwS3iEct1rMNkdacR35NK1AAEVS
 | 
			
		||||
OuG2pc9aYYpyXLSk4aNXMFUwUwYDVR0RBEwwSoIJbG9jYWxob3N0ghFjbG91ZGZs
 | 
			
		||||
YXJlZC1oZWxsb4ISY2xvdWRmbGFyZWQyLWhlbGxvhwR/AAABhxAAAAAAAAAAAAAA
 | 
			
		||||
AAAAAAABMAkGByqGSM49BAEDaQAwZgIxAPxkdghH6y8xLMnY9Bom3Llf4NYM6yB9
 | 
			
		||||
PD1YsaNUJTsxjTk3YY1Jsp+yzK0yUKtTZwIxAPcdvqCF2/iR9H288pCT1TgtO0a9
 | 
			
		||||
cJL9RY1lq7DIGN37v1ZXReWaD+3hNokY8NriVg==
 | 
			
		||||
-----END CERTIFICATE-----`
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,38 @@
 | 
			
		|||
package tunneldns
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/coredns/coredns/plugin"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"golang.org/x/net/context"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Upstream is a simplified interface for proxy destination
 | 
			
		||||
type Upstream interface {
 | 
			
		||||
	Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ProxyPlugin is a simplified DNS proxy using a generic upstream interface
 | 
			
		||||
type ProxyPlugin struct {
 | 
			
		||||
	Upstreams []Upstream
 | 
			
		||||
	Next      plugin.Handler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ServeDNS implements interface for CoreDNS plugin
 | 
			
		||||
func (p ProxyPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
 | 
			
		||||
	var reply *dns.Msg
 | 
			
		||||
	var backendErr error
 | 
			
		||||
 | 
			
		||||
	for _, upstream := range p.Upstreams {
 | 
			
		||||
		reply, backendErr = upstream.Exchange(ctx, r)
 | 
			
		||||
		if backendErr == nil {
 | 
			
		||||
			w.WriteMsg(reply)
 | 
			
		||||
			return 0, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Name implements interface for CoreDNS plugin
 | 
			
		||||
func (p ProxyPlugin) Name() string { return "proxy" }
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,97 @@
 | 
			
		|||
package tunneldns
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
	"golang.org/x/net/context"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	defaultTimeout = 5 * time.Second
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// UpstreamHTTPS is the upstream implementation for DNS over HTTPS service
 | 
			
		||||
type UpstreamHTTPS struct {
 | 
			
		||||
	client   *http.Client
 | 
			
		||||
	endpoint *url.URL
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewUpstreamHTTPS creates a new DNS over HTTPS upstream from hostname
 | 
			
		||||
func NewUpstreamHTTPS(endpoint string) (Upstream, error) {
 | 
			
		||||
	u, err := url.Parse(endpoint)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Update TLS and HTTP client configuration
 | 
			
		||||
	tls := &tls.Config{ServerName: u.Hostname()}
 | 
			
		||||
	client := &http.Client{
 | 
			
		||||
		Timeout:   time.Second * defaultTimeout,
 | 
			
		||||
		Transport: &http.Transport{TLSClientConfig: tls},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &UpstreamHTTPS{client: client, endpoint: u}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Exchange provides an implementation for the Upstream interface
 | 
			
		||||
func (u *UpstreamHTTPS) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
 | 
			
		||||
	queryBuf, err := query.Pack()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "failed to pack DNS query")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// No content negotiation for now, use DNS wire format
 | 
			
		||||
	buf, backendErr := u.exchangeWireformat(queryBuf)
 | 
			
		||||
	if backendErr == nil {
 | 
			
		||||
		response := &dns.Msg{}
 | 
			
		||||
		if err := response.Unpack(buf); err != nil {
 | 
			
		||||
			return nil, errors.Wrap(err, "failed to unpack DNS response from body")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		response.Id = query.Id
 | 
			
		||||
		return response, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.WithError(backendErr).Errorf("failed to connect to an HTTPS backend %q", u.endpoint)
 | 
			
		||||
	return nil, backendErr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Perform message exchange with the default UDP wireformat defined in current draft
 | 
			
		||||
// https://datatracker.ietf.org/doc/draft-ietf-doh-dns-over-https
 | 
			
		||||
func (u *UpstreamHTTPS) exchangeWireformat(msg []byte) ([]byte, error) {
 | 
			
		||||
	req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "failed to create an HTTPS request")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req.Header.Add("Content-Type", "application/dns-udpwireformat")
 | 
			
		||||
	req.Host = u.endpoint.Hostname()
 | 
			
		||||
 | 
			
		||||
	resp, err := u.client.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "failed to perform an HTTPS request")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Check response status code
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
	if resp.StatusCode != http.StatusOK {
 | 
			
		||||
		return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Read wireformat response from the body
 | 
			
		||||
	buf, err := ioutil.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "failed to read the response body")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return buf, nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,45 @@
 | 
			
		|||
package tunneldns
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/coredns/coredns/plugin"
 | 
			
		||||
	"github.com/coredns/coredns/plugin/metrics/vars"
 | 
			
		||||
	"github.com/coredns/coredns/plugin/pkg/dnstest"
 | 
			
		||||
	"github.com/coredns/coredns/plugin/pkg/rcode"
 | 
			
		||||
	"github.com/coredns/coredns/request"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	"github.com/prometheus/client_golang/prometheus"
 | 
			
		||||
	"golang.org/x/net/context"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// MetricsPlugin is an adapter for CoreDNS and built-in metrics
 | 
			
		||||
type MetricsPlugin struct {
 | 
			
		||||
	Next plugin.Handler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewMetricsPlugin creates a plugin with configured metrics
 | 
			
		||||
func NewMetricsPlugin(next plugin.Handler) *MetricsPlugin {
 | 
			
		||||
	prometheus.MustRegister(vars.RequestCount)
 | 
			
		||||
	prometheus.MustRegister(vars.RequestDuration)
 | 
			
		||||
	prometheus.MustRegister(vars.RequestSize)
 | 
			
		||||
	prometheus.MustRegister(vars.RequestDo)
 | 
			
		||||
	prometheus.MustRegister(vars.RequestType)
 | 
			
		||||
	prometheus.MustRegister(vars.ResponseSize)
 | 
			
		||||
	prometheus.MustRegister(vars.ResponseRcode)
 | 
			
		||||
	return &MetricsPlugin{Next: next}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ServeDNS implements the CoreDNS plugin interface
 | 
			
		||||
func (p MetricsPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
 | 
			
		||||
	state := request.Request{W: w, Req: r}
 | 
			
		||||
 | 
			
		||||
	rw := dnstest.NewRecorder(w)
 | 
			
		||||
	status, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, rw, r)
 | 
			
		||||
 | 
			
		||||
	// Update built-in metrics
 | 
			
		||||
	vars.Report(state, ".", rcode.ToString(rw.Rcode), rw.Len, rw.Start)
 | 
			
		||||
 | 
			
		||||
	return status, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Name implements the CoreDNS plugin interface
 | 
			
		||||
func (p MetricsPlugin) Name() string { return "metrics" }
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,144 @@
 | 
			
		|||
package tunneldns
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os"
 | 
			
		||||
	"os/signal"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"syscall"
 | 
			
		||||
 | 
			
		||||
	"gopkg.in/urfave/cli.v2"
 | 
			
		||||
 | 
			
		||||
	"github.com/cloudflare/cloudflared/metrics"
 | 
			
		||||
	"github.com/coredns/coredns/core/dnsserver"
 | 
			
		||||
	"github.com/coredns/coredns/plugin"
 | 
			
		||||
	"github.com/coredns/coredns/plugin/cache"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	log "github.com/sirupsen/logrus"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Listener is an adapter between CoreDNS server and Warp runnable
 | 
			
		||||
type Listener struct {
 | 
			
		||||
	server *dnsserver.Server
 | 
			
		||||
	wg     sync.WaitGroup
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Run implements a foreground runner
 | 
			
		||||
func Run(c *cli.Context) error {
 | 
			
		||||
	metricsListener, err := net.Listen("tcp", c.String("metrics"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.WithError(err).Fatal("Failed to open the metrics listener")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go metrics.ServeMetrics(metricsListener, nil)
 | 
			
		||||
 | 
			
		||||
	listener, err := CreateListener(c.String("address"), uint16(c.Uint("port")), c.StringSlice("upstream"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.WithError(err).Errorf("Failed to create the listeners")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Try to start the server
 | 
			
		||||
	err = listener.Start()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.WithError(err).Errorf("Failed to start the listeners")
 | 
			
		||||
		return listener.Stop()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Wait for signal
 | 
			
		||||
	signals := make(chan os.Signal, 10)
 | 
			
		||||
	signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
 | 
			
		||||
	defer signal.Stop(signals)
 | 
			
		||||
	<-signals
 | 
			
		||||
 | 
			
		||||
	// Shut down server
 | 
			
		||||
	err = listener.Stop()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.WithError(err).Errorf("failed to stop")
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Create a CoreDNS server plugin from configuration
 | 
			
		||||
func createConfig(address string, port uint16, p plugin.Handler) *dnsserver.Config {
 | 
			
		||||
	c := &dnsserver.Config{
 | 
			
		||||
		Zone:        ".",
 | 
			
		||||
		Transport:   "dns",
 | 
			
		||||
		ListenHosts: []string{address},
 | 
			
		||||
		Port:        strconv.FormatUint(uint64(port), 10),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.AddPlugin(func(next plugin.Handler) plugin.Handler { return p })
 | 
			
		||||
	return c
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Start blocks for serving requests
 | 
			
		||||
func (l *Listener) Start() error {
 | 
			
		||||
	log.WithField("addr", l.server.Address()).Infof("Starting DNS over HTTPS proxy server")
 | 
			
		||||
 | 
			
		||||
	// Start UDP listener
 | 
			
		||||
	if udp, err := l.server.ListenPacket(); err == nil {
 | 
			
		||||
		l.wg.Add(1)
 | 
			
		||||
		go func() {
 | 
			
		||||
			l.server.ServePacket(udp)
 | 
			
		||||
			l.wg.Done()
 | 
			
		||||
		}()
 | 
			
		||||
	} else {
 | 
			
		||||
		return errors.Wrap(err, "failed to create a UDP listener")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Start TCP listener
 | 
			
		||||
	tcp, err := l.server.Listen()
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		l.wg.Add(1)
 | 
			
		||||
		go func() {
 | 
			
		||||
			l.server.Serve(tcp)
 | 
			
		||||
			l.wg.Done()
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return errors.Wrap(err, "failed to create a TCP listener")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Stop signals server shutdown and blocks until completed
 | 
			
		||||
func (l *Listener) Stop() error {
 | 
			
		||||
	if err := l.server.Stop(); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	l.wg.Wait()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateListener configures the server and bound sockets
 | 
			
		||||
func CreateListener(address string, port uint16, upstreams []string) (*Listener, error) {
 | 
			
		||||
	// Build the list of upstreams
 | 
			
		||||
	upstreamList := make([]Upstream, 0)
 | 
			
		||||
	for _, url := range upstreams {
 | 
			
		||||
		log.WithField("url", url).Infof("Adding DNS upstream")
 | 
			
		||||
		upstream, err := NewUpstreamHTTPS(url)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, errors.Wrap(err, "failed to create HTTPS upstream")
 | 
			
		||||
		}
 | 
			
		||||
		upstreamList = append(upstreamList, upstream)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create a local cache with HTTPS proxy plugin
 | 
			
		||||
	chain := cache.New()
 | 
			
		||||
	chain.Next = ProxyPlugin{
 | 
			
		||||
		Upstreams: upstreamList,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Format an endpoint
 | 
			
		||||
	endpoint := fmt.Sprintf("dns://%s:%d", address, port)
 | 
			
		||||
 | 
			
		||||
	// Create the actual middleware server
 | 
			
		||||
	server, err := dnsserver.NewServer(endpoint, []*dnsserver.Config{createConfig(address, port, NewMetricsPlugin(chain))})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &Listener{server: server}, nil
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -1,7 +1,7 @@
 | 
			
		|||
package pogs
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/cloudflare/cloudflare-warp/tunnelrpc"
 | 
			
		||||
	"github.com/cloudflare/cloudflared/tunnelrpc"
 | 
			
		||||
	"golang.org/x/net/context"
 | 
			
		||||
	"zombiezen.com/go/capnproto2"
 | 
			
		||||
	"zombiezen.com/go/capnproto2/pogs"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,7 +1,7 @@
 | 
			
		|||
using Go = import "go.capnp";
 | 
			
		||||
@0xdb8274f9144abc7e;
 | 
			
		||||
$Go.package("tunnelrpc");
 | 
			
		||||
$Go.import("github.com/cloudflare/cloudflare-warp/tunnelrpc");
 | 
			
		||||
$Go.import("github.com/cloudflare/cloudflared/tunnelrpc");
 | 
			
		||||
 | 
			
		||||
struct Authentication {
 | 
			
		||||
    key @0 :Text;
 | 
			
		||||
| 
						 | 
				
			
			@ -35,7 +35,7 @@ struct RegistrationOptions {
 | 
			
		|||
    connectionId @6 :UInt8;
 | 
			
		||||
    # origin LAN IP
 | 
			
		||||
    originLocalIp @7 :Text;
 | 
			
		||||
    # whether Warp client has been autoupdated
 | 
			
		||||
    # whether Argo Tunnel client has been autoupdated
 | 
			
		||||
    isAutoupdated @8 :Bool;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -119,7 +119,7 @@ func validateScheme(scheme string) error {
 | 
			
		|||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Errorf("Currently Cloudflare-Warp does not support %s protocol.", scheme)
 | 
			
		||||
	return fmt.Errorf("Currently Argo Tunnel does not support %s protocol.", scheme)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateIP(scheme, host, port string) (string, error) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -126,7 +126,7 @@ func TestValidateUrl(t *testing.T) {
 | 
			
		|||
	assert.Equal(t, "https://hello.example.com", validUrl)
 | 
			
		||||
 | 
			
		||||
	validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
 | 
			
		||||
	assert.Equal(t, "Currently Cloudflare-Warp does not support ftp protocol.", err.Error())
 | 
			
		||||
	assert.Equal(t, "Currently Argo Tunnel does not support ftp protocol.", err.Error())
 | 
			
		||||
	assert.Empty(t, validUrl)
 | 
			
		||||
 | 
			
		||||
	validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue