Initial import
This commit is contained in:
commit
82cb539fbe
|
@ -0,0 +1,9 @@
|
|||
# Cloudflare Warp client
|
||||
|
||||
Contains the command-line client and its libraries for Cloudflare Warp, a tunneling daemon that proxies any local webserver through the Cloudflare network.
|
||||
|
||||
## Getting started
|
||||
|
||||
go install github.com/cloudflare/cloudflare-warp/cmd/cloudflare-warp
|
||||
|
||||
User documentation for Warp can be found at https://warp.cloudflare.com
|
|
@ -0,0 +1,39 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
)
|
||||
|
||||
const cloudflareRootCA = `-----BEGIN CERTIFICATE-----
|
||||
MIID/DCCAuagAwIBAgIID+rOSdTGfGcwCwYJKoZIhvcNAQELMIGLMQswCQYDVQQG
|
||||
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRG
|
||||
bGFyZSBPcmlnaW4gU1NMIENlcnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMN
|
||||
U2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTAeFw0xNDExMTMyMDM4
|
||||
NTBaFw0xOTExMTQwMTQzNTBaMIGLMQswCQYDVQQGEwJVUzEZMBcGA1UEChMQQ2xv
|
||||
dWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRGbGFyZSBPcmlnaW4gU1NMIENl
|
||||
cnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEG
|
||||
A1UECBMKQ2FsaWZvcm5pYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
|
||||
AMBIlWf1KEKR5hbB75OYrAcUXobpD/AxvSYRXr91mbRu+lqE7YbyyRUShQh15lem
|
||||
ef+umeEtPZoLFLhcLyczJxOhI+siLGDQm/a/UDkWvAXYa5DZ+pHU5ct5nZ8pGzqJ
|
||||
p8G1Hy5RMVYDXZT9F6EaHjMG0OOffH6Ih25TtgfyyrjXycwDH0u6GXt+G/rywcqz
|
||||
/9W4Aki3XNQMUHNQAtBLEEIYHMkyTYJxuL2tXO6ID5cCsoWw8meHufTeZW2DyUpl
|
||||
yP3AHt4149RQSyWZMJ6AyntL9d8Xhfpxd9rJkh9Kge2iV9rQTFuE1rRT5s7OSJcK
|
||||
xUsklgHcGHYMcNfNMilNHb8CAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgAGMBIGA1Ud
|
||||
EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFCToU1ddfDRAh6nrlNu64RZ4/CmkMB8G
|
||||
A1UdIwQYMBaAFCToU1ddfDRAh6nrlNu64RZ4/CmkMAsGCSqGSIb3DQEBCwOCAQEA
|
||||
cQDBVAoRrhhsGegsSFsv1w8v27zzHKaJNv6ffLGIRvXK8VKKK0gKXh2zQtN9SnaD
|
||||
gYNe7Pr4C3I8ooYKRJJWLsmEHdGdnYYmj0OJfGrfQf6MLIc/11bQhLepZTxdhFYh
|
||||
QGgDl6gRmb8aDwk7Q92BPvek5nMzaWlP82ixavvYI+okoSY8pwdcVKobx6rWzMWz
|
||||
ZEC9M6H3F0dDYE23XcCFIdgNSAmmGyXPBstOe0aAJXwJTxOEPn36VWr0PKIQJy5Y
|
||||
4o1wpMpqCOIwWc8J9REV/REzN6Z1LXImdUgXIXOwrz56gKUJzPejtBQyIGj0mveX
|
||||
Fu6q54beR89jDc+oABmOgg==
|
||||
-----END CERTIFICATE-----`
|
||||
|
||||
func GetCloudflareRootCA() *x509.CertPool {
|
||||
ca := x509.NewCertPool()
|
||||
if !ca.AppendCertsFromPEM([]byte(cloudflareRootCA)) {
|
||||
// should never happen
|
||||
panic("failure loading Cloudflare origin CA pem")
|
||||
}
|
||||
return ca
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
// +build !windows,!darwin,!linux
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
||||
func runApp(app *cli.App) {
|
||||
app.Run(os.Args)
|
||||
}
|
|
@ -0,0 +1,155 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/cloudflare/cloudflare-warp/origin"
|
||||
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
||||
type templateData struct {
|
||||
ServerName string
|
||||
Request *http.Request
|
||||
Tags []tunnelpogs.Tag
|
||||
}
|
||||
|
||||
const defaultServerName = "the Cloudflare Warp test server"
|
||||
const indexTemplate = `
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=Edge">
|
||||
<title>
|
||||
Cloudflare Warp Connection
|
||||
</title>
|
||||
<meta name="author" content="">
|
||||
<meta name="description" content="Cloudflare Warp Connection">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<style>
|
||||
html{line-height:1.15;-ms-text-size-adjust:100%;-webkit-text-size-adjust:100%}body{margin:0}section{display:block}h1{font-size:2em;margin:.67em 0}a{background-color:transparent;-webkit-text-decoration-skip:objects}/* 1 */::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}/* 1 */a,body,dd,div,dl,dt,h1,h4,html,p,section{box-sizing:border-box}.bt{border-top-style:solid;border-top-width:1px}.bl{border-left-style:solid;border-left-width:1px}.b--orange{border-color:#f38020}.br1{border-radius:.125rem}.bw2{border-width:.25rem}.dib{display:inline-block}.sans-serif{font-family:open sans,-apple-system,BlinkMacSystemFont,avenir next,avenir,helvetica neue,helvetica,ubuntu,roboto,noto,segoe ui,arial,sans-serif}.code{font-family:Consolas,monaco,monospace}.b{font-weight:700}.fw3{font-weight:300}.fw4{font-weight:400}.fw5{font-weight:500}.fw6{font-weight:600}.lh-copy{line-height:1.5}.link{text-decoration:none}.link,.link:active,.link:focus,.link:hover,.link:link,.link:visited{transition:color .15s ease-in}.link:focus{outline:1px dotted currentColor}.mw-100{max-width:100%}.mw4{max-width:8rem}.mw7{max-width:48rem}.bg-light-gray{background-color:#f7f7f7}.link-hover:hover{background-color:#1f679e}.white{color:#fff}.bg-white{background-color:#fff}.bg-blue{background-color:#408bc9}.pb2{padding-bottom:.5rem}.pb6{padding-bottom:8rem}.pt3{padding-top:1rem}.pt5{padding-top:4rem}.pv2{padding-top:.5rem;padding-bottom:.5rem}.ph3{padding-left:1rem;padding-right:1rem}.ph4{padding-left:2rem;padding-right:2rem}.ml0{margin-left:0}.mb1{margin-bottom:.25rem}.mb2{margin-bottom:.5rem}.mb3{margin-bottom:1rem}.mt5{margin-top:4rem}.ttu{text-transform:uppercase}.f4{font-size:1.25rem}.f5{font-size:1rem}.f6{font-size:.875rem}.f7{font-size:.75rem}.measure{max-width:30em}.center{margin-left:auto}.center{margin-right:auto}@media screen and (min-width:30em){.f2-ns{font-size:2.25rem}}@media screen and (min-width:30em) and (max-width:60em){.f5-m{font-size:1rem}}@media screen and (min-width:60em){.f4-l{font-size:1.25rem}}
|
||||
.st0{fill:#FFF}.st1{fill:#f48120}.st2{fill:#faad3f}.st3{fill:#404041}
|
||||
</style>
|
||||
</head>
|
||||
<body class="sans-serif black">
|
||||
<div class="bt bw2 b--orange bg-white pb6">
|
||||
<div class="mw7 center ph4 pt3">
|
||||
<svg id="Layer_2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 109 40.5" class="mw4">
|
||||
<path class="st0" d="M98.6 14.2L93 12.9l-1-.4-25.7.2v12.4l32.3.1z"/>
|
||||
<path class="st1" d="M88.1 24c.3-1 .2-2-.3-2.6-.5-.6-1.2-1-2.1-1.1l-17.4-.2c-.1 0-.2-.1-.3-.1-.1-.1-.1-.2 0-.3.1-.2.2-.3.4-.3l17.5-.2c2.1-.1 4.3-1.8 5.1-3.8l1-2.6c0-.1.1-.2 0-.3-1.1-5.1-5.7-8.9-11.1-8.9-5 0-9.3 3.2-10.8 7.7-1-.7-2.2-1.1-3.6-1-2.4.2-4.3 2.2-4.6 4.6-.1.6 0 1.2.1 1.8-3.9.1-7.1 3.3-7.1 7.3 0 .4 0 .7.1 1.1 0 .2.2.3.3.3h32.1c.2 0 .4-.1.4-.3l.3-1.1z"/>
|
||||
<path class="st2" d="M93.6 12.8h-.5c-.1 0-.2.1-.3.2l-.7 2.4c-.3 1-.2 2 .3 2.6.5.6 1.2 1 2.1 1.1l3.7.2c.1 0 .2.1.3.1.1.1.1.2 0 .3-.1.2-.2.3-.4.3l-3.8.2c-2.1.1-4.3 1.8-5.1 3.8l-.2.9c-.1.1 0 .3.2.3h13.2c.2 0 .3-.1.3-.3.2-.8.4-1.7.4-2.6 0-5.2-4.3-9.5-9.5-9.5"/>
|
||||
<path class="st3" d="M104.4 30.8c-.5 0-.9-.4-.9-.9s.4-.9.9-.9.9.4.9.9-.4.9-.9.9m0-1.6c-.4 0-.7.3-.7.7 0 .4.3.7.7.7.4 0 .7-.3.7-.7 0-.4-.3-.7-.7-.7m.4 1.2h-.2l-.2-.3h-.2v.3h-.2v-.9h.5c.2 0 .3.1.3.3 0 .1-.1.2-.2.3l.2.3zm-.3-.5c.1 0 .1 0 .1-.1s-.1-.1-.1-.1h-.3v.3h.3zM14.8 29H17v6h3.8v1.9h-6zM23.1 32.9c0-2.3 1.8-4.1 4.3-4.1s4.2 1.8 4.2 4.1-1.8 4.1-4.3 4.1c-2.4 0-4.2-1.8-4.2-4.1m6.3 0c0-1.2-.8-2.2-2-2.2s-2 1-2 2.1.8 2.1 2 2.1c1.2.2 2-.8 2-2M34.3 33.4V29h2.2v4.4c0 1.1.6 1.7 1.5 1.7s1.5-.5 1.5-1.6V29h2.2v4.4c0 2.6-1.5 3.7-3.7 3.7-2.3-.1-3.7-1.2-3.7-3.7M45 29h3.1c2.8 0 4.5 1.6 4.5 3.9s-1.7 4-4.5 4h-3V29zm3.1 5.9c1.3 0 2.2-.7 2.2-2s-.9-2-2.2-2h-.9v4h.9zM55.7 29H62v1.9h-4.1v1.3h3.7V34h-3.7v2.9h-2.2zM65.1 29h2.2v6h3.8v1.9h-6zM76.8 28.9H79l3.4 8H80l-.6-1.4h-3.1l-.6 1.4h-2.3l3.4-8zm2 4.9l-.9-2.2-.9 2.2h1.8zM85.2 29h3.7c1.2 0 2 .3 2.6.9.5.5.7 1.1.7 1.8 0 1.2-.6 2-1.6 2.4l1.9 2.8H90l-1.6-2.4h-1v2.4h-2.2V29zm3.6 3.8c.7 0 1.2-.4 1.2-.9 0-.6-.5-.9-1.2-.9h-1.4v1.9h1.4zM95.3 29h6.4v1.8h-4.2V32h3.8v1.8h-3.8V35h4.3v1.9h-6.5zM10 33.9c-.3.7-1 1.2-1.8 1.2-1.2 0-2-1-2-2.1s.8-2.1 2-2.1c.9 0 1.6.6 1.9 1.3h2.3c-.4-1.9-2-3.3-4.2-3.3-2.4 0-4.3 1.8-4.3 4.1s1.8 4.1 4.2 4.1c2.1 0 3.7-1.4 4.2-3.2H10z"/>
|
||||
</svg>
|
||||
<h1 class="f4 f2-ns mt5 fw5">Congrats! You created your first tunnel!</h1>
|
||||
<p class="f6 f5-m f4-l measure lh-copy fw3">
|
||||
Cloudflare Warp exposes locally running applications to the internet by
|
||||
running an encrypted, virtual tunnel from your laptop or server to
|
||||
Cloudflare's edge network.
|
||||
</p>
|
||||
<p class="b f5 mt5 fw6">Ready for the next step?</p>
|
||||
<a
|
||||
class="fw6 link white bg-blue ph4 pv2 br1 dib f5 link-hover"
|
||||
style="border-bottom: 1px solid #1f679e"
|
||||
href="https://warp.cloudflare.com">
|
||||
Get started here
|
||||
</a>
|
||||
{{if .Tags}} <section>
|
||||
<h4 class="f6 fw4 pt5 mb2">Connection</h4>
|
||||
<dl class="bl bw2 b--orange ph3 pt3 pb2 bg-light-gray f7 code overflow-x-auto mw-100">
|
||||
{{range .Tags}} <dt class="ttu mb1">{{.Name}}</dt>
|
||||
<dd class="ml0 mb3 f5">{{.Value}}</dd>
|
||||
{{end}} </dl>
|
||||
</section>
|
||||
{{end}} </div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
|
||||
func hello(c *cli.Context) error {
|
||||
address := fmt.Sprintf(":%d", c.Int("port"))
|
||||
server := NewHelloWorldServer()
|
||||
if hostname, err := os.Hostname(); err != nil {
|
||||
server.serverName = hostname
|
||||
}
|
||||
err := server.ListenAndServe(address)
|
||||
return errors.Wrap(err, "Fail to start Hello World Server")
|
||||
}
|
||||
|
||||
func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
|
||||
server := NewHelloWorldServer()
|
||||
if hostname, err := os.Hostname(); err != nil {
|
||||
server.serverName = hostname
|
||||
}
|
||||
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: server}
|
||||
go func() {
|
||||
<-shutdownC
|
||||
httpServer.Close()
|
||||
}()
|
||||
err := httpServer.Serve(listener)
|
||||
return err
|
||||
}
|
||||
|
||||
type HelloWorldServer struct {
|
||||
responseTemplate *template.Template
|
||||
serverName string
|
||||
}
|
||||
|
||||
func NewHelloWorldServer() *HelloWorldServer {
|
||||
return &HelloWorldServer{
|
||||
responseTemplate: template.Must(template.New("index").Parse(indexTemplate)),
|
||||
serverName: defaultServerName,
|
||||
}
|
||||
}
|
||||
|
||||
func findAvailablePort() (net.Listener, error) {
|
||||
// If the port in address is empty, a port number is automatically chosen.
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:")
|
||||
return listener, err
|
||||
}
|
||||
|
||||
func (s *HelloWorldServer) ListenAndServe(address string) error {
|
||||
log.Infof("Starting Hello World server on %s", address)
|
||||
err := http.ListenAndServe(address, s)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto)
|
||||
var buffer bytes.Buffer
|
||||
err := s.responseTemplate.Execute(&buffer, &templateData{
|
||||
ServerName: s.serverName,
|
||||
Request: r,
|
||||
Tags: tagsFromHeaders(r.Header),
|
||||
})
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
fmt.Fprintf(w, "error: %v", err)
|
||||
} else {
|
||||
buffer.WriteTo(w)
|
||||
}
|
||||
}
|
||||
|
||||
func tagsFromHeaders(header http.Header) []tunnelpogs.Tag {
|
||||
var tags []tunnelpogs.Tag
|
||||
for headerName, headerValues := range header {
|
||||
trimmed := strings.TrimPrefix(headerName, origin.TagHeaderNamePrefix)
|
||||
if trimmed == headerName {
|
||||
continue
|
||||
}
|
||||
for _, value := range headerValues {
|
||||
tags = append(tags, tunnelpogs.Tag{Name: trimmed, Value: value})
|
||||
}
|
||||
}
|
||||
return tags
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
const testPort = "8080"
|
||||
|
||||
func TestNewHelloWorldServer(t *testing.T) {
|
||||
if NewHelloWorldServer() == nil {
|
||||
t.Fatal("NewHelloWorldServer returned nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAvailablePort(t *testing.T) {
|
||||
listener, err := findAvailablePort()
|
||||
if err != nil {
|
||||
t.Fatal("Fail to find available port")
|
||||
}
|
||||
if listener.Addr().String() == "" {
|
||||
t.Fatal("Fail to find available port")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,264 @@
|
|||
// +build linux
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
||||
func runApp(app *cli.App) {
|
||||
app.Commands = append(app.Commands, &cli.Command{
|
||||
Name: "service",
|
||||
Usage: "Manages the Cloudflare Warp system service",
|
||||
Subcommands: []*cli.Command{
|
||||
&cli.Command{
|
||||
Name: "install",
|
||||
Usage: "Install Cloudflare Warp as a system service",
|
||||
Action: installLinuxService,
|
||||
},
|
||||
&cli.Command{
|
||||
Name: "uninstall",
|
||||
Usage: "Uninstall the Cloudflare Warp service",
|
||||
Action: uninstallLinuxService,
|
||||
},
|
||||
},
|
||||
})
|
||||
app.Run(os.Args)
|
||||
}
|
||||
|
||||
var systemdTemplates = []ServiceTemplate{
|
||||
{
|
||||
Path: "/etc/systemd/system/cloudflare-warp.service",
|
||||
Content: `[Unit]
|
||||
Description=Cloudflare Warp
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
TimeoutStartSec=0
|
||||
Type=notify
|
||||
ExecStart={{ .Path }} --config /etc/cloudflare-warp.yml --autoupdate 0s
|
||||
User=nobody
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
`,
|
||||
},
|
||||
{
|
||||
Path: "/etc/systemd/system/cloudflare-warp-update.service",
|
||||
Content: `[Unit]
|
||||
Description=Update Cloudflare Warp
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 64 ]; then systemctl restart cloudflare-warp; exit 0; fi; exit $code'
|
||||
`,
|
||||
},
|
||||
{
|
||||
Path: "/etc/systemd/system/cloudflare-warp-update.timer",
|
||||
Content: `[Unit]
|
||||
Description=Update Cloudflare Warp
|
||||
|
||||
[Timer]
|
||||
OnUnitActiveSec=1d
|
||||
|
||||
[Install]
|
||||
WantedBy=timers.target
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
var sysvTemplate = ServiceTemplate{
|
||||
Path: "/etc/init.d/cloudflare-warp",
|
||||
FileMode: 0755,
|
||||
Content: `# For RedHat and cousins:
|
||||
# chkconfig: 2345 99 01
|
||||
# description: Cloudflare Warp agent
|
||||
# processname: {{.Path}}
|
||||
### BEGIN INIT INFO
|
||||
# Provides: {{.Path}}
|
||||
# Required-Start:
|
||||
# Required-Stop:
|
||||
# Default-Start: 2 3 4 5
|
||||
# Default-Stop: 0 1 6
|
||||
# Short-Description: Cloudflare Warp
|
||||
# Description: Cloudflare Warp agent
|
||||
### END INIT INFO
|
||||
cmd="{{.Path}} --config /etc/cloudflare-warp.yml --pidfile /var/run/$name.pid"
|
||||
name=$(basename $(readlink -f $0))
|
||||
pid_file="/var/run/$name.pid"
|
||||
stdout_log="/var/log/$name.log"
|
||||
stderr_log="/var/log/$name.err"
|
||||
[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name
|
||||
get_pid() {
|
||||
cat "$pid_file"
|
||||
}
|
||||
is_running() {
|
||||
[ -f "$pid_file" ] && ps $(get_pid) > /dev/null 2>&1
|
||||
}
|
||||
case "$1" in
|
||||
start)
|
||||
if is_running; then
|
||||
echo "Already started"
|
||||
else
|
||||
echo "Starting $name"
|
||||
$cmd >> "$stdout_log" 2>> "$stderr_log" &
|
||||
echo $! > "$pid_file"
|
||||
if ! is_running; then
|
||||
echo "Unable to start, see $stdout_log and $stderr_log"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
stop)
|
||||
if is_running; then
|
||||
echo -n "Stopping $name.."
|
||||
kill $(get_pid)
|
||||
for i in {1..10}
|
||||
do
|
||||
if ! is_running; then
|
||||
break
|
||||
fi
|
||||
echo -n "."
|
||||
sleep 1
|
||||
done
|
||||
echo
|
||||
if is_running; then
|
||||
echo "Not stopped; may still be shutting down or shutdown may have failed"
|
||||
exit 1
|
||||
else
|
||||
echo "Stopped"
|
||||
if [ -f "$pid_file" ]; then
|
||||
rm "$pid_file"
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "Not running"
|
||||
fi
|
||||
;;
|
||||
restart)
|
||||
$0 stop
|
||||
if is_running; then
|
||||
echo "Unable to stop, will not attempt to start"
|
||||
exit 1
|
||||
fi
|
||||
$0 start
|
||||
;;
|
||||
status)
|
||||
if is_running; then
|
||||
echo "Running"
|
||||
else
|
||||
echo "Stopped"
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 {start|stop|restart|status}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
exit 0
|
||||
`,
|
||||
}
|
||||
|
||||
func isSystemd() bool {
|
||||
if _, err := os.Stat("/run/systemd/system"); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func installLinuxService(c *cli.Context) error {
|
||||
etPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error determining executable path: %v", err)
|
||||
}
|
||||
templateArgs := ServiceTemplateArgs{Path: etPath}
|
||||
|
||||
switch {
|
||||
case isSystemd():
|
||||
return installSystemd(&templateArgs)
|
||||
default:
|
||||
return installSysv(&templateArgs)
|
||||
}
|
||||
}
|
||||
|
||||
func installSystemd(templateArgs *ServiceTemplateArgs) error {
|
||||
for _, serviceTemplate := range systemdTemplates {
|
||||
err := serviceTemplate.Generate(templateArgs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := runCommand("systemctl", "enable", "cloudflare-warp.service"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := runCommand("systemctl", "start", "cloudflare-warp-update.timer"); err != nil {
|
||||
return err
|
||||
}
|
||||
return runCommand("systemctl", "daemon-reload")
|
||||
}
|
||||
|
||||
func installSysv(templateArgs *ServiceTemplateArgs) error {
|
||||
confPath, err := sysvTemplate.ResolvePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := sysvTemplate.Generate(templateArgs); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, i := range [...]string{"2", "3", "4", "5"} {
|
||||
if err := os.Symlink(confPath, "/etc/rc"+i+".d/S50et"); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
for _, i := range [...]string{"0", "1", "6"} {
|
||||
if err := os.Symlink(confPath, "/etc/rc"+i+".d/K02et"); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func uninstallLinuxService(c *cli.Context) error {
|
||||
switch {
|
||||
case isSystemd():
|
||||
return uninstallSystemd()
|
||||
default:
|
||||
return uninstallSysv()
|
||||
}
|
||||
}
|
||||
|
||||
func uninstallSystemd() error {
|
||||
if err := runCommand("systemctl", "disable", "cloudflare-warp.service"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := runCommand("systemctl", "stop", "cloudflare-warp-update.timer"); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, serviceTemplate := range systemdTemplates {
|
||||
if err := serviceTemplate.Remove(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func uninstallSysv() error {
|
||||
if err := sysvTemplate.Remove(); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, i := range [...]string{"2", "3", "4", "5"} {
|
||||
if err := os.Remove("/etc/rc" + i + ".d/S50et"); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
for _, i := range [...]string{"0", "1", "6"} {
|
||||
if err := os.Remove("/etc/rc" + i + ".d/K02et"); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
||||
func login(c *cli.Context) error {
|
||||
path, err := homedir.Expand(defaultConfigPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fileInfo, err := os.Stat(path)
|
||||
if err == nil && fileInfo.Size() > 0 {
|
||||
fmt.Fprintf(os.Stderr, `You have an existing config file at %s which login would overwrite.
|
||||
If this is intentional, please move or delete that file then run this command again.
|
||||
`, defaultConfigPath)
|
||||
return nil
|
||||
}
|
||||
if err != nil && err.(*os.PathError).Err != syscall.ENOENT {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "Please visit https://www.cloudflare.com/a/warp to obtain a certificate.")
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
// +build darwin
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
||||
func runApp(app *cli.App) {
|
||||
app.Commands = append(app.Commands, &cli.Command{
|
||||
Name: "service",
|
||||
Usage: "Manages the Cloudflare Warp launch agent",
|
||||
Subcommands: []*cli.Command{
|
||||
&cli.Command{
|
||||
Name: "install",
|
||||
Usage: "Install Cloudflare Warp as an user launch agent",
|
||||
Action: installLaunchd,
|
||||
},
|
||||
&cli.Command{
|
||||
Name: "uninstall",
|
||||
Usage: "Uninstall the Cloudflare Warp launch agent",
|
||||
Action: uninstallLaunchd,
|
||||
},
|
||||
},
|
||||
})
|
||||
app.Run(os.Args)
|
||||
}
|
||||
|
||||
var launchdTemplate = ServiceTemplate{
|
||||
Path: "~/Library/LaunchAgents/com.cloudflare.warp.plist",
|
||||
Content: `<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>com.cloudflare.warp</string>
|
||||
<key>Program</key>
|
||||
<string>{{ .Path }}</string>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>KeepAlive</key>
|
||||
<dict>
|
||||
<key>NetworkState</key>
|
||||
<true/>
|
||||
</dict>
|
||||
<key>ThrottleInterval</key>
|
||||
<integer>20</integer>
|
||||
</dict>
|
||||
</plist>`,
|
||||
}
|
||||
|
||||
func installLaunchd(c *cli.Context) error {
|
||||
etPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error determining executable path: %v", err)
|
||||
}
|
||||
templateArgs := ServiceTemplateArgs{Path: etPath}
|
||||
err = launchdTemplate.Generate(&templateArgs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
plistPath, err := launchdTemplate.ResolvePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return runCommand("launchctl", "load", plistPath)
|
||||
}
|
||||
|
||||
func uninstallLaunchd(c *cli.Context) error {
|
||||
plistPath, err := launchdTemplate.ResolvePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = runCommand("launchctl", "unload", plistPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return launchdTemplate.Remove()
|
||||
}
|
|
@ -0,0 +1,473 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflare-warp/h2mux"
|
||||
"github.com/cloudflare/cloudflare-warp/metrics"
|
||||
"github.com/cloudflare/cloudflare-warp/origin"
|
||||
"github.com/cloudflare/cloudflare-warp/tlsconfig"
|
||||
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||
"github.com/cloudflare/cloudflare-warp/validation"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/facebookgo/grace/gracenet"
|
||||
raven "github.com/getsentry/raven-go"
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
"gopkg.in/urfave/cli.v2/altsrc"
|
||||
|
||||
"github.com/coreos/go-systemd/daemon"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
|
||||
const defaultConfigPath = "~/.cloudflare-warp.yml"
|
||||
|
||||
var listeners = gracenet.Net{}
|
||||
var Version = "DEV"
|
||||
var BuildTime = "unknown"
|
||||
|
||||
// Shutdown channel used by the app. When closed, app must terminate.
|
||||
// May be closed by the Windows service runner.
|
||||
var shutdownC chan struct{}
|
||||
|
||||
func main() {
|
||||
metrics.RegisterBuildInfo(BuildTime, Version)
|
||||
raven.SetDSN(sentryDSN)
|
||||
raven.SetRelease(Version)
|
||||
shutdownC = make(chan struct{})
|
||||
app := &cli.App{}
|
||||
app.Name = "cloudflare-warp"
|
||||
app.Copyright = `(c) 2017 Cloudflare Inc.
|
||||
Use is subject to the license agreement at https://warp.cloudflare.com/licence/`
|
||||
app.Usage = "Cloudflare reverse tunnelling proxy agent \033[1;31m*BETA*\033[0m"
|
||||
app.ArgsUsage = "origin-url"
|
||||
app.Version = fmt.Sprintf("%s (built %s)", Version, BuildTime)
|
||||
app.Description = `A reverse tunnel proxy agent that connects to Cloudflare's infrastructure.
|
||||
Upon connecting, you are assigned a unique subdomain on cftunnel.com.
|
||||
Alternatively, you can specify a hostname on a zone you control.
|
||||
|
||||
Requests made to Cloudflare's servers for your hostname will be proxied
|
||||
through the tunnel to your local webserver.
|
||||
|
||||
WARNING:
|
||||
` + "\033[1;31m*** THIS IS A BETA VERSION OF THE CLOUDFLARE WARP AGENT ***\033[0m" + `
|
||||
|
||||
At this time, do not use Cloudflare Warp for connecting production servers to Cloudflare.
|
||||
Availability and reliability of this service is not guaranteed through the beta period.`
|
||||
app.Flags = []cli.Flag{
|
||||
&cli.StringFlag{
|
||||
Name: "config",
|
||||
Usage: "Specifies a config file in YAML format.",
|
||||
},
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||
Name: "autoupdate",
|
||||
Usage: "Periodically check for updates, restarting the server with the new version.",
|
||||
Value: time.Hour * 24,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "edge",
|
||||
Value: "cftunnel.com:7844",
|
||||
Usage: "Address of the Cloudflare tunnel server.",
|
||||
EnvVars: []string{"TUNNEL_EDGE"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "cacert",
|
||||
Usage: "Certificate Authority authenticating the Cloudflare tunnel connection.",
|
||||
EnvVars: []string{"TUNNEL_CACERT"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "url",
|
||||
Value: "http://localhost:8080",
|
||||
Usage: "Connect to the local webserver at `URL`.",
|
||||
EnvVars: []string{"TUNNEL_URL"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "hostname",
|
||||
Usage: "Set a hostname on a Cloudflare zone to route traffic through this tunnel.",
|
||||
EnvVars: []string{"TUNNEL_HOSTNAME"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "id",
|
||||
Usage: "A unique identifier used to tie connections to this tunnel instance.",
|
||||
EnvVars: []string{"TUNNEL_ID"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "lb-pool",
|
||||
Usage: "The name of a (new/existing) load balancing pool to add this origin to.",
|
||||
EnvVars: []string{"TUNNEL_LB_POOL"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "api-key",
|
||||
Usage: "A Cloudflare API key. Required(can be in the config file) unless you are only running the hello command or login command.",
|
||||
EnvVars: []string{"TUNNEL_API_KEY"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "api-email",
|
||||
Usage: "The Cloudflare user's email address associated with the API key. Required(can be in the config file) unless you are only running the hello command or login command.",
|
||||
EnvVars: []string{"TUNNEL_API_EMAIL"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "api-ca-key",
|
||||
Usage: "The Origin CA service key associated with the user. Required(can be in the config file) unless you are only running the hello command or login command.",
|
||||
EnvVars: []string{"TUNNEL_API_CA_KEY"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "metrics",
|
||||
Value: "localhost:",
|
||||
Usage: "Listen address for metrics reporting.",
|
||||
EnvVars: []string{"TUNNEL_METRICS"},
|
||||
}),
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||
Name: "tag",
|
||||
Usage: "Custom tags used to identify this tunnel, in format `KEY=VALUE`. Multiple tags may be specified",
|
||||
EnvVars: []string{"TUNNEL_TAG"},
|
||||
}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||
Name: "heartbeat-interval",
|
||||
Usage: "Minimum idle time before sending a heartbeat.",
|
||||
Value: time.Second * 5,
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewUint64Flag(&cli.Uint64Flag{
|
||||
Name: "heartbeat-count",
|
||||
Usage: "Minimum number of unacked heartbeats to send before closing the connection.",
|
||||
Value: 5,
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "loglevel",
|
||||
Value: "info",
|
||||
Usage: "Logging level {panic, fatal, error, warn, info, debug}",
|
||||
EnvVars: []string{"TUNNEL_LOGLEVEL"},
|
||||
}),
|
||||
altsrc.NewUintFlag(&cli.UintFlag{
|
||||
Name: "retries",
|
||||
Value: 5,
|
||||
Usage: "Maximum number of retries for connection/protocol errors.",
|
||||
EnvVars: []string{"TUNNEL_RETRIES"},
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "debug",
|
||||
Value: false,
|
||||
Usage: "Enable HTTP requests to the autogenerated cftunnel.com domain.",
|
||||
EnvVars: []string{"TUNNEL_DEBUG"},
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "hello-world",
|
||||
Usage: "Run Hello World Server",
|
||||
Value: false,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "pidfile",
|
||||
Usage: "Write the application's PID to this file after first successful connection.",
|
||||
EnvVars: []string{"TUNNEL_PIDFILE"},
|
||||
}),
|
||||
}
|
||||
app.Action = func(c *cli.Context) error {
|
||||
raven.CapturePanic(func() { startServer(c) }, nil)
|
||||
return nil
|
||||
}
|
||||
app.Before = func(context *cli.Context) error {
|
||||
inputSource, err := findInputSourceContext(context)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if inputSource != nil {
|
||||
return altsrc.ApplyInputSourceValues(context, inputSource, app.Flags)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
app.Commands = []*cli.Command{
|
||||
&cli.Command{
|
||||
Name: "update",
|
||||
Action: update,
|
||||
Usage: "Update the agent if a new version exists",
|
||||
ArgsUsage: " ",
|
||||
Description: `Looks for a new version on the offical download server.
|
||||
If a new version exists, updates the agent binary and quits.
|
||||
Otherwise, does nothing.
|
||||
|
||||
To determine if an update happened in a script, check for error code 64.`,
|
||||
},
|
||||
&cli.Command{
|
||||
Name: "login",
|
||||
Action: login,
|
||||
Usage: "Generate a configuration file with your login details",
|
||||
ArgsUsage: " ",
|
||||
},
|
||||
&cli.Command{
|
||||
Name: "hello",
|
||||
Action: hello,
|
||||
Usage: "Run a simple \"Hello World\" server for testing Cloudflare Warp.",
|
||||
Flags: []cli.Flag{
|
||||
&cli.IntFlag{
|
||||
Name: "port",
|
||||
Usage: "Listen on the selected port.",
|
||||
Value: 8080,
|
||||
},
|
||||
},
|
||||
ArgsUsage: " ", // can't be the empty string or we get the default output
|
||||
},
|
||||
}
|
||||
runApp(app)
|
||||
}
|
||||
|
||||
func startServer(c *cli.Context) {
|
||||
var wg sync.WaitGroup
|
||||
errC := make(chan error)
|
||||
wg.Add(2)
|
||||
|
||||
if c.NumFlags() == 0 && c.NArg() == 0 {
|
||||
cli.ShowAppHelp(c)
|
||||
return
|
||||
}
|
||||
|
||||
logLevel, err := log.ParseLevel(c.String("loglevel"))
|
||||
if err != nil {
|
||||
log.WithError(err).Fatal("Unknown logging level specified")
|
||||
}
|
||||
log.SetLevel(logLevel)
|
||||
hostname, err := validation.ValidateHostname(c.String("hostname"))
|
||||
if err != nil {
|
||||
log.WithError(err).Fatal("Invalid hostname")
|
||||
|
||||
}
|
||||
clientID := c.String("id")
|
||||
if !c.IsSet("id") {
|
||||
clientID = generateRandomClientID()
|
||||
}
|
||||
|
||||
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
|
||||
if err != nil {
|
||||
log.WithError(err).Fatal("Tag parse failure")
|
||||
}
|
||||
|
||||
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
|
||||
|
||||
if c.IsSet("hello-world") {
|
||||
wg.Add(1)
|
||||
listener, err := findAvailablePort()
|
||||
|
||||
if err != nil {
|
||||
listener.Close()
|
||||
log.WithError(err).Fatal("Cannot start Hello World Server")
|
||||
}
|
||||
go func() {
|
||||
startHelloWorldServer(listener, shutdownC)
|
||||
wg.Done()
|
||||
}()
|
||||
c.Set("url", "http://"+listener.Addr().String())
|
||||
log.Infof("Starting Hello World Server at %s", c.String("url"))
|
||||
}
|
||||
url, err := validateUrl(c)
|
||||
if err != nil {
|
||||
log.WithError(err).Fatal("Error validating url")
|
||||
}
|
||||
// User must have api-key, api-email and api-ca-key
|
||||
if !c.IsSet("api-key") {
|
||||
log.Fatal("You need to give us your api-key either via the --api-key option or put it in the configuration file. You will also need to give us your api-email and api-ca-key.")
|
||||
}
|
||||
if !c.IsSet("api-email") {
|
||||
log.Fatal("You need to give us your api-email either via the --api-email option or put it in the configuration file. You will also need to give us your api-ca-key.")
|
||||
}
|
||||
if !c.IsSet("api-ca-key") {
|
||||
log.Fatal("You need to give us your api-ca-key either via the --api-ca-key option or put it in the configuration file.")
|
||||
}
|
||||
log.Infof("Proxying tunnel requests to %s", url)
|
||||
tunnelConfig := &origin.TunnelConfig{
|
||||
EdgeAddr: c.String("edge"),
|
||||
OriginUrl: url,
|
||||
Hostname: hostname,
|
||||
APIKey: c.String("api-key"),
|
||||
APIEmail: c.String("api-email"),
|
||||
APICAKey: c.String("api-ca-key"),
|
||||
TlsConfig: &tls.Config{},
|
||||
Retries: c.Uint("retries"),
|
||||
HeartbeatInterval: c.Duration("heartbeat-interval"),
|
||||
MaxHeartbeats: c.Uint64("heartbeat-count"),
|
||||
ClientID: clientID,
|
||||
ReportedVersion: Version,
|
||||
LBPool: c.String("lb-pool"),
|
||||
Tags: tags,
|
||||
AccessInternalIP: c.Bool("debug"),
|
||||
ConnectedSignal: h2mux.NewSignal(),
|
||||
}
|
||||
|
||||
tunnelConfig.TlsConfig = tlsconfig.CLIFlags{RootCA: "cacert"}.GetConfig(c)
|
||||
if tunnelConfig.TlsConfig.RootCAs == nil {
|
||||
tunnelConfig.TlsConfig.RootCAs = GetCloudflareRootCA()
|
||||
tunnelConfig.TlsConfig.ServerName = "cftunnel.com"
|
||||
} else {
|
||||
tunnelConfig.TlsConfig.ServerName, _, _ = net.SplitHostPort(tunnelConfig.EdgeAddr)
|
||||
}
|
||||
|
||||
go writePidFile(tunnelConfig.ConnectedSignal, c.String("pidfile"))
|
||||
go func() {
|
||||
errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
|
||||
if err != nil {
|
||||
log.WithError(err).Fatal("Error opening metrics server listener")
|
||||
}
|
||||
go func() {
|
||||
errC <- metrics.ServeMetrics(metricsListener, shutdownC)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go autoupdate(c.Duration("autoupdate"), shutdownC)
|
||||
|
||||
err = WaitForSignal(errC, shutdownC)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Quitting due to error")
|
||||
raven.CaptureErrorAndWait(err, nil)
|
||||
} else {
|
||||
log.Info("Quitting...")
|
||||
}
|
||||
// Wait for clean exit, discarding all errors
|
||||
go func() {
|
||||
for range errC {
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
|
||||
signals := make(chan os.Signal, 10)
|
||||
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
|
||||
defer signal.Stop(signals)
|
||||
select {
|
||||
case err := <-errC:
|
||||
close(shutdownC)
|
||||
return err
|
||||
case <-signals:
|
||||
close(shutdownC)
|
||||
case <-shutdownC:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func update(c *cli.Context) error {
|
||||
if updateApplied() {
|
||||
os.Exit(64)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func autoupdate(frequency time.Duration, shutdownC chan struct{}) {
|
||||
if int64(frequency) == 0 {
|
||||
return
|
||||
}
|
||||
for {
|
||||
if updateApplied() {
|
||||
if _, err := listeners.StartProcess(); err != nil {
|
||||
log.WithError(err).Error("Unable to restart server automatically")
|
||||
}
|
||||
close(shutdownC)
|
||||
return
|
||||
}
|
||||
time.Sleep(frequency)
|
||||
}
|
||||
}
|
||||
|
||||
func updateApplied() bool {
|
||||
releaseInfo := checkForUpdates()
|
||||
if releaseInfo.Updated {
|
||||
log.Infof("Updated to version %s", releaseInfo.Version)
|
||||
return true
|
||||
}
|
||||
if releaseInfo.Error != nil {
|
||||
log.WithError(releaseInfo.Error).Error("Update check failed")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func fileExists(path string) (bool, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// ignore missing files
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
f.Close()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) {
|
||||
if context.IsSet("config") {
|
||||
return altsrc.NewYamlSourceFromFile(context.String("config"))
|
||||
}
|
||||
for _, tryPath := range []string{
|
||||
defaultConfigPath,
|
||||
"~/.cloudflare-warp.yaml",
|
||||
"~/cloudflare-warp.yaml",
|
||||
"~/cloudflare-warp.yml",
|
||||
"~/.et.yaml",
|
||||
"~/et.yml",
|
||||
"~/et.yaml",
|
||||
"~/.cftunnel.yaml", // for existing users
|
||||
"~/cftunnel.yaml",
|
||||
} {
|
||||
path, err := homedir.Expand(tryPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ok, err := fileExists(path)
|
||||
if ok {
|
||||
return altsrc.NewYamlSourceFromFile(path)
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func generateRandomClientID() string {
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
id := make([]byte, 32)
|
||||
r.Read(id)
|
||||
return hex.EncodeToString(id)
|
||||
}
|
||||
|
||||
func writePidFile(waitForSignal h2mux.Signal, pidFile string) {
|
||||
waitForSignal.Wait()
|
||||
daemon.SdNotify(false, "READY=1")
|
||||
if pidFile == "" {
|
||||
return
|
||||
}
|
||||
file, err := os.Create(pidFile)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
|
||||
}
|
||||
defer file.Close()
|
||||
fmt.Fprintf(file, "%d", os.Getpid())
|
||||
}
|
||||
|
||||
// validate url. It can be either from --url or argument
|
||||
func validateUrl(c *cli.Context) (string, error) {
|
||||
var url = c.String("url")
|
||||
if c.NArg() > 0 {
|
||||
if c.IsSet("url") {
|
||||
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
|
||||
}
|
||||
url = c.Args().Get(0)
|
||||
}
|
||||
validUrl, err := validation.ValidateUrl(url)
|
||||
return validUrl, err
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"text/template"
|
||||
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
)
|
||||
|
||||
type ServiceTemplate struct {
|
||||
Path string
|
||||
Content string
|
||||
FileMode os.FileMode
|
||||
}
|
||||
|
||||
type ServiceTemplateArgs struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
func (st *ServiceTemplate) ResolvePath() (string, error) {
|
||||
resolvedPath, err := homedir.Expand(st.Path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error resolving path %s: %v", st.Path, err)
|
||||
}
|
||||
return resolvedPath, nil
|
||||
}
|
||||
|
||||
func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
|
||||
tmpl, err := template.New(st.Path).Parse(st.Content)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating %s template: %v", st.Path, err)
|
||||
}
|
||||
resolvedPath, err := st.ResolvePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var buffer bytes.Buffer
|
||||
err = tmpl.Execute(&buffer, args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating %s: %v", st.Path, err)
|
||||
}
|
||||
fileMode := os.FileMode(0644)
|
||||
if st.FileMode != 0 {
|
||||
fileMode = st.FileMode
|
||||
}
|
||||
err = ioutil.WriteFile(resolvedPath, buffer.Bytes(), fileMode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing %s: %v", resolvedPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *ServiceTemplate) Remove() error {
|
||||
resolvedPath, err := st.ResolvePath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = os.Remove(resolvedPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deleting %s: %v", resolvedPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runCommand(command string, args ...string) error {
|
||||
cmd := exec.Command(command, args...)
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting stderr pipe: %v", err)
|
||||
}
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error starting %s: %v", command, err)
|
||||
}
|
||||
commandErr, _ := ioutil.ReadAll(stderr)
|
||||
if len(commandErr) > 0 {
|
||||
return fmt.Errorf("%s error: %s", command, commandErr)
|
||||
}
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s returned with error: %v", command, err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
// Restrict key names to characters allowed in an HTTP header name.
|
||||
// Restrict key values to printable characters (what is recognised as data in an HTTP header value).
|
||||
var tagRegexp = regexp.MustCompile("^([a-zA-Z0-9!#$%&'*+\\-.^_`|~]+)=([[:print:]]+)$")
|
||||
|
||||
func NewTagFromCLI(compoundTag string) (tunnelpogs.Tag, bool) {
|
||||
matches := tagRegexp.FindStringSubmatch(compoundTag)
|
||||
if len(matches) == 0 {
|
||||
return tunnelpogs.Tag{}, false
|
||||
}
|
||||
return tunnelpogs.Tag{Name: matches[1], Value: matches[2]}, true
|
||||
}
|
||||
|
||||
func NewTagSliceFromCLI(tags []string) ([]tunnelpogs.Tag, error) {
|
||||
var tagSlice []tunnelpogs.Tag
|
||||
for _, compoundTag := range tags {
|
||||
if tag, ok := NewTagFromCLI(compoundTag); ok {
|
||||
tagSlice = append(tagSlice, tag)
|
||||
} else {
|
||||
return nil, fmt.Errorf("Cannot parse tag value %s", compoundTag)
|
||||
}
|
||||
}
|
||||
return tagSlice, nil
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSingleTag(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Input string
|
||||
Output tunnelpogs.Tag
|
||||
Fail bool
|
||||
}{
|
||||
{Input: "x=y", Output: tunnelpogs.Tag{Name: "x", Value: "y"}},
|
||||
{Input: "More-Complex=Tag Values", Output: tunnelpogs.Tag{Name: "More-Complex", Value: "Tag Values"}},
|
||||
{Input: "First=Equals=Wins", Output: tunnelpogs.Tag{Name: "First", Value: "Equals=Wins"}},
|
||||
{Input: "x=", Fail: true},
|
||||
{Input: "=y", Fail: true},
|
||||
{Input: "=", Fail: true},
|
||||
{Input: "No spaces allowed=in key names", Fail: true},
|
||||
{Input: "omg\nwtf=bbq", Fail: true},
|
||||
}
|
||||
for i, testCase := range testCases {
|
||||
tag, ok := NewTagFromCLI(testCase.Input)
|
||||
assert.Equalf(t, !testCase.Fail, ok, "mismatched success for test case %d", i)
|
||||
assert.Equalf(t, testCase.Output, tag, "mismatched output for test case %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagSlice(t *testing.T) {
|
||||
tagSlice, err := NewTagSliceFromCLI([]string{"a=b", "c=d", "e=f"})
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, tagSlice, 3)
|
||||
assert.Equal(t, "a", tagSlice[0].Name)
|
||||
assert.Equal(t, "b", tagSlice[0].Value)
|
||||
assert.Equal(t, "c", tagSlice[1].Name)
|
||||
assert.Equal(t, "d", tagSlice[1].Value)
|
||||
assert.Equal(t, "e", tagSlice[2].Name)
|
||||
assert.Equal(t, "f", tagSlice[2].Value)
|
||||
|
||||
tagSlice, err = NewTagSliceFromCLI([]string{"a=b", "=", "e=f"})
|
||||
assert.Error(t, err)
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
package main
|
||||
|
||||
import "github.com/equinox-io/equinox"
|
||||
|
||||
const appID = "app_cwbQae3Tpea"
|
||||
|
||||
var publicKey = []byte(`
|
||||
-----BEGIN ECDSA PUBLIC KEY-----
|
||||
MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE4OWZocTVZ8Do/L6ScLdkV+9A0IYMHoOf
|
||||
dsCmJ/QZ6aw0w9qkkwEpne1Lmo6+0pGexZzFZOH6w5amShn+RXt7qkSid9iWlzGq
|
||||
EKx0BZogHSor9Wy5VztdFaAaVbsJiCbO
|
||||
-----END ECDSA PUBLIC KEY-----
|
||||
`)
|
||||
|
||||
type ReleaseInfo struct {
|
||||
Updated bool
|
||||
Version string
|
||||
Error error
|
||||
}
|
||||
|
||||
func checkForUpdates() ReleaseInfo {
|
||||
var opts equinox.Options
|
||||
if err := opts.SetPublicKeyPEM(publicKey); err != nil {
|
||||
return ReleaseInfo{Error: err}
|
||||
}
|
||||
|
||||
resp, err := equinox.Check(appID, opts)
|
||||
switch {
|
||||
case err == equinox.NotAvailableErr:
|
||||
return ReleaseInfo{}
|
||||
case err != nil:
|
||||
return ReleaseInfo{Error: err}
|
||||
}
|
||||
|
||||
err = resp.Apply()
|
||||
if err != nil {
|
||||
return ReleaseInfo{Error: err}
|
||||
}
|
||||
|
||||
return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion}
|
||||
}
|
|
@ -0,0 +1,145 @@
|
|||
// +build windows
|
||||
|
||||
package main
|
||||
|
||||
// Copypasta from the example files:
|
||||
// https://github.com/golang/sys/blob/master/windows/svc/example
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
|
||||
"golang.org/x/sys/windows/svc"
|
||||
"golang.org/x/sys/windows/svc/eventlog"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
const (
|
||||
windowsServiceName = "CloudflareWarp"
|
||||
windowsServiceDescription = "Cloudflare Warp agent"
|
||||
)
|
||||
|
||||
func runApp(app *cli.App) {
|
||||
app.Commands = append(app.Commands, &cli.Command{
|
||||
Name: "service",
|
||||
Usage: "Manages the Cloudflare Warp Windows service",
|
||||
Subcommands: []*cli.Command{
|
||||
&cli.Command{
|
||||
Name: "install",
|
||||
Usage: "Install Cloudflare Warp as a Windows service",
|
||||
Action: installWindowsService,
|
||||
},
|
||||
&cli.Command{
|
||||
Name: "uninstall",
|
||||
Usage: "Uninstall the Cloudflare Warp service",
|
||||
Action: uninstallWindowsService,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
isIntSess, err := svc.IsAnInteractiveSession()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to determine if we are running in an interactive session: %v", err)
|
||||
}
|
||||
|
||||
if isIntSess {
|
||||
app.Run(os.Args)
|
||||
return
|
||||
}
|
||||
|
||||
elog, err := eventlog.Open(windowsServiceName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer elog.Close()
|
||||
|
||||
elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName))
|
||||
err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog})
|
||||
if err != nil {
|
||||
elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
|
||||
return
|
||||
}
|
||||
elog.Info(1, fmt.Sprintf("%s service stopped", windowsServiceName))
|
||||
}
|
||||
|
||||
type windowsService struct {
|
||||
app *cli.App
|
||||
elog *eventlog.Log
|
||||
}
|
||||
|
||||
func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) {
|
||||
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
|
||||
changes <- svc.Status{State: svc.StartPending}
|
||||
go s.app.Run(args)
|
||||
changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case c := <-r:
|
||||
switch c.Cmd {
|
||||
case svc.Interrogate:
|
||||
changes <- c.CurrentStatus
|
||||
case svc.Stop, svc.Shutdown:
|
||||
break loop
|
||||
default:
|
||||
s.elog.Error(1, fmt.Sprintf("unexpected control request #%d", c))
|
||||
}
|
||||
}
|
||||
}
|
||||
close(shutdownC)
|
||||
changes <- svc.Status{State: svc.StopPending}
|
||||
return
|
||||
}
|
||||
|
||||
func installWindowsService(c *cli.Context) error {
|
||||
exepath, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer m.Disconnect()
|
||||
s, err := m.OpenService(windowsServiceName)
|
||||
if err == nil {
|
||||
s.Close()
|
||||
return fmt.Errorf("service %s already exists", windowsServiceName)
|
||||
}
|
||||
s, err = m.CreateService(windowsServiceName, exepath, mgr.Config{DisplayName: windowsServiceDescription}, "is", "auto-started")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
err = eventlog.InstallAsEventCreate(windowsServiceName, eventlog.Error|eventlog.Warning|eventlog.Info)
|
||||
if err != nil {
|
||||
s.Delete()
|
||||
return fmt.Errorf("SetupEventLogSource() failed: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func uninstallWindowsService(c *cli.Context) error {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer m.Disconnect()
|
||||
s, err := m.OpenService(windowsServiceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("service %s is not installed", windowsServiceName)
|
||||
}
|
||||
defer s.Close()
|
||||
err = s.Delete()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = eventlog.Remove(windowsServiceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,213 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// activeStreamMap is used to moderate access to active streams between the read and write
|
||||
// threads, and deny access to new peer streams while shutting down.
|
||||
type activeStreamMap struct {
|
||||
sync.RWMutex
|
||||
// streams tracks open streams.
|
||||
streams map[uint32]*MuxedStream
|
||||
// streamsEmpty is a chan that should be closed when no more streams are open.
|
||||
streamsEmpty chan struct{}
|
||||
// nextStreamID is the next ID to use on our side of the connection.
|
||||
// This is odd for clients, even for servers.
|
||||
nextStreamID uint32
|
||||
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
|
||||
maxPeerStreamID uint32
|
||||
// ignoreNewStreams is true when the connection is being shut down. New streams
|
||||
// cannot be registered.
|
||||
ignoreNewStreams bool
|
||||
}
|
||||
|
||||
type FlowControlMetrics struct {
|
||||
AverageReceiveWindowSize, AverageSendWindowSize float64
|
||||
MinReceiveWindowSize, MaxReceiveWindowSize uint32
|
||||
MinSendWindowSize, MaxSendWindowSize uint32
|
||||
}
|
||||
|
||||
func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap {
|
||||
m := &activeStreamMap{
|
||||
streams: make(map[uint32]*MuxedStream),
|
||||
streamsEmpty: make(chan struct{}),
|
||||
nextStreamID: 1,
|
||||
}
|
||||
// Client initiated stream uses odd stream ID, server initiated stream uses even stream ID
|
||||
if !useClientStreamNumbers {
|
||||
m.nextStreamID = 2
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Len returns the number of active streams.
|
||||
func (m *activeStreamMap) Len() int {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return len(m.streams)
|
||||
}
|
||||
|
||||
func (m *activeStreamMap) Get(streamID uint32) (*MuxedStream, bool) {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
stream, ok := m.streams[streamID]
|
||||
return stream, ok
|
||||
}
|
||||
|
||||
// Set returns true if the stream was assigned successfully. If a stream
|
||||
// already existed with that ID or we are shutting down, return false.
|
||||
func (m *activeStreamMap) Set(newStream *MuxedStream) bool {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
if _, ok := m.streams[newStream.streamID]; ok {
|
||||
return false
|
||||
}
|
||||
if m.ignoreNewStreams {
|
||||
return false
|
||||
}
|
||||
m.streams[newStream.streamID] = newStream
|
||||
return true
|
||||
}
|
||||
|
||||
// Delete stops tracking the stream. It should be called only after it is closed and resetted.
|
||||
func (m *activeStreamMap) Delete(streamID uint32) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
delete(m.streams, streamID)
|
||||
if len(m.streams) == 0 && m.streamsEmpty != nil {
|
||||
close(m.streamsEmpty)
|
||||
m.streamsEmpty = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown blocks new streams from being created. It returns a channel that receives an event
|
||||
// once the last stream has closed, or nil if a shutdown is in progress.
|
||||
func (m *activeStreamMap) Shutdown() <-chan struct{} {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
if m.ignoreNewStreams {
|
||||
// already shutting down
|
||||
return nil
|
||||
}
|
||||
m.ignoreNewStreams = true
|
||||
done := make(chan struct{})
|
||||
if len(m.streams) == 0 {
|
||||
// nothing to shut down
|
||||
close(done)
|
||||
return done
|
||||
}
|
||||
m.streamsEmpty = done
|
||||
return done
|
||||
}
|
||||
|
||||
// AcquireLocalID acquires a new stream ID for a stream you're opening.
|
||||
func (m *activeStreamMap) AcquireLocalID() uint32 {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
x := m.nextStreamID
|
||||
m.nextStreamID += 2
|
||||
return x
|
||||
}
|
||||
|
||||
// ObservePeerID observes the ID of a stream opened by the peer. It returns true if we should accept
|
||||
// the new stream, or false to reject it. The ErrCode gives the reason why.
|
||||
func (m *activeStreamMap) AcquirePeerID(streamID uint32) (bool, http2.ErrCode) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
switch {
|
||||
case m.ignoreNewStreams:
|
||||
return false, http2.ErrCodeStreamClosed
|
||||
case streamID > m.maxPeerStreamID:
|
||||
m.maxPeerStreamID = streamID
|
||||
return true, http2.ErrCodeNo
|
||||
default:
|
||||
return false, http2.ErrCodeStreamClosed
|
||||
}
|
||||
}
|
||||
|
||||
// IsPeerStreamID is true if the stream ID belongs to the peer.
|
||||
func (m *activeStreamMap) IsPeerStreamID(streamID uint32) bool {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return (streamID % 2) != (m.nextStreamID % 2)
|
||||
}
|
||||
|
||||
// IsLocalStreamID is true if it is a stream we have opened, even if it is now closed.
|
||||
func (m *activeStreamMap) IsLocalStreamID(streamID uint32) bool {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return (streamID%2) == (m.nextStreamID%2) && streamID < m.nextStreamID
|
||||
}
|
||||
|
||||
// LastPeerStreamID returns the most recently opened peer stream ID.
|
||||
func (m *activeStreamMap) LastPeerStreamID() uint32 {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return m.maxPeerStreamID
|
||||
}
|
||||
|
||||
// LastLocalStreamID returns the most recently opened local stream ID.
|
||||
func (m *activeStreamMap) LastLocalStreamID() uint32 {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
if m.nextStreamID > 1 {
|
||||
return m.nextStreamID - 2
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Abort closes every active stream and prevents new ones being created. This should be used to
|
||||
// return errors in pending read/writes when the underlying connection goes away.
|
||||
func (m *activeStreamMap) Abort() {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
for _, stream := range m.streams {
|
||||
stream.Close()
|
||||
}
|
||||
m.ignoreNewStreams = true
|
||||
}
|
||||
|
||||
func (m *activeStreamMap) Metrics() *FlowControlMetrics {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
var averageReceiveWindowSize, averageSendWindowSize float64
|
||||
var minReceiveWindowSize, maxReceiveWindowSize, minSendWindowSize, maxSendWindowSize uint32
|
||||
i := 0
|
||||
// The first variable in the range expression for map is the key, not index.
|
||||
for _, stream := range m.streams {
|
||||
// iterative mean: a(t+1) = a(t) + (a(t)-x)/(t+1)
|
||||
windows := stream.FlowControlWindow()
|
||||
averageReceiveWindowSize += (float64(windows.receiveWindow) - averageReceiveWindowSize) / float64(i+1)
|
||||
averageSendWindowSize += (float64(windows.sendWindow) - averageSendWindowSize) / float64(i+1)
|
||||
if i == 0 {
|
||||
maxReceiveWindowSize = windows.receiveWindow
|
||||
minReceiveWindowSize = windows.receiveWindow
|
||||
maxSendWindowSize = windows.sendWindow
|
||||
minSendWindowSize = windows.sendWindow
|
||||
} else {
|
||||
if windows.receiveWindow > maxReceiveWindowSize {
|
||||
maxReceiveWindowSize = windows.receiveWindow
|
||||
} else if windows.receiveWindow < minReceiveWindowSize {
|
||||
minReceiveWindowSize = windows.receiveWindow
|
||||
}
|
||||
|
||||
if windows.sendWindow > maxSendWindowSize {
|
||||
maxSendWindowSize = windows.sendWindow
|
||||
} else if windows.sendWindow < minSendWindowSize {
|
||||
minSendWindowSize = windows.sendWindow
|
||||
}
|
||||
}
|
||||
i++
|
||||
}
|
||||
return &FlowControlMetrics{
|
||||
MinReceiveWindowSize: minReceiveWindowSize,
|
||||
MaxReceiveWindowSize: maxReceiveWindowSize,
|
||||
AverageReceiveWindowSize: averageReceiveWindowSize,
|
||||
MinSendWindowSize: minSendWindowSize,
|
||||
MaxSendWindowSize: maxSendWindowSize,
|
||||
AverageSendWindowSize: averageSendWindowSize,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestMetrics(t *testing.T) {
|
||||
streamMap := newActiveStreamMap(false)
|
||||
for i := 1; i <= 7; i++ {
|
||||
stream := new(MuxedStream)
|
||||
*stream = MuxedStream{
|
||||
streamID: defaultWindowSize * uint32(i),
|
||||
receiveWindow: defaultWindowSize * uint32(i),
|
||||
sendWindow: defaultWindowSize * uint32(i),
|
||||
}
|
||||
|
||||
assert.True(t, streamMap.Set(stream))
|
||||
}
|
||||
metrics := streamMap.Metrics()
|
||||
assert.Equal(t, float64(defaultWindowSize*4), metrics.AverageReceiveWindowSize)
|
||||
assert.Equal(t, float64(defaultWindowSize*4), metrics.AverageSendWindowSize)
|
||||
assert.Equal(t, defaultWindowSize, metrics.MinReceiveWindowSize)
|
||||
assert.Equal(t, defaultWindowSize, metrics.MinSendWindowSize)
|
||||
assert.Equal(t, defaultWindowSize*7, metrics.MaxReceiveWindowSize)
|
||||
assert.Equal(t, defaultWindowSize*7, metrics.MaxSendWindowSize)
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
package h2mux
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// BooleanFuse is a data structure that can be set once to a particular value using Fuse(value).
|
||||
// Subsequent calls to Fuse() will have no effect.
|
||||
type BooleanFuse struct {
|
||||
value int32
|
||||
}
|
||||
|
||||
// Value gets the value
|
||||
func (f *BooleanFuse) Value() bool {
|
||||
// 0: unset
|
||||
// 1: set true
|
||||
// 2: set false
|
||||
return atomic.LoadInt32(&f.value) == 1
|
||||
}
|
||||
|
||||
func (f *BooleanFuse) Fuse(result bool) {
|
||||
newValue := int32(2)
|
||||
if result {
|
||||
newValue = 1
|
||||
}
|
||||
atomic.CompareAndSwapInt32(&f.value, 0, newValue)
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrHandshakeTimeout = MuxerHandshakeError{"1000 handshake timeout"}
|
||||
ErrBadHandshakeNotSettings = MuxerHandshakeError{"1001 unexpected response"}
|
||||
ErrBadHandshakeUnexpectedAck = MuxerHandshakeError{"1002 unexpected response"}
|
||||
ErrBadHandshakeNoMagic = MuxerHandshakeError{"1003 unexpected response"}
|
||||
ErrBadHandshakeWrongMagic = MuxerHandshakeError{"1004 connected to endpoint of wrong type"}
|
||||
ErrBadHandshakeNotSettingsAck = MuxerHandshakeError{"1005 unexpected response"}
|
||||
ErrBadHandshakeUnexpectedSettings = MuxerHandshakeError{"1006 unexpected response"}
|
||||
|
||||
ErrUnexpectedFrameType = MuxerProtocolError{"2001 unexpected frame type", http2.ErrCodeProtocol}
|
||||
ErrUnknownStream = MuxerProtocolError{"2002 unknown stream", http2.ErrCodeProtocol}
|
||||
ErrInvalidStream = MuxerProtocolError{"2003 invalid stream", http2.ErrCodeProtocol}
|
||||
|
||||
ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"}
|
||||
ErrConnectionClosed = MuxerApplicationError{"3001 connection closed"}
|
||||
ErrConnectionDropped = MuxerApplicationError{"3002 connection dropped"}
|
||||
|
||||
ErrClosedStream = MuxerStreamError{"4000 stream closed", http2.ErrCodeStreamClosed}
|
||||
)
|
||||
|
||||
type MuxerHandshakeError struct {
|
||||
cause string
|
||||
}
|
||||
|
||||
func (e MuxerHandshakeError) Error() string {
|
||||
return fmt.Sprintf("Handshake error: %s", e.cause)
|
||||
}
|
||||
|
||||
type MuxerProtocolError struct {
|
||||
cause string
|
||||
h2code http2.ErrCode
|
||||
}
|
||||
|
||||
func (e MuxerProtocolError) Error() string {
|
||||
return fmt.Sprintf("Protocol error: %s", e.cause)
|
||||
}
|
||||
|
||||
type MuxerApplicationError struct {
|
||||
cause string
|
||||
}
|
||||
|
||||
func (e MuxerApplicationError) Error() string {
|
||||
return fmt.Sprintf("Application error: %s", e.cause)
|
||||
}
|
||||
|
||||
type MuxerStreamError struct {
|
||||
cause string
|
||||
h2code http2.ErrCode
|
||||
}
|
||||
|
||||
func (e MuxerStreamError) Error() string {
|
||||
return fmt.Sprintf("Stream error: %s", e.cause)
|
||||
}
|
|
@ -0,0 +1,335 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultFrameSize uint32 = 1 << 14 // Minimum frame size in http2 spec
|
||||
defaultWindowSize uint32 = 65535
|
||||
maxWindowSize uint32 = (1 << 31) - 1 // 2^31-1 = 2147483647, max window size specified in http2 spec
|
||||
defaultTimeout time.Duration = 5 * time.Second
|
||||
defaultRetries uint64 = 5
|
||||
|
||||
SettingMuxerMagic http2.SettingID = 0x42db
|
||||
MuxerMagicOrigin uint32 = 0xa2e43c8b
|
||||
MuxerMagicEdge uint32 = 0x1088ebf9
|
||||
)
|
||||
|
||||
type MuxedStreamHandler interface {
|
||||
ServeStream(*MuxedStream) error
|
||||
}
|
||||
|
||||
type MuxedStreamFunc func(stream *MuxedStream) error
|
||||
|
||||
func (f MuxedStreamFunc) ServeStream(stream *MuxedStream) error {
|
||||
return f(stream)
|
||||
}
|
||||
|
||||
type MuxerConfig struct {
|
||||
Timeout time.Duration
|
||||
Handler MuxedStreamHandler
|
||||
IsClient bool
|
||||
// Name is used to identify this muxer instance when logging.
|
||||
Name string
|
||||
// The minimum time this connection can be idle before sending a heartbeat.
|
||||
HeartbeatInterval time.Duration
|
||||
// The minimum number of heartbeats to send before terminating the connection.
|
||||
MaxHeartbeats uint64
|
||||
}
|
||||
|
||||
type Muxer struct {
|
||||
// f is used to read and write HTTP2 frames on the wire.
|
||||
f *http2.Framer
|
||||
// config is the MuxerConfig given in Handshake.
|
||||
config MuxerConfig
|
||||
// w, r are references to the underlying connection used.
|
||||
w io.WriteCloser
|
||||
r io.ReadCloser
|
||||
// muxReader is the read process.
|
||||
muxReader *MuxReader
|
||||
// muxWriter is the write process.
|
||||
muxWriter *MuxWriter
|
||||
// newStreamChan is used to create new streams on the writer thread.
|
||||
// The writer will assign the next available stream ID.
|
||||
newStreamChan chan MuxedStreamRequest
|
||||
// abortChan is used to abort the writer event loop.
|
||||
abortChan chan struct{}
|
||||
// abortOnce is used to ensure abortChan is closed once only.
|
||||
abortOnce sync.Once
|
||||
// readyList is used to signal writable streams.
|
||||
readyList *ReadyList
|
||||
// streams tracks currently-open streams.
|
||||
streams *activeStreamMap
|
||||
// explicitShutdown records whether the Muxer is closing because Shutdown was called, or due to another
|
||||
// error.
|
||||
explicitShutdown BooleanFuse
|
||||
}
|
||||
|
||||
type Header struct {
|
||||
Name, Value string
|
||||
}
|
||||
|
||||
// Handshake establishes a muxed connection with the peer.
|
||||
// After the handshake completes, it is possible to open and accept streams.
|
||||
func Handshake(
|
||||
w io.WriteCloser,
|
||||
r io.ReadCloser,
|
||||
config MuxerConfig,
|
||||
) (*Muxer, error) {
|
||||
// Initialise connection state fields
|
||||
m := &Muxer{
|
||||
f: http2.NewFramer(w, r), // A framer that writes to w and reads from r
|
||||
config: config,
|
||||
w: w,
|
||||
r: r,
|
||||
newStreamChan: make(chan MuxedStreamRequest),
|
||||
abortChan: make(chan struct{}),
|
||||
readyList: NewReadyList(),
|
||||
streams: newActiveStreamMap(config.IsClient),
|
||||
}
|
||||
m.f.ReadMetaHeaders = hpack.NewDecoder(4096, func(hpack.HeaderField) {})
|
||||
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = defaultTimeout
|
||||
}
|
||||
|
||||
// Initialise the settings to identify this connection and confirm the other end is sane.
|
||||
handshakeSetting := http2.Setting{ID: SettingMuxerMagic, Val: MuxerMagicEdge}
|
||||
expectedMagic := MuxerMagicOrigin
|
||||
if config.IsClient {
|
||||
handshakeSetting.Val = MuxerMagicOrigin
|
||||
expectedMagic = MuxerMagicEdge
|
||||
}
|
||||
errChan := make(chan error, 2)
|
||||
// Simultaneously send our settings and verify the peer's settings.
|
||||
go func() { errChan <- m.f.WriteSettings(handshakeSetting) }()
|
||||
go func() { errChan <- m.readPeerSettings(expectedMagic) }()
|
||||
err := joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Confirm sanity by ACKing the frame and expecting an ACK for our frame.
|
||||
// Not strictly necessary, but let's pretend to be H2-like.
|
||||
go func() { errChan <- m.f.WriteSettingsAck() }()
|
||||
go func() { errChan <- m.readPeerSettingsAck() }()
|
||||
err = joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// set up reader/writer pair ready for serve
|
||||
streamErrors := NewStreamErrorMap()
|
||||
goAwayChan := make(chan http2.ErrCode, 1)
|
||||
pingTimestamp := NewPingTimestamp()
|
||||
connActive := NewSignal()
|
||||
idleDuration := config.HeartbeatInterval
|
||||
// Sanity check to enusre idelDuration is sane
|
||||
if idleDuration == 0 || idleDuration < defaultTimeout {
|
||||
idleDuration = defaultTimeout
|
||||
log.Warn("Minimum idle time has been adjusted to ", defaultTimeout)
|
||||
}
|
||||
maxRetries := config.MaxHeartbeats
|
||||
if maxRetries == 0 {
|
||||
maxRetries = defaultRetries
|
||||
log.Warn("Minimum number of unacked heartbeats to send before closing the connection has been adjusted to ", maxRetries)
|
||||
}
|
||||
|
||||
m.muxReader = &MuxReader{
|
||||
f: m.f,
|
||||
handler: m.config.Handler,
|
||||
streams: m.streams,
|
||||
readyList: m.readyList,
|
||||
streamErrors: streamErrors,
|
||||
goAwayChan: goAwayChan,
|
||||
abortChan: m.abortChan,
|
||||
pingTimestamp: pingTimestamp,
|
||||
connActive: connActive,
|
||||
initialStreamWindow: defaultWindowSize,
|
||||
streamWindowMax: maxWindowSize,
|
||||
r: m.r,
|
||||
}
|
||||
m.muxWriter = &MuxWriter{
|
||||
f: m.f,
|
||||
streams: m.streams,
|
||||
streamErrors: streamErrors,
|
||||
readyStreamChan: m.readyList.ReadyChannel(),
|
||||
newStreamChan: m.newStreamChan,
|
||||
goAwayChan: goAwayChan,
|
||||
abortChan: m.abortChan,
|
||||
pingTimestamp: pingTimestamp,
|
||||
idleTimer: NewIdleTimer(idleDuration, maxRetries),
|
||||
connActiveChan: connActive.WaitChannel(),
|
||||
maxFrameSize: defaultFrameSize,
|
||||
}
|
||||
m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Muxer) readPeerSettings(magic uint32) error {
|
||||
frame, err := m.f.ReadFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
settingsFrame, ok := frame.(*http2.SettingsFrame)
|
||||
if !ok {
|
||||
return ErrBadHandshakeNotSettings
|
||||
}
|
||||
if settingsFrame.Header().Flags != 0 {
|
||||
return ErrBadHandshakeUnexpectedAck
|
||||
}
|
||||
peerMagic, ok := settingsFrame.Value(SettingMuxerMagic)
|
||||
if !ok {
|
||||
return ErrBadHandshakeNoMagic
|
||||
}
|
||||
if magic != peerMagic {
|
||||
return ErrBadHandshakeWrongMagic
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Muxer) readPeerSettingsAck() error {
|
||||
frame, err := m.f.ReadFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
settingsFrame, ok := frame.(*http2.SettingsFrame)
|
||||
if !ok {
|
||||
return ErrBadHandshakeNotSettingsAck
|
||||
}
|
||||
if settingsFrame.Header().Flags != http2.FlagSettingsAck {
|
||||
return ErrBadHandshakeUnexpectedSettings
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time.Duration, timeoutError error) error {
|
||||
for i := 0; i < receiveCount; i++ {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case <-time.After(timeout):
|
||||
return timeoutError
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Muxer) Serve() error {
|
||||
logger := log.WithField("name", m.config.Name)
|
||||
errChan := make(chan error)
|
||||
go func() {
|
||||
errChan <- m.muxReader.run(logger)
|
||||
m.explicitShutdown.Fuse(false)
|
||||
m.r.Close()
|
||||
m.abort()
|
||||
}()
|
||||
go func() {
|
||||
errChan <- m.muxWriter.run(logger)
|
||||
m.explicitShutdown.Fuse(false)
|
||||
m.w.Close()
|
||||
m.abort()
|
||||
}()
|
||||
err := <-errChan
|
||||
go func() {
|
||||
// discard error as other handler closes
|
||||
<-errChan
|
||||
close(errChan)
|
||||
}()
|
||||
if isUnexpectedTunnelError(err, m.explicitShutdown.Value()) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Muxer) Shutdown() {
|
||||
m.explicitShutdown.Fuse(true)
|
||||
m.muxReader.Shutdown()
|
||||
}
|
||||
|
||||
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
|
||||
// The set of expected errors change depending on whether we initiated shutdown or not.
|
||||
func isUnexpectedTunnelError(err error, expectedShutdown bool) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if !expectedShutdown {
|
||||
return true
|
||||
}
|
||||
return !isConnectionClosedError(err)
|
||||
}
|
||||
|
||||
func isConnectionClosedError(err error) bool {
|
||||
if err == io.EOF {
|
||||
return true
|
||||
}
|
||||
if err == io.ErrClosedPipe {
|
||||
return true
|
||||
}
|
||||
if err.Error() == "tls: use of closed connection" {
|
||||
return true
|
||||
}
|
||||
if strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// OpenStream opens a new data stream with the given headers.
|
||||
// Called by proxy server and tunnel
|
||||
func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, error) {
|
||||
stream := &MuxedStream{
|
||||
responseHeadersReceived: make(chan struct{}),
|
||||
readBuffer: NewSharedBuffer(),
|
||||
receiveWindow: defaultWindowSize,
|
||||
receiveWindowCurrentMax: defaultWindowSize, // Initial window size limit. exponentially increase it when receiveWindow is exhausted
|
||||
receiveWindowMax: maxWindowSize,
|
||||
sendWindow: defaultWindowSize,
|
||||
readyList: m.readyList,
|
||||
writeHeaders: headers,
|
||||
}
|
||||
select {
|
||||
// Will be received by mux writer
|
||||
case m.newStreamChan <- MuxedStreamRequest{stream: stream, body: body}:
|
||||
case <-m.abortChan:
|
||||
return nil, ErrConnectionClosed
|
||||
}
|
||||
select {
|
||||
case <-stream.responseHeadersReceived:
|
||||
return stream, nil
|
||||
case <-m.abortChan:
|
||||
return nil, ErrConnectionClosed
|
||||
}
|
||||
}
|
||||
|
||||
// Return the estimated round-trip time.
|
||||
func (m *Muxer) RTT() RTTMeasurement {
|
||||
return m.muxReader.RTT()
|
||||
}
|
||||
|
||||
// Return min/max/average of send/receive window for all streams on this connection
|
||||
func (m *Muxer) FlowControlMetrics() *FlowControlMetrics {
|
||||
return m.muxReader.FlowControlMetrics()
|
||||
}
|
||||
|
||||
func (m *Muxer) abort() {
|
||||
m.abortOnce.Do(func() {
|
||||
close(m.abortChan)
|
||||
m.streams.Abort()
|
||||
})
|
||||
}
|
||||
|
||||
// Return how many retries/ticks since the connection was last marked active
|
||||
func (m *Muxer) TimerRetries() uint64 {
|
||||
return m.muxWriter.idleTimer.RetryCount()
|
||||
}
|
|
@ -0,0 +1,646 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if os.Getenv("VERBOSE") == "1" {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
type DefaultMuxerPair struct {
|
||||
OriginMuxConfig MuxerConfig
|
||||
OriginMux *Muxer
|
||||
OriginWriter *io.PipeWriter
|
||||
OriginReader *io.PipeReader
|
||||
EdgeMuxConfig MuxerConfig
|
||||
EdgeMux *Muxer
|
||||
EdgeWriter *io.PipeWriter
|
||||
EdgeReader *io.PipeReader
|
||||
doneC chan struct{}
|
||||
}
|
||||
|
||||
func NewDefaultMuxerPair() *DefaultMuxerPair {
|
||||
originReader, edgeWriter := io.Pipe()
|
||||
edgeReader, originWriter := io.Pipe()
|
||||
return &DefaultMuxerPair{
|
||||
OriginMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: true, Name: "origin"},
|
||||
OriginWriter: originWriter,
|
||||
OriginReader: originReader,
|
||||
EdgeMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: false, Name: "edge"},
|
||||
EdgeWriter: edgeWriter,
|
||||
EdgeReader: edgeReader,
|
||||
doneC: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DefaultMuxerPair) Handshake(t *testing.T) {
|
||||
edgeErrC := make(chan error)
|
||||
originErrC := make(chan error)
|
||||
go func() {
|
||||
var err error
|
||||
p.EdgeMux, err = Handshake(p.EdgeWriter, p.EdgeReader, p.EdgeMuxConfig)
|
||||
edgeErrC <- err
|
||||
}()
|
||||
go func() {
|
||||
var err error
|
||||
p.OriginMux, err = Handshake(p.OriginWriter, p.OriginReader, p.OriginMuxConfig)
|
||||
originErrC <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-edgeErrC:
|
||||
if err != nil {
|
||||
t.Fatalf("edge handshake failure: %s", err)
|
||||
}
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Fatalf("edge handshake timeout")
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-originErrC:
|
||||
if err != nil {
|
||||
t.Fatalf("origin handshake failure: %s", err)
|
||||
}
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Fatalf("origin handshake timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) {
|
||||
p.Handshake(t)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
err := p.EdgeMux.Serve()
|
||||
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
|
||||
t.Errorf("error in edge muxer Serve(): %s", err)
|
||||
}
|
||||
p.OriginMux.Shutdown()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
err := p.OriginMux.Serve()
|
||||
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
|
||||
t.Errorf("error in origin muxer Serve(): %s", err)
|
||||
}
|
||||
p.EdgeMux.Shutdown()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
// notify when both muxes have stopped serving
|
||||
wg.Wait()
|
||||
close(p.doneC)
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *DefaultMuxerPair) Wait(t *testing.T) {
|
||||
select {
|
||||
case <-p.doneC:
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
muxPair := NewDefaultMuxerPair()
|
||||
muxPair.Handshake(t)
|
||||
AssertIfPipeReadable(t, muxPair.OriginReader)
|
||||
AssertIfPipeReadable(t, muxPair.EdgeReader)
|
||||
}
|
||||
|
||||
func TestSingleStream(t *testing.T) {
|
||||
closeC := make(chan struct{})
|
||||
muxPair := NewDefaultMuxerPair()
|
||||
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
defer close(closeC)
|
||||
if len(stream.Headers) != 1 {
|
||||
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "test-header" {
|
||||
t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "headerValue" {
|
||||
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
|
||||
}
|
||||
stream.WriteHeaders([]Header{
|
||||
Header{Name: "response-header", Value: "responseValue"},
|
||||
})
|
||||
buf := []byte("Hello world")
|
||||
stream.Write(buf)
|
||||
// after this receive, the edge closed the stream
|
||||
<-closeC
|
||||
n, err := stream.Read(buf)
|
||||
if n > 0 {
|
||||
t.Fatalf("read %d bytes after EOF", n)
|
||||
}
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
muxPair.HandshakeAndServe(t)
|
||||
|
||||
stream, err := muxPair.EdgeMux.OpenStream(
|
||||
[]Header{Header{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
if len(stream.Headers) != 1 {
|
||||
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "response-header" {
|
||||
t.Fatalf("expected header name %s, got %s", "response-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "responseValue" {
|
||||
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
|
||||
}
|
||||
responseBody := make([]byte, 11)
|
||||
n, err := stream.Read(responseBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error from (*MuxedStream).Read: %s", err)
|
||||
}
|
||||
if n != len(responseBody) {
|
||||
t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n)
|
||||
}
|
||||
if string(responseBody) != "Hello world" {
|
||||
t.Fatalf("expected response body %s, got %s", "Hello world", responseBody)
|
||||
}
|
||||
stream.Close()
|
||||
closeC <- struct{}{}
|
||||
n, err = stream.Write([]byte("aaaaa"))
|
||||
if n > 0 {
|
||||
t.Fatalf("wrote %d bytes after EOF", n)
|
||||
}
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %s", err)
|
||||
}
|
||||
<-closeC
|
||||
}
|
||||
|
||||
func TestSingleStreamLargeResponseBody(t *testing.T) {
|
||||
muxPair := NewDefaultMuxerPair()
|
||||
bodySize := 1 << 24
|
||||
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
if len(stream.Headers) != 1 {
|
||||
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "test-header" {
|
||||
t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "headerValue" {
|
||||
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
|
||||
}
|
||||
stream.WriteHeaders([]Header{
|
||||
Header{Name: "response-header", Value: "responseValue"},
|
||||
})
|
||||
payload := make([]byte, bodySize)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i % 256)
|
||||
}
|
||||
n, err := stream.Write(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("origin write error: %s", err)
|
||||
}
|
||||
if n != len(payload) {
|
||||
t.Fatalf("origin short write: %d/%d bytes", n, len(payload))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
muxPair.HandshakeAndServe(t)
|
||||
|
||||
stream, err := muxPair.EdgeMux.OpenStream(
|
||||
[]Header{Header{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
if len(stream.Headers) != 1 {
|
||||
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "response-header" {
|
||||
t.Fatalf("expected header name %s, got %s", "response-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "responseValue" {
|
||||
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
|
||||
}
|
||||
responseBody := make([]byte, bodySize)
|
||||
n, err := stream.Read(responseBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error from (*MuxedStream).Read: %s", err)
|
||||
}
|
||||
if n != len(responseBody) {
|
||||
t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleStreams(t *testing.T) {
|
||||
muxPair := NewDefaultMuxerPair()
|
||||
maxStreams := 64
|
||||
errorsC := make(chan error, maxStreams)
|
||||
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
if len(stream.Headers) != 1 {
|
||||
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "client-token" {
|
||||
t.Fatalf("expected header name %s, got %s", "client-token", stream.Headers[0].Name)
|
||||
}
|
||||
log.Debugf("Got request for stream %s", stream.Headers[0].Value)
|
||||
stream.WriteHeaders([]Header{
|
||||
Header{Name: "response-token", Value: stream.Headers[0].Value},
|
||||
})
|
||||
log.Debugf("Wrote headers for stream %s", stream.Headers[0].Value)
|
||||
stream.Write([]byte("OK"))
|
||||
log.Debugf("Wrote body for stream %s", stream.Headers[0].Value)
|
||||
return nil
|
||||
})
|
||||
muxPair.HandshakeAndServe(t)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(maxStreams)
|
||||
for i := 0; i < maxStreams; i++ {
|
||||
go func(tokenId int) {
|
||||
defer wg.Done()
|
||||
tokenString := fmt.Sprintf("%d", tokenId)
|
||||
stream, err := muxPair.EdgeMux.OpenStream(
|
||||
[]Header{Header{Name: "client-token", Value: tokenString}},
|
||||
nil,
|
||||
)
|
||||
log.Debugf("Got headers for stream %d", tokenId)
|
||||
if err != nil {
|
||||
errorsC <- err
|
||||
return
|
||||
}
|
||||
if len(stream.Headers) != 1 {
|
||||
errorsC <- fmt.Errorf("stream %d has error: expected %d headers, got %d", stream.streamID, 1, len(stream.Headers))
|
||||
return
|
||||
}
|
||||
if stream.Headers[0].Name != "response-token" {
|
||||
errorsC <- fmt.Errorf("stream %d has error: expected header name %s, got %s", stream.streamID, "response-token", stream.Headers[0].Name)
|
||||
return
|
||||
}
|
||||
if stream.Headers[0].Value != tokenString {
|
||||
errorsC <- fmt.Errorf("stream %d has error: expected header value %s, got %s", stream.streamID, tokenString, stream.Headers[0].Value)
|
||||
return
|
||||
}
|
||||
responseBody := make([]byte, 2)
|
||||
n, err := stream.Read(responseBody)
|
||||
if err != nil {
|
||||
errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err)
|
||||
return
|
||||
}
|
||||
if n != len(responseBody) {
|
||||
errorsC <- fmt.Errorf("stream %d has error: expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n)
|
||||
return
|
||||
}
|
||||
if string(responseBody) != "OK" {
|
||||
errorsC <- fmt.Errorf("stream %d has error: expected response body %s, got %s", stream.streamID, "OK", responseBody)
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
close(errorsC)
|
||||
testFail := false
|
||||
for err := range errorsC {
|
||||
testFail = true
|
||||
log.Error(err)
|
||||
}
|
||||
if testFail {
|
||||
t.Fatalf("TestMultipleStreamsFlowControl failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleStreamsFlowControl(t *testing.T) {
|
||||
maxStreams := 32
|
||||
errorsC := make(chan error, maxStreams)
|
||||
responseSizes := make([]int32, maxStreams)
|
||||
for i := 0; i < maxStreams; i++ {
|
||||
responseSizes[i] = rand.Int31n(int32(defaultWindowSize << 4))
|
||||
}
|
||||
muxPair := NewDefaultMuxerPair()
|
||||
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
if len(stream.Headers) != 1 {
|
||||
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "test-header" {
|
||||
t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "headerValue" {
|
||||
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
|
||||
}
|
||||
stream.WriteHeaders([]Header{
|
||||
Header{Name: "response-header", Value: "responseValue"},
|
||||
})
|
||||
payload := make([]byte, responseSizes[(stream.streamID-2)/2])
|
||||
for i := range payload {
|
||||
payload[i] = byte(i % 256)
|
||||
}
|
||||
n, err := stream.Write(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("origin write error: %s", err)
|
||||
}
|
||||
if n != len(payload) {
|
||||
t.Fatalf("origin short write: %d/%d bytes", n, len(payload))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
muxPair.HandshakeAndServe(t)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(maxStreams)
|
||||
for i := 0; i < maxStreams; i++ {
|
||||
go func(tokenId int) {
|
||||
defer wg.Done()
|
||||
stream, err := muxPair.EdgeMux.OpenStream(
|
||||
[]Header{Header{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
errorsC <- fmt.Errorf("stream %d error in OpenStream: %s", stream.streamID, err)
|
||||
return
|
||||
}
|
||||
if len(stream.Headers) != 1 {
|
||||
errorsC <- fmt.Errorf("stream %d expected %d headers, got %d", stream.streamID, 1, len(stream.Headers))
|
||||
return
|
||||
}
|
||||
if stream.Headers[0].Name != "response-header" {
|
||||
errorsC <- fmt.Errorf("stream %d expected header name %s, got %s", stream.streamID, "response-header", stream.Headers[0].Name)
|
||||
return
|
||||
}
|
||||
if stream.Headers[0].Value != "responseValue" {
|
||||
errorsC <- fmt.Errorf("stream %d expected header value %s, got %s", stream.streamID, "responseValue", stream.Headers[0].Value)
|
||||
return
|
||||
}
|
||||
|
||||
responseBody := make([]byte, responseSizes[(stream.streamID-2)/2])
|
||||
n, err := stream.Read(responseBody)
|
||||
if err != nil {
|
||||
errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err)
|
||||
return
|
||||
}
|
||||
if n != len(responseBody) {
|
||||
errorsC <- fmt.Errorf("stream %d expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n)
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
close(errorsC)
|
||||
testFail := false
|
||||
for err := range errorsC {
|
||||
testFail = true
|
||||
log.Error(err)
|
||||
}
|
||||
if testFail {
|
||||
t.Fatalf("TestMultipleStreamsFlowControl failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGracefulShutdown(t *testing.T) {
|
||||
sendC := make(chan struct{})
|
||||
responseBuf := bytes.Repeat([]byte("Hello world"), 65536)
|
||||
muxPair := NewDefaultMuxerPair()
|
||||
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
stream.WriteHeaders([]Header{
|
||||
Header{Name: "response-header", Value: "responseValue"},
|
||||
})
|
||||
<-sendC
|
||||
log.Debugf("Writing %d bytes", len(responseBuf))
|
||||
stream.Write(responseBuf)
|
||||
stream.CloseWrite()
|
||||
log.Debugf("Wrote %d bytes", len(responseBuf))
|
||||
// Reading from the stream will block until the edge closes its end of the stream.
|
||||
// Otherwise, we'll close the whole connection before receiving the 'stream closed'
|
||||
// message from the edge.
|
||||
// Graceful shutdown works if you omit this, it just gives spurious errors for now -
|
||||
// TODO ignore errors when writing 'stream closed' and we're shutting down.
|
||||
stream.Read([]byte{0})
|
||||
log.Debugf("Handler ends")
|
||||
return nil
|
||||
})
|
||||
muxPair.HandshakeAndServe(t)
|
||||
|
||||
stream, err := muxPair.EdgeMux.OpenStream(
|
||||
[]Header{Header{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
// Start graceful shutdown of the edge mux - this should also close the origin mux when done
|
||||
muxPair.EdgeMux.Shutdown()
|
||||
close(sendC)
|
||||
if err != nil {
|
||||
t.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
responseBody := make([]byte, len(responseBuf))
|
||||
log.Debugf("Waiting for %d bytes", len(responseBuf))
|
||||
n, err := stream.Read(responseBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err)
|
||||
}
|
||||
if n != len(responseBody) {
|
||||
t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n)
|
||||
}
|
||||
if !bytes.Equal(responseBuf, responseBody) {
|
||||
t.Fatalf("response body mismatch")
|
||||
}
|
||||
stream.Close()
|
||||
muxPair.Wait(t)
|
||||
}
|
||||
|
||||
func TestUnexpectedShutdown(t *testing.T) {
|
||||
sendC := make(chan struct{})
|
||||
handlerFinishC := make(chan struct{})
|
||||
responseBuf := bytes.Repeat([]byte("Hello world"), 65536)
|
||||
muxPair := NewDefaultMuxerPair()
|
||||
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
defer close(handlerFinishC)
|
||||
stream.WriteHeaders([]Header{
|
||||
Header{Name: "response-header", Value: "responseValue"},
|
||||
})
|
||||
<-sendC
|
||||
n, err := stream.Read([]byte{0})
|
||||
if err != io.EOF {
|
||||
t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Fatalf("expected empty read, got %d bytes", n)
|
||||
}
|
||||
// Write comes after read, because write buffers data before it is flushed. It wouldn't know about EOF
|
||||
// until some time later. Calling read first forces it to know about EOF now.
|
||||
_, err = stream.Write(responseBuf)
|
||||
if err != io.EOF {
|
||||
t.Fatalf("unexpected error from (*MuxedStream).Write: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
muxPair.HandshakeAndServe(t)
|
||||
|
||||
stream, err := muxPair.EdgeMux.OpenStream(
|
||||
[]Header{Header{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
// Close the underlying connection before telling the origin to write.
|
||||
muxPair.EdgeReader.Close()
|
||||
close(sendC)
|
||||
if err != nil {
|
||||
t.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
responseBody := make([]byte, len(responseBuf))
|
||||
n, err := stream.Read(responseBody)
|
||||
if err != io.EOF {
|
||||
t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Fatalf("expected response body to have %d bytes, got %d", 0, n)
|
||||
}
|
||||
// The write ordering requirement explained in the origin handler applies here too.
|
||||
_, err = stream.Write(responseBuf)
|
||||
if err != io.EOF {
|
||||
t.Fatalf("unexpected error from (*MuxedStream).Write: %s", err)
|
||||
}
|
||||
<-handlerFinishC
|
||||
}
|
||||
|
||||
func EchoHandler(stream *MuxedStream) error {
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprintf(&buf, "Hello, world!\n\n# REQUEST HEADERS:\n\n")
|
||||
for _, header := range stream.Headers {
|
||||
fmt.Fprintf(&buf, "[%s] = %s\n", header.Name, header.Value)
|
||||
}
|
||||
stream.WriteHeaders([]Header{
|
||||
{Name: ":status", Value: "200"},
|
||||
{Name: "server", Value: "Echo-server/1.0"},
|
||||
{Name: "date", Value: time.Now().Format(time.RFC850)},
|
||||
{Name: "content-type", Value: "text/html; charset=utf-8"},
|
||||
{Name: "content-length", Value: strconv.Itoa(buf.Len())},
|
||||
})
|
||||
buf.WriteTo(stream)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOpenAfterDisconnect(t *testing.T) {
|
||||
for i := 0; i < 3; i++ {
|
||||
muxPair := NewDefaultMuxerPair()
|
||||
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(EchoHandler)
|
||||
muxPair.HandshakeAndServe(t)
|
||||
|
||||
switch i {
|
||||
case 0:
|
||||
// Close both directions of the connection to cause EOF on both peers.
|
||||
muxPair.OriginReader.Close()
|
||||
muxPair.OriginWriter.Close()
|
||||
case 1:
|
||||
// Close origin reader (edge writer) to cause EOF on origin only.
|
||||
muxPair.OriginReader.Close()
|
||||
case 2:
|
||||
// Close origin writer (edge reader) to cause EOF on edge only.
|
||||
muxPair.OriginWriter.Close()
|
||||
}
|
||||
|
||||
_, err := muxPair.EdgeMux.OpenStream(
|
||||
[]Header{Header{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
if err != ErrConnectionClosed {
|
||||
t.Fatalf("unexpected error in OpenStream: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHPACK(t *testing.T) {
|
||||
muxPair := NewDefaultMuxerPair()
|
||||
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(EchoHandler)
|
||||
muxPair.HandshakeAndServe(t)
|
||||
|
||||
stream, err := muxPair.EdgeMux.OpenStream(
|
||||
[]Header{
|
||||
{Name: ":method", Value: "RPC"},
|
||||
{Name: ":scheme", Value: "capnp"},
|
||||
{Name: ":path", Value: "*"},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
stream.Close()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
stream, err := muxPair.EdgeMux.OpenStream(
|
||||
[]Header{
|
||||
{Name: ":method", Value: "GET"},
|
||||
{Name: ":scheme", Value: "https"},
|
||||
{Name: ":authority", Value: "tunnel.otterlyadorable.co.uk"},
|
||||
{Name: ":path", Value: "/get"},
|
||||
{Name: "accept-encoding", Value: "gzip"},
|
||||
{Name: "cf-ray", Value: "378948953f044408-SFO-DOG"},
|
||||
{Name: "cf-visitor", Value: "{\"scheme\":\"https\"}"},
|
||||
{Name: "cf-connecting-ip", Value: "2400:cb00:0025:010d:0000:0000:0000:0001"},
|
||||
{Name: "x-forwarded-for", Value: "2400:cb00:0025:010d:0000:0000:0000:0001"},
|
||||
{Name: "x-forwarded-proto", Value: "https"},
|
||||
{Name: "accept-language", Value: "en-gb"},
|
||||
{Name: "referer", Value: "https://tunnel.otterlyadorable.co.uk/"},
|
||||
{Name: "cookie", Value: "__cfduid=d4555095065f92daedc059490771967d81493032162"},
|
||||
{Name: "connection", Value: "Keep-Alive"},
|
||||
{Name: "cf-ipcountry", Value: "US"},
|
||||
{Name: "accept", Value: "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
|
||||
{Name: "user-agent", Value: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_5) AppleWebKit/603.2.4 (KHTML, like Gecko) Version/10.1.1 Safari/603.2.4"},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
if len(stream.Headers) == 0 {
|
||||
t.Fatal("response has no headers")
|
||||
}
|
||||
if stream.Headers[0].Name != ":status" {
|
||||
t.Fatalf("first header should be status, found %s instead", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "200" {
|
||||
t.Fatalf("expected status 200, got %s", stream.Headers[0].Value)
|
||||
}
|
||||
ioutil.ReadAll(stream)
|
||||
stream.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) {
|
||||
errC := make(chan error)
|
||||
go func() {
|
||||
b := []byte{0}
|
||||
n, err := pipe.Read(b)
|
||||
if n > 0 {
|
||||
t.Fatalf("read pipe was not empty")
|
||||
}
|
||||
errC <- err
|
||||
}()
|
||||
select {
|
||||
case err := <-errC:
|
||||
if err != nil {
|
||||
t.Fatalf("read error: %s", err)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// nothing to read
|
||||
pipe.Close()
|
||||
<-errC
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IdleTimer is a type of Timer designed for managing heartbeats on an idle connection.
|
||||
// The timer ticks on an interval with added jitter to avoid accidental synchronisation
|
||||
// between two endpoints. It tracks the number of retries/ticks since the connection was
|
||||
// last marked active.
|
||||
//
|
||||
// The methods of IdleTimer must not be called while a goroutine is reading from C.
|
||||
type IdleTimer struct {
|
||||
// The channel on which ticks are delivered.
|
||||
C <-chan time.Time
|
||||
|
||||
// A timer used to measure idle connection time. Reset after sending data.
|
||||
idleTimer *time.Timer
|
||||
// The maximum length of time a connection is idle before sending a ping.
|
||||
idleDuration time.Duration
|
||||
// A pseudorandom source used to add jitter to the idle duration.
|
||||
randomSource *rand.Rand
|
||||
// The maximum number of retries allowed.
|
||||
maxRetries uint64
|
||||
// The number of retries since the connection was last marked active.
|
||||
retries uint64
|
||||
// A lock to prevent race condition while checking retries
|
||||
stateLock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewIdleTimer(idleDuration time.Duration, maxRetries uint64) *IdleTimer {
|
||||
t := &IdleTimer{
|
||||
idleTimer: time.NewTimer(idleDuration),
|
||||
idleDuration: idleDuration,
|
||||
randomSource: rand.New(rand.NewSource(time.Now().Unix())),
|
||||
maxRetries: maxRetries,
|
||||
}
|
||||
t.C = t.idleTimer.C
|
||||
return t
|
||||
}
|
||||
|
||||
// Retry should be called when retrying the idle timeout. If the maximum number of retries
|
||||
// has been met, returns false.
|
||||
// After calling this function and sending a heartbeat, call ResetTimer. Since sending the
|
||||
// heartbeat could be a blocking operation, we resetting the timer after the write completes
|
||||
// to avoid it expiring during the write.
|
||||
func (t *IdleTimer) Retry() bool {
|
||||
t.stateLock.Lock()
|
||||
defer t.stateLock.Unlock()
|
||||
if t.retries >= t.maxRetries {
|
||||
return false
|
||||
}
|
||||
t.retries++
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *IdleTimer) RetryCount() uint64 {
|
||||
t.stateLock.RLock()
|
||||
defer t.stateLock.RUnlock()
|
||||
return t.retries
|
||||
}
|
||||
|
||||
// MarkActive resets the idle connection timer and suppresses any outstanding idle events.
|
||||
func (t *IdleTimer) MarkActive() {
|
||||
if !t.idleTimer.Stop() {
|
||||
// eat the timer event to prevent spurious pings
|
||||
<-t.idleTimer.C
|
||||
}
|
||||
t.stateLock.Lock()
|
||||
t.retries = 0
|
||||
t.stateLock.Unlock()
|
||||
t.ResetTimer()
|
||||
}
|
||||
|
||||
// Reset the idle timer according to the configured duration, with some added jitter.
|
||||
func (t *IdleTimer) ResetTimer() {
|
||||
jitter := time.Duration(t.randomSource.Int63n(int64(t.idleDuration)))
|
||||
t.idleTimer.Reset(t.idleDuration + jitter)
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRetry(t *testing.T) {
|
||||
timer := NewIdleTimer(time.Second, 2)
|
||||
assert.Equal(t, uint64(0), timer.RetryCount())
|
||||
ok := timer.Retry()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, uint64(1), timer.RetryCount())
|
||||
ok = timer.Retry()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, uint64(2), timer.RetryCount())
|
||||
ok = timer.Retry()
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestMarkActive(t *testing.T) {
|
||||
timer := NewIdleTimer(time.Second, 2)
|
||||
assert.Equal(t, uint64(0), timer.RetryCount())
|
||||
ok := timer.Retry()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, uint64(1), timer.RetryCount())
|
||||
timer.MarkActive()
|
||||
assert.Equal(t, uint64(0), timer.RetryCount())
|
||||
}
|
|
@ -0,0 +1,250 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type MuxedStream struct {
|
||||
Headers []Header
|
||||
|
||||
streamID uint32
|
||||
|
||||
responseHeadersReceived chan struct{}
|
||||
|
||||
readBuffer *SharedBuffer
|
||||
receiveWindow uint32
|
||||
// current window size limit. Exponentially increase it when it's exhausted
|
||||
receiveWindowCurrentMax uint32
|
||||
// limit set in http2 spec. 2^31-1
|
||||
receiveWindowMax uint32
|
||||
// nonzero if a WINDOW_UPDATE frame for a stream needs to be sent
|
||||
windowUpdate uint32
|
||||
|
||||
writeLock sync.Mutex
|
||||
// The zero value for Buffer is an empty buffer ready to use.
|
||||
writeBuffer bytes.Buffer
|
||||
|
||||
sendWindow uint32
|
||||
|
||||
readyList *ReadyList
|
||||
headersSent bool
|
||||
writeHeaders []Header
|
||||
// true if the write end of this stream has been closed
|
||||
writeEOF bool
|
||||
// true if we have sent EOF to the peer
|
||||
sentEOF bool
|
||||
// true if the peer sent us an EOF
|
||||
receivedEOF bool
|
||||
}
|
||||
|
||||
type flowControlWindow struct {
|
||||
receiveWindow, sendWindow uint32
|
||||
}
|
||||
|
||||
func (s *MuxedStream) Read(p []byte) (n int, err error) {
|
||||
return s.readBuffer.Read(p)
|
||||
}
|
||||
|
||||
func (s *MuxedStream) Write(p []byte) (n int, err error) {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
if s.writeEOF {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n, err = s.writeBuffer.Write(p)
|
||||
if n != len(p) || err != nil {
|
||||
return n, err
|
||||
}
|
||||
s.writeNotify()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *MuxedStream) Close() error {
|
||||
// TUN-115: Close the write buffer before the read buffer.
|
||||
// In the case of shutdown, read will not get new data, but the write buffer can still receive
|
||||
// new data. Closing read before write allows application to race between a failed read and a
|
||||
// successful write, even though this close should appear to be atomic.
|
||||
// This can't happen the other way because reads may succeed after a failed write; if we read
|
||||
// past EOF the application will block until we close the buffer.
|
||||
err := s.CloseWrite()
|
||||
if err != nil {
|
||||
if s.CloseRead() == nil {
|
||||
// don't bother the caller with errors if at least one close succeeded
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return s.CloseRead()
|
||||
}
|
||||
|
||||
func (s *MuxedStream) CloseRead() error {
|
||||
return s.readBuffer.Close()
|
||||
}
|
||||
|
||||
func (s *MuxedStream) CloseWrite() error {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
if s.writeEOF {
|
||||
return io.EOF
|
||||
}
|
||||
s.writeEOF = true
|
||||
s.writeNotify()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MuxedStream) WriteHeaders(headers []Header) error {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
if s.writeHeaders != nil {
|
||||
return ErrStreamHeadersSent
|
||||
}
|
||||
s.writeHeaders = headers
|
||||
s.writeNotify()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MuxedStream) FlowControlWindow() *flowControlWindow {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
return &flowControlWindow{
|
||||
receiveWindow: s.receiveWindow,
|
||||
sendWindow: s.sendWindow,
|
||||
}
|
||||
}
|
||||
|
||||
// writeNotify must happen while holding writeLock.
|
||||
func (s *MuxedStream) writeNotify() {
|
||||
s.readyList.Signal(s.streamID)
|
||||
}
|
||||
|
||||
// Call by muxreader when it gets a WindowUpdateFrame. This is an update of the peer's
|
||||
// receive window (how much data we can send).
|
||||
func (s *MuxedStream) replenishSendWindow(bytes uint32) {
|
||||
s.writeLock.Lock()
|
||||
s.sendWindow += bytes
|
||||
s.writeNotify()
|
||||
s.writeLock.Unlock()
|
||||
}
|
||||
|
||||
// Call by muxreader when it receives a data frame
|
||||
func (s *MuxedStream) consumeReceiveWindow(bytes uint32) bool {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
// received data size is greater than receive window/buffer
|
||||
if s.receiveWindow < bytes {
|
||||
return false
|
||||
}
|
||||
s.receiveWindow -= bytes
|
||||
if s.receiveWindow < s.receiveWindowCurrentMax/2 {
|
||||
// exhausting client send window (how much data client can send)
|
||||
if s.receiveWindowCurrentMax < s.receiveWindowMax {
|
||||
s.receiveWindowCurrentMax <<= 1
|
||||
}
|
||||
s.windowUpdate += s.receiveWindowCurrentMax - s.receiveWindow
|
||||
s.writeNotify()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// receiveEOF should be called when the peer indicates no more data will be sent.
|
||||
// Returns true if the socket is now closed (i.e. the write side is already closed).
|
||||
func (s *MuxedStream) receiveEOF() (closed bool) {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
s.receivedEOF = true
|
||||
s.CloseRead()
|
||||
return s.writeEOF && s.writeBuffer.Len() == 0
|
||||
}
|
||||
|
||||
func (s *MuxedStream) gotReceiveEOF() bool {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
return s.receivedEOF
|
||||
}
|
||||
|
||||
// MuxedStreamReader implements io.ReadCloser for the read end of the stream.
|
||||
// This is useful for passing to functions that close the object after it is done reading,
|
||||
// but you still want to be able to write data afterwards (e.g. http.Client).
|
||||
type MuxedStreamReader struct {
|
||||
*MuxedStream
|
||||
}
|
||||
|
||||
func (s MuxedStreamReader) Read(p []byte) (n int, err error) {
|
||||
return s.MuxedStream.Read(p)
|
||||
}
|
||||
|
||||
func (s MuxedStreamReader) Close() error {
|
||||
return s.MuxedStream.CloseRead()
|
||||
}
|
||||
|
||||
// streamChunk represents a chunk of data to be written.
|
||||
type streamChunk struct {
|
||||
streamID uint32
|
||||
// true if a HEADERS frame should be sent
|
||||
sendHeaders bool
|
||||
headers []Header
|
||||
// nonzero if a WINDOW_UPDATE frame should be sent
|
||||
windowUpdate uint32
|
||||
// true if data frames should be sent
|
||||
sendData bool
|
||||
eof bool
|
||||
buffer bytes.Buffer
|
||||
}
|
||||
|
||||
// getChunk atomically extracts a chunk of data to be written by MuxWriter.
|
||||
// The data returned will not exceed the send window for this stream.
|
||||
func (s *MuxedStream) getChunk() *streamChunk {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
|
||||
chunk := &streamChunk{
|
||||
streamID: s.streamID,
|
||||
sendHeaders: !s.headersSent,
|
||||
headers: s.writeHeaders,
|
||||
windowUpdate: s.windowUpdate,
|
||||
sendData: !s.sentEOF,
|
||||
eof: s.writeEOF && uint32(s.writeBuffer.Len()) <= s.sendWindow,
|
||||
}
|
||||
|
||||
// Copies at most s.sendWindow bytes
|
||||
//log.Infof("writeBuffer len %d stream %d", s.writeBuffer.Len(), s.streamID)
|
||||
writeLen, _ := io.CopyN(&chunk.buffer, &s.writeBuffer, int64(s.sendWindow))
|
||||
//log.Infof("writeLen %d stream %d", writeLen, s.streamID)
|
||||
s.sendWindow -= uint32(writeLen)
|
||||
s.receiveWindow += s.windowUpdate
|
||||
s.windowUpdate = 0
|
||||
s.headersSent = true
|
||||
|
||||
// if this chunk contains the end of the stream, close the stream now
|
||||
if chunk.sendData && chunk.eof {
|
||||
s.sentEOF = true
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
func (c *streamChunk) sendHeadersFrame() bool {
|
||||
return c.sendHeaders
|
||||
}
|
||||
|
||||
func (c *streamChunk) sendWindowUpdateFrame() bool {
|
||||
return c.windowUpdate > 0
|
||||
}
|
||||
|
||||
func (c *streamChunk) sendDataFrame() bool {
|
||||
return c.sendData
|
||||
}
|
||||
|
||||
func (c *streamChunk) nextDataFrame(frameSize int) (payload []byte, endStream bool) {
|
||||
payload = c.buffer.Next(frameSize)
|
||||
if c.buffer.Len() == 0 {
|
||||
// this is the last data frame in this chunk
|
||||
c.sendData = false
|
||||
if c.eof {
|
||||
endStream = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const testWindowSize uint32 = 65535
|
||||
const testMaxWindowSize uint32 = testWindowSize << 2
|
||||
|
||||
// Only sending WINDOW_UPDATE frame, so sendWindow should never change
|
||||
func TestFlowControlSingleStream(t *testing.T) {
|
||||
stream := &MuxedStream{
|
||||
responseHeadersReceived: make(chan struct{}),
|
||||
readBuffer: NewSharedBuffer(),
|
||||
receiveWindow: testWindowSize,
|
||||
receiveWindowCurrentMax: testWindowSize,
|
||||
receiveWindowMax: testMaxWindowSize,
|
||||
sendWindow: testWindowSize,
|
||||
readyList: NewReadyList(),
|
||||
}
|
||||
assert.True(t, stream.consumeReceiveWindow(testWindowSize/2))
|
||||
dataSent := testWindowSize / 2
|
||||
assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow)
|
||||
assert.Equal(t, testWindowSize, stream.receiveWindowCurrentMax)
|
||||
assert.Equal(t, uint32(0), stream.windowUpdate)
|
||||
tempWindowUpdate := stream.windowUpdate
|
||||
|
||||
streamChunk := stream.getChunk()
|
||||
assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate)
|
||||
assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow)
|
||||
assert.Equal(t, uint32(0), stream.windowUpdate)
|
||||
assert.Equal(t, testWindowSize, stream.sendWindow)
|
||||
|
||||
assert.True(t, stream.consumeReceiveWindow(2))
|
||||
dataSent += 2
|
||||
assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow)
|
||||
assert.Equal(t, testWindowSize<<1, stream.receiveWindowCurrentMax)
|
||||
assert.Equal(t, (testWindowSize<<1)-stream.receiveWindow, stream.windowUpdate)
|
||||
tempWindowUpdate = stream.windowUpdate
|
||||
|
||||
streamChunk = stream.getChunk()
|
||||
assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate)
|
||||
assert.Equal(t, testWindowSize<<1, stream.receiveWindow)
|
||||
assert.Equal(t, uint32(0), stream.windowUpdate)
|
||||
assert.Equal(t, testWindowSize, stream.sendWindow)
|
||||
|
||||
assert.True(t, stream.consumeReceiveWindow(testWindowSize+10))
|
||||
dataSent = testWindowSize + 10
|
||||
assert.Equal(t, (testWindowSize<<1)-dataSent, stream.receiveWindow)
|
||||
assert.Equal(t, testWindowSize<<2, stream.receiveWindowCurrentMax)
|
||||
assert.Equal(t, (testWindowSize<<2)-stream.receiveWindow, stream.windowUpdate)
|
||||
tempWindowUpdate = stream.windowUpdate
|
||||
|
||||
streamChunk = stream.getChunk()
|
||||
assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate)
|
||||
assert.Equal(t, testWindowSize<<2, stream.receiveWindow)
|
||||
assert.Equal(t, uint32(0), stream.windowUpdate)
|
||||
assert.Equal(t, testWindowSize, stream.sendWindow)
|
||||
|
||||
assert.False(t, stream.consumeReceiveWindow(testMaxWindowSize+1))
|
||||
assert.Equal(t, testWindowSize<<2, stream.receiveWindow)
|
||||
assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax)
|
||||
}
|
|
@ -0,0 +1,326 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type MuxReader struct {
|
||||
// f is used to read HTTP2 frames.
|
||||
f *http2.Framer
|
||||
// handler provides a callback to receive new streams. if nil, new streams cannot be accepted.
|
||||
handler MuxedStreamHandler
|
||||
// streams tracks currently-open streams.
|
||||
streams *activeStreamMap
|
||||
// readyList is used to signal writable streams.
|
||||
readyList *ReadyList
|
||||
// streamErrors lets us report stream errors to the MuxWriter.
|
||||
streamErrors *StreamErrorMap
|
||||
// goAwayChan is used to tell the writer to send a GOAWAY message.
|
||||
goAwayChan chan<- http2.ErrCode
|
||||
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
|
||||
abortChan <-chan struct{}
|
||||
// pingTimestamp is an atomic value containing the latest received ping timestamp.
|
||||
pingTimestamp *PingTimestamp
|
||||
// connActive is used to signal to the writer that something happened on the connection.
|
||||
// This is used to clear idle timeout disconnection deadlines.
|
||||
connActive Signal
|
||||
// The initial value for the send and receive window of a new stream.
|
||||
initialStreamWindow uint32
|
||||
// The max value for the send window of a stream.
|
||||
streamWindowMax uint32
|
||||
// windowMetrics keeps track of min/max/average of send/receive windows for all streams
|
||||
flowControlMetrics *FlowControlMetrics
|
||||
metricsMutex sync.Mutex
|
||||
// r is a reference to the underlying connection used when shutting down.
|
||||
r io.Closer
|
||||
// rttMeasurement measures RTT based on ping timestamps.
|
||||
rttMeasurement RTTMeasurement
|
||||
rttMutex sync.Mutex
|
||||
}
|
||||
|
||||
func (r *MuxReader) Shutdown() {
|
||||
done := r.streams.Shutdown()
|
||||
if done == nil {
|
||||
return
|
||||
}
|
||||
r.sendGoAway(http2.ErrCodeNo)
|
||||
go func() {
|
||||
// close reader side when last stream ends; this will cause the writer to abort
|
||||
<-done
|
||||
r.r.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *MuxReader) RTT() RTTMeasurement {
|
||||
r.rttMutex.Lock()
|
||||
defer r.rttMutex.Unlock()
|
||||
return r.rttMeasurement
|
||||
}
|
||||
|
||||
func (r *MuxReader) FlowControlMetrics() *FlowControlMetrics {
|
||||
r.metricsMutex.Lock()
|
||||
defer r.metricsMutex.Unlock()
|
||||
if r.flowControlMetrics != nil {
|
||||
return r.flowControlMetrics
|
||||
}
|
||||
// No metrics available yet
|
||||
return &FlowControlMetrics{}
|
||||
}
|
||||
|
||||
func (r *MuxReader) run(parentLogger *log.Entry) error {
|
||||
logger := parentLogger.WithFields(log.Fields{
|
||||
"subsystem": "mux",
|
||||
"dir": "read",
|
||||
})
|
||||
defer logger.Debug("event loop finished")
|
||||
for {
|
||||
frame, err := r.f.ReadFrame()
|
||||
if err != nil {
|
||||
switch e := err.(type) {
|
||||
case http2.StreamError:
|
||||
logger.WithError(err).Warn("stream error")
|
||||
r.streamError(e.StreamID, e.Code)
|
||||
case http2.ConnectionError:
|
||||
logger.WithError(err).Warn("connection error")
|
||||
return r.connectionError(err)
|
||||
default:
|
||||
if isConnectionClosedError(err) {
|
||||
if r.streams.Len() == 0 {
|
||||
logger.Debug("shutting down")
|
||||
return nil
|
||||
}
|
||||
logger.Warn("connection closed unexpectedly")
|
||||
return err
|
||||
} else {
|
||||
logger.WithError(err).Warn("frame read error")
|
||||
return r.connectionError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
r.connActive.Signal()
|
||||
logger.WithField("data", frame).Debug("read frame")
|
||||
switch f := frame.(type) {
|
||||
case *http2.DataFrame:
|
||||
err = r.receiveFrameData(f, logger)
|
||||
case *http2.MetaHeadersFrame:
|
||||
err = r.receiveHeaderData(f)
|
||||
case *http2.RSTStreamFrame:
|
||||
streamID := f.Header().StreamID
|
||||
if streamID == 0 {
|
||||
return ErrInvalidStream
|
||||
}
|
||||
r.streams.Delete(streamID)
|
||||
case *http2.PingFrame:
|
||||
r.receivePingData(f)
|
||||
case *http2.GoAwayFrame:
|
||||
err = r.receiveGoAway(f)
|
||||
case *http2.WindowUpdateFrame:
|
||||
err = r.updateStreamWindow(f)
|
||||
default:
|
||||
err = ErrUnexpectedFrameType
|
||||
}
|
||||
if err != nil {
|
||||
logger.WithField("data", frame).WithError(err).Debug("frame error")
|
||||
return r.connectionError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MuxReader) newMuxedStream(streamID uint32) *MuxedStream {
|
||||
return &MuxedStream{
|
||||
streamID: streamID,
|
||||
readBuffer: NewSharedBuffer(),
|
||||
receiveWindow: r.initialStreamWindow,
|
||||
receiveWindowCurrentMax: r.initialStreamWindow,
|
||||
receiveWindowMax: r.streamWindowMax,
|
||||
sendWindow: r.initialStreamWindow,
|
||||
readyList: r.readyList,
|
||||
}
|
||||
}
|
||||
|
||||
// getStreamForFrame returns a stream if valid, or an error describing why the stream could not be returned.
|
||||
func (r *MuxReader) getStreamForFrame(frame http2.Frame) (*MuxedStream, error) {
|
||||
sid := frame.Header().StreamID
|
||||
if sid == 0 {
|
||||
return nil, ErrUnexpectedFrameType
|
||||
}
|
||||
if stream, ok := r.streams.Get(sid); ok {
|
||||
return stream, nil
|
||||
}
|
||||
if r.streams.IsLocalStreamID(sid) {
|
||||
// no stream available, but no error
|
||||
return nil, ErrClosedStream
|
||||
}
|
||||
if sid < r.streams.LastPeerStreamID() {
|
||||
// no stream available, stream closed error
|
||||
return nil, ErrClosedStream
|
||||
}
|
||||
return nil, ErrUnknownStream
|
||||
}
|
||||
|
||||
func (r *MuxReader) defaultStreamErrorHandler(err error, header http2.FrameHeader) error {
|
||||
if header.Flags.Has(http2.FlagHeadersEndStream) {
|
||||
return nil
|
||||
} else if err == ErrUnknownStream || err == ErrClosedStream {
|
||||
return r.streamError(header.StreamID, http2.ErrCodeStreamClosed)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Receives header frames from a stream. A non-nil error is a connection error.
|
||||
func (r *MuxReader) receiveHeaderData(frame *http2.MetaHeadersFrame) error {
|
||||
var stream *MuxedStream
|
||||
sid := frame.Header().StreamID
|
||||
if sid == 0 {
|
||||
return ErrUnexpectedFrameType
|
||||
}
|
||||
newStream := r.streams.IsPeerStreamID(sid)
|
||||
if newStream {
|
||||
// header request
|
||||
// TODO support trailers (if stream exists)
|
||||
ok, err := r.streams.AcquirePeerID(sid)
|
||||
if !ok {
|
||||
// ignore new streams while shutting down
|
||||
return r.streamError(sid, err)
|
||||
}
|
||||
stream = r.newMuxedStream(sid)
|
||||
// Set stream. Returns false if a stream already existed with that ID or we are shutting down, return false.
|
||||
if !r.streams.Set(stream) {
|
||||
// got HEADERS frame for an existing stream
|
||||
// TODO support trailers
|
||||
return r.streamError(sid, http2.ErrCodeInternal)
|
||||
}
|
||||
} else {
|
||||
// header response
|
||||
var err error
|
||||
if stream, err = r.getStreamForFrame(frame); err != nil {
|
||||
return r.defaultStreamErrorHandler(err, frame.Header())
|
||||
}
|
||||
}
|
||||
headers := make([]Header, len(frame.Fields))
|
||||
for i, header := range frame.Fields {
|
||||
headers[i].Name = header.Name
|
||||
headers[i].Value = header.Value
|
||||
}
|
||||
stream.Headers = headers
|
||||
if frame.Header().Flags.Has(http2.FlagHeadersEndStream) {
|
||||
stream.receiveEOF()
|
||||
return nil
|
||||
}
|
||||
if newStream {
|
||||
go r.handleStream(stream)
|
||||
} else {
|
||||
close(stream.responseHeadersReceived)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MuxReader) handleStream(stream *MuxedStream) {
|
||||
defer stream.Close()
|
||||
r.handler.ServeStream(stream)
|
||||
}
|
||||
|
||||
// Receives a data frame from a stream. A non-nil error is a connection error.
|
||||
func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.Entry) error {
|
||||
logger := parentLogger.WithField("stream", frame.Header().StreamID)
|
||||
stream, err := r.getStreamForFrame(frame)
|
||||
if err != nil {
|
||||
return r.defaultStreamErrorHandler(err, frame.Header())
|
||||
}
|
||||
data := frame.Data()
|
||||
if len(data) > 0 {
|
||||
_, err = stream.readBuffer.Write(data)
|
||||
if err != nil {
|
||||
return r.streamError(stream.streamID, http2.ErrCodeInternal)
|
||||
}
|
||||
}
|
||||
if frame.Header().Flags.Has(http2.FlagDataEndStream) {
|
||||
if stream.receiveEOF() {
|
||||
r.streams.Delete(stream.streamID)
|
||||
logger.Debug("stream closed")
|
||||
} else {
|
||||
logger.Debug("shutdown receive side")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if !stream.consumeReceiveWindow(uint32(len(data))) {
|
||||
return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Receive a PING from the peer. Update RTT and send/receive window metrics if it's an ACK.
|
||||
func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
|
||||
ts := int64(binary.LittleEndian.Uint64(frame.Data[:]))
|
||||
if !frame.IsAck() {
|
||||
r.pingTimestamp.Set(ts)
|
||||
return
|
||||
}
|
||||
r.rttMutex.Lock()
|
||||
r.rttMeasurement.Update(time.Unix(0, ts))
|
||||
r.rttMutex.Unlock()
|
||||
r.flowControlMetrics = r.streams.Metrics()
|
||||
}
|
||||
|
||||
// Receive a GOAWAY from the peer. Gracefully shut down our connection.
|
||||
func (r *MuxReader) receiveGoAway(frame *http2.GoAwayFrame) error {
|
||||
r.Shutdown()
|
||||
// Close all streams above the last processed stream
|
||||
lastStream := r.streams.LastLocalStreamID()
|
||||
for i := frame.LastStreamID + 2; i <= lastStream; i++ {
|
||||
if stream, ok := r.streams.Get(i); ok {
|
||||
stream.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Receives header frames from a stream. A non-nil error is a connection error.
|
||||
func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
|
||||
stream, err := r.getStreamForFrame(frame)
|
||||
if err != nil && err != ErrUnknownStream && err != ErrClosedStream {
|
||||
return err
|
||||
}
|
||||
if stream == nil {
|
||||
// ignore window updates on closed streams
|
||||
return nil
|
||||
}
|
||||
stream.replenishSendWindow(frame.Increment)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Raise a stream processing error, closing the stream. Runs on the write thread.
|
||||
func (r *MuxReader) streamError(streamID uint32, e http2.ErrCode) error {
|
||||
r.streamErrors.RaiseError(streamID, e)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MuxReader) connectionError(err error) error {
|
||||
http2Code := http2.ErrCodeInternal
|
||||
switch e := err.(type) {
|
||||
case http2.ConnectionError:
|
||||
http2Code = http2.ErrCode(e)
|
||||
case MuxerProtocolError:
|
||||
http2Code = e.h2code
|
||||
}
|
||||
log.Warnf("Connection error %v", http2Code)
|
||||
r.sendGoAway(http2Code)
|
||||
return err
|
||||
}
|
||||
|
||||
// Instruct the writer to send a GOAWAY message if possible. This may fail in
|
||||
// the case where an existing GOAWAY message is in flight or the writer event
|
||||
// loop already ended.
|
||||
func (r *MuxReader) sendGoAway(errCode http2.ErrCode) {
|
||||
select {
|
||||
case r.goAwayChan <- errCode:
|
||||
default:
|
||||
}
|
||||
}
|
|
@ -0,0 +1,238 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
type MuxWriter struct {
|
||||
// f is used to write HTTP2 frames.
|
||||
f *http2.Framer
|
||||
// streams tracks currently-open streams.
|
||||
streams *activeStreamMap
|
||||
// streamErrors receives stream errors raised by the MuxReader.
|
||||
streamErrors *StreamErrorMap
|
||||
// readyStreamChan is used to multiplex writable streams onto the single connection.
|
||||
// When a stream becomes writable its ID is sent on this channel.
|
||||
readyStreamChan <-chan uint32
|
||||
// newStreamChan is used to create new streams with a given set of headers.
|
||||
newStreamChan <-chan MuxedStreamRequest
|
||||
// goAwayChan is used to send a single GOAWAY message to the peer. The element received
|
||||
// is the HTTP/2 error code to send.
|
||||
goAwayChan <-chan http2.ErrCode
|
||||
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
|
||||
abortChan <-chan struct{}
|
||||
// pingTimestamp is an atomic value containing the latest received ping timestamp.
|
||||
pingTimestamp *PingTimestamp
|
||||
// A timer used to measure idle connection time. Reset after sending data.
|
||||
idleTimer *IdleTimer
|
||||
// connActiveChan receives a signal that the connection received some (read) activity.
|
||||
connActiveChan <-chan struct{}
|
||||
// Maximum size of all frames that can be sent on this connection.
|
||||
maxFrameSize uint32
|
||||
// headerEncoder is the stateful header encoder for this connection
|
||||
headerEncoder *hpack.Encoder
|
||||
// headerBuffer is the temporary buffer used by headerEncoder.
|
||||
headerBuffer bytes.Buffer
|
||||
}
|
||||
|
||||
type MuxedStreamRequest struct {
|
||||
stream *MuxedStream
|
||||
body io.Reader
|
||||
}
|
||||
|
||||
func (r *MuxedStreamRequest) flushBody() {
|
||||
io.Copy(r.stream, r.body)
|
||||
r.stream.CloseWrite()
|
||||
}
|
||||
|
||||
func tsToPingData(ts int64) [8]byte {
|
||||
pingData := [8]byte{}
|
||||
binary.LittleEndian.PutUint64(pingData[:], uint64(ts))
|
||||
return pingData
|
||||
}
|
||||
|
||||
func (w *MuxWriter) run(parentLogger *log.Entry) error {
|
||||
logger := parentLogger.WithFields(log.Fields{
|
||||
"subsystem": "mux",
|
||||
"dir": "write",
|
||||
})
|
||||
defer logger.Debug("event loop finished")
|
||||
for {
|
||||
select {
|
||||
case <-w.abortChan:
|
||||
logger.Debug("aborting writer thread")
|
||||
return nil
|
||||
case errCode := <-w.goAwayChan:
|
||||
logger.Debug("sending GOAWAY code ", errCode)
|
||||
err := w.f.WriteGoAway(w.streams.LastPeerStreamID(), errCode, []byte{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case <-w.pingTimestamp.GetUpdateChan():
|
||||
logger.Debug("sending PING ACK")
|
||||
err := w.f.WritePing(true, tsToPingData(w.pingTimestamp.Get()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case <-w.idleTimer.C:
|
||||
if !w.idleTimer.Retry() {
|
||||
return ErrConnectionDropped
|
||||
}
|
||||
logger.Debug("sending PING")
|
||||
err := w.f.WritePing(false, tsToPingData(time.Now().UnixNano()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.ResetTimer()
|
||||
case <-w.connActiveChan:
|
||||
w.idleTimer.MarkActive()
|
||||
case <-w.streamErrors.GetSignalChan():
|
||||
for streamID, errCode := range w.streamErrors.GetErrors() {
|
||||
logger.WithField("stream", streamID).WithField("code", errCode).Debug("resetting stream")
|
||||
err := w.f.WriteRSTStream(streamID, errCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case streamRequest := <-w.newStreamChan:
|
||||
streamID := w.streams.AcquireLocalID()
|
||||
streamRequest.stream.streamID = streamID
|
||||
if !w.streams.Set(streamRequest.stream) {
|
||||
// Race between OpenStream and Shutdown, and Shutdown won. Let Shutdown (and the eventual abort) take
|
||||
// care of this stream. Ideally we'd pass the error directly to the stream object somehow so the
|
||||
// caller can be unblocked sooner, but the value of that optimisation is minimal for most of the
|
||||
// reasons why you'd call Shutdown anyway.
|
||||
continue
|
||||
}
|
||||
if streamRequest.body != nil {
|
||||
go streamRequest.flushBody()
|
||||
}
|
||||
streamLogger := logger.WithField("stream", streamID)
|
||||
err := w.writeStreamData(streamRequest.stream, streamLogger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case streamID := <-w.readyStreamChan:
|
||||
streamLogger := logger.WithField("stream", streamID)
|
||||
stream, ok := w.streams.Get(streamID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
err := w.writeStreamData(stream, streamLogger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error {
|
||||
logger.Debug("writable")
|
||||
chunk := stream.getChunk()
|
||||
|
||||
if chunk.sendHeadersFrame() {
|
||||
err := w.writeHeaders(chunk.streamID, chunk.headers)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("error writing headers")
|
||||
return err
|
||||
}
|
||||
logger.Debug("output headers")
|
||||
}
|
||||
|
||||
if chunk.sendWindowUpdateFrame() {
|
||||
// Send a WINDOW_UPDATE frame to update our receive window.
|
||||
// If the Stream ID is zero, the window update applies to the connection as a whole
|
||||
// A WINDOW_UPDATE in a specific stream applies to the connection-level flow control as well.
|
||||
err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("error writing window update")
|
||||
return err
|
||||
}
|
||||
logger.Debugf("increment receive window by %d", chunk.windowUpdate)
|
||||
}
|
||||
|
||||
for chunk.sendDataFrame() {
|
||||
payload, sentEOF := chunk.nextDataFrame(int(w.maxFrameSize))
|
||||
err := w.f.WriteData(chunk.streamID, sentEOF, payload)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("error writing data")
|
||||
return err
|
||||
}
|
||||
logger.WithField("len", len(payload)).Debug("output data")
|
||||
|
||||
if sentEOF {
|
||||
if stream.readBuffer.Closed() {
|
||||
// transition into closed state
|
||||
if !stream.gotReceiveEOF() {
|
||||
// the peer may send data that we no longer want to receive. Force them into the
|
||||
// closed state.
|
||||
logger.Debug("resetting stream")
|
||||
w.f.WriteRSTStream(chunk.streamID, http2.ErrCodeNo)
|
||||
} else {
|
||||
// Half-open stream transitioned into closed
|
||||
logger.Debug("closing stream")
|
||||
}
|
||||
w.streams.Delete(chunk.streamID)
|
||||
} else {
|
||||
logger.Debug("closing stream write side")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *MuxWriter) encodeHeaders(headers []Header) ([]byte, error) {
|
||||
w.headerBuffer.Reset()
|
||||
for _, header := range headers {
|
||||
err := w.headerEncoder.WriteField(hpack.HeaderField{
|
||||
Name: header.Name,
|
||||
Value: header.Value,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return w.headerBuffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// writeHeaders writes a block of encoded headers, splitting it into multiple frames if necessary.
|
||||
func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error {
|
||||
encodedHeaders, err := w.encodeHeaders(headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
blockSize := int(w.maxFrameSize)
|
||||
continuation := false
|
||||
endHeaders := len(encodedHeaders) == 0
|
||||
for !endHeaders && err == nil {
|
||||
blockFragment := encodedHeaders
|
||||
if len(encodedHeaders) > blockSize {
|
||||
blockFragment = blockFragment[:blockSize]
|
||||
encodedHeaders = encodedHeaders[blockSize:]
|
||||
} else {
|
||||
endHeaders = true
|
||||
}
|
||||
if continuation {
|
||||
err = w.f.WriteContinuation(streamID, endHeaders, blockFragment)
|
||||
} else {
|
||||
err = w.f.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: streamID,
|
||||
EndHeaders: endHeaders,
|
||||
BlockFragment: blockFragment,
|
||||
})
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,140 @@
|
|||
package h2mux
|
||||
|
||||
// ReadyList multiplexes several event signals onto a single channel.
|
||||
type ReadyList struct {
|
||||
signalC chan uint32
|
||||
waitC chan uint32
|
||||
}
|
||||
|
||||
func NewReadyList() *ReadyList {
|
||||
rl := &ReadyList{
|
||||
signalC: make(chan uint32),
|
||||
waitC: make(chan uint32),
|
||||
}
|
||||
go rl.run()
|
||||
return rl
|
||||
}
|
||||
|
||||
// ID is the stream ID
|
||||
func (r *ReadyList) Signal(ID uint32) {
|
||||
r.signalC <- ID
|
||||
}
|
||||
|
||||
func (r *ReadyList) ReadyChannel() <-chan uint32 {
|
||||
return r.waitC
|
||||
}
|
||||
|
||||
func (r *ReadyList) Close() {
|
||||
close(r.signalC)
|
||||
}
|
||||
|
||||
func (r *ReadyList) run() {
|
||||
defer close(r.waitC)
|
||||
var queue readyDescriptorQueue
|
||||
var firstReady *readyDescriptor
|
||||
activeDescriptors := newReadyDescriptorMap()
|
||||
for {
|
||||
if firstReady == nil {
|
||||
// Wait for first ready descriptor
|
||||
i, ok := <-r.signalC
|
||||
if !ok {
|
||||
// closed
|
||||
return
|
||||
}
|
||||
firstReady = activeDescriptors.SetIfMissing(i)
|
||||
}
|
||||
select {
|
||||
case r.waitC <- firstReady.ID:
|
||||
activeDescriptors.Delete(firstReady.ID)
|
||||
firstReady = queue.Dequeue()
|
||||
case i, ok := <-r.signalC:
|
||||
if !ok {
|
||||
// closed
|
||||
return
|
||||
}
|
||||
newReady := activeDescriptors.SetIfMissing(i)
|
||||
if newReady != nil {
|
||||
// key doesn't exist
|
||||
queue.Enqueue(newReady)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type readyDescriptor struct {
|
||||
ID uint32
|
||||
Next *readyDescriptor
|
||||
}
|
||||
|
||||
// readyDescriptorQueue is a queue of readyDescriptors in the form of a singly-linked list.
|
||||
// The nil readyDescriptorQueue is an empty queue ready for use.
|
||||
type readyDescriptorQueue struct {
|
||||
Head *readyDescriptor
|
||||
Tail *readyDescriptor
|
||||
}
|
||||
|
||||
func (q *readyDescriptorQueue) Empty() bool {
|
||||
return q.Head == nil
|
||||
}
|
||||
|
||||
func (q *readyDescriptorQueue) Enqueue(x *readyDescriptor) {
|
||||
if x.Next != nil {
|
||||
panic("enqueued already queued item")
|
||||
}
|
||||
if q.Empty() {
|
||||
q.Head = x
|
||||
q.Tail = x
|
||||
} else {
|
||||
q.Tail.Next = x
|
||||
q.Tail = x
|
||||
}
|
||||
}
|
||||
|
||||
// Dequeue returns the first readyDescriptor in the queue, or nil if empty.
|
||||
func (q *readyDescriptorQueue) Dequeue() *readyDescriptor {
|
||||
if q.Empty() {
|
||||
return nil
|
||||
}
|
||||
x := q.Head
|
||||
q.Head = x.Next
|
||||
x.Next = nil
|
||||
return x
|
||||
}
|
||||
|
||||
// readyDescriptorQueue is a map of readyDescriptors keyed by ID.
|
||||
// It maintains a free list of deleted ready descriptors.
|
||||
type readyDescriptorMap struct {
|
||||
descriptors map[uint32]*readyDescriptor
|
||||
free []*readyDescriptor
|
||||
}
|
||||
|
||||
func newReadyDescriptorMap() *readyDescriptorMap {
|
||||
return &readyDescriptorMap{descriptors: make(map[uint32]*readyDescriptor)}
|
||||
}
|
||||
|
||||
// create or reuse a readyDescriptor if the stream is not in the queue.
|
||||
// This avoid stream starvation caused by a single high-bandwidth stream monopolising the writer goroutine
|
||||
func (m *readyDescriptorMap) SetIfMissing(key uint32) *readyDescriptor {
|
||||
if _, ok := m.descriptors[key]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var newDescriptor *readyDescriptor
|
||||
if len(m.free) > 0 {
|
||||
// reuse deleted ready descriptors
|
||||
newDescriptor = m.free[len(m.free)-1]
|
||||
m.free = m.free[:len(m.free)-1]
|
||||
} else {
|
||||
newDescriptor = &readyDescriptor{}
|
||||
}
|
||||
newDescriptor.ID = key
|
||||
m.descriptors[key] = newDescriptor
|
||||
return newDescriptor
|
||||
}
|
||||
|
||||
func (m *readyDescriptorMap) Delete(key uint32) {
|
||||
if descriptor, ok := m.descriptors[key]; ok {
|
||||
m.free = append(m.free, descriptor)
|
||||
delete(m.descriptors, key)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestReadyList(t *testing.T) {
|
||||
rl := NewReadyList()
|
||||
c := rl.ReadyChannel()
|
||||
// helper functions
|
||||
assertEmpty := func() {
|
||||
select {
|
||||
case <-c:
|
||||
t.Fatalf("Spurious wakeup")
|
||||
default:
|
||||
}
|
||||
}
|
||||
receiveWithTimeout := func() uint32 {
|
||||
select {
|
||||
case i := <-c:
|
||||
return i
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatalf("Timeout")
|
||||
return 0
|
||||
}
|
||||
}
|
||||
// no signals, receive should fail
|
||||
assertEmpty()
|
||||
rl.Signal(0)
|
||||
if receiveWithTimeout() != 0 {
|
||||
t.Fatalf("Received wrong ID of signalled event")
|
||||
}
|
||||
// no new signals, receive should fail
|
||||
assertEmpty()
|
||||
// Signals should not block;
|
||||
// Duplicate unhandled signals should not cause multiple wakeups
|
||||
signalled := [5]bool{}
|
||||
for i := range signalled {
|
||||
rl.Signal(uint32(i))
|
||||
rl.Signal(uint32(i))
|
||||
}
|
||||
// All signals should be received once (in any order)
|
||||
for range signalled {
|
||||
i := receiveWithTimeout()
|
||||
if signalled[i] {
|
||||
t.Fatalf("Received signal %d more than once", i)
|
||||
}
|
||||
signalled[i] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadyDescriptorQueue(t *testing.T) {
|
||||
var queue readyDescriptorQueue
|
||||
items := [4]readyDescriptor{}
|
||||
for i := range items {
|
||||
items[i].ID = uint32(i)
|
||||
}
|
||||
|
||||
if !queue.Empty() {
|
||||
t.Fatalf("nil queue should be empty")
|
||||
}
|
||||
queue.Enqueue(&items[3])
|
||||
queue.Enqueue(&items[1])
|
||||
queue.Enqueue(&items[0])
|
||||
queue.Enqueue(&items[2])
|
||||
if queue.Empty() {
|
||||
t.Fatalf("Empty should be false after enqueue")
|
||||
}
|
||||
i := queue.Dequeue().ID
|
||||
if i != 3 {
|
||||
t.Fatalf("item 3 should have been dequeued, got %d instead", i)
|
||||
}
|
||||
i = queue.Dequeue().ID
|
||||
if i != 1 {
|
||||
t.Fatalf("item 1 should have been dequeued, got %d instead", i)
|
||||
}
|
||||
i = queue.Dequeue().ID
|
||||
if i != 0 {
|
||||
t.Fatalf("item 0 should have been dequeued, got %d instead", i)
|
||||
}
|
||||
i = queue.Dequeue().ID
|
||||
if i != 2 {
|
||||
t.Fatalf("item 2 should have been dequeued, got %d instead", i)
|
||||
}
|
||||
if !queue.Empty() {
|
||||
t.Fatal("queue should be empty after dequeuing all items")
|
||||
}
|
||||
if queue.Dequeue() != nil {
|
||||
t.Fatal("dequeue on empty queue should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadyDescriptorMap(t *testing.T) {
|
||||
m := newReadyDescriptorMap()
|
||||
m.Delete(42)
|
||||
// (delete of missing key should be a noop)
|
||||
x := m.SetIfMissing(42)
|
||||
if x == nil {
|
||||
t.Fatal("SetIfMissing for new key returned nil")
|
||||
}
|
||||
if m.SetIfMissing(42) != nil {
|
||||
t.Fatal("SetIfMissing for existing key returned non-nil")
|
||||
}
|
||||
// this delete has effect
|
||||
m.Delete(42)
|
||||
// the next set should reuse the old object
|
||||
y := m.SetIfMissing(666)
|
||||
if y == nil {
|
||||
t.Fatal("SetIfMissing for new key returned nil")
|
||||
}
|
||||
if x != y {
|
||||
t.Fatal("SetIfMissing didn't reuse freed object")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PingTimestamp is an atomic interface around ping timestamping and signalling.
|
||||
type PingTimestamp struct {
|
||||
ts int64
|
||||
signal Signal
|
||||
}
|
||||
|
||||
func NewPingTimestamp() *PingTimestamp {
|
||||
return &PingTimestamp{signal: NewSignal()}
|
||||
}
|
||||
|
||||
func (pt *PingTimestamp) Set(v int64) {
|
||||
if atomic.SwapInt64(&pt.ts, v) != 0 {
|
||||
pt.signal.Signal()
|
||||
}
|
||||
}
|
||||
|
||||
func (pt *PingTimestamp) Get() int64 {
|
||||
return atomic.SwapInt64(&pt.ts, 0)
|
||||
}
|
||||
|
||||
func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} {
|
||||
return pt.signal.WaitChannel()
|
||||
}
|
||||
|
||||
// RTTMeasurement encapsulates a continuous round trip time measurement.
|
||||
type RTTMeasurement struct {
|
||||
Current, Min, Max time.Duration
|
||||
lastMeasurementTime time.Time
|
||||
}
|
||||
|
||||
// Update updates the computed values with a new measurement.
|
||||
// outgoingTime is the time that the probe was sent.
|
||||
// We assume that time.Now() is the time we received that probe.
|
||||
func (r *RTTMeasurement) Update(outgoingTime time.Time) {
|
||||
if !r.lastMeasurementTime.Before(outgoingTime) {
|
||||
return
|
||||
}
|
||||
r.lastMeasurementTime = outgoingTime
|
||||
r.Current = time.Since(outgoingTime)
|
||||
if r.Max < r.Current {
|
||||
r.Max = r.Current
|
||||
}
|
||||
if r.Min > r.Current {
|
||||
r.Min = r.Current
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type SharedBuffer struct {
|
||||
cond *sync.Cond
|
||||
buffer bytes.Buffer
|
||||
eof bool
|
||||
}
|
||||
|
||||
func NewSharedBuffer() *SharedBuffer {
|
||||
return &SharedBuffer{
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SharedBuffer) Read(p []byte) (n int, err error) {
|
||||
totalRead := 0
|
||||
s.cond.L.Lock()
|
||||
for totalRead < len(p) {
|
||||
n, err = s.buffer.Read(p[totalRead:])
|
||||
totalRead += n
|
||||
if err == io.EOF {
|
||||
if s.eof {
|
||||
break
|
||||
}
|
||||
err = nil
|
||||
s.cond.Wait()
|
||||
}
|
||||
}
|
||||
s.cond.L.Unlock()
|
||||
return totalRead, err
|
||||
}
|
||||
|
||||
func (s *SharedBuffer) Write(p []byte) (n int, err error) {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
if s.eof {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n, err = s.buffer.Write(p)
|
||||
s.cond.Signal()
|
||||
return
|
||||
}
|
||||
|
||||
func (s *SharedBuffer) Close() error {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
if !s.eof {
|
||||
s.eof = true
|
||||
s.cond.Signal()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SharedBuffer) Closed() bool {
|
||||
s.cond.L.Lock()
|
||||
defer s.cond.L.Unlock()
|
||||
return s.eof
|
||||
}
|
|
@ -0,0 +1,120 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) {
|
||||
return func(actual int, err error) {
|
||||
if expected != actual {
|
||||
t.Fatalf("Expected %d bytes, got %d", expected, actual)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharedBuffer(t *testing.T) {
|
||||
b := NewSharedBuffer()
|
||||
testData := []byte("Hello world")
|
||||
AssertIOReturnIsGood(t, len(testData))(b.Write(testData))
|
||||
bytesRead := make([]byte, len(testData))
|
||||
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
|
||||
}
|
||||
|
||||
func TestSharedBufferBlockingRead(t *testing.T) {
|
||||
b := NewSharedBuffer()
|
||||
testData := []byte("Hello world")
|
||||
result := make(chan []byte)
|
||||
go func() {
|
||||
bytesRead := make([]byte, len(testData))
|
||||
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
|
||||
result <- bytesRead
|
||||
}()
|
||||
select {
|
||||
case <-result:
|
||||
t.Fatalf("read returned early")
|
||||
default:
|
||||
}
|
||||
AssertIOReturnIsGood(t, 5)(b.Write(testData[:5]))
|
||||
select {
|
||||
case <-result:
|
||||
t.Fatalf("read returned early")
|
||||
default:
|
||||
}
|
||||
AssertIOReturnIsGood(t, len(testData)-5)(b.Write(testData[5:]))
|
||||
select {
|
||||
case r := <-result:
|
||||
if string(r) != string(testData) {
|
||||
t.Fatalf("expected read to return %s, got %s", testData, r)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("read timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// This is quite slow under the race detector
|
||||
func TestSharedBufferConcurrentReadWrite(t *testing.T) {
|
||||
b := NewSharedBuffer()
|
||||
var expectedResult, actualResult bytes.Buffer
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
block := make([]byte, 256)
|
||||
for i := range block {
|
||||
block[i] = byte(i)
|
||||
}
|
||||
for blockSize := 1; blockSize <= 256; blockSize++ {
|
||||
for i := 0; i < 256; i++ {
|
||||
expectedResult.Write(block[:blockSize])
|
||||
n, err := b.Write(block[:blockSize])
|
||||
if n != blockSize || err != nil {
|
||||
t.Fatalf("write error: %d %s", n, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
block := make([]byte, 256)
|
||||
// Change block sizes in opposition to the write thread, to test blocking for new data.
|
||||
for blockSize := 256; blockSize > 0; blockSize-- {
|
||||
for i := 0; i < 256; i++ {
|
||||
n, err := b.Read(block[:blockSize])
|
||||
if n != blockSize || err != nil {
|
||||
t.Fatalf("read error: %d %s", n, err)
|
||||
}
|
||||
actualResult.Write(block[:blockSize])
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
if bytes.Compare(expectedResult.Bytes(), actualResult.Bytes()) != 0 {
|
||||
t.Fatal("Result diverged")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharedBufferClose(t *testing.T) {
|
||||
b := NewSharedBuffer()
|
||||
testData := []byte("Hello world")
|
||||
AssertIOReturnIsGood(t, len(testData))(b.Write(testData))
|
||||
err := b.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error from Close: %s", err)
|
||||
}
|
||||
bytesRead := make([]byte, len(testData))
|
||||
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
|
||||
n, err := b.Read(bytesRead)
|
||||
if n != 0 {
|
||||
t.Fatalf("extra bytes received: %d", n)
|
||||
}
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %s", err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
package h2mux
|
||||
|
||||
// Signal describes an event that can be waited on for at least one signal.
|
||||
// Signalling the event while it is in the signalled state is a noop.
|
||||
// When the waiter wakes up, the signal is set to unsignalled.
|
||||
// It is a way for any number of writers to inform a reader (without blocking)
|
||||
// that an event has happened.
|
||||
type Signal struct {
|
||||
c chan struct{}
|
||||
}
|
||||
|
||||
// NewSignal creates a new Signal.
|
||||
func NewSignal() Signal {
|
||||
return Signal{c: make(chan struct{}, 1)}
|
||||
}
|
||||
|
||||
// Signal signals the event.
|
||||
func (s Signal) Signal() {
|
||||
// This channel is buffered, so the nonblocking send will always succeed if the buffer is empty.
|
||||
select {
|
||||
case s.c <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for the event to be signalled.
|
||||
func (s Signal) Wait() {
|
||||
<-s.c
|
||||
}
|
||||
|
||||
// WaitChannel returns a channel that is readable after Signal is called.
|
||||
func (s Signal) WaitChannel() <-chan struct{} {
|
||||
return s.c
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// StreamErrorMap is used to track stream errors. This is a separate structure to ActiveStreamMap because
|
||||
// errors can be raised against non-existent or closed streams.
|
||||
type StreamErrorMap struct {
|
||||
sync.RWMutex
|
||||
// errors tracks per-stream errors
|
||||
errors map[uint32]http2.ErrCode
|
||||
// hasError is signaled whenever an error is raised.
|
||||
hasError Signal
|
||||
}
|
||||
|
||||
// NewStreamErrorMap creates a new StreamErrorMap.
|
||||
func NewStreamErrorMap() *StreamErrorMap {
|
||||
return &StreamErrorMap{
|
||||
errors: make(map[uint32]http2.ErrCode),
|
||||
hasError: NewSignal(),
|
||||
}
|
||||
}
|
||||
|
||||
// RaiseError raises a stream error.
|
||||
func (s *StreamErrorMap) RaiseError(streamID uint32, err http2.ErrCode) {
|
||||
s.Lock()
|
||||
s.errors[streamID] = err
|
||||
s.Unlock()
|
||||
s.hasError.Signal()
|
||||
}
|
||||
|
||||
// GetSignalChan returns a channel that is signalled when an error is raised.
|
||||
func (s *StreamErrorMap) GetSignalChan() <-chan struct{} {
|
||||
return s.hasError.WaitChannel()
|
||||
}
|
||||
|
||||
// GetErrors retrieves all errors currently raised. This resets the currently-tracked errors.
|
||||
func (s *StreamErrorMap) GetErrors() map[uint32]http2.ErrCode {
|
||||
s.Lock()
|
||||
errors := s.errors
|
||||
s.errors = make(map[uint32]http2.ErrCode)
|
||||
s.Unlock()
|
||||
return errors
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
func ServeMetrics(l net.Listener, shutdownC <-chan struct{}) error {
|
||||
server := &http.Server{
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
}
|
||||
go func() {
|
||||
<-shutdownC
|
||||
server.Shutdown(context.Background())
|
||||
}()
|
||||
http.Handle("/metrics", promhttp.Handler())
|
||||
log.WithField("addr", l.Addr()).Info("Starting metrics server")
|
||||
err := server.Serve(l)
|
||||
if err == http.ErrServerClosed {
|
||||
log.Info("Metrics server stopped")
|
||||
return nil
|
||||
}
|
||||
log.WithError(err).Error("Metrics server quit with error")
|
||||
return err
|
||||
}
|
||||
|
||||
func RegisterBuildInfo(buildTime string, version string) {
|
||||
buildInfo := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
// Don't namespace build_info, since we want it to be consistent across all Cloudflare services
|
||||
Name: "build_info",
|
||||
Help: "Build and version information",
|
||||
},
|
||||
[]string{"goversion", "revision", "version"},
|
||||
)
|
||||
prometheus.MustRegister(buildInfo)
|
||||
buildInfo.WithLabelValues(runtime.Version(), buildTime, version).Set(1)
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
package origin
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// Redeclare time functions so they can be overridden in tests.
|
||||
var (
|
||||
timeNow = time.Now
|
||||
timeAfter = time.After
|
||||
)
|
||||
|
||||
// BackoffHandler manages exponential backoff and limits the maximum number of retries.
|
||||
// The base time period is 1 second, doubling with each retry.
|
||||
// After initial success, a grace period can be set to reset the backoff timer if
|
||||
// a connection is maintained successfully for a long enough period. The base grace period
|
||||
// is 2 seconds, doubling with each retry.
|
||||
type BackoffHandler struct {
|
||||
// MaxRetries sets the maximum number of retries to perform. The default value
|
||||
// of 0 disables retry completely.
|
||||
MaxRetries uint
|
||||
|
||||
retries uint
|
||||
resetDeadline time.Time
|
||||
}
|
||||
|
||||
func (b BackoffHandler) GetBackoffDuration(ctx context.Context) (time.Duration, bool) {
|
||||
// Follows the same logic as Backoff, but without mutating the receiver.
|
||||
// This select has to happen first to reflect the actual behaviour of the Backoff function.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return time.Duration(0), false
|
||||
default:
|
||||
}
|
||||
if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) {
|
||||
// b.retries would be set to 0 at this point
|
||||
return time.Second, true
|
||||
}
|
||||
if b.retries >= b.MaxRetries {
|
||||
return time.Duration(0), false
|
||||
}
|
||||
return time.Duration(time.Second * 1 << b.retries), true
|
||||
}
|
||||
|
||||
// Backoff is used to wait according to exponential backoff. Returns false if the
|
||||
// maximum number of retries have been used or if the underlying context has been cancelled.
|
||||
func (b *BackoffHandler) Backoff(ctx context.Context) bool {
|
||||
if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) {
|
||||
b.retries = 0
|
||||
b.resetDeadline = time.Time{}
|
||||
}
|
||||
if b.retries >= b.MaxRetries {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-timeAfter(time.Duration(time.Second * 1 << b.retries)):
|
||||
b.retries++
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Sets a grace period within which the the backoff timer is maintained. After the grace
|
||||
// period expires, the number of retries & backoff duration is reset.
|
||||
func (b *BackoffHandler) SetGracePeriod() {
|
||||
b.resetDeadline = timeNow().Add(time.Duration(time.Second * 2 << b.retries))
|
||||
}
|
|
@ -0,0 +1,114 @@
|
|||
package origin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func immediateTimeAfter(time.Duration) <-chan time.Time {
|
||||
c := make(chan time.Time, 1)
|
||||
c <- time.Now()
|
||||
return c
|
||||
}
|
||||
|
||||
func TestBackoffRetries(t *testing.T) {
|
||||
// make backoff return immediately
|
||||
timeAfter = immediateTimeAfter
|
||||
ctx := context.Background()
|
||||
backoff := BackoffHandler{MaxRetries: 3}
|
||||
if !backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff failed immediately")
|
||||
}
|
||||
if !backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff failed after 1 retry")
|
||||
}
|
||||
if !backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff failed after 2 retry")
|
||||
}
|
||||
if backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff allowed after 3 (max) retries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackoffCancel(t *testing.T) {
|
||||
// prevent backoff from returning normally
|
||||
timeAfter = func(time.Duration) <-chan time.Time { return make(chan time.Time) }
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
backoff := BackoffHandler{MaxRetries: 3}
|
||||
cancelFunc()
|
||||
if backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff allowed after cancel")
|
||||
}
|
||||
if _, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||
t.Fatalf("backoff allowed after cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackoffGracePeriod(t *testing.T) {
|
||||
currentTime := time.Now()
|
||||
// make timeNow return whatever we like
|
||||
timeNow = func() time.Time { return currentTime }
|
||||
// make backoff return immediately
|
||||
timeAfter = immediateTimeAfter
|
||||
ctx := context.Background()
|
||||
backoff := BackoffHandler{MaxRetries: 1}
|
||||
if !backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff failed immediately")
|
||||
}
|
||||
// the next call to Backoff would fail unless it's after the grace period
|
||||
backoff.SetGracePeriod()
|
||||
// advance time to after the grace period (~4 seconds) and see what happens
|
||||
currentTime = currentTime.Add(time.Second * 5)
|
||||
if !backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff failed after the grace period expired")
|
||||
}
|
||||
// confirm we ignore grace period after backoff
|
||||
if backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff allowed after 1 (max) retry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBackoffDurationRetries(t *testing.T) {
|
||||
// make backoff return immediately
|
||||
timeAfter = immediateTimeAfter
|
||||
ctx := context.Background()
|
||||
backoff := BackoffHandler{MaxRetries: 3}
|
||||
if _, ok := backoff.GetBackoffDuration(ctx); !ok {
|
||||
t.Fatalf("backoff failed immediately")
|
||||
}
|
||||
backoff.Backoff(ctx) // noop
|
||||
if _, ok := backoff.GetBackoffDuration(ctx); !ok {
|
||||
t.Fatalf("backoff failed after 1 retry")
|
||||
}
|
||||
backoff.Backoff(ctx) // noop
|
||||
if _, ok := backoff.GetBackoffDuration(ctx); !ok {
|
||||
t.Fatalf("backoff failed after 2 retry")
|
||||
}
|
||||
backoff.Backoff(ctx) // noop
|
||||
if _, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||
t.Fatalf("backoff allowed after 3 (max) retries")
|
||||
}
|
||||
if backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff allowed after 3 (max) retries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBackoffDuration(t *testing.T) {
|
||||
// make backoff return immediately
|
||||
timeAfter = immediateTimeAfter
|
||||
ctx := context.Background()
|
||||
backoff := BackoffHandler{MaxRetries: 3}
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second {
|
||||
t.Fatalf("backoff didn't return 1 second on first retry")
|
||||
}
|
||||
backoff.Backoff(ctx) // noop
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*2 {
|
||||
t.Fatalf("backoff didn't return 2 seconds on second retry")
|
||||
}
|
||||
backoff.Backoff(ctx) // noop
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*4 {
|
||||
t.Fatalf("backoff didn't return 4 seconds on third retry")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,360 @@
|
|||
package origin
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/cloudflare/cloudflare-warp/h2mux"
|
||||
"github.com/cloudflare/cloudflare-warp/tunnelrpc"
|
||||
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
|
||||
"github.com/cloudflare/cloudflare-warp/validation"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
raven "github.com/getsentry/raven-go"
|
||||
"github.com/pkg/errors"
|
||||
rpc "zombiezen.com/go/capnproto2/rpc"
|
||||
)
|
||||
|
||||
const (
|
||||
dialTimeout = 15 * time.Second
|
||||
|
||||
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
||||
)
|
||||
|
||||
type TunnelConfig struct {
|
||||
EdgeAddr string
|
||||
OriginUrl string
|
||||
Hostname string
|
||||
APIKey string
|
||||
APIEmail string
|
||||
APICAKey string
|
||||
TlsConfig *tls.Config
|
||||
Retries uint
|
||||
HeartbeatInterval time.Duration
|
||||
MaxHeartbeats uint64
|
||||
ClientID string
|
||||
ReportedVersion string
|
||||
LBPool string
|
||||
Tags []tunnelpogs.Tag
|
||||
AccessInternalIP bool
|
||||
ConnectedSignal h2mux.Signal
|
||||
}
|
||||
|
||||
type dialError struct {
|
||||
cause error
|
||||
}
|
||||
|
||||
func (e dialError) Error() string {
|
||||
return e.cause.Error()
|
||||
}
|
||||
|
||||
type printableRegisterTunnelError struct {
|
||||
cause error
|
||||
permanent bool
|
||||
}
|
||||
|
||||
func (e printableRegisterTunnelError) Error() string {
|
||||
return e.cause.Error()
|
||||
}
|
||||
|
||||
func (c *TunnelConfig) RegistrationOptions() *tunnelpogs.RegistrationOptions {
|
||||
policy := tunnelrpc.ExistingTunnelPolicy_disconnect
|
||||
if c.LBPool != "" {
|
||||
policy = tunnelrpc.ExistingTunnelPolicy_balance
|
||||
}
|
||||
return &tunnelpogs.RegistrationOptions{
|
||||
ClientID: c.ClientID,
|
||||
Version: c.ReportedVersion,
|
||||
OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH),
|
||||
ExistingTunnelPolicy: policy,
|
||||
PoolID: c.LBPool,
|
||||
Tags: c.Tags,
|
||||
ExposeInternalHostname: c.AccessInternalIP,
|
||||
}
|
||||
}
|
||||
|
||||
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
<-shutdownC
|
||||
cancel()
|
||||
}()
|
||||
backoff := BackoffHandler{MaxRetries: config.Retries}
|
||||
for {
|
||||
err, recoverable := ServeTunnel(ctx, config, &backoff)
|
||||
if recoverable {
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||
log.Infof("Retrying in %s seconds", duration)
|
||||
backoff.Backoff(ctx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func ServeTunnel(
|
||||
ctx context.Context,
|
||||
config *TunnelConfig,
|
||||
backoff *BackoffHandler,
|
||||
) (err error, recoverable bool) {
|
||||
// Returns error from parsing the origin URL or handshake errors
|
||||
handler, err := NewTunnelHandler(ctx, config)
|
||||
if err != nil {
|
||||
errLog := log.WithError(err)
|
||||
switch err.(type) {
|
||||
case dialError:
|
||||
errLog.Error("Unable to dial edge")
|
||||
case h2mux.MuxerHandshakeError:
|
||||
errLog.Error("Handshake failed with edge server")
|
||||
default:
|
||||
errLog.Error("Tunnel creation failure")
|
||||
return err, false
|
||||
}
|
||||
return err, true
|
||||
}
|
||||
serveCtx, serveCancel := context.WithCancel(ctx)
|
||||
registerErrC := make(chan error, 1)
|
||||
go func() {
|
||||
err := RegisterTunnel(serveCtx, handler.muxer, config)
|
||||
if err == nil {
|
||||
config.ConnectedSignal.Signal()
|
||||
backoff.SetGracePeriod()
|
||||
} else {
|
||||
serveCancel()
|
||||
}
|
||||
registerErrC <- err
|
||||
}()
|
||||
go func() {
|
||||
<-serveCtx.Done()
|
||||
handler.muxer.Shutdown()
|
||||
}()
|
||||
err = handler.muxer.Serve()
|
||||
serveCancel()
|
||||
registerErr := <-registerErrC
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Tunnel error")
|
||||
return err, true
|
||||
}
|
||||
if registerErr != nil {
|
||||
raven.CaptureError(registerErr, nil)
|
||||
// Don't retry on errors like entitlement failure or version too old
|
||||
if e, ok := registerErr.(printableRegisterTunnelError); ok {
|
||||
log.WithError(e).Error("Cannot register")
|
||||
if e.permanent {
|
||||
return nil, false
|
||||
}
|
||||
return e.cause, true
|
||||
}
|
||||
log.Error("Cannot register")
|
||||
return err, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func IsRPCStreamResponse(headers []h2mux.Header) bool {
|
||||
if len(headers) != 1 {
|
||||
return false
|
||||
}
|
||||
if headers[0].Name != ":status" || headers[0].Value != "200" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig) error {
|
||||
logger := log.WithField("subsystem", "rpc")
|
||||
logger.Debug("initiating RPC stream")
|
||||
stream, err := muxer.OpenStream([]h2mux.Header{
|
||||
{Name: ":method", Value: "RPC"},
|
||||
{Name: ":scheme", Value: "capnp"},
|
||||
{Name: ":path", Value: "*"},
|
||||
}, nil)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
raven.CaptureError(err, nil)
|
||||
return err
|
||||
}
|
||||
if !IsRPCStreamResponse(stream.Headers) {
|
||||
// stream response error
|
||||
raven.CaptureError(err, nil)
|
||||
return err
|
||||
}
|
||||
conn := rpc.NewConn(
|
||||
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
|
||||
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
|
||||
)
|
||||
defer conn.Close()
|
||||
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
|
||||
// Request server info without blocking tunnel registration; must use capnp library directly.
|
||||
tsClient := tunnelrpc.TunnelServer{Client: ts.Client}
|
||||
serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||
return nil
|
||||
})
|
||||
registration, err := ts.RegisterTunnel(
|
||||
ctx,
|
||||
&tunnelpogs.Authentication{Key: config.APIKey, Email: config.APIEmail, OriginCAKey: config.APICAKey},
|
||||
config.Hostname,
|
||||
config.RegistrationOptions(),
|
||||
)
|
||||
LogServerInfo(logger, serverInfoPromise.Result())
|
||||
if err != nil {
|
||||
// RegisterTunnel RPC failure
|
||||
return err
|
||||
}
|
||||
for _, logLine := range registration.LogLines {
|
||||
logger.Info(logLine)
|
||||
}
|
||||
if registration.Err != "" {
|
||||
return printableRegisterTunnelError{
|
||||
cause: fmt.Errorf("Server error: %s", registration.Err),
|
||||
permanent: registration.PermanentFailure,
|
||||
}
|
||||
}
|
||||
for _, url := range registration.Urls {
|
||||
log.Infof("Registered at %s", url)
|
||||
}
|
||||
for _, logLine := range registration.LogLines {
|
||||
log.Infof(logLine)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func LogServerInfo(logger *log.Entry, promise tunnelrpc.ServerInfo_Promise) {
|
||||
serverInfoMessage, err := promise.Struct()
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Failed to retrieve server information")
|
||||
return
|
||||
}
|
||||
serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Failed to retrieve server information")
|
||||
return
|
||||
}
|
||||
log.Infof("Connected to %s", serverInfo.LocationName)
|
||||
}
|
||||
|
||||
type TunnelHandler struct {
|
||||
originUrl string
|
||||
muxer *h2mux.Muxer
|
||||
httpClient *http.Client
|
||||
tags []tunnelpogs.Tag
|
||||
}
|
||||
|
||||
var dialer = net.Dialer{DualStack: true}
|
||||
|
||||
func NewTunnelHandler(ctx context.Context, config *TunnelConfig) (*TunnelHandler, error) {
|
||||
url, err := validation.ValidateUrl(config.OriginUrl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to parse origin url %#v", url)
|
||||
}
|
||||
h := &TunnelHandler{
|
||||
originUrl: url,
|
||||
httpClient: &http.Client{Timeout: time.Minute},
|
||||
tags: config.Tags,
|
||||
}
|
||||
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
|
||||
// TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one)
|
||||
plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", config.EdgeAddr)
|
||||
dialCancel()
|
||||
if err != nil {
|
||||
return nil, dialError{cause: err}
|
||||
}
|
||||
edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig)
|
||||
edgeConn.SetDeadline(time.Now().Add(dialTimeout))
|
||||
err = edgeConn.Handshake()
|
||||
if err != nil {
|
||||
return nil, dialError{cause: err}
|
||||
}
|
||||
// clear the deadline on the conn; h2mux has its own timeouts
|
||||
edgeConn.SetDeadline(time.Time{})
|
||||
// Establish a muxed connection with the edge
|
||||
// Client mux handshake with agent server
|
||||
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
|
||||
Timeout: 5 * time.Second,
|
||||
Handler: h,
|
||||
IsClient: true,
|
||||
HeartbeatInterval: config.HeartbeatInterval,
|
||||
MaxHeartbeats: config.MaxHeartbeats,
|
||||
})
|
||||
if err != nil {
|
||||
return h, errors.New("TLS handshake error")
|
||||
}
|
||||
return h, err
|
||||
}
|
||||
|
||||
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
|
||||
for _, header := range h2 {
|
||||
switch header.Name {
|
||||
case ":method":
|
||||
h1.Method = header.Value
|
||||
case ":scheme":
|
||||
case ":authority":
|
||||
// Otherwise the host header will be based on the origin URL
|
||||
h1.Host = header.Value
|
||||
case ":path":
|
||||
u, err := url.Parse(header.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unparseable path")
|
||||
}
|
||||
resolved := h1.URL.ResolveReference(u)
|
||||
// prevent escaping base URL
|
||||
if !strings.HasPrefix(resolved.String(), h1.URL.String()) {
|
||||
return fmt.Errorf("invalid path")
|
||||
}
|
||||
h1.URL = resolved
|
||||
default:
|
||||
h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
|
||||
h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}}
|
||||
for headerName, headerValues := range h1.Header {
|
||||
for _, headerValue := range headerValues {
|
||||
h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
|
||||
for _, tag := range h.tags {
|
||||
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||
if err != nil {
|
||||
log.WithError(err).Panic("Unexpected error from http.NewRequest")
|
||||
}
|
||||
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("invalid request received")
|
||||
}
|
||||
h.AppendTagHeaders(req)
|
||||
response, err := h.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("HTTP request error")
|
||||
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
|
||||
stream.Write([]byte("502 Bad Gateway"))
|
||||
} else {
|
||||
defer response.Body.Close()
|
||||
stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||
io.Copy(stream, response.Body)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
// Package tlsconfig provides convenience functions for configuring TLS connections from the
|
||||
// command line.
|
||||
package tlsconfig
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io/ioutil"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
||||
// CLIFlags names the flags used to configure TLS for a command or subsystem.
|
||||
// The nil value for a field means the flag is ignored.
|
||||
type CLIFlags struct {
|
||||
Cert string
|
||||
Key string
|
||||
ClientCert string
|
||||
RootCA string
|
||||
}
|
||||
|
||||
// GetConfig returns a TLS configuration according to the flags defined in f and
|
||||
// set by the user.
|
||||
func (f CLIFlags) GetConfig(c *cli.Context) *tls.Config {
|
||||
config := &tls.Config{}
|
||||
|
||||
if c.IsSet(f.Cert) && c.IsSet(f.Key) {
|
||||
cert, err := tls.LoadX509KeyPair(c.String(f.Cert), c.String(f.Key))
|
||||
if err != nil {
|
||||
log.WithError(err).Fatal("Error parsing X509 key pair")
|
||||
}
|
||||
config.Certificates = []tls.Certificate{cert}
|
||||
config.BuildNameToCertificate()
|
||||
}
|
||||
if c.IsSet(f.ClientCert) {
|
||||
// set of root certificate authorities that servers use if required to verify a client certificate
|
||||
// by the policy in ClientAuth
|
||||
config.ClientCAs = LoadCert(c.String(f.ClientCert))
|
||||
// server's policy for TLS Client Authentication. Default is no client cert
|
||||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
// set of root certificate authorities that clients use when verifying server certificates
|
||||
if c.IsSet(f.RootCA) {
|
||||
config.RootCAs = LoadCert(c.String(f.RootCA))
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// LoadCert creates a CertPool containing all certificates in a PEM-format file.
|
||||
func LoadCert(certPath string) *x509.CertPool {
|
||||
caCert, err := ioutil.ReadFile(certPath)
|
||||
if err != nil {
|
||||
log.WithError(err).Fatalf("Error reading certificate %s", certPath)
|
||||
}
|
||||
ca := x509.NewCertPool()
|
||||
if !ca.AppendCertsFromPEM(caCert) {
|
||||
log.WithError(err).Fatalf("Error parsing certificate %s", certPath)
|
||||
}
|
||||
return ca
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
# Generate go.capnp.out with:
|
||||
# capnp compile -o- go.capnp > go.capnp.out
|
||||
# Must run inside this directory to preserve paths.
|
||||
|
||||
@0xd12a1c51fedd6c88;
|
||||
|
||||
annotation package(file) :Text;
|
||||
annotation import(file) :Text;
|
||||
annotation doc(struct, field, enum) :Text;
|
||||
annotation tag(enumerant) :Text;
|
||||
annotation notag(enumerant) :Void;
|
||||
annotation customtype(field) :Text;
|
||||
annotation name(struct, field, union, enum, enumerant, interface, method, param, annotation, const, group) :Text;
|
||||
|
||||
$package("capnp");
|
|
@ -0,0 +1,26 @@
|
|||
package tunnelrpc
|
||||
|
||||
//go:generate capnp compile -ogo -I./tunnelrpc/ tunnelrpc.capnp
|
||||
|
||||
import (
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
)
|
||||
|
||||
// ConnLogger wraps a logrus *log.Entry for a connection.
|
||||
type ConnLogger struct {
|
||||
Entry *log.Entry
|
||||
}
|
||||
|
||||
func (c ConnLogger) Infof(ctx context.Context, format string, args ...interface{}) {
|
||||
c.Entry.Infof(format, args...)
|
||||
}
|
||||
|
||||
func (c ConnLogger) Errorf(ctx context.Context, format string, args ...interface{}) {
|
||||
c.Entry.Errorf(format, args...)
|
||||
}
|
||||
|
||||
func ConnLog(log *log.Entry) rpc.ConnOption {
|
||||
return rpc.ConnLog(ConnLogger{log})
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// Package logtransport provides a transport that logs all of its messages.
|
||||
package tunnelrpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2/encoding/text"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
|
||||
)
|
||||
|
||||
type transport struct {
|
||||
rpc.Transport
|
||||
l *log.Entry
|
||||
}
|
||||
|
||||
// New creates a new logger that proxies messages to and from t and
|
||||
// logs them to l. If l is nil, then the log package's default
|
||||
// logger is used.
|
||||
func NewTransportLogger(l *log.Entry, t rpc.Transport) rpc.Transport {
|
||||
return &transport{Transport: t, l: l}
|
||||
}
|
||||
|
||||
func (t *transport) SendMessage(ctx context.Context, msg rpccapnp.Message) error {
|
||||
t.l.Debugf("tx %s", formatMsg(msg))
|
||||
return t.Transport.SendMessage(ctx, msg)
|
||||
}
|
||||
|
||||
func (t *transport) RecvMessage(ctx context.Context) (rpccapnp.Message, error) {
|
||||
msg, err := t.Transport.RecvMessage(ctx)
|
||||
if err != nil {
|
||||
t.l.WithError(err).Debug("rx error")
|
||||
return msg, err
|
||||
}
|
||||
t.l.Debugf("rx %s", formatMsg(msg))
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func formatMsg(m rpccapnp.Message) string {
|
||||
var buf bytes.Buffer
|
||||
text.NewEncoder(&buf).Encode(0x91b79f1f808db032, m.Struct)
|
||||
return buf.String()
|
||||
}
|
|
@ -0,0 +1,194 @@
|
|||
package pogs
|
||||
|
||||
import (
|
||||
"github.com/cloudflare/cloudflare-warp/tunnelrpc"
|
||||
"golang.org/x/net/context"
|
||||
"zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/pogs"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
"zombiezen.com/go/capnproto2/server"
|
||||
)
|
||||
|
||||
type Authentication struct {
|
||||
Key string
|
||||
Email string
|
||||
OriginCAKey string
|
||||
}
|
||||
|
||||
func MarshalAuthentication(s tunnelrpc.Authentication, p *Authentication) error {
|
||||
return pogs.Insert(tunnelrpc.Authentication_TypeID, s.Struct, p)
|
||||
}
|
||||
|
||||
func UnmarshalAuthentication(s tunnelrpc.Authentication) (*Authentication, error) {
|
||||
p := new(Authentication)
|
||||
err := pogs.Extract(p, tunnelrpc.Authentication_TypeID, s.Struct)
|
||||
return p, err
|
||||
}
|
||||
|
||||
type TunnelRegistration struct {
|
||||
Err string
|
||||
Urls []string
|
||||
LogLines []string
|
||||
PermanentFailure bool
|
||||
}
|
||||
|
||||
func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error {
|
||||
return pogs.Insert(tunnelrpc.TunnelRegistration_TypeID, s.Struct, p)
|
||||
}
|
||||
|
||||
func UnmarshalTunnelRegistration(s tunnelrpc.TunnelRegistration) (*TunnelRegistration, error) {
|
||||
p := new(TunnelRegistration)
|
||||
err := pogs.Extract(p, tunnelrpc.TunnelRegistration_TypeID, s.Struct)
|
||||
return p, err
|
||||
}
|
||||
|
||||
type RegistrationOptions struct {
|
||||
ClientID string `capnp:"clientId"`
|
||||
Version string
|
||||
OS string `capnp:"os"`
|
||||
ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy
|
||||
PoolID string `capnp:"poolId"`
|
||||
ExposeInternalHostname bool
|
||||
Tags []Tag
|
||||
}
|
||||
|
||||
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {
|
||||
return pogs.Insert(tunnelrpc.RegistrationOptions_TypeID, s.Struct, p)
|
||||
}
|
||||
|
||||
func UnmarshalRegistrationOptions(s tunnelrpc.RegistrationOptions) (*RegistrationOptions, error) {
|
||||
p := new(RegistrationOptions)
|
||||
err := pogs.Extract(p, tunnelrpc.RegistrationOptions_TypeID, s.Struct)
|
||||
return p, err
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
type ServerInfo struct {
|
||||
LocationName string
|
||||
}
|
||||
|
||||
func MarshalServerInfo(s tunnelrpc.ServerInfo, p *ServerInfo) error {
|
||||
return pogs.Insert(tunnelrpc.ServerInfo_TypeID, s.Struct, p)
|
||||
}
|
||||
|
||||
func UnmarshalServerInfo(s tunnelrpc.ServerInfo) (*ServerInfo, error) {
|
||||
p := new(ServerInfo)
|
||||
err := pogs.Extract(p, tunnelrpc.ServerInfo_TypeID, s.Struct)
|
||||
return p, err
|
||||
}
|
||||
|
||||
type TunnelServer interface {
|
||||
RegisterTunnel(ctx context.Context, auth *Authentication, hostname string, options *RegistrationOptions) (*TunnelRegistration, error)
|
||||
GetServerInfo(ctx context.Context) (*ServerInfo, error)
|
||||
}
|
||||
|
||||
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {
|
||||
return tunnelrpc.TunnelServer_ServerToClient(TunnelServer_PogsImpl{s})
|
||||
}
|
||||
|
||||
type TunnelServer_PogsImpl struct {
|
||||
impl TunnelServer
|
||||
}
|
||||
|
||||
func (i TunnelServer_PogsImpl) RegisterTunnel(p tunnelrpc.TunnelServer_registerTunnel) error {
|
||||
authentication, err := p.Params.Auth()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pogsAuthentication, err := UnmarshalAuthentication(authentication)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hostname, err := p.Params.Hostname()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
options, err := p.Params.Options()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pogsOptions, err := UnmarshalRegistrationOptions(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
server.Ack(p.Options)
|
||||
registration, err := i.impl.RegisterTunnel(p.Ctx, pogsAuthentication, hostname, pogsOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := p.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return MarshalTunnelRegistration(result, registration)
|
||||
}
|
||||
|
||||
func (i TunnelServer_PogsImpl) GetServerInfo(p tunnelrpc.TunnelServer_getServerInfo) error {
|
||||
server.Ack(p.Options)
|
||||
serverInfo, err := i.impl.GetServerInfo(p.Ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := p.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return MarshalServerInfo(result, serverInfo)
|
||||
}
|
||||
|
||||
type TunnelServer_PogsClient struct {
|
||||
Client capnp.Client
|
||||
Conn *rpc.Conn
|
||||
}
|
||||
|
||||
func (c TunnelServer_PogsClient) Close() error {
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, auth *Authentication, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error {
|
||||
authentication, err := p.NewAuth()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = MarshalAuthentication(authentication, auth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.SetHostname(hostname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
registrationOptions, err := p.NewOptions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = MarshalRegistrationOptions(registrationOptions, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
retval, err := promise.Result().Struct()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return UnmarshalTunnelRegistration(retval)
|
||||
}
|
||||
|
||||
func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.GetServerInfo(ctx, func(p tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||
return nil
|
||||
})
|
||||
retval, err := promise.Result().Struct()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return UnmarshalServerInfo(retval)
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
using Go = import "go.capnp";
|
||||
@0xdb8274f9144abc7e;
|
||||
$Go.package("tunnelrpc");
|
||||
$Go.import("github.com/cloudflare/cloudflare-warp/tunnelrpc");
|
||||
|
||||
struct Authentication {
|
||||
key @0 :Text;
|
||||
email @1 :Text;
|
||||
originCAKey @2 :Text;
|
||||
}
|
||||
|
||||
struct TunnelRegistration {
|
||||
err @0 :Text;
|
||||
# A list of URLs that the tunnel is accessible from.
|
||||
urls @1 :List(Text);
|
||||
# Used to inform the client of actions taken.
|
||||
logLines @2 :List(Text);
|
||||
# In case of error, whether the client should attempt to reconnect.
|
||||
permanentFailure @3 :Bool;
|
||||
}
|
||||
|
||||
struct RegistrationOptions {
|
||||
# The tunnel client's unique identifier, used to verify a reconnection.
|
||||
clientId @0 :Text;
|
||||
# Information about the running binary.
|
||||
version @1 :Text;
|
||||
os @2 :Text;
|
||||
# What to do with existing tunnels for the given hostname.
|
||||
existingTunnelPolicy @3 :ExistingTunnelPolicy;
|
||||
# If using the balancing policy, identifies the LB pool to use.
|
||||
poolId @4 :Text;
|
||||
# Prevents the tunnel from being accessed at <subdomain>.cftunnel.com
|
||||
exposeInternalHostname @5 :Bool;
|
||||
# Client-defined tags to associate with the tunnel
|
||||
tags @6 :List(Tag);
|
||||
}
|
||||
|
||||
struct Tag {
|
||||
name @0 :Text;
|
||||
value @1 :Text;
|
||||
}
|
||||
|
||||
enum ExistingTunnelPolicy {
|
||||
ignore @0;
|
||||
disconnect @1;
|
||||
balance @2;
|
||||
}
|
||||
|
||||
struct ServerInfo {
|
||||
locationName @0 :Text;
|
||||
}
|
||||
|
||||
interface TunnelServer {
|
||||
registerTunnel @0 (auth :Authentication, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
|
||||
getServerInfo @1 () -> (result :ServerInfo);
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,136 @@
|
|||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
const defaultScheme = "http"
|
||||
|
||||
var supportedProtocol = [2]string{"http", "https"}
|
||||
|
||||
func ValidateHostname(hostname string) (string, error) {
|
||||
if hostname == "" {
|
||||
return "", fmt.Errorf("Hostname should not be empty")
|
||||
}
|
||||
// users gives url(contains schema) not just hostname
|
||||
if strings.Contains(hostname, ":") || strings.Contains(hostname, "%3A") {
|
||||
unescapeHostname, err := url.PathUnescape(hostname)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid escape characters %s", hostname, unescapeHostname)
|
||||
}
|
||||
hostnameToURL, err := url.Parse(unescapeHostname)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid format %s", hostname, hostnameToURL)
|
||||
}
|
||||
asciiHostname, err := idna.ToASCII(hostnameToURL.Hostname())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid ASCII encdoing %s", hostname, asciiHostname)
|
||||
}
|
||||
return asciiHostname, nil
|
||||
}
|
||||
|
||||
asciiHostname, err := idna.ToASCII(hostname)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Hostname %s has invalid ASCII encdoing %s", hostname, asciiHostname)
|
||||
}
|
||||
hostnameToURL, err := url.Parse(asciiHostname)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Hostname %s is not valid", hostnameToURL)
|
||||
}
|
||||
return hostnameToURL.RequestURI(), nil
|
||||
|
||||
}
|
||||
|
||||
func ValidateUrl(originUrl string) (string, error) {
|
||||
if originUrl == "" {
|
||||
return "", fmt.Errorf("Url should not be empty")
|
||||
}
|
||||
|
||||
if net.ParseIP(originUrl) != nil {
|
||||
return validateIP("", originUrl, "")
|
||||
} else if strings.HasPrefix(originUrl, "[") && strings.HasSuffix(originUrl, "]") {
|
||||
// ParseIP doesn't recoginze [::1]
|
||||
return validateIP("", originUrl[1:len(originUrl)-1], "")
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(originUrl)
|
||||
// user might pass in an ip address like 127.0.0.1
|
||||
if err == nil && net.ParseIP(host) != nil {
|
||||
return validateIP("", host, port)
|
||||
}
|
||||
|
||||
unescapedUrl, err := url.PathUnescape(originUrl)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("URL %s has invalid escape characters %s", originUrl, unescapedUrl)
|
||||
}
|
||||
|
||||
parsedUrl, err := url.Parse(unescapedUrl)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("URL %s has invalid format", originUrl)
|
||||
}
|
||||
|
||||
// if the url is in the form of host:port, IsAbs() will think host is the schema
|
||||
var hostname string
|
||||
hasScheme := parsedUrl.IsAbs() && parsedUrl.Host != ""
|
||||
if hasScheme {
|
||||
err := validateScheme(parsedUrl.Scheme)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// The earlier check for ip address will miss the case http://[::1]
|
||||
// and http://[::1]:8080
|
||||
if net.ParseIP(parsedUrl.Hostname()) != nil {
|
||||
return validateIP(parsedUrl.Scheme, parsedUrl.Hostname(), parsedUrl.Port())
|
||||
}
|
||||
hostname, err = ValidateHostname(parsedUrl.Hostname())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("URL %s has invalid format", originUrl)
|
||||
}
|
||||
if parsedUrl.Port() != "" {
|
||||
return fmt.Sprintf("%s://%s", parsedUrl.Scheme, net.JoinHostPort(hostname, parsedUrl.Port())), nil
|
||||
}
|
||||
return fmt.Sprintf("%s://%s", parsedUrl.Scheme, hostname), nil
|
||||
} else {
|
||||
if host == "" {
|
||||
hostname, err = ValidateHostname(originUrl)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("URL no %s has invalid format", originUrl)
|
||||
}
|
||||
return fmt.Sprintf("%s://%s", defaultScheme, hostname), nil
|
||||
} else {
|
||||
hostname, err = ValidateHostname(host)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("URL %s has invalid format", originUrl)
|
||||
}
|
||||
return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func validateScheme(scheme string) error {
|
||||
for _, protocol := range supportedProtocol {
|
||||
if scheme == protocol {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("Currently Cloudflare-Warp does not support %s protocol.", scheme)
|
||||
}
|
||||
|
||||
func validateIP(scheme, host, port string) (string, error) {
|
||||
if scheme == "" {
|
||||
scheme = defaultScheme
|
||||
}
|
||||
if port != "" {
|
||||
return fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(host, port)), nil
|
||||
} else if strings.Contains(host, ":") {
|
||||
// IPv6
|
||||
return fmt.Sprintf("%s://[%s]", scheme, host), nil
|
||||
}
|
||||
return fmt.Sprintf("%s://%s", scheme, host), nil
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestValidateHostname(t *testing.T) {
|
||||
var inputHostname string
|
||||
hostname, err := ValidateHostname(inputHostname)
|
||||
assert.Equal(t, err, fmt.Errorf("Hostname should not be empty"))
|
||||
assert.Empty(t, hostname)
|
||||
|
||||
inputHostname = "hello.example.com"
|
||||
hostname, err = ValidateHostname(inputHostname)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "hello.example.com", hostname)
|
||||
|
||||
inputHostname = "http://hello.example.com"
|
||||
hostname, err = ValidateHostname(inputHostname)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "hello.example.com", hostname)
|
||||
|
||||
inputHostname = "bücher.example.com"
|
||||
hostname, err = ValidateHostname(inputHostname)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "xn--bcher-kva.example.com", hostname)
|
||||
|
||||
inputHostname = "http://bücher.example.com"
|
||||
hostname, err = ValidateHostname(inputHostname)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "xn--bcher-kva.example.com", hostname)
|
||||
|
||||
inputHostname = "http%3A%2F%2Fhello.example.com"
|
||||
hostname, err = ValidateHostname(inputHostname)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "hello.example.com", hostname)
|
||||
|
||||
}
|
||||
|
||||
func TestValidateUrl(t *testing.T) {
|
||||
validUrl, err := ValidateUrl("")
|
||||
assert.Equal(t, fmt.Errorf("Url should not be empty"), err)
|
||||
assert.Empty(t, validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https://localhost:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://localhost:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("localhost:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://localhost:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("http://localhost")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://localhost", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("http://127.0.0.1:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://127.0.0.1:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("127.0.0.1:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://127.0.0.1:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("127.0.0.1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://127.0.0.1", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https://127.0.0.1:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://127.0.0.1:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("[::1]:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://[::1]:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("http://[::1]")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://[::1]", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("http://[::1]:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://[::1]:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("[::1]")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://[::1]", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https://example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("http://hello.example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://hello.example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("hello.example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://hello.example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("hello.example.com:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://hello.example.com:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https://hello.example.com:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://hello.example.com:8080", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https://bücher.example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://xn--bcher-kva.example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("bücher.example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "http://xn--bcher-kva.example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https%3A%2F%2Fhello.example.com")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://hello.example.com", validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
|
||||
assert.Equal(t, "Currently Cloudflare-Warp does not support ftp protocol.", err.Error())
|
||||
assert.Empty(t, validUrl)
|
||||
|
||||
validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "https://hello.example.com:8080", validUrl)
|
||||
|
||||
}
|
Loading…
Reference in New Issue