Release Argo Tunnel Client 2018.3.1
This commit is contained in:
		
							parent
							
								
									9f5cec8dbc
								
							
						
					
					
						commit
						d0a6a2a829
					
				| 
						 | 
					@ -0,0 +1,13 @@
 | 
				
			||||||
 | 
					// +build !windows,!darwin,!linux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cli "gopkg.in/urfave/cli.v2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func runApp(app *cli.App) {
 | 
				
			||||||
 | 
						app.Run(os.Args)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,204 @@
 | 
				
			||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"crypto/tls"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"html/template"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/gorilla/websocket"
 | 
				
			||||||
 | 
						"gopkg.in/urfave/cli.v2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/cloudflare/cloudflared/tlsconfig"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type templateData struct {
 | 
				
			||||||
 | 
						ServerName string
 | 
				
			||||||
 | 
						Request    *http.Request
 | 
				
			||||||
 | 
						Body       string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type OriginUpTime struct {
 | 
				
			||||||
 | 
						StartTime time.Time `json:"startTime"`
 | 
				
			||||||
 | 
						UpTime    string    `json:"uptime"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const defaultServerName = "the Argo Tunnel test server"
 | 
				
			||||||
 | 
					const indexTemplate = `
 | 
				
			||||||
 | 
					<!DOCTYPE html>
 | 
				
			||||||
 | 
					<html lang="en">
 | 
				
			||||||
 | 
					  <head>
 | 
				
			||||||
 | 
					    <meta charset="utf-8">
 | 
				
			||||||
 | 
					    <meta http-equiv="X-UA-Compatible" content="IE=Edge">
 | 
				
			||||||
 | 
					    <title>
 | 
				
			||||||
 | 
					      Argo Tunnel Connection
 | 
				
			||||||
 | 
					    </title>
 | 
				
			||||||
 | 
					    <meta name="author" content="">
 | 
				
			||||||
 | 
					    <meta name="description" content="Argo Tunnel Connection">
 | 
				
			||||||
 | 
					    <meta name="viewport" content="width=device-width, initial-scale=1">
 | 
				
			||||||
 | 
					    <style>
 | 
				
			||||||
 | 
					      html{line-height:1.15;-ms-text-size-adjust:100%;-webkit-text-size-adjust:100%}body{margin:0}section{display:block}h1{font-size:2em;margin:.67em 0}a{background-color:transparent;-webkit-text-decoration-skip:objects}/* 1 */::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}/* 1 */a,body,dd,div,dl,dt,h1,h4,html,p,section{box-sizing:border-box}.bt{border-top-style:solid;border-top-width:1px}.bl{border-left-style:solid;border-left-width:1px}.b--orange{border-color:#f38020}.br1{border-radius:.125rem}.bw2{border-width:.25rem}.dib{display:inline-block}.sans-serif{font-family:open sans,-apple-system,BlinkMacSystemFont,avenir next,avenir,helvetica neue,helvetica,ubuntu,roboto,noto,segoe ui,arial,sans-serif}.code{font-family:Consolas,monaco,monospace}.b{font-weight:700}.fw3{font-weight:300}.fw4{font-weight:400}.fw5{font-weight:500}.fw6{font-weight:600}.lh-copy{line-height:1.5}.link{text-decoration:none}.link,.link:active,.link:focus,.link:hover,.link:link,.link:visited{transition:color .15s ease-in}.link:focus{outline:1px dotted currentColor}.mw-100{max-width:100%}.mw4{max-width:8rem}.mw7{max-width:48rem}.bg-light-gray{background-color:#f7f7f7}.link-hover:hover{background-color:#1f679e}.white{color:#fff}.bg-white{background-color:#fff}.bg-blue{background-color:#408bc9}.pb2{padding-bottom:.5rem}.pb6{padding-bottom:8rem}.pt3{padding-top:1rem}.pt5{padding-top:4rem}.pv2{padding-top:.5rem;padding-bottom:.5rem}.ph3{padding-left:1rem;padding-right:1rem}.ph4{padding-left:2rem;padding-right:2rem}.ml0{margin-left:0}.mb1{margin-bottom:.25rem}.mb2{margin-bottom:.5rem}.mb3{margin-bottom:1rem}.mt5{margin-top:4rem}.ttu{text-transform:uppercase}.f4{font-size:1.25rem}.f5{font-size:1rem}.f6{font-size:.875rem}.f7{font-size:.75rem}.measure{max-width:30em}.center{margin-left:auto}.center{margin-right:auto}@media screen and (min-width:30em){.f2-ns{font-size:2.25rem}}@media screen and (min-width:30em) and (max-width:60em){.f5-m{font-size:1rem}}@media screen and (min-width:60em){.f4-l{font-size:1.25rem}}
 | 
				
			||||||
 | 
					    .st0{fill:#FFF}.st1{fill:#f48120}.st2{fill:#faad3f}.st3{fill:#404041}
 | 
				
			||||||
 | 
					    </style>
 | 
				
			||||||
 | 
					  </head>
 | 
				
			||||||
 | 
					  <body class="sans-serif black">
 | 
				
			||||||
 | 
					    <div class="bt bw2 b--orange bg-white pb6">
 | 
				
			||||||
 | 
					      <div class="mw7 center ph4 pt3">
 | 
				
			||||||
 | 
					        <svg id="Layer_2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 109 40.5" class="mw4">
 | 
				
			||||||
 | 
					          <path class="st0" d="M98.6 14.2L93 12.9l-1-.4-25.7.2v12.4l32.3.1z"/>
 | 
				
			||||||
 | 
					          <path class="st1" d="M88.1 24c.3-1 .2-2-.3-2.6-.5-.6-1.2-1-2.1-1.1l-17.4-.2c-.1 0-.2-.1-.3-.1-.1-.1-.1-.2 0-.3.1-.2.2-.3.4-.3l17.5-.2c2.1-.1 4.3-1.8 5.1-3.8l1-2.6c0-.1.1-.2 0-.3-1.1-5.1-5.7-8.9-11.1-8.9-5 0-9.3 3.2-10.8 7.7-1-.7-2.2-1.1-3.6-1-2.4.2-4.3 2.2-4.6 4.6-.1.6 0 1.2.1 1.8-3.9.1-7.1 3.3-7.1 7.3 0 .4 0 .7.1 1.1 0 .2.2.3.3.3h32.1c.2 0 .4-.1.4-.3l.3-1.1z"/>
 | 
				
			||||||
 | 
					          <path class="st2" d="M93.6 12.8h-.5c-.1 0-.2.1-.3.2l-.7 2.4c-.3 1-.2 2 .3 2.6.5.6 1.2 1 2.1 1.1l3.7.2c.1 0 .2.1.3.1.1.1.1.2 0 .3-.1.2-.2.3-.4.3l-3.8.2c-2.1.1-4.3 1.8-5.1 3.8l-.2.9c-.1.1 0 .3.2.3h13.2c.2 0 .3-.1.3-.3.2-.8.4-1.7.4-2.6 0-5.2-4.3-9.5-9.5-9.5"/>
 | 
				
			||||||
 | 
					          <path class="st3" d="M104.4 30.8c-.5 0-.9-.4-.9-.9s.4-.9.9-.9.9.4.9.9-.4.9-.9.9m0-1.6c-.4 0-.7.3-.7.7 0 .4.3.7.7.7.4 0 .7-.3.7-.7 0-.4-.3-.7-.7-.7m.4 1.2h-.2l-.2-.3h-.2v.3h-.2v-.9h.5c.2 0 .3.1.3.3 0 .1-.1.2-.2.3l.2.3zm-.3-.5c.1 0 .1 0 .1-.1s-.1-.1-.1-.1h-.3v.3h.3zM14.8 29H17v6h3.8v1.9h-6zM23.1 32.9c0-2.3 1.8-4.1 4.3-4.1s4.2 1.8 4.2 4.1-1.8 4.1-4.3 4.1c-2.4 0-4.2-1.8-4.2-4.1m6.3 0c0-1.2-.8-2.2-2-2.2s-2 1-2 2.1.8 2.1 2 2.1c1.2.2 2-.8 2-2M34.3 33.4V29h2.2v4.4c0 1.1.6 1.7 1.5 1.7s1.5-.5 1.5-1.6V29h2.2v4.4c0 2.6-1.5 3.7-3.7 3.7-2.3-.1-3.7-1.2-3.7-3.7M45 29h3.1c2.8 0 4.5 1.6 4.5 3.9s-1.7 4-4.5 4h-3V29zm3.1 5.9c1.3 0 2.2-.7 2.2-2s-.9-2-2.2-2h-.9v4h.9zM55.7 29H62v1.9h-4.1v1.3h3.7V34h-3.7v2.9h-2.2zM65.1 29h2.2v6h3.8v1.9h-6zM76.8 28.9H79l3.4 8H80l-.6-1.4h-3.1l-.6 1.4h-2.3l3.4-8zm2 4.9l-.9-2.2-.9 2.2h1.8zM85.2 29h3.7c1.2 0 2 .3 2.6.9.5.5.7 1.1.7 1.8 0 1.2-.6 2-1.6 2.4l1.9 2.8H90l-1.6-2.4h-1v2.4h-2.2V29zm3.6 3.8c.7 0 1.2-.4 1.2-.9 0-.6-.5-.9-1.2-.9h-1.4v1.9h1.4zM95.3 29h6.4v1.8h-4.2V32h3.8v1.8h-3.8V35h4.3v1.9h-6.5zM10 33.9c-.3.7-1 1.2-1.8 1.2-1.2 0-2-1-2-2.1s.8-2.1 2-2.1c.9 0 1.6.6 1.9 1.3h2.3c-.4-1.9-2-3.3-4.2-3.3-2.4 0-4.3 1.8-4.3 4.1s1.8 4.1 4.2 4.1c2.1 0 3.7-1.4 4.2-3.2H10z"/>
 | 
				
			||||||
 | 
					        </svg>
 | 
				
			||||||
 | 
					        <h1 class="f4 f2-ns mt5 fw5">Congrats! You created your first tunnel!</h1>
 | 
				
			||||||
 | 
					        <p class="f6 f5-m f4-l measure lh-copy fw3">
 | 
				
			||||||
 | 
					          Argo Tunnel exposes locally running applications to the internet by
 | 
				
			||||||
 | 
					          running an encrypted, virtual tunnel from your laptop or server to
 | 
				
			||||||
 | 
					          Cloudflare's edge network.
 | 
				
			||||||
 | 
					        </p>
 | 
				
			||||||
 | 
					        <p class="b f5 mt5 fw6">Ready for the next step?</p>
 | 
				
			||||||
 | 
					        <a
 | 
				
			||||||
 | 
					          class="fw6 link white bg-blue ph4 pv2 br1 dib f5 link-hover"
 | 
				
			||||||
 | 
					          style="border-bottom: 1px solid #1f679e"
 | 
				
			||||||
 | 
					          href="https://developers.cloudflare.com/argo-tunnel/">
 | 
				
			||||||
 | 
					          Get started here
 | 
				
			||||||
 | 
					        </a>
 | 
				
			||||||
 | 
					       <section>
 | 
				
			||||||
 | 
					          <h4 class="f6 fw4 pt5 mb2">Request</h4>
 | 
				
			||||||
 | 
					          <dl class="bl bw2 b--orange ph3 pt3 pb2 bg-light-gray f7 code overflow-x-auto mw-100">
 | 
				
			||||||
 | 
											<dd class="ml0 mb3 f5">Method: {{.Request.Method}}</dd>
 | 
				
			||||||
 | 
											<dd class="ml0 mb3 f5">Protocol: {{.Request.Proto}}</dd>
 | 
				
			||||||
 | 
											<dd class="ml0 mb3 f5">Request URL: {{.Request.URL}}</dd>
 | 
				
			||||||
 | 
											<dd class="ml0 mb3 f5">Transfer encoding: {{.Request.TransferEncoding}}</dd>
 | 
				
			||||||
 | 
											<dd class="ml0 mb3 f5">Host: {{.Request.Host}}</dd>
 | 
				
			||||||
 | 
											<dd class="ml0 mb3 f5">Remote address: {{.Request.RemoteAddr}}</dd>
 | 
				
			||||||
 | 
											<dd class="ml0 mb3 f5">Request URI: {{.Request.RequestURI}}</dd>
 | 
				
			||||||
 | 
					{{range $key, $value := .Request.Header}}
 | 
				
			||||||
 | 
											<dd class="ml0 mb3 f5">Header: {{$key}}, Value: {{$value}}</dd>
 | 
				
			||||||
 | 
					{{end}}
 | 
				
			||||||
 | 
											<dd class="ml0 mb3 f5">Body: {{.Body}}</dd>
 | 
				
			||||||
 | 
										</dl>
 | 
				
			||||||
 | 
					        </section>
 | 
				
			||||||
 | 
					     </div>
 | 
				
			||||||
 | 
					    </div>
 | 
				
			||||||
 | 
					  </body>
 | 
				
			||||||
 | 
					</html>
 | 
				
			||||||
 | 
					`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func hello(c *cli.Context) error {
 | 
				
			||||||
 | 
						address := fmt.Sprintf(":%d", c.Int("port"))
 | 
				
			||||||
 | 
						listener, err := createListener(address)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer listener.Close()
 | 
				
			||||||
 | 
						err = startHelloWorldServer(listener, nil)
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
 | 
				
			||||||
 | 
						Log.Infof("Starting Hello World server at %s", listener.Addr())
 | 
				
			||||||
 | 
						serverName := defaultServerName
 | 
				
			||||||
 | 
						if hostname, err := os.Hostname(); err == nil {
 | 
				
			||||||
 | 
							serverName = hostname
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						upgrader := websocket.Upgrader{
 | 
				
			||||||
 | 
							ReadBufferSize:  1024,
 | 
				
			||||||
 | 
							WriteBufferSize: 1024,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							<-shutdownC
 | 
				
			||||||
 | 
							httpServer.Close()
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						http.HandleFunc("/uptime", uptimeHandler(time.Now()))
 | 
				
			||||||
 | 
						http.HandleFunc("/ws", websocketHandler(upgrader))
 | 
				
			||||||
 | 
						http.HandleFunc("/", rootHandler(serverName))
 | 
				
			||||||
 | 
						err := httpServer.Serve(listener)
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func createListener(address string) (net.Listener, error) {
 | 
				
			||||||
 | 
						certificate, err := tlsconfig.GetHelloCertificate()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If the port in address is empty, a port number is automatically chosen
 | 
				
			||||||
 | 
						listener, err := tls.Listen(
 | 
				
			||||||
 | 
							"tcp",
 | 
				
			||||||
 | 
							address,
 | 
				
			||||||
 | 
							&tls.Config{Certificates: []tls.Certificate{certificate}})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return listener, err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func uptimeHandler(startTime time.Time) http.HandlerFunc {
 | 
				
			||||||
 | 
						return func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							// Note that if autoupdate is enabled, the uptime is reset when a new client
 | 
				
			||||||
 | 
							// release is available
 | 
				
			||||||
 | 
							resp := &OriginUpTime{StartTime: startTime, UpTime: time.Now().Sub(startTime).String()}
 | 
				
			||||||
 | 
							respJson, err := json.Marshal(resp)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								w.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								w.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
 | 
								w.Write(respJson)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func websocketHandler(upgrader websocket.Upgrader) http.HandlerFunc {
 | 
				
			||||||
 | 
						return func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							conn, err := upgrader.Upgrade(w, r, nil)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							defer conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for {
 | 
				
			||||||
 | 
								mt, message, err := conn.ReadMessage()
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									break
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if err := conn.WriteMessage(mt, message); err != nil {
 | 
				
			||||||
 | 
									break
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func rootHandler(serverName string) http.HandlerFunc {
 | 
				
			||||||
 | 
						responseTemplate := template.Must(template.New("index").Parse(indexTemplate))
 | 
				
			||||||
 | 
						return func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							var buffer bytes.Buffer
 | 
				
			||||||
 | 
							var body string
 | 
				
			||||||
 | 
							rawBody, err := ioutil.ReadAll(r.Body)
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								body = string(rawBody)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								body = ""
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							err = responseTemplate.Execute(&buffer, &templateData{
 | 
				
			||||||
 | 
								ServerName: serverName,
 | 
				
			||||||
 | 
								Request:    r,
 | 
				
			||||||
 | 
								Body:       body,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								w.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
 | 
								fmt.Fprintf(w, "error: %v", err)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								buffer.WriteTo(w)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,35 @@
 | 
				
			||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCreateListenerHostAndPortSuccess(t *testing.T) {
 | 
				
			||||||
 | 
						listener, err := createListener("localhost:1234")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if listener.Addr().String() == "" {
 | 
				
			||||||
 | 
							t.Fatal("Fail to find available port")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCreateListenerOnlyHostSuccess(t *testing.T) {
 | 
				
			||||||
 | 
						listener, err := createListener("localhost:")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if listener.Addr().String() == "" {
 | 
				
			||||||
 | 
							t.Fatal("Fail to find available port")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCreateListenerOnlyPortSuccess(t *testing.T) {
 | 
				
			||||||
 | 
						listener, err := createListener(":8888")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if listener.Addr().String() == "" {
 | 
				
			||||||
 | 
							t.Fatal("Fail to find available port")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,292 @@
 | 
				
			||||||
 | 
					// +build linux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"path/filepath"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cli "gopkg.in/urfave/cli.v2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func runApp(app *cli.App) {
 | 
				
			||||||
 | 
						app.Commands = append(app.Commands, &cli.Command{
 | 
				
			||||||
 | 
							Name:  "service",
 | 
				
			||||||
 | 
							Usage: "Manages the Argo Tunnel system service",
 | 
				
			||||||
 | 
							Subcommands: []*cli.Command{
 | 
				
			||||||
 | 
								&cli.Command{
 | 
				
			||||||
 | 
									Name:   "install",
 | 
				
			||||||
 | 
									Usage:  "Install Argo Tunnel as a system service",
 | 
				
			||||||
 | 
									Action: installLinuxService,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								&cli.Command{
 | 
				
			||||||
 | 
									Name:   "uninstall",
 | 
				
			||||||
 | 
									Usage:  "Uninstall the Argo Tunnel service",
 | 
				
			||||||
 | 
									Action: uninstallLinuxService,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						app.Run(os.Args)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const serviceConfigDir = "/etc/cloudflared"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var systemdTemplates = []ServiceTemplate{
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Path: "/etc/systemd/system/cloudflared.service",
 | 
				
			||||||
 | 
							Content: `[Unit]
 | 
				
			||||||
 | 
					Description=Argo Tunnel
 | 
				
			||||||
 | 
					After=network.target
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[Service]
 | 
				
			||||||
 | 
					TimeoutStartSec=0
 | 
				
			||||||
 | 
					Type=notify
 | 
				
			||||||
 | 
					ExecStart={{ .Path }} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --no-autoupdate
 | 
				
			||||||
 | 
					Restart=on-failure
 | 
				
			||||||
 | 
					RestartSec=5s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[Install]
 | 
				
			||||||
 | 
					WantedBy=multi-user.target
 | 
				
			||||||
 | 
					`,
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Path: "/etc/systemd/system/cloudflared-update.service",
 | 
				
			||||||
 | 
							Content: `[Unit]
 | 
				
			||||||
 | 
					Description=Update Argo Tunnel
 | 
				
			||||||
 | 
					After=network.target
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[Service]
 | 
				
			||||||
 | 
					ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 64 ]; then systemctl restart cloudflared; exit 0; fi; exit $code'
 | 
				
			||||||
 | 
					`,
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							Path: "/etc/systemd/system/cloudflared-update.timer",
 | 
				
			||||||
 | 
							Content: `[Unit]
 | 
				
			||||||
 | 
					Description=Update Argo Tunnel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[Timer]
 | 
				
			||||||
 | 
					OnUnitActiveSec=1d
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[Install]
 | 
				
			||||||
 | 
					WantedBy=timers.target
 | 
				
			||||||
 | 
					`,
 | 
				
			||||||
 | 
						},
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var sysvTemplate = ServiceTemplate{
 | 
				
			||||||
 | 
						Path:     "/etc/init.d/cloudflared",
 | 
				
			||||||
 | 
						FileMode: 0755,
 | 
				
			||||||
 | 
						Content: `# For RedHat and cousins:
 | 
				
			||||||
 | 
					# chkconfig: 2345 99 01
 | 
				
			||||||
 | 
					# description: Argo Tunnel agent
 | 
				
			||||||
 | 
					# processname: {{.Path}}
 | 
				
			||||||
 | 
					### BEGIN INIT INFO
 | 
				
			||||||
 | 
					# Provides:          {{.Path}}
 | 
				
			||||||
 | 
					# Required-Start:
 | 
				
			||||||
 | 
					# Required-Stop:
 | 
				
			||||||
 | 
					# Default-Start:     2 3 4 5
 | 
				
			||||||
 | 
					# Default-Stop:      0 1 6
 | 
				
			||||||
 | 
					# Short-Description: Argo Tunnel
 | 
				
			||||||
 | 
					# Description:       Argo Tunnel agent
 | 
				
			||||||
 | 
					### END INIT INFO
 | 
				
			||||||
 | 
					cmd="{{.Path}} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --pidfile /var/run/$name.pid --autoupdate-freq 24h0m0s"
 | 
				
			||||||
 | 
					name=$(basename $(readlink -f $0))
 | 
				
			||||||
 | 
					pid_file="/var/run/$name.pid"
 | 
				
			||||||
 | 
					stdout_log="/var/log/$name.log"
 | 
				
			||||||
 | 
					stderr_log="/var/log/$name.err"
 | 
				
			||||||
 | 
					[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name
 | 
				
			||||||
 | 
					get_pid() {
 | 
				
			||||||
 | 
					    cat "$pid_file"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					is_running() {
 | 
				
			||||||
 | 
					    [ -f "$pid_file" ] && ps $(get_pid) > /dev/null 2>&1
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					case "$1" in
 | 
				
			||||||
 | 
					    start)
 | 
				
			||||||
 | 
					        if is_running; then
 | 
				
			||||||
 | 
					            echo "Already started"
 | 
				
			||||||
 | 
					        else
 | 
				
			||||||
 | 
					            echo "Starting $name"
 | 
				
			||||||
 | 
					            $cmd >> "$stdout_log" 2>> "$stderr_log" &
 | 
				
			||||||
 | 
					            echo $! > "$pid_file"
 | 
				
			||||||
 | 
					            if ! is_running; then
 | 
				
			||||||
 | 
					                echo "Unable to start, see $stdout_log and $stderr_log"
 | 
				
			||||||
 | 
					                exit 1
 | 
				
			||||||
 | 
					            fi
 | 
				
			||||||
 | 
					        fi
 | 
				
			||||||
 | 
					    ;;
 | 
				
			||||||
 | 
					    stop)
 | 
				
			||||||
 | 
					        if is_running; then
 | 
				
			||||||
 | 
					            echo -n "Stopping $name.."
 | 
				
			||||||
 | 
					            kill $(get_pid)
 | 
				
			||||||
 | 
					            for i in {1..10}
 | 
				
			||||||
 | 
					            do
 | 
				
			||||||
 | 
					                if ! is_running; then
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					                fi
 | 
				
			||||||
 | 
					                echo -n "."
 | 
				
			||||||
 | 
					                sleep 1
 | 
				
			||||||
 | 
					            done
 | 
				
			||||||
 | 
					            echo
 | 
				
			||||||
 | 
					            if is_running; then
 | 
				
			||||||
 | 
					                echo "Not stopped; may still be shutting down or shutdown may have failed"
 | 
				
			||||||
 | 
					                exit 1
 | 
				
			||||||
 | 
					            else
 | 
				
			||||||
 | 
					                echo "Stopped"
 | 
				
			||||||
 | 
					                if [ -f "$pid_file" ]; then
 | 
				
			||||||
 | 
					                    rm "$pid_file"
 | 
				
			||||||
 | 
					                fi
 | 
				
			||||||
 | 
					            fi
 | 
				
			||||||
 | 
					        else
 | 
				
			||||||
 | 
					            echo "Not running"
 | 
				
			||||||
 | 
					        fi
 | 
				
			||||||
 | 
					    ;;
 | 
				
			||||||
 | 
					    restart)
 | 
				
			||||||
 | 
					        $0 stop
 | 
				
			||||||
 | 
					        if is_running; then
 | 
				
			||||||
 | 
					            echo "Unable to stop, will not attempt to start"
 | 
				
			||||||
 | 
					            exit 1
 | 
				
			||||||
 | 
					        fi
 | 
				
			||||||
 | 
					        $0 start
 | 
				
			||||||
 | 
					    ;;
 | 
				
			||||||
 | 
					    status)
 | 
				
			||||||
 | 
					        if is_running; then
 | 
				
			||||||
 | 
					            echo "Running"
 | 
				
			||||||
 | 
					        else
 | 
				
			||||||
 | 
					            echo "Stopped"
 | 
				
			||||||
 | 
					            exit 1
 | 
				
			||||||
 | 
					        fi
 | 
				
			||||||
 | 
					    ;;
 | 
				
			||||||
 | 
					    *)
 | 
				
			||||||
 | 
					    echo "Usage: $0 {start|stop|restart|status}"
 | 
				
			||||||
 | 
					    exit 1
 | 
				
			||||||
 | 
					    ;;
 | 
				
			||||||
 | 
					esac
 | 
				
			||||||
 | 
					exit 0
 | 
				
			||||||
 | 
					`,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func isSystemd() bool {
 | 
				
			||||||
 | 
						if _, err := os.Stat("/run/systemd/system"); err == nil {
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func installLinuxService(c *cli.Context) error {
 | 
				
			||||||
 | 
						etPath, err := os.Executable()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("error determining executable path: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						templateArgs := ServiceTemplateArgs{Path: etPath}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defaultConfigDir := filepath.Dir(c.String("config"))
 | 
				
			||||||
 | 
						defaultConfigFile := filepath.Base(c.String("config"))
 | 
				
			||||||
 | 
						if err = copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile); err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("Failed to copy user configuration. Before running the service, ensure that %s contains two files, %s and %s",
 | 
				
			||||||
 | 
								serviceConfigDir, credentialFile, defaultConfigFiles[0])
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						switch {
 | 
				
			||||||
 | 
						case isSystemd():
 | 
				
			||||||
 | 
							Log.Infof("Using Systemd")
 | 
				
			||||||
 | 
							return installSystemd(&templateArgs)
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							Log.Infof("Using Sysv")
 | 
				
			||||||
 | 
							return installSysv(&templateArgs)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func installSystemd(templateArgs *ServiceTemplateArgs) error {
 | 
				
			||||||
 | 
						for _, serviceTemplate := range systemdTemplates {
 | 
				
			||||||
 | 
							err := serviceTemplate.Generate(templateArgs)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								Log.WithError(err).Infof("error generating service template")
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := runCommand("systemctl", "enable", "cloudflared.service"); err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("systemctl enable cloudflared.service error")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := runCommand("systemctl", "start", "cloudflared-update.timer"); err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("systemctl start cloudflared-update.timer error")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						Log.Infof("systemctl daemon-reload")
 | 
				
			||||||
 | 
						return runCommand("systemctl", "daemon-reload")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func installSysv(templateArgs *ServiceTemplateArgs) error {
 | 
				
			||||||
 | 
						confPath, err := sysvTemplate.ResolvePath()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("error resolving system path")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := sysvTemplate.Generate(templateArgs); err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("error generating system template")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, i := range [...]string{"2", "3", "4", "5"} {
 | 
				
			||||||
 | 
							if err := os.Symlink(confPath, "/etc/rc"+i+".d/S50et"); err != nil {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, i := range [...]string{"0", "1", "6"} {
 | 
				
			||||||
 | 
							if err := os.Symlink(confPath, "/etc/rc"+i+".d/K02et"); err != nil {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func uninstallLinuxService(c *cli.Context) error {
 | 
				
			||||||
 | 
						switch {
 | 
				
			||||||
 | 
						case isSystemd():
 | 
				
			||||||
 | 
							Log.Infof("Using Systemd")
 | 
				
			||||||
 | 
							return uninstallSystemd()
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							Log.Infof("Using Sysv")
 | 
				
			||||||
 | 
							return uninstallSysv()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func uninstallSystemd() error {
 | 
				
			||||||
 | 
						if err := runCommand("systemctl", "disable", "cloudflared.service"); err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("systemctl disable cloudflared.service error")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := runCommand("systemctl", "stop", "cloudflared-update.timer"); err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("systemctl stop cloudflared-update.timer error")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, serviceTemplate := range systemdTemplates {
 | 
				
			||||||
 | 
							if err := serviceTemplate.Remove(); err != nil {
 | 
				
			||||||
 | 
								Log.WithError(err).Infof("error removing service template")
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						Log.Infof("Successfully uninstall cloudflared service")
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func uninstallSysv() error {
 | 
				
			||||||
 | 
						if err := sysvTemplate.Remove(); err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("error removing service template")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, i := range [...]string{"2", "3", "4", "5"} {
 | 
				
			||||||
 | 
							if err := os.Remove("/etc/rc" + i + ".d/S50et"); err != nil {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, i := range [...]string{"0", "1", "6"} {
 | 
				
			||||||
 | 
							if err := os.Remove("/etc/rc" + i + ".d/K02et"); err != nil {
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						Log.Infof("Successfully uninstall cloudflared service")
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,194 @@
 | 
				
			||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"crypto/rand"
 | 
				
			||||||
 | 
						"encoding/base32"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"os/exec"
 | 
				
			||||||
 | 
						"path/filepath"
 | 
				
			||||||
 | 
						"runtime"
 | 
				
			||||||
 | 
						"syscall"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						homedir "github.com/mitchellh/go-homedir"
 | 
				
			||||||
 | 
						cli "gopkg.in/urfave/cli.v2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const baseLoginURL = "https://www.cloudflare.com/a/warp"
 | 
				
			||||||
 | 
					const baseCertStoreURL = "https://login.cloudflarewarp.com"
 | 
				
			||||||
 | 
					const clientTimeout = time.Minute * 20
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func login(c *cli.Context) error {
 | 
				
			||||||
 | 
						configPath, err := homedir.Expand(defaultConfigDirs[0])
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						ok, err := fileExists(configPath)
 | 
				
			||||||
 | 
						if !ok && err == nil {
 | 
				
			||||||
 | 
							// create config directory if doesn't already exist
 | 
				
			||||||
 | 
							err = os.Mkdir(configPath, 0700)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						path := filepath.Join(configPath, credentialFile)
 | 
				
			||||||
 | 
						fileInfo, err := os.Stat(path)
 | 
				
			||||||
 | 
						if err == nil && fileInfo.Size() > 0 {
 | 
				
			||||||
 | 
							fmt.Fprintf(os.Stderr, `You have an existing certificate at %s which login would overwrite.
 | 
				
			||||||
 | 
					If this is intentional, please move or delete that file then run this command again.
 | 
				
			||||||
 | 
					`, path)
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err != nil && err.(*os.PathError).Err != syscall.ENOENT {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// for local debugging
 | 
				
			||||||
 | 
						baseURL := baseCertStoreURL
 | 
				
			||||||
 | 
						if c.IsSet("url") {
 | 
				
			||||||
 | 
							baseURL = c.String("url")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// Generate a random post URL
 | 
				
			||||||
 | 
						certURL := baseURL + generateRandomPath()
 | 
				
			||||||
 | 
						loginURL, err := url.Parse(baseLoginURL)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							// shouldn't happen, URL is hardcoded
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						loginURL.RawQuery = "callback=" + url.QueryEscape(certURL)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = open(loginURL.String())
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							fmt.Fprintf(os.Stderr, `Please open the following URL and log in with your Cloudflare account:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					%s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Leave cloudflared running to install the certificate automatically.
 | 
				
			||||||
 | 
					`, loginURL.String())
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							fmt.Fprintf(os.Stderr, `A browser window should have opened at the following URL:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					%s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					If the browser failed to open, open it yourself and visit the URL above.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					`, loginURL.String())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if download(certURL, path) {
 | 
				
			||||||
 | 
							fmt.Fprintf(os.Stderr, `You have successfully logged in.
 | 
				
			||||||
 | 
					If you wish to copy your credentials to a server, they have been saved to:
 | 
				
			||||||
 | 
					%s
 | 
				
			||||||
 | 
					`, path)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							fmt.Fprintf(os.Stderr, `Failed to write the certificate due to the following error:
 | 
				
			||||||
 | 
					%v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Your browser will download the certificate instead. You will have to manually
 | 
				
			||||||
 | 
					copy it to the following path:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					%s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					`, err, path)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// generateRandomPath generates a random URL to associate with the certificate.
 | 
				
			||||||
 | 
					func generateRandomPath() string {
 | 
				
			||||||
 | 
						randomBytes := make([]byte, 40)
 | 
				
			||||||
 | 
						_, err := rand.Read(randomBytes)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							panic(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return "/" + base32.StdEncoding.EncodeToString(randomBytes)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// open opens the specified URL in the default browser of the user.
 | 
				
			||||||
 | 
					func open(url string) error {
 | 
				
			||||||
 | 
						var cmd string
 | 
				
			||||||
 | 
						var args []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						switch runtime.GOOS {
 | 
				
			||||||
 | 
						case "windows":
 | 
				
			||||||
 | 
							cmd = "cmd"
 | 
				
			||||||
 | 
							args = []string{"/c", "start"}
 | 
				
			||||||
 | 
						case "darwin":
 | 
				
			||||||
 | 
							cmd = "open"
 | 
				
			||||||
 | 
						default: // "linux", "freebsd", "openbsd", "netbsd"
 | 
				
			||||||
 | 
							cmd = "xdg-open"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						args = append(args, url)
 | 
				
			||||||
 | 
						return exec.Command(cmd, args...).Start()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func download(certURL, filePath string) bool {
 | 
				
			||||||
 | 
						client := &http.Client{Timeout: clientTimeout}
 | 
				
			||||||
 | 
						// attempt a (long-running) certificate get
 | 
				
			||||||
 | 
						for i := 0; i < 20; i++ {
 | 
				
			||||||
 | 
							ok, err := tryDownload(client, certURL, filePath)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								putSuccess(client, certURL)
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								Log.WithError(err).Error("Error fetching certificate")
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func tryDownload(client *http.Client, certURL, filePath string) (ok bool, err error) {
 | 
				
			||||||
 | 
						resp, err := client.Get(certURL)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer resp.Body.Close()
 | 
				
			||||||
 | 
						if resp.StatusCode == 404 {
 | 
				
			||||||
 | 
							return false, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if resp.StatusCode != 200 {
 | 
				
			||||||
 | 
							return false, fmt.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if resp.Header.Get("Content-Type") != "application/x-pem-file" {
 | 
				
			||||||
 | 
							return false, fmt.Errorf("Unexpected content type %s", resp.Header.Get("Content-Type"))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// write response
 | 
				
			||||||
 | 
						file, err := os.Create(filePath)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer file.Close()
 | 
				
			||||||
 | 
						written, err := io.Copy(file, resp.Body)
 | 
				
			||||||
 | 
						switch {
 | 
				
			||||||
 | 
						case err != nil:
 | 
				
			||||||
 | 
							return false, err
 | 
				
			||||||
 | 
						case resp.ContentLength != written && resp.ContentLength != -1:
 | 
				
			||||||
 | 
							return false, fmt.Errorf("Short read (%d bytes) from server while writing certificate", written)
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							return true, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func putSuccess(client *http.Client, certURL string) {
 | 
				
			||||||
 | 
						// indicate success to the relay server
 | 
				
			||||||
 | 
						req, err := http.NewRequest("PUT", certURL+"/ok", nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Error("HTTP request error")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resp, err := client.Do(req)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Error("HTTP error")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resp.Body.Close()
 | 
				
			||||||
 | 
						if resp.StatusCode != 200 {
 | 
				
			||||||
 | 
							Log.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,97 @@
 | 
				
			||||||
 | 
					// +build darwin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"gopkg.in/urfave/cli.v2"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const launchAgentIdentifier = "com.cloudflare.cloudflared"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func runApp(app *cli.App) {
 | 
				
			||||||
 | 
						app.Commands = append(app.Commands, &cli.Command{
 | 
				
			||||||
 | 
							Name:  "service",
 | 
				
			||||||
 | 
							Usage: "Manages the Argo Tunnel launch agent",
 | 
				
			||||||
 | 
							Subcommands: []*cli.Command{
 | 
				
			||||||
 | 
								{
 | 
				
			||||||
 | 
									Name:   "install",
 | 
				
			||||||
 | 
									Usage:  "Install Argo Tunnel as an user launch agent",
 | 
				
			||||||
 | 
									Action: installLaunchd,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								{
 | 
				
			||||||
 | 
									Name:   "uninstall",
 | 
				
			||||||
 | 
									Usage:  "Uninstall the Argo Tunnel launch agent",
 | 
				
			||||||
 | 
									Action: uninstallLaunchd,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						app.Run(os.Args)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var launchdTemplate = ServiceTemplate{
 | 
				
			||||||
 | 
						Path: fmt.Sprintf("~/Library/LaunchAgents/%s.plist", launchAgentIdentifier),
 | 
				
			||||||
 | 
						Content: fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
 | 
				
			||||||
 | 
					<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
 | 
				
			||||||
 | 
					<plist version="1.0">
 | 
				
			||||||
 | 
						<dict>
 | 
				
			||||||
 | 
							<key>Label</key>
 | 
				
			||||||
 | 
							<string>%s</string>
 | 
				
			||||||
 | 
							<key>Program</key>
 | 
				
			||||||
 | 
							<string>{{ .Path }}</string>
 | 
				
			||||||
 | 
							<key>RunAtLoad</key>
 | 
				
			||||||
 | 
							<true/>
 | 
				
			||||||
 | 
							<key>StandardOutPath</key>
 | 
				
			||||||
 | 
							<string>/tmp/%s.out.log</string>
 | 
				
			||||||
 | 
					    <key>StandardErrorPath</key>
 | 
				
			||||||
 | 
							<string>/tmp/%s.err.log</string>
 | 
				
			||||||
 | 
							<key>KeepAlive</key>
 | 
				
			||||||
 | 
							<dict>
 | 
				
			||||||
 | 
								<key>NetworkState</key>
 | 
				
			||||||
 | 
								<true/>
 | 
				
			||||||
 | 
							</dict>
 | 
				
			||||||
 | 
							<key>ThrottleInterval</key>
 | 
				
			||||||
 | 
							<integer>20</integer>
 | 
				
			||||||
 | 
						</dict>
 | 
				
			||||||
 | 
					</plist>`, launchAgentIdentifier, launchAgentIdentifier, launchAgentIdentifier),
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func installLaunchd(c *cli.Context) error {
 | 
				
			||||||
 | 
						Log.Infof("Installing Argo Tunnel as an user launch agent")
 | 
				
			||||||
 | 
						etPath, err := os.Executable()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("error determining executable path")
 | 
				
			||||||
 | 
							return fmt.Errorf("error determining executable path: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						templateArgs := ServiceTemplateArgs{Path: etPath}
 | 
				
			||||||
 | 
						err = launchdTemplate.Generate(&templateArgs)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("error generating launchd template")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						plistPath, err := launchdTemplate.ResolvePath()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("error resolving launchd template path")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						Log.Infof("Outputs are logged in %s and %s", fmt.Sprintf("/tmp/%s.out.log", launchAgentIdentifier), fmt.Sprintf("/tmp/%s.err.log", launchAgentIdentifier))
 | 
				
			||||||
 | 
						return runCommand("launchctl", "load", plistPath)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func uninstallLaunchd(c *cli.Context) error {
 | 
				
			||||||
 | 
						Log.Infof("Uninstalling Argo Tunnel as an user launch agent")
 | 
				
			||||||
 | 
						plistPath, err := launchdTemplate.ResolvePath()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("error resolving launchd template path")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = runCommand("launchctl", "unload", plistPath)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("error unloading")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						Log.Infof("Outputs are logged in %s and %s", fmt.Sprintf("/tmp/%s.out.log", launchAgentIdentifier), fmt.Sprintf("/tmp/%s.err.log", launchAgentIdentifier))
 | 
				
			||||||
 | 
						return launchdTemplate.Remove()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,794 @@
 | 
				
			||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"crypto/tls"
 | 
				
			||||||
 | 
						"encoding/hex"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"math/rand"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"os/signal"
 | 
				
			||||||
 | 
						"path/filepath"
 | 
				
			||||||
 | 
						"runtime"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"syscall"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/cloudflare/cloudflared/metrics"
 | 
				
			||||||
 | 
						"github.com/cloudflare/cloudflared/origin"
 | 
				
			||||||
 | 
						"github.com/cloudflare/cloudflared/tlsconfig"
 | 
				
			||||||
 | 
						"github.com/cloudflare/cloudflared/tunneldns"
 | 
				
			||||||
 | 
						tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
				
			||||||
 | 
						"github.com/cloudflare/cloudflared/validation"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/facebookgo/grace/gracenet"
 | 
				
			||||||
 | 
						"github.com/getsentry/raven-go"
 | 
				
			||||||
 | 
						"github.com/mitchellh/go-homedir"
 | 
				
			||||||
 | 
						"github.com/rifflock/lfshook"
 | 
				
			||||||
 | 
						"github.com/sirupsen/logrus"
 | 
				
			||||||
 | 
						"golang.org/x/crypto/ssh/terminal"
 | 
				
			||||||
 | 
						"gopkg.in/urfave/cli.v2"
 | 
				
			||||||
 | 
						"gopkg.in/urfave/cli.v2/altsrc"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/coreos/go-systemd/daemon"
 | 
				
			||||||
 | 
						"github.com/pkg/errors"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						sentryDSN           = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
 | 
				
			||||||
 | 
						credentialFile      = "cert.pem"
 | 
				
			||||||
 | 
						quickStartUrl       = "https://developers.cloudflare.com/argo-tunnel/quickstart/quickstart/"
 | 
				
			||||||
 | 
						noAutoupdateMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/"
 | 
				
			||||||
 | 
						licenseUrl          = "https://developers.cloudflare.com/argo-tunnel/licence/"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var listeners = gracenet.Net{}
 | 
				
			||||||
 | 
					var Version = "DEV"
 | 
				
			||||||
 | 
					var BuildTime = "unknown"
 | 
				
			||||||
 | 
					var Log *logrus.Logger
 | 
				
			||||||
 | 
					var defaultConfigFiles = []string{"config.yml", "config.yaml"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible
 | 
				
			||||||
 | 
					var defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Shutdown channel used by the app. When closed, app must terminate.
 | 
				
			||||||
 | 
					// May be closed by the Windows service runner.
 | 
				
			||||||
 | 
					var shutdownC chan struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type BuildAndRuntimeInfo struct {
 | 
				
			||||||
 | 
						GoOS        string                 `json:"go_os"`
 | 
				
			||||||
 | 
						GoVersion   string                 `json:"go_version"`
 | 
				
			||||||
 | 
						GoArch      string                 `json:"go_arch"`
 | 
				
			||||||
 | 
						WarpVersion string                 `json:"warp_version"`
 | 
				
			||||||
 | 
						WarpFlags   map[string]interface{} `json:"warp_flags"`
 | 
				
			||||||
 | 
						WarpEnvs    map[string]string      `json:"warp_envs"`
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func main() {
 | 
				
			||||||
 | 
						metrics.RegisterBuildInfo(BuildTime, Version)
 | 
				
			||||||
 | 
						raven.SetDSN(sentryDSN)
 | 
				
			||||||
 | 
						raven.SetRelease(Version)
 | 
				
			||||||
 | 
						shutdownC = make(chan struct{})
 | 
				
			||||||
 | 
						app := &cli.App{}
 | 
				
			||||||
 | 
						app.Name = "cloudflared"
 | 
				
			||||||
 | 
						app.Copyright = fmt.Sprintf(`(c) %d Cloudflare Inc.
 | 
				
			||||||
 | 
					   Use is subject to the license agreement at %s`, time.Now().Year(), licenseUrl)
 | 
				
			||||||
 | 
						app.Usage = "Cloudflare reverse tunnelling proxy agent"
 | 
				
			||||||
 | 
						app.ArgsUsage = "origin-url"
 | 
				
			||||||
 | 
						app.Version = fmt.Sprintf("%s (built %s)", Version, BuildTime)
 | 
				
			||||||
 | 
						app.Description = `A reverse tunnel proxy agent that connects to Cloudflare's infrastructure.
 | 
				
			||||||
 | 
					   Upon connecting, you are assigned a unique subdomain on cftunnel.com.
 | 
				
			||||||
 | 
					   You need to specify a hostname on a zone you control.
 | 
				
			||||||
 | 
					   A DNS record will be created to CNAME your hostname to the unique subdomain on cftunnel.com.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   Requests made to Cloudflare's servers for your hostname will be proxied
 | 
				
			||||||
 | 
					   through the tunnel to your local webserver.`
 | 
				
			||||||
 | 
						app.Flags = []cli.Flag{
 | 
				
			||||||
 | 
							&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:  "config",
 | 
				
			||||||
 | 
								Usage: "Specifies a config file in YAML format.",
 | 
				
			||||||
 | 
								Value: findDefaultConfigPath(),
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
				
			||||||
 | 
								Name:  "autoupdate-freq",
 | 
				
			||||||
 | 
								Usage: "Autoupdate frequency. Default is 24h.",
 | 
				
			||||||
 | 
								Value: time.Hour * 24,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewBoolFlag(&cli.BoolFlag{
 | 
				
			||||||
 | 
								Name:  "no-autoupdate",
 | 
				
			||||||
 | 
								Usage: "Disable periodic check for updates, restarting the server with the new version.",
 | 
				
			||||||
 | 
								Value: false,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewBoolFlag(&cli.BoolFlag{
 | 
				
			||||||
 | 
								Name:   "is-autoupdated",
 | 
				
			||||||
 | 
								Usage:  "Signal the new process that Argo Tunnel client has been autoupdated",
 | 
				
			||||||
 | 
								Value:  false,
 | 
				
			||||||
 | 
								Hidden: true,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
 | 
				
			||||||
 | 
								Name:    "edge",
 | 
				
			||||||
 | 
								Usage:   "Address of the Cloudflare tunnel server.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_EDGE"},
 | 
				
			||||||
 | 
								Hidden:  true,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "cacert",
 | 
				
			||||||
 | 
								Usage:   "Certificate Authority authenticating the Cloudflare tunnel connection.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_CACERT"},
 | 
				
			||||||
 | 
								Hidden:  true,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "origincert",
 | 
				
			||||||
 | 
								Usage:   "Path to the certificate generated for your origin when you run cloudflared login.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
 | 
				
			||||||
 | 
								Value:   findDefaultOriginCertPath(),
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "url",
 | 
				
			||||||
 | 
								Value:   "https://localhost:8080",
 | 
				
			||||||
 | 
								Usage:   "Connect to the local webserver at `URL`.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_URL"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "hostname",
 | 
				
			||||||
 | 
								Usage:   "Set a hostname on a Cloudflare zone to route traffic through this tunnel.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_HOSTNAME"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "origin-server-name",
 | 
				
			||||||
 | 
								Usage:   "Hostname on the origin server certificate.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_ORIGIN_SERVER_NAME"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "id",
 | 
				
			||||||
 | 
								Usage:   "A unique identifier used to tie connections to this tunnel instance.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_ID"},
 | 
				
			||||||
 | 
								Hidden:  true,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "lb-pool",
 | 
				
			||||||
 | 
								Usage:   "The name of a (new/existing) load balancing pool to add this origin to.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_LB_POOL"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "api-key",
 | 
				
			||||||
 | 
								Usage:   "This parameter has been deprecated since version 2017.10.1.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_API_KEY"},
 | 
				
			||||||
 | 
								Hidden:  true,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "api-email",
 | 
				
			||||||
 | 
								Usage:   "This parameter has been deprecated since version 2017.10.1.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_API_EMAIL"},
 | 
				
			||||||
 | 
								Hidden:  true,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "api-ca-key",
 | 
				
			||||||
 | 
								Usage:   "This parameter has been deprecated since version 2017.10.1.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_API_CA_KEY"},
 | 
				
			||||||
 | 
								Hidden:  true,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "metrics",
 | 
				
			||||||
 | 
								Value:   "localhost:",
 | 
				
			||||||
 | 
								Usage:   "Listen address for metrics reporting.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_METRICS"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
				
			||||||
 | 
								Name:    "metrics-update-freq",
 | 
				
			||||||
 | 
								Usage:   "Frequency to update tunnel metrics",
 | 
				
			||||||
 | 
								Value:   time.Second * 5,
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_METRICS_UPDATE_FREQ"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
 | 
				
			||||||
 | 
								Name:    "tag",
 | 
				
			||||||
 | 
								Usage:   "Custom tags used to identify this tunnel, in format `KEY=VALUE`. Multiple tags may be specified",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_TAG"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
				
			||||||
 | 
								Name:   "heartbeat-interval",
 | 
				
			||||||
 | 
								Usage:  "Minimum idle time before sending a heartbeat.",
 | 
				
			||||||
 | 
								Value:  time.Second * 5,
 | 
				
			||||||
 | 
								Hidden: true,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewUint64Flag(&cli.Uint64Flag{
 | 
				
			||||||
 | 
								Name:   "heartbeat-count",
 | 
				
			||||||
 | 
								Usage:  "Minimum number of unacked heartbeats to send before closing the connection.",
 | 
				
			||||||
 | 
								Value:  5,
 | 
				
			||||||
 | 
								Hidden: true,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "loglevel",
 | 
				
			||||||
 | 
								Value:   "info",
 | 
				
			||||||
 | 
								Usage:   "Application logging level {panic, fatal, error, warn, info, debug}",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_LOGLEVEL"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "proto-loglevel",
 | 
				
			||||||
 | 
								Value:   "warn",
 | 
				
			||||||
 | 
								Usage:   "Protocol logging level {panic, fatal, error, warn, info, debug}",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_PROTO_LOGLEVEL"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewUintFlag(&cli.UintFlag{
 | 
				
			||||||
 | 
								Name:    "retries",
 | 
				
			||||||
 | 
								Value:   5,
 | 
				
			||||||
 | 
								Usage:   "Maximum number of retries for connection/protocol errors.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_RETRIES"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewBoolFlag(&cli.BoolFlag{
 | 
				
			||||||
 | 
								Name:    "hello-world",
 | 
				
			||||||
 | 
								Value:   false,
 | 
				
			||||||
 | 
								Usage:   "Run Hello World Server",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_HELLO_WORLD"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "pidfile",
 | 
				
			||||||
 | 
								Usage:   "Write the application's PID to this file after first successful connection.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_PIDFILE"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "logfile",
 | 
				
			||||||
 | 
								Usage:   "Save application log to this file for reporting issues.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_LOGFILE"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewIntFlag(&cli.IntFlag{
 | 
				
			||||||
 | 
								Name:   "ha-connections",
 | 
				
			||||||
 | 
								Value:  4,
 | 
				
			||||||
 | 
								Hidden: true,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
				
			||||||
 | 
								Name:  "proxy-connect-timeout",
 | 
				
			||||||
 | 
								Usage: "HTTP proxy timeout for establishing a new connection",
 | 
				
			||||||
 | 
								Value: time.Second * 30,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
				
			||||||
 | 
								Name:  "proxy-tls-timeout",
 | 
				
			||||||
 | 
								Usage: "HTTP proxy timeout for completing a TLS handshake",
 | 
				
			||||||
 | 
								Value: time.Second * 10,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
				
			||||||
 | 
								Name:  "proxy-tcp-keepalive",
 | 
				
			||||||
 | 
								Usage: "HTTP proxy TCP keepalive duration",
 | 
				
			||||||
 | 
								Value: time.Second * 30,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewBoolFlag(&cli.BoolFlag{
 | 
				
			||||||
 | 
								Name:  "proxy-no-happy-eyeballs",
 | 
				
			||||||
 | 
								Usage: "HTTP proxy should disable \"happy eyeballs\" for IPv4/v6 fallback",
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewIntFlag(&cli.IntFlag{
 | 
				
			||||||
 | 
								Name:  "proxy-keepalive-connections",
 | 
				
			||||||
 | 
								Usage: "HTTP proxy maximum keepalive connection pool size",
 | 
				
			||||||
 | 
								Value: 100,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewDurationFlag(&cli.DurationFlag{
 | 
				
			||||||
 | 
								Name:  "proxy-keepalive-timeout",
 | 
				
			||||||
 | 
								Usage: "HTTP proxy timeout for closing an idle connection",
 | 
				
			||||||
 | 
								Value: time.Second * 90,
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewBoolFlag(&cli.BoolFlag{
 | 
				
			||||||
 | 
								Name:    "proxy-dns",
 | 
				
			||||||
 | 
								Usage:   "Run a DNS over HTTPS proxy server.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_DNS"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewUintFlag(&cli.UintFlag{
 | 
				
			||||||
 | 
								Name:    "proxy-dns-port",
 | 
				
			||||||
 | 
								Value:   53,
 | 
				
			||||||
 | 
								Usage:   "Listen on given port for the DNS over HTTPS proxy server.",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_DNS_PORT"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringFlag(&cli.StringFlag{
 | 
				
			||||||
 | 
								Name:    "proxy-dns-address",
 | 
				
			||||||
 | 
								Usage:   "Listen address for the DNS over HTTPS proxy server.",
 | 
				
			||||||
 | 
								Value:   "localhost",
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_DNS_ADDRESS"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
							altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
 | 
				
			||||||
 | 
								Name:    "proxy-dns-upstream",
 | 
				
			||||||
 | 
								Usage:   "Upstream endpoint URL, you can specify multiple endpoints for redundancy.",
 | 
				
			||||||
 | 
								Value:   cli.NewStringSlice("https://dns.cloudflare.com/.well-known/dns-query"),
 | 
				
			||||||
 | 
								EnvVars: []string{"TUNNEL_DNS_UPSTREAM"},
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						app.Action = func(c *cli.Context) error {
 | 
				
			||||||
 | 
							raven.CapturePanic(func() { startServer(c) }, nil)
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						app.Before = func(context *cli.Context) error {
 | 
				
			||||||
 | 
							Log = logrus.New()
 | 
				
			||||||
 | 
							inputSource, err := findInputSourceContext(context)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								Log.WithError(err).Infof("Cannot load configuration from %s", context.String("config"))
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							} else if inputSource != nil {
 | 
				
			||||||
 | 
								err := altsrc.ApplyInputSourceValues(context, inputSource, app.Flags)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									Log.WithError(err).Infof("Cannot apply configuration from %s", context.String("config"))
 | 
				
			||||||
 | 
									return err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								Log.Infof("Applied configuration from %s", context.String("config"))
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						app.Commands = []*cli.Command{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Name:      "update",
 | 
				
			||||||
 | 
								Action:    update,
 | 
				
			||||||
 | 
								Usage:     "Update the agent if a new version exists",
 | 
				
			||||||
 | 
								ArgsUsage: " ",
 | 
				
			||||||
 | 
								Description: `Looks for a new version on the offical download server.
 | 
				
			||||||
 | 
					   If a new version exists, updates the agent binary and quits.
 | 
				
			||||||
 | 
					   Otherwise, does nothing.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   To determine if an update happened in a script, check for error code 64.`,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Name:      "login",
 | 
				
			||||||
 | 
								Action:    login,
 | 
				
			||||||
 | 
								Usage:     "Generate a configuration file with your login details",
 | 
				
			||||||
 | 
								ArgsUsage: " ",
 | 
				
			||||||
 | 
								Flags: []cli.Flag{
 | 
				
			||||||
 | 
									&cli.StringFlag{
 | 
				
			||||||
 | 
										Name:   "url",
 | 
				
			||||||
 | 
										Hidden: true,
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Name:   "hello",
 | 
				
			||||||
 | 
								Action: hello,
 | 
				
			||||||
 | 
								Usage:  "Run a simple \"Hello World\" server for testing Argo Tunnel.",
 | 
				
			||||||
 | 
								Flags: []cli.Flag{
 | 
				
			||||||
 | 
									&cli.IntFlag{
 | 
				
			||||||
 | 
										Name:  "port",
 | 
				
			||||||
 | 
										Usage: "Listen on the selected port.",
 | 
				
			||||||
 | 
										Value: 8080,
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								ArgsUsage: " ", // can't be the empty string or we get the default output
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								Name:   "proxy-dns",
 | 
				
			||||||
 | 
								Action: tunneldns.Run,
 | 
				
			||||||
 | 
								Usage:  "Run a DNS over HTTPS proxy server.",
 | 
				
			||||||
 | 
								Flags: []cli.Flag{
 | 
				
			||||||
 | 
									&cli.StringFlag{
 | 
				
			||||||
 | 
										Name:    "metrics",
 | 
				
			||||||
 | 
										Value:   "localhost:",
 | 
				
			||||||
 | 
										Usage:   "Listen address for metrics reporting.",
 | 
				
			||||||
 | 
										EnvVars: []string{"TUNNEL_METRICS"},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									&cli.StringFlag{
 | 
				
			||||||
 | 
										Name:    "address",
 | 
				
			||||||
 | 
										Usage:   "Listen address for the DNS over HTTPS proxy server.",
 | 
				
			||||||
 | 
										Value:   "localhost",
 | 
				
			||||||
 | 
										EnvVars: []string{"TUNNEL_DNS_ADDRESS"},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									&cli.IntFlag{
 | 
				
			||||||
 | 
										Name:    "port",
 | 
				
			||||||
 | 
										Usage:   "Listen on given port for the DNS over HTTPS proxy server.",
 | 
				
			||||||
 | 
										Value:   53,
 | 
				
			||||||
 | 
										EnvVars: []string{"TUNNEL_DNS_PORT"},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
									&cli.StringSliceFlag{
 | 
				
			||||||
 | 
										Name:    "upstream",
 | 
				
			||||||
 | 
										Usage:   "Upstream endpoint URL, you can specify multiple endpoints for redundancy.",
 | 
				
			||||||
 | 
										Value:   cli.NewStringSlice("https://dns.cloudflare.com/.well-known/dns-query"),
 | 
				
			||||||
 | 
										EnvVars: []string{"TUNNEL_DNS_UPSTREAM"},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								ArgsUsage: " ", // can't be the empty string or we get the default output
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						runApp(app)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func startServer(c *cli.Context) {
 | 
				
			||||||
 | 
						var wg sync.WaitGroup
 | 
				
			||||||
 | 
						errC := make(chan error)
 | 
				
			||||||
 | 
						wg.Add(2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// If the user choose to supply all options through env variables,
 | 
				
			||||||
 | 
						// c.NumFlags() == 0 && c.NArg() == 0. For cloudflared to work, the user needs to at
 | 
				
			||||||
 | 
						// least provide a hostname.
 | 
				
			||||||
 | 
						if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" {
 | 
				
			||||||
 | 
							Log.Infof("No arguments were provided. You need to at least specify the hostname for this tunnel. See %s", quickStartUrl)
 | 
				
			||||||
 | 
							cli.ShowAppHelp(c)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						logLevel, err := logrus.ParseLevel(c.String("loglevel"))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Fatal("Unknown logging level specified")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						Log.SetLevel(logLevel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Fatal("Unknown protocol logging level specified")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						protoLogger := logrus.New()
 | 
				
			||||||
 | 
						protoLogger.Level = protoLogLevel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if c.String("logfile") != "" {
 | 
				
			||||||
 | 
							if err := initLogFile(c, protoLogger); err != nil {
 | 
				
			||||||
 | 
								Log.Error(err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if isAutoupdateEnabled(c) {
 | 
				
			||||||
 | 
							if initUpdate() {
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							Log.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
 | 
				
			||||||
 | 
							go autoupdate(c.Duration("autoupdate-freq"), shutdownC)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						hostname, err := validation.ValidateHostname(c.String("hostname"))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Fatal("Invalid hostname")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						clientID := c.String("id")
 | 
				
			||||||
 | 
						if !c.IsSet("id") {
 | 
				
			||||||
 | 
							clientID = generateRandomClientID()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Fatal("Tag parse failure")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
 | 
				
			||||||
 | 
						if c.IsSet("hello-world") {
 | 
				
			||||||
 | 
							wg.Add(1)
 | 
				
			||||||
 | 
							listener, err := createListener("127.0.0.1:")
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								listener.Close()
 | 
				
			||||||
 | 
								Log.WithError(err).Fatal("Cannot start Hello World Server")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							go func() {
 | 
				
			||||||
 | 
								startHelloWorldServer(listener, shutdownC)
 | 
				
			||||||
 | 
								wg.Done()
 | 
				
			||||||
 | 
								listener.Close()
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
							c.Set("url", "https://"+listener.Addr().String())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if c.IsSet("proxy-dns") {
 | 
				
			||||||
 | 
							wg.Add(1)
 | 
				
			||||||
 | 
							listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(c.Uint("proxy-dns-port")), c.StringSlice("proxy-dns-upstream"))
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								listener.Stop()
 | 
				
			||||||
 | 
								Log.WithError(err).Fatal("Cannot start the DNS over HTTPS proxy server")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							go func() {
 | 
				
			||||||
 | 
								listener.Start()
 | 
				
			||||||
 | 
								<-shutdownC
 | 
				
			||||||
 | 
								listener.Stop()
 | 
				
			||||||
 | 
								wg.Done()
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						url, err := validateUrl(c)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Fatal("Error validating url")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						Log.Infof("Proxying tunnel requests to %s", url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Fail if the user provided an old authentication method
 | 
				
			||||||
 | 
						if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
 | 
				
			||||||
 | 
							Log.Fatal("You don't need to give us your api-key anymore. Please use the new log in method. Just run cloudflared login")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check that the user has acquired a certificate using the log in command
 | 
				
			||||||
 | 
						originCertPath, err := homedir.Expand(c.String("origincert"))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Fatalf("Cannot resolve path %s", c.String("origincert"))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						ok, err := fileExists(originCertPath)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.Fatalf("Cannot check if origin cert exists at path %s", c.String("origincert"))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							Log.Fatalf(`Cannot find a valid certificate for your origin at the path:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    %s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					If the path above is wrong, specify the path with the -origincert option.
 | 
				
			||||||
 | 
					If you don't have a certificate signed by Cloudflare, run the command:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    %s login
 | 
				
			||||||
 | 
					`, originCertPath, os.Args[0])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// Easier to send the certificate as []byte via RPC than decoding it at this point
 | 
				
			||||||
 | 
						originCert, err := ioutil.ReadFile(originCertPath)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Fatalf("Cannot read %s to load origin certificate", originCertPath)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tunnelMetrics := origin.NewTunnelMetrics()
 | 
				
			||||||
 | 
						httpTransport := &http.Transport{
 | 
				
			||||||
 | 
							Proxy: http.ProxyFromEnvironment,
 | 
				
			||||||
 | 
							DialContext: (&net.Dialer{
 | 
				
			||||||
 | 
								Timeout:   c.Duration("proxy-connect-timeout"),
 | 
				
			||||||
 | 
								KeepAlive: c.Duration("proxy-tcp-keepalive"),
 | 
				
			||||||
 | 
								DualStack: !c.Bool("proxy-no-happy-eyeballs"),
 | 
				
			||||||
 | 
							}).DialContext,
 | 
				
			||||||
 | 
							MaxIdleConns:          c.Int("proxy-keepalive-connections"),
 | 
				
			||||||
 | 
							IdleConnTimeout:       c.Duration("proxy-keepalive-timeout"),
 | 
				
			||||||
 | 
							TLSHandshakeTimeout:   c.Duration("proxy-tls-timeout"),
 | 
				
			||||||
 | 
							ExpectContinueTimeout: 1 * time.Second,
 | 
				
			||||||
 | 
							TLSClientConfig:       &tls.Config{RootCAs: tlsconfig.LoadOriginCertsPool()},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if !c.IsSet("hello-world") && c.IsSet("origin-server-name") {
 | 
				
			||||||
 | 
							httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tunnelConfig := &origin.TunnelConfig{
 | 
				
			||||||
 | 
							EdgeAddrs:         c.StringSlice("edge"),
 | 
				
			||||||
 | 
							OriginUrl:         url,
 | 
				
			||||||
 | 
							Hostname:          hostname,
 | 
				
			||||||
 | 
							OriginCert:        originCert,
 | 
				
			||||||
 | 
							TlsConfig:         tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
 | 
				
			||||||
 | 
							ClientTlsConfig:   httpTransport.TLSClientConfig,
 | 
				
			||||||
 | 
							Retries:           c.Uint("retries"),
 | 
				
			||||||
 | 
							HeartbeatInterval: c.Duration("heartbeat-interval"),
 | 
				
			||||||
 | 
							MaxHeartbeats:     c.Uint64("heartbeat-count"),
 | 
				
			||||||
 | 
							ClientID:          clientID,
 | 
				
			||||||
 | 
							ReportedVersion:   Version,
 | 
				
			||||||
 | 
							LBPool:            c.String("lb-pool"),
 | 
				
			||||||
 | 
							Tags:              tags,
 | 
				
			||||||
 | 
							HAConnections:     c.Int("ha-connections"),
 | 
				
			||||||
 | 
							HTTPTransport:     httpTransport,
 | 
				
			||||||
 | 
							Metrics:           tunnelMetrics,
 | 
				
			||||||
 | 
							MetricsUpdateFreq: c.Duration("metrics-update-freq"),
 | 
				
			||||||
 | 
							ProtocolLogger:    protoLogger,
 | 
				
			||||||
 | 
							Logger:            Log,
 | 
				
			||||||
 | 
							IsAutoupdated:     c.Bool("is-autoupdated"),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						connectedSignal := make(chan struct{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go writePidFile(connectedSignal, c.String("pidfile"))
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC, connectedSignal)
 | 
				
			||||||
 | 
							wg.Done()
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Fatal("Error opening metrics server listener")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							errC <- metrics.ServeMetrics(metricsListener, shutdownC)
 | 
				
			||||||
 | 
							wg.Done()
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var errCode int
 | 
				
			||||||
 | 
						err = WaitForSignal(errC, shutdownC)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Error("Quitting due to error")
 | 
				
			||||||
 | 
							raven.CaptureErrorAndWait(err, nil)
 | 
				
			||||||
 | 
							errCode = 1
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							Log.Info("Quitting...")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// Wait for clean exit, discarding all errors
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							for range errC {
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						wg.Wait()
 | 
				
			||||||
 | 
						os.Exit(errCode)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
 | 
				
			||||||
 | 
						signals := make(chan os.Signal, 10)
 | 
				
			||||||
 | 
						signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
 | 
				
			||||||
 | 
						defer signal.Stop(signals)
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case err := <-errC:
 | 
				
			||||||
 | 
							close(shutdownC)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						case <-signals:
 | 
				
			||||||
 | 
							close(shutdownC)
 | 
				
			||||||
 | 
						case <-shutdownC:
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func update(_ *cli.Context) error {
 | 
				
			||||||
 | 
						if updateApplied() {
 | 
				
			||||||
 | 
							os.Exit(64)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func initUpdate() bool {
 | 
				
			||||||
 | 
						if updateApplied() {
 | 
				
			||||||
 | 
							os.Args = append(os.Args, "--is-autoupdated=true")
 | 
				
			||||||
 | 
							if _, err := listeners.StartProcess(); err != nil {
 | 
				
			||||||
 | 
								Log.WithError(err).Error("Unable to restart server automatically")
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func autoupdate(freq time.Duration, shutdownC chan struct{}) {
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							if updateApplied() {
 | 
				
			||||||
 | 
								os.Args = append(os.Args, "--is-autoupdated=true")
 | 
				
			||||||
 | 
								if _, err := listeners.StartProcess(); err != nil {
 | 
				
			||||||
 | 
									Log.WithError(err).Error("Unable to restart server automatically")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								close(shutdownC)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							time.Sleep(freq)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func updateApplied() bool {
 | 
				
			||||||
 | 
						releaseInfo := checkForUpdates()
 | 
				
			||||||
 | 
						if releaseInfo.Updated {
 | 
				
			||||||
 | 
							Log.Infof("Updated to version %s", releaseInfo.Version)
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if releaseInfo.Error != nil {
 | 
				
			||||||
 | 
							Log.WithError(releaseInfo.Error).Error("Update check failed")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func fileExists(path string) (bool, error) {
 | 
				
			||||||
 | 
						f, err := os.Open(path)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							if os.IsNotExist(err) {
 | 
				
			||||||
 | 
								// ignore missing files
 | 
				
			||||||
 | 
								return false, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return false, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						f.Close()
 | 
				
			||||||
 | 
						return true, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// returns the first path that contains a cert.pem file. If none of the defaultConfigDirs
 | 
				
			||||||
 | 
					// (differs by OS for legacy reasons) contains a cert.pem file, return empty string
 | 
				
			||||||
 | 
					func findDefaultOriginCertPath() string {
 | 
				
			||||||
 | 
						for _, defaultConfigDir := range defaultConfigDirs {
 | 
				
			||||||
 | 
							originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, credentialFile))
 | 
				
			||||||
 | 
							if ok, _ := fileExists(originCertPath); ok {
 | 
				
			||||||
 | 
								return originCertPath
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// returns the firt path that contains a config file. If none of the combination of
 | 
				
			||||||
 | 
					// defaultConfigDirs (differs by OS for legacy reasons) and defaultConfigFiles
 | 
				
			||||||
 | 
					// contains a config file, return empty string
 | 
				
			||||||
 | 
					func findDefaultConfigPath() string {
 | 
				
			||||||
 | 
						for _, configDir := range defaultConfigDirs {
 | 
				
			||||||
 | 
							for _, configFile := range defaultConfigFiles {
 | 
				
			||||||
 | 
								dirPath, err := homedir.Expand(configDir)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return ""
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								path := filepath.Join(dirPath, configFile)
 | 
				
			||||||
 | 
								if ok, _ := fileExists(path); ok {
 | 
				
			||||||
 | 
									return path
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ""
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) {
 | 
				
			||||||
 | 
						if context.String("config") != "" {
 | 
				
			||||||
 | 
							return altsrc.NewYamlSourceFromFile(context.String("config"))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func generateRandomClientID() string {
 | 
				
			||||||
 | 
						r := rand.New(rand.NewSource(time.Now().UnixNano()))
 | 
				
			||||||
 | 
						id := make([]byte, 32)
 | 
				
			||||||
 | 
						r.Read(id)
 | 
				
			||||||
 | 
						return hex.EncodeToString(id)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func writePidFile(waitForSignal chan struct{}, pidFile string) {
 | 
				
			||||||
 | 
						<-waitForSignal
 | 
				
			||||||
 | 
						daemon.SdNotify(false, "READY=1")
 | 
				
			||||||
 | 
						if pidFile == "" {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						file, err := os.Create(pidFile)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer file.Close()
 | 
				
			||||||
 | 
						fmt.Fprintf(file, "%d", os.Getpid())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// validate url. It can be either from --url or argument
 | 
				
			||||||
 | 
					func validateUrl(c *cli.Context) (string, error) {
 | 
				
			||||||
 | 
						var url = c.String("url")
 | 
				
			||||||
 | 
						if c.NArg() > 0 {
 | 
				
			||||||
 | 
							if c.IsSet("url") {
 | 
				
			||||||
 | 
								return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							url = c.Args().Get(0)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						validUrl, err := validation.ValidateUrl(url)
 | 
				
			||||||
 | 
						return validUrl, err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func initLogFile(c *cli.Context, protoLogger *logrus.Logger) error {
 | 
				
			||||||
 | 
						filePath, err := homedir.Expand(c.String("logfile"))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errors.Wrap(err, "Cannot resolve logfile path")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC
 | 
				
			||||||
 | 
						// do not truncate log file if the client has been autoupdated
 | 
				
			||||||
 | 
						if c.Bool("is-autoupdated") {
 | 
				
			||||||
 | 
							fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						f, err := os.OpenFile(filePath, fileMode, 0664)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							errors.Wrap(err, fmt.Sprintf("Cannot open file %s", filePath))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer f.Close()
 | 
				
			||||||
 | 
						pathMap := lfshook.PathMap{
 | 
				
			||||||
 | 
							logrus.InfoLevel:  filePath,
 | 
				
			||||||
 | 
							logrus.ErrorLevel: filePath,
 | 
				
			||||||
 | 
							logrus.FatalLevel: filePath,
 | 
				
			||||||
 | 
							logrus.PanicLevel: filePath,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						Log.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
 | 
				
			||||||
 | 
						protoLogger.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						flags := make(map[string]interface{})
 | 
				
			||||||
 | 
						envs := make(map[string]string)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, flag := range c.LocalFlagNames() {
 | 
				
			||||||
 | 
							flags[flag] = c.Generic(flag)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Find env variables for Argo Tunnel
 | 
				
			||||||
 | 
						for _, env := range os.Environ() {
 | 
				
			||||||
 | 
							// All Argo Tunnel env variables start with TUNNEL_
 | 
				
			||||||
 | 
							if strings.Contains(env, "TUNNEL_") {
 | 
				
			||||||
 | 
								vars := strings.Split(env, "=")
 | 
				
			||||||
 | 
								if len(vars) == 2 {
 | 
				
			||||||
 | 
									envs[vars[0]] = vars[1]
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						Log.Infof("Argo Tunnel build and runtime configuration: %+v", BuildAndRuntimeInfo{
 | 
				
			||||||
 | 
							GoOS:        runtime.GOOS,
 | 
				
			||||||
 | 
							GoVersion:   runtime.Version(),
 | 
				
			||||||
 | 
							GoArch:      runtime.GOARCH,
 | 
				
			||||||
 | 
							WarpVersion: Version,
 | 
				
			||||||
 | 
							WarpFlags:   flags,
 | 
				
			||||||
 | 
							WarpEnvs:    envs,
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func isAutoupdateEnabled(c *cli.Context) bool {
 | 
				
			||||||
 | 
						if terminal.IsTerminal(int(os.Stdout.Fd())) {
 | 
				
			||||||
 | 
							Log.Info(noAutoupdateMessage)
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,192 @@
 | 
				
			||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bufio"
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"os/exec"
 | 
				
			||||||
 | 
						"path/filepath"
 | 
				
			||||||
 | 
						"text/template"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/mitchellh/go-homedir"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ServiceTemplate struct {
 | 
				
			||||||
 | 
						Path     string
 | 
				
			||||||
 | 
						Content  string
 | 
				
			||||||
 | 
						FileMode os.FileMode
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ServiceTemplateArgs struct {
 | 
				
			||||||
 | 
						Path string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (st *ServiceTemplate) ResolvePath() (string, error) {
 | 
				
			||||||
 | 
						resolvedPath, err := homedir.Expand(st.Path)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("error resolving path %s: %v", st.Path, err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return resolvedPath, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
 | 
				
			||||||
 | 
						tmpl, err := template.New(st.Path).Parse(st.Content)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("error generating %s template: %v", st.Path, err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						resolvedPath, err := st.ResolvePath()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						var buffer bytes.Buffer
 | 
				
			||||||
 | 
						err = tmpl.Execute(&buffer, args)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("error generating %s: %v", st.Path, err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						fileMode := os.FileMode(0644)
 | 
				
			||||||
 | 
						if st.FileMode != 0 {
 | 
				
			||||||
 | 
							fileMode = st.FileMode
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = ioutil.WriteFile(resolvedPath, buffer.Bytes(), fileMode)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("error writing %s: %v", resolvedPath, err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (st *ServiceTemplate) Remove() error {
 | 
				
			||||||
 | 
						resolvedPath, err := st.ResolvePath()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = os.Remove(resolvedPath)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("error deleting %s: %v", resolvedPath, err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func runCommand(command string, args ...string) error {
 | 
				
			||||||
 | 
						cmd := exec.Command(command, args...)
 | 
				
			||||||
 | 
						stderr, err := cmd.StderrPipe()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("error getting stderr pipe")
 | 
				
			||||||
 | 
							return fmt.Errorf("error getting stderr pipe: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = cmd.Start()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("error starting %s", command)
 | 
				
			||||||
 | 
							return fmt.Errorf("error starting %s: %v", command, err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						commandErr, _ := ioutil.ReadAll(stderr)
 | 
				
			||||||
 | 
						if len(commandErr) > 0 {
 | 
				
			||||||
 | 
							Log.Errorf("%s: %s", command, commandErr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = cmd.Wait()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("%s returned error", command)
 | 
				
			||||||
 | 
							return fmt.Errorf("%s returned with error: %v", command, err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ensureConfigDirExists(configDir string) error {
 | 
				
			||||||
 | 
						ok, err := fileExists(configDir)
 | 
				
			||||||
 | 
						if !ok && err == nil {
 | 
				
			||||||
 | 
							err = os.Mkdir(configDir, 0700)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// openFile opens the file at path. If create is set and the file exists, returns nil, true, nil
 | 
				
			||||||
 | 
					func openFile(path string, create bool) (file *os.File, exists bool, err error) {
 | 
				
			||||||
 | 
						expandedPath, err := homedir.Expand(path)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, false, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if create {
 | 
				
			||||||
 | 
							fileInfo, err := os.Stat(expandedPath)
 | 
				
			||||||
 | 
							if err == nil && fileInfo.Size() > 0 {
 | 
				
			||||||
 | 
								return nil, true, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							file, err = os.OpenFile(expandedPath, os.O_RDWR|os.O_CREATE, 0600)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							file, err = os.Open(expandedPath)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return file, false, err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func copyCertificate(srcConfigDir, destConfigDir string) error {
 | 
				
			||||||
 | 
						destCredentialPath := filepath.Join(destConfigDir, credentialFile)
 | 
				
			||||||
 | 
						destFile, exists, err := openFile(destCredentialPath, true)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						} else if exists {
 | 
				
			||||||
 | 
							// credentials already exist, do nothing
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer destFile.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						srcCredentialPath := filepath.Join(srcConfigDir, credentialFile)
 | 
				
			||||||
 | 
						srcFile, _, err := openFile(srcCredentialPath, false)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer srcFile.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Copy certificate
 | 
				
			||||||
 | 
						_, err = io.Copy(destFile, srcFile)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return fmt.Errorf("unable to copy %s to %s: %v", srcCredentialPath, destCredentialPath, err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile string) error {
 | 
				
			||||||
 | 
						if err := ensureConfigDirExists(serviceConfigDir); err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := copyCertificate(defaultConfigDir, serviceConfigDir); err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Copy or create config
 | 
				
			||||||
 | 
						destConfigPath := filepath.Join(serviceConfigDir, defaultConfigFile)
 | 
				
			||||||
 | 
						destFile, exists, err := openFile(destConfigPath, true)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("cannot open %s", destConfigPath)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						} else if exists {
 | 
				
			||||||
 | 
							// config already exists, do nothing
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer destFile.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						srcConfigPath := filepath.Join(defaultConfigDir, defaultConfigFile)
 | 
				
			||||||
 | 
						srcFile, _, err := openFile(srcConfigPath, false)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							fmt.Println("Your service needs a config file that at least specifies the hostname option.")
 | 
				
			||||||
 | 
							fmt.Println("Type in a hostname now, or leave it blank and create the config file later.")
 | 
				
			||||||
 | 
							fmt.Print("Hostname: ")
 | 
				
			||||||
 | 
							reader := bufio.NewReader(os.Stdin)
 | 
				
			||||||
 | 
							input, _ := reader.ReadString('\n')
 | 
				
			||||||
 | 
							if input == "" {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							fmt.Fprintf(destFile, "hostname: %s\n", input)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							defer srcFile.Close()
 | 
				
			||||||
 | 
							_, err = io.Copy(destFile, srcFile)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return fmt.Errorf("unable to copy %s to %s: %v", srcConfigPath, destConfigPath, err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							Log.Infof("Copied %s to %s", srcConfigPath, destConfigPath)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,32 @@
 | 
				
			||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"regexp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Restrict key names to characters allowed in an HTTP header name.
 | 
				
			||||||
 | 
					// Restrict key values to printable characters (what is recognised as data in an HTTP header value).
 | 
				
			||||||
 | 
					var tagRegexp = regexp.MustCompile("^([a-zA-Z0-9!#$%&'*+\\-.^_`|~]+)=([[:print:]]+)$")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewTagFromCLI(compoundTag string) (tunnelpogs.Tag, bool) {
 | 
				
			||||||
 | 
						matches := tagRegexp.FindStringSubmatch(compoundTag)
 | 
				
			||||||
 | 
						if len(matches) == 0 {
 | 
				
			||||||
 | 
							return tunnelpogs.Tag{}, false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return tunnelpogs.Tag{Name: matches[1], Value: matches[2]}, true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewTagSliceFromCLI(tags []string) ([]tunnelpogs.Tag, error) {
 | 
				
			||||||
 | 
						var tagSlice []tunnelpogs.Tag
 | 
				
			||||||
 | 
						for _, compoundTag := range tags {
 | 
				
			||||||
 | 
							if tag, ok := NewTagFromCLI(compoundTag); ok {
 | 
				
			||||||
 | 
								tagSlice = append(tagSlice, tag)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return nil, fmt.Errorf("Cannot parse tag value %s", compoundTag)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return tagSlice, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,46 @@
 | 
				
			||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestSingleTag(t *testing.T) {
 | 
				
			||||||
 | 
						testCases := []struct {
 | 
				
			||||||
 | 
							Input  string
 | 
				
			||||||
 | 
							Output tunnelpogs.Tag
 | 
				
			||||||
 | 
							Fail   bool
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{Input: "x=y", Output: tunnelpogs.Tag{Name: "x", Value: "y"}},
 | 
				
			||||||
 | 
							{Input: "More-Complex=Tag Values", Output: tunnelpogs.Tag{Name: "More-Complex", Value: "Tag Values"}},
 | 
				
			||||||
 | 
							{Input: "First=Equals=Wins", Output: tunnelpogs.Tag{Name: "First", Value: "Equals=Wins"}},
 | 
				
			||||||
 | 
							{Input: "x=", Fail: true},
 | 
				
			||||||
 | 
							{Input: "=y", Fail: true},
 | 
				
			||||||
 | 
							{Input: "=", Fail: true},
 | 
				
			||||||
 | 
							{Input: "No spaces allowed=in key names", Fail: true},
 | 
				
			||||||
 | 
							{Input: "omg\nwtf=bbq", Fail: true},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for i, testCase := range testCases {
 | 
				
			||||||
 | 
							tag, ok := NewTagFromCLI(testCase.Input)
 | 
				
			||||||
 | 
							assert.Equalf(t, !testCase.Fail, ok, "mismatched success for test case %d", i)
 | 
				
			||||||
 | 
							assert.Equalf(t, testCase.Output, tag, "mismatched output for test case %d", i)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestTagSlice(t *testing.T) {
 | 
				
			||||||
 | 
						tagSlice, err := NewTagSliceFromCLI([]string{"a=b", "c=d", "e=f"})
 | 
				
			||||||
 | 
						assert.NoError(t, err)
 | 
				
			||||||
 | 
						assert.Len(t, tagSlice, 3)
 | 
				
			||||||
 | 
						assert.Equal(t, "a", tagSlice[0].Name)
 | 
				
			||||||
 | 
						assert.Equal(t, "b", tagSlice[0].Value)
 | 
				
			||||||
 | 
						assert.Equal(t, "c", tagSlice[1].Name)
 | 
				
			||||||
 | 
						assert.Equal(t, "d", tagSlice[1].Value)
 | 
				
			||||||
 | 
						assert.Equal(t, "e", tagSlice[2].Name)
 | 
				
			||||||
 | 
						assert.Equal(t, "f", tagSlice[2].Value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tagSlice, err = NewTagSliceFromCLI([]string{"a=b", "=", "e=f"})
 | 
				
			||||||
 | 
						assert.Error(t, err)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,41 @@
 | 
				
			||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "github.com/equinox-io/equinox"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const appID = "app_cwbQae3Tpea"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var publicKey = []byte(`
 | 
				
			||||||
 | 
					-----BEGIN ECDSA PUBLIC KEY-----
 | 
				
			||||||
 | 
					MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE4OWZocTVZ8Do/L6ScLdkV+9A0IYMHoOf
 | 
				
			||||||
 | 
					dsCmJ/QZ6aw0w9qkkwEpne1Lmo6+0pGexZzFZOH6w5amShn+RXt7qkSid9iWlzGq
 | 
				
			||||||
 | 
					EKx0BZogHSor9Wy5VztdFaAaVbsJiCbO
 | 
				
			||||||
 | 
					-----END ECDSA PUBLIC KEY-----
 | 
				
			||||||
 | 
					`)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type ReleaseInfo struct {
 | 
				
			||||||
 | 
						Updated bool
 | 
				
			||||||
 | 
						Version string
 | 
				
			||||||
 | 
						Error   error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func checkForUpdates() ReleaseInfo {
 | 
				
			||||||
 | 
						var opts equinox.Options
 | 
				
			||||||
 | 
						if err := opts.SetPublicKeyPEM(publicKey); err != nil {
 | 
				
			||||||
 | 
							return ReleaseInfo{Error: err}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp, err := equinox.Check(appID, opts)
 | 
				
			||||||
 | 
						switch {
 | 
				
			||||||
 | 
						case err == equinox.NotAvailableErr:
 | 
				
			||||||
 | 
							return ReleaseInfo{}
 | 
				
			||||||
 | 
						case err != nil:
 | 
				
			||||||
 | 
							return ReleaseInfo{Error: err}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err = resp.Apply()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return ReleaseInfo{Error: err}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,166 @@
 | 
				
			||||||
 | 
					// +build windows
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Copypasta from the example files:
 | 
				
			||||||
 | 
					// https://github.com/golang/sys/blob/master/windows/svc/example
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						cli "gopkg.in/urfave/cli.v2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"golang.org/x/sys/windows/svc"
 | 
				
			||||||
 | 
						"golang.org/x/sys/windows/svc/eventlog"
 | 
				
			||||||
 | 
						"golang.org/x/sys/windows/svc/mgr"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						windowsServiceName        = "Cloudflared"
 | 
				
			||||||
 | 
						windowsServiceDescription = "Argo Tunnel agent"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func runApp(app *cli.App) {
 | 
				
			||||||
 | 
						app.Commands = append(app.Commands, &cli.Command{
 | 
				
			||||||
 | 
							Name:  "service",
 | 
				
			||||||
 | 
							Usage: "Manages the Argo Tunnel Windows service",
 | 
				
			||||||
 | 
							Subcommands: []*cli.Command{
 | 
				
			||||||
 | 
								&cli.Command{
 | 
				
			||||||
 | 
									Name:   "install",
 | 
				
			||||||
 | 
									Usage:  "Install Argo Tunnel as a Windows service",
 | 
				
			||||||
 | 
									Action: installWindowsService,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								&cli.Command{
 | 
				
			||||||
 | 
									Name:   "uninstall",
 | 
				
			||||||
 | 
									Usage:  "Uninstall the Argo Tunnel service",
 | 
				
			||||||
 | 
									Action: uninstallWindowsService,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						isIntSess, err := svc.IsAnInteractiveSession()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.Fatalf("failed to determine if we are running in an interactive session: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if isIntSess {
 | 
				
			||||||
 | 
							app.Run(os.Args)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						elog, err := eventlog.Open(windowsServiceName)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("Cannot open event log for %s", windowsServiceName)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer elog.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName))
 | 
				
			||||||
 | 
						// Run executes service name by calling windowsService which is a Handler
 | 
				
			||||||
 | 
						// interface that implements Execute method
 | 
				
			||||||
 | 
						err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						elog.Info(1, fmt.Sprintf("%s service stopped", windowsServiceName))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type windowsService struct {
 | 
				
			||||||
 | 
						app  *cli.App
 | 
				
			||||||
 | 
						elog *eventlog.Log
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// called by the package code at the start of the service
 | 
				
			||||||
 | 
					func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) {
 | 
				
			||||||
 | 
						const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
 | 
				
			||||||
 | 
						changes <- svc.Status{State: svc.StartPending}
 | 
				
			||||||
 | 
						go s.app.Run(args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
 | 
				
			||||||
 | 
					loop:
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case c := <-r:
 | 
				
			||||||
 | 
								switch c.Cmd {
 | 
				
			||||||
 | 
								case svc.Interrogate:
 | 
				
			||||||
 | 
									s.elog.Info(1, fmt.Sprintf("control request 1 #%d", c))
 | 
				
			||||||
 | 
									changes <- c.CurrentStatus
 | 
				
			||||||
 | 
								case svc.Stop:
 | 
				
			||||||
 | 
									s.elog.Info(1, "received stop control request")
 | 
				
			||||||
 | 
									break loop
 | 
				
			||||||
 | 
								case svc.Shutdown:
 | 
				
			||||||
 | 
									s.elog.Info(1, "received shutdown control request")
 | 
				
			||||||
 | 
									break loop
 | 
				
			||||||
 | 
								default:
 | 
				
			||||||
 | 
									s.elog.Error(1, fmt.Sprintf("unexpected control request #%d", c))
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						close(shutdownC)
 | 
				
			||||||
 | 
						changes <- svc.Status{State: svc.StopPending}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func installWindowsService(c *cli.Context) error {
 | 
				
			||||||
 | 
						Log.Infof("Installing Argo Tunnel Windows service")
 | 
				
			||||||
 | 
						exepath, err := os.Executable()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.Infof("Cannot find path name that start the process")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						m, err := mgr.Connect()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("Cannot establish a connection to the service control manager")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer m.Disconnect()
 | 
				
			||||||
 | 
						s, err := m.OpenService(windowsServiceName)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							s.Close()
 | 
				
			||||||
 | 
							Log.Errorf("service %s already exists", windowsServiceName)
 | 
				
			||||||
 | 
							return fmt.Errorf("service %s already exists", windowsServiceName)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						config := mgr.Config{StartType: mgr.StartAutomatic, DisplayName: windowsServiceDescription}
 | 
				
			||||||
 | 
						s, err = m.CreateService(windowsServiceName, exepath, config)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.Infof("Cannot install service %s", windowsServiceName)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer s.Close()
 | 
				
			||||||
 | 
						err = eventlog.InstallAsEventCreate(windowsServiceName, eventlog.Error|eventlog.Warning|eventlog.Info)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							s.Delete()
 | 
				
			||||||
 | 
							Log.WithError(err).Infof("Cannot install event logger")
 | 
				
			||||||
 | 
							return fmt.Errorf("SetupEventLogSource() failed: %s", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func uninstallWindowsService(c *cli.Context) error {
 | 
				
			||||||
 | 
						Log.Infof("Uninstalling Argo Tunnel Windows Service")
 | 
				
			||||||
 | 
						m, err := mgr.Connect()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.Infof("Cannot establish a connection to the service control manager")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer m.Disconnect()
 | 
				
			||||||
 | 
						s, err := m.OpenService(windowsServiceName)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.Infof("service %s is not installed", windowsServiceName)
 | 
				
			||||||
 | 
							return fmt.Errorf("service %s is not installed", windowsServiceName)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer s.Close()
 | 
				
			||||||
 | 
						err = s.Delete()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.Errorf("Cannot delete service %s", windowsServiceName)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err = eventlog.Remove(windowsServiceName)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							Log.Infof("Cannot remove event logger")
 | 
				
			||||||
 | 
							return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -24,12 +24,6 @@ type activeStreamMap struct {
 | 
				
			||||||
	ignoreNewStreams bool
 | 
						ignoreNewStreams bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type FlowControlMetrics struct {
 | 
					 | 
				
			||||||
	AverageReceiveWindowSize, AverageSendWindowSize float64
 | 
					 | 
				
			||||||
	MinReceiveWindowSize, MaxReceiveWindowSize      uint32
 | 
					 | 
				
			||||||
	MinSendWindowSize, MaxSendWindowSize            uint32
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap {
 | 
					func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap {
 | 
				
			||||||
	m := &activeStreamMap{
 | 
						m := &activeStreamMap{
 | 
				
			||||||
		streams:      make(map[uint32]*MuxedStream),
 | 
							streams:      make(map[uint32]*MuxedStream),
 | 
				
			||||||
| 
						 | 
					@ -169,45 +163,3 @@ func (m *activeStreamMap) Abort() {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	m.ignoreNewStreams = true
 | 
						m.ignoreNewStreams = true
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
func (m *activeStreamMap) Metrics() *FlowControlMetrics {
 | 
					 | 
				
			||||||
	m.Lock()
 | 
					 | 
				
			||||||
	defer m.Unlock()
 | 
					 | 
				
			||||||
	var averageReceiveWindowSize, averageSendWindowSize float64
 | 
					 | 
				
			||||||
	var minReceiveWindowSize, maxReceiveWindowSize, minSendWindowSize, maxSendWindowSize uint32
 | 
					 | 
				
			||||||
	i := 0
 | 
					 | 
				
			||||||
	// The first variable in the range expression for map is the key, not index.
 | 
					 | 
				
			||||||
	for _, stream := range m.streams {
 | 
					 | 
				
			||||||
		// iterative mean: a(t+1) = a(t) + (a(t)-x)/(t+1)
 | 
					 | 
				
			||||||
		windows := stream.FlowControlWindow()
 | 
					 | 
				
			||||||
		averageReceiveWindowSize += (float64(windows.receiveWindow) - averageReceiveWindowSize) / float64(i+1)
 | 
					 | 
				
			||||||
		averageSendWindowSize += (float64(windows.sendWindow) - averageSendWindowSize) / float64(i+1)
 | 
					 | 
				
			||||||
		if i == 0 {
 | 
					 | 
				
			||||||
			maxReceiveWindowSize = windows.receiveWindow
 | 
					 | 
				
			||||||
			minReceiveWindowSize = windows.receiveWindow
 | 
					 | 
				
			||||||
			maxSendWindowSize = windows.sendWindow
 | 
					 | 
				
			||||||
			minSendWindowSize = windows.sendWindow
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			if windows.receiveWindow > maxReceiveWindowSize {
 | 
					 | 
				
			||||||
				maxReceiveWindowSize = windows.receiveWindow
 | 
					 | 
				
			||||||
			} else if windows.receiveWindow < minReceiveWindowSize {
 | 
					 | 
				
			||||||
				minReceiveWindowSize = windows.receiveWindow
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if windows.sendWindow > maxSendWindowSize {
 | 
					 | 
				
			||||||
				maxSendWindowSize = windows.sendWindow
 | 
					 | 
				
			||||||
			} else if windows.sendWindow < minSendWindowSize {
 | 
					 | 
				
			||||||
				minSendWindowSize = windows.sendWindow
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		i++
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return &FlowControlMetrics{
 | 
					 | 
				
			||||||
		MinReceiveWindowSize:     minReceiveWindowSize,
 | 
					 | 
				
			||||||
		MaxReceiveWindowSize:     maxReceiveWindowSize,
 | 
					 | 
				
			||||||
		AverageReceiveWindowSize: averageReceiveWindowSize,
 | 
					 | 
				
			||||||
		MinSendWindowSize:        minSendWindowSize,
 | 
					 | 
				
			||||||
		MaxSendWindowSize:        maxSendWindowSize,
 | 
					 | 
				
			||||||
		AverageSendWindowSize:    averageSendWindowSize,
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,18 @@
 | 
				
			||||||
 | 
					package h2mux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"sync/atomic"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AtomicCounter struct {
 | 
				
			||||||
 | 
						count uint64
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *AtomicCounter) IncrementBy(number uint64) {
 | 
				
			||||||
 | 
						atomic.AddUint64(&c.count, number)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Get returns the current value of counter and reset it to 0
 | 
				
			||||||
 | 
					func (c *AtomicCounter) Count() uint64 {
 | 
				
			||||||
 | 
						return atomic.SwapUint64(&c.count, 0)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,23 @@
 | 
				
			||||||
 | 
					package h2mux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestCounter(t *testing.T) {
 | 
				
			||||||
 | 
						var wg sync.WaitGroup
 | 
				
			||||||
 | 
						wg.Add(dataPoints)
 | 
				
			||||||
 | 
						c := AtomicCounter{}
 | 
				
			||||||
 | 
						for i := 0; i < dataPoints; i++ {
 | 
				
			||||||
 | 
							go func() {
 | 
				
			||||||
 | 
								defer wg.Done()
 | 
				
			||||||
 | 
								c.IncrementBy(uint64(1))
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						wg.Wait()
 | 
				
			||||||
 | 
						assert.Equal(t, uint64(dataPoints), c.Count())
 | 
				
			||||||
 | 
						assert.Equal(t, uint64(0), c.Count())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -59,6 +59,8 @@ type Muxer struct {
 | 
				
			||||||
	muxReader *MuxReader
 | 
						muxReader *MuxReader
 | 
				
			||||||
	// muxWriter is the write process.
 | 
						// muxWriter is the write process.
 | 
				
			||||||
	muxWriter *MuxWriter
 | 
						muxWriter *MuxWriter
 | 
				
			||||||
 | 
						// muxMetricsUpdater is the process to update metrics
 | 
				
			||||||
 | 
						muxMetricsUpdater *muxMetricsUpdater
 | 
				
			||||||
	// newStreamChan is used to create new streams on the writer thread.
 | 
						// newStreamChan is used to create new streams on the writer thread.
 | 
				
			||||||
	// The writer will assign the next available stream ID.
 | 
						// The writer will assign the next available stream ID.
 | 
				
			||||||
	newStreamChan chan MuxedStreamRequest
 | 
						newStreamChan chan MuxedStreamRequest
 | 
				
			||||||
| 
						 | 
					@ -133,6 +135,11 @@ func Handshake(
 | 
				
			||||||
	// set up reader/writer pair ready for serve
 | 
						// set up reader/writer pair ready for serve
 | 
				
			||||||
	streamErrors := NewStreamErrorMap()
 | 
						streamErrors := NewStreamErrorMap()
 | 
				
			||||||
	goAwayChan := make(chan http2.ErrCode, 1)
 | 
						goAwayChan := make(chan http2.ErrCode, 1)
 | 
				
			||||||
 | 
						updateRTTChan := make(chan *roundTripMeasurement, 1)
 | 
				
			||||||
 | 
						updateReceiveWindowChan := make(chan uint32, 1)
 | 
				
			||||||
 | 
						updateSendWindowChan := make(chan uint32, 1)
 | 
				
			||||||
 | 
						updateInBoundBytesChan := make(chan uint64)
 | 
				
			||||||
 | 
						updateOutBoundBytesChan := make(chan uint64)
 | 
				
			||||||
	pingTimestamp := NewPingTimestamp()
 | 
						pingTimestamp := NewPingTimestamp()
 | 
				
			||||||
	connActive := NewSignal()
 | 
						connActive := NewSignal()
 | 
				
			||||||
	idleDuration := config.HeartbeatInterval
 | 
						idleDuration := config.HeartbeatInterval
 | 
				
			||||||
| 
						 | 
					@ -149,34 +156,48 @@ func Handshake(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	m.explicitShutdown = NewBooleanFuse()
 | 
						m.explicitShutdown = NewBooleanFuse()
 | 
				
			||||||
	m.muxReader = &MuxReader{
 | 
						m.muxReader = &MuxReader{
 | 
				
			||||||
		f:                   m.f,
 | 
							f:                       m.f,
 | 
				
			||||||
		handler:             m.config.Handler,
 | 
							handler:                 m.config.Handler,
 | 
				
			||||||
		streams:             m.streams,
 | 
							streams:                 m.streams,
 | 
				
			||||||
		readyList:           m.readyList,
 | 
							readyList:               m.readyList,
 | 
				
			||||||
		streamErrors:        streamErrors,
 | 
							streamErrors:            streamErrors,
 | 
				
			||||||
		goAwayChan:          goAwayChan,
 | 
							goAwayChan:              goAwayChan,
 | 
				
			||||||
		abortChan:           m.abortChan,
 | 
							abortChan:               m.abortChan,
 | 
				
			||||||
		pingTimestamp:       pingTimestamp,
 | 
							pingTimestamp:           pingTimestamp,
 | 
				
			||||||
		connActive:          connActive,
 | 
							connActive:              connActive,
 | 
				
			||||||
		initialStreamWindow: defaultWindowSize,
 | 
							initialStreamWindow:     defaultWindowSize,
 | 
				
			||||||
		streamWindowMax:     maxWindowSize,
 | 
							streamWindowMax:         maxWindowSize,
 | 
				
			||||||
		r:                   m.r,
 | 
							r:                       m.r,
 | 
				
			||||||
 | 
							updateRTTChan:           updateRTTChan,
 | 
				
			||||||
 | 
							updateReceiveWindowChan: updateReceiveWindowChan,
 | 
				
			||||||
 | 
							updateSendWindowChan:    updateSendWindowChan,
 | 
				
			||||||
 | 
							updateInBoundBytesChan:  updateInBoundBytesChan,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	m.muxWriter = &MuxWriter{
 | 
						m.muxWriter = &MuxWriter{
 | 
				
			||||||
		f:               m.f,
 | 
							f:                       m.f,
 | 
				
			||||||
		streams:         m.streams,
 | 
							streams:                 m.streams,
 | 
				
			||||||
		streamErrors:    streamErrors,
 | 
							streamErrors:            streamErrors,
 | 
				
			||||||
		readyStreamChan: m.readyList.ReadyChannel(),
 | 
							readyStreamChan:         m.readyList.ReadyChannel(),
 | 
				
			||||||
		newStreamChan:   m.newStreamChan,
 | 
							newStreamChan:           m.newStreamChan,
 | 
				
			||||||
		goAwayChan:      goAwayChan,
 | 
							goAwayChan:              goAwayChan,
 | 
				
			||||||
		abortChan:       m.abortChan,
 | 
							abortChan:               m.abortChan,
 | 
				
			||||||
		pingTimestamp:   pingTimestamp,
 | 
							pingTimestamp:           pingTimestamp,
 | 
				
			||||||
		idleTimer:       NewIdleTimer(idleDuration, maxRetries),
 | 
							idleTimer:               NewIdleTimer(idleDuration, maxRetries),
 | 
				
			||||||
		connActiveChan:  connActive.WaitChannel(),
 | 
							connActiveChan:          connActive.WaitChannel(),
 | 
				
			||||||
		maxFrameSize:    defaultFrameSize,
 | 
							maxFrameSize:            defaultFrameSize,
 | 
				
			||||||
 | 
							updateReceiveWindowChan: updateReceiveWindowChan,
 | 
				
			||||||
 | 
							updateSendWindowChan:    updateSendWindowChan,
 | 
				
			||||||
 | 
							updateOutBoundBytesChan: updateOutBoundBytesChan,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
 | 
						m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
 | 
				
			||||||
 | 
						m.muxMetricsUpdater = newMuxMetricsUpdater(
 | 
				
			||||||
 | 
							updateRTTChan,
 | 
				
			||||||
 | 
							updateReceiveWindowChan,
 | 
				
			||||||
 | 
							updateSendWindowChan,
 | 
				
			||||||
 | 
							updateInBoundBytesChan,
 | 
				
			||||||
 | 
							updateOutBoundBytesChan,
 | 
				
			||||||
 | 
							m.abortChan,
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
	return m, nil
 | 
						return m, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -246,9 +267,13 @@ func (m *Muxer) Serve() error {
 | 
				
			||||||
		m.w.Close()
 | 
							m.w.Close()
 | 
				
			||||||
		m.abort()
 | 
							m.abort()
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							errChan <- m.muxMetricsUpdater.run(logger)
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
	err := <-errChan
 | 
						err := <-errChan
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
		// discard error as other handler closes
 | 
							// discard errors as other handler and muxMetricsUpdater close
 | 
				
			||||||
 | 
							<-errChan
 | 
				
			||||||
		<-errChan
 | 
							<-errChan
 | 
				
			||||||
		close(errChan)
 | 
							close(errChan)
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
| 
						 | 
					@ -318,14 +343,8 @@ func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, erro
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Return the estimated round-trip time.
 | 
					func (m *Muxer) Metrics() *MuxerMetrics {
 | 
				
			||||||
func (m *Muxer) RTT() RTTMeasurement {
 | 
						return m.muxMetricsUpdater.Metrics()
 | 
				
			||||||
	return m.muxReader.RTT()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Return min/max/average of send/receive window for all streams on this connection
 | 
					 | 
				
			||||||
func (m *Muxer) FlowControlMetrics() *FlowControlMetrics {
 | 
					 | 
				
			||||||
	return m.muxReader.FlowControlMetrics()
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m *Muxer) abort() {
 | 
					func (m *Muxer) abort() {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,6 +14,7 @@ import (
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log "github.com/sirupsen/logrus"
 | 
						log "github.com/sirupsen/logrus"
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestMain(m *testing.M) {
 | 
					func TestMain(m *testing.M) {
 | 
				
			||||||
| 
						 | 
					@ -134,9 +135,12 @@ func TestSingleStream(t *testing.T) {
 | 
				
			||||||
		if stream.Headers[0].Value != "headerValue" {
 | 
							if stream.Headers[0].Value != "headerValue" {
 | 
				
			||||||
			t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
 | 
								t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		stream.WriteHeaders([]Header{
 | 
							headers := []Header{
 | 
				
			||||||
			Header{Name: "response-header", Value: "responseValue"},
 | 
								Header{Name: "response-header", Value: "responseValue"},
 | 
				
			||||||
		})
 | 
							}
 | 
				
			||||||
 | 
							stream.WriteHeaders(headers)
 | 
				
			||||||
 | 
							assert.Equal(t, headers, stream.writeHeaders)
 | 
				
			||||||
 | 
							assert.False(t, stream.headersSent)
 | 
				
			||||||
		buf := []byte("Hello world")
 | 
							buf := []byte("Hello world")
 | 
				
			||||||
		stream.Write(buf)
 | 
							stream.Write(buf)
 | 
				
			||||||
		// after this receive, the edge closed the stream
 | 
							// after this receive, the edge closed the stream
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,6 +19,7 @@ type MuxedStream struct {
 | 
				
			||||||
	receiveWindowCurrentMax uint32
 | 
						receiveWindowCurrentMax uint32
 | 
				
			||||||
	// limit set in http2 spec. 2^31-1
 | 
						// limit set in http2 spec. 2^31-1
 | 
				
			||||||
	receiveWindowMax uint32
 | 
						receiveWindowMax uint32
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// nonzero if a WINDOW_UPDATE frame for a stream needs to be sent
 | 
						// nonzero if a WINDOW_UPDATE frame for a stream needs to be sent
 | 
				
			||||||
	windowUpdate uint32
 | 
						windowUpdate uint32
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -39,10 +40,6 @@ type MuxedStream struct {
 | 
				
			||||||
	receivedEOF bool
 | 
						receivedEOF bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type flowControlWindow struct {
 | 
					 | 
				
			||||||
	receiveWindow, sendWindow uint32
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s *MuxedStream) Read(p []byte) (n int, err error) {
 | 
					func (s *MuxedStream) Read(p []byte) (n int, err error) {
 | 
				
			||||||
	return s.readBuffer.Read(p)
 | 
						return s.readBuffer.Read(p)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -101,17 +98,21 @@ func (s *MuxedStream) WriteHeaders(headers []Header) error {
 | 
				
			||||||
		return ErrStreamHeadersSent
 | 
							return ErrStreamHeadersSent
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	s.writeHeaders = headers
 | 
						s.writeHeaders = headers
 | 
				
			||||||
 | 
						s.headersSent = false
 | 
				
			||||||
	s.writeNotify()
 | 
						s.writeNotify()
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *MuxedStream) FlowControlWindow() *flowControlWindow {
 | 
					func (s *MuxedStream) getReceiveWindow() uint32 {
 | 
				
			||||||
	s.writeLock.Lock()
 | 
						s.writeLock.Lock()
 | 
				
			||||||
	defer s.writeLock.Unlock()
 | 
						defer s.writeLock.Unlock()
 | 
				
			||||||
	return &flowControlWindow{
 | 
						return s.receiveWindow
 | 
				
			||||||
		receiveWindow: s.receiveWindow,
 | 
					}
 | 
				
			||||||
		sendWindow:    s.sendWindow,
 | 
					
 | 
				
			||||||
	}
 | 
					func (s *MuxedStream) getSendWindow() uint32 {
 | 
				
			||||||
 | 
						s.writeLock.Lock()
 | 
				
			||||||
 | 
						defer s.writeLock.Unlock()
 | 
				
			||||||
 | 
						return s.sendWindow
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// writeNotify must happen while holding writeLock.
 | 
					// writeNotify must happen while holding writeLock.
 | 
				
			||||||
| 
						 | 
					@ -209,9 +210,7 @@ func (s *MuxedStream) getChunk() *streamChunk {
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Copies at most s.sendWindow bytes
 | 
						// Copies at most s.sendWindow bytes
 | 
				
			||||||
	//log.Infof("writeBuffer len %d stream %d", s.writeBuffer.Len(), s.streamID)
 | 
					 | 
				
			||||||
	writeLen, _ := io.CopyN(&chunk.buffer, &s.writeBuffer, int64(s.sendWindow))
 | 
						writeLen, _ := io.CopyN(&chunk.buffer, &s.writeBuffer, int64(s.sendWindow))
 | 
				
			||||||
	//log.Infof("writeLen %d stream %d", writeLen, s.streamID)
 | 
					 | 
				
			||||||
	s.sendWindow -= uint32(writeLen)
 | 
						s.sendWindow -= uint32(writeLen)
 | 
				
			||||||
	s.receiveWindow += s.windowUpdate
 | 
						s.receiveWindow += s.windowUpdate
 | 
				
			||||||
	s.windowUpdate = 0
 | 
						s.windowUpdate = 0
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,232 @@
 | 
				
			||||||
 | 
					package h2mux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/golang-collections/collections/queue"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log "github.com/sirupsen/logrus"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// data points used to compute average receive window and send window size
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						// data points used to compute average receive window and send window size
 | 
				
			||||||
 | 
						dataPoints = 100
 | 
				
			||||||
 | 
						// updateFreq is set to 1 sec so we can get inbound & outbound byes/sec
 | 
				
			||||||
 | 
						updateFreq = time.Second
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type muxMetricsUpdater struct {
 | 
				
			||||||
 | 
						// rttData keeps record of rtt, rttMin, rttMax and last measured time
 | 
				
			||||||
 | 
						rttData *rttData
 | 
				
			||||||
 | 
						// receiveWindowData keeps record of receive window measurement
 | 
				
			||||||
 | 
						receiveWindowData *flowControlData
 | 
				
			||||||
 | 
						// sendWindowData keeps record of send window measurement
 | 
				
			||||||
 | 
						sendWindowData *flowControlData
 | 
				
			||||||
 | 
						// inBoundRate is incoming bytes/sec
 | 
				
			||||||
 | 
						inBoundRate *rate
 | 
				
			||||||
 | 
						// outBoundRate is outgoing bytes/sec
 | 
				
			||||||
 | 
						outBoundRate *rate
 | 
				
			||||||
 | 
						// updateRTTChan is the channel to receive new RTT measurement from muxReader
 | 
				
			||||||
 | 
						updateRTTChan <-chan *roundTripMeasurement
 | 
				
			||||||
 | 
						//updateReceiveWindowChan is the channel to receive updated receiveWindow size from muxReader and muxWriter
 | 
				
			||||||
 | 
						updateReceiveWindowChan <-chan uint32
 | 
				
			||||||
 | 
						//updateSendWindowChan is the channel to receive updated sendWindow size from muxReader and muxWriter
 | 
				
			||||||
 | 
						updateSendWindowChan <-chan uint32
 | 
				
			||||||
 | 
						// updateInBoundBytesChan us the channel to receive bytesRead from muxReader
 | 
				
			||||||
 | 
						updateInBoundBytesChan <-chan uint64
 | 
				
			||||||
 | 
						// updateOutBoundBytesChan us the channel to receive bytesWrote from muxWriter
 | 
				
			||||||
 | 
						updateOutBoundBytesChan <-chan uint64
 | 
				
			||||||
 | 
						// shutdownC is to signal the muxerMetricsUpdater to shutdown
 | 
				
			||||||
 | 
						abortChan <-chan struct{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type MuxerMetrics struct {
 | 
				
			||||||
 | 
						RTT, RTTMin, RTTMax                                              time.Duration
 | 
				
			||||||
 | 
						ReceiveWindowAve, SendWindowAve                                  float64
 | 
				
			||||||
 | 
						ReceiveWindowMin, ReceiveWindowMax, SendWindowMin, SendWindowMax uint32
 | 
				
			||||||
 | 
						InBoundRateCurr, InBoundRateMin, InBoundRateMax                  uint64
 | 
				
			||||||
 | 
						OutBoundRateCurr, OutBoundRateMin, OutBoundRateMax               uint64
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type roundTripMeasurement struct {
 | 
				
			||||||
 | 
						receiveTime, sendTime time.Time
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type rttData struct {
 | 
				
			||||||
 | 
						rtt, rttMin, rttMax time.Duration
 | 
				
			||||||
 | 
						lastMeasurementTime time.Time
 | 
				
			||||||
 | 
						lock                sync.RWMutex
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type flowControlData struct {
 | 
				
			||||||
 | 
						sum      uint64
 | 
				
			||||||
 | 
						min, max uint32
 | 
				
			||||||
 | 
						queue    *queue.Queue
 | 
				
			||||||
 | 
						lock     sync.RWMutex
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type rate struct {
 | 
				
			||||||
 | 
						curr uint64
 | 
				
			||||||
 | 
						min, max uint64
 | 
				
			||||||
 | 
						lock sync.RWMutex
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newMuxMetricsUpdater(
 | 
				
			||||||
 | 
						updateRTTChan <-chan *roundTripMeasurement,
 | 
				
			||||||
 | 
						updateReceiveWindowChan <-chan uint32,
 | 
				
			||||||
 | 
						updateSendWindowChan <-chan uint32,
 | 
				
			||||||
 | 
						updateInBoundBytesChan <-chan uint64,
 | 
				
			||||||
 | 
						updateOutBoundBytesChan <-chan uint64,
 | 
				
			||||||
 | 
						abortChan <-chan struct{},
 | 
				
			||||||
 | 
					) *muxMetricsUpdater {
 | 
				
			||||||
 | 
						return &muxMetricsUpdater{
 | 
				
			||||||
 | 
							rttData:                 newRTTData(),
 | 
				
			||||||
 | 
							receiveWindowData:       newFlowControlData(),
 | 
				
			||||||
 | 
							sendWindowData:          newFlowControlData(),
 | 
				
			||||||
 | 
							inBoundRate:             newRate(),
 | 
				
			||||||
 | 
							outBoundRate:            newRate(),
 | 
				
			||||||
 | 
							updateRTTChan:           updateRTTChan,
 | 
				
			||||||
 | 
							updateReceiveWindowChan: updateReceiveWindowChan,
 | 
				
			||||||
 | 
							updateSendWindowChan:    updateSendWindowChan,
 | 
				
			||||||
 | 
							updateInBoundBytesChan:  updateInBoundBytesChan,
 | 
				
			||||||
 | 
							updateOutBoundBytesChan: updateOutBoundBytesChan,
 | 
				
			||||||
 | 
							abortChan:               abortChan,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (updater *muxMetricsUpdater) Metrics() *MuxerMetrics {
 | 
				
			||||||
 | 
						m := &MuxerMetrics{}
 | 
				
			||||||
 | 
						m.RTT, m.RTTMin, m.RTTMax = updater.rttData.metrics()
 | 
				
			||||||
 | 
						m.ReceiveWindowAve, m.ReceiveWindowMin, m.ReceiveWindowMax = updater.receiveWindowData.metrics()
 | 
				
			||||||
 | 
						m.SendWindowAve, m.SendWindowMin, m.SendWindowMax = updater.sendWindowData.metrics()
 | 
				
			||||||
 | 
						m.InBoundRateCurr, m.InBoundRateMin, m.InBoundRateMax = updater.inBoundRate.get()
 | 
				
			||||||
 | 
						m.OutBoundRateCurr, m.OutBoundRateMin, m.OutBoundRateMax = updater.outBoundRate.get()
 | 
				
			||||||
 | 
						return m
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (updater *muxMetricsUpdater) run(parentLogger *log.Entry) error {
 | 
				
			||||||
 | 
						logger := parentLogger.WithFields(log.Fields{
 | 
				
			||||||
 | 
							"subsystem": "mux",
 | 
				
			||||||
 | 
							"dir":       "metrics",
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
						defer logger.Debug("event loop finished")
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case <-updater.abortChan:
 | 
				
			||||||
 | 
								logger.Infof("Stopping mux metrics updater")
 | 
				
			||||||
 | 
								return nil
 | 
				
			||||||
 | 
							case roundTripMeasurement := <-updater.updateRTTChan:
 | 
				
			||||||
 | 
								go updater.rttData.update(roundTripMeasurement)
 | 
				
			||||||
 | 
								logger.Debug("Update rtt")
 | 
				
			||||||
 | 
							case receiveWindow := <-updater.updateReceiveWindowChan:
 | 
				
			||||||
 | 
								go updater.receiveWindowData.update(receiveWindow)
 | 
				
			||||||
 | 
								logger.Debug("Update receive window")
 | 
				
			||||||
 | 
							case sendWindow := <-updater.updateSendWindowChan:
 | 
				
			||||||
 | 
								go updater.sendWindowData.update(sendWindow)
 | 
				
			||||||
 | 
								logger.Debug("Update send window")
 | 
				
			||||||
 | 
							case inBoundBytes := <-updater.updateInBoundBytesChan:
 | 
				
			||||||
 | 
								// inBoundBytes is bytes/sec because the update interval is 1 sec
 | 
				
			||||||
 | 
								go updater.inBoundRate.update(inBoundBytes)
 | 
				
			||||||
 | 
								logger.Debugf("Inbound bytes %d", inBoundBytes)
 | 
				
			||||||
 | 
							case outBoundBytes := <-updater.updateOutBoundBytesChan:
 | 
				
			||||||
 | 
								// outBoundBytes is bytes/sec because the update interval is 1 sec
 | 
				
			||||||
 | 
								go updater.outBoundRate.update(outBoundBytes)
 | 
				
			||||||
 | 
								logger.Debugf("Outbound bytes %d", outBoundBytes)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newRTTData() *rttData {
 | 
				
			||||||
 | 
						return &rttData{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *rttData) update(measurement *roundTripMeasurement) {
 | 
				
			||||||
 | 
						r.lock.Lock()
 | 
				
			||||||
 | 
						defer r.lock.Unlock()
 | 
				
			||||||
 | 
						// discard pings before lastMeasurementTime
 | 
				
			||||||
 | 
						if r.lastMeasurementTime.After(measurement.sendTime) {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						r.lastMeasurementTime = measurement.sendTime
 | 
				
			||||||
 | 
						r.rtt = measurement.receiveTime.Sub(measurement.sendTime)
 | 
				
			||||||
 | 
						if r.rttMax < r.rtt {
 | 
				
			||||||
 | 
							r.rttMax = r.rtt
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if r.rttMin == 0 || r.rttMin > r.rtt {
 | 
				
			||||||
 | 
							r.rttMin = r.rtt
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *rttData) metrics() (rtt, rttMin, rttMax time.Duration) {
 | 
				
			||||||
 | 
						r.lock.RLock()
 | 
				
			||||||
 | 
						defer r.lock.RUnlock()
 | 
				
			||||||
 | 
						return r.rtt, r.rttMin, r.rttMax
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newFlowControlData() *flowControlData {
 | 
				
			||||||
 | 
						return &flowControlData{queue: queue.New()}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (f *flowControlData) update(measurement uint32) {
 | 
				
			||||||
 | 
						f.lock.Lock()
 | 
				
			||||||
 | 
						defer f.lock.Unlock()
 | 
				
			||||||
 | 
						var firstItem uint32
 | 
				
			||||||
 | 
						// store new data into queue, remove oldest data if queue is full
 | 
				
			||||||
 | 
						f.queue.Enqueue(measurement)
 | 
				
			||||||
 | 
						if f.queue.Len() > dataPoints {
 | 
				
			||||||
 | 
							// data type should always be uint32
 | 
				
			||||||
 | 
							firstItem = f.queue.Dequeue().(uint32)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// if (measurement - firstItem) < 0, uint64(measurement - firstItem)
 | 
				
			||||||
 | 
						// will overflow and become a large positive number
 | 
				
			||||||
 | 
						f.sum += uint64(measurement)
 | 
				
			||||||
 | 
						f.sum -= uint64(firstItem)
 | 
				
			||||||
 | 
						if measurement > f.max {
 | 
				
			||||||
 | 
							f.max = measurement
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if f.min == 0 || measurement < f.min {
 | 
				
			||||||
 | 
							f.min = measurement
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// caller of ave() should acquire lock first
 | 
				
			||||||
 | 
					func (f *flowControlData) ave() float64 {
 | 
				
			||||||
 | 
						if f.queue.Len() == 0 {
 | 
				
			||||||
 | 
							return 0
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return float64(f.sum) / float64(f.queue.Len())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (f *flowControlData) metrics() (ave float64, min, max uint32) {
 | 
				
			||||||
 | 
						f.lock.RLock()
 | 
				
			||||||
 | 
						defer f.lock.RUnlock()
 | 
				
			||||||
 | 
						return f.ave(), f.min, f.max
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newRate() *rate {
 | 
				
			||||||
 | 
						return &rate{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *rate) update(measurement uint64) {
 | 
				
			||||||
 | 
						r.lock.Lock()
 | 
				
			||||||
 | 
						defer r.lock.Unlock()
 | 
				
			||||||
 | 
						r.curr = measurement
 | 
				
			||||||
 | 
						// if measurement is 0, then there is no incoming/outgoing connection, don't update min/max
 | 
				
			||||||
 | 
						if r.curr == 0 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if measurement > r.max {
 | 
				
			||||||
 | 
							r.max = measurement
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if r.min == 0 || measurement < r.min {
 | 
				
			||||||
 | 
							r.min = measurement
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *rate) get() (curr, min, max uint64) {
 | 
				
			||||||
 | 
						r.lock.RLock()
 | 
				
			||||||
 | 
						defer r.lock.RUnlock()
 | 
				
			||||||
 | 
						return r.curr, r.min, r.max
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,176 @@
 | 
				
			||||||
 | 
					package h2mux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log "github.com/sirupsen/logrus"
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func ave(sum uint64, len int) float64 {
 | 
				
			||||||
 | 
						return float64(sum) / float64(len)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRTTUpdate(t *testing.T) {
 | 
				
			||||||
 | 
						r := newRTTData()
 | 
				
			||||||
 | 
						start := time.Now()
 | 
				
			||||||
 | 
						// send at 0 ms, receive at 2 ms, RTT = 2ms
 | 
				
			||||||
 | 
						m := &roundTripMeasurement{receiveTime: start.Add(2 * time.Millisecond), sendTime: start}
 | 
				
			||||||
 | 
						r.update(m)
 | 
				
			||||||
 | 
						assert.Equal(t, start, r.lastMeasurementTime)
 | 
				
			||||||
 | 
						assert.Equal(t, 2*time.Millisecond, r.rtt)
 | 
				
			||||||
 | 
						assert.Equal(t, 2*time.Millisecond, r.rttMin)
 | 
				
			||||||
 | 
						assert.Equal(t, 2*time.Millisecond, r.rttMax)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// send at 3 ms, receive at 6 ms, RTT = 3ms
 | 
				
			||||||
 | 
						m = &roundTripMeasurement{receiveTime: start.Add(6 * time.Millisecond), sendTime: start.Add(3 * time.Millisecond)}
 | 
				
			||||||
 | 
						r.update(m)
 | 
				
			||||||
 | 
						assert.Equal(t, start.Add(3*time.Millisecond), r.lastMeasurementTime)
 | 
				
			||||||
 | 
						assert.Equal(t, 3*time.Millisecond, r.rtt)
 | 
				
			||||||
 | 
						assert.Equal(t, 2*time.Millisecond, r.rttMin)
 | 
				
			||||||
 | 
						assert.Equal(t, 3*time.Millisecond, r.rttMax)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// send at 7 ms, receive at 8 ms, RTT = 1ms
 | 
				
			||||||
 | 
						m = &roundTripMeasurement{receiveTime: start.Add(8 * time.Millisecond), sendTime: start.Add(7 * time.Millisecond)}
 | 
				
			||||||
 | 
						r.update(m)
 | 
				
			||||||
 | 
						assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime)
 | 
				
			||||||
 | 
						assert.Equal(t, 1*time.Millisecond, r.rtt)
 | 
				
			||||||
 | 
						assert.Equal(t, 1*time.Millisecond, r.rttMin)
 | 
				
			||||||
 | 
						assert.Equal(t, 3*time.Millisecond, r.rttMax)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// send at -4 ms, receive at 0 ms, RTT = 4ms, but this ping is before last measurement
 | 
				
			||||||
 | 
						// so it will be discarded
 | 
				
			||||||
 | 
						m = &roundTripMeasurement{receiveTime: start, sendTime: start.Add(-2 * time.Millisecond)}
 | 
				
			||||||
 | 
						r.update(m)
 | 
				
			||||||
 | 
						assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime)
 | 
				
			||||||
 | 
						assert.Equal(t, 1*time.Millisecond, r.rtt)
 | 
				
			||||||
 | 
						assert.Equal(t, 1*time.Millisecond, r.rttMin)
 | 
				
			||||||
 | 
						assert.Equal(t, 3*time.Millisecond, r.rttMax)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestFlowControlDataUpdate(t *testing.T) {
 | 
				
			||||||
 | 
						f := newFlowControlData()
 | 
				
			||||||
 | 
						assert.Equal(t, 0, f.queue.Len())
 | 
				
			||||||
 | 
						assert.Equal(t, float64(0), f.ave())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var sum uint64
 | 
				
			||||||
 | 
						min := maxWindowSize - dataPoints
 | 
				
			||||||
 | 
						max := maxWindowSize
 | 
				
			||||||
 | 
						for i := 1; i <= dataPoints; i++ {
 | 
				
			||||||
 | 
							size := maxWindowSize - uint32(i)
 | 
				
			||||||
 | 
							f.update(size)
 | 
				
			||||||
 | 
							assert.Equal(t, max - uint32(1), f.max)
 | 
				
			||||||
 | 
							assert.Equal(t, size, f.min)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							assert.Equal(t, i, f.queue.Len())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							sum += uint64(size)
 | 
				
			||||||
 | 
							assert.Equal(t, sum, f.sum)
 | 
				
			||||||
 | 
							assert.Equal(t, ave(sum, f.queue.Len()), f.ave())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// queue is full, should start to dequeue first element
 | 
				
			||||||
 | 
						for i := 1; i <= dataPoints; i++ {
 | 
				
			||||||
 | 
							f.update(max)
 | 
				
			||||||
 | 
							assert.Equal(t, max, f.max)
 | 
				
			||||||
 | 
							assert.Equal(t, min, f.min)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							assert.Equal(t, dataPoints, f.queue.Len())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							sum += uint64(i)
 | 
				
			||||||
 | 
							assert.Equal(t, sum, f.sum)
 | 
				
			||||||
 | 
							assert.Equal(t, ave(sum, dataPoints), f.ave())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestMuxMetricsUpdater(t *testing.T) {
 | 
				
			||||||
 | 
						updateRTTChan := make(chan *roundTripMeasurement)
 | 
				
			||||||
 | 
						updateReceiveWindowChan := make(chan uint32)
 | 
				
			||||||
 | 
						updateSendWindowChan := make(chan uint32)
 | 
				
			||||||
 | 
						updateInBoundBytesChan := make(chan uint64)
 | 
				
			||||||
 | 
						updateOutBoundBytesChan := make(chan uint64)
 | 
				
			||||||
 | 
						abortChan := make(chan struct{})
 | 
				
			||||||
 | 
						errChan := make(chan error)
 | 
				
			||||||
 | 
						m := newMuxMetricsUpdater(updateRTTChan,
 | 
				
			||||||
 | 
							updateReceiveWindowChan,
 | 
				
			||||||
 | 
							updateSendWindowChan,
 | 
				
			||||||
 | 
							updateInBoundBytesChan,
 | 
				
			||||||
 | 
							updateOutBoundBytesChan,
 | 
				
			||||||
 | 
							abortChan,
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						logger := log.NewEntry(log.New())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							errChan <- m.run(logger)
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var wg sync.WaitGroup
 | 
				
			||||||
 | 
						wg.Add(2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// mock muxReader
 | 
				
			||||||
 | 
						readerStart := time.Now()
 | 
				
			||||||
 | 
						rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart}
 | 
				
			||||||
 | 
						updateRTTChan <- rm
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							defer wg.Done()
 | 
				
			||||||
 | 
							// Becareful if dataPoints is not divisibile by 4
 | 
				
			||||||
 | 
							readerSend := readerStart.Add(time.Millisecond)
 | 
				
			||||||
 | 
							for i := 1; i <= dataPoints/4; i++ {
 | 
				
			||||||
 | 
								readerReceive := readerSend.Add(time.Duration(i) * time.Millisecond)
 | 
				
			||||||
 | 
								rm := &roundTripMeasurement{receiveTime: readerReceive, sendTime: readerSend}
 | 
				
			||||||
 | 
								updateRTTChan <- rm
 | 
				
			||||||
 | 
								readerSend = readerReceive.Add(time.Millisecond)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								updateReceiveWindowChan <- uint32(i)
 | 
				
			||||||
 | 
								updateSendWindowChan <- uint32(i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								updateInBoundBytesChan <- uint64(i)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// mock muxWriter
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							defer wg.Done()
 | 
				
			||||||
 | 
							for j := dataPoints/4 + 1; j <= dataPoints/2; j++ {
 | 
				
			||||||
 | 
								updateReceiveWindowChan <- uint32(j)
 | 
				
			||||||
 | 
								updateSendWindowChan <- uint32(j)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// should always be disgard since the send time is before readerSend
 | 
				
			||||||
 | 
								rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart.Add(-time.Duration(j*dataPoints) * time.Millisecond)}
 | 
				
			||||||
 | 
								updateRTTChan <- rm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								updateOutBoundBytesChan <- uint64(j)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						wg.Wait()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						metrics := m.Metrics()
 | 
				
			||||||
 | 
						points := dataPoints / 2
 | 
				
			||||||
 | 
						assert.Equal(t, time.Millisecond, metrics.RTTMin)
 | 
				
			||||||
 | 
						assert.Equal(t, time.Duration(dataPoints/4)*time.Millisecond, metrics.RTTMax)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// sum(1..i) = i*(i+1)/2, ave(1..i) = i*(i+1)/2/i = (i+1)/2
 | 
				
			||||||
 | 
						assert.Equal(t, float64(points+1)/float64(2), metrics.ReceiveWindowAve)
 | 
				
			||||||
 | 
						assert.Equal(t, uint32(1), metrics.ReceiveWindowMin)
 | 
				
			||||||
 | 
						assert.Equal(t, uint32(points), metrics.ReceiveWindowMax)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.Equal(t, float64(points+1)/float64(2), metrics.SendWindowAve)
 | 
				
			||||||
 | 
						assert.Equal(t, uint32(1), metrics.SendWindowMin)
 | 
				
			||||||
 | 
						assert.Equal(t, uint32(points), metrics.SendWindowMax)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateCurr)
 | 
				
			||||||
 | 
						assert.Equal(t, uint64(1), metrics.InBoundRateMin)
 | 
				
			||||||
 | 
						assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateMax)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateCurr)
 | 
				
			||||||
 | 
						assert.Equal(t, uint64(dataPoints/4+1), metrics.OutBoundRateMin)
 | 
				
			||||||
 | 
						assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateMax)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						close(abortChan)
 | 
				
			||||||
 | 
						assert.Nil(t, <-errChan)
 | 
				
			||||||
 | 
						close(errChan)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -3,7 +3,6 @@ package h2mux
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"encoding/binary"
 | 
						"encoding/binary"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"sync"
 | 
					 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log "github.com/sirupsen/logrus"
 | 
						log "github.com/sirupsen/logrus"
 | 
				
			||||||
| 
						 | 
					@ -34,14 +33,18 @@ type MuxReader struct {
 | 
				
			||||||
	initialStreamWindow uint32
 | 
						initialStreamWindow uint32
 | 
				
			||||||
	// The max value for the send window of a stream.
 | 
						// The max value for the send window of a stream.
 | 
				
			||||||
	streamWindowMax uint32
 | 
						streamWindowMax uint32
 | 
				
			||||||
	// windowMetrics keeps track of min/max/average of send/receive windows for all streams
 | 
					 | 
				
			||||||
	flowControlMetrics *FlowControlMetrics
 | 
					 | 
				
			||||||
	metricsMutex       sync.Mutex
 | 
					 | 
				
			||||||
	// r is a reference to the underlying connection used when shutting down.
 | 
						// r is a reference to the underlying connection used when shutting down.
 | 
				
			||||||
	r io.Closer
 | 
						r io.Closer
 | 
				
			||||||
	// rttMeasurement measures RTT based on ping timestamps.
 | 
						// updateRTTChan is the channel to send new RTT measurement to muxerMetricsUpdater
 | 
				
			||||||
	rttMeasurement RTTMeasurement
 | 
						updateRTTChan chan<- *roundTripMeasurement
 | 
				
			||||||
	rttMutex       sync.Mutex
 | 
						// updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
 | 
				
			||||||
 | 
						updateReceiveWindowChan chan<- uint32
 | 
				
			||||||
 | 
						// updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
 | 
				
			||||||
 | 
						updateSendWindowChan chan<- uint32
 | 
				
			||||||
 | 
						// bytesRead is the amount of bytes read from data frame since the last time we send bytes read to metrics
 | 
				
			||||||
 | 
						bytesRead AtomicCounter
 | 
				
			||||||
 | 
						// updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
 | 
				
			||||||
 | 
						updateInBoundBytesChan chan<- uint64
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (r *MuxReader) Shutdown() {
 | 
					func (r *MuxReader) Shutdown() {
 | 
				
			||||||
| 
						 | 
					@ -57,28 +60,26 @@ func (r *MuxReader) Shutdown() {
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (r *MuxReader) RTT() RTTMeasurement {
 | 
					 | 
				
			||||||
	r.rttMutex.Lock()
 | 
					 | 
				
			||||||
	defer r.rttMutex.Unlock()
 | 
					 | 
				
			||||||
	return r.rttMeasurement
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (r *MuxReader) FlowControlMetrics() *FlowControlMetrics {
 | 
					 | 
				
			||||||
	r.metricsMutex.Lock()
 | 
					 | 
				
			||||||
	defer r.metricsMutex.Unlock()
 | 
					 | 
				
			||||||
	if r.flowControlMetrics != nil {
 | 
					 | 
				
			||||||
		return r.flowControlMetrics
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	// No metrics available yet
 | 
					 | 
				
			||||||
	return &FlowControlMetrics{}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (r *MuxReader) run(parentLogger *log.Entry) error {
 | 
					func (r *MuxReader) run(parentLogger *log.Entry) error {
 | 
				
			||||||
	logger := parentLogger.WithFields(log.Fields{
 | 
						logger := parentLogger.WithFields(log.Fields{
 | 
				
			||||||
		"subsystem": "mux",
 | 
							"subsystem": "mux",
 | 
				
			||||||
		"dir":       "read",
 | 
							"dir":       "read",
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	defer logger.Debug("event loop finished")
 | 
						defer logger.Debug("event loop finished")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// routine to periodically update bytesRead
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							tickC := time.Tick(updateFreq)
 | 
				
			||||||
 | 
							for {
 | 
				
			||||||
 | 
								select {
 | 
				
			||||||
 | 
								case <-r.abortChan:
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								case <-tickC:
 | 
				
			||||||
 | 
									r.updateInBoundBytesChan <- r.bytesRead.Count()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for {
 | 
						for {
 | 
				
			||||||
		frame, err := r.f.ReadFrame()
 | 
							frame, err := r.f.ReadFrame()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -120,6 +121,8 @@ func (r *MuxReader) run(parentLogger *log.Entry) error {
 | 
				
			||||||
			r.receivePingData(f)
 | 
								r.receivePingData(f)
 | 
				
			||||||
		case *http2.GoAwayFrame:
 | 
							case *http2.GoAwayFrame:
 | 
				
			||||||
			err = r.receiveGoAway(f)
 | 
								err = r.receiveGoAway(f)
 | 
				
			||||||
 | 
							// The receiver of a flow-controlled frame sends a WINDOW_UPDATE frame as it
 | 
				
			||||||
 | 
							// consumes data and frees up space in flow-control windows
 | 
				
			||||||
		case *http2.WindowUpdateFrame:
 | 
							case *http2.WindowUpdateFrame:
 | 
				
			||||||
			err = r.updateStreamWindow(f)
 | 
								err = r.updateStreamWindow(f)
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
| 
						 | 
					@ -236,10 +239,11 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	data := frame.Data()
 | 
						data := frame.Data()
 | 
				
			||||||
	if len(data) > 0 {
 | 
						if len(data) > 0 {
 | 
				
			||||||
		_, err = stream.readBuffer.Write(data)
 | 
							n, err := stream.readBuffer.Write(data)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return r.streamError(stream.streamID, http2.ErrCodeInternal)
 | 
								return r.streamError(stream.streamID, http2.ErrCodeInternal)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							r.bytesRead.IncrementBy(uint64(n))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if frame.Header().Flags.Has(http2.FlagDataEndStream) {
 | 
						if frame.Header().Flags.Has(http2.FlagDataEndStream) {
 | 
				
			||||||
		if stream.receiveEOF() {
 | 
							if stream.receiveEOF() {
 | 
				
			||||||
| 
						 | 
					@ -253,6 +257,7 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E
 | 
				
			||||||
	if !stream.consumeReceiveWindow(uint32(len(data))) {
 | 
						if !stream.consumeReceiveWindow(uint32(len(data))) {
 | 
				
			||||||
		return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
 | 
							return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						r.updateReceiveWindowChan <- stream.getReceiveWindow()
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -263,10 +268,14 @@ func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
 | 
				
			||||||
		r.pingTimestamp.Set(ts)
 | 
							r.pingTimestamp.Set(ts)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	r.rttMutex.Lock()
 | 
					
 | 
				
			||||||
	r.rttMeasurement.Update(time.Unix(0, ts))
 | 
						// Update updates the computed values with a new measurement.
 | 
				
			||||||
	r.rttMutex.Unlock()
 | 
						// outgoingTime is the time that the probe was sent.
 | 
				
			||||||
	r.flowControlMetrics = r.streams.Metrics()
 | 
						// We assume that time.Now() is the time we received that probe.
 | 
				
			||||||
 | 
						r.updateRTTChan <- &roundTripMeasurement{
 | 
				
			||||||
 | 
							receiveTime: time.Now(),
 | 
				
			||||||
 | 
							sendTime:    time.Unix(0, ts),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Receive a GOAWAY from the peer. Gracefully shut down our connection.
 | 
					// Receive a GOAWAY from the peer. Gracefully shut down our connection.
 | 
				
			||||||
| 
						 | 
					@ -293,6 +302,7 @@ func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	stream.replenishSendWindow(frame.Increment)
 | 
						stream.replenishSendWindow(frame.Increment)
 | 
				
			||||||
 | 
						r.updateSendWindowChan <- stream.getSendWindow()
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -40,6 +40,14 @@ type MuxWriter struct {
 | 
				
			||||||
	headerEncoder *hpack.Encoder
 | 
						headerEncoder *hpack.Encoder
 | 
				
			||||||
	// headerBuffer is the temporary buffer used by headerEncoder.
 | 
						// headerBuffer is the temporary buffer used by headerEncoder.
 | 
				
			||||||
	headerBuffer bytes.Buffer
 | 
						headerBuffer bytes.Buffer
 | 
				
			||||||
 | 
						// updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
 | 
				
			||||||
 | 
						updateReceiveWindowChan chan<- uint32
 | 
				
			||||||
 | 
						// updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
 | 
				
			||||||
 | 
						updateSendWindowChan chan<- uint32
 | 
				
			||||||
 | 
						// bytesWrote is the amount of bytes wrote to data frame since the last time we send bytes wrote to metrics
 | 
				
			||||||
 | 
						bytesWrote AtomicCounter
 | 
				
			||||||
 | 
						// updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
 | 
				
			||||||
 | 
						updateOutBoundBytesChan chan<- uint64
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type MuxedStreamRequest struct {
 | 
					type MuxedStreamRequest struct {
 | 
				
			||||||
| 
						 | 
					@ -64,6 +72,20 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error {
 | 
				
			||||||
		"dir":       "write",
 | 
							"dir":       "write",
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	defer logger.Debug("event loop finished")
 | 
						defer logger.Debug("event loop finished")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// routine to periodically communicate bytesWrote
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							tickC := time.Tick(updateFreq)
 | 
				
			||||||
 | 
							for {
 | 
				
			||||||
 | 
								select {
 | 
				
			||||||
 | 
								case <-w.abortChan:
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								case <-tickC:
 | 
				
			||||||
 | 
									w.updateOutBoundBytesChan <- w.bytesWrote.Count()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for {
 | 
						for {
 | 
				
			||||||
		select {
 | 
							select {
 | 
				
			||||||
		case <-w.abortChan:
 | 
							case <-w.abortChan:
 | 
				
			||||||
| 
						 | 
					@ -141,7 +163,8 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error {
 | 
				
			||||||
func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error {
 | 
					func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error {
 | 
				
			||||||
	logger.Debug("writable")
 | 
						logger.Debug("writable")
 | 
				
			||||||
	chunk := stream.getChunk()
 | 
						chunk := stream.getChunk()
 | 
				
			||||||
 | 
						w.updateReceiveWindowChan <- stream.getReceiveWindow()
 | 
				
			||||||
 | 
						w.updateSendWindowChan <- stream.getSendWindow()
 | 
				
			||||||
	if chunk.sendHeadersFrame() {
 | 
						if chunk.sendHeadersFrame() {
 | 
				
			||||||
		err := w.writeHeaders(chunk.streamID, chunk.headers)
 | 
							err := w.writeHeaders(chunk.streamID, chunk.headers)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -154,7 +177,9 @@ func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) erro
 | 
				
			||||||
	if chunk.sendWindowUpdateFrame() {
 | 
						if chunk.sendWindowUpdateFrame() {
 | 
				
			||||||
		// Send a WINDOW_UPDATE frame to update our receive window.
 | 
							// Send a WINDOW_UPDATE frame to update our receive window.
 | 
				
			||||||
		// If the Stream ID is zero, the window update applies to the connection as a whole
 | 
							// If the Stream ID is zero, the window update applies to the connection as a whole
 | 
				
			||||||
		// A WINDOW_UPDATE in a specific stream applies to the connection-level flow control as well.
 | 
							// RFC7540 section-6.9.1 "A receiver that receives a flow-controlled frame MUST
 | 
				
			||||||
 | 
							// always account for  its contribution against the connection flow-control
 | 
				
			||||||
 | 
							// window, unless the receiver treats this as a connection error"
 | 
				
			||||||
		err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
 | 
							err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			logger.WithError(err).Warn("error writing window update")
 | 
								logger.WithError(err).Warn("error writing window update")
 | 
				
			||||||
| 
						 | 
					@ -170,6 +195,8 @@ func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) erro
 | 
				
			||||||
			logger.WithError(err).Warn("error writing data")
 | 
								logger.WithError(err).Warn("error writing data")
 | 
				
			||||||
			return err
 | 
								return err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							// update the amount of data wrote
 | 
				
			||||||
 | 
							w.bytesWrote.IncrementBy(uint64(len(payload)))
 | 
				
			||||||
		logger.WithField("len", len(payload)).Debug("output data")
 | 
							logger.WithField("len", len(payload)).Debug("output data")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if sentEOF {
 | 
							if sentEOF {
 | 
				
			||||||
| 
						 | 
					@ -214,19 +241,16 @@ func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	blockSize := int(w.maxFrameSize)
 | 
						blockSize := int(w.maxFrameSize)
 | 
				
			||||||
	continuation := false
 | 
					 | 
				
			||||||
	endHeaders := len(encodedHeaders) == 0
 | 
						endHeaders := len(encodedHeaders) == 0
 | 
				
			||||||
	for !endHeaders && err == nil {
 | 
						for !endHeaders && err == nil {
 | 
				
			||||||
		blockFragment := encodedHeaders
 | 
							blockFragment := encodedHeaders
 | 
				
			||||||
		if len(encodedHeaders) > blockSize {
 | 
							if len(encodedHeaders) > blockSize {
 | 
				
			||||||
			blockFragment = blockFragment[:blockSize]
 | 
								blockFragment = blockFragment[:blockSize]
 | 
				
			||||||
			encodedHeaders = encodedHeaders[blockSize:]
 | 
								encodedHeaders = encodedHeaders[blockSize:]
 | 
				
			||||||
		} else {
 | 
								// Send CONTINUATION frame if the headers can't be fit into 1 frame
 | 
				
			||||||
			endHeaders = true
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		if continuation {
 | 
					 | 
				
			||||||
			err = w.f.WriteContinuation(streamID, endHeaders, blockFragment)
 | 
								err = w.f.WriteContinuation(streamID, endHeaders, blockFragment)
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
 | 
								endHeaders = true
 | 
				
			||||||
			err = w.f.WriteHeaders(http2.HeadersFrameParam{
 | 
								err = w.f.WriteHeaders(http2.HeadersFrameParam{
 | 
				
			||||||
				StreamID:      streamID,
 | 
									StreamID:      streamID,
 | 
				
			||||||
				EndHeaders:    endHeaders,
 | 
									EndHeaders:    endHeaders,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										24
									
								
								h2mux/rtt.go
								
								
								
								
							
							
						
						
									
										24
									
								
								h2mux/rtt.go
								
								
								
								
							| 
						 | 
					@ -2,7 +2,6 @@ package h2mux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"sync/atomic"
 | 
						"sync/atomic"
 | 
				
			||||||
	"time"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// PingTimestamp is an atomic interface around ping timestamping and signalling.
 | 
					// PingTimestamp is an atomic interface around ping timestamping and signalling.
 | 
				
			||||||
| 
						 | 
					@ -28,26 +27,3 @@ func (pt *PingTimestamp) Get() int64 {
 | 
				
			||||||
func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} {
 | 
					func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} {
 | 
				
			||||||
	return pt.signal.WaitChannel()
 | 
						return pt.signal.WaitChannel()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
// RTTMeasurement encapsulates a continuous round trip time measurement.
 | 
					 | 
				
			||||||
type RTTMeasurement struct {
 | 
					 | 
				
			||||||
	Current, Min, Max   time.Duration
 | 
					 | 
				
			||||||
	lastMeasurementTime time.Time
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Update updates the computed values with a new measurement.
 | 
					 | 
				
			||||||
// outgoingTime is the time that the probe was sent.
 | 
					 | 
				
			||||||
// We assume that time.Now() is the time we received that probe.
 | 
					 | 
				
			||||||
func (r *RTTMeasurement) Update(outgoingTime time.Time) {
 | 
					 | 
				
			||||||
	if !r.lastMeasurementTime.Before(outgoingTime) {
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	r.lastMeasurementTime = outgoingTime
 | 
					 | 
				
			||||||
	r.Current = time.Since(outgoingTime)
 | 
					 | 
				
			||||||
	if r.Max < r.Current {
 | 
					 | 
				
			||||||
		r.Max = r.Current
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if r.Min > r.Current {
 | 
					 | 
				
			||||||
		r.Min = r.Current
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,13 +2,32 @@ package origin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/cloudflare/cloudflare-warp/h2mux"
 | 
						"github.com/cloudflare/cloudflared/h2mux"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/prometheus/client_golang/prometheus"
 | 
						"github.com/prometheus/client_golang/prometheus"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type TunnelMetrics struct {
 | 
					type muxerMetrics struct {
 | 
				
			||||||
 | 
						rtt              *prometheus.GaugeVec
 | 
				
			||||||
 | 
						rttMin           *prometheus.GaugeVec
 | 
				
			||||||
 | 
						rttMax           *prometheus.GaugeVec
 | 
				
			||||||
 | 
						receiveWindowAve *prometheus.GaugeVec
 | 
				
			||||||
 | 
						sendWindowAve    *prometheus.GaugeVec
 | 
				
			||||||
 | 
						receiveWindowMin *prometheus.GaugeVec
 | 
				
			||||||
 | 
						receiveWindowMax *prometheus.GaugeVec
 | 
				
			||||||
 | 
						sendWindowMin    *prometheus.GaugeVec
 | 
				
			||||||
 | 
						sendWindowMax    *prometheus.GaugeVec
 | 
				
			||||||
 | 
						inBoundRateCurr  *prometheus.GaugeVec
 | 
				
			||||||
 | 
						inBoundRateMin   *prometheus.GaugeVec
 | 
				
			||||||
 | 
						inBoundRateMax   *prometheus.GaugeVec
 | 
				
			||||||
 | 
						outBoundRateCurr *prometheus.GaugeVec
 | 
				
			||||||
 | 
						outBoundRateMin  *prometheus.GaugeVec
 | 
				
			||||||
 | 
						outBoundRateMax  *prometheus.GaugeVec
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type tunnelMetrics struct {
 | 
				
			||||||
	haConnections     prometheus.Gauge
 | 
						haConnections     prometheus.Gauge
 | 
				
			||||||
	totalRequests     prometheus.Counter
 | 
						totalRequests     prometheus.Counter
 | 
				
			||||||
	requestsPerTunnel *prometheus.CounterVec
 | 
						requestsPerTunnel *prometheus.CounterVec
 | 
				
			||||||
| 
						 | 
					@ -20,16 +39,7 @@ type TunnelMetrics struct {
 | 
				
			||||||
	maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
 | 
						maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
 | 
				
			||||||
	// concurrentRequests records max count of concurrent requests for each tunnel
 | 
						// concurrentRequests records max count of concurrent requests for each tunnel
 | 
				
			||||||
	maxConcurrentRequests map[string]uint64
 | 
						maxConcurrentRequests map[string]uint64
 | 
				
			||||||
	rtt                   prometheus.Gauge
 | 
					 | 
				
			||||||
	rttMin                prometheus.Gauge
 | 
					 | 
				
			||||||
	rttMax                prometheus.Gauge
 | 
					 | 
				
			||||||
	timerRetries          prometheus.Gauge
 | 
						timerRetries          prometheus.Gauge
 | 
				
			||||||
	receiveWindowSizeAve  prometheus.Gauge
 | 
					 | 
				
			||||||
	sendWindowSizeAve     prometheus.Gauge
 | 
					 | 
				
			||||||
	receiveWindowSizeMin  prometheus.Gauge
 | 
					 | 
				
			||||||
	receiveWindowSizeMax  prometheus.Gauge
 | 
					 | 
				
			||||||
	sendWindowSizeMin     prometheus.Gauge
 | 
					 | 
				
			||||||
	sendWindowSizeMax     prometheus.Gauge
 | 
					 | 
				
			||||||
	responseByCode        *prometheus.CounterVec
 | 
						responseByCode        *prometheus.CounterVec
 | 
				
			||||||
	responseCodePerTunnel *prometheus.CounterVec
 | 
						responseCodePerTunnel *prometheus.CounterVec
 | 
				
			||||||
	serverLocations       *prometheus.GaugeVec
 | 
						serverLocations       *prometheus.GaugeVec
 | 
				
			||||||
| 
						 | 
					@ -37,10 +47,189 @@ type TunnelMetrics struct {
 | 
				
			||||||
	locationLock sync.Mutex
 | 
						locationLock sync.Mutex
 | 
				
			||||||
	// oldServerLocations stores the last server the tunnel was connected to
 | 
						// oldServerLocations stores the last server the tunnel was connected to
 | 
				
			||||||
	oldServerLocations map[string]string
 | 
						oldServerLocations map[string]string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						muxerMetrics *muxerMetrics
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newMuxerMetrics() *muxerMetrics {
 | 
				
			||||||
 | 
						rtt := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "rtt",
 | 
				
			||||||
 | 
								Help: "Round-trip time in millisecond",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(rtt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rttMin := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "rtt_min",
 | 
				
			||||||
 | 
								Help: "Shortest round-trip time in millisecond",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(rttMin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rttMax := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "rtt_max",
 | 
				
			||||||
 | 
								Help: "Longest round-trip time in millisecond",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(rttMax)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						receiveWindowAve := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "receive_window_ave",
 | 
				
			||||||
 | 
								Help: "Average receive window size in bytes",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(receiveWindowAve)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sendWindowAve := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "send_window_ave",
 | 
				
			||||||
 | 
								Help: "Average send window size in bytes",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(sendWindowAve)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						receiveWindowMin := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "receive_window_min",
 | 
				
			||||||
 | 
								Help: "Smallest receive window size in bytes",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(receiveWindowMin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						receiveWindowMax := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "receive_window_max",
 | 
				
			||||||
 | 
								Help: "Largest receive window size in bytes",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(receiveWindowMax)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sendWindowMin := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "send_window_min",
 | 
				
			||||||
 | 
								Help: "Smallest send window size in bytes",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(sendWindowMin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						sendWindowMax := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "send_window_max",
 | 
				
			||||||
 | 
								Help: "Largest send window size in bytes",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(sendWindowMax)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						inBoundRateCurr := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "inbound_bytes_per_sec_curr",
 | 
				
			||||||
 | 
								Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(inBoundRateCurr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						inBoundRateMin := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "inbound_bytes_per_sec_min",
 | 
				
			||||||
 | 
								Help: "Minimum non-zero inbounding bytes per second",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(inBoundRateMin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						inBoundRateMax := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "inbound_bytes_per_sec_max",
 | 
				
			||||||
 | 
								Help: "Maximum inbounding bytes per second",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(inBoundRateMax)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						outBoundRateCurr := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "outbound_bytes_per_sec_curr",
 | 
				
			||||||
 | 
								Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(outBoundRateCurr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						outBoundRateMin := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "outbound_bytes_per_sec_min",
 | 
				
			||||||
 | 
								Help: "Minimum non-zero outbounding bytes per second",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(outBoundRateMin)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						outBoundRateMax := prometheus.NewGaugeVec(
 | 
				
			||||||
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
 | 
								Name: "outbound_bytes_per_sec_max",
 | 
				
			||||||
 | 
								Help: "Maximum outbounding bytes per second",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							[]string{"connection_id"},
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
						prometheus.MustRegister(outBoundRateMax)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &muxerMetrics{
 | 
				
			||||||
 | 
							rtt:              rtt,
 | 
				
			||||||
 | 
							rttMin:           rttMin,
 | 
				
			||||||
 | 
							rttMax:           rttMax,
 | 
				
			||||||
 | 
							receiveWindowAve: receiveWindowAve,
 | 
				
			||||||
 | 
							sendWindowAve:    sendWindowAve,
 | 
				
			||||||
 | 
							receiveWindowMin: receiveWindowMin,
 | 
				
			||||||
 | 
							receiveWindowMax: receiveWindowMax,
 | 
				
			||||||
 | 
							sendWindowMin:    sendWindowMin,
 | 
				
			||||||
 | 
							sendWindowMax:    sendWindowMax,
 | 
				
			||||||
 | 
							inBoundRateCurr:  inBoundRateCurr,
 | 
				
			||||||
 | 
							inBoundRateMin:   inBoundRateMin,
 | 
				
			||||||
 | 
							inBoundRateMax:   inBoundRateMax,
 | 
				
			||||||
 | 
							outBoundRateCurr: outBoundRateCurr,
 | 
				
			||||||
 | 
							outBoundRateMin:  outBoundRateMin,
 | 
				
			||||||
 | 
							outBoundRateMax:  outBoundRateMax,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) {
 | 
				
			||||||
 | 
						m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT))
 | 
				
			||||||
 | 
						m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin))
 | 
				
			||||||
 | 
						m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax))
 | 
				
			||||||
 | 
						m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve)
 | 
				
			||||||
 | 
						m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve)
 | 
				
			||||||
 | 
						m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin))
 | 
				
			||||||
 | 
						m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax))
 | 
				
			||||||
 | 
						m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin))
 | 
				
			||||||
 | 
						m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax))
 | 
				
			||||||
 | 
						m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr))
 | 
				
			||||||
 | 
						m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin))
 | 
				
			||||||
 | 
						m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax))
 | 
				
			||||||
 | 
						m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr))
 | 
				
			||||||
 | 
						m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin))
 | 
				
			||||||
 | 
						m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func convertRTTMilliSec(t time.Duration) float64 {
 | 
				
			||||||
 | 
						return float64(t / time.Millisecond)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Metrics that can be collected without asking the edge
 | 
					// Metrics that can be collected without asking the edge
 | 
				
			||||||
func NewTunnelMetrics() *TunnelMetrics {
 | 
					func NewTunnelMetrics() *tunnelMetrics {
 | 
				
			||||||
	haConnections := prometheus.NewGauge(
 | 
						haConnections := prometheus.NewGauge(
 | 
				
			||||||
		prometheus.GaugeOpts{
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
			Name: "ha_connections",
 | 
								Name: "ha_connections",
 | 
				
			||||||
| 
						 | 
					@ -82,27 +271,6 @@ func NewTunnelMetrics() *TunnelMetrics {
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
	prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
 | 
						prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rtt := prometheus.NewGauge(
 | 
					 | 
				
			||||||
		prometheus.GaugeOpts{
 | 
					 | 
				
			||||||
			Name: "rtt",
 | 
					 | 
				
			||||||
			Help: "Round-trip time",
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	prometheus.MustRegister(rtt)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	rttMin := prometheus.NewGauge(
 | 
					 | 
				
			||||||
		prometheus.GaugeOpts{
 | 
					 | 
				
			||||||
			Name: "rtt_min",
 | 
					 | 
				
			||||||
			Help: "Shortest round-trip time",
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	prometheus.MustRegister(rttMin)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	rttMax := prometheus.NewGauge(
 | 
					 | 
				
			||||||
		prometheus.GaugeOpts{
 | 
					 | 
				
			||||||
			Name: "rtt_max",
 | 
					 | 
				
			||||||
			Help: "Longest round-trip time",
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	prometheus.MustRegister(rttMax)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	timerRetries := prometheus.NewGauge(
 | 
						timerRetries := prometheus.NewGauge(
 | 
				
			||||||
		prometheus.GaugeOpts{
 | 
							prometheus.GaugeOpts{
 | 
				
			||||||
			Name: "timer_retries",
 | 
								Name: "timer_retries",
 | 
				
			||||||
| 
						 | 
					@ -110,48 +278,6 @@ func NewTunnelMetrics() *TunnelMetrics {
 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	prometheus.MustRegister(timerRetries)
 | 
						prometheus.MustRegister(timerRetries)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	receiveWindowSizeAve := prometheus.NewGauge(
 | 
					 | 
				
			||||||
		prometheus.GaugeOpts{
 | 
					 | 
				
			||||||
			Name: "receive_window_ave",
 | 
					 | 
				
			||||||
			Help: "Average receive window size",
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	prometheus.MustRegister(receiveWindowSizeAve)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	sendWindowSizeAve := prometheus.NewGauge(
 | 
					 | 
				
			||||||
		prometheus.GaugeOpts{
 | 
					 | 
				
			||||||
			Name: "send_window_ave",
 | 
					 | 
				
			||||||
			Help: "Average send window size",
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	prometheus.MustRegister(sendWindowSizeAve)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	receiveWindowSizeMin := prometheus.NewGauge(
 | 
					 | 
				
			||||||
		prometheus.GaugeOpts{
 | 
					 | 
				
			||||||
			Name: "receive_window_min",
 | 
					 | 
				
			||||||
			Help: "Smallest receive window size",
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	prometheus.MustRegister(receiveWindowSizeMin)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	receiveWindowSizeMax := prometheus.NewGauge(
 | 
					 | 
				
			||||||
		prometheus.GaugeOpts{
 | 
					 | 
				
			||||||
			Name: "receive_window_max",
 | 
					 | 
				
			||||||
			Help: "Largest receive window size",
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	prometheus.MustRegister(receiveWindowSizeMax)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	sendWindowSizeMin := prometheus.NewGauge(
 | 
					 | 
				
			||||||
		prometheus.GaugeOpts{
 | 
					 | 
				
			||||||
			Name: "send_window_min",
 | 
					 | 
				
			||||||
			Help: "Smallest send window size",
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	prometheus.MustRegister(sendWindowSizeMin)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	sendWindowSizeMax := prometheus.NewGauge(
 | 
					 | 
				
			||||||
		prometheus.GaugeOpts{
 | 
					 | 
				
			||||||
			Name: "send_window_max",
 | 
					 | 
				
			||||||
			Help: "Largest send window size",
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
	prometheus.MustRegister(sendWindowSizeMax)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	responseByCode := prometheus.NewCounterVec(
 | 
						responseByCode := prometheus.NewCounterVec(
 | 
				
			||||||
		prometheus.CounterOpts{
 | 
							prometheus.CounterOpts{
 | 
				
			||||||
			Name: "response_by_code",
 | 
								Name: "response_by_code",
 | 
				
			||||||
| 
						 | 
					@ -179,7 +305,7 @@ func NewTunnelMetrics() *TunnelMetrics {
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
	prometheus.MustRegister(serverLocations)
 | 
						prometheus.MustRegister(serverLocations)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &TunnelMetrics{
 | 
						return &tunnelMetrics{
 | 
				
			||||||
		haConnections:                  haConnections,
 | 
							haConnections:                  haConnections,
 | 
				
			||||||
		totalRequests:                  totalRequests,
 | 
							totalRequests:                  totalRequests,
 | 
				
			||||||
		requestsPerTunnel:              requestsPerTunnel,
 | 
							requestsPerTunnel:              requestsPerTunnel,
 | 
				
			||||||
| 
						 | 
					@ -187,41 +313,28 @@ func NewTunnelMetrics() *TunnelMetrics {
 | 
				
			||||||
		concurrentRequests:             make(map[string]uint64),
 | 
							concurrentRequests:             make(map[string]uint64),
 | 
				
			||||||
		maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
 | 
							maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
 | 
				
			||||||
		maxConcurrentRequests:          make(map[string]uint64),
 | 
							maxConcurrentRequests:          make(map[string]uint64),
 | 
				
			||||||
		rtt:                   rtt,
 | 
							timerRetries:                   timerRetries,
 | 
				
			||||||
		rttMin:                rttMin,
 | 
							responseByCode:                 responseByCode,
 | 
				
			||||||
		rttMax:                rttMax,
 | 
							responseCodePerTunnel:          responseCodePerTunnel,
 | 
				
			||||||
		timerRetries:          timerRetries,
 | 
							serverLocations:                serverLocations,
 | 
				
			||||||
		receiveWindowSizeAve:  receiveWindowSizeAve,
 | 
							oldServerLocations:             make(map[string]string),
 | 
				
			||||||
		sendWindowSizeAve:     sendWindowSizeAve,
 | 
							muxerMetrics:                   newMuxerMetrics(),
 | 
				
			||||||
		receiveWindowSizeMin:  receiveWindowSizeMin,
 | 
					 | 
				
			||||||
		receiveWindowSizeMax:  receiveWindowSizeMax,
 | 
					 | 
				
			||||||
		sendWindowSizeMin:     sendWindowSizeMin,
 | 
					 | 
				
			||||||
		sendWindowSizeMax:     sendWindowSizeMax,
 | 
					 | 
				
			||||||
		responseByCode:        responseByCode,
 | 
					 | 
				
			||||||
		responseCodePerTunnel: responseCodePerTunnel,
 | 
					 | 
				
			||||||
		serverLocations:       serverLocations,
 | 
					 | 
				
			||||||
		oldServerLocations:    make(map[string]string),
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *TunnelMetrics) incrementHaConnections() {
 | 
					func (t *tunnelMetrics) incrementHaConnections() {
 | 
				
			||||||
	t.haConnections.Inc()
 | 
						t.haConnections.Inc()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *TunnelMetrics) decrementHaConnections() {
 | 
					func (t *tunnelMetrics) decrementHaConnections() {
 | 
				
			||||||
	t.haConnections.Dec()
 | 
						t.haConnections.Dec()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *TunnelMetrics) updateTunnelFlowControlMetrics(metrics *h2mux.FlowControlMetrics) {
 | 
					func (t *tunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) {
 | 
				
			||||||
	t.receiveWindowSizeAve.Set(float64(metrics.AverageReceiveWindowSize))
 | 
						t.muxerMetrics.update(connectionID, metrics)
 | 
				
			||||||
	t.sendWindowSizeAve.Set(float64(metrics.AverageSendWindowSize))
 | 
					 | 
				
			||||||
	t.receiveWindowSizeMin.Set(float64(metrics.MinReceiveWindowSize))
 | 
					 | 
				
			||||||
	t.receiveWindowSizeMax.Set(float64(metrics.MaxReceiveWindowSize))
 | 
					 | 
				
			||||||
	t.sendWindowSizeMin.Set(float64(metrics.MinSendWindowSize))
 | 
					 | 
				
			||||||
	t.sendWindowSizeMax.Set(float64(metrics.MaxSendWindowSize))
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *TunnelMetrics) incrementRequests(connectionID string) {
 | 
					func (t *tunnelMetrics) incrementRequests(connectionID string) {
 | 
				
			||||||
	t.concurrentRequestsLock.Lock()
 | 
						t.concurrentRequestsLock.Lock()
 | 
				
			||||||
	var concurrentRequests uint64
 | 
						var concurrentRequests uint64
 | 
				
			||||||
	var ok bool
 | 
						var ok bool
 | 
				
			||||||
| 
						 | 
					@ -243,7 +356,7 @@ func (t *TunnelMetrics) incrementRequests(connectionID string) {
 | 
				
			||||||
	t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
 | 
						t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
 | 
					func (t *tunnelMetrics) decrementConcurrentRequests(connectionID string) {
 | 
				
			||||||
	t.concurrentRequestsLock.Lock()
 | 
						t.concurrentRequestsLock.Lock()
 | 
				
			||||||
	if _, ok := t.concurrentRequests[connectionID]; ok {
 | 
						if _, ok := t.concurrentRequests[connectionID]; ok {
 | 
				
			||||||
		t.concurrentRequests[connectionID] -= 1
 | 
							t.concurrentRequests[connectionID] -= 1
 | 
				
			||||||
| 
						 | 
					@ -255,13 +368,13 @@ func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
 | 
				
			||||||
	t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
 | 
						t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *TunnelMetrics) incrementResponses(connectionID, code string) {
 | 
					func (t *tunnelMetrics) incrementResponses(connectionID, code string) {
 | 
				
			||||||
	t.responseByCode.WithLabelValues(code).Inc()
 | 
						t.responseByCode.WithLabelValues(code).Inc()
 | 
				
			||||||
	t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
 | 
						t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) {
 | 
					func (t *tunnelMetrics) registerServerLocation(connectionID, loc string) {
 | 
				
			||||||
	t.locationLock.Lock()
 | 
						t.locationLock.Lock()
 | 
				
			||||||
	defer t.locationLock.Unlock()
 | 
						defer t.locationLock.Unlock()
 | 
				
			||||||
	if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
 | 
						if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,11 +15,11 @@ import (
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"golang.org/x/net/context"
 | 
						"golang.org/x/net/context"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/cloudflare/cloudflare-warp/h2mux"
 | 
						"github.com/cloudflare/cloudflared/h2mux"
 | 
				
			||||||
	"github.com/cloudflare/cloudflare-warp/tunnelrpc"
 | 
						"github.com/cloudflare/cloudflared/tunnelrpc"
 | 
				
			||||||
	tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
 | 
						tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
 | 
				
			||||||
	"github.com/cloudflare/cloudflare-warp/validation"
 | 
						"github.com/cloudflare/cloudflared/validation"
 | 
				
			||||||
	"github.com/cloudflare/cloudflare-warp/websocket"
 | 
						"github.com/cloudflare/cloudflared/websocket"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	raven "github.com/getsentry/raven-go"
 | 
						raven "github.com/getsentry/raven-go"
 | 
				
			||||||
	"github.com/pkg/errors"
 | 
						"github.com/pkg/errors"
 | 
				
			||||||
| 
						 | 
					@ -53,7 +53,7 @@ type TunnelConfig struct {
 | 
				
			||||||
	Tags              []tunnelpogs.Tag
 | 
						Tags              []tunnelpogs.Tag
 | 
				
			||||||
	HAConnections     int
 | 
						HAConnections     int
 | 
				
			||||||
	HTTPTransport     http.RoundTripper
 | 
						HTTPTransport     http.RoundTripper
 | 
				
			||||||
	Metrics           *TunnelMetrics
 | 
						Metrics           *tunnelMetrics
 | 
				
			||||||
	MetricsUpdateFreq time.Duration
 | 
						MetricsUpdateFreq time.Duration
 | 
				
			||||||
	ProtocolLogger    *logrus.Logger
 | 
						ProtocolLogger    *logrus.Logger
 | 
				
			||||||
	Logger            *logrus.Logger
 | 
						Logger            *logrus.Logger
 | 
				
			||||||
| 
						 | 
					@ -185,6 +185,7 @@ func ServeTunnel(
 | 
				
			||||||
	serveCtx, serveCancel := context.WithCancel(ctx)
 | 
						serveCtx, serveCancel := context.WithCancel(ctx)
 | 
				
			||||||
	registerErrC := make(chan error, 1)
 | 
						registerErrC := make(chan error, 1)
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
 | 
							defer wg.Done()
 | 
				
			||||||
		err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP)
 | 
							err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP)
 | 
				
			||||||
		if err == nil {
 | 
							if err == nil {
 | 
				
			||||||
			connectedFuse.Fuse(true)
 | 
								connectedFuse.Fuse(true)
 | 
				
			||||||
| 
						 | 
					@ -193,18 +194,18 @@ func ServeTunnel(
 | 
				
			||||||
			serveCancel()
 | 
								serveCancel()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		registerErrC <- err
 | 
							registerErrC <- err
 | 
				
			||||||
		wg.Done()
 | 
					 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
	updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
 | 
						updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
		defer wg.Done()
 | 
							defer wg.Done()
 | 
				
			||||||
 | 
							connectionTag := uint8ToString(connectionID)
 | 
				
			||||||
		for {
 | 
							for {
 | 
				
			||||||
			select {
 | 
								select {
 | 
				
			||||||
			case <-serveCtx.Done():
 | 
								case <-serveCtx.Done():
 | 
				
			||||||
				handler.muxer.Shutdown()
 | 
									handler.muxer.Shutdown()
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			case <-updateMetricsTickC:
 | 
								case <-updateMetricsTickC:
 | 
				
			||||||
				handler.UpdateMetrics()
 | 
									handler.UpdateMetrics(connectionTag)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
| 
						 | 
					@ -303,7 +304,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi
 | 
				
			||||||
func LogServerInfo(logger *logrus.Entry,
 | 
					func LogServerInfo(logger *logrus.Entry,
 | 
				
			||||||
	promise tunnelrpc.ServerInfo_Promise,
 | 
						promise tunnelrpc.ServerInfo_Promise,
 | 
				
			||||||
	connectionID uint8,
 | 
						connectionID uint8,
 | 
				
			||||||
	metrics *TunnelMetrics,
 | 
						metrics *tunnelMetrics,
 | 
				
			||||||
) {
 | 
					) {
 | 
				
			||||||
	serverInfoMessage, err := promise.Struct()
 | 
						serverInfoMessage, err := promise.Struct()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -356,13 +357,17 @@ func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func FindCfRayHeader(h1 *http.Request) string {
 | 
				
			||||||
 | 
						return h1.Header.Get("Cf-Ray")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type TunnelHandler struct {
 | 
					type TunnelHandler struct {
 | 
				
			||||||
	originUrl  string
 | 
						originUrl  string
 | 
				
			||||||
	muxer      *h2mux.Muxer
 | 
						muxer      *h2mux.Muxer
 | 
				
			||||||
	httpClient http.RoundTripper
 | 
						httpClient http.RoundTripper
 | 
				
			||||||
	tlsConfig  *tls.Config
 | 
						tlsConfig  *tls.Config
 | 
				
			||||||
	tags       []tunnelpogs.Tag
 | 
						tags       []tunnelpogs.Tag
 | 
				
			||||||
	metrics    *TunnelMetrics
 | 
						metrics    *tunnelMetrics
 | 
				
			||||||
	// connectionID is only used by metrics, and prometheus requires labels to be string
 | 
						// connectionID is only used by metrics, and prometheus requires labels to be string
 | 
				
			||||||
	connectionID string
 | 
						connectionID string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -435,7 +440,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
 | 
				
			||||||
		Log.WithError(err).Error("invalid request received")
 | 
							Log.WithError(err).Error("invalid request received")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	h.AppendTagHeaders(req)
 | 
						h.AppendTagHeaders(req)
 | 
				
			||||||
 | 
						cfRay := FindCfRayHeader(req)
 | 
				
			||||||
 | 
						h.logRequest(req, cfRay)
 | 
				
			||||||
	if websocket.IsWebSocketUpgrade(req) {
 | 
						if websocket.IsWebSocketUpgrade(req) {
 | 
				
			||||||
		conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
 | 
							conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
| 
						 | 
					@ -444,6 +450,8 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
 | 
				
			||||||
			stream.WriteHeaders(H1ResponseToH2Response(response))
 | 
								stream.WriteHeaders(H1ResponseToH2Response(response))
 | 
				
			||||||
			defer conn.Close()
 | 
								defer conn.Close()
 | 
				
			||||||
			websocket.Stream(conn.UnderlyingConn(), stream)
 | 
								websocket.Stream(conn.UnderlyingConn(), stream)
 | 
				
			||||||
 | 
								h.metrics.incrementResponses(h.connectionID, "200")
 | 
				
			||||||
 | 
								h.logResponse(response, cfRay)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		response, err := h.httpClient.RoundTrip(req)
 | 
							response, err := h.httpClient.RoundTrip(req)
 | 
				
			||||||
| 
						 | 
					@ -454,6 +462,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
 | 
				
			||||||
			stream.WriteHeaders(H1ResponseToH2Response(response))
 | 
								stream.WriteHeaders(H1ResponseToH2Response(response))
 | 
				
			||||||
			io.Copy(stream, response.Body)
 | 
								io.Copy(stream, response.Body)
 | 
				
			||||||
			h.metrics.incrementResponses(h.connectionID, "200")
 | 
								h.metrics.incrementResponses(h.connectionID, "200")
 | 
				
			||||||
 | 
								h.logResponse(response, cfRay)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	h.metrics.decrementConcurrentRequests(h.connectionID)
 | 
						h.metrics.decrementConcurrentRequests(h.connectionID)
 | 
				
			||||||
| 
						 | 
					@ -467,9 +476,27 @@ func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
 | 
				
			||||||
	h.metrics.incrementResponses(h.connectionID, "502")
 | 
						h.metrics.incrementResponses(h.connectionID, "502")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *TunnelHandler) UpdateMetrics() {
 | 
					func (h *TunnelHandler) logRequest(req *http.Request, cfRay string) {
 | 
				
			||||||
	flowCtlMetrics := h.muxer.FlowControlMetrics()
 | 
						if cfRay != "" {
 | 
				
			||||||
	h.metrics.updateTunnelFlowControlMetrics(flowCtlMetrics)
 | 
							Log.WithField("CF-RAY", cfRay).Infof("%s %s %s", req.Method, req.URL, req.Proto)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							Log.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						Log.Debugf("Request Headers %+v", req.Header)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *TunnelHandler) logResponse(r *http.Response, cfRay string) {
 | 
				
			||||||
 | 
						if cfRay != "" {
 | 
				
			||||||
 | 
							Log.WithField("CF-RAY", cfRay).Infof("%s", r.Status)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							Log.Infof("%s", r.Status)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						Log.Debugf("Response Headers %+v", r.Header)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *TunnelHandler) UpdateMetrics(connectionID string) {
 | 
				
			||||||
 | 
						h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics())
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func uint8ToString(input uint8) string {
 | 
					func uint8ToString(input uint8) string {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -11,28 +11,28 @@ const (
 | 
				
			||||||
BgUrgQQAIg==
 | 
					BgUrgQQAIg==
 | 
				
			||||||
-----END EC PARAMETERS-----
 | 
					-----END EC PARAMETERS-----
 | 
				
			||||||
-----BEGIN EC PRIVATE KEY-----
 | 
					-----BEGIN EC PRIVATE KEY-----
 | 
				
			||||||
MIGkAgEBBDAdyQBXfxTDCQSOT0HugmH9pVBtIw8t5dYvm6HxGlNq6P57v5GeN02Z
 | 
					MIGkAgEBBDBGGfwhIJdiUiJUVIItqJjEIMmlXxsMa8TQeer47+g+cIZ466rgg8EK
 | 
				
			||||||
dH9FRl7+VSWgBwYFK4EEACKhZANiAATqpFzTxxV7D+/oqhKCTR6BEM9elTfKaRQE
 | 
					+Mdn6BY48GCgBwYFK4EEACKhZANiAASW//A9iDbPKg3OLkn7yJqLer32g9I5lBKR
 | 
				
			||||||
FsLufcmaTMw/9tTwgpHKao/QsLKDTNbQhbSQLkcmpCQKlSGhl+pCrqNt/oYUAhav
 | 
					tPc/zBubQLLz9lAaYI6AOQiJXhGr5JkKmQfi1sYHK5rJITPFy4W8Et4hHLdazDZH
 | 
				
			||||||
UIwpwGiLCqGH/R2AqWLKRPOa/Rufs/U=
 | 
					WnEd+TStQABFUjrhtqXPWmGKcly0pOE=
 | 
				
			||||||
-----END EC PRIVATE KEY-----`
 | 
					-----END EC PRIVATE KEY-----`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	helloCRT = `
 | 
						helloCRT = `
 | 
				
			||||||
-----BEGIN CERTIFICATE-----
 | 
					-----BEGIN CERTIFICATE-----
 | 
				
			||||||
MIICkDCCAhigAwIBAgIJAPtKfUjc2lwGMAkGByqGSM49BAEwgYoxCzAJBgNVBAYT
 | 
					MIICiDCCAg6gAwIBAgIJAJ/FfkBTtbuIMAkGByqGSM49BAEwfzELMAkGA1UEBhMC
 | 
				
			||||||
AlVTMQ4wDAYDVQQIDAVUZXhhczEPMA0GA1UEBwwGQXVzdGluMRkwFwYDVQQKDBBD
 | 
					VVMxDjAMBgNVBAgMBVRleGFzMQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENs
 | 
				
			||||||
bG91ZGZsYXJlLCBJbmMuMT8wPQYDVQQDDDZDbG91ZGZsYXJlIEFyZ28gVHVubmVs
 | 
					b3VkZmxhcmUsIEluYy4xNDAyBgNVBAMMK0FyZ28gVHVubmVsIFNhbXBsZSBIZWxs
 | 
				
			||||||
IFNhbXBsZSBIZWxsbyBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMjE1MjAxNjU5
 | 
					byBTZXJ2ZXIgQ2VydGlmaWNhdGUwHhcNMTgwMzE5MjMwNTMyWhcNMjgwMzE2MjMw
 | 
				
			||||||
WhcNMjgwMjEzMjAxNjU5WjCBijELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVRleGFz
 | 
					NTMyWjB/MQswCQYDVQQGEwJVUzEOMAwGA1UECAwFVGV4YXMxDzANBgNVBAcMBkF1
 | 
				
			||||||
MQ8wDQYDVQQHDAZBdXN0aW4xGTAXBgNVBAoMEENsb3VkZmxhcmUsIEluYy4xPzA9
 | 
					c3RpbjEZMBcGA1UECgwQQ2xvdWRmbGFyZSwgSW5jLjE0MDIGA1UEAwwrQXJnbyBU
 | 
				
			||||||
BgNVBAMMNkNsb3VkZmxhcmUgQXJnbyBUdW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZl
 | 
					dW5uZWwgU2FtcGxlIEhlbGxvIFNlcnZlciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49
 | 
				
			||||||
ciBDZXJ0aWZpY2F0ZTB2MBAGByqGSM49AgEGBSuBBAAiA2IABOqkXNPHFXsP7+iq
 | 
					AgEGBSuBBAAiA2IABJb/8D2INs8qDc4uSfvImot6vfaD0jmUEpG09z/MG5tAsvP2
 | 
				
			||||||
EoJNHoEQz16VN8ppFAQWwu59yZpMzD/21PCCkcpqj9CwsoNM1tCFtJAuRyakJAqV
 | 
					UBpgjoA5CIleEavkmQqZB+LWxgcrmskhM8XLhbwS3iEct1rMNkdacR35NK1AAEVS
 | 
				
			||||||
IaGX6kKuo23+hhQCFq9QjCnAaIsKoYf9HYCpYspE85r9G5+z9aNJMEcwRQYDVR0R
 | 
					OuG2pc9aYYpyXLSk4aNXMFUwUwYDVR0RBEwwSoIJbG9jYWxob3N0ghFjbG91ZGZs
 | 
				
			||||||
BD4wPIIJbG9jYWxob3N0ggp3YXJwLWhlbGxvggt3YXJwMi1oZWxsb4cEfwAAAYcQ
 | 
					YXJlZC1oZWxsb4ISY2xvdWRmbGFyZWQyLWhlbGxvhwR/AAABhxAAAAAAAAAAAAAA
 | 
				
			||||||
AAAAAAAAAAAAAAAAAAAAATAJBgcqhkjOPQQBA2cAMGQCMHyVPufXZ6vQo6XRWRa0
 | 
					AAAAAAABMAkGByqGSM49BAEDaQAwZgIxAPxkdghH6y8xLMnY9Bom3Llf4NYM6yB9
 | 
				
			||||||
dAwtfgesOdZVP2Wt+t5v8jOIQQh1IQXYk5GtyoZGSObjhQIwd1fRgAyKXaZt+1DV
 | 
					PD1YsaNUJTsxjTk3YY1Jsp+yzK0yUKtTZwIxAPcdvqCF2/iR9H288pCT1TgtO0a9
 | 
				
			||||||
ZtHTdf8pMvESfJsSd8AB1eQ6q+pAiRUYyaxcE1Mlo2YY5o+g
 | 
					cJL9RY1lq7DIGN37v1ZXReWaD+3hNokY8NriVg==
 | 
				
			||||||
-----END CERTIFICATE-----`
 | 
					-----END CERTIFICATE-----`
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,38 @@
 | 
				
			||||||
 | 
					package tunneldns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/coredns/coredns/plugin"
 | 
				
			||||||
 | 
						"github.com/miekg/dns"
 | 
				
			||||||
 | 
						"github.com/pkg/errors"
 | 
				
			||||||
 | 
						"golang.org/x/net/context"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Upstream is a simplified interface for proxy destination
 | 
				
			||||||
 | 
					type Upstream interface {
 | 
				
			||||||
 | 
						Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ProxyPlugin is a simplified DNS proxy using a generic upstream interface
 | 
				
			||||||
 | 
					type ProxyPlugin struct {
 | 
				
			||||||
 | 
						Upstreams []Upstream
 | 
				
			||||||
 | 
						Next      plugin.Handler
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ServeDNS implements interface for CoreDNS plugin
 | 
				
			||||||
 | 
					func (p ProxyPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
 | 
				
			||||||
 | 
						var reply *dns.Msg
 | 
				
			||||||
 | 
						var backendErr error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, upstream := range p.Upstreams {
 | 
				
			||||||
 | 
							reply, backendErr = upstream.Exchange(ctx, r)
 | 
				
			||||||
 | 
							if backendErr == nil {
 | 
				
			||||||
 | 
								w.WriteMsg(reply)
 | 
				
			||||||
 | 
								return 0, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Name implements interface for CoreDNS plugin
 | 
				
			||||||
 | 
					func (p ProxyPlugin) Name() string { return "proxy" }
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,97 @@
 | 
				
			||||||
 | 
					package tunneldns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"crypto/tls"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io/ioutil"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/miekg/dns"
 | 
				
			||||||
 | 
						"github.com/pkg/errors"
 | 
				
			||||||
 | 
						log "github.com/sirupsen/logrus"
 | 
				
			||||||
 | 
						"golang.org/x/net/context"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						defaultTimeout = 5 * time.Second
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// UpstreamHTTPS is the upstream implementation for DNS over HTTPS service
 | 
				
			||||||
 | 
					type UpstreamHTTPS struct {
 | 
				
			||||||
 | 
						client   *http.Client
 | 
				
			||||||
 | 
						endpoint *url.URL
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewUpstreamHTTPS creates a new DNS over HTTPS upstream from hostname
 | 
				
			||||||
 | 
					func NewUpstreamHTTPS(endpoint string) (Upstream, error) {
 | 
				
			||||||
 | 
						u, err := url.Parse(endpoint)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Update TLS and HTTP client configuration
 | 
				
			||||||
 | 
						tls := &tls.Config{ServerName: u.Hostname()}
 | 
				
			||||||
 | 
						client := &http.Client{
 | 
				
			||||||
 | 
							Timeout:   time.Second * defaultTimeout,
 | 
				
			||||||
 | 
							Transport: &http.Transport{TLSClientConfig: tls},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &UpstreamHTTPS{client: client, endpoint: u}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Exchange provides an implementation for the Upstream interface
 | 
				
			||||||
 | 
					func (u *UpstreamHTTPS) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
 | 
				
			||||||
 | 
						queryBuf, err := query.Pack()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, errors.Wrap(err, "failed to pack DNS query")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// No content negotiation for now, use DNS wire format
 | 
				
			||||||
 | 
						buf, backendErr := u.exchangeWireformat(queryBuf)
 | 
				
			||||||
 | 
						if backendErr == nil {
 | 
				
			||||||
 | 
							response := &dns.Msg{}
 | 
				
			||||||
 | 
							if err := response.Unpack(buf); err != nil {
 | 
				
			||||||
 | 
								return nil, errors.Wrap(err, "failed to unpack DNS response from body")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							response.Id = query.Id
 | 
				
			||||||
 | 
							return response, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.WithError(backendErr).Errorf("failed to connect to an HTTPS backend %q", u.endpoint)
 | 
				
			||||||
 | 
						return nil, backendErr
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Perform message exchange with the default UDP wireformat defined in current draft
 | 
				
			||||||
 | 
					// https://datatracker.ietf.org/doc/draft-ietf-doh-dns-over-https
 | 
				
			||||||
 | 
					func (u *UpstreamHTTPS) exchangeWireformat(msg []byte) ([]byte, error) {
 | 
				
			||||||
 | 
						req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, errors.Wrap(err, "failed to create an HTTPS request")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req.Header.Add("Content-Type", "application/dns-udpwireformat")
 | 
				
			||||||
 | 
						req.Host = u.endpoint.Hostname()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						resp, err := u.client.Do(req)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, errors.Wrap(err, "failed to perform an HTTPS request")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check response status code
 | 
				
			||||||
 | 
						defer resp.Body.Close()
 | 
				
			||||||
 | 
						if resp.StatusCode != http.StatusOK {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Read wireformat response from the body
 | 
				
			||||||
 | 
						buf, err := ioutil.ReadAll(resp.Body)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, errors.Wrap(err, "failed to read the response body")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return buf, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,45 @@
 | 
				
			||||||
 | 
					package tunneldns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"github.com/coredns/coredns/plugin"
 | 
				
			||||||
 | 
						"github.com/coredns/coredns/plugin/metrics/vars"
 | 
				
			||||||
 | 
						"github.com/coredns/coredns/plugin/pkg/dnstest"
 | 
				
			||||||
 | 
						"github.com/coredns/coredns/plugin/pkg/rcode"
 | 
				
			||||||
 | 
						"github.com/coredns/coredns/request"
 | 
				
			||||||
 | 
						"github.com/miekg/dns"
 | 
				
			||||||
 | 
						"github.com/prometheus/client_golang/prometheus"
 | 
				
			||||||
 | 
						"golang.org/x/net/context"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// MetricsPlugin is an adapter for CoreDNS and built-in metrics
 | 
				
			||||||
 | 
					type MetricsPlugin struct {
 | 
				
			||||||
 | 
						Next plugin.Handler
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewMetricsPlugin creates a plugin with configured metrics
 | 
				
			||||||
 | 
					func NewMetricsPlugin(next plugin.Handler) *MetricsPlugin {
 | 
				
			||||||
 | 
						prometheus.MustRegister(vars.RequestCount)
 | 
				
			||||||
 | 
						prometheus.MustRegister(vars.RequestDuration)
 | 
				
			||||||
 | 
						prometheus.MustRegister(vars.RequestSize)
 | 
				
			||||||
 | 
						prometheus.MustRegister(vars.RequestDo)
 | 
				
			||||||
 | 
						prometheus.MustRegister(vars.RequestType)
 | 
				
			||||||
 | 
						prometheus.MustRegister(vars.ResponseSize)
 | 
				
			||||||
 | 
						prometheus.MustRegister(vars.ResponseRcode)
 | 
				
			||||||
 | 
						return &MetricsPlugin{Next: next}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ServeDNS implements the CoreDNS plugin interface
 | 
				
			||||||
 | 
					func (p MetricsPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
 | 
				
			||||||
 | 
						state := request.Request{W: w, Req: r}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rw := dnstest.NewRecorder(w)
 | 
				
			||||||
 | 
						status, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, rw, r)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Update built-in metrics
 | 
				
			||||||
 | 
						vars.Report(state, ".", rcode.ToString(rw.Rcode), rw.Len, rw.Start)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return status, err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Name implements the CoreDNS plugin interface
 | 
				
			||||||
 | 
					func (p MetricsPlugin) Name() string { return "metrics" }
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,144 @@
 | 
				
			||||||
 | 
					package tunneldns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
 | 
						"os/signal"
 | 
				
			||||||
 | 
						"strconv"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"syscall"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"gopkg.in/urfave/cli.v2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/cloudflare/cloudflared/metrics"
 | 
				
			||||||
 | 
						"github.com/coredns/coredns/core/dnsserver"
 | 
				
			||||||
 | 
						"github.com/coredns/coredns/plugin"
 | 
				
			||||||
 | 
						"github.com/coredns/coredns/plugin/cache"
 | 
				
			||||||
 | 
						"github.com/pkg/errors"
 | 
				
			||||||
 | 
						log "github.com/sirupsen/logrus"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Listener is an adapter between CoreDNS server and Warp runnable
 | 
				
			||||||
 | 
					type Listener struct {
 | 
				
			||||||
 | 
						server *dnsserver.Server
 | 
				
			||||||
 | 
						wg     sync.WaitGroup
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Run implements a foreground runner
 | 
				
			||||||
 | 
					func Run(c *cli.Context) error {
 | 
				
			||||||
 | 
						metricsListener, err := net.Listen("tcp", c.String("metrics"))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.WithError(err).Fatal("Failed to open the metrics listener")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go metrics.ServeMetrics(metricsListener, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						listener, err := CreateListener(c.String("address"), uint16(c.Uint("port")), c.StringSlice("upstream"))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.WithError(err).Errorf("Failed to create the listeners")
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Try to start the server
 | 
				
			||||||
 | 
						err = listener.Start()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.WithError(err).Errorf("Failed to start the listeners")
 | 
				
			||||||
 | 
							return listener.Stop()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Wait for signal
 | 
				
			||||||
 | 
						signals := make(chan os.Signal, 10)
 | 
				
			||||||
 | 
						signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
 | 
				
			||||||
 | 
						defer signal.Stop(signals)
 | 
				
			||||||
 | 
						<-signals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Shut down server
 | 
				
			||||||
 | 
						err = listener.Stop()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.WithError(err).Errorf("failed to stop")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create a CoreDNS server plugin from configuration
 | 
				
			||||||
 | 
					func createConfig(address string, port uint16, p plugin.Handler) *dnsserver.Config {
 | 
				
			||||||
 | 
						c := &dnsserver.Config{
 | 
				
			||||||
 | 
							Zone:        ".",
 | 
				
			||||||
 | 
							Transport:   "dns",
 | 
				
			||||||
 | 
							ListenHosts: []string{address},
 | 
				
			||||||
 | 
							Port:        strconv.FormatUint(uint64(port), 10),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						c.AddPlugin(func(next plugin.Handler) plugin.Handler { return p })
 | 
				
			||||||
 | 
						return c
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Start blocks for serving requests
 | 
				
			||||||
 | 
					func (l *Listener) Start() error {
 | 
				
			||||||
 | 
						log.WithField("addr", l.server.Address()).Infof("Starting DNS over HTTPS proxy server")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Start UDP listener
 | 
				
			||||||
 | 
						if udp, err := l.server.ListenPacket(); err == nil {
 | 
				
			||||||
 | 
							l.wg.Add(1)
 | 
				
			||||||
 | 
							go func() {
 | 
				
			||||||
 | 
								l.server.ServePacket(udp)
 | 
				
			||||||
 | 
								l.wg.Done()
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							return errors.Wrap(err, "failed to create a UDP listener")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Start TCP listener
 | 
				
			||||||
 | 
						tcp, err := l.server.Listen()
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							l.wg.Add(1)
 | 
				
			||||||
 | 
							go func() {
 | 
				
			||||||
 | 
								l.server.Serve(tcp)
 | 
				
			||||||
 | 
								l.wg.Done()
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return errors.Wrap(err, "failed to create a TCP listener")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Stop signals server shutdown and blocks until completed
 | 
				
			||||||
 | 
					func (l *Listener) Stop() error {
 | 
				
			||||||
 | 
						if err := l.server.Stop(); err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						l.wg.Wait()
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CreateListener configures the server and bound sockets
 | 
				
			||||||
 | 
					func CreateListener(address string, port uint16, upstreams []string) (*Listener, error) {
 | 
				
			||||||
 | 
						// Build the list of upstreams
 | 
				
			||||||
 | 
						upstreamList := make([]Upstream, 0)
 | 
				
			||||||
 | 
						for _, url := range upstreams {
 | 
				
			||||||
 | 
							log.WithField("url", url).Infof("Adding DNS upstream")
 | 
				
			||||||
 | 
							upstream, err := NewUpstreamHTTPS(url)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, errors.Wrap(err, "failed to create HTTPS upstream")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							upstreamList = append(upstreamList, upstream)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Create a local cache with HTTPS proxy plugin
 | 
				
			||||||
 | 
						chain := cache.New()
 | 
				
			||||||
 | 
						chain.Next = ProxyPlugin{
 | 
				
			||||||
 | 
							Upstreams: upstreamList,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Format an endpoint
 | 
				
			||||||
 | 
						endpoint := fmt.Sprintf("dns://%s:%d", address, port)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Create the actual middleware server
 | 
				
			||||||
 | 
						server, err := dnsserver.NewServer(endpoint, []*dnsserver.Config{createConfig(address, port, NewMetricsPlugin(chain))})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &Listener{server: server}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,7 @@
 | 
				
			||||||
package pogs
 | 
					package pogs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"github.com/cloudflare/cloudflare-warp/tunnelrpc"
 | 
						"github.com/cloudflare/cloudflared/tunnelrpc"
 | 
				
			||||||
	"golang.org/x/net/context"
 | 
						"golang.org/x/net/context"
 | 
				
			||||||
	"zombiezen.com/go/capnproto2"
 | 
						"zombiezen.com/go/capnproto2"
 | 
				
			||||||
	"zombiezen.com/go/capnproto2/pogs"
 | 
						"zombiezen.com/go/capnproto2/pogs"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,7 @@
 | 
				
			||||||
using Go = import "go.capnp";
 | 
					using Go = import "go.capnp";
 | 
				
			||||||
@0xdb8274f9144abc7e;
 | 
					@0xdb8274f9144abc7e;
 | 
				
			||||||
$Go.package("tunnelrpc");
 | 
					$Go.package("tunnelrpc");
 | 
				
			||||||
$Go.import("github.com/cloudflare/cloudflare-warp/tunnelrpc");
 | 
					$Go.import("github.com/cloudflare/cloudflared/tunnelrpc");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct Authentication {
 | 
					struct Authentication {
 | 
				
			||||||
    key @0 :Text;
 | 
					    key @0 :Text;
 | 
				
			||||||
| 
						 | 
					@ -35,7 +35,7 @@ struct RegistrationOptions {
 | 
				
			||||||
    connectionId @6 :UInt8;
 | 
					    connectionId @6 :UInt8;
 | 
				
			||||||
    # origin LAN IP
 | 
					    # origin LAN IP
 | 
				
			||||||
    originLocalIp @7 :Text;
 | 
					    originLocalIp @7 :Text;
 | 
				
			||||||
    # whether Warp client has been autoupdated
 | 
					    # whether Argo Tunnel client has been autoupdated
 | 
				
			||||||
    isAutoupdated @8 :Bool;
 | 
					    isAutoupdated @8 :Bool;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -119,7 +119,7 @@ func validateScheme(scheme string) error {
 | 
				
			||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return fmt.Errorf("Currently Cloudflare-Warp does not support %s protocol.", scheme)
 | 
						return fmt.Errorf("Currently Argo Tunnel does not support %s protocol.", scheme)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func validateIP(scheme, host, port string) (string, error) {
 | 
					func validateIP(scheme, host, port string) (string, error) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -126,7 +126,7 @@ func TestValidateUrl(t *testing.T) {
 | 
				
			||||||
	assert.Equal(t, "https://hello.example.com", validUrl)
 | 
						assert.Equal(t, "https://hello.example.com", validUrl)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
 | 
						validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
 | 
				
			||||||
	assert.Equal(t, "Currently Cloudflare-Warp does not support ftp protocol.", err.Error())
 | 
						assert.Equal(t, "Currently Argo Tunnel does not support ftp protocol.", err.Error())
 | 
				
			||||||
	assert.Empty(t, validUrl)
 | 
						assert.Empty(t, validUrl)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")
 | 
						validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue