Initial import
This commit is contained in:
commit
82cb539fbe
|
@ -0,0 +1,9 @@
|
||||||
|
# Cloudflare Warp client
|
||||||
|
|
||||||
|
Contains the command-line client and its libraries for Cloudflare Warp, a tunneling daemon that proxies any local webserver through the Cloudflare network.
|
||||||
|
|
||||||
|
## Getting started
|
||||||
|
|
||||||
|
go install github.com/cloudflare/cloudflare-warp/cmd/cloudflare-warp
|
||||||
|
|
||||||
|
User documentation for Warp can be found at https://warp.cloudflare.com
|
|
@ -0,0 +1,39 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
)
|
||||||
|
|
||||||
|
const cloudflareRootCA = `-----BEGIN CERTIFICATE-----
|
||||||
|
MIID/DCCAuagAwIBAgIID+rOSdTGfGcwCwYJKoZIhvcNAQELMIGLMQswCQYDVQQG
|
||||||
|
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRG
|
||||||
|
bGFyZSBPcmlnaW4gU1NMIENlcnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMN
|
||||||
|
U2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTAeFw0xNDExMTMyMDM4
|
||||||
|
NTBaFw0xOTExMTQwMTQzNTBaMIGLMQswCQYDVQQGEwJVUzEZMBcGA1UEChMQQ2xv
|
||||||
|
dWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRGbGFyZSBPcmlnaW4gU1NMIENl
|
||||||
|
cnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEG
|
||||||
|
A1UECBMKQ2FsaWZvcm5pYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
|
||||||
|
AMBIlWf1KEKR5hbB75OYrAcUXobpD/AxvSYRXr91mbRu+lqE7YbyyRUShQh15lem
|
||||||
|
ef+umeEtPZoLFLhcLyczJxOhI+siLGDQm/a/UDkWvAXYa5DZ+pHU5ct5nZ8pGzqJ
|
||||||
|
p8G1Hy5RMVYDXZT9F6EaHjMG0OOffH6Ih25TtgfyyrjXycwDH0u6GXt+G/rywcqz
|
||||||
|
/9W4Aki3XNQMUHNQAtBLEEIYHMkyTYJxuL2tXO6ID5cCsoWw8meHufTeZW2DyUpl
|
||||||
|
yP3AHt4149RQSyWZMJ6AyntL9d8Xhfpxd9rJkh9Kge2iV9rQTFuE1rRT5s7OSJcK
|
||||||
|
xUsklgHcGHYMcNfNMilNHb8CAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgAGMBIGA1Ud
|
||||||
|
EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFCToU1ddfDRAh6nrlNu64RZ4/CmkMB8G
|
||||||
|
A1UdIwQYMBaAFCToU1ddfDRAh6nrlNu64RZ4/CmkMAsGCSqGSIb3DQEBCwOCAQEA
|
||||||
|
cQDBVAoRrhhsGegsSFsv1w8v27zzHKaJNv6ffLGIRvXK8VKKK0gKXh2zQtN9SnaD
|
||||||
|
gYNe7Pr4C3I8ooYKRJJWLsmEHdGdnYYmj0OJfGrfQf6MLIc/11bQhLepZTxdhFYh
|
||||||
|
QGgDl6gRmb8aDwk7Q92BPvek5nMzaWlP82ixavvYI+okoSY8pwdcVKobx6rWzMWz
|
||||||
|
ZEC9M6H3F0dDYE23XcCFIdgNSAmmGyXPBstOe0aAJXwJTxOEPn36VWr0PKIQJy5Y
|
||||||
|
4o1wpMpqCOIwWc8J9REV/REzN6Z1LXImdUgXIXOwrz56gKUJzPejtBQyIGj0mveX
|
||||||
|
Fu6q54beR89jDc+oABmOgg==
|
||||||
|
-----END CERTIFICATE-----`
|
||||||
|
|
||||||
|
func GetCloudflareRootCA() *x509.CertPool {
|
||||||
|
ca := x509.NewCertPool()
|
||||||
|
if !ca.AppendCertsFromPEM([]byte(cloudflareRootCA)) {
|
||||||
|
// should never happen
|
||||||
|
panic("failure loading Cloudflare origin CA pem")
|
||||||
|
}
|
||||||
|
return ca
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
// +build !windows,!darwin,!linux
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func runApp(app *cli.App) {
|
||||||
|
app.Run(os.Args)
|
||||||
|
}
|
|
@ -0,0 +1,155 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"github.com/cloudflare/cloudflare-warp/origin"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||||
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type templateData struct {
|
||||||
|
ServerName string
|
||||||
|
Request *http.Request
|
||||||
|
Tags []tunnelpogs.Tag
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultServerName = "the Cloudflare Warp test server"
|
||||||
|
const indexTemplate = `
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<meta http-equiv="X-UA-Compatible" content="IE=Edge">
|
||||||
|
<title>
|
||||||
|
Cloudflare Warp Connection
|
||||||
|
</title>
|
||||||
|
<meta name="author" content="">
|
||||||
|
<meta name="description" content="Cloudflare Warp Connection">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||||
|
<style>
|
||||||
|
html{line-height:1.15;-ms-text-size-adjust:100%;-webkit-text-size-adjust:100%}body{margin:0}section{display:block}h1{font-size:2em;margin:.67em 0}a{background-color:transparent;-webkit-text-decoration-skip:objects}/* 1 */::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}/* 1 */a,body,dd,div,dl,dt,h1,h4,html,p,section{box-sizing:border-box}.bt{border-top-style:solid;border-top-width:1px}.bl{border-left-style:solid;border-left-width:1px}.b--orange{border-color:#f38020}.br1{border-radius:.125rem}.bw2{border-width:.25rem}.dib{display:inline-block}.sans-serif{font-family:open sans,-apple-system,BlinkMacSystemFont,avenir next,avenir,helvetica neue,helvetica,ubuntu,roboto,noto,segoe ui,arial,sans-serif}.code{font-family:Consolas,monaco,monospace}.b{font-weight:700}.fw3{font-weight:300}.fw4{font-weight:400}.fw5{font-weight:500}.fw6{font-weight:600}.lh-copy{line-height:1.5}.link{text-decoration:none}.link,.link:active,.link:focus,.link:hover,.link:link,.link:visited{transition:color .15s ease-in}.link:focus{outline:1px dotted currentColor}.mw-100{max-width:100%}.mw4{max-width:8rem}.mw7{max-width:48rem}.bg-light-gray{background-color:#f7f7f7}.link-hover:hover{background-color:#1f679e}.white{color:#fff}.bg-white{background-color:#fff}.bg-blue{background-color:#408bc9}.pb2{padding-bottom:.5rem}.pb6{padding-bottom:8rem}.pt3{padding-top:1rem}.pt5{padding-top:4rem}.pv2{padding-top:.5rem;padding-bottom:.5rem}.ph3{padding-left:1rem;padding-right:1rem}.ph4{padding-left:2rem;padding-right:2rem}.ml0{margin-left:0}.mb1{margin-bottom:.25rem}.mb2{margin-bottom:.5rem}.mb3{margin-bottom:1rem}.mt5{margin-top:4rem}.ttu{text-transform:uppercase}.f4{font-size:1.25rem}.f5{font-size:1rem}.f6{font-size:.875rem}.f7{font-size:.75rem}.measure{max-width:30em}.center{margin-left:auto}.center{margin-right:auto}@media screen and (min-width:30em){.f2-ns{font-size:2.25rem}}@media screen and (min-width:30em) and (max-width:60em){.f5-m{font-size:1rem}}@media screen and (min-width:60em){.f4-l{font-size:1.25rem}}
|
||||||
|
.st0{fill:#FFF}.st1{fill:#f48120}.st2{fill:#faad3f}.st3{fill:#404041}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body class="sans-serif black">
|
||||||
|
<div class="bt bw2 b--orange bg-white pb6">
|
||||||
|
<div class="mw7 center ph4 pt3">
|
||||||
|
<svg id="Layer_2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 109 40.5" class="mw4">
|
||||||
|
<path class="st0" d="M98.6 14.2L93 12.9l-1-.4-25.7.2v12.4l32.3.1z"/>
|
||||||
|
<path class="st1" d="M88.1 24c.3-1 .2-2-.3-2.6-.5-.6-1.2-1-2.1-1.1l-17.4-.2c-.1 0-.2-.1-.3-.1-.1-.1-.1-.2 0-.3.1-.2.2-.3.4-.3l17.5-.2c2.1-.1 4.3-1.8 5.1-3.8l1-2.6c0-.1.1-.2 0-.3-1.1-5.1-5.7-8.9-11.1-8.9-5 0-9.3 3.2-10.8 7.7-1-.7-2.2-1.1-3.6-1-2.4.2-4.3 2.2-4.6 4.6-.1.6 0 1.2.1 1.8-3.9.1-7.1 3.3-7.1 7.3 0 .4 0 .7.1 1.1 0 .2.2.3.3.3h32.1c.2 0 .4-.1.4-.3l.3-1.1z"/>
|
||||||
|
<path class="st2" d="M93.6 12.8h-.5c-.1 0-.2.1-.3.2l-.7 2.4c-.3 1-.2 2 .3 2.6.5.6 1.2 1 2.1 1.1l3.7.2c.1 0 .2.1.3.1.1.1.1.2 0 .3-.1.2-.2.3-.4.3l-3.8.2c-2.1.1-4.3 1.8-5.1 3.8l-.2.9c-.1.1 0 .3.2.3h13.2c.2 0 .3-.1.3-.3.2-.8.4-1.7.4-2.6 0-5.2-4.3-9.5-9.5-9.5"/>
|
||||||
|
<path class="st3" d="M104.4 30.8c-.5 0-.9-.4-.9-.9s.4-.9.9-.9.9.4.9.9-.4.9-.9.9m0-1.6c-.4 0-.7.3-.7.7 0 .4.3.7.7.7.4 0 .7-.3.7-.7 0-.4-.3-.7-.7-.7m.4 1.2h-.2l-.2-.3h-.2v.3h-.2v-.9h.5c.2 0 .3.1.3.3 0 .1-.1.2-.2.3l.2.3zm-.3-.5c.1 0 .1 0 .1-.1s-.1-.1-.1-.1h-.3v.3h.3zM14.8 29H17v6h3.8v1.9h-6zM23.1 32.9c0-2.3 1.8-4.1 4.3-4.1s4.2 1.8 4.2 4.1-1.8 4.1-4.3 4.1c-2.4 0-4.2-1.8-4.2-4.1m6.3 0c0-1.2-.8-2.2-2-2.2s-2 1-2 2.1.8 2.1 2 2.1c1.2.2 2-.8 2-2M34.3 33.4V29h2.2v4.4c0 1.1.6 1.7 1.5 1.7s1.5-.5 1.5-1.6V29h2.2v4.4c0 2.6-1.5 3.7-3.7 3.7-2.3-.1-3.7-1.2-3.7-3.7M45 29h3.1c2.8 0 4.5 1.6 4.5 3.9s-1.7 4-4.5 4h-3V29zm3.1 5.9c1.3 0 2.2-.7 2.2-2s-.9-2-2.2-2h-.9v4h.9zM55.7 29H62v1.9h-4.1v1.3h3.7V34h-3.7v2.9h-2.2zM65.1 29h2.2v6h3.8v1.9h-6zM76.8 28.9H79l3.4 8H80l-.6-1.4h-3.1l-.6 1.4h-2.3l3.4-8zm2 4.9l-.9-2.2-.9 2.2h1.8zM85.2 29h3.7c1.2 0 2 .3 2.6.9.5.5.7 1.1.7 1.8 0 1.2-.6 2-1.6 2.4l1.9 2.8H90l-1.6-2.4h-1v2.4h-2.2V29zm3.6 3.8c.7 0 1.2-.4 1.2-.9 0-.6-.5-.9-1.2-.9h-1.4v1.9h1.4zM95.3 29h6.4v1.8h-4.2V32h3.8v1.8h-3.8V35h4.3v1.9h-6.5zM10 33.9c-.3.7-1 1.2-1.8 1.2-1.2 0-2-1-2-2.1s.8-2.1 2-2.1c.9 0 1.6.6 1.9 1.3h2.3c-.4-1.9-2-3.3-4.2-3.3-2.4 0-4.3 1.8-4.3 4.1s1.8 4.1 4.2 4.1c2.1 0 3.7-1.4 4.2-3.2H10z"/>
|
||||||
|
</svg>
|
||||||
|
<h1 class="f4 f2-ns mt5 fw5">Congrats! You created your first tunnel!</h1>
|
||||||
|
<p class="f6 f5-m f4-l measure lh-copy fw3">
|
||||||
|
Cloudflare Warp exposes locally running applications to the internet by
|
||||||
|
running an encrypted, virtual tunnel from your laptop or server to
|
||||||
|
Cloudflare's edge network.
|
||||||
|
</p>
|
||||||
|
<p class="b f5 mt5 fw6">Ready for the next step?</p>
|
||||||
|
<a
|
||||||
|
class="fw6 link white bg-blue ph4 pv2 br1 dib f5 link-hover"
|
||||||
|
style="border-bottom: 1px solid #1f679e"
|
||||||
|
href="https://warp.cloudflare.com">
|
||||||
|
Get started here
|
||||||
|
</a>
|
||||||
|
{{if .Tags}} <section>
|
||||||
|
<h4 class="f6 fw4 pt5 mb2">Connection</h4>
|
||||||
|
<dl class="bl bw2 b--orange ph3 pt3 pb2 bg-light-gray f7 code overflow-x-auto mw-100">
|
||||||
|
{{range .Tags}} <dt class="ttu mb1">{{.Name}}</dt>
|
||||||
|
<dd class="ml0 mb3 f5">{{.Value}}</dd>
|
||||||
|
{{end}} </dl>
|
||||||
|
</section>
|
||||||
|
{{end}} </div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`
|
||||||
|
|
||||||
|
func hello(c *cli.Context) error {
|
||||||
|
address := fmt.Sprintf(":%d", c.Int("port"))
|
||||||
|
server := NewHelloWorldServer()
|
||||||
|
if hostname, err := os.Hostname(); err != nil {
|
||||||
|
server.serverName = hostname
|
||||||
|
}
|
||||||
|
err := server.ListenAndServe(address)
|
||||||
|
return errors.Wrap(err, "Fail to start Hello World Server")
|
||||||
|
}
|
||||||
|
|
||||||
|
func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
|
||||||
|
server := NewHelloWorldServer()
|
||||||
|
if hostname, err := os.Hostname(); err != nil {
|
||||||
|
server.serverName = hostname
|
||||||
|
}
|
||||||
|
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: server}
|
||||||
|
go func() {
|
||||||
|
<-shutdownC
|
||||||
|
httpServer.Close()
|
||||||
|
}()
|
||||||
|
err := httpServer.Serve(listener)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type HelloWorldServer struct {
|
||||||
|
responseTemplate *template.Template
|
||||||
|
serverName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHelloWorldServer() *HelloWorldServer {
|
||||||
|
return &HelloWorldServer{
|
||||||
|
responseTemplate: template.Must(template.New("index").Parse(indexTemplate)),
|
||||||
|
serverName: defaultServerName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func findAvailablePort() (net.Listener, error) {
|
||||||
|
// If the port in address is empty, a port number is automatically chosen.
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:")
|
||||||
|
return listener, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HelloWorldServer) ListenAndServe(address string) error {
|
||||||
|
log.Infof("Starting Hello World server on %s", address)
|
||||||
|
err := http.ListenAndServe(address, s)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto)
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
err := s.responseTemplate.Execute(&buffer, &templateData{
|
||||||
|
ServerName: s.serverName,
|
||||||
|
Request: r,
|
||||||
|
Tags: tagsFromHeaders(r.Header),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
fmt.Fprintf(w, "error: %v", err)
|
||||||
|
} else {
|
||||||
|
buffer.WriteTo(w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func tagsFromHeaders(header http.Header) []tunnelpogs.Tag {
|
||||||
|
var tags []tunnelpogs.Tag
|
||||||
|
for headerName, headerValues := range header {
|
||||||
|
trimmed := strings.TrimPrefix(headerName, origin.TagHeaderNamePrefix)
|
||||||
|
if trimmed == headerName {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, value := range headerValues {
|
||||||
|
tags = append(tags, tunnelpogs.Tag{Name: trimmed, Value: value})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tags
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testPort = "8080"
|
||||||
|
|
||||||
|
func TestNewHelloWorldServer(t *testing.T) {
|
||||||
|
if NewHelloWorldServer() == nil {
|
||||||
|
t.Fatal("NewHelloWorldServer returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindAvailablePort(t *testing.T) {
|
||||||
|
listener, err := findAvailablePort()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Fail to find available port")
|
||||||
|
}
|
||||||
|
if listener.Addr().String() == "" {
|
||||||
|
t.Fatal("Fail to find available port")
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,264 @@
|
||||||
|
// +build linux
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func runApp(app *cli.App) {
|
||||||
|
app.Commands = append(app.Commands, &cli.Command{
|
||||||
|
Name: "service",
|
||||||
|
Usage: "Manages the Cloudflare Warp system service",
|
||||||
|
Subcommands: []*cli.Command{
|
||||||
|
&cli.Command{
|
||||||
|
Name: "install",
|
||||||
|
Usage: "Install Cloudflare Warp as a system service",
|
||||||
|
Action: installLinuxService,
|
||||||
|
},
|
||||||
|
&cli.Command{
|
||||||
|
Name: "uninstall",
|
||||||
|
Usage: "Uninstall the Cloudflare Warp service",
|
||||||
|
Action: uninstallLinuxService,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
app.Run(os.Args)
|
||||||
|
}
|
||||||
|
|
||||||
|
var systemdTemplates = []ServiceTemplate{
|
||||||
|
{
|
||||||
|
Path: "/etc/systemd/system/cloudflare-warp.service",
|
||||||
|
Content: `[Unit]
|
||||||
|
Description=Cloudflare Warp
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
TimeoutStartSec=0
|
||||||
|
Type=notify
|
||||||
|
ExecStart={{ .Path }} --config /etc/cloudflare-warp.yml --autoupdate 0s
|
||||||
|
User=nobody
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/etc/systemd/system/cloudflare-warp-update.service",
|
||||||
|
Content: `[Unit]
|
||||||
|
Description=Update Cloudflare Warp
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 64 ]; then systemctl restart cloudflare-warp; exit 0; fi; exit $code'
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Path: "/etc/systemd/system/cloudflare-warp-update.timer",
|
||||||
|
Content: `[Unit]
|
||||||
|
Description=Update Cloudflare Warp
|
||||||
|
|
||||||
|
[Timer]
|
||||||
|
OnUnitActiveSec=1d
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=timers.target
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var sysvTemplate = ServiceTemplate{
|
||||||
|
Path: "/etc/init.d/cloudflare-warp",
|
||||||
|
FileMode: 0755,
|
||||||
|
Content: `# For RedHat and cousins:
|
||||||
|
# chkconfig: 2345 99 01
|
||||||
|
# description: Cloudflare Warp agent
|
||||||
|
# processname: {{.Path}}
|
||||||
|
### BEGIN INIT INFO
|
||||||
|
# Provides: {{.Path}}
|
||||||
|
# Required-Start:
|
||||||
|
# Required-Stop:
|
||||||
|
# Default-Start: 2 3 4 5
|
||||||
|
# Default-Stop: 0 1 6
|
||||||
|
# Short-Description: Cloudflare Warp
|
||||||
|
# Description: Cloudflare Warp agent
|
||||||
|
### END INIT INFO
|
||||||
|
cmd="{{.Path}} --config /etc/cloudflare-warp.yml --pidfile /var/run/$name.pid"
|
||||||
|
name=$(basename $(readlink -f $0))
|
||||||
|
pid_file="/var/run/$name.pid"
|
||||||
|
stdout_log="/var/log/$name.log"
|
||||||
|
stderr_log="/var/log/$name.err"
|
||||||
|
[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name
|
||||||
|
get_pid() {
|
||||||
|
cat "$pid_file"
|
||||||
|
}
|
||||||
|
is_running() {
|
||||||
|
[ -f "$pid_file" ] && ps $(get_pid) > /dev/null 2>&1
|
||||||
|
}
|
||||||
|
case "$1" in
|
||||||
|
start)
|
||||||
|
if is_running; then
|
||||||
|
echo "Already started"
|
||||||
|
else
|
||||||
|
echo "Starting $name"
|
||||||
|
$cmd >> "$stdout_log" 2>> "$stderr_log" &
|
||||||
|
echo $! > "$pid_file"
|
||||||
|
if ! is_running; then
|
||||||
|
echo "Unable to start, see $stdout_log and $stderr_log"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
stop)
|
||||||
|
if is_running; then
|
||||||
|
echo -n "Stopping $name.."
|
||||||
|
kill $(get_pid)
|
||||||
|
for i in {1..10}
|
||||||
|
do
|
||||||
|
if ! is_running; then
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
echo -n "."
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo
|
||||||
|
if is_running; then
|
||||||
|
echo "Not stopped; may still be shutting down or shutdown may have failed"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "Stopped"
|
||||||
|
if [ -f "$pid_file" ]; then
|
||||||
|
rm "$pid_file"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "Not running"
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
restart)
|
||||||
|
$0 stop
|
||||||
|
if is_running; then
|
||||||
|
echo "Unable to stop, will not attempt to start"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
$0 start
|
||||||
|
;;
|
||||||
|
status)
|
||||||
|
if is_running; then
|
||||||
|
echo "Running"
|
||||||
|
else
|
||||||
|
echo "Stopped"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Usage: $0 {start|stop|restart|status}"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
exit 0
|
||||||
|
`,
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSystemd() bool {
|
||||||
|
if _, err := os.Stat("/run/systemd/system"); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func installLinuxService(c *cli.Context) error {
|
||||||
|
etPath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error determining executable path: %v", err)
|
||||||
|
}
|
||||||
|
templateArgs := ServiceTemplateArgs{Path: etPath}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case isSystemd():
|
||||||
|
return installSystemd(&templateArgs)
|
||||||
|
default:
|
||||||
|
return installSysv(&templateArgs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func installSystemd(templateArgs *ServiceTemplateArgs) error {
|
||||||
|
for _, serviceTemplate := range systemdTemplates {
|
||||||
|
err := serviceTemplate.Generate(templateArgs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := runCommand("systemctl", "enable", "cloudflare-warp.service"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := runCommand("systemctl", "start", "cloudflare-warp-update.timer"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return runCommand("systemctl", "daemon-reload")
|
||||||
|
}
|
||||||
|
|
||||||
|
func installSysv(templateArgs *ServiceTemplateArgs) error {
|
||||||
|
confPath, err := sysvTemplate.ResolvePath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := sysvTemplate.Generate(templateArgs); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, i := range [...]string{"2", "3", "4", "5"} {
|
||||||
|
if err := os.Symlink(confPath, "/etc/rc"+i+".d/S50et"); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, i := range [...]string{"0", "1", "6"} {
|
||||||
|
if err := os.Symlink(confPath, "/etc/rc"+i+".d/K02et"); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallLinuxService(c *cli.Context) error {
|
||||||
|
switch {
|
||||||
|
case isSystemd():
|
||||||
|
return uninstallSystemd()
|
||||||
|
default:
|
||||||
|
return uninstallSysv()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallSystemd() error {
|
||||||
|
if err := runCommand("systemctl", "disable", "cloudflare-warp.service"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := runCommand("systemctl", "stop", "cloudflare-warp-update.timer"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, serviceTemplate := range systemdTemplates {
|
||||||
|
if err := serviceTemplate.Remove(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallSysv() error {
|
||||||
|
if err := sysvTemplate.Remove(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, i := range [...]string{"2", "3", "4", "5"} {
|
||||||
|
if err := os.Remove("/etc/rc" + i + ".d/S50et"); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, i := range [...]string{"0", "1", "6"} {
|
||||||
|
if err := os.Remove("/etc/rc" + i + ".d/K02et"); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,31 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
homedir "github.com/mitchellh/go-homedir"
|
||||||
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func login(c *cli.Context) error {
|
||||||
|
path, err := homedir.Expand(defaultConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fileInfo, err := os.Stat(path)
|
||||||
|
if err == nil && fileInfo.Size() > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, `You have an existing config file at %s which login would overwrite.
|
||||||
|
If this is intentional, please move or delete that file then run this command again.
|
||||||
|
`, defaultConfigPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil && err.(*os.PathError).Err != syscall.ENOENT {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintln(os.Stderr, "Please visit https://www.cloudflare.com/a/warp to obtain a certificate.")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,82 @@
|
||||||
|
// +build darwin
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func runApp(app *cli.App) {
|
||||||
|
app.Commands = append(app.Commands, &cli.Command{
|
||||||
|
Name: "service",
|
||||||
|
Usage: "Manages the Cloudflare Warp launch agent",
|
||||||
|
Subcommands: []*cli.Command{
|
||||||
|
&cli.Command{
|
||||||
|
Name: "install",
|
||||||
|
Usage: "Install Cloudflare Warp as an user launch agent",
|
||||||
|
Action: installLaunchd,
|
||||||
|
},
|
||||||
|
&cli.Command{
|
||||||
|
Name: "uninstall",
|
||||||
|
Usage: "Uninstall the Cloudflare Warp launch agent",
|
||||||
|
Action: uninstallLaunchd,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
app.Run(os.Args)
|
||||||
|
}
|
||||||
|
|
||||||
|
var launchdTemplate = ServiceTemplate{
|
||||||
|
Path: "~/Library/LaunchAgents/com.cloudflare.warp.plist",
|
||||||
|
Content: `<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||||
|
<plist version="1.0">
|
||||||
|
<dict>
|
||||||
|
<key>Label</key>
|
||||||
|
<string>com.cloudflare.warp</string>
|
||||||
|
<key>Program</key>
|
||||||
|
<string>{{ .Path }}</string>
|
||||||
|
<key>RunAtLoad</key>
|
||||||
|
<true/>
|
||||||
|
<key>KeepAlive</key>
|
||||||
|
<dict>
|
||||||
|
<key>NetworkState</key>
|
||||||
|
<true/>
|
||||||
|
</dict>
|
||||||
|
<key>ThrottleInterval</key>
|
||||||
|
<integer>20</integer>
|
||||||
|
</dict>
|
||||||
|
</plist>`,
|
||||||
|
}
|
||||||
|
|
||||||
|
func installLaunchd(c *cli.Context) error {
|
||||||
|
etPath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error determining executable path: %v", err)
|
||||||
|
}
|
||||||
|
templateArgs := ServiceTemplateArgs{Path: etPath}
|
||||||
|
err = launchdTemplate.Generate(&templateArgs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
plistPath, err := launchdTemplate.ResolvePath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return runCommand("launchctl", "load", plistPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallLaunchd(c *cli.Context) error {
|
||||||
|
plistPath, err := launchdTemplate.ResolvePath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = runCommand("launchctl", "unload", plistPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return launchdTemplate.Remove()
|
||||||
|
}
|
|
@ -0,0 +1,473 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflare-warp/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflare-warp/metrics"
|
||||||
|
"github.com/cloudflare/cloudflare-warp/origin"
|
||||||
|
"github.com/cloudflare/cloudflare-warp/tlsconfig"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||||
|
"github.com/cloudflare/cloudflare-warp/validation"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"github.com/facebookgo/grace/gracenet"
|
||||||
|
raven "github.com/getsentry/raven-go"
|
||||||
|
homedir "github.com/mitchellh/go-homedir"
|
||||||
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
|
"gopkg.in/urfave/cli.v2/altsrc"
|
||||||
|
|
||||||
|
"github.com/coreos/go-systemd/daemon"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
|
||||||
|
const defaultConfigPath = "~/.cloudflare-warp.yml"
|
||||||
|
|
||||||
|
var listeners = gracenet.Net{}
|
||||||
|
var Version = "DEV"
|
||||||
|
var BuildTime = "unknown"
|
||||||
|
|
||||||
|
// Shutdown channel used by the app. When closed, app must terminate.
|
||||||
|
// May be closed by the Windows service runner.
|
||||||
|
var shutdownC chan struct{}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
metrics.RegisterBuildInfo(BuildTime, Version)
|
||||||
|
raven.SetDSN(sentryDSN)
|
||||||
|
raven.SetRelease(Version)
|
||||||
|
shutdownC = make(chan struct{})
|
||||||
|
app := &cli.App{}
|
||||||
|
app.Name = "cloudflare-warp"
|
||||||
|
app.Copyright = `(c) 2017 Cloudflare Inc.
|
||||||
|
Use is subject to the license agreement at https://warp.cloudflare.com/licence/`
|
||||||
|
app.Usage = "Cloudflare reverse tunnelling proxy agent \033[1;31m*BETA*\033[0m"
|
||||||
|
app.ArgsUsage = "origin-url"
|
||||||
|
app.Version = fmt.Sprintf("%s (built %s)", Version, BuildTime)
|
||||||
|
app.Description = `A reverse tunnel proxy agent that connects to Cloudflare's infrastructure.
|
||||||
|
Upon connecting, you are assigned a unique subdomain on cftunnel.com.
|
||||||
|
Alternatively, you can specify a hostname on a zone you control.
|
||||||
|
|
||||||
|
Requests made to Cloudflare's servers for your hostname will be proxied
|
||||||
|
through the tunnel to your local webserver.
|
||||||
|
|
||||||
|
WARNING:
|
||||||
|
` + "\033[1;31m*** THIS IS A BETA VERSION OF THE CLOUDFLARE WARP AGENT ***\033[0m" + `
|
||||||
|
|
||||||
|
At this time, do not use Cloudflare Warp for connecting production servers to Cloudflare.
|
||||||
|
Availability and reliability of this service is not guaranteed through the beta period.`
|
||||||
|
app.Flags = []cli.Flag{
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "config",
|
||||||
|
Usage: "Specifies a config file in YAML format.",
|
||||||
|
},
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: "autoupdate",
|
||||||
|
Usage: "Periodically check for updates, restarting the server with the new version.",
|
||||||
|
Value: time.Hour * 24,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "edge",
|
||||||
|
Value: "cftunnel.com:7844",
|
||||||
|
Usage: "Address of the Cloudflare tunnel server.",
|
||||||
|
EnvVars: []string{"TUNNEL_EDGE"},
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "cacert",
|
||||||
|
Usage: "Certificate Authority authenticating the Cloudflare tunnel connection.",
|
||||||
|
EnvVars: []string{"TUNNEL_CACERT"},
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "url",
|
||||||
|
Value: "http://localhost:8080",
|
||||||
|
Usage: "Connect to the local webserver at `URL`.",
|
||||||
|
EnvVars: []string{"TUNNEL_URL"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "hostname",
|
||||||
|
Usage: "Set a hostname on a Cloudflare zone to route traffic through this tunnel.",
|
||||||
|
EnvVars: []string{"TUNNEL_HOSTNAME"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "id",
|
||||||
|
Usage: "A unique identifier used to tie connections to this tunnel instance.",
|
||||||
|
EnvVars: []string{"TUNNEL_ID"},
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "lb-pool",
|
||||||
|
Usage: "The name of a (new/existing) load balancing pool to add this origin to.",
|
||||||
|
EnvVars: []string{"TUNNEL_LB_POOL"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "api-key",
|
||||||
|
Usage: "A Cloudflare API key. Required(can be in the config file) unless you are only running the hello command or login command.",
|
||||||
|
EnvVars: []string{"TUNNEL_API_KEY"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "api-email",
|
||||||
|
Usage: "The Cloudflare user's email address associated with the API key. Required(can be in the config file) unless you are only running the hello command or login command.",
|
||||||
|
EnvVars: []string{"TUNNEL_API_EMAIL"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "api-ca-key",
|
||||||
|
Usage: "The Origin CA service key associated with the user. Required(can be in the config file) unless you are only running the hello command or login command.",
|
||||||
|
EnvVars: []string{"TUNNEL_API_CA_KEY"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "metrics",
|
||||||
|
Value: "localhost:",
|
||||||
|
Usage: "Listen address for metrics reporting.",
|
||||||
|
EnvVars: []string{"TUNNEL_METRICS"},
|
||||||
|
}),
|
||||||
|
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||||
|
Name: "tag",
|
||||||
|
Usage: "Custom tags used to identify this tunnel, in format `KEY=VALUE`. Multiple tags may be specified",
|
||||||
|
EnvVars: []string{"TUNNEL_TAG"},
|
||||||
|
}),
|
||||||
|
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||||
|
Name: "heartbeat-interval",
|
||||||
|
Usage: "Minimum idle time before sending a heartbeat.",
|
||||||
|
Value: time.Second * 5,
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewUint64Flag(&cli.Uint64Flag{
|
||||||
|
Name: "heartbeat-count",
|
||||||
|
Usage: "Minimum number of unacked heartbeats to send before closing the connection.",
|
||||||
|
Value: 5,
|
||||||
|
Hidden: true,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "loglevel",
|
||||||
|
Value: "info",
|
||||||
|
Usage: "Logging level {panic, fatal, error, warn, info, debug}",
|
||||||
|
EnvVars: []string{"TUNNEL_LOGLEVEL"},
|
||||||
|
}),
|
||||||
|
altsrc.NewUintFlag(&cli.UintFlag{
|
||||||
|
Name: "retries",
|
||||||
|
Value: 5,
|
||||||
|
Usage: "Maximum number of retries for connection/protocol errors.",
|
||||||
|
EnvVars: []string{"TUNNEL_RETRIES"},
|
||||||
|
}),
|
||||||
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
|
Name: "debug",
|
||||||
|
Value: false,
|
||||||
|
Usage: "Enable HTTP requests to the autogenerated cftunnel.com domain.",
|
||||||
|
EnvVars: []string{"TUNNEL_DEBUG"},
|
||||||
|
}),
|
||||||
|
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||||
|
Name: "hello-world",
|
||||||
|
Usage: "Run Hello World Server",
|
||||||
|
Value: false,
|
||||||
|
}),
|
||||||
|
altsrc.NewStringFlag(&cli.StringFlag{
|
||||||
|
Name: "pidfile",
|
||||||
|
Usage: "Write the application's PID to this file after first successful connection.",
|
||||||
|
EnvVars: []string{"TUNNEL_PIDFILE"},
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
app.Action = func(c *cli.Context) error {
|
||||||
|
raven.CapturePanic(func() { startServer(c) }, nil)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
app.Before = func(context *cli.Context) error {
|
||||||
|
inputSource, err := findInputSourceContext(context)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if inputSource != nil {
|
||||||
|
return altsrc.ApplyInputSourceValues(context, inputSource, app.Flags)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
app.Commands = []*cli.Command{
|
||||||
|
&cli.Command{
|
||||||
|
Name: "update",
|
||||||
|
Action: update,
|
||||||
|
Usage: "Update the agent if a new version exists",
|
||||||
|
ArgsUsage: " ",
|
||||||
|
Description: `Looks for a new version on the offical download server.
|
||||||
|
If a new version exists, updates the agent binary and quits.
|
||||||
|
Otherwise, does nothing.
|
||||||
|
|
||||||
|
To determine if an update happened in a script, check for error code 64.`,
|
||||||
|
},
|
||||||
|
&cli.Command{
|
||||||
|
Name: "login",
|
||||||
|
Action: login,
|
||||||
|
Usage: "Generate a configuration file with your login details",
|
||||||
|
ArgsUsage: " ",
|
||||||
|
},
|
||||||
|
&cli.Command{
|
||||||
|
Name: "hello",
|
||||||
|
Action: hello,
|
||||||
|
Usage: "Run a simple \"Hello World\" server for testing Cloudflare Warp.",
|
||||||
|
Flags: []cli.Flag{
|
||||||
|
&cli.IntFlag{
|
||||||
|
Name: "port",
|
||||||
|
Usage: "Listen on the selected port.",
|
||||||
|
Value: 8080,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ArgsUsage: " ", // can't be the empty string or we get the default output
|
||||||
|
},
|
||||||
|
}
|
||||||
|
runApp(app)
|
||||||
|
}
|
||||||
|
|
||||||
|
func startServer(c *cli.Context) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errC := make(chan error)
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
if c.NumFlags() == 0 && c.NArg() == 0 {
|
||||||
|
cli.ShowAppHelp(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logLevel, err := log.ParseLevel(c.String("loglevel"))
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Fatal("Unknown logging level specified")
|
||||||
|
}
|
||||||
|
log.SetLevel(logLevel)
|
||||||
|
hostname, err := validation.ValidateHostname(c.String("hostname"))
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Fatal("Invalid hostname")
|
||||||
|
|
||||||
|
}
|
||||||
|
clientID := c.String("id")
|
||||||
|
if !c.IsSet("id") {
|
||||||
|
clientID = generateRandomClientID()
|
||||||
|
}
|
||||||
|
|
||||||
|
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Fatal("Tag parse failure")
|
||||||
|
}
|
||||||
|
|
||||||
|
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
|
||||||
|
|
||||||
|
if c.IsSet("hello-world") {
|
||||||
|
wg.Add(1)
|
||||||
|
listener, err := findAvailablePort()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
listener.Close()
|
||||||
|
log.WithError(err).Fatal("Cannot start Hello World Server")
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
startHelloWorldServer(listener, shutdownC)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
c.Set("url", "http://"+listener.Addr().String())
|
||||||
|
log.Infof("Starting Hello World Server at %s", c.String("url"))
|
||||||
|
}
|
||||||
|
url, err := validateUrl(c)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Fatal("Error validating url")
|
||||||
|
}
|
||||||
|
// User must have api-key, api-email and api-ca-key
|
||||||
|
if !c.IsSet("api-key") {
|
||||||
|
log.Fatal("You need to give us your api-key either via the --api-key option or put it in the configuration file. You will also need to give us your api-email and api-ca-key.")
|
||||||
|
}
|
||||||
|
if !c.IsSet("api-email") {
|
||||||
|
log.Fatal("You need to give us your api-email either via the --api-email option or put it in the configuration file. You will also need to give us your api-ca-key.")
|
||||||
|
}
|
||||||
|
if !c.IsSet("api-ca-key") {
|
||||||
|
log.Fatal("You need to give us your api-ca-key either via the --api-ca-key option or put it in the configuration file.")
|
||||||
|
}
|
||||||
|
log.Infof("Proxying tunnel requests to %s", url)
|
||||||
|
tunnelConfig := &origin.TunnelConfig{
|
||||||
|
EdgeAddr: c.String("edge"),
|
||||||
|
OriginUrl: url,
|
||||||
|
Hostname: hostname,
|
||||||
|
APIKey: c.String("api-key"),
|
||||||
|
APIEmail: c.String("api-email"),
|
||||||
|
APICAKey: c.String("api-ca-key"),
|
||||||
|
TlsConfig: &tls.Config{},
|
||||||
|
Retries: c.Uint("retries"),
|
||||||
|
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
||||||
|
MaxHeartbeats: c.Uint64("heartbeat-count"),
|
||||||
|
ClientID: clientID,
|
||||||
|
ReportedVersion: Version,
|
||||||
|
LBPool: c.String("lb-pool"),
|
||||||
|
Tags: tags,
|
||||||
|
AccessInternalIP: c.Bool("debug"),
|
||||||
|
ConnectedSignal: h2mux.NewSignal(),
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnelConfig.TlsConfig = tlsconfig.CLIFlags{RootCA: "cacert"}.GetConfig(c)
|
||||||
|
if tunnelConfig.TlsConfig.RootCAs == nil {
|
||||||
|
tunnelConfig.TlsConfig.RootCAs = GetCloudflareRootCA()
|
||||||
|
tunnelConfig.TlsConfig.ServerName = "cftunnel.com"
|
||||||
|
} else {
|
||||||
|
tunnelConfig.TlsConfig.ServerName, _, _ = net.SplitHostPort(tunnelConfig.EdgeAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
go writePidFile(tunnelConfig.ConnectedSignal, c.String("pidfile"))
|
||||||
|
go func() {
|
||||||
|
errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Fatal("Error opening metrics server listener")
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
errC <- metrics.ServeMetrics(metricsListener, shutdownC)
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
go autoupdate(c.Duration("autoupdate"), shutdownC)
|
||||||
|
|
||||||
|
err = WaitForSignal(errC, shutdownC)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Quitting due to error")
|
||||||
|
raven.CaptureErrorAndWait(err, nil)
|
||||||
|
} else {
|
||||||
|
log.Info("Quitting...")
|
||||||
|
}
|
||||||
|
// Wait for clean exit, discarding all errors
|
||||||
|
go func() {
|
||||||
|
for range errC {
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
|
||||||
|
signals := make(chan os.Signal, 10)
|
||||||
|
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
defer signal.Stop(signals)
|
||||||
|
select {
|
||||||
|
case err := <-errC:
|
||||||
|
close(shutdownC)
|
||||||
|
return err
|
||||||
|
case <-signals:
|
||||||
|
close(shutdownC)
|
||||||
|
case <-shutdownC:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func update(c *cli.Context) error {
|
||||||
|
if updateApplied() {
|
||||||
|
os.Exit(64)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func autoupdate(frequency time.Duration, shutdownC chan struct{}) {
|
||||||
|
if int64(frequency) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
if updateApplied() {
|
||||||
|
if _, err := listeners.StartProcess(); err != nil {
|
||||||
|
log.WithError(err).Error("Unable to restart server automatically")
|
||||||
|
}
|
||||||
|
close(shutdownC)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(frequency)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateApplied() bool {
|
||||||
|
releaseInfo := checkForUpdates()
|
||||||
|
if releaseInfo.Updated {
|
||||||
|
log.Infof("Updated to version %s", releaseInfo.Version)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if releaseInfo.Error != nil {
|
||||||
|
log.WithError(releaseInfo.Error).Error("Update check failed")
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileExists(path string) (bool, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
// ignore missing files
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
f.Close()
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) {
|
||||||
|
if context.IsSet("config") {
|
||||||
|
return altsrc.NewYamlSourceFromFile(context.String("config"))
|
||||||
|
}
|
||||||
|
for _, tryPath := range []string{
|
||||||
|
defaultConfigPath,
|
||||||
|
"~/.cloudflare-warp.yaml",
|
||||||
|
"~/cloudflare-warp.yaml",
|
||||||
|
"~/cloudflare-warp.yml",
|
||||||
|
"~/.et.yaml",
|
||||||
|
"~/et.yml",
|
||||||
|
"~/et.yaml",
|
||||||
|
"~/.cftunnel.yaml", // for existing users
|
||||||
|
"~/cftunnel.yaml",
|
||||||
|
} {
|
||||||
|
path, err := homedir.Expand(tryPath)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ok, err := fileExists(path)
|
||||||
|
if ok {
|
||||||
|
return altsrc.NewYamlSourceFromFile(path)
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateRandomClientID() string {
|
||||||
|
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
id := make([]byte, 32)
|
||||||
|
r.Read(id)
|
||||||
|
return hex.EncodeToString(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writePidFile(waitForSignal h2mux.Signal, pidFile string) {
|
||||||
|
waitForSignal.Wait()
|
||||||
|
daemon.SdNotify(false, "READY=1")
|
||||||
|
if pidFile == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
file, err := os.Create(pidFile)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
fmt.Fprintf(file, "%d", os.Getpid())
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate url. It can be either from --url or argument
|
||||||
|
func validateUrl(c *cli.Context) (string, error) {
|
||||||
|
var url = c.String("url")
|
||||||
|
if c.NArg() > 0 {
|
||||||
|
if c.IsSet("url") {
|
||||||
|
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
|
||||||
|
}
|
||||||
|
url = c.Args().Get(0)
|
||||||
|
}
|
||||||
|
validUrl, err := validation.ValidateUrl(url)
|
||||||
|
return validUrl, err
|
||||||
|
}
|
|
@ -0,0 +1,88 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
homedir "github.com/mitchellh/go-homedir"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServiceTemplate struct {
|
||||||
|
Path string
|
||||||
|
Content string
|
||||||
|
FileMode os.FileMode
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServiceTemplateArgs struct {
|
||||||
|
Path string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (st *ServiceTemplate) ResolvePath() (string, error) {
|
||||||
|
resolvedPath, err := homedir.Expand(st.Path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error resolving path %s: %v", st.Path, err)
|
||||||
|
}
|
||||||
|
return resolvedPath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
|
||||||
|
tmpl, err := template.New(st.Path).Parse(st.Content)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error generating %s template: %v", st.Path, err)
|
||||||
|
}
|
||||||
|
resolvedPath, err := st.ResolvePath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
err = tmpl.Execute(&buffer, args)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error generating %s: %v", st.Path, err)
|
||||||
|
}
|
||||||
|
fileMode := os.FileMode(0644)
|
||||||
|
if st.FileMode != 0 {
|
||||||
|
fileMode = st.FileMode
|
||||||
|
}
|
||||||
|
err = ioutil.WriteFile(resolvedPath, buffer.Bytes(), fileMode)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error writing %s: %v", resolvedPath, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (st *ServiceTemplate) Remove() error {
|
||||||
|
resolvedPath, err := st.ResolvePath()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = os.Remove(resolvedPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error deleting %s: %v", resolvedPath, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runCommand(command string, args ...string) error {
|
||||||
|
cmd := exec.Command(command, args...)
|
||||||
|
stderr, err := cmd.StderrPipe()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting stderr pipe: %v", err)
|
||||||
|
}
|
||||||
|
err = cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error starting %s: %v", command, err)
|
||||||
|
}
|
||||||
|
commandErr, _ := ioutil.ReadAll(stderr)
|
||||||
|
if len(commandErr) > 0 {
|
||||||
|
return fmt.Errorf("%s error: %s", command, commandErr)
|
||||||
|
}
|
||||||
|
err = cmd.Wait()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%s returned with error: %v", command, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Restrict key names to characters allowed in an HTTP header name.
|
||||||
|
// Restrict key values to printable characters (what is recognised as data in an HTTP header value).
|
||||||
|
var tagRegexp = regexp.MustCompile("^([a-zA-Z0-9!#$%&'*+\\-.^_`|~]+)=([[:print:]]+)$")
|
||||||
|
|
||||||
|
func NewTagFromCLI(compoundTag string) (tunnelpogs.Tag, bool) {
|
||||||
|
matches := tagRegexp.FindStringSubmatch(compoundTag)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return tunnelpogs.Tag{}, false
|
||||||
|
}
|
||||||
|
return tunnelpogs.Tag{Name: matches[1], Value: matches[2]}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTagSliceFromCLI(tags []string) ([]tunnelpogs.Tag, error) {
|
||||||
|
var tagSlice []tunnelpogs.Tag
|
||||||
|
for _, compoundTag := range tags {
|
||||||
|
if tag, ok := NewTagFromCLI(compoundTag); ok {
|
||||||
|
tagSlice = append(tagSlice, tag)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("Cannot parse tag value %s", compoundTag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tagSlice, nil
|
||||||
|
}
|
|
@ -0,0 +1,46 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSingleTag(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
Input string
|
||||||
|
Output tunnelpogs.Tag
|
||||||
|
Fail bool
|
||||||
|
}{
|
||||||
|
{Input: "x=y", Output: tunnelpogs.Tag{Name: "x", Value: "y"}},
|
||||||
|
{Input: "More-Complex=Tag Values", Output: tunnelpogs.Tag{Name: "More-Complex", Value: "Tag Values"}},
|
||||||
|
{Input: "First=Equals=Wins", Output: tunnelpogs.Tag{Name: "First", Value: "Equals=Wins"}},
|
||||||
|
{Input: "x=", Fail: true},
|
||||||
|
{Input: "=y", Fail: true},
|
||||||
|
{Input: "=", Fail: true},
|
||||||
|
{Input: "No spaces allowed=in key names", Fail: true},
|
||||||
|
{Input: "omg\nwtf=bbq", Fail: true},
|
||||||
|
}
|
||||||
|
for i, testCase := range testCases {
|
||||||
|
tag, ok := NewTagFromCLI(testCase.Input)
|
||||||
|
assert.Equalf(t, !testCase.Fail, ok, "mismatched success for test case %d", i)
|
||||||
|
assert.Equalf(t, testCase.Output, tag, "mismatched output for test case %d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTagSlice(t *testing.T) {
|
||||||
|
tagSlice, err := NewTagSliceFromCLI([]string{"a=b", "c=d", "e=f"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, tagSlice, 3)
|
||||||
|
assert.Equal(t, "a", tagSlice[0].Name)
|
||||||
|
assert.Equal(t, "b", tagSlice[0].Value)
|
||||||
|
assert.Equal(t, "c", tagSlice[1].Name)
|
||||||
|
assert.Equal(t, "d", tagSlice[1].Value)
|
||||||
|
assert.Equal(t, "e", tagSlice[2].Name)
|
||||||
|
assert.Equal(t, "f", tagSlice[2].Value)
|
||||||
|
|
||||||
|
tagSlice, err = NewTagSliceFromCLI([]string{"a=b", "=", "e=f"})
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import "github.com/equinox-io/equinox"
|
||||||
|
|
||||||
|
const appID = "app_cwbQae3Tpea"
|
||||||
|
|
||||||
|
var publicKey = []byte(`
|
||||||
|
-----BEGIN ECDSA PUBLIC KEY-----
|
||||||
|
MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE4OWZocTVZ8Do/L6ScLdkV+9A0IYMHoOf
|
||||||
|
dsCmJ/QZ6aw0w9qkkwEpne1Lmo6+0pGexZzFZOH6w5amShn+RXt7qkSid9iWlzGq
|
||||||
|
EKx0BZogHSor9Wy5VztdFaAaVbsJiCbO
|
||||||
|
-----END ECDSA PUBLIC KEY-----
|
||||||
|
`)
|
||||||
|
|
||||||
|
type ReleaseInfo struct {
|
||||||
|
Updated bool
|
||||||
|
Version string
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkForUpdates() ReleaseInfo {
|
||||||
|
var opts equinox.Options
|
||||||
|
if err := opts.SetPublicKeyPEM(publicKey); err != nil {
|
||||||
|
return ReleaseInfo{Error: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := equinox.Check(appID, opts)
|
||||||
|
switch {
|
||||||
|
case err == equinox.NotAvailableErr:
|
||||||
|
return ReleaseInfo{}
|
||||||
|
case err != nil:
|
||||||
|
return ReleaseInfo{Error: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = resp.Apply()
|
||||||
|
if err != nil {
|
||||||
|
return ReleaseInfo{Error: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion}
|
||||||
|
}
|
|
@ -0,0 +1,145 @@
|
||||||
|
// +build windows
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
// Copypasta from the example files:
|
||||||
|
// https://github.com/golang/sys/blob/master/windows/svc/example
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows/svc"
|
||||||
|
"golang.org/x/sys/windows/svc/eventlog"
|
||||||
|
"golang.org/x/sys/windows/svc/mgr"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
windowsServiceName = "CloudflareWarp"
|
||||||
|
windowsServiceDescription = "Cloudflare Warp agent"
|
||||||
|
)
|
||||||
|
|
||||||
|
func runApp(app *cli.App) {
|
||||||
|
app.Commands = append(app.Commands, &cli.Command{
|
||||||
|
Name: "service",
|
||||||
|
Usage: "Manages the Cloudflare Warp Windows service",
|
||||||
|
Subcommands: []*cli.Command{
|
||||||
|
&cli.Command{
|
||||||
|
Name: "install",
|
||||||
|
Usage: "Install Cloudflare Warp as a Windows service",
|
||||||
|
Action: installWindowsService,
|
||||||
|
},
|
||||||
|
&cli.Command{
|
||||||
|
Name: "uninstall",
|
||||||
|
Usage: "Uninstall the Cloudflare Warp service",
|
||||||
|
Action: uninstallWindowsService,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
isIntSess, err := svc.IsAnInteractiveSession()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to determine if we are running in an interactive session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isIntSess {
|
||||||
|
app.Run(os.Args)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
elog, err := eventlog.Open(windowsServiceName)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer elog.Close()
|
||||||
|
|
||||||
|
elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName))
|
||||||
|
err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog})
|
||||||
|
if err != nil {
|
||||||
|
elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
elog.Info(1, fmt.Sprintf("%s service stopped", windowsServiceName))
|
||||||
|
}
|
||||||
|
|
||||||
|
type windowsService struct {
|
||||||
|
app *cli.App
|
||||||
|
elog *eventlog.Log
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) {
|
||||||
|
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
|
||||||
|
changes <- svc.Status{State: svc.StartPending}
|
||||||
|
go s.app.Run(args)
|
||||||
|
changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case c := <-r:
|
||||||
|
switch c.Cmd {
|
||||||
|
case svc.Interrogate:
|
||||||
|
changes <- c.CurrentStatus
|
||||||
|
case svc.Stop, svc.Shutdown:
|
||||||
|
break loop
|
||||||
|
default:
|
||||||
|
s.elog.Error(1, fmt.Sprintf("unexpected control request #%d", c))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(shutdownC)
|
||||||
|
changes <- svc.Status{State: svc.StopPending}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func installWindowsService(c *cli.Context) error {
|
||||||
|
exepath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
s, err := m.OpenService(windowsServiceName)
|
||||||
|
if err == nil {
|
||||||
|
s.Close()
|
||||||
|
return fmt.Errorf("service %s already exists", windowsServiceName)
|
||||||
|
}
|
||||||
|
s, err = m.CreateService(windowsServiceName, exepath, mgr.Config{DisplayName: windowsServiceDescription}, "is", "auto-started")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
err = eventlog.InstallAsEventCreate(windowsServiceName, eventlog.Error|eventlog.Warning|eventlog.Info)
|
||||||
|
if err != nil {
|
||||||
|
s.Delete()
|
||||||
|
return fmt.Errorf("SetupEventLogSource() failed: %s", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallWindowsService(c *cli.Context) error {
|
||||||
|
m, err := mgr.Connect()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer m.Disconnect()
|
||||||
|
s, err := m.OpenService(windowsServiceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("service %s is not installed", windowsServiceName)
|
||||||
|
}
|
||||||
|
defer s.Close()
|
||||||
|
err = s.Delete()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = eventlog.Remove(windowsServiceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,213 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// activeStreamMap is used to moderate access to active streams between the read and write
|
||||||
|
// threads, and deny access to new peer streams while shutting down.
|
||||||
|
type activeStreamMap struct {
|
||||||
|
sync.RWMutex
|
||||||
|
// streams tracks open streams.
|
||||||
|
streams map[uint32]*MuxedStream
|
||||||
|
// streamsEmpty is a chan that should be closed when no more streams are open.
|
||||||
|
streamsEmpty chan struct{}
|
||||||
|
// nextStreamID is the next ID to use on our side of the connection.
|
||||||
|
// This is odd for clients, even for servers.
|
||||||
|
nextStreamID uint32
|
||||||
|
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
|
||||||
|
maxPeerStreamID uint32
|
||||||
|
// ignoreNewStreams is true when the connection is being shut down. New streams
|
||||||
|
// cannot be registered.
|
||||||
|
ignoreNewStreams bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type FlowControlMetrics struct {
|
||||||
|
AverageReceiveWindowSize, AverageSendWindowSize float64
|
||||||
|
MinReceiveWindowSize, MaxReceiveWindowSize uint32
|
||||||
|
MinSendWindowSize, MaxSendWindowSize uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap {
|
||||||
|
m := &activeStreamMap{
|
||||||
|
streams: make(map[uint32]*MuxedStream),
|
||||||
|
streamsEmpty: make(chan struct{}),
|
||||||
|
nextStreamID: 1,
|
||||||
|
}
|
||||||
|
// Client initiated stream uses odd stream ID, server initiated stream uses even stream ID
|
||||||
|
if !useClientStreamNumbers {
|
||||||
|
m.nextStreamID = 2
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of active streams.
|
||||||
|
func (m *activeStreamMap) Len() int {
|
||||||
|
m.RLock()
|
||||||
|
defer m.RUnlock()
|
||||||
|
return len(m.streams)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *activeStreamMap) Get(streamID uint32) (*MuxedStream, bool) {
|
||||||
|
m.RLock()
|
||||||
|
defer m.RUnlock()
|
||||||
|
stream, ok := m.streams[streamID]
|
||||||
|
return stream, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set returns true if the stream was assigned successfully. If a stream
|
||||||
|
// already existed with that ID or we are shutting down, return false.
|
||||||
|
func (m *activeStreamMap) Set(newStream *MuxedStream) bool {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
if _, ok := m.streams[newStream.streamID]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if m.ignoreNewStreams {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
m.streams[newStream.streamID] = newStream
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete stops tracking the stream. It should be called only after it is closed and resetted.
|
||||||
|
func (m *activeStreamMap) Delete(streamID uint32) {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
delete(m.streams, streamID)
|
||||||
|
if len(m.streams) == 0 && m.streamsEmpty != nil {
|
||||||
|
close(m.streamsEmpty)
|
||||||
|
m.streamsEmpty = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown blocks new streams from being created. It returns a channel that receives an event
|
||||||
|
// once the last stream has closed, or nil if a shutdown is in progress.
|
||||||
|
func (m *activeStreamMap) Shutdown() <-chan struct{} {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
if m.ignoreNewStreams {
|
||||||
|
// already shutting down
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.ignoreNewStreams = true
|
||||||
|
done := make(chan struct{})
|
||||||
|
if len(m.streams) == 0 {
|
||||||
|
// nothing to shut down
|
||||||
|
close(done)
|
||||||
|
return done
|
||||||
|
}
|
||||||
|
m.streamsEmpty = done
|
||||||
|
return done
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcquireLocalID acquires a new stream ID for a stream you're opening.
|
||||||
|
func (m *activeStreamMap) AcquireLocalID() uint32 {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
x := m.nextStreamID
|
||||||
|
m.nextStreamID += 2
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
// ObservePeerID observes the ID of a stream opened by the peer. It returns true if we should accept
|
||||||
|
// the new stream, or false to reject it. The ErrCode gives the reason why.
|
||||||
|
func (m *activeStreamMap) AcquirePeerID(streamID uint32) (bool, http2.ErrCode) {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
switch {
|
||||||
|
case m.ignoreNewStreams:
|
||||||
|
return false, http2.ErrCodeStreamClosed
|
||||||
|
case streamID > m.maxPeerStreamID:
|
||||||
|
m.maxPeerStreamID = streamID
|
||||||
|
return true, http2.ErrCodeNo
|
||||||
|
default:
|
||||||
|
return false, http2.ErrCodeStreamClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPeerStreamID is true if the stream ID belongs to the peer.
|
||||||
|
func (m *activeStreamMap) IsPeerStreamID(streamID uint32) bool {
|
||||||
|
m.RLock()
|
||||||
|
defer m.RUnlock()
|
||||||
|
return (streamID % 2) != (m.nextStreamID % 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLocalStreamID is true if it is a stream we have opened, even if it is now closed.
|
||||||
|
func (m *activeStreamMap) IsLocalStreamID(streamID uint32) bool {
|
||||||
|
m.RLock()
|
||||||
|
defer m.RUnlock()
|
||||||
|
return (streamID%2) == (m.nextStreamID%2) && streamID < m.nextStreamID
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastPeerStreamID returns the most recently opened peer stream ID.
|
||||||
|
func (m *activeStreamMap) LastPeerStreamID() uint32 {
|
||||||
|
m.RLock()
|
||||||
|
defer m.RUnlock()
|
||||||
|
return m.maxPeerStreamID
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastLocalStreamID returns the most recently opened local stream ID.
|
||||||
|
func (m *activeStreamMap) LastLocalStreamID() uint32 {
|
||||||
|
m.RLock()
|
||||||
|
defer m.RUnlock()
|
||||||
|
if m.nextStreamID > 1 {
|
||||||
|
return m.nextStreamID - 2
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abort closes every active stream and prevents new ones being created. This should be used to
|
||||||
|
// return errors in pending read/writes when the underlying connection goes away.
|
||||||
|
func (m *activeStreamMap) Abort() {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
for _, stream := range m.streams {
|
||||||
|
stream.Close()
|
||||||
|
}
|
||||||
|
m.ignoreNewStreams = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *activeStreamMap) Metrics() *FlowControlMetrics {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
var averageReceiveWindowSize, averageSendWindowSize float64
|
||||||
|
var minReceiveWindowSize, maxReceiveWindowSize, minSendWindowSize, maxSendWindowSize uint32
|
||||||
|
i := 0
|
||||||
|
// The first variable in the range expression for map is the key, not index.
|
||||||
|
for _, stream := range m.streams {
|
||||||
|
// iterative mean: a(t+1) = a(t) + (a(t)-x)/(t+1)
|
||||||
|
windows := stream.FlowControlWindow()
|
||||||
|
averageReceiveWindowSize += (float64(windows.receiveWindow) - averageReceiveWindowSize) / float64(i+1)
|
||||||
|
averageSendWindowSize += (float64(windows.sendWindow) - averageSendWindowSize) / float64(i+1)
|
||||||
|
if i == 0 {
|
||||||
|
maxReceiveWindowSize = windows.receiveWindow
|
||||||
|
minReceiveWindowSize = windows.receiveWindow
|
||||||
|
maxSendWindowSize = windows.sendWindow
|
||||||
|
minSendWindowSize = windows.sendWindow
|
||||||
|
} else {
|
||||||
|
if windows.receiveWindow > maxReceiveWindowSize {
|
||||||
|
maxReceiveWindowSize = windows.receiveWindow
|
||||||
|
} else if windows.receiveWindow < minReceiveWindowSize {
|
||||||
|
minReceiveWindowSize = windows.receiveWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
if windows.sendWindow > maxSendWindowSize {
|
||||||
|
maxSendWindowSize = windows.sendWindow
|
||||||
|
} else if windows.sendWindow < minSendWindowSize {
|
||||||
|
minSendWindowSize = windows.sendWindow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return &FlowControlMetrics{
|
||||||
|
MinReceiveWindowSize: minReceiveWindowSize,
|
||||||
|
MaxReceiveWindowSize: maxReceiveWindowSize,
|
||||||
|
AverageReceiveWindowSize: averageReceiveWindowSize,
|
||||||
|
MinSendWindowSize: minSendWindowSize,
|
||||||
|
MaxSendWindowSize: maxSendWindowSize,
|
||||||
|
AverageSendWindowSize: averageSendWindowSize,
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMetrics(t *testing.T) {
|
||||||
|
streamMap := newActiveStreamMap(false)
|
||||||
|
for i := 1; i <= 7; i++ {
|
||||||
|
stream := new(MuxedStream)
|
||||||
|
*stream = MuxedStream{
|
||||||
|
streamID: defaultWindowSize * uint32(i),
|
||||||
|
receiveWindow: defaultWindowSize * uint32(i),
|
||||||
|
sendWindow: defaultWindowSize * uint32(i),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, streamMap.Set(stream))
|
||||||
|
}
|
||||||
|
metrics := streamMap.Metrics()
|
||||||
|
assert.Equal(t, float64(defaultWindowSize*4), metrics.AverageReceiveWindowSize)
|
||||||
|
assert.Equal(t, float64(defaultWindowSize*4), metrics.AverageSendWindowSize)
|
||||||
|
assert.Equal(t, defaultWindowSize, metrics.MinReceiveWindowSize)
|
||||||
|
assert.Equal(t, defaultWindowSize, metrics.MinSendWindowSize)
|
||||||
|
assert.Equal(t, defaultWindowSize*7, metrics.MaxReceiveWindowSize)
|
||||||
|
assert.Equal(t, defaultWindowSize*7, metrics.MaxSendWindowSize)
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import "sync/atomic"
|
||||||
|
|
||||||
|
// BooleanFuse is a data structure that can be set once to a particular value using Fuse(value).
|
||||||
|
// Subsequent calls to Fuse() will have no effect.
|
||||||
|
type BooleanFuse struct {
|
||||||
|
value int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value gets the value
|
||||||
|
func (f *BooleanFuse) Value() bool {
|
||||||
|
// 0: unset
|
||||||
|
// 1: set true
|
||||||
|
// 2: set false
|
||||||
|
return atomic.LoadInt32(&f.value) == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *BooleanFuse) Fuse(result bool) {
|
||||||
|
newValue := int32(2)
|
||||||
|
if result {
|
||||||
|
newValue = 1
|
||||||
|
}
|
||||||
|
atomic.CompareAndSwapInt32(&f.value, 0, newValue)
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrHandshakeTimeout = MuxerHandshakeError{"1000 handshake timeout"}
|
||||||
|
ErrBadHandshakeNotSettings = MuxerHandshakeError{"1001 unexpected response"}
|
||||||
|
ErrBadHandshakeUnexpectedAck = MuxerHandshakeError{"1002 unexpected response"}
|
||||||
|
ErrBadHandshakeNoMagic = MuxerHandshakeError{"1003 unexpected response"}
|
||||||
|
ErrBadHandshakeWrongMagic = MuxerHandshakeError{"1004 connected to endpoint of wrong type"}
|
||||||
|
ErrBadHandshakeNotSettingsAck = MuxerHandshakeError{"1005 unexpected response"}
|
||||||
|
ErrBadHandshakeUnexpectedSettings = MuxerHandshakeError{"1006 unexpected response"}
|
||||||
|
|
||||||
|
ErrUnexpectedFrameType = MuxerProtocolError{"2001 unexpected frame type", http2.ErrCodeProtocol}
|
||||||
|
ErrUnknownStream = MuxerProtocolError{"2002 unknown stream", http2.ErrCodeProtocol}
|
||||||
|
ErrInvalidStream = MuxerProtocolError{"2003 invalid stream", http2.ErrCodeProtocol}
|
||||||
|
|
||||||
|
ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"}
|
||||||
|
ErrConnectionClosed = MuxerApplicationError{"3001 connection closed"}
|
||||||
|
ErrConnectionDropped = MuxerApplicationError{"3002 connection dropped"}
|
||||||
|
|
||||||
|
ErrClosedStream = MuxerStreamError{"4000 stream closed", http2.ErrCodeStreamClosed}
|
||||||
|
)
|
||||||
|
|
||||||
|
type MuxerHandshakeError struct {
|
||||||
|
cause string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e MuxerHandshakeError) Error() string {
|
||||||
|
return fmt.Sprintf("Handshake error: %s", e.cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MuxerProtocolError struct {
|
||||||
|
cause string
|
||||||
|
h2code http2.ErrCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e MuxerProtocolError) Error() string {
|
||||||
|
return fmt.Sprintf("Protocol error: %s", e.cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MuxerApplicationError struct {
|
||||||
|
cause string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e MuxerApplicationError) Error() string {
|
||||||
|
return fmt.Sprintf("Application error: %s", e.cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MuxerStreamError struct {
|
||||||
|
cause string
|
||||||
|
h2code http2.ErrCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e MuxerStreamError) Error() string {
|
||||||
|
return fmt.Sprintf("Stream error: %s", e.cause)
|
||||||
|
}
|
|
@ -0,0 +1,335 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/http2/hpack"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultFrameSize uint32 = 1 << 14 // Minimum frame size in http2 spec
|
||||||
|
defaultWindowSize uint32 = 65535
|
||||||
|
maxWindowSize uint32 = (1 << 31) - 1 // 2^31-1 = 2147483647, max window size specified in http2 spec
|
||||||
|
defaultTimeout time.Duration = 5 * time.Second
|
||||||
|
defaultRetries uint64 = 5
|
||||||
|
|
||||||
|
SettingMuxerMagic http2.SettingID = 0x42db
|
||||||
|
MuxerMagicOrigin uint32 = 0xa2e43c8b
|
||||||
|
MuxerMagicEdge uint32 = 0x1088ebf9
|
||||||
|
)
|
||||||
|
|
||||||
|
type MuxedStreamHandler interface {
|
||||||
|
ServeStream(*MuxedStream) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type MuxedStreamFunc func(stream *MuxedStream) error
|
||||||
|
|
||||||
|
func (f MuxedStreamFunc) ServeStream(stream *MuxedStream) error {
|
||||||
|
return f(stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MuxerConfig struct {
|
||||||
|
Timeout time.Duration
|
||||||
|
Handler MuxedStreamHandler
|
||||||
|
IsClient bool
|
||||||
|
// Name is used to identify this muxer instance when logging.
|
||||||
|
Name string
|
||||||
|
// The minimum time this connection can be idle before sending a heartbeat.
|
||||||
|
HeartbeatInterval time.Duration
|
||||||
|
// The minimum number of heartbeats to send before terminating the connection.
|
||||||
|
MaxHeartbeats uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type Muxer struct {
|
||||||
|
// f is used to read and write HTTP2 frames on the wire.
|
||||||
|
f *http2.Framer
|
||||||
|
// config is the MuxerConfig given in Handshake.
|
||||||
|
config MuxerConfig
|
||||||
|
// w, r are references to the underlying connection used.
|
||||||
|
w io.WriteCloser
|
||||||
|
r io.ReadCloser
|
||||||
|
// muxReader is the read process.
|
||||||
|
muxReader *MuxReader
|
||||||
|
// muxWriter is the write process.
|
||||||
|
muxWriter *MuxWriter
|
||||||
|
// newStreamChan is used to create new streams on the writer thread.
|
||||||
|
// The writer will assign the next available stream ID.
|
||||||
|
newStreamChan chan MuxedStreamRequest
|
||||||
|
// abortChan is used to abort the writer event loop.
|
||||||
|
abortChan chan struct{}
|
||||||
|
// abortOnce is used to ensure abortChan is closed once only.
|
||||||
|
abortOnce sync.Once
|
||||||
|
// readyList is used to signal writable streams.
|
||||||
|
readyList *ReadyList
|
||||||
|
// streams tracks currently-open streams.
|
||||||
|
streams *activeStreamMap
|
||||||
|
// explicitShutdown records whether the Muxer is closing because Shutdown was called, or due to another
|
||||||
|
// error.
|
||||||
|
explicitShutdown BooleanFuse
|
||||||
|
}
|
||||||
|
|
||||||
|
type Header struct {
|
||||||
|
Name, Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handshake establishes a muxed connection with the peer.
|
||||||
|
// After the handshake completes, it is possible to open and accept streams.
|
||||||
|
func Handshake(
|
||||||
|
w io.WriteCloser,
|
||||||
|
r io.ReadCloser,
|
||||||
|
config MuxerConfig,
|
||||||
|
) (*Muxer, error) {
|
||||||
|
// Initialise connection state fields
|
||||||
|
m := &Muxer{
|
||||||
|
f: http2.NewFramer(w, r), // A framer that writes to w and reads from r
|
||||||
|
config: config,
|
||||||
|
w: w,
|
||||||
|
r: r,
|
||||||
|
newStreamChan: make(chan MuxedStreamRequest),
|
||||||
|
abortChan: make(chan struct{}),
|
||||||
|
readyList: NewReadyList(),
|
||||||
|
streams: newActiveStreamMap(config.IsClient),
|
||||||
|
}
|
||||||
|
m.f.ReadMetaHeaders = hpack.NewDecoder(4096, func(hpack.HeaderField) {})
|
||||||
|
|
||||||
|
if config.Timeout == 0 {
|
||||||
|
config.Timeout = defaultTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialise the settings to identify this connection and confirm the other end is sane.
|
||||||
|
handshakeSetting := http2.Setting{ID: SettingMuxerMagic, Val: MuxerMagicEdge}
|
||||||
|
expectedMagic := MuxerMagicOrigin
|
||||||
|
if config.IsClient {
|
||||||
|
handshakeSetting.Val = MuxerMagicOrigin
|
||||||
|
expectedMagic = MuxerMagicEdge
|
||||||
|
}
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
// Simultaneously send our settings and verify the peer's settings.
|
||||||
|
go func() { errChan <- m.f.WriteSettings(handshakeSetting) }()
|
||||||
|
go func() { errChan <- m.readPeerSettings(expectedMagic) }()
|
||||||
|
err := joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Confirm sanity by ACKing the frame and expecting an ACK for our frame.
|
||||||
|
// Not strictly necessary, but let's pretend to be H2-like.
|
||||||
|
go func() { errChan <- m.f.WriteSettingsAck() }()
|
||||||
|
go func() { errChan <- m.readPeerSettingsAck() }()
|
||||||
|
err = joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set up reader/writer pair ready for serve
|
||||||
|
streamErrors := NewStreamErrorMap()
|
||||||
|
goAwayChan := make(chan http2.ErrCode, 1)
|
||||||
|
pingTimestamp := NewPingTimestamp()
|
||||||
|
connActive := NewSignal()
|
||||||
|
idleDuration := config.HeartbeatInterval
|
||||||
|
// Sanity check to enusre idelDuration is sane
|
||||||
|
if idleDuration == 0 || idleDuration < defaultTimeout {
|
||||||
|
idleDuration = defaultTimeout
|
||||||
|
log.Warn("Minimum idle time has been adjusted to ", defaultTimeout)
|
||||||
|
}
|
||||||
|
maxRetries := config.MaxHeartbeats
|
||||||
|
if maxRetries == 0 {
|
||||||
|
maxRetries = defaultRetries
|
||||||
|
log.Warn("Minimum number of unacked heartbeats to send before closing the connection has been adjusted to ", maxRetries)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.muxReader = &MuxReader{
|
||||||
|
f: m.f,
|
||||||
|
handler: m.config.Handler,
|
||||||
|
streams: m.streams,
|
||||||
|
readyList: m.readyList,
|
||||||
|
streamErrors: streamErrors,
|
||||||
|
goAwayChan: goAwayChan,
|
||||||
|
abortChan: m.abortChan,
|
||||||
|
pingTimestamp: pingTimestamp,
|
||||||
|
connActive: connActive,
|
||||||
|
initialStreamWindow: defaultWindowSize,
|
||||||
|
streamWindowMax: maxWindowSize,
|
||||||
|
r: m.r,
|
||||||
|
}
|
||||||
|
m.muxWriter = &MuxWriter{
|
||||||
|
f: m.f,
|
||||||
|
streams: m.streams,
|
||||||
|
streamErrors: streamErrors,
|
||||||
|
readyStreamChan: m.readyList.ReadyChannel(),
|
||||||
|
newStreamChan: m.newStreamChan,
|
||||||
|
goAwayChan: goAwayChan,
|
||||||
|
abortChan: m.abortChan,
|
||||||
|
pingTimestamp: pingTimestamp,
|
||||||
|
idleTimer: NewIdleTimer(idleDuration, maxRetries),
|
||||||
|
connActiveChan: connActive.WaitChannel(),
|
||||||
|
maxFrameSize: defaultFrameSize,
|
||||||
|
}
|
||||||
|
m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Muxer) readPeerSettings(magic uint32) error {
|
||||||
|
frame, err := m.f.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
settingsFrame, ok := frame.(*http2.SettingsFrame)
|
||||||
|
if !ok {
|
||||||
|
return ErrBadHandshakeNotSettings
|
||||||
|
}
|
||||||
|
if settingsFrame.Header().Flags != 0 {
|
||||||
|
return ErrBadHandshakeUnexpectedAck
|
||||||
|
}
|
||||||
|
peerMagic, ok := settingsFrame.Value(SettingMuxerMagic)
|
||||||
|
if !ok {
|
||||||
|
return ErrBadHandshakeNoMagic
|
||||||
|
}
|
||||||
|
if magic != peerMagic {
|
||||||
|
return ErrBadHandshakeWrongMagic
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Muxer) readPeerSettingsAck() error {
|
||||||
|
frame, err := m.f.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
settingsFrame, ok := frame.(*http2.SettingsFrame)
|
||||||
|
if !ok {
|
||||||
|
return ErrBadHandshakeNotSettingsAck
|
||||||
|
}
|
||||||
|
if settingsFrame.Header().Flags != http2.FlagSettingsAck {
|
||||||
|
return ErrBadHandshakeUnexpectedSettings
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time.Duration, timeoutError error) error {
|
||||||
|
for i := 0; i < receiveCount; i++ {
|
||||||
|
select {
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return timeoutError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Muxer) Serve() error {
|
||||||
|
logger := log.WithField("name", m.config.Name)
|
||||||
|
errChan := make(chan error)
|
||||||
|
go func() {
|
||||||
|
errChan <- m.muxReader.run(logger)
|
||||||
|
m.explicitShutdown.Fuse(false)
|
||||||
|
m.r.Close()
|
||||||
|
m.abort()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
errChan <- m.muxWriter.run(logger)
|
||||||
|
m.explicitShutdown.Fuse(false)
|
||||||
|
m.w.Close()
|
||||||
|
m.abort()
|
||||||
|
}()
|
||||||
|
err := <-errChan
|
||||||
|
go func() {
|
||||||
|
// discard error as other handler closes
|
||||||
|
<-errChan
|
||||||
|
close(errChan)
|
||||||
|
}()
|
||||||
|
if isUnexpectedTunnelError(err, m.explicitShutdown.Value()) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Muxer) Shutdown() {
|
||||||
|
m.explicitShutdown.Fuse(true)
|
||||||
|
m.muxReader.Shutdown()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
|
||||||
|
// The set of expected errors change depending on whether we initiated shutdown or not.
|
||||||
|
func isUnexpectedTunnelError(err error, expectedShutdown bool) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !expectedShutdown {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return !isConnectionClosedError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isConnectionClosedError(err error) bool {
|
||||||
|
if err == io.EOF {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err == io.ErrClosedPipe {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err.Error() == "tls: use of closed connection" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenStream opens a new data stream with the given headers.
|
||||||
|
// Called by proxy server and tunnel
|
||||||
|
func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, error) {
|
||||||
|
stream := &MuxedStream{
|
||||||
|
responseHeadersReceived: make(chan struct{}),
|
||||||
|
readBuffer: NewSharedBuffer(),
|
||||||
|
receiveWindow: defaultWindowSize,
|
||||||
|
receiveWindowCurrentMax: defaultWindowSize, // Initial window size limit. exponentially increase it when receiveWindow is exhausted
|
||||||
|
receiveWindowMax: maxWindowSize,
|
||||||
|
sendWindow: defaultWindowSize,
|
||||||
|
readyList: m.readyList,
|
||||||
|
writeHeaders: headers,
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
// Will be received by mux writer
|
||||||
|
case m.newStreamChan <- MuxedStreamRequest{stream: stream, body: body}:
|
||||||
|
case <-m.abortChan:
|
||||||
|
return nil, ErrConnectionClosed
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-stream.responseHeadersReceived:
|
||||||
|
return stream, nil
|
||||||
|
case <-m.abortChan:
|
||||||
|
return nil, ErrConnectionClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the estimated round-trip time.
|
||||||
|
func (m *Muxer) RTT() RTTMeasurement {
|
||||||
|
return m.muxReader.RTT()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return min/max/average of send/receive window for all streams on this connection
|
||||||
|
func (m *Muxer) FlowControlMetrics() *FlowControlMetrics {
|
||||||
|
return m.muxReader.FlowControlMetrics()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Muxer) abort() {
|
||||||
|
m.abortOnce.Do(func() {
|
||||||
|
close(m.abortChan)
|
||||||
|
m.streams.Abort()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return how many retries/ticks since the connection was last marked active
|
||||||
|
func (m *Muxer) TimerRetries() uint64 {
|
||||||
|
return m.muxWriter.idleTimer.RetryCount()
|
||||||
|
}
|
|
@ -0,0 +1,646 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/rand"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
if os.Getenv("VERBOSE") == "1" {
|
||||||
|
log.SetLevel(log.DebugLevel)
|
||||||
|
}
|
||||||
|
os.Exit(m.Run())
|
||||||
|
}
|
||||||
|
|
||||||
|
type DefaultMuxerPair struct {
|
||||||
|
OriginMuxConfig MuxerConfig
|
||||||
|
OriginMux *Muxer
|
||||||
|
OriginWriter *io.PipeWriter
|
||||||
|
OriginReader *io.PipeReader
|
||||||
|
EdgeMuxConfig MuxerConfig
|
||||||
|
EdgeMux *Muxer
|
||||||
|
EdgeWriter *io.PipeWriter
|
||||||
|
EdgeReader *io.PipeReader
|
||||||
|
doneC chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDefaultMuxerPair() *DefaultMuxerPair {
|
||||||
|
originReader, edgeWriter := io.Pipe()
|
||||||
|
edgeReader, originWriter := io.Pipe()
|
||||||
|
return &DefaultMuxerPair{
|
||||||
|
OriginMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: true, Name: "origin"},
|
||||||
|
OriginWriter: originWriter,
|
||||||
|
OriginReader: originReader,
|
||||||
|
EdgeMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: false, Name: "edge"},
|
||||||
|
EdgeWriter: edgeWriter,
|
||||||
|
EdgeReader: edgeReader,
|
||||||
|
doneC: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DefaultMuxerPair) Handshake(t *testing.T) {
|
||||||
|
edgeErrC := make(chan error)
|
||||||
|
originErrC := make(chan error)
|
||||||
|
go func() {
|
||||||
|
var err error
|
||||||
|
p.EdgeMux, err = Handshake(p.EdgeWriter, p.EdgeReader, p.EdgeMuxConfig)
|
||||||
|
edgeErrC <- err
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
var err error
|
||||||
|
p.OriginMux, err = Handshake(p.OriginWriter, p.OriginReader, p.OriginMuxConfig)
|
||||||
|
originErrC <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-edgeErrC:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("edge handshake failure: %s", err)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second * 5):
|
||||||
|
t.Fatalf("edge handshake timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-originErrC:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("origin handshake failure: %s", err)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second * 5):
|
||||||
|
t.Fatalf("origin handshake timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) {
|
||||||
|
p.Handshake(t)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
err := p.EdgeMux.Serve()
|
||||||
|
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
|
||||||
|
t.Errorf("error in edge muxer Serve(): %s", err)
|
||||||
|
}
|
||||||
|
p.OriginMux.Shutdown()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
err := p.OriginMux.Serve()
|
||||||
|
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
|
||||||
|
t.Errorf("error in origin muxer Serve(): %s", err)
|
||||||
|
}
|
||||||
|
p.EdgeMux.Shutdown()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
// notify when both muxes have stopped serving
|
||||||
|
wg.Wait()
|
||||||
|
close(p.doneC)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DefaultMuxerPair) Wait(t *testing.T) {
|
||||||
|
select {
|
||||||
|
case <-p.doneC:
|
||||||
|
return
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for shutdown")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandshake(t *testing.T) {
|
||||||
|
muxPair := NewDefaultMuxerPair()
|
||||||
|
muxPair.Handshake(t)
|
||||||
|
AssertIfPipeReadable(t, muxPair.OriginReader)
|
||||||
|
AssertIfPipeReadable(t, muxPair.EdgeReader)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSingleStream(t *testing.T) {
|
||||||
|
closeC := make(chan struct{})
|
||||||
|
muxPair := NewDefaultMuxerPair()
|
||||||
|
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||||
|
defer close(closeC)
|
||||||
|
if len(stream.Headers) != 1 {
|
||||||
|
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Name != "test-header" {
|
||||||
|
t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Value != "headerValue" {
|
||||||
|
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
|
||||||
|
}
|
||||||
|
stream.WriteHeaders([]Header{
|
||||||
|
Header{Name: "response-header", Value: "responseValue"},
|
||||||
|
})
|
||||||
|
buf := []byte("Hello world")
|
||||||
|
stream.Write(buf)
|
||||||
|
// after this receive, the edge closed the stream
|
||||||
|
<-closeC
|
||||||
|
n, err := stream.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
t.Fatalf("read %d bytes after EOF", n)
|
||||||
|
}
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Fatalf("expected EOF, got %s", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
muxPair.HandshakeAndServe(t)
|
||||||
|
|
||||||
|
stream, err := muxPair.EdgeMux.OpenStream(
|
||||||
|
[]Header{Header{Name: "test-header", Value: "headerValue"}},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error in OpenStream: %s", err)
|
||||||
|
}
|
||||||
|
if len(stream.Headers) != 1 {
|
||||||
|
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Name != "response-header" {
|
||||||
|
t.Fatalf("expected header name %s, got %s", "response-header", stream.Headers[0].Name)
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Value != "responseValue" {
|
||||||
|
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
|
||||||
|
}
|
||||||
|
responseBody := make([]byte, 11)
|
||||||
|
n, err := stream.Read(responseBody)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error from (*MuxedStream).Read: %s", err)
|
||||||
|
}
|
||||||
|
if n != len(responseBody) {
|
||||||
|
t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n)
|
||||||
|
}
|
||||||
|
if string(responseBody) != "Hello world" {
|
||||||
|
t.Fatalf("expected response body %s, got %s", "Hello world", responseBody)
|
||||||
|
}
|
||||||
|
stream.Close()
|
||||||
|
closeC <- struct{}{}
|
||||||
|
n, err = stream.Write([]byte("aaaaa"))
|
||||||
|
if n > 0 {
|
||||||
|
t.Fatalf("wrote %d bytes after EOF", n)
|
||||||
|
}
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Fatalf("expected EOF, got %s", err)
|
||||||
|
}
|
||||||
|
<-closeC
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSingleStreamLargeResponseBody(t *testing.T) {
|
||||||
|
muxPair := NewDefaultMuxerPair()
|
||||||
|
bodySize := 1 << 24
|
||||||
|
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||||
|
if len(stream.Headers) != 1 {
|
||||||
|
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Name != "test-header" {
|
||||||
|
t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Value != "headerValue" {
|
||||||
|
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
|
||||||
|
}
|
||||||
|
stream.WriteHeaders([]Header{
|
||||||
|
Header{Name: "response-header", Value: "responseValue"},
|
||||||
|
})
|
||||||
|
payload := make([]byte, bodySize)
|
||||||
|
for i := range payload {
|
||||||
|
payload[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
n, err := stream.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("origin write error: %s", err)
|
||||||
|
}
|
||||||
|
if n != len(payload) {
|
||||||
|
t.Fatalf("origin short write: %d/%d bytes", n, len(payload))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
muxPair.HandshakeAndServe(t)
|
||||||
|
|
||||||
|
stream, err := muxPair.EdgeMux.OpenStream(
|
||||||
|
[]Header{Header{Name: "test-header", Value: "headerValue"}},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error in OpenStream: %s", err)
|
||||||
|
}
|
||||||
|
if len(stream.Headers) != 1 {
|
||||||
|
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Name != "response-header" {
|
||||||
|
t.Fatalf("expected header name %s, got %s", "response-header", stream.Headers[0].Name)
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Value != "responseValue" {
|
||||||
|
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
|
||||||
|
}
|
||||||
|
responseBody := make([]byte, bodySize)
|
||||||
|
n, err := stream.Read(responseBody)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error from (*MuxedStream).Read: %s", err)
|
||||||
|
}
|
||||||
|
if n != len(responseBody) {
|
||||||
|
t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleStreams(t *testing.T) {
|
||||||
|
muxPair := NewDefaultMuxerPair()
|
||||||
|
maxStreams := 64
|
||||||
|
errorsC := make(chan error, maxStreams)
|
||||||
|
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||||
|
if len(stream.Headers) != 1 {
|
||||||
|
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Name != "client-token" {
|
||||||
|
t.Fatalf("expected header name %s, got %s", "client-token", stream.Headers[0].Name)
|
||||||
|
}
|
||||||
|
log.Debugf("Got request for stream %s", stream.Headers[0].Value)
|
||||||
|
stream.WriteHeaders([]Header{
|
||||||
|
Header{Name: "response-token", Value: stream.Headers[0].Value},
|
||||||
|
})
|
||||||
|
log.Debugf("Wrote headers for stream %s", stream.Headers[0].Value)
|
||||||
|
stream.Write([]byte("OK"))
|
||||||
|
log.Debugf("Wrote body for stream %s", stream.Headers[0].Value)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
muxPair.HandshakeAndServe(t)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(maxStreams)
|
||||||
|
for i := 0; i < maxStreams; i++ {
|
||||||
|
go func(tokenId int) {
|
||||||
|
defer wg.Done()
|
||||||
|
tokenString := fmt.Sprintf("%d", tokenId)
|
||||||
|
stream, err := muxPair.EdgeMux.OpenStream(
|
||||||
|
[]Header{Header{Name: "client-token", Value: tokenString}},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
log.Debugf("Got headers for stream %d", tokenId)
|
||||||
|
if err != nil {
|
||||||
|
errorsC <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(stream.Headers) != 1 {
|
||||||
|
errorsC <- fmt.Errorf("stream %d has error: expected %d headers, got %d", stream.streamID, 1, len(stream.Headers))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Name != "response-token" {
|
||||||
|
errorsC <- fmt.Errorf("stream %d has error: expected header name %s, got %s", stream.streamID, "response-token", stream.Headers[0].Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Value != tokenString {
|
||||||
|
errorsC <- fmt.Errorf("stream %d has error: expected header value %s, got %s", stream.streamID, tokenString, stream.Headers[0].Value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
responseBody := make([]byte, 2)
|
||||||
|
n, err := stream.Read(responseBody)
|
||||||
|
if err != nil {
|
||||||
|
errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n != len(responseBody) {
|
||||||
|
errorsC <- fmt.Errorf("stream %d has error: expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if string(responseBody) != "OK" {
|
||||||
|
errorsC <- fmt.Errorf("stream %d has error: expected response body %s, got %s", stream.streamID, "OK", responseBody)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
close(errorsC)
|
||||||
|
testFail := false
|
||||||
|
for err := range errorsC {
|
||||||
|
testFail = true
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
if testFail {
|
||||||
|
t.Fatalf("TestMultipleStreamsFlowControl failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleStreamsFlowControl(t *testing.T) {
|
||||||
|
maxStreams := 32
|
||||||
|
errorsC := make(chan error, maxStreams)
|
||||||
|
responseSizes := make([]int32, maxStreams)
|
||||||
|
for i := 0; i < maxStreams; i++ {
|
||||||
|
responseSizes[i] = rand.Int31n(int32(defaultWindowSize << 4))
|
||||||
|
}
|
||||||
|
muxPair := NewDefaultMuxerPair()
|
||||||
|
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||||
|
if len(stream.Headers) != 1 {
|
||||||
|
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Name != "test-header" {
|
||||||
|
t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Value != "headerValue" {
|
||||||
|
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
|
||||||
|
}
|
||||||
|
stream.WriteHeaders([]Header{
|
||||||
|
Header{Name: "response-header", Value: "responseValue"},
|
||||||
|
})
|
||||||
|
payload := make([]byte, responseSizes[(stream.streamID-2)/2])
|
||||||
|
for i := range payload {
|
||||||
|
payload[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
n, err := stream.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("origin write error: %s", err)
|
||||||
|
}
|
||||||
|
if n != len(payload) {
|
||||||
|
t.Fatalf("origin short write: %d/%d bytes", n, len(payload))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
muxPair.HandshakeAndServe(t)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(maxStreams)
|
||||||
|
for i := 0; i < maxStreams; i++ {
|
||||||
|
go func(tokenId int) {
|
||||||
|
defer wg.Done()
|
||||||
|
stream, err := muxPair.EdgeMux.OpenStream(
|
||||||
|
[]Header{Header{Name: "test-header", Value: "headerValue"}},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
errorsC <- fmt.Errorf("stream %d error in OpenStream: %s", stream.streamID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(stream.Headers) != 1 {
|
||||||
|
errorsC <- fmt.Errorf("stream %d expected %d headers, got %d", stream.streamID, 1, len(stream.Headers))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Name != "response-header" {
|
||||||
|
errorsC <- fmt.Errorf("stream %d expected header name %s, got %s", stream.streamID, "response-header", stream.Headers[0].Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Value != "responseValue" {
|
||||||
|
errorsC <- fmt.Errorf("stream %d expected header value %s, got %s", stream.streamID, "responseValue", stream.Headers[0].Value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
responseBody := make([]byte, responseSizes[(stream.streamID-2)/2])
|
||||||
|
n, err := stream.Read(responseBody)
|
||||||
|
if err != nil {
|
||||||
|
errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n != len(responseBody) {
|
||||||
|
errorsC <- fmt.Errorf("stream %d expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
close(errorsC)
|
||||||
|
testFail := false
|
||||||
|
for err := range errorsC {
|
||||||
|
testFail = true
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
if testFail {
|
||||||
|
t.Fatalf("TestMultipleStreamsFlowControl failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGracefulShutdown(t *testing.T) {
|
||||||
|
sendC := make(chan struct{})
|
||||||
|
responseBuf := bytes.Repeat([]byte("Hello world"), 65536)
|
||||||
|
muxPair := NewDefaultMuxerPair()
|
||||||
|
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||||
|
stream.WriteHeaders([]Header{
|
||||||
|
Header{Name: "response-header", Value: "responseValue"},
|
||||||
|
})
|
||||||
|
<-sendC
|
||||||
|
log.Debugf("Writing %d bytes", len(responseBuf))
|
||||||
|
stream.Write(responseBuf)
|
||||||
|
stream.CloseWrite()
|
||||||
|
log.Debugf("Wrote %d bytes", len(responseBuf))
|
||||||
|
// Reading from the stream will block until the edge closes its end of the stream.
|
||||||
|
// Otherwise, we'll close the whole connection before receiving the 'stream closed'
|
||||||
|
// message from the edge.
|
||||||
|
// Graceful shutdown works if you omit this, it just gives spurious errors for now -
|
||||||
|
// TODO ignore errors when writing 'stream closed' and we're shutting down.
|
||||||
|
stream.Read([]byte{0})
|
||||||
|
log.Debugf("Handler ends")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
muxPair.HandshakeAndServe(t)
|
||||||
|
|
||||||
|
stream, err := muxPair.EdgeMux.OpenStream(
|
||||||
|
[]Header{Header{Name: "test-header", Value: "headerValue"}},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
// Start graceful shutdown of the edge mux - this should also close the origin mux when done
|
||||||
|
muxPair.EdgeMux.Shutdown()
|
||||||
|
close(sendC)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error in OpenStream: %s", err)
|
||||||
|
}
|
||||||
|
responseBody := make([]byte, len(responseBuf))
|
||||||
|
log.Debugf("Waiting for %d bytes", len(responseBuf))
|
||||||
|
n, err := stream.Read(responseBody)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err)
|
||||||
|
}
|
||||||
|
if n != len(responseBody) {
|
||||||
|
t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(responseBuf, responseBody) {
|
||||||
|
t.Fatalf("response body mismatch")
|
||||||
|
}
|
||||||
|
stream.Close()
|
||||||
|
muxPair.Wait(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnexpectedShutdown(t *testing.T) {
|
||||||
|
sendC := make(chan struct{})
|
||||||
|
handlerFinishC := make(chan struct{})
|
||||||
|
responseBuf := bytes.Repeat([]byte("Hello world"), 65536)
|
||||||
|
muxPair := NewDefaultMuxerPair()
|
||||||
|
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||||
|
defer close(handlerFinishC)
|
||||||
|
stream.WriteHeaders([]Header{
|
||||||
|
Header{Name: "response-header", Value: "responseValue"},
|
||||||
|
})
|
||||||
|
<-sendC
|
||||||
|
n, err := stream.Read([]byte{0})
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Fatalf("expected empty read, got %d bytes", n)
|
||||||
|
}
|
||||||
|
// Write comes after read, because write buffers data before it is flushed. It wouldn't know about EOF
|
||||||
|
// until some time later. Calling read first forces it to know about EOF now.
|
||||||
|
_, err = stream.Write(responseBuf)
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Fatalf("unexpected error from (*MuxedStream).Write: %s", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
muxPair.HandshakeAndServe(t)
|
||||||
|
|
||||||
|
stream, err := muxPair.EdgeMux.OpenStream(
|
||||||
|
[]Header{Header{Name: "test-header", Value: "headerValue"}},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
// Close the underlying connection before telling the origin to write.
|
||||||
|
muxPair.EdgeReader.Close()
|
||||||
|
close(sendC)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error in OpenStream: %s", err)
|
||||||
|
}
|
||||||
|
responseBody := make([]byte, len(responseBuf))
|
||||||
|
n, err := stream.Read(responseBody)
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Fatalf("expected response body to have %d bytes, got %d", 0, n)
|
||||||
|
}
|
||||||
|
// The write ordering requirement explained in the origin handler applies here too.
|
||||||
|
_, err = stream.Write(responseBuf)
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Fatalf("unexpected error from (*MuxedStream).Write: %s", err)
|
||||||
|
}
|
||||||
|
<-handlerFinishC
|
||||||
|
}
|
||||||
|
|
||||||
|
func EchoHandler(stream *MuxedStream) error {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
fmt.Fprintf(&buf, "Hello, world!\n\n# REQUEST HEADERS:\n\n")
|
||||||
|
for _, header := range stream.Headers {
|
||||||
|
fmt.Fprintf(&buf, "[%s] = %s\n", header.Name, header.Value)
|
||||||
|
}
|
||||||
|
stream.WriteHeaders([]Header{
|
||||||
|
{Name: ":status", Value: "200"},
|
||||||
|
{Name: "server", Value: "Echo-server/1.0"},
|
||||||
|
{Name: "date", Value: time.Now().Format(time.RFC850)},
|
||||||
|
{Name: "content-type", Value: "text/html; charset=utf-8"},
|
||||||
|
{Name: "content-length", Value: strconv.Itoa(buf.Len())},
|
||||||
|
})
|
||||||
|
buf.WriteTo(stream)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAfterDisconnect(t *testing.T) {
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
muxPair := NewDefaultMuxerPair()
|
||||||
|
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(EchoHandler)
|
||||||
|
muxPair.HandshakeAndServe(t)
|
||||||
|
|
||||||
|
switch i {
|
||||||
|
case 0:
|
||||||
|
// Close both directions of the connection to cause EOF on both peers.
|
||||||
|
muxPair.OriginReader.Close()
|
||||||
|
muxPair.OriginWriter.Close()
|
||||||
|
case 1:
|
||||||
|
// Close origin reader (edge writer) to cause EOF on origin only.
|
||||||
|
muxPair.OriginReader.Close()
|
||||||
|
case 2:
|
||||||
|
// Close origin writer (edge reader) to cause EOF on edge only.
|
||||||
|
muxPair.OriginWriter.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := muxPair.EdgeMux.OpenStream(
|
||||||
|
[]Header{Header{Name: "test-header", Value: "headerValue"}},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != ErrConnectionClosed {
|
||||||
|
t.Fatalf("unexpected error in OpenStream: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHPACK(t *testing.T) {
|
||||||
|
muxPair := NewDefaultMuxerPair()
|
||||||
|
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(EchoHandler)
|
||||||
|
muxPair.HandshakeAndServe(t)
|
||||||
|
|
||||||
|
stream, err := muxPair.EdgeMux.OpenStream(
|
||||||
|
[]Header{
|
||||||
|
{Name: ":method", Value: "RPC"},
|
||||||
|
{Name: ":scheme", Value: "capnp"},
|
||||||
|
{Name: ":path", Value: "*"},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error in OpenStream: %s", err)
|
||||||
|
}
|
||||||
|
stream.Close()
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
stream, err := muxPair.EdgeMux.OpenStream(
|
||||||
|
[]Header{
|
||||||
|
{Name: ":method", Value: "GET"},
|
||||||
|
{Name: ":scheme", Value: "https"},
|
||||||
|
{Name: ":authority", Value: "tunnel.otterlyadorable.co.uk"},
|
||||||
|
{Name: ":path", Value: "/get"},
|
||||||
|
{Name: "accept-encoding", Value: "gzip"},
|
||||||
|
{Name: "cf-ray", Value: "378948953f044408-SFO-DOG"},
|
||||||
|
{Name: "cf-visitor", Value: "{\"scheme\":\"https\"}"},
|
||||||
|
{Name: "cf-connecting-ip", Value: "2400:cb00:0025:010d:0000:0000:0000:0001"},
|
||||||
|
{Name: "x-forwarded-for", Value: "2400:cb00:0025:010d:0000:0000:0000:0001"},
|
||||||
|
{Name: "x-forwarded-proto", Value: "https"},
|
||||||
|
{Name: "accept-language", Value: "en-gb"},
|
||||||
|
{Name: "referer", Value: "https://tunnel.otterlyadorable.co.uk/"},
|
||||||
|
{Name: "cookie", Value: "__cfduid=d4555095065f92daedc059490771967d81493032162"},
|
||||||
|
{Name: "connection", Value: "Keep-Alive"},
|
||||||
|
{Name: "cf-ipcountry", Value: "US"},
|
||||||
|
{Name: "accept", Value: "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
|
||||||
|
{Name: "user-agent", Value: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_5) AppleWebKit/603.2.4 (KHTML, like Gecko) Version/10.1.1 Safari/603.2.4"},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error in OpenStream: %s", err)
|
||||||
|
}
|
||||||
|
if len(stream.Headers) == 0 {
|
||||||
|
t.Fatal("response has no headers")
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Name != ":status" {
|
||||||
|
t.Fatalf("first header should be status, found %s instead", stream.Headers[0].Name)
|
||||||
|
}
|
||||||
|
if stream.Headers[0].Value != "200" {
|
||||||
|
t.Fatalf("expected status 200, got %s", stream.Headers[0].Value)
|
||||||
|
}
|
||||||
|
ioutil.ReadAll(stream)
|
||||||
|
stream.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) {
|
||||||
|
errC := make(chan error)
|
||||||
|
go func() {
|
||||||
|
b := []byte{0}
|
||||||
|
n, err := pipe.Read(b)
|
||||||
|
if n > 0 {
|
||||||
|
t.Fatalf("read pipe was not empty")
|
||||||
|
}
|
||||||
|
errC <- err
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case err := <-errC:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read error: %s", err)
|
||||||
|
}
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
// nothing to read
|
||||||
|
pipe.Close()
|
||||||
|
<-errC
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,81 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IdleTimer is a type of Timer designed for managing heartbeats on an idle connection.
|
||||||
|
// The timer ticks on an interval with added jitter to avoid accidental synchronisation
|
||||||
|
// between two endpoints. It tracks the number of retries/ticks since the connection was
|
||||||
|
// last marked active.
|
||||||
|
//
|
||||||
|
// The methods of IdleTimer must not be called while a goroutine is reading from C.
|
||||||
|
type IdleTimer struct {
|
||||||
|
// The channel on which ticks are delivered.
|
||||||
|
C <-chan time.Time
|
||||||
|
|
||||||
|
// A timer used to measure idle connection time. Reset after sending data.
|
||||||
|
idleTimer *time.Timer
|
||||||
|
// The maximum length of time a connection is idle before sending a ping.
|
||||||
|
idleDuration time.Duration
|
||||||
|
// A pseudorandom source used to add jitter to the idle duration.
|
||||||
|
randomSource *rand.Rand
|
||||||
|
// The maximum number of retries allowed.
|
||||||
|
maxRetries uint64
|
||||||
|
// The number of retries since the connection was last marked active.
|
||||||
|
retries uint64
|
||||||
|
// A lock to prevent race condition while checking retries
|
||||||
|
stateLock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIdleTimer(idleDuration time.Duration, maxRetries uint64) *IdleTimer {
|
||||||
|
t := &IdleTimer{
|
||||||
|
idleTimer: time.NewTimer(idleDuration),
|
||||||
|
idleDuration: idleDuration,
|
||||||
|
randomSource: rand.New(rand.NewSource(time.Now().Unix())),
|
||||||
|
maxRetries: maxRetries,
|
||||||
|
}
|
||||||
|
t.C = t.idleTimer.C
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry should be called when retrying the idle timeout. If the maximum number of retries
|
||||||
|
// has been met, returns false.
|
||||||
|
// After calling this function and sending a heartbeat, call ResetTimer. Since sending the
|
||||||
|
// heartbeat could be a blocking operation, we resetting the timer after the write completes
|
||||||
|
// to avoid it expiring during the write.
|
||||||
|
func (t *IdleTimer) Retry() bool {
|
||||||
|
t.stateLock.Lock()
|
||||||
|
defer t.stateLock.Unlock()
|
||||||
|
if t.retries >= t.maxRetries {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
t.retries++
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *IdleTimer) RetryCount() uint64 {
|
||||||
|
t.stateLock.RLock()
|
||||||
|
defer t.stateLock.RUnlock()
|
||||||
|
return t.retries
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkActive resets the idle connection timer and suppresses any outstanding idle events.
|
||||||
|
func (t *IdleTimer) MarkActive() {
|
||||||
|
if !t.idleTimer.Stop() {
|
||||||
|
// eat the timer event to prevent spurious pings
|
||||||
|
<-t.idleTimer.C
|
||||||
|
}
|
||||||
|
t.stateLock.Lock()
|
||||||
|
t.retries = 0
|
||||||
|
t.stateLock.Unlock()
|
||||||
|
t.ResetTimer()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset the idle timer according to the configured duration, with some added jitter.
|
||||||
|
func (t *IdleTimer) ResetTimer() {
|
||||||
|
jitter := time.Duration(t.randomSource.Int63n(int64(t.idleDuration)))
|
||||||
|
t.idleTimer.Reset(t.idleDuration + jitter)
|
||||||
|
}
|
|
@ -0,0 +1,31 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRetry(t *testing.T) {
|
||||||
|
timer := NewIdleTimer(time.Second, 2)
|
||||||
|
assert.Equal(t, uint64(0), timer.RetryCount())
|
||||||
|
ok := timer.Retry()
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, uint64(1), timer.RetryCount())
|
||||||
|
ok = timer.Retry()
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, uint64(2), timer.RetryCount())
|
||||||
|
ok = timer.Retry()
|
||||||
|
assert.False(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMarkActive(t *testing.T) {
|
||||||
|
timer := NewIdleTimer(time.Second, 2)
|
||||||
|
assert.Equal(t, uint64(0), timer.RetryCount())
|
||||||
|
ok := timer.Retry()
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, uint64(1), timer.RetryCount())
|
||||||
|
timer.MarkActive()
|
||||||
|
assert.Equal(t, uint64(0), timer.RetryCount())
|
||||||
|
}
|
|
@ -0,0 +1,250 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MuxedStream struct {
|
||||||
|
Headers []Header
|
||||||
|
|
||||||
|
streamID uint32
|
||||||
|
|
||||||
|
responseHeadersReceived chan struct{}
|
||||||
|
|
||||||
|
readBuffer *SharedBuffer
|
||||||
|
receiveWindow uint32
|
||||||
|
// current window size limit. Exponentially increase it when it's exhausted
|
||||||
|
receiveWindowCurrentMax uint32
|
||||||
|
// limit set in http2 spec. 2^31-1
|
||||||
|
receiveWindowMax uint32
|
||||||
|
// nonzero if a WINDOW_UPDATE frame for a stream needs to be sent
|
||||||
|
windowUpdate uint32
|
||||||
|
|
||||||
|
writeLock sync.Mutex
|
||||||
|
// The zero value for Buffer is an empty buffer ready to use.
|
||||||
|
writeBuffer bytes.Buffer
|
||||||
|
|
||||||
|
sendWindow uint32
|
||||||
|
|
||||||
|
readyList *ReadyList
|
||||||
|
headersSent bool
|
||||||
|
writeHeaders []Header
|
||||||
|
// true if the write end of this stream has been closed
|
||||||
|
writeEOF bool
|
||||||
|
// true if we have sent EOF to the peer
|
||||||
|
sentEOF bool
|
||||||
|
// true if the peer sent us an EOF
|
||||||
|
receivedEOF bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type flowControlWindow struct {
|
||||||
|
receiveWindow, sendWindow uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MuxedStream) Read(p []byte) (n int, err error) {
|
||||||
|
return s.readBuffer.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MuxedStream) Write(p []byte) (n int, err error) {
|
||||||
|
s.writeLock.Lock()
|
||||||
|
defer s.writeLock.Unlock()
|
||||||
|
if s.writeEOF {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n, err = s.writeBuffer.Write(p)
|
||||||
|
if n != len(p) || err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
s.writeNotify()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MuxedStream) Close() error {
|
||||||
|
// TUN-115: Close the write buffer before the read buffer.
|
||||||
|
// In the case of shutdown, read will not get new data, but the write buffer can still receive
|
||||||
|
// new data. Closing read before write allows application to race between a failed read and a
|
||||||
|
// successful write, even though this close should appear to be atomic.
|
||||||
|
// This can't happen the other way because reads may succeed after a failed write; if we read
|
||||||
|
// past EOF the application will block until we close the buffer.
|
||||||
|
err := s.CloseWrite()
|
||||||
|
if err != nil {
|
||||||
|
if s.CloseRead() == nil {
|
||||||
|
// don't bother the caller with errors if at least one close succeeded
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.CloseRead()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MuxedStream) CloseRead() error {
|
||||||
|
return s.readBuffer.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MuxedStream) CloseWrite() error {
|
||||||
|
s.writeLock.Lock()
|
||||||
|
defer s.writeLock.Unlock()
|
||||||
|
if s.writeEOF {
|
||||||
|
return io.EOF
|
||||||
|
}
|
||||||
|
s.writeEOF = true
|
||||||
|
s.writeNotify()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MuxedStream) WriteHeaders(headers []Header) error {
|
||||||
|
s.writeLock.Lock()
|
||||||
|
defer s.writeLock.Unlock()
|
||||||
|
if s.writeHeaders != nil {
|
||||||
|
return ErrStreamHeadersSent
|
||||||
|
}
|
||||||
|
s.writeHeaders = headers
|
||||||
|
s.writeNotify()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MuxedStream) FlowControlWindow() *flowControlWindow {
|
||||||
|
s.writeLock.Lock()
|
||||||
|
defer s.writeLock.Unlock()
|
||||||
|
return &flowControlWindow{
|
||||||
|
receiveWindow: s.receiveWindow,
|
||||||
|
sendWindow: s.sendWindow,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeNotify must happen while holding writeLock.
|
||||||
|
func (s *MuxedStream) writeNotify() {
|
||||||
|
s.readyList.Signal(s.streamID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call by muxreader when it gets a WindowUpdateFrame. This is an update of the peer's
|
||||||
|
// receive window (how much data we can send).
|
||||||
|
func (s *MuxedStream) replenishSendWindow(bytes uint32) {
|
||||||
|
s.writeLock.Lock()
|
||||||
|
s.sendWindow += bytes
|
||||||
|
s.writeNotify()
|
||||||
|
s.writeLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call by muxreader when it receives a data frame
|
||||||
|
func (s *MuxedStream) consumeReceiveWindow(bytes uint32) bool {
|
||||||
|
s.writeLock.Lock()
|
||||||
|
defer s.writeLock.Unlock()
|
||||||
|
// received data size is greater than receive window/buffer
|
||||||
|
if s.receiveWindow < bytes {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s.receiveWindow -= bytes
|
||||||
|
if s.receiveWindow < s.receiveWindowCurrentMax/2 {
|
||||||
|
// exhausting client send window (how much data client can send)
|
||||||
|
if s.receiveWindowCurrentMax < s.receiveWindowMax {
|
||||||
|
s.receiveWindowCurrentMax <<= 1
|
||||||
|
}
|
||||||
|
s.windowUpdate += s.receiveWindowCurrentMax - s.receiveWindow
|
||||||
|
s.writeNotify()
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// receiveEOF should be called when the peer indicates no more data will be sent.
|
||||||
|
// Returns true if the socket is now closed (i.e. the write side is already closed).
|
||||||
|
func (s *MuxedStream) receiveEOF() (closed bool) {
|
||||||
|
s.writeLock.Lock()
|
||||||
|
defer s.writeLock.Unlock()
|
||||||
|
s.receivedEOF = true
|
||||||
|
s.CloseRead()
|
||||||
|
return s.writeEOF && s.writeBuffer.Len() == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MuxedStream) gotReceiveEOF() bool {
|
||||||
|
s.writeLock.Lock()
|
||||||
|
defer s.writeLock.Unlock()
|
||||||
|
return s.receivedEOF
|
||||||
|
}
|
||||||
|
|
||||||
|
// MuxedStreamReader implements io.ReadCloser for the read end of the stream.
|
||||||
|
// This is useful for passing to functions that close the object after it is done reading,
|
||||||
|
// but you still want to be able to write data afterwards (e.g. http.Client).
|
||||||
|
type MuxedStreamReader struct {
|
||||||
|
*MuxedStream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s MuxedStreamReader) Read(p []byte) (n int, err error) {
|
||||||
|
return s.MuxedStream.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s MuxedStreamReader) Close() error {
|
||||||
|
return s.MuxedStream.CloseRead()
|
||||||
|
}
|
||||||
|
|
||||||
|
// streamChunk represents a chunk of data to be written.
|
||||||
|
type streamChunk struct {
|
||||||
|
streamID uint32
|
||||||
|
// true if a HEADERS frame should be sent
|
||||||
|
sendHeaders bool
|
||||||
|
headers []Header
|
||||||
|
// nonzero if a WINDOW_UPDATE frame should be sent
|
||||||
|
windowUpdate uint32
|
||||||
|
// true if data frames should be sent
|
||||||
|
sendData bool
|
||||||
|
eof bool
|
||||||
|
buffer bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
// getChunk atomically extracts a chunk of data to be written by MuxWriter.
|
||||||
|
// The data returned will not exceed the send window for this stream.
|
||||||
|
func (s *MuxedStream) getChunk() *streamChunk {
|
||||||
|
s.writeLock.Lock()
|
||||||
|
defer s.writeLock.Unlock()
|
||||||
|
|
||||||
|
chunk := &streamChunk{
|
||||||
|
streamID: s.streamID,
|
||||||
|
sendHeaders: !s.headersSent,
|
||||||
|
headers: s.writeHeaders,
|
||||||
|
windowUpdate: s.windowUpdate,
|
||||||
|
sendData: !s.sentEOF,
|
||||||
|
eof: s.writeEOF && uint32(s.writeBuffer.Len()) <= s.sendWindow,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copies at most s.sendWindow bytes
|
||||||
|
//log.Infof("writeBuffer len %d stream %d", s.writeBuffer.Len(), s.streamID)
|
||||||
|
writeLen, _ := io.CopyN(&chunk.buffer, &s.writeBuffer, int64(s.sendWindow))
|
||||||
|
//log.Infof("writeLen %d stream %d", writeLen, s.streamID)
|
||||||
|
s.sendWindow -= uint32(writeLen)
|
||||||
|
s.receiveWindow += s.windowUpdate
|
||||||
|
s.windowUpdate = 0
|
||||||
|
s.headersSent = true
|
||||||
|
|
||||||
|
// if this chunk contains the end of the stream, close the stream now
|
||||||
|
if chunk.sendData && chunk.eof {
|
||||||
|
s.sentEOF = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return chunk
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *streamChunk) sendHeadersFrame() bool {
|
||||||
|
return c.sendHeaders
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *streamChunk) sendWindowUpdateFrame() bool {
|
||||||
|
return c.windowUpdate > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *streamChunk) sendDataFrame() bool {
|
||||||
|
return c.sendData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *streamChunk) nextDataFrame(frameSize int) (payload []byte, endStream bool) {
|
||||||
|
payload = c.buffer.Next(frameSize)
|
||||||
|
if c.buffer.Len() == 0 {
|
||||||
|
// this is the last data frame in this chunk
|
||||||
|
c.sendData = false
|
||||||
|
if c.eof {
|
||||||
|
endStream = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testWindowSize uint32 = 65535
|
||||||
|
const testMaxWindowSize uint32 = testWindowSize << 2
|
||||||
|
|
||||||
|
// Only sending WINDOW_UPDATE frame, so sendWindow should never change
|
||||||
|
func TestFlowControlSingleStream(t *testing.T) {
|
||||||
|
stream := &MuxedStream{
|
||||||
|
responseHeadersReceived: make(chan struct{}),
|
||||||
|
readBuffer: NewSharedBuffer(),
|
||||||
|
receiveWindow: testWindowSize,
|
||||||
|
receiveWindowCurrentMax: testWindowSize,
|
||||||
|
receiveWindowMax: testMaxWindowSize,
|
||||||
|
sendWindow: testWindowSize,
|
||||||
|
readyList: NewReadyList(),
|
||||||
|
}
|
||||||
|
assert.True(t, stream.consumeReceiveWindow(testWindowSize/2))
|
||||||
|
dataSent := testWindowSize / 2
|
||||||
|
assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow)
|
||||||
|
assert.Equal(t, testWindowSize, stream.receiveWindowCurrentMax)
|
||||||
|
assert.Equal(t, uint32(0), stream.windowUpdate)
|
||||||
|
tempWindowUpdate := stream.windowUpdate
|
||||||
|
|
||||||
|
streamChunk := stream.getChunk()
|
||||||
|
assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate)
|
||||||
|
assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow)
|
||||||
|
assert.Equal(t, uint32(0), stream.windowUpdate)
|
||||||
|
assert.Equal(t, testWindowSize, stream.sendWindow)
|
||||||
|
|
||||||
|
assert.True(t, stream.consumeReceiveWindow(2))
|
||||||
|
dataSent += 2
|
||||||
|
assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow)
|
||||||
|
assert.Equal(t, testWindowSize<<1, stream.receiveWindowCurrentMax)
|
||||||
|
assert.Equal(t, (testWindowSize<<1)-stream.receiveWindow, stream.windowUpdate)
|
||||||
|
tempWindowUpdate = stream.windowUpdate
|
||||||
|
|
||||||
|
streamChunk = stream.getChunk()
|
||||||
|
assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate)
|
||||||
|
assert.Equal(t, testWindowSize<<1, stream.receiveWindow)
|
||||||
|
assert.Equal(t, uint32(0), stream.windowUpdate)
|
||||||
|
assert.Equal(t, testWindowSize, stream.sendWindow)
|
||||||
|
|
||||||
|
assert.True(t, stream.consumeReceiveWindow(testWindowSize+10))
|
||||||
|
dataSent = testWindowSize + 10
|
||||||
|
assert.Equal(t, (testWindowSize<<1)-dataSent, stream.receiveWindow)
|
||||||
|
assert.Equal(t, testWindowSize<<2, stream.receiveWindowCurrentMax)
|
||||||
|
assert.Equal(t, (testWindowSize<<2)-stream.receiveWindow, stream.windowUpdate)
|
||||||
|
tempWindowUpdate = stream.windowUpdate
|
||||||
|
|
||||||
|
streamChunk = stream.getChunk()
|
||||||
|
assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate)
|
||||||
|
assert.Equal(t, testWindowSize<<2, stream.receiveWindow)
|
||||||
|
assert.Equal(t, uint32(0), stream.windowUpdate)
|
||||||
|
assert.Equal(t, testWindowSize, stream.sendWindow)
|
||||||
|
|
||||||
|
assert.False(t, stream.consumeReceiveWindow(testMaxWindowSize+1))
|
||||||
|
assert.Equal(t, testWindowSize<<2, stream.receiveWindow)
|
||||||
|
assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax)
|
||||||
|
}
|
|
@ -0,0 +1,326 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MuxReader struct {
|
||||||
|
// f is used to read HTTP2 frames.
|
||||||
|
f *http2.Framer
|
||||||
|
// handler provides a callback to receive new streams. if nil, new streams cannot be accepted.
|
||||||
|
handler MuxedStreamHandler
|
||||||
|
// streams tracks currently-open streams.
|
||||||
|
streams *activeStreamMap
|
||||||
|
// readyList is used to signal writable streams.
|
||||||
|
readyList *ReadyList
|
||||||
|
// streamErrors lets us report stream errors to the MuxWriter.
|
||||||
|
streamErrors *StreamErrorMap
|
||||||
|
// goAwayChan is used to tell the writer to send a GOAWAY message.
|
||||||
|
goAwayChan chan<- http2.ErrCode
|
||||||
|
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
|
||||||
|
abortChan <-chan struct{}
|
||||||
|
// pingTimestamp is an atomic value containing the latest received ping timestamp.
|
||||||
|
pingTimestamp *PingTimestamp
|
||||||
|
// connActive is used to signal to the writer that something happened on the connection.
|
||||||
|
// This is used to clear idle timeout disconnection deadlines.
|
||||||
|
connActive Signal
|
||||||
|
// The initial value for the send and receive window of a new stream.
|
||||||
|
initialStreamWindow uint32
|
||||||
|
// The max value for the send window of a stream.
|
||||||
|
streamWindowMax uint32
|
||||||
|
// windowMetrics keeps track of min/max/average of send/receive windows for all streams
|
||||||
|
flowControlMetrics *FlowControlMetrics
|
||||||
|
metricsMutex sync.Mutex
|
||||||
|
// r is a reference to the underlying connection used when shutting down.
|
||||||
|
r io.Closer
|
||||||
|
// rttMeasurement measures RTT based on ping timestamps.
|
||||||
|
rttMeasurement RTTMeasurement
|
||||||
|
rttMutex sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MuxReader) Shutdown() {
|
||||||
|
done := r.streams.Shutdown()
|
||||||
|
if done == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.sendGoAway(http2.ErrCodeNo)
|
||||||
|
go func() {
|
||||||
|
// close reader side when last stream ends; this will cause the writer to abort
|
||||||
|
<-done
|
||||||
|
r.r.Close()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MuxReader) RTT() RTTMeasurement {
|
||||||
|
r.rttMutex.Lock()
|
||||||
|
defer r.rttMutex.Unlock()
|
||||||
|
return r.rttMeasurement
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MuxReader) FlowControlMetrics() *FlowControlMetrics {
|
||||||
|
r.metricsMutex.Lock()
|
||||||
|
defer r.metricsMutex.Unlock()
|
||||||
|
if r.flowControlMetrics != nil {
|
||||||
|
return r.flowControlMetrics
|
||||||
|
}
|
||||||
|
// No metrics available yet
|
||||||
|
return &FlowControlMetrics{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MuxReader) run(parentLogger *log.Entry) error {
|
||||||
|
logger := parentLogger.WithFields(log.Fields{
|
||||||
|
"subsystem": "mux",
|
||||||
|
"dir": "read",
|
||||||
|
})
|
||||||
|
defer logger.Debug("event loop finished")
|
||||||
|
for {
|
||||||
|
frame, err := r.f.ReadFrame()
|
||||||
|
if err != nil {
|
||||||
|
switch e := err.(type) {
|
||||||
|
case http2.StreamError:
|
||||||
|
logger.WithError(err).Warn("stream error")
|
||||||
|
r.streamError(e.StreamID, e.Code)
|
||||||
|
case http2.ConnectionError:
|
||||||
|
logger.WithError(err).Warn("connection error")
|
||||||
|
return r.connectionError(err)
|
||||||
|
default:
|
||||||
|
if isConnectionClosedError(err) {
|
||||||
|
if r.streams.Len() == 0 {
|
||||||
|
logger.Debug("shutting down")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
logger.Warn("connection closed unexpectedly")
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
logger.WithError(err).Warn("frame read error")
|
||||||
|
return r.connectionError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.connActive.Signal()
|
||||||
|
logger.WithField("data", frame).Debug("read frame")
|
||||||
|
switch f := frame.(type) {
|
||||||
|
case *http2.DataFrame:
|
||||||
|
err = r.receiveFrameData(f, logger)
|
||||||
|
case *http2.MetaHeadersFrame:
|
||||||
|
err = r.receiveHeaderData(f)
|
||||||
|
case *http2.RSTStreamFrame:
|
||||||
|
streamID := f.Header().StreamID
|
||||||
|
if streamID == 0 {
|
||||||
|
return ErrInvalidStream
|
||||||
|
}
|
||||||
|
r.streams.Delete(streamID)
|
||||||
|
case *http2.PingFrame:
|
||||||
|
r.receivePingData(f)
|
||||||
|
case *http2.GoAwayFrame:
|
||||||
|
err = r.receiveGoAway(f)
|
||||||
|
case *http2.WindowUpdateFrame:
|
||||||
|
err = r.updateStreamWindow(f)
|
||||||
|
default:
|
||||||
|
err = ErrUnexpectedFrameType
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logger.WithField("data", frame).WithError(err).Debug("frame error")
|
||||||
|
return r.connectionError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MuxReader) newMuxedStream(streamID uint32) *MuxedStream {
|
||||||
|
return &MuxedStream{
|
||||||
|
streamID: streamID,
|
||||||
|
readBuffer: NewSharedBuffer(),
|
||||||
|
receiveWindow: r.initialStreamWindow,
|
||||||
|
receiveWindowCurrentMax: r.initialStreamWindow,
|
||||||
|
receiveWindowMax: r.streamWindowMax,
|
||||||
|
sendWindow: r.initialStreamWindow,
|
||||||
|
readyList: r.readyList,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getStreamForFrame returns a stream if valid, or an error describing why the stream could not be returned.
|
||||||
|
func (r *MuxReader) getStreamForFrame(frame http2.Frame) (*MuxedStream, error) {
|
||||||
|
sid := frame.Header().StreamID
|
||||||
|
if sid == 0 {
|
||||||
|
return nil, ErrUnexpectedFrameType
|
||||||
|
}
|
||||||
|
if stream, ok := r.streams.Get(sid); ok {
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
if r.streams.IsLocalStreamID(sid) {
|
||||||
|
// no stream available, but no error
|
||||||
|
return nil, ErrClosedStream
|
||||||
|
}
|
||||||
|
if sid < r.streams.LastPeerStreamID() {
|
||||||
|
// no stream available, stream closed error
|
||||||
|
return nil, ErrClosedStream
|
||||||
|
}
|
||||||
|
return nil, ErrUnknownStream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MuxReader) defaultStreamErrorHandler(err error, header http2.FrameHeader) error {
|
||||||
|
if header.Flags.Has(http2.FlagHeadersEndStream) {
|
||||||
|
return nil
|
||||||
|
} else if err == ErrUnknownStream || err == ErrClosedStream {
|
||||||
|
return r.streamError(header.StreamID, http2.ErrCodeStreamClosed)
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receives header frames from a stream. A non-nil error is a connection error.
|
||||||
|
func (r *MuxReader) receiveHeaderData(frame *http2.MetaHeadersFrame) error {
|
||||||
|
var stream *MuxedStream
|
||||||
|
sid := frame.Header().StreamID
|
||||||
|
if sid == 0 {
|
||||||
|
return ErrUnexpectedFrameType
|
||||||
|
}
|
||||||
|
newStream := r.streams.IsPeerStreamID(sid)
|
||||||
|
if newStream {
|
||||||
|
// header request
|
||||||
|
// TODO support trailers (if stream exists)
|
||||||
|
ok, err := r.streams.AcquirePeerID(sid)
|
||||||
|
if !ok {
|
||||||
|
// ignore new streams while shutting down
|
||||||
|
return r.streamError(sid, err)
|
||||||
|
}
|
||||||
|
stream = r.newMuxedStream(sid)
|
||||||
|
// Set stream. Returns false if a stream already existed with that ID or we are shutting down, return false.
|
||||||
|
if !r.streams.Set(stream) {
|
||||||
|
// got HEADERS frame for an existing stream
|
||||||
|
// TODO support trailers
|
||||||
|
return r.streamError(sid, http2.ErrCodeInternal)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// header response
|
||||||
|
var err error
|
||||||
|
if stream, err = r.getStreamForFrame(frame); err != nil {
|
||||||
|
return r.defaultStreamErrorHandler(err, frame.Header())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
headers := make([]Header, len(frame.Fields))
|
||||||
|
for i, header := range frame.Fields {
|
||||||
|
headers[i].Name = header.Name
|
||||||
|
headers[i].Value = header.Value
|
||||||
|
}
|
||||||
|
stream.Headers = headers
|
||||||
|
if frame.Header().Flags.Has(http2.FlagHeadersEndStream) {
|
||||||
|
stream.receiveEOF()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if newStream {
|
||||||
|
go r.handleStream(stream)
|
||||||
|
} else {
|
||||||
|
close(stream.responseHeadersReceived)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MuxReader) handleStream(stream *MuxedStream) {
|
||||||
|
defer stream.Close()
|
||||||
|
r.handler.ServeStream(stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receives a data frame from a stream. A non-nil error is a connection error.
|
||||||
|
func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.Entry) error {
|
||||||
|
logger := parentLogger.WithField("stream", frame.Header().StreamID)
|
||||||
|
stream, err := r.getStreamForFrame(frame)
|
||||||
|
if err != nil {
|
||||||
|
return r.defaultStreamErrorHandler(err, frame.Header())
|
||||||
|
}
|
||||||
|
data := frame.Data()
|
||||||
|
if len(data) > 0 {
|
||||||
|
_, err = stream.readBuffer.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
return r.streamError(stream.streamID, http2.ErrCodeInternal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if frame.Header().Flags.Has(http2.FlagDataEndStream) {
|
||||||
|
if stream.receiveEOF() {
|
||||||
|
r.streams.Delete(stream.streamID)
|
||||||
|
logger.Debug("stream closed")
|
||||||
|
} else {
|
||||||
|
logger.Debug("shutdown receive side")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !stream.consumeReceiveWindow(uint32(len(data))) {
|
||||||
|
return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive a PING from the peer. Update RTT and send/receive window metrics if it's an ACK.
|
||||||
|
func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
|
||||||
|
ts := int64(binary.LittleEndian.Uint64(frame.Data[:]))
|
||||||
|
if !frame.IsAck() {
|
||||||
|
r.pingTimestamp.Set(ts)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.rttMutex.Lock()
|
||||||
|
r.rttMeasurement.Update(time.Unix(0, ts))
|
||||||
|
r.rttMutex.Unlock()
|
||||||
|
r.flowControlMetrics = r.streams.Metrics()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive a GOAWAY from the peer. Gracefully shut down our connection.
|
||||||
|
func (r *MuxReader) receiveGoAway(frame *http2.GoAwayFrame) error {
|
||||||
|
r.Shutdown()
|
||||||
|
// Close all streams above the last processed stream
|
||||||
|
lastStream := r.streams.LastLocalStreamID()
|
||||||
|
for i := frame.LastStreamID + 2; i <= lastStream; i++ {
|
||||||
|
if stream, ok := r.streams.Get(i); ok {
|
||||||
|
stream.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receives header frames from a stream. A non-nil error is a connection error.
|
||||||
|
func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
|
||||||
|
stream, err := r.getStreamForFrame(frame)
|
||||||
|
if err != nil && err != ErrUnknownStream && err != ErrClosedStream {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if stream == nil {
|
||||||
|
// ignore window updates on closed streams
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
stream.replenishSendWindow(frame.Increment)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Raise a stream processing error, closing the stream. Runs on the write thread.
|
||||||
|
func (r *MuxReader) streamError(streamID uint32, e http2.ErrCode) error {
|
||||||
|
r.streamErrors.RaiseError(streamID, e)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MuxReader) connectionError(err error) error {
|
||||||
|
http2Code := http2.ErrCodeInternal
|
||||||
|
switch e := err.(type) {
|
||||||
|
case http2.ConnectionError:
|
||||||
|
http2Code = http2.ErrCode(e)
|
||||||
|
case MuxerProtocolError:
|
||||||
|
http2Code = e.h2code
|
||||||
|
}
|
||||||
|
log.Warnf("Connection error %v", http2Code)
|
||||||
|
r.sendGoAway(http2Code)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instruct the writer to send a GOAWAY message if possible. This may fail in
|
||||||
|
// the case where an existing GOAWAY message is in flight or the writer event
|
||||||
|
// loop already ended.
|
||||||
|
func (r *MuxReader) sendGoAway(errCode http2.ErrCode) {
|
||||||
|
select {
|
||||||
|
case r.goAwayChan <- errCode:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,238 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/http2/hpack"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MuxWriter struct {
|
||||||
|
// f is used to write HTTP2 frames.
|
||||||
|
f *http2.Framer
|
||||||
|
// streams tracks currently-open streams.
|
||||||
|
streams *activeStreamMap
|
||||||
|
// streamErrors receives stream errors raised by the MuxReader.
|
||||||
|
streamErrors *StreamErrorMap
|
||||||
|
// readyStreamChan is used to multiplex writable streams onto the single connection.
|
||||||
|
// When a stream becomes writable its ID is sent on this channel.
|
||||||
|
readyStreamChan <-chan uint32
|
||||||
|
// newStreamChan is used to create new streams with a given set of headers.
|
||||||
|
newStreamChan <-chan MuxedStreamRequest
|
||||||
|
// goAwayChan is used to send a single GOAWAY message to the peer. The element received
|
||||||
|
// is the HTTP/2 error code to send.
|
||||||
|
goAwayChan <-chan http2.ErrCode
|
||||||
|
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
|
||||||
|
abortChan <-chan struct{}
|
||||||
|
// pingTimestamp is an atomic value containing the latest received ping timestamp.
|
||||||
|
pingTimestamp *PingTimestamp
|
||||||
|
// A timer used to measure idle connection time. Reset after sending data.
|
||||||
|
idleTimer *IdleTimer
|
||||||
|
// connActiveChan receives a signal that the connection received some (read) activity.
|
||||||
|
connActiveChan <-chan struct{}
|
||||||
|
// Maximum size of all frames that can be sent on this connection.
|
||||||
|
maxFrameSize uint32
|
||||||
|
// headerEncoder is the stateful header encoder for this connection
|
||||||
|
headerEncoder *hpack.Encoder
|
||||||
|
// headerBuffer is the temporary buffer used by headerEncoder.
|
||||||
|
headerBuffer bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
type MuxedStreamRequest struct {
|
||||||
|
stream *MuxedStream
|
||||||
|
body io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *MuxedStreamRequest) flushBody() {
|
||||||
|
io.Copy(r.stream, r.body)
|
||||||
|
r.stream.CloseWrite()
|
||||||
|
}
|
||||||
|
|
||||||
|
func tsToPingData(ts int64) [8]byte {
|
||||||
|
pingData := [8]byte{}
|
||||||
|
binary.LittleEndian.PutUint64(pingData[:], uint64(ts))
|
||||||
|
return pingData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *MuxWriter) run(parentLogger *log.Entry) error {
|
||||||
|
logger := parentLogger.WithFields(log.Fields{
|
||||||
|
"subsystem": "mux",
|
||||||
|
"dir": "write",
|
||||||
|
})
|
||||||
|
defer logger.Debug("event loop finished")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-w.abortChan:
|
||||||
|
logger.Debug("aborting writer thread")
|
||||||
|
return nil
|
||||||
|
case errCode := <-w.goAwayChan:
|
||||||
|
logger.Debug("sending GOAWAY code ", errCode)
|
||||||
|
err := w.f.WriteGoAway(w.streams.LastPeerStreamID(), errCode, []byte{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.idleTimer.MarkActive()
|
||||||
|
case <-w.pingTimestamp.GetUpdateChan():
|
||||||
|
logger.Debug("sending PING ACK")
|
||||||
|
err := w.f.WritePing(true, tsToPingData(w.pingTimestamp.Get()))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.idleTimer.MarkActive()
|
||||||
|
case <-w.idleTimer.C:
|
||||||
|
if !w.idleTimer.Retry() {
|
||||||
|
return ErrConnectionDropped
|
||||||
|
}
|
||||||
|
logger.Debug("sending PING")
|
||||||
|
err := w.f.WritePing(false, tsToPingData(time.Now().UnixNano()))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.idleTimer.ResetTimer()
|
||||||
|
case <-w.connActiveChan:
|
||||||
|
w.idleTimer.MarkActive()
|
||||||
|
case <-w.streamErrors.GetSignalChan():
|
||||||
|
for streamID, errCode := range w.streamErrors.GetErrors() {
|
||||||
|
logger.WithField("stream", streamID).WithField("code", errCode).Debug("resetting stream")
|
||||||
|
err := w.f.WriteRSTStream(streamID, errCode)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.idleTimer.MarkActive()
|
||||||
|
case streamRequest := <-w.newStreamChan:
|
||||||
|
streamID := w.streams.AcquireLocalID()
|
||||||
|
streamRequest.stream.streamID = streamID
|
||||||
|
if !w.streams.Set(streamRequest.stream) {
|
||||||
|
// Race between OpenStream and Shutdown, and Shutdown won. Let Shutdown (and the eventual abort) take
|
||||||
|
// care of this stream. Ideally we'd pass the error directly to the stream object somehow so the
|
||||||
|
// caller can be unblocked sooner, but the value of that optimisation is minimal for most of the
|
||||||
|
// reasons why you'd call Shutdown anyway.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if streamRequest.body != nil {
|
||||||
|
go streamRequest.flushBody()
|
||||||
|
}
|
||||||
|
streamLogger := logger.WithField("stream", streamID)
|
||||||
|
err := w.writeStreamData(streamRequest.stream, streamLogger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.idleTimer.MarkActive()
|
||||||
|
case streamID := <-w.readyStreamChan:
|
||||||
|
streamLogger := logger.WithField("stream", streamID)
|
||||||
|
stream, ok := w.streams.Get(streamID)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err := w.writeStreamData(stream, streamLogger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
w.idleTimer.MarkActive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error {
|
||||||
|
logger.Debug("writable")
|
||||||
|
chunk := stream.getChunk()
|
||||||
|
|
||||||
|
if chunk.sendHeadersFrame() {
|
||||||
|
err := w.writeHeaders(chunk.streamID, chunk.headers)
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Warn("error writing headers")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Debug("output headers")
|
||||||
|
}
|
||||||
|
|
||||||
|
if chunk.sendWindowUpdateFrame() {
|
||||||
|
// Send a WINDOW_UPDATE frame to update our receive window.
|
||||||
|
// If the Stream ID is zero, the window update applies to the connection as a whole
|
||||||
|
// A WINDOW_UPDATE in a specific stream applies to the connection-level flow control as well.
|
||||||
|
err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Warn("error writing window update")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Debugf("increment receive window by %d", chunk.windowUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
for chunk.sendDataFrame() {
|
||||||
|
payload, sentEOF := chunk.nextDataFrame(int(w.maxFrameSize))
|
||||||
|
err := w.f.WriteData(chunk.streamID, sentEOF, payload)
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Warn("error writing data")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.WithField("len", len(payload)).Debug("output data")
|
||||||
|
|
||||||
|
if sentEOF {
|
||||||
|
if stream.readBuffer.Closed() {
|
||||||
|
// transition into closed state
|
||||||
|
if !stream.gotReceiveEOF() {
|
||||||
|
// the peer may send data that we no longer want to receive. Force them into the
|
||||||
|
// closed state.
|
||||||
|
logger.Debug("resetting stream")
|
||||||
|
w.f.WriteRSTStream(chunk.streamID, http2.ErrCodeNo)
|
||||||
|
} else {
|
||||||
|
// Half-open stream transitioned into closed
|
||||||
|
logger.Debug("closing stream")
|
||||||
|
}
|
||||||
|
w.streams.Delete(chunk.streamID)
|
||||||
|
} else {
|
||||||
|
logger.Debug("closing stream write side")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *MuxWriter) encodeHeaders(headers []Header) ([]byte, error) {
|
||||||
|
w.headerBuffer.Reset()
|
||||||
|
for _, header := range headers {
|
||||||
|
err := w.headerEncoder.WriteField(hpack.HeaderField{
|
||||||
|
Name: header.Name,
|
||||||
|
Value: header.Value,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return w.headerBuffer.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeHeaders writes a block of encoded headers, splitting it into multiple frames if necessary.
|
||||||
|
func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error {
|
||||||
|
encodedHeaders, err := w.encodeHeaders(headers)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
blockSize := int(w.maxFrameSize)
|
||||||
|
continuation := false
|
||||||
|
endHeaders := len(encodedHeaders) == 0
|
||||||
|
for !endHeaders && err == nil {
|
||||||
|
blockFragment := encodedHeaders
|
||||||
|
if len(encodedHeaders) > blockSize {
|
||||||
|
blockFragment = blockFragment[:blockSize]
|
||||||
|
encodedHeaders = encodedHeaders[blockSize:]
|
||||||
|
} else {
|
||||||
|
endHeaders = true
|
||||||
|
}
|
||||||
|
if continuation {
|
||||||
|
err = w.f.WriteContinuation(streamID, endHeaders, blockFragment)
|
||||||
|
} else {
|
||||||
|
err = w.f.WriteHeaders(http2.HeadersFrameParam{
|
||||||
|
StreamID: streamID,
|
||||||
|
EndHeaders: endHeaders,
|
||||||
|
BlockFragment: blockFragment,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,140 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
// ReadyList multiplexes several event signals onto a single channel.
|
||||||
|
type ReadyList struct {
|
||||||
|
signalC chan uint32
|
||||||
|
waitC chan uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewReadyList() *ReadyList {
|
||||||
|
rl := &ReadyList{
|
||||||
|
signalC: make(chan uint32),
|
||||||
|
waitC: make(chan uint32),
|
||||||
|
}
|
||||||
|
go rl.run()
|
||||||
|
return rl
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID is the stream ID
|
||||||
|
func (r *ReadyList) Signal(ID uint32) {
|
||||||
|
r.signalC <- ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ReadyList) ReadyChannel() <-chan uint32 {
|
||||||
|
return r.waitC
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ReadyList) Close() {
|
||||||
|
close(r.signalC)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ReadyList) run() {
|
||||||
|
defer close(r.waitC)
|
||||||
|
var queue readyDescriptorQueue
|
||||||
|
var firstReady *readyDescriptor
|
||||||
|
activeDescriptors := newReadyDescriptorMap()
|
||||||
|
for {
|
||||||
|
if firstReady == nil {
|
||||||
|
// Wait for first ready descriptor
|
||||||
|
i, ok := <-r.signalC
|
||||||
|
if !ok {
|
||||||
|
// closed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
firstReady = activeDescriptors.SetIfMissing(i)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case r.waitC <- firstReady.ID:
|
||||||
|
activeDescriptors.Delete(firstReady.ID)
|
||||||
|
firstReady = queue.Dequeue()
|
||||||
|
case i, ok := <-r.signalC:
|
||||||
|
if !ok {
|
||||||
|
// closed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newReady := activeDescriptors.SetIfMissing(i)
|
||||||
|
if newReady != nil {
|
||||||
|
// key doesn't exist
|
||||||
|
queue.Enqueue(newReady)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type readyDescriptor struct {
|
||||||
|
ID uint32
|
||||||
|
Next *readyDescriptor
|
||||||
|
}
|
||||||
|
|
||||||
|
// readyDescriptorQueue is a queue of readyDescriptors in the form of a singly-linked list.
|
||||||
|
// The nil readyDescriptorQueue is an empty queue ready for use.
|
||||||
|
type readyDescriptorQueue struct {
|
||||||
|
Head *readyDescriptor
|
||||||
|
Tail *readyDescriptor
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *readyDescriptorQueue) Empty() bool {
|
||||||
|
return q.Head == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *readyDescriptorQueue) Enqueue(x *readyDescriptor) {
|
||||||
|
if x.Next != nil {
|
||||||
|
panic("enqueued already queued item")
|
||||||
|
}
|
||||||
|
if q.Empty() {
|
||||||
|
q.Head = x
|
||||||
|
q.Tail = x
|
||||||
|
} else {
|
||||||
|
q.Tail.Next = x
|
||||||
|
q.Tail = x
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dequeue returns the first readyDescriptor in the queue, or nil if empty.
|
||||||
|
func (q *readyDescriptorQueue) Dequeue() *readyDescriptor {
|
||||||
|
if q.Empty() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
x := q.Head
|
||||||
|
q.Head = x.Next
|
||||||
|
x.Next = nil
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
|
// readyDescriptorQueue is a map of readyDescriptors keyed by ID.
|
||||||
|
// It maintains a free list of deleted ready descriptors.
|
||||||
|
type readyDescriptorMap struct {
|
||||||
|
descriptors map[uint32]*readyDescriptor
|
||||||
|
free []*readyDescriptor
|
||||||
|
}
|
||||||
|
|
||||||
|
func newReadyDescriptorMap() *readyDescriptorMap {
|
||||||
|
return &readyDescriptorMap{descriptors: make(map[uint32]*readyDescriptor)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create or reuse a readyDescriptor if the stream is not in the queue.
|
||||||
|
// This avoid stream starvation caused by a single high-bandwidth stream monopolising the writer goroutine
|
||||||
|
func (m *readyDescriptorMap) SetIfMissing(key uint32) *readyDescriptor {
|
||||||
|
if _, ok := m.descriptors[key]; ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var newDescriptor *readyDescriptor
|
||||||
|
if len(m.free) > 0 {
|
||||||
|
// reuse deleted ready descriptors
|
||||||
|
newDescriptor = m.free[len(m.free)-1]
|
||||||
|
m.free = m.free[:len(m.free)-1]
|
||||||
|
} else {
|
||||||
|
newDescriptor = &readyDescriptor{}
|
||||||
|
}
|
||||||
|
newDescriptor.ID = key
|
||||||
|
m.descriptors[key] = newDescriptor
|
||||||
|
return newDescriptor
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *readyDescriptorMap) Delete(key uint32) {
|
||||||
|
if descriptor, ok := m.descriptors[key]; ok {
|
||||||
|
m.free = append(m.free, descriptor)
|
||||||
|
delete(m.descriptors, key)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,115 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReadyList(t *testing.T) {
|
||||||
|
rl := NewReadyList()
|
||||||
|
c := rl.ReadyChannel()
|
||||||
|
// helper functions
|
||||||
|
assertEmpty := func() {
|
||||||
|
select {
|
||||||
|
case <-c:
|
||||||
|
t.Fatalf("Spurious wakeup")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
receiveWithTimeout := func() uint32 {
|
||||||
|
select {
|
||||||
|
case i := <-c:
|
||||||
|
return i
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Fatalf("Timeout")
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// no signals, receive should fail
|
||||||
|
assertEmpty()
|
||||||
|
rl.Signal(0)
|
||||||
|
if receiveWithTimeout() != 0 {
|
||||||
|
t.Fatalf("Received wrong ID of signalled event")
|
||||||
|
}
|
||||||
|
// no new signals, receive should fail
|
||||||
|
assertEmpty()
|
||||||
|
// Signals should not block;
|
||||||
|
// Duplicate unhandled signals should not cause multiple wakeups
|
||||||
|
signalled := [5]bool{}
|
||||||
|
for i := range signalled {
|
||||||
|
rl.Signal(uint32(i))
|
||||||
|
rl.Signal(uint32(i))
|
||||||
|
}
|
||||||
|
// All signals should be received once (in any order)
|
||||||
|
for range signalled {
|
||||||
|
i := receiveWithTimeout()
|
||||||
|
if signalled[i] {
|
||||||
|
t.Fatalf("Received signal %d more than once", i)
|
||||||
|
}
|
||||||
|
signalled[i] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadyDescriptorQueue(t *testing.T) {
|
||||||
|
var queue readyDescriptorQueue
|
||||||
|
items := [4]readyDescriptor{}
|
||||||
|
for i := range items {
|
||||||
|
items[i].ID = uint32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !queue.Empty() {
|
||||||
|
t.Fatalf("nil queue should be empty")
|
||||||
|
}
|
||||||
|
queue.Enqueue(&items[3])
|
||||||
|
queue.Enqueue(&items[1])
|
||||||
|
queue.Enqueue(&items[0])
|
||||||
|
queue.Enqueue(&items[2])
|
||||||
|
if queue.Empty() {
|
||||||
|
t.Fatalf("Empty should be false after enqueue")
|
||||||
|
}
|
||||||
|
i := queue.Dequeue().ID
|
||||||
|
if i != 3 {
|
||||||
|
t.Fatalf("item 3 should have been dequeued, got %d instead", i)
|
||||||
|
}
|
||||||
|
i = queue.Dequeue().ID
|
||||||
|
if i != 1 {
|
||||||
|
t.Fatalf("item 1 should have been dequeued, got %d instead", i)
|
||||||
|
}
|
||||||
|
i = queue.Dequeue().ID
|
||||||
|
if i != 0 {
|
||||||
|
t.Fatalf("item 0 should have been dequeued, got %d instead", i)
|
||||||
|
}
|
||||||
|
i = queue.Dequeue().ID
|
||||||
|
if i != 2 {
|
||||||
|
t.Fatalf("item 2 should have been dequeued, got %d instead", i)
|
||||||
|
}
|
||||||
|
if !queue.Empty() {
|
||||||
|
t.Fatal("queue should be empty after dequeuing all items")
|
||||||
|
}
|
||||||
|
if queue.Dequeue() != nil {
|
||||||
|
t.Fatal("dequeue on empty queue should return nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadyDescriptorMap(t *testing.T) {
|
||||||
|
m := newReadyDescriptorMap()
|
||||||
|
m.Delete(42)
|
||||||
|
// (delete of missing key should be a noop)
|
||||||
|
x := m.SetIfMissing(42)
|
||||||
|
if x == nil {
|
||||||
|
t.Fatal("SetIfMissing for new key returned nil")
|
||||||
|
}
|
||||||
|
if m.SetIfMissing(42) != nil {
|
||||||
|
t.Fatal("SetIfMissing for existing key returned non-nil")
|
||||||
|
}
|
||||||
|
// this delete has effect
|
||||||
|
m.Delete(42)
|
||||||
|
// the next set should reuse the old object
|
||||||
|
y := m.SetIfMissing(666)
|
||||||
|
if y == nil {
|
||||||
|
t.Fatal("SetIfMissing for new key returned nil")
|
||||||
|
}
|
||||||
|
if x != y {
|
||||||
|
t.Fatal("SetIfMissing didn't reuse freed object")
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,53 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PingTimestamp is an atomic interface around ping timestamping and signalling.
|
||||||
|
type PingTimestamp struct {
|
||||||
|
ts int64
|
||||||
|
signal Signal
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPingTimestamp() *PingTimestamp {
|
||||||
|
return &PingTimestamp{signal: NewSignal()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pt *PingTimestamp) Set(v int64) {
|
||||||
|
if atomic.SwapInt64(&pt.ts, v) != 0 {
|
||||||
|
pt.signal.Signal()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pt *PingTimestamp) Get() int64 {
|
||||||
|
return atomic.SwapInt64(&pt.ts, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} {
|
||||||
|
return pt.signal.WaitChannel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RTTMeasurement encapsulates a continuous round trip time measurement.
|
||||||
|
type RTTMeasurement struct {
|
||||||
|
Current, Min, Max time.Duration
|
||||||
|
lastMeasurementTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update updates the computed values with a new measurement.
|
||||||
|
// outgoingTime is the time that the probe was sent.
|
||||||
|
// We assume that time.Now() is the time we received that probe.
|
||||||
|
func (r *RTTMeasurement) Update(outgoingTime time.Time) {
|
||||||
|
if !r.lastMeasurementTime.Before(outgoingTime) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.lastMeasurementTime = outgoingTime
|
||||||
|
r.Current = time.Since(outgoingTime)
|
||||||
|
if r.Max < r.Current {
|
||||||
|
r.Max = r.Current
|
||||||
|
}
|
||||||
|
if r.Min > r.Current {
|
||||||
|
r.Min = r.Current
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,64 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SharedBuffer struct {
|
||||||
|
cond *sync.Cond
|
||||||
|
buffer bytes.Buffer
|
||||||
|
eof bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSharedBuffer() *SharedBuffer {
|
||||||
|
return &SharedBuffer{
|
||||||
|
cond: sync.NewCond(&sync.Mutex{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SharedBuffer) Read(p []byte) (n int, err error) {
|
||||||
|
totalRead := 0
|
||||||
|
s.cond.L.Lock()
|
||||||
|
for totalRead < len(p) {
|
||||||
|
n, err = s.buffer.Read(p[totalRead:])
|
||||||
|
totalRead += n
|
||||||
|
if err == io.EOF {
|
||||||
|
if s.eof {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = nil
|
||||||
|
s.cond.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.cond.L.Unlock()
|
||||||
|
return totalRead, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SharedBuffer) Write(p []byte) (n int, err error) {
|
||||||
|
s.cond.L.Lock()
|
||||||
|
defer s.cond.L.Unlock()
|
||||||
|
if s.eof {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n, err = s.buffer.Write(p)
|
||||||
|
s.cond.Signal()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SharedBuffer) Close() error {
|
||||||
|
s.cond.L.Lock()
|
||||||
|
defer s.cond.L.Unlock()
|
||||||
|
if !s.eof {
|
||||||
|
s.eof = true
|
||||||
|
s.cond.Signal()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SharedBuffer) Closed() bool {
|
||||||
|
s.cond.L.Lock()
|
||||||
|
defer s.cond.L.Unlock()
|
||||||
|
return s.eof
|
||||||
|
}
|
|
@ -0,0 +1,120 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) {
|
||||||
|
return func(actual int, err error) {
|
||||||
|
if expected != actual {
|
||||||
|
t.Fatalf("Expected %d bytes, got %d", expected, actual)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSharedBuffer(t *testing.T) {
|
||||||
|
b := NewSharedBuffer()
|
||||||
|
testData := []byte("Hello world")
|
||||||
|
AssertIOReturnIsGood(t, len(testData))(b.Write(testData))
|
||||||
|
bytesRead := make([]byte, len(testData))
|
||||||
|
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSharedBufferBlockingRead(t *testing.T) {
|
||||||
|
b := NewSharedBuffer()
|
||||||
|
testData := []byte("Hello world")
|
||||||
|
result := make(chan []byte)
|
||||||
|
go func() {
|
||||||
|
bytesRead := make([]byte, len(testData))
|
||||||
|
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
|
||||||
|
result <- bytesRead
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-result:
|
||||||
|
t.Fatalf("read returned early")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
AssertIOReturnIsGood(t, 5)(b.Write(testData[:5]))
|
||||||
|
select {
|
||||||
|
case <-result:
|
||||||
|
t.Fatalf("read returned early")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
AssertIOReturnIsGood(t, len(testData)-5)(b.Write(testData[5:]))
|
||||||
|
select {
|
||||||
|
case r := <-result:
|
||||||
|
if string(r) != string(testData) {
|
||||||
|
t.Fatalf("expected read to return %s, got %s", testData, r)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatalf("read timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is quite slow under the race detector
|
||||||
|
func TestSharedBufferConcurrentReadWrite(t *testing.T) {
|
||||||
|
b := NewSharedBuffer()
|
||||||
|
var expectedResult, actualResult bytes.Buffer
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
block := make([]byte, 256)
|
||||||
|
for i := range block {
|
||||||
|
block[i] = byte(i)
|
||||||
|
}
|
||||||
|
for blockSize := 1; blockSize <= 256; blockSize++ {
|
||||||
|
for i := 0; i < 256; i++ {
|
||||||
|
expectedResult.Write(block[:blockSize])
|
||||||
|
n, err := b.Write(block[:blockSize])
|
||||||
|
if n != blockSize || err != nil {
|
||||||
|
t.Fatalf("write error: %d %s", n, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
block := make([]byte, 256)
|
||||||
|
// Change block sizes in opposition to the write thread, to test blocking for new data.
|
||||||
|
for blockSize := 256; blockSize > 0; blockSize-- {
|
||||||
|
for i := 0; i < 256; i++ {
|
||||||
|
n, err := b.Read(block[:blockSize])
|
||||||
|
if n != blockSize || err != nil {
|
||||||
|
t.Fatalf("read error: %d %s", n, err)
|
||||||
|
}
|
||||||
|
actualResult.Write(block[:blockSize])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
if bytes.Compare(expectedResult.Bytes(), actualResult.Bytes()) != 0 {
|
||||||
|
t.Fatal("Result diverged")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSharedBufferClose(t *testing.T) {
|
||||||
|
b := NewSharedBuffer()
|
||||||
|
testData := []byte("Hello world")
|
||||||
|
AssertIOReturnIsGood(t, len(testData))(b.Write(testData))
|
||||||
|
err := b.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error from Close: %s", err)
|
||||||
|
}
|
||||||
|
bytesRead := make([]byte, len(testData))
|
||||||
|
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
|
||||||
|
n, err := b.Read(bytesRead)
|
||||||
|
if n != 0 {
|
||||||
|
t.Fatalf("extra bytes received: %d", n)
|
||||||
|
}
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Fatalf("expected EOF, got %s", err)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,34 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
// Signal describes an event that can be waited on for at least one signal.
|
||||||
|
// Signalling the event while it is in the signalled state is a noop.
|
||||||
|
// When the waiter wakes up, the signal is set to unsignalled.
|
||||||
|
// It is a way for any number of writers to inform a reader (without blocking)
|
||||||
|
// that an event has happened.
|
||||||
|
type Signal struct {
|
||||||
|
c chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSignal creates a new Signal.
|
||||||
|
func NewSignal() Signal {
|
||||||
|
return Signal{c: make(chan struct{}, 1)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signal signals the event.
|
||||||
|
func (s Signal) Signal() {
|
||||||
|
// This channel is buffered, so the nonblocking send will always succeed if the buffer is empty.
|
||||||
|
select {
|
||||||
|
case s.c <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the event to be signalled.
|
||||||
|
func (s Signal) Wait() {
|
||||||
|
<-s.c
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitChannel returns a channel that is readable after Signal is called.
|
||||||
|
func (s Signal) WaitChannel() <-chan struct{} {
|
||||||
|
return s.c
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StreamErrorMap is used to track stream errors. This is a separate structure to ActiveStreamMap because
|
||||||
|
// errors can be raised against non-existent or closed streams.
|
||||||
|
type StreamErrorMap struct {
|
||||||
|
sync.RWMutex
|
||||||
|
// errors tracks per-stream errors
|
||||||
|
errors map[uint32]http2.ErrCode
|
||||||
|
// hasError is signaled whenever an error is raised.
|
||||||
|
hasError Signal
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStreamErrorMap creates a new StreamErrorMap.
|
||||||
|
func NewStreamErrorMap() *StreamErrorMap {
|
||||||
|
return &StreamErrorMap{
|
||||||
|
errors: make(map[uint32]http2.ErrCode),
|
||||||
|
hasError: NewSignal(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RaiseError raises a stream error.
|
||||||
|
func (s *StreamErrorMap) RaiseError(streamID uint32, err http2.ErrCode) {
|
||||||
|
s.Lock()
|
||||||
|
s.errors[streamID] = err
|
||||||
|
s.Unlock()
|
||||||
|
s.hasError.Signal()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSignalChan returns a channel that is signalled when an error is raised.
|
||||||
|
func (s *StreamErrorMap) GetSignalChan() <-chan struct{} {
|
||||||
|
return s.hasError.WaitChannel()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetErrors retrieves all errors currently raised. This resets the currently-tracked errors.
|
||||||
|
func (s *StreamErrorMap) GetErrors() map[uint32]http2.ErrCode {
|
||||||
|
s.Lock()
|
||||||
|
errors := s.errors
|
||||||
|
s.errors = make(map[uint32]http2.ErrCode)
|
||||||
|
s.Unlock()
|
||||||
|
return errors
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ServeMetrics(l net.Listener, shutdownC <-chan struct{}) error {
|
||||||
|
server := &http.Server{
|
||||||
|
ReadTimeout: 5 * time.Second,
|
||||||
|
WriteTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
<-shutdownC
|
||||||
|
server.Shutdown(context.Background())
|
||||||
|
}()
|
||||||
|
http.Handle("/metrics", promhttp.Handler())
|
||||||
|
log.WithField("addr", l.Addr()).Info("Starting metrics server")
|
||||||
|
err := server.Serve(l)
|
||||||
|
if err == http.ErrServerClosed {
|
||||||
|
log.Info("Metrics server stopped")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.WithError(err).Error("Metrics server quit with error")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterBuildInfo(buildTime string, version string) {
|
||||||
|
buildInfo := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
// Don't namespace build_info, since we want it to be consistent across all Cloudflare services
|
||||||
|
Name: "build_info",
|
||||||
|
Help: "Build and version information",
|
||||||
|
},
|
||||||
|
[]string{"goversion", "revision", "version"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(buildInfo)
|
||||||
|
buildInfo.WithLabelValues(runtime.Version(), buildTime, version).Set(1)
|
||||||
|
}
|
|
@ -0,0 +1,70 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Redeclare time functions so they can be overridden in tests.
|
||||||
|
var (
|
||||||
|
timeNow = time.Now
|
||||||
|
timeAfter = time.After
|
||||||
|
)
|
||||||
|
|
||||||
|
// BackoffHandler manages exponential backoff and limits the maximum number of retries.
|
||||||
|
// The base time period is 1 second, doubling with each retry.
|
||||||
|
// After initial success, a grace period can be set to reset the backoff timer if
|
||||||
|
// a connection is maintained successfully for a long enough period. The base grace period
|
||||||
|
// is 2 seconds, doubling with each retry.
|
||||||
|
type BackoffHandler struct {
|
||||||
|
// MaxRetries sets the maximum number of retries to perform. The default value
|
||||||
|
// of 0 disables retry completely.
|
||||||
|
MaxRetries uint
|
||||||
|
|
||||||
|
retries uint
|
||||||
|
resetDeadline time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BackoffHandler) GetBackoffDuration(ctx context.Context) (time.Duration, bool) {
|
||||||
|
// Follows the same logic as Backoff, but without mutating the receiver.
|
||||||
|
// This select has to happen first to reflect the actual behaviour of the Backoff function.
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return time.Duration(0), false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) {
|
||||||
|
// b.retries would be set to 0 at this point
|
||||||
|
return time.Second, true
|
||||||
|
}
|
||||||
|
if b.retries >= b.MaxRetries {
|
||||||
|
return time.Duration(0), false
|
||||||
|
}
|
||||||
|
return time.Duration(time.Second * 1 << b.retries), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backoff is used to wait according to exponential backoff. Returns false if the
|
||||||
|
// maximum number of retries have been used or if the underlying context has been cancelled.
|
||||||
|
func (b *BackoffHandler) Backoff(ctx context.Context) bool {
|
||||||
|
if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) {
|
||||||
|
b.retries = 0
|
||||||
|
b.resetDeadline = time.Time{}
|
||||||
|
}
|
||||||
|
if b.retries >= b.MaxRetries {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-timeAfter(time.Duration(time.Second * 1 << b.retries)):
|
||||||
|
b.retries++
|
||||||
|
return true
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets a grace period within which the the backoff timer is maintained. After the grace
|
||||||
|
// period expires, the number of retries & backoff duration is reset.
|
||||||
|
func (b *BackoffHandler) SetGracePeriod() {
|
||||||
|
b.resetDeadline = timeNow().Add(time.Duration(time.Second * 2 << b.retries))
|
||||||
|
}
|
|
@ -0,0 +1,114 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
func immediateTimeAfter(time.Duration) <-chan time.Time {
|
||||||
|
c := make(chan time.Time, 1)
|
||||||
|
c <- time.Now()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackoffRetries(t *testing.T) {
|
||||||
|
// make backoff return immediately
|
||||||
|
timeAfter = immediateTimeAfter
|
||||||
|
ctx := context.Background()
|
||||||
|
backoff := BackoffHandler{MaxRetries: 3}
|
||||||
|
if !backoff.Backoff(ctx) {
|
||||||
|
t.Fatalf("backoff failed immediately")
|
||||||
|
}
|
||||||
|
if !backoff.Backoff(ctx) {
|
||||||
|
t.Fatalf("backoff failed after 1 retry")
|
||||||
|
}
|
||||||
|
if !backoff.Backoff(ctx) {
|
||||||
|
t.Fatalf("backoff failed after 2 retry")
|
||||||
|
}
|
||||||
|
if backoff.Backoff(ctx) {
|
||||||
|
t.Fatalf("backoff allowed after 3 (max) retries")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackoffCancel(t *testing.T) {
|
||||||
|
// prevent backoff from returning normally
|
||||||
|
timeAfter = func(time.Duration) <-chan time.Time { return make(chan time.Time) }
|
||||||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||||
|
backoff := BackoffHandler{MaxRetries: 3}
|
||||||
|
cancelFunc()
|
||||||
|
if backoff.Backoff(ctx) {
|
||||||
|
t.Fatalf("backoff allowed after cancel")
|
||||||
|
}
|
||||||
|
if _, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||||
|
t.Fatalf("backoff allowed after cancel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackoffGracePeriod(t *testing.T) {
|
||||||
|
currentTime := time.Now()
|
||||||
|
// make timeNow return whatever we like
|
||||||
|
timeNow = func() time.Time { return currentTime }
|
||||||
|
// make backoff return immediately
|
||||||
|
timeAfter = immediateTimeAfter
|
||||||
|
ctx := context.Background()
|
||||||
|
backoff := BackoffHandler{MaxRetries: 1}
|
||||||
|
if !backoff.Backoff(ctx) {
|
||||||
|
t.Fatalf("backoff failed immediately")
|
||||||
|
}
|
||||||
|
// the next call to Backoff would fail unless it's after the grace period
|
||||||
|
backoff.SetGracePeriod()
|
||||||
|
// advance time to after the grace period (~4 seconds) and see what happens
|
||||||
|
currentTime = currentTime.Add(time.Second * 5)
|
||||||
|
if !backoff.Backoff(ctx) {
|
||||||
|
t.Fatalf("backoff failed after the grace period expired")
|
||||||
|
}
|
||||||
|
// confirm we ignore grace period after backoff
|
||||||
|
if backoff.Backoff(ctx) {
|
||||||
|
t.Fatalf("backoff allowed after 1 (max) retry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetBackoffDurationRetries(t *testing.T) {
|
||||||
|
// make backoff return immediately
|
||||||
|
timeAfter = immediateTimeAfter
|
||||||
|
ctx := context.Background()
|
||||||
|
backoff := BackoffHandler{MaxRetries: 3}
|
||||||
|
if _, ok := backoff.GetBackoffDuration(ctx); !ok {
|
||||||
|
t.Fatalf("backoff failed immediately")
|
||||||
|
}
|
||||||
|
backoff.Backoff(ctx) // noop
|
||||||
|
if _, ok := backoff.GetBackoffDuration(ctx); !ok {
|
||||||
|
t.Fatalf("backoff failed after 1 retry")
|
||||||
|
}
|
||||||
|
backoff.Backoff(ctx) // noop
|
||||||
|
if _, ok := backoff.GetBackoffDuration(ctx); !ok {
|
||||||
|
t.Fatalf("backoff failed after 2 retry")
|
||||||
|
}
|
||||||
|
backoff.Backoff(ctx) // noop
|
||||||
|
if _, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||||
|
t.Fatalf("backoff allowed after 3 (max) retries")
|
||||||
|
}
|
||||||
|
if backoff.Backoff(ctx) {
|
||||||
|
t.Fatalf("backoff allowed after 3 (max) retries")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetBackoffDuration(t *testing.T) {
|
||||||
|
// make backoff return immediately
|
||||||
|
timeAfter = immediateTimeAfter
|
||||||
|
ctx := context.Background()
|
||||||
|
backoff := BackoffHandler{MaxRetries: 3}
|
||||||
|
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second {
|
||||||
|
t.Fatalf("backoff didn't return 1 second on first retry")
|
||||||
|
}
|
||||||
|
backoff.Backoff(ctx) // noop
|
||||||
|
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*2 {
|
||||||
|
t.Fatalf("backoff didn't return 2 seconds on second retry")
|
||||||
|
}
|
||||||
|
backoff.Backoff(ctx) // noop
|
||||||
|
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*4 {
|
||||||
|
t.Fatalf("backoff didn't return 4 seconds on third retry")
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,360 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflare-warp/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflare-warp/tunnelrpc"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||||
|
"github.com/cloudflare/cloudflare-warp/validation"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
raven "github.com/getsentry/raven-go"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
rpc "zombiezen.com/go/capnproto2/rpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
dialTimeout = 15 * time.Second
|
||||||
|
|
||||||
|
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TunnelConfig struct {
|
||||||
|
EdgeAddr string
|
||||||
|
OriginUrl string
|
||||||
|
Hostname string
|
||||||
|
APIKey string
|
||||||
|
APIEmail string
|
||||||
|
APICAKey string
|
||||||
|
TlsConfig *tls.Config
|
||||||
|
Retries uint
|
||||||
|
HeartbeatInterval time.Duration
|
||||||
|
MaxHeartbeats uint64
|
||||||
|
ClientID string
|
||||||
|
ReportedVersion string
|
||||||
|
LBPool string
|
||||||
|
Tags []tunnelpogs.Tag
|
||||||
|
AccessInternalIP bool
|
||||||
|
ConnectedSignal h2mux.Signal
|
||||||
|
}
|
||||||
|
|
||||||
|
type dialError struct {
|
||||||
|
cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e dialError) Error() string {
|
||||||
|
return e.cause.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
type printableRegisterTunnelError struct {
|
||||||
|
cause error
|
||||||
|
permanent bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e printableRegisterTunnelError) Error() string {
|
||||||
|
return e.cause.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TunnelConfig) RegistrationOptions() *tunnelpogs.RegistrationOptions {
|
||||||
|
policy := tunnelrpc.ExistingTunnelPolicy_disconnect
|
||||||
|
if c.LBPool != "" {
|
||||||
|
policy = tunnelrpc.ExistingTunnelPolicy_balance
|
||||||
|
}
|
||||||
|
return &tunnelpogs.RegistrationOptions{
|
||||||
|
ClientID: c.ClientID,
|
||||||
|
Version: c.ReportedVersion,
|
||||||
|
OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH),
|
||||||
|
ExistingTunnelPolicy: policy,
|
||||||
|
PoolID: c.LBPool,
|
||||||
|
Tags: c.Tags,
|
||||||
|
ExposeInternalHostname: c.AccessInternalIP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}) error {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
<-shutdownC
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
backoff := BackoffHandler{MaxRetries: config.Retries}
|
||||||
|
for {
|
||||||
|
err, recoverable := ServeTunnel(ctx, config, &backoff)
|
||||||
|
if recoverable {
|
||||||
|
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||||
|
log.Infof("Retrying in %s seconds", duration)
|
||||||
|
backoff.Backoff(ctx)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ServeTunnel(
|
||||||
|
ctx context.Context,
|
||||||
|
config *TunnelConfig,
|
||||||
|
backoff *BackoffHandler,
|
||||||
|
) (err error, recoverable bool) {
|
||||||
|
// Returns error from parsing the origin URL or handshake errors
|
||||||
|
handler, err := NewTunnelHandler(ctx, config)
|
||||||
|
if err != nil {
|
||||||
|
errLog := log.WithError(err)
|
||||||
|
switch err.(type) {
|
||||||
|
case dialError:
|
||||||
|
errLog.Error("Unable to dial edge")
|
||||||
|
case h2mux.MuxerHandshakeError:
|
||||||
|
errLog.Error("Handshake failed with edge server")
|
||||||
|
default:
|
||||||
|
errLog.Error("Tunnel creation failure")
|
||||||
|
return err, false
|
||||||
|
}
|
||||||
|
return err, true
|
||||||
|
}
|
||||||
|
serveCtx, serveCancel := context.WithCancel(ctx)
|
||||||
|
registerErrC := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
err := RegisterTunnel(serveCtx, handler.muxer, config)
|
||||||
|
if err == nil {
|
||||||
|
config.ConnectedSignal.Signal()
|
||||||
|
backoff.SetGracePeriod()
|
||||||
|
} else {
|
||||||
|
serveCancel()
|
||||||
|
}
|
||||||
|
registerErrC <- err
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
<-serveCtx.Done()
|
||||||
|
handler.muxer.Shutdown()
|
||||||
|
}()
|
||||||
|
err = handler.muxer.Serve()
|
||||||
|
serveCancel()
|
||||||
|
registerErr := <-registerErrC
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("Tunnel error")
|
||||||
|
return err, true
|
||||||
|
}
|
||||||
|
if registerErr != nil {
|
||||||
|
raven.CaptureError(registerErr, nil)
|
||||||
|
// Don't retry on errors like entitlement failure or version too old
|
||||||
|
if e, ok := registerErr.(printableRegisterTunnelError); ok {
|
||||||
|
log.WithError(e).Error("Cannot register")
|
||||||
|
if e.permanent {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return e.cause, true
|
||||||
|
}
|
||||||
|
log.Error("Cannot register")
|
||||||
|
return err, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsRPCStreamResponse(headers []h2mux.Header) bool {
|
||||||
|
if len(headers) != 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if headers[0].Name != ":status" || headers[0].Value != "200" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig) error {
|
||||||
|
logger := log.WithField("subsystem", "rpc")
|
||||||
|
logger.Debug("initiating RPC stream")
|
||||||
|
stream, err := muxer.OpenStream([]h2mux.Header{
|
||||||
|
{Name: ":method", Value: "RPC"},
|
||||||
|
{Name: ":scheme", Value: "capnp"},
|
||||||
|
{Name: ":path", Value: "*"},
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
// RPC stream open error
|
||||||
|
raven.CaptureError(err, nil)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !IsRPCStreamResponse(stream.Headers) {
|
||||||
|
// stream response error
|
||||||
|
raven.CaptureError(err, nil)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conn := rpc.NewConn(
|
||||||
|
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
|
||||||
|
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
|
||||||
|
)
|
||||||
|
defer conn.Close()
|
||||||
|
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
|
||||||
|
// Request server info without blocking tunnel registration; must use capnp library directly.
|
||||||
|
tsClient := tunnelrpc.TunnelServer{Client: ts.Client}
|
||||||
|
serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
registration, err := ts.RegisterTunnel(
|
||||||
|
ctx,
|
||||||
|
&tunnelpogs.Authentication{Key: config.APIKey, Email: config.APIEmail, OriginCAKey: config.APICAKey},
|
||||||
|
config.Hostname,
|
||||||
|
config.RegistrationOptions(),
|
||||||
|
)
|
||||||
|
LogServerInfo(logger, serverInfoPromise.Result())
|
||||||
|
if err != nil {
|
||||||
|
// RegisterTunnel RPC failure
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, logLine := range registration.LogLines {
|
||||||
|
logger.Info(logLine)
|
||||||
|
}
|
||||||
|
if registration.Err != "" {
|
||||||
|
return printableRegisterTunnelError{
|
||||||
|
cause: fmt.Errorf("Server error: %s", registration.Err),
|
||||||
|
permanent: registration.PermanentFailure,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, url := range registration.Urls {
|
||||||
|
log.Infof("Registered at %s", url)
|
||||||
|
}
|
||||||
|
for _, logLine := range registration.LogLines {
|
||||||
|
log.Infof(logLine)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogServerInfo(logger *log.Entry, promise tunnelrpc.ServerInfo_Promise) {
|
||||||
|
serverInfoMessage, err := promise.Struct()
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Warn("Failed to retrieve server information")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage)
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Warn("Failed to retrieve server information")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("Connected to %s", serverInfo.LocationName)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TunnelHandler struct {
|
||||||
|
originUrl string
|
||||||
|
muxer *h2mux.Muxer
|
||||||
|
httpClient *http.Client
|
||||||
|
tags []tunnelpogs.Tag
|
||||||
|
}
|
||||||
|
|
||||||
|
var dialer = net.Dialer{DualStack: true}
|
||||||
|
|
||||||
|
func NewTunnelHandler(ctx context.Context, config *TunnelConfig) (*TunnelHandler, error) {
|
||||||
|
url, err := validation.ValidateUrl(config.OriginUrl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Unable to parse origin url %#v", url)
|
||||||
|
}
|
||||||
|
h := &TunnelHandler{
|
||||||
|
originUrl: url,
|
||||||
|
httpClient: &http.Client{Timeout: time.Minute},
|
||||||
|
tags: config.Tags,
|
||||||
|
}
|
||||||
|
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
|
||||||
|
dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
|
||||||
|
// TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one)
|
||||||
|
plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", config.EdgeAddr)
|
||||||
|
dialCancel()
|
||||||
|
if err != nil {
|
||||||
|
return nil, dialError{cause: err}
|
||||||
|
}
|
||||||
|
edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig)
|
||||||
|
edgeConn.SetDeadline(time.Now().Add(dialTimeout))
|
||||||
|
err = edgeConn.Handshake()
|
||||||
|
if err != nil {
|
||||||
|
return nil, dialError{cause: err}
|
||||||
|
}
|
||||||
|
// clear the deadline on the conn; h2mux has its own timeouts
|
||||||
|
edgeConn.SetDeadline(time.Time{})
|
||||||
|
// Establish a muxed connection with the edge
|
||||||
|
// Client mux handshake with agent server
|
||||||
|
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
Handler: h,
|
||||||
|
IsClient: true,
|
||||||
|
HeartbeatInterval: config.HeartbeatInterval,
|
||||||
|
MaxHeartbeats: config.MaxHeartbeats,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return h, errors.New("TLS handshake error")
|
||||||
|
}
|
||||||
|
return h, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
|
||||||
|
for _, header := range h2 {
|
||||||
|
switch header.Name {
|
||||||
|
case ":method":
|
||||||
|
h1.Method = header.Value
|
||||||
|
case ":scheme":
|
||||||
|
case ":authority":
|
||||||
|
// Otherwise the host header will be based on the origin URL
|
||||||
|
h1.Host = header.Value
|
||||||
|
case ":path":
|
||||||
|
u, err := url.Parse(header.Value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unparseable path")
|
||||||
|
}
|
||||||
|
resolved := h1.URL.ResolveReference(u)
|
||||||
|
// prevent escaping base URL
|
||||||
|
if !strings.HasPrefix(resolved.String(), h1.URL.String()) {
|
||||||
|
return fmt.Errorf("invalid path")
|
||||||
|
}
|
||||||
|
h1.URL = resolved
|
||||||
|
default:
|
||||||
|
h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
|
||||||
|
h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}}
|
||||||
|
for headerName, headerValues := range h1.Header {
|
||||||
|
for _, headerValue := range headerValues {
|
||||||
|
h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
|
||||||
|
for _, tag := range h.tags {
|
||||||
|
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
|
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Panic("Unexpected error from http.NewRequest")
|
||||||
|
}
|
||||||
|
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("invalid request received")
|
||||||
|
}
|
||||||
|
h.AppendTagHeaders(req)
|
||||||
|
response, err := h.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Error("HTTP request error")
|
||||||
|
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
|
||||||
|
stream.Write([]byte("502 Bad Gateway"))
|
||||||
|
} else {
|
||||||
|
defer response.Body.Close()
|
||||||
|
stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||||
|
io.Copy(stream, response.Body)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
// Package tlsconfig provides convenience functions for configuring TLS connections from the
|
||||||
|
// command line.
|
||||||
|
package tlsconfig
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"io/ioutil"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
cli "gopkg.in/urfave/cli.v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CLIFlags names the flags used to configure TLS for a command or subsystem.
|
||||||
|
// The nil value for a field means the flag is ignored.
|
||||||
|
type CLIFlags struct {
|
||||||
|
Cert string
|
||||||
|
Key string
|
||||||
|
ClientCert string
|
||||||
|
RootCA string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig returns a TLS configuration according to the flags defined in f and
|
||||||
|
// set by the user.
|
||||||
|
func (f CLIFlags) GetConfig(c *cli.Context) *tls.Config {
|
||||||
|
config := &tls.Config{}
|
||||||
|
|
||||||
|
if c.IsSet(f.Cert) && c.IsSet(f.Key) {
|
||||||
|
cert, err := tls.LoadX509KeyPair(c.String(f.Cert), c.String(f.Key))
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Fatal("Error parsing X509 key pair")
|
||||||
|
}
|
||||||
|
config.Certificates = []tls.Certificate{cert}
|
||||||
|
config.BuildNameToCertificate()
|
||||||
|
}
|
||||||
|
if c.IsSet(f.ClientCert) {
|
||||||
|
// set of root certificate authorities that servers use if required to verify a client certificate
|
||||||
|
// by the policy in ClientAuth
|
||||||
|
config.ClientCAs = LoadCert(c.String(f.ClientCert))
|
||||||
|
// server's policy for TLS Client Authentication. Default is no client cert
|
||||||
|
config.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
|
}
|
||||||
|
// set of root certificate authorities that clients use when verifying server certificates
|
||||||
|
if c.IsSet(f.RootCA) {
|
||||||
|
config.RootCAs = LoadCert(c.String(f.RootCA))
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadCert creates a CertPool containing all certificates in a PEM-format file.
|
||||||
|
func LoadCert(certPath string) *x509.CertPool {
|
||||||
|
caCert, err := ioutil.ReadFile(certPath)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Fatalf("Error reading certificate %s", certPath)
|
||||||
|
}
|
||||||
|
ca := x509.NewCertPool()
|
||||||
|
if !ca.AppendCertsFromPEM(caCert) {
|
||||||
|
log.WithError(err).Fatalf("Error parsing certificate %s", certPath)
|
||||||
|
}
|
||||||
|
return ca
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Generate go.capnp.out with:
|
||||||
|
# capnp compile -o- go.capnp > go.capnp.out
|
||||||
|
# Must run inside this directory to preserve paths.
|
||||||
|
|
||||||
|
@0xd12a1c51fedd6c88;
|
||||||
|
|
||||||
|
annotation package(file) :Text;
|
||||||
|
annotation import(file) :Text;
|
||||||
|
annotation doc(struct, field, enum) :Text;
|
||||||
|
annotation tag(enumerant) :Text;
|
||||||
|
annotation notag(enumerant) :Void;
|
||||||
|
annotation customtype(field) :Text;
|
||||||
|
annotation name(struct, field, union, enum, enumerant, interface, method, param, annotation, const, group) :Text;
|
||||||
|
|
||||||
|
$package("capnp");
|
|
@ -0,0 +1,26 @@
|
||||||
|
package tunnelrpc
|
||||||
|
|
||||||
|
//go:generate capnp compile -ogo -I./tunnelrpc/ tunnelrpc.capnp
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"zombiezen.com/go/capnproto2/rpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnLogger wraps a logrus *log.Entry for a connection.
|
||||||
|
type ConnLogger struct {
|
||||||
|
Entry *log.Entry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ConnLogger) Infof(ctx context.Context, format string, args ...interface{}) {
|
||||||
|
c.Entry.Infof(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ConnLogger) Errorf(ctx context.Context, format string, args ...interface{}) {
|
||||||
|
c.Entry.Errorf(format, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConnLog(log *log.Entry) rpc.ConnOption {
|
||||||
|
return rpc.ConnLog(ConnLogger{log})
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
// Package logtransport provides a transport that logs all of its messages.
|
||||||
|
package tunnelrpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"zombiezen.com/go/capnproto2/encoding/text"
|
||||||
|
"zombiezen.com/go/capnproto2/rpc"
|
||||||
|
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
type transport struct {
|
||||||
|
rpc.Transport
|
||||||
|
l *log.Entry
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new logger that proxies messages to and from t and
|
||||||
|
// logs them to l. If l is nil, then the log package's default
|
||||||
|
// logger is used.
|
||||||
|
func NewTransportLogger(l *log.Entry, t rpc.Transport) rpc.Transport {
|
||||||
|
return &transport{Transport: t, l: l}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *transport) SendMessage(ctx context.Context, msg rpccapnp.Message) error {
|
||||||
|
t.l.Debugf("tx %s", formatMsg(msg))
|
||||||
|
return t.Transport.SendMessage(ctx, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *transport) RecvMessage(ctx context.Context) (rpccapnp.Message, error) {
|
||||||
|
msg, err := t.Transport.RecvMessage(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.l.WithError(err).Debug("rx error")
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
t.l.Debugf("rx %s", formatMsg(msg))
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatMsg(m rpccapnp.Message) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
text.NewEncoder(&buf).Encode(0x91b79f1f808db032, m.Struct)
|
||||||
|
return buf.String()
|
||||||
|
}
|
|
@ -0,0 +1,194 @@
|
||||||
|
package pogs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/cloudflare/cloudflare-warp/tunnelrpc"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"zombiezen.com/go/capnproto2"
|
||||||
|
"zombiezen.com/go/capnproto2/pogs"
|
||||||
|
"zombiezen.com/go/capnproto2/rpc"
|
||||||
|
"zombiezen.com/go/capnproto2/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Authentication struct {
|
||||||
|
Key string
|
||||||
|
Email string
|
||||||
|
OriginCAKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalAuthentication(s tunnelrpc.Authentication, p *Authentication) error {
|
||||||
|
return pogs.Insert(tunnelrpc.Authentication_TypeID, s.Struct, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalAuthentication(s tunnelrpc.Authentication) (*Authentication, error) {
|
||||||
|
p := new(Authentication)
|
||||||
|
err := pogs.Extract(p, tunnelrpc.Authentication_TypeID, s.Struct)
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type TunnelRegistration struct {
|
||||||
|
Err string
|
||||||
|
Urls []string
|
||||||
|
LogLines []string
|
||||||
|
PermanentFailure bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error {
|
||||||
|
return pogs.Insert(tunnelrpc.TunnelRegistration_TypeID, s.Struct, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalTunnelRegistration(s tunnelrpc.TunnelRegistration) (*TunnelRegistration, error) {
|
||||||
|
p := new(TunnelRegistration)
|
||||||
|
err := pogs.Extract(p, tunnelrpc.TunnelRegistration_TypeID, s.Struct)
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type RegistrationOptions struct {
|
||||||
|
ClientID string `capnp:"clientId"`
|
||||||
|
Version string
|
||||||
|
OS string `capnp:"os"`
|
||||||
|
ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy
|
||||||
|
PoolID string `capnp:"poolId"`
|
||||||
|
ExposeInternalHostname bool
|
||||||
|
Tags []Tag
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {
|
||||||
|
return pogs.Insert(tunnelrpc.RegistrationOptions_TypeID, s.Struct, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalRegistrationOptions(s tunnelrpc.RegistrationOptions) (*RegistrationOptions, error) {
|
||||||
|
p := new(RegistrationOptions)
|
||||||
|
err := pogs.Extract(p, tunnelrpc.RegistrationOptions_TypeID, s.Struct)
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tag struct {
|
||||||
|
Name string
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerInfo struct {
|
||||||
|
LocationName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func MarshalServerInfo(s tunnelrpc.ServerInfo, p *ServerInfo) error {
|
||||||
|
return pogs.Insert(tunnelrpc.ServerInfo_TypeID, s.Struct, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnmarshalServerInfo(s tunnelrpc.ServerInfo) (*ServerInfo, error) {
|
||||||
|
p := new(ServerInfo)
|
||||||
|
err := pogs.Extract(p, tunnelrpc.ServerInfo_TypeID, s.Struct)
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
type TunnelServer interface {
|
||||||
|
RegisterTunnel(ctx context.Context, auth *Authentication, hostname string, options *RegistrationOptions) (*TunnelRegistration, error)
|
||||||
|
GetServerInfo(ctx context.Context) (*ServerInfo, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {
|
||||||
|
return tunnelrpc.TunnelServer_ServerToClient(TunnelServer_PogsImpl{s})
|
||||||
|
}
|
||||||
|
|
||||||
|
type TunnelServer_PogsImpl struct {
|
||||||
|
impl TunnelServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i TunnelServer_PogsImpl) RegisterTunnel(p tunnelrpc.TunnelServer_registerTunnel) error {
|
||||||
|
authentication, err := p.Params.Auth()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pogsAuthentication, err := UnmarshalAuthentication(authentication)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
hostname, err := p.Params.Hostname()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
options, err := p.Params.Options()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pogsOptions, err := UnmarshalRegistrationOptions(options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
server.Ack(p.Options)
|
||||||
|
registration, err := i.impl.RegisterTunnel(p.Ctx, pogsAuthentication, hostname, pogsOptions)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result, err := p.Results.NewResult()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return MarshalTunnelRegistration(result, registration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i TunnelServer_PogsImpl) GetServerInfo(p tunnelrpc.TunnelServer_getServerInfo) error {
|
||||||
|
server.Ack(p.Options)
|
||||||
|
serverInfo, err := i.impl.GetServerInfo(p.Ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result, err := p.Results.NewResult()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return MarshalServerInfo(result, serverInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TunnelServer_PogsClient struct {
|
||||||
|
Client capnp.Client
|
||||||
|
Conn *rpc.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c TunnelServer_PogsClient) Close() error {
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, auth *Authentication, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) {
|
||||||
|
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||||
|
promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error {
|
||||||
|
authentication, err := p.NewAuth()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = MarshalAuthentication(authentication, auth)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = p.SetHostname(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
registrationOptions, err := p.NewOptions()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = MarshalRegistrationOptions(registrationOptions, options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
retval, err := promise.Result().Struct()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return UnmarshalTunnelRegistration(retval)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) {
|
||||||
|
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||||
|
promise := client.GetServerInfo(ctx, func(p tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
retval, err := promise.Result().Struct()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return UnmarshalServerInfo(retval)
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
using Go = import "go.capnp";
|
||||||
|
@0xdb8274f9144abc7e;
|
||||||
|
$Go.package("tunnelrpc");
|
||||||
|
$Go.import("github.com/cloudflare/cloudflare-warp/tunnelrpc");
|
||||||
|
|
||||||
|
struct Authentication {
|
||||||
|
key @0 :Text;
|
||||||
|
email @1 :Text;
|
||||||
|
originCAKey @2 :Text;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TunnelRegistration {
|
||||||
|
err @0 :Text;
|
||||||
|
# A list of URLs that the tunnel is accessible from.
|
||||||
|
urls @1 :List(Text);
|
||||||
|
# Used to inform the client of actions taken.
|
||||||
|
logLines @2 :List(Text);
|
||||||
|
# In case of error, whether the client should attempt to reconnect.
|
||||||
|
permanentFailure @3 :Bool;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RegistrationOptions {
|
||||||
|
# The tunnel client's unique identifier, used to verify a reconnection.
|
||||||
|
clientId @0 :Text;
|
||||||
|
# Information about the running binary.
|
||||||
|
version @1 :Text;
|
||||||
|
os @2 :Text;
|
||||||
|
# What to do with existing tunnels for the given hostname.
|
||||||
|
existingTunnelPolicy @3 :ExistingTunnelPolicy;
|
||||||
|
# If using the balancing policy, identifies the LB pool to use.
|
||||||
|
poolId @4 :Text;
|
||||||
|
# Prevents the tunnel from being accessed at <subdomain>.cftunnel.com
|
||||||
|
exposeInternalHostname @5 :Bool;
|
||||||
|
# Client-defined tags to associate with the tunnel
|
||||||
|
tags @6 :List(Tag);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Tag {
|
||||||
|
name @0 :Text;
|
||||||
|
value @1 :Text;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ExistingTunnelPolicy {
|
||||||
|
ignore @0;
|
||||||
|
disconnect @1;
|
||||||
|
balance @2;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ServerInfo {
|
||||||
|
locationName @0 :Text;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface TunnelServer {
|
||||||
|
registerTunnel @0 (auth :Authentication, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
|
||||||
|
getServerInfo @1 () -> (result :ServerInfo);
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,136 @@
|
||||||
|
package validation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/net/idna"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultScheme = "http"
|
||||||
|
|
||||||
|
var supportedProtocol = [2]string{"http", "https"}
|
||||||
|
|
||||||
|
func ValidateHostname(hostname string) (string, error) {
|
||||||
|
if hostname == "" {
|
||||||
|
return "", fmt.Errorf("Hostname should not be empty")
|
||||||
|
}
|
||||||
|
// users gives url(contains schema) not just hostname
|
||||||
|
if strings.Contains(hostname, ":") || strings.Contains(hostname, "%3A") {
|
||||||
|
unescapeHostname, err := url.PathUnescape(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid escape characters %s", hostname, unescapeHostname)
|
||||||
|
}
|
||||||
|
hostnameToURL, err := url.Parse(unescapeHostname)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid format %s", hostname, hostnameToURL)
|
||||||
|
}
|
||||||
|
asciiHostname, err := idna.ToASCII(hostnameToURL.Hostname())
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid ASCII encdoing %s", hostname, asciiHostname)
|
||||||
|
}
|
||||||
|
return asciiHostname, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
asciiHostname, err := idna.ToASCII(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Hostname %s has invalid ASCII encdoing %s", hostname, asciiHostname)
|
||||||
|
}
|
||||||
|
hostnameToURL, err := url.Parse(asciiHostname)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Hostname %s is not valid", hostnameToURL)
|
||||||
|
}
|
||||||
|
return hostnameToURL.RequestURI(), nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateUrl(originUrl string) (string, error) {
|
||||||
|
if originUrl == "" {
|
||||||
|
return "", fmt.Errorf("Url should not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if net.ParseIP(originUrl) != nil {
|
||||||
|
return validateIP("", originUrl, "")
|
||||||
|
} else if strings.HasPrefix(originUrl, "[") && strings.HasSuffix(originUrl, "]") {
|
||||||
|
// ParseIP doesn't recoginze [::1]
|
||||||
|
return validateIP("", originUrl[1:len(originUrl)-1], "")
|
||||||
|
}
|
||||||
|
|
||||||
|
host, port, err := net.SplitHostPort(originUrl)
|
||||||
|
// user might pass in an ip address like 127.0.0.1
|
||||||
|
if err == nil && net.ParseIP(host) != nil {
|
||||||
|
return validateIP("", host, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
unescapedUrl, err := url.PathUnescape(originUrl)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("URL %s has invalid escape characters %s", originUrl, unescapedUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedUrl, err := url.Parse(unescapedUrl)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("URL %s has invalid format", originUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the url is in the form of host:port, IsAbs() will think host is the schema
|
||||||
|
var hostname string
|
||||||
|
hasScheme := parsedUrl.IsAbs() && parsedUrl.Host != ""
|
||||||
|
if hasScheme {
|
||||||
|
err := validateScheme(parsedUrl.Scheme)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// The earlier check for ip address will miss the case http://[::1]
|
||||||
|
// and http://[::1]:8080
|
||||||
|
if net.ParseIP(parsedUrl.Hostname()) != nil {
|
||||||
|
return validateIP(parsedUrl.Scheme, parsedUrl.Hostname(), parsedUrl.Port())
|
||||||
|
}
|
||||||
|
hostname, err = ValidateHostname(parsedUrl.Hostname())
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("URL %s has invalid format", originUrl)
|
||||||
|
}
|
||||||
|
if parsedUrl.Port() != "" {
|
||||||
|
return fmt.Sprintf("%s://%s", parsedUrl.Scheme, net.JoinHostPort(hostname, parsedUrl.Port())), nil
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s://%s", parsedUrl.Scheme, hostname), nil
|
||||||
|
} else {
|
||||||
|
if host == "" {
|
||||||
|
hostname, err = ValidateHostname(originUrl)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("URL no %s has invalid format", originUrl)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s://%s", defaultScheme, hostname), nil
|
||||||
|
} else {
|
||||||
|
hostname, err = ValidateHostname(host)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("URL %s has invalid format", originUrl)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateScheme(scheme string) error {
|
||||||
|
for _, protocol := range supportedProtocol {
|
||||||
|
if scheme == protocol {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("Currently Cloudflare-Warp does not support %s protocol.", scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateIP(scheme, host, port string) (string, error) {
|
||||||
|
if scheme == "" {
|
||||||
|
scheme = defaultScheme
|
||||||
|
}
|
||||||
|
if port != "" {
|
||||||
|
return fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(host, port)), nil
|
||||||
|
} else if strings.Contains(host, ":") {
|
||||||
|
// IPv6
|
||||||
|
return fmt.Sprintf("%s://[%s]", scheme, host), nil
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s://%s", scheme, host), nil
|
||||||
|
}
|
|
@ -0,0 +1,136 @@
|
||||||
|
package validation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateHostname(t *testing.T) {
|
||||||
|
var inputHostname string
|
||||||
|
hostname, err := ValidateHostname(inputHostname)
|
||||||
|
assert.Equal(t, err, fmt.Errorf("Hostname should not be empty"))
|
||||||
|
assert.Empty(t, hostname)
|
||||||
|
|
||||||
|
inputHostname = "hello.example.com"
|
||||||
|
hostname, err = ValidateHostname(inputHostname)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "hello.example.com", hostname)
|
||||||
|
|
||||||
|
inputHostname = "http://hello.example.com"
|
||||||
|
hostname, err = ValidateHostname(inputHostname)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "hello.example.com", hostname)
|
||||||
|
|
||||||
|
inputHostname = "bücher.example.com"
|
||||||
|
hostname, err = ValidateHostname(inputHostname)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "xn--bcher-kva.example.com", hostname)
|
||||||
|
|
||||||
|
inputHostname = "http://bücher.example.com"
|
||||||
|
hostname, err = ValidateHostname(inputHostname)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "xn--bcher-kva.example.com", hostname)
|
||||||
|
|
||||||
|
inputHostname = "http%3A%2F%2Fhello.example.com"
|
||||||
|
hostname, err = ValidateHostname(inputHostname)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "hello.example.com", hostname)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateUrl(t *testing.T) {
|
||||||
|
validUrl, err := ValidateUrl("")
|
||||||
|
assert.Equal(t, fmt.Errorf("Url should not be empty"), err)
|
||||||
|
assert.Empty(t, validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("https://localhost:8080")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "https://localhost:8080", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("localhost:8080")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://localhost:8080", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("http://localhost")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://localhost", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("http://127.0.0.1:8080")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://127.0.0.1:8080", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("127.0.0.1:8080")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://127.0.0.1:8080", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("127.0.0.1")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://127.0.0.1", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("https://127.0.0.1:8080")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "https://127.0.0.1:8080", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("[::1]:8080")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://[::1]:8080", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("http://[::1]")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://[::1]", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("http://[::1]:8080")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://[::1]:8080", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("[::1]")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://[::1]", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("https://example.com")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "https://example.com", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("example.com")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://example.com", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("http://hello.example.com")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://hello.example.com", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("hello.example.com")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://hello.example.com", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("hello.example.com:8080")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://hello.example.com:8080", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("https://hello.example.com:8080")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "https://hello.example.com:8080", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("https://bücher.example.com")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "https://xn--bcher-kva.example.com", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("bücher.example.com")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "http://xn--bcher-kva.example.com", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("https%3A%2F%2Fhello.example.com")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "https://hello.example.com", validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
|
||||||
|
assert.Equal(t, "Currently Cloudflare-Warp does not support ftp protocol.", err.Error())
|
||||||
|
assert.Empty(t, validUrl)
|
||||||
|
|
||||||
|
validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "https://hello.example.com:8080", validUrl)
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue