Initial import

This commit is contained in:
Chris Branch 2017-10-16 12:44:03 +01:00
commit 82cb539fbe
46 changed files with 6720 additions and 0 deletions

9
README.md Normal file
View File

@ -0,0 +1,9 @@
# Cloudflare Warp client
Contains the command-line client and its libraries for Cloudflare Warp, a tunneling daemon that proxies any local webserver through the Cloudflare network.
## Getting started
go install github.com/cloudflare/cloudflare-warp/cmd/cloudflare-warp
User documentation for Warp can be found at https://warp.cloudflare.com

View File

@ -0,0 +1,39 @@
package main
import (
"crypto/x509"
)
const cloudflareRootCA = `-----BEGIN CERTIFICATE-----
MIID/DCCAuagAwIBAgIID+rOSdTGfGcwCwYJKoZIhvcNAQELMIGLMQswCQYDVQQG
EwJVUzEZMBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRG
bGFyZSBPcmlnaW4gU1NMIENlcnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMN
U2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZvcm5pYTAeFw0xNDExMTMyMDM4
NTBaFw0xOTExMTQwMTQzNTBaMIGLMQswCQYDVQQGEwJVUzEZMBcGA1UEChMQQ2xv
dWRGbGFyZSwgSW5jLjE0MDIGA1UECxMrQ2xvdWRGbGFyZSBPcmlnaW4gU1NMIENl
cnRpZmljYXRlIEF1dGhvcml0eTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEG
A1UECBMKQ2FsaWZvcm5pYTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
AMBIlWf1KEKR5hbB75OYrAcUXobpD/AxvSYRXr91mbRu+lqE7YbyyRUShQh15lem
ef+umeEtPZoLFLhcLyczJxOhI+siLGDQm/a/UDkWvAXYa5DZ+pHU5ct5nZ8pGzqJ
p8G1Hy5RMVYDXZT9F6EaHjMG0OOffH6Ih25TtgfyyrjXycwDH0u6GXt+G/rywcqz
/9W4Aki3XNQMUHNQAtBLEEIYHMkyTYJxuL2tXO6ID5cCsoWw8meHufTeZW2DyUpl
yP3AHt4149RQSyWZMJ6AyntL9d8Xhfpxd9rJkh9Kge2iV9rQTFuE1rRT5s7OSJcK
xUsklgHcGHYMcNfNMilNHb8CAwEAAaNmMGQwDgYDVR0PAQH/BAQDAgAGMBIGA1Ud
EwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFCToU1ddfDRAh6nrlNu64RZ4/CmkMB8G
A1UdIwQYMBaAFCToU1ddfDRAh6nrlNu64RZ4/CmkMAsGCSqGSIb3DQEBCwOCAQEA
cQDBVAoRrhhsGegsSFsv1w8v27zzHKaJNv6ffLGIRvXK8VKKK0gKXh2zQtN9SnaD
gYNe7Pr4C3I8ooYKRJJWLsmEHdGdnYYmj0OJfGrfQf6MLIc/11bQhLepZTxdhFYh
QGgDl6gRmb8aDwk7Q92BPvek5nMzaWlP82ixavvYI+okoSY8pwdcVKobx6rWzMWz
ZEC9M6H3F0dDYE23XcCFIdgNSAmmGyXPBstOe0aAJXwJTxOEPn36VWr0PKIQJy5Y
4o1wpMpqCOIwWc8J9REV/REzN6Z1LXImdUgXIXOwrz56gKUJzPejtBQyIGj0mveX
Fu6q54beR89jDc+oABmOgg==
-----END CERTIFICATE-----`
func GetCloudflareRootCA() *x509.CertPool {
ca := x509.NewCertPool()
if !ca.AppendCertsFromPEM([]byte(cloudflareRootCA)) {
// should never happen
panic("failure loading Cloudflare origin CA pem")
}
return ca
}

View File

@ -0,0 +1,13 @@
// +build !windows,!darwin,!linux
package main
import (
"os"
cli "gopkg.in/urfave/cli.v2"
)
func runApp(app *cli.App) {
app.Run(os.Args)
}

View File

@ -0,0 +1,155 @@
package main
import (
"bytes"
"fmt"
"html/template"
"net"
"net/http"
"os"
"strings"
"github.com/pkg/errors"
log "github.com/Sirupsen/logrus"
"github.com/cloudflare/cloudflare-warp/origin"
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
cli "gopkg.in/urfave/cli.v2"
)
type templateData struct {
ServerName string
Request *http.Request
Tags []tunnelpogs.Tag
}
const defaultServerName = "the Cloudflare Warp test server"
const indexTemplate = `
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=Edge">
<title>
Cloudflare Warp Connection
</title>
<meta name="author" content="">
<meta name="description" content="Cloudflare Warp Connection">
<meta name="viewport" content="width=device-width, initial-scale=1">
<style>
html{line-height:1.15;-ms-text-size-adjust:100%;-webkit-text-size-adjust:100%}body{margin:0}section{display:block}h1{font-size:2em;margin:.67em 0}a{background-color:transparent;-webkit-text-decoration-skip:objects}/* 1 */::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}/* 1 */a,body,dd,div,dl,dt,h1,h4,html,p,section{box-sizing:border-box}.bt{border-top-style:solid;border-top-width:1px}.bl{border-left-style:solid;border-left-width:1px}.b--orange{border-color:#f38020}.br1{border-radius:.125rem}.bw2{border-width:.25rem}.dib{display:inline-block}.sans-serif{font-family:open sans,-apple-system,BlinkMacSystemFont,avenir next,avenir,helvetica neue,helvetica,ubuntu,roboto,noto,segoe ui,arial,sans-serif}.code{font-family:Consolas,monaco,monospace}.b{font-weight:700}.fw3{font-weight:300}.fw4{font-weight:400}.fw5{font-weight:500}.fw6{font-weight:600}.lh-copy{line-height:1.5}.link{text-decoration:none}.link,.link:active,.link:focus,.link:hover,.link:link,.link:visited{transition:color .15s ease-in}.link:focus{outline:1px dotted currentColor}.mw-100{max-width:100%}.mw4{max-width:8rem}.mw7{max-width:48rem}.bg-light-gray{background-color:#f7f7f7}.link-hover:hover{background-color:#1f679e}.white{color:#fff}.bg-white{background-color:#fff}.bg-blue{background-color:#408bc9}.pb2{padding-bottom:.5rem}.pb6{padding-bottom:8rem}.pt3{padding-top:1rem}.pt5{padding-top:4rem}.pv2{padding-top:.5rem;padding-bottom:.5rem}.ph3{padding-left:1rem;padding-right:1rem}.ph4{padding-left:2rem;padding-right:2rem}.ml0{margin-left:0}.mb1{margin-bottom:.25rem}.mb2{margin-bottom:.5rem}.mb3{margin-bottom:1rem}.mt5{margin-top:4rem}.ttu{text-transform:uppercase}.f4{font-size:1.25rem}.f5{font-size:1rem}.f6{font-size:.875rem}.f7{font-size:.75rem}.measure{max-width:30em}.center{margin-left:auto}.center{margin-right:auto}@media screen and (min-width:30em){.f2-ns{font-size:2.25rem}}@media screen and (min-width:30em) and (max-width:60em){.f5-m{font-size:1rem}}@media screen and (min-width:60em){.f4-l{font-size:1.25rem}}
.st0{fill:#FFF}.st1{fill:#f48120}.st2{fill:#faad3f}.st3{fill:#404041}
</style>
</head>
<body class="sans-serif black">
<div class="bt bw2 b--orange bg-white pb6">
<div class="mw7 center ph4 pt3">
<svg id="Layer_2" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 109 40.5" class="mw4">
<path class="st0" d="M98.6 14.2L93 12.9l-1-.4-25.7.2v12.4l32.3.1z"/>
<path class="st1" d="M88.1 24c.3-1 .2-2-.3-2.6-.5-.6-1.2-1-2.1-1.1l-17.4-.2c-.1 0-.2-.1-.3-.1-.1-.1-.1-.2 0-.3.1-.2.2-.3.4-.3l17.5-.2c2.1-.1 4.3-1.8 5.1-3.8l1-2.6c0-.1.1-.2 0-.3-1.1-5.1-5.7-8.9-11.1-8.9-5 0-9.3 3.2-10.8 7.7-1-.7-2.2-1.1-3.6-1-2.4.2-4.3 2.2-4.6 4.6-.1.6 0 1.2.1 1.8-3.9.1-7.1 3.3-7.1 7.3 0 .4 0 .7.1 1.1 0 .2.2.3.3.3h32.1c.2 0 .4-.1.4-.3l.3-1.1z"/>
<path class="st2" d="M93.6 12.8h-.5c-.1 0-.2.1-.3.2l-.7 2.4c-.3 1-.2 2 .3 2.6.5.6 1.2 1 2.1 1.1l3.7.2c.1 0 .2.1.3.1.1.1.1.2 0 .3-.1.2-.2.3-.4.3l-3.8.2c-2.1.1-4.3 1.8-5.1 3.8l-.2.9c-.1.1 0 .3.2.3h13.2c.2 0 .3-.1.3-.3.2-.8.4-1.7.4-2.6 0-5.2-4.3-9.5-9.5-9.5"/>
<path class="st3" d="M104.4 30.8c-.5 0-.9-.4-.9-.9s.4-.9.9-.9.9.4.9.9-.4.9-.9.9m0-1.6c-.4 0-.7.3-.7.7 0 .4.3.7.7.7.4 0 .7-.3.7-.7 0-.4-.3-.7-.7-.7m.4 1.2h-.2l-.2-.3h-.2v.3h-.2v-.9h.5c.2 0 .3.1.3.3 0 .1-.1.2-.2.3l.2.3zm-.3-.5c.1 0 .1 0 .1-.1s-.1-.1-.1-.1h-.3v.3h.3zM14.8 29H17v6h3.8v1.9h-6zM23.1 32.9c0-2.3 1.8-4.1 4.3-4.1s4.2 1.8 4.2 4.1-1.8 4.1-4.3 4.1c-2.4 0-4.2-1.8-4.2-4.1m6.3 0c0-1.2-.8-2.2-2-2.2s-2 1-2 2.1.8 2.1 2 2.1c1.2.2 2-.8 2-2M34.3 33.4V29h2.2v4.4c0 1.1.6 1.7 1.5 1.7s1.5-.5 1.5-1.6V29h2.2v4.4c0 2.6-1.5 3.7-3.7 3.7-2.3-.1-3.7-1.2-3.7-3.7M45 29h3.1c2.8 0 4.5 1.6 4.5 3.9s-1.7 4-4.5 4h-3V29zm3.1 5.9c1.3 0 2.2-.7 2.2-2s-.9-2-2.2-2h-.9v4h.9zM55.7 29H62v1.9h-4.1v1.3h3.7V34h-3.7v2.9h-2.2zM65.1 29h2.2v6h3.8v1.9h-6zM76.8 28.9H79l3.4 8H80l-.6-1.4h-3.1l-.6 1.4h-2.3l3.4-8zm2 4.9l-.9-2.2-.9 2.2h1.8zM85.2 29h3.7c1.2 0 2 .3 2.6.9.5.5.7 1.1.7 1.8 0 1.2-.6 2-1.6 2.4l1.9 2.8H90l-1.6-2.4h-1v2.4h-2.2V29zm3.6 3.8c.7 0 1.2-.4 1.2-.9 0-.6-.5-.9-1.2-.9h-1.4v1.9h1.4zM95.3 29h6.4v1.8h-4.2V32h3.8v1.8h-3.8V35h4.3v1.9h-6.5zM10 33.9c-.3.7-1 1.2-1.8 1.2-1.2 0-2-1-2-2.1s.8-2.1 2-2.1c.9 0 1.6.6 1.9 1.3h2.3c-.4-1.9-2-3.3-4.2-3.3-2.4 0-4.3 1.8-4.3 4.1s1.8 4.1 4.2 4.1c2.1 0 3.7-1.4 4.2-3.2H10z"/>
</svg>
<h1 class="f4 f2-ns mt5 fw5">Congrats! You created your first tunnel!</h1>
<p class="f6 f5-m f4-l measure lh-copy fw3">
Cloudflare Warp exposes locally running applications to the internet by
running an encrypted, virtual tunnel from your laptop or server to
Cloudflare's edge network.
</p>
<p class="b f5 mt5 fw6">Ready for the next step?</p>
<a
class="fw6 link white bg-blue ph4 pv2 br1 dib f5 link-hover"
style="border-bottom: 1px solid #1f679e"
href="https://warp.cloudflare.com">
Get started here
</a>
{{if .Tags}} <section>
<h4 class="f6 fw4 pt5 mb2">Connection</h4>
<dl class="bl bw2 b--orange ph3 pt3 pb2 bg-light-gray f7 code overflow-x-auto mw-100">
{{range .Tags}} <dt class="ttu mb1">{{.Name}}</dt>
<dd class="ml0 mb3 f5">{{.Value}}</dd>
{{end}} </dl>
</section>
{{end}} </div>
</div>
</body>
</html>
`
func hello(c *cli.Context) error {
address := fmt.Sprintf(":%d", c.Int("port"))
server := NewHelloWorldServer()
if hostname, err := os.Hostname(); err != nil {
server.serverName = hostname
}
err := server.ListenAndServe(address)
return errors.Wrap(err, "Fail to start Hello World Server")
}
func startHelloWorldServer(listener net.Listener, shutdownC <-chan struct{}) error {
server := NewHelloWorldServer()
if hostname, err := os.Hostname(); err != nil {
server.serverName = hostname
}
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: server}
go func() {
<-shutdownC
httpServer.Close()
}()
err := httpServer.Serve(listener)
return err
}
type HelloWorldServer struct {
responseTemplate *template.Template
serverName string
}
func NewHelloWorldServer() *HelloWorldServer {
return &HelloWorldServer{
responseTemplate: template.Must(template.New("index").Parse(indexTemplate)),
serverName: defaultServerName,
}
}
func findAvailablePort() (net.Listener, error) {
// If the port in address is empty, a port number is automatically chosen.
listener, err := net.Listen("tcp", "127.0.0.1:")
return listener, err
}
func (s *HelloWorldServer) ListenAndServe(address string) error {
log.Infof("Starting Hello World server on %s", address)
err := http.ListenAndServe(address, s)
return err
}
func (s *HelloWorldServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.WithField("client", r.RemoteAddr).Infof("%s %s %s", r.Method, r.URL, r.Proto)
var buffer bytes.Buffer
err := s.responseTemplate.Execute(&buffer, &templateData{
ServerName: s.serverName,
Request: r,
Tags: tagsFromHeaders(r.Header),
})
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "error: %v", err)
} else {
buffer.WriteTo(w)
}
}
func tagsFromHeaders(header http.Header) []tunnelpogs.Tag {
var tags []tunnelpogs.Tag
for headerName, headerValues := range header {
trimmed := strings.TrimPrefix(headerName, origin.TagHeaderNamePrefix)
if trimmed == headerName {
continue
}
for _, value := range headerValues {
tags = append(tags, tunnelpogs.Tag{Name: trimmed, Value: value})
}
}
return tags
}

View File

@ -0,0 +1,23 @@
package main
import (
"testing"
)
const testPort = "8080"
func TestNewHelloWorldServer(t *testing.T) {
if NewHelloWorldServer() == nil {
t.Fatal("NewHelloWorldServer returned nil")
}
}
func TestFindAvailablePort(t *testing.T) {
listener, err := findAvailablePort()
if err != nil {
t.Fatal("Fail to find available port")
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}

View File

@ -0,0 +1,264 @@
// +build linux
package main
import (
"fmt"
"os"
cli "gopkg.in/urfave/cli.v2"
)
func runApp(app *cli.App) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Cloudflare Warp system service",
Subcommands: []*cli.Command{
&cli.Command{
Name: "install",
Usage: "Install Cloudflare Warp as a system service",
Action: installLinuxService,
},
&cli.Command{
Name: "uninstall",
Usage: "Uninstall the Cloudflare Warp service",
Action: uninstallLinuxService,
},
},
})
app.Run(os.Args)
}
var systemdTemplates = []ServiceTemplate{
{
Path: "/etc/systemd/system/cloudflare-warp.service",
Content: `[Unit]
Description=Cloudflare Warp
After=network.target
[Service]
TimeoutStartSec=0
Type=notify
ExecStart={{ .Path }} --config /etc/cloudflare-warp.yml --autoupdate 0s
User=nobody
[Install]
WantedBy=multi-user.target
`,
},
{
Path: "/etc/systemd/system/cloudflare-warp-update.service",
Content: `[Unit]
Description=Update Cloudflare Warp
After=network.target
[Service]
ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 64 ]; then systemctl restart cloudflare-warp; exit 0; fi; exit $code'
`,
},
{
Path: "/etc/systemd/system/cloudflare-warp-update.timer",
Content: `[Unit]
Description=Update Cloudflare Warp
[Timer]
OnUnitActiveSec=1d
[Install]
WantedBy=timers.target
`,
},
}
var sysvTemplate = ServiceTemplate{
Path: "/etc/init.d/cloudflare-warp",
FileMode: 0755,
Content: `# For RedHat and cousins:
# chkconfig: 2345 99 01
# description: Cloudflare Warp agent
# processname: {{.Path}}
### BEGIN INIT INFO
# Provides: {{.Path}}
# Required-Start:
# Required-Stop:
# Default-Start: 2 3 4 5
# Default-Stop: 0 1 6
# Short-Description: Cloudflare Warp
# Description: Cloudflare Warp agent
### END INIT INFO
cmd="{{.Path}} --config /etc/cloudflare-warp.yml --pidfile /var/run/$name.pid"
name=$(basename $(readlink -f $0))
pid_file="/var/run/$name.pid"
stdout_log="/var/log/$name.log"
stderr_log="/var/log/$name.err"
[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name
get_pid() {
cat "$pid_file"
}
is_running() {
[ -f "$pid_file" ] && ps $(get_pid) > /dev/null 2>&1
}
case "$1" in
start)
if is_running; then
echo "Already started"
else
echo "Starting $name"
$cmd >> "$stdout_log" 2>> "$stderr_log" &
echo $! > "$pid_file"
if ! is_running; then
echo "Unable to start, see $stdout_log and $stderr_log"
exit 1
fi
fi
;;
stop)
if is_running; then
echo -n "Stopping $name.."
kill $(get_pid)
for i in {1..10}
do
if ! is_running; then
break
fi
echo -n "."
sleep 1
done
echo
if is_running; then
echo "Not stopped; may still be shutting down or shutdown may have failed"
exit 1
else
echo "Stopped"
if [ -f "$pid_file" ]; then
rm "$pid_file"
fi
fi
else
echo "Not running"
fi
;;
restart)
$0 stop
if is_running; then
echo "Unable to stop, will not attempt to start"
exit 1
fi
$0 start
;;
status)
if is_running; then
echo "Running"
else
echo "Stopped"
exit 1
fi
;;
*)
echo "Usage: $0 {start|stop|restart|status}"
exit 1
;;
esac
exit 0
`,
}
func isSystemd() bool {
if _, err := os.Stat("/run/systemd/system"); err == nil {
return true
}
return false
}
func installLinuxService(c *cli.Context) error {
etPath, err := os.Executable()
if err != nil {
return fmt.Errorf("error determining executable path: %v", err)
}
templateArgs := ServiceTemplateArgs{Path: etPath}
switch {
case isSystemd():
return installSystemd(&templateArgs)
default:
return installSysv(&templateArgs)
}
}
func installSystemd(templateArgs *ServiceTemplateArgs) error {
for _, serviceTemplate := range systemdTemplates {
err := serviceTemplate.Generate(templateArgs)
if err != nil {
return err
}
}
if err := runCommand("systemctl", "enable", "cloudflare-warp.service"); err != nil {
return err
}
if err := runCommand("systemctl", "start", "cloudflare-warp-update.timer"); err != nil {
return err
}
return runCommand("systemctl", "daemon-reload")
}
func installSysv(templateArgs *ServiceTemplateArgs) error {
confPath, err := sysvTemplate.ResolvePath()
if err != nil {
return err
}
if err := sysvTemplate.Generate(templateArgs); err != nil {
return err
}
for _, i := range [...]string{"2", "3", "4", "5"} {
if err := os.Symlink(confPath, "/etc/rc"+i+".d/S50et"); err != nil {
continue
}
}
for _, i := range [...]string{"0", "1", "6"} {
if err := os.Symlink(confPath, "/etc/rc"+i+".d/K02et"); err != nil {
continue
}
}
return nil
}
func uninstallLinuxService(c *cli.Context) error {
switch {
case isSystemd():
return uninstallSystemd()
default:
return uninstallSysv()
}
}
func uninstallSystemd() error {
if err := runCommand("systemctl", "disable", "cloudflare-warp.service"); err != nil {
return err
}
if err := runCommand("systemctl", "stop", "cloudflare-warp-update.timer"); err != nil {
return err
}
for _, serviceTemplate := range systemdTemplates {
if err := serviceTemplate.Remove(); err != nil {
return err
}
}
return nil
}
func uninstallSysv() error {
if err := sysvTemplate.Remove(); err != nil {
return err
}
for _, i := range [...]string{"2", "3", "4", "5"} {
if err := os.Remove("/etc/rc" + i + ".d/S50et"); err != nil {
continue
}
}
for _, i := range [...]string{"0", "1", "6"} {
if err := os.Remove("/etc/rc" + i + ".d/K02et"); err != nil {
continue
}
}
return nil
}

View File

@ -0,0 +1,31 @@
package main
import (
"fmt"
"os"
"syscall"
homedir "github.com/mitchellh/go-homedir"
cli "gopkg.in/urfave/cli.v2"
)
func login(c *cli.Context) error {
path, err := homedir.Expand(defaultConfigPath)
if err != nil {
return err
}
fileInfo, err := os.Stat(path)
if err == nil && fileInfo.Size() > 0 {
fmt.Fprintf(os.Stderr, `You have an existing config file at %s which login would overwrite.
If this is intentional, please move or delete that file then run this command again.
`, defaultConfigPath)
return nil
}
if err != nil && err.(*os.PathError).Err != syscall.ENOENT {
return err
}
fmt.Fprintln(os.Stderr, "Please visit https://www.cloudflare.com/a/warp to obtain a certificate.")
return nil
}

View File

@ -0,0 +1,82 @@
// +build darwin
package main
import (
"fmt"
"os"
cli "gopkg.in/urfave/cli.v2"
)
func runApp(app *cli.App) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Cloudflare Warp launch agent",
Subcommands: []*cli.Command{
&cli.Command{
Name: "install",
Usage: "Install Cloudflare Warp as an user launch agent",
Action: installLaunchd,
},
&cli.Command{
Name: "uninstall",
Usage: "Uninstall the Cloudflare Warp launch agent",
Action: uninstallLaunchd,
},
},
})
app.Run(os.Args)
}
var launchdTemplate = ServiceTemplate{
Path: "~/Library/LaunchAgents/com.cloudflare.warp.plist",
Content: `<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>com.cloudflare.warp</string>
<key>Program</key>
<string>{{ .Path }}</string>
<key>RunAtLoad</key>
<true/>
<key>KeepAlive</key>
<dict>
<key>NetworkState</key>
<true/>
</dict>
<key>ThrottleInterval</key>
<integer>20</integer>
</dict>
</plist>`,
}
func installLaunchd(c *cli.Context) error {
etPath, err := os.Executable()
if err != nil {
return fmt.Errorf("error determining executable path: %v", err)
}
templateArgs := ServiceTemplateArgs{Path: etPath}
err = launchdTemplate.Generate(&templateArgs)
if err != nil {
return err
}
plistPath, err := launchdTemplate.ResolvePath()
if err != nil {
return err
}
return runCommand("launchctl", "load", plistPath)
}
func uninstallLaunchd(c *cli.Context) error {
plistPath, err := launchdTemplate.ResolvePath()
if err != nil {
return err
}
err = runCommand("launchctl", "unload", plistPath)
if err != nil {
return err
}
return launchdTemplate.Remove()
}

473
cmd/cloudflare-warp/main.go Normal file
View File

@ -0,0 +1,473 @@
package main
import (
"crypto/tls"
"encoding/hex"
"fmt"
"math/rand"
"net"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/cloudflare/cloudflare-warp/h2mux"
"github.com/cloudflare/cloudflare-warp/metrics"
"github.com/cloudflare/cloudflare-warp/origin"
"github.com/cloudflare/cloudflare-warp/tlsconfig"
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
"github.com/cloudflare/cloudflare-warp/validation"
log "github.com/Sirupsen/logrus"
"github.com/facebookgo/grace/gracenet"
raven "github.com/getsentry/raven-go"
homedir "github.com/mitchellh/go-homedir"
cli "gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
"github.com/coreos/go-systemd/daemon"
"github.com/pkg/errors"
)
const sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
const defaultConfigPath = "~/.cloudflare-warp.yml"
var listeners = gracenet.Net{}
var Version = "DEV"
var BuildTime = "unknown"
// Shutdown channel used by the app. When closed, app must terminate.
// May be closed by the Windows service runner.
var shutdownC chan struct{}
func main() {
metrics.RegisterBuildInfo(BuildTime, Version)
raven.SetDSN(sentryDSN)
raven.SetRelease(Version)
shutdownC = make(chan struct{})
app := &cli.App{}
app.Name = "cloudflare-warp"
app.Copyright = `(c) 2017 Cloudflare Inc.
Use is subject to the license agreement at https://warp.cloudflare.com/licence/`
app.Usage = "Cloudflare reverse tunnelling proxy agent \033[1;31m*BETA*\033[0m"
app.ArgsUsage = "origin-url"
app.Version = fmt.Sprintf("%s (built %s)", Version, BuildTime)
app.Description = `A reverse tunnel proxy agent that connects to Cloudflare's infrastructure.
Upon connecting, you are assigned a unique subdomain on cftunnel.com.
Alternatively, you can specify a hostname on a zone you control.
Requests made to Cloudflare's servers for your hostname will be proxied
through the tunnel to your local webserver.
WARNING:
` + "\033[1;31m*** THIS IS A BETA VERSION OF THE CLOUDFLARE WARP AGENT ***\033[0m" + `
At this time, do not use Cloudflare Warp for connecting production servers to Cloudflare.
Availability and reliability of this service is not guaranteed through the beta period.`
app.Flags = []cli.Flag{
&cli.StringFlag{
Name: "config",
Usage: "Specifies a config file in YAML format.",
},
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "autoupdate",
Usage: "Periodically check for updates, restarting the server with the new version.",
Value: time.Hour * 24,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "edge",
Value: "cftunnel.com:7844",
Usage: "Address of the Cloudflare tunnel server.",
EnvVars: []string{"TUNNEL_EDGE"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "cacert",
Usage: "Certificate Authority authenticating the Cloudflare tunnel connection.",
EnvVars: []string{"TUNNEL_CACERT"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "url",
Value: "http://localhost:8080",
Usage: "Connect to the local webserver at `URL`.",
EnvVars: []string{"TUNNEL_URL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "hostname",
Usage: "Set a hostname on a Cloudflare zone to route traffic through this tunnel.",
EnvVars: []string{"TUNNEL_HOSTNAME"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "id",
Usage: "A unique identifier used to tie connections to this tunnel instance.",
EnvVars: []string{"TUNNEL_ID"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "lb-pool",
Usage: "The name of a (new/existing) load balancing pool to add this origin to.",
EnvVars: []string{"TUNNEL_LB_POOL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "api-key",
Usage: "A Cloudflare API key. Required(can be in the config file) unless you are only running the hello command or login command.",
EnvVars: []string{"TUNNEL_API_KEY"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "api-email",
Usage: "The Cloudflare user's email address associated with the API key. Required(can be in the config file) unless you are only running the hello command or login command.",
EnvVars: []string{"TUNNEL_API_EMAIL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "api-ca-key",
Usage: "The Origin CA service key associated with the user. Required(can be in the config file) unless you are only running the hello command or login command.",
EnvVars: []string{"TUNNEL_API_CA_KEY"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "metrics",
Value: "localhost:",
Usage: "Listen address for metrics reporting.",
EnvVars: []string{"TUNNEL_METRICS"},
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "tag",
Usage: "Custom tags used to identify this tunnel, in format `KEY=VALUE`. Multiple tags may be specified",
EnvVars: []string{"TUNNEL_TAG"},
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "heartbeat-interval",
Usage: "Minimum idle time before sending a heartbeat.",
Value: time.Second * 5,
Hidden: true,
}),
altsrc.NewUint64Flag(&cli.Uint64Flag{
Name: "heartbeat-count",
Usage: "Minimum number of unacked heartbeats to send before closing the connection.",
Value: 5,
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "loglevel",
Value: "info",
Usage: "Logging level {panic, fatal, error, warn, info, debug}",
EnvVars: []string{"TUNNEL_LOGLEVEL"},
}),
altsrc.NewUintFlag(&cli.UintFlag{
Name: "retries",
Value: 5,
Usage: "Maximum number of retries for connection/protocol errors.",
EnvVars: []string{"TUNNEL_RETRIES"},
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "debug",
Value: false,
Usage: "Enable HTTP requests to the autogenerated cftunnel.com domain.",
EnvVars: []string{"TUNNEL_DEBUG"},
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "hello-world",
Usage: "Run Hello World Server",
Value: false,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "pidfile",
Usage: "Write the application's PID to this file after first successful connection.",
EnvVars: []string{"TUNNEL_PIDFILE"},
}),
}
app.Action = func(c *cli.Context) error {
raven.CapturePanic(func() { startServer(c) }, nil)
return nil
}
app.Before = func(context *cli.Context) error {
inputSource, err := findInputSourceContext(context)
if err != nil {
return err
} else if inputSource != nil {
return altsrc.ApplyInputSourceValues(context, inputSource, app.Flags)
}
return nil
}
app.Commands = []*cli.Command{
&cli.Command{
Name: "update",
Action: update,
Usage: "Update the agent if a new version exists",
ArgsUsage: " ",
Description: `Looks for a new version on the offical download server.
If a new version exists, updates the agent binary and quits.
Otherwise, does nothing.
To determine if an update happened in a script, check for error code 64.`,
},
&cli.Command{
Name: "login",
Action: login,
Usage: "Generate a configuration file with your login details",
ArgsUsage: " ",
},
&cli.Command{
Name: "hello",
Action: hello,
Usage: "Run a simple \"Hello World\" server for testing Cloudflare Warp.",
Flags: []cli.Flag{
&cli.IntFlag{
Name: "port",
Usage: "Listen on the selected port.",
Value: 8080,
},
},
ArgsUsage: " ", // can't be the empty string or we get the default output
},
}
runApp(app)
}
func startServer(c *cli.Context) {
var wg sync.WaitGroup
errC := make(chan error)
wg.Add(2)
if c.NumFlags() == 0 && c.NArg() == 0 {
cli.ShowAppHelp(c)
return
}
logLevel, err := log.ParseLevel(c.String("loglevel"))
if err != nil {
log.WithError(err).Fatal("Unknown logging level specified")
}
log.SetLevel(logLevel)
hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil {
log.WithError(err).Fatal("Invalid hostname")
}
clientID := c.String("id")
if !c.IsSet("id") {
clientID = generateRandomClientID()
}
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
if err != nil {
log.WithError(err).Fatal("Tag parse failure")
}
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
if c.IsSet("hello-world") {
wg.Add(1)
listener, err := findAvailablePort()
if err != nil {
listener.Close()
log.WithError(err).Fatal("Cannot start Hello World Server")
}
go func() {
startHelloWorldServer(listener, shutdownC)
wg.Done()
}()
c.Set("url", "http://"+listener.Addr().String())
log.Infof("Starting Hello World Server at %s", c.String("url"))
}
url, err := validateUrl(c)
if err != nil {
log.WithError(err).Fatal("Error validating url")
}
// User must have api-key, api-email and api-ca-key
if !c.IsSet("api-key") {
log.Fatal("You need to give us your api-key either via the --api-key option or put it in the configuration file. You will also need to give us your api-email and api-ca-key.")
}
if !c.IsSet("api-email") {
log.Fatal("You need to give us your api-email either via the --api-email option or put it in the configuration file. You will also need to give us your api-ca-key.")
}
if !c.IsSet("api-ca-key") {
log.Fatal("You need to give us your api-ca-key either via the --api-ca-key option or put it in the configuration file.")
}
log.Infof("Proxying tunnel requests to %s", url)
tunnelConfig := &origin.TunnelConfig{
EdgeAddr: c.String("edge"),
OriginUrl: url,
Hostname: hostname,
APIKey: c.String("api-key"),
APIEmail: c.String("api-email"),
APICAKey: c.String("api-ca-key"),
TlsConfig: &tls.Config{},
Retries: c.Uint("retries"),
HeartbeatInterval: c.Duration("heartbeat-interval"),
MaxHeartbeats: c.Uint64("heartbeat-count"),
ClientID: clientID,
ReportedVersion: Version,
LBPool: c.String("lb-pool"),
Tags: tags,
AccessInternalIP: c.Bool("debug"),
ConnectedSignal: h2mux.NewSignal(),
}
tunnelConfig.TlsConfig = tlsconfig.CLIFlags{RootCA: "cacert"}.GetConfig(c)
if tunnelConfig.TlsConfig.RootCAs == nil {
tunnelConfig.TlsConfig.RootCAs = GetCloudflareRootCA()
tunnelConfig.TlsConfig.ServerName = "cftunnel.com"
} else {
tunnelConfig.TlsConfig.ServerName, _, _ = net.SplitHostPort(tunnelConfig.EdgeAddr)
}
go writePidFile(tunnelConfig.ConnectedSignal, c.String("pidfile"))
go func() {
errC <- origin.StartTunnelDaemon(tunnelConfig, shutdownC)
wg.Done()
}()
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
if err != nil {
log.WithError(err).Fatal("Error opening metrics server listener")
}
go func() {
errC <- metrics.ServeMetrics(metricsListener, shutdownC)
wg.Done()
}()
go autoupdate(c.Duration("autoupdate"), shutdownC)
err = WaitForSignal(errC, shutdownC)
if err != nil {
log.WithError(err).Error("Quitting due to error")
raven.CaptureErrorAndWait(err, nil)
} else {
log.Info("Quitting...")
}
// Wait for clean exit, discarding all errors
go func() {
for range errC {
}
}()
wg.Wait()
}
func WaitForSignal(errC chan error, shutdownC chan struct{}) error {
signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signals)
select {
case err := <-errC:
close(shutdownC)
return err
case <-signals:
close(shutdownC)
case <-shutdownC:
}
return nil
}
func update(c *cli.Context) error {
if updateApplied() {
os.Exit(64)
}
return nil
}
func autoupdate(frequency time.Duration, shutdownC chan struct{}) {
if int64(frequency) == 0 {
return
}
for {
if updateApplied() {
if _, err := listeners.StartProcess(); err != nil {
log.WithError(err).Error("Unable to restart server automatically")
}
close(shutdownC)
return
}
time.Sleep(frequency)
}
}
func updateApplied() bool {
releaseInfo := checkForUpdates()
if releaseInfo.Updated {
log.Infof("Updated to version %s", releaseInfo.Version)
return true
}
if releaseInfo.Error != nil {
log.WithError(releaseInfo.Error).Error("Update check failed")
}
return false
}
func fileExists(path string) (bool, error) {
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
// ignore missing files
return false, nil
}
return false, err
}
f.Close()
return true, nil
}
func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) {
if context.IsSet("config") {
return altsrc.NewYamlSourceFromFile(context.String("config"))
}
for _, tryPath := range []string{
defaultConfigPath,
"~/.cloudflare-warp.yaml",
"~/cloudflare-warp.yaml",
"~/cloudflare-warp.yml",
"~/.et.yaml",
"~/et.yml",
"~/et.yaml",
"~/.cftunnel.yaml", // for existing users
"~/cftunnel.yaml",
} {
path, err := homedir.Expand(tryPath)
if err != nil {
continue
}
ok, err := fileExists(path)
if ok {
return altsrc.NewYamlSourceFromFile(path)
} else if err != nil {
return nil, err
}
}
return nil, nil
}
func generateRandomClientID() string {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
id := make([]byte, 32)
r.Read(id)
return hex.EncodeToString(id)
}
func writePidFile(waitForSignal h2mux.Signal, pidFile string) {
waitForSignal.Wait()
daemon.SdNotify(false, "READY=1")
if pidFile == "" {
return
}
file, err := os.Create(pidFile)
if err != nil {
log.WithError(err).Errorf("Unable to write pid to %s", pidFile)
}
defer file.Close()
fmt.Fprintf(file, "%d", os.Getpid())
}
// validate url. It can be either from --url or argument
func validateUrl(c *cli.Context) (string, error) {
var url = c.String("url")
if c.NArg() > 0 {
if c.IsSet("url") {
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
}
url = c.Args().Get(0)
}
validUrl, err := validation.ValidateUrl(url)
return validUrl, err
}

View File

@ -0,0 +1,88 @@
package main
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"os/exec"
"text/template"
homedir "github.com/mitchellh/go-homedir"
)
type ServiceTemplate struct {
Path string
Content string
FileMode os.FileMode
}
type ServiceTemplateArgs struct {
Path string
}
func (st *ServiceTemplate) ResolvePath() (string, error) {
resolvedPath, err := homedir.Expand(st.Path)
if err != nil {
return "", fmt.Errorf("error resolving path %s: %v", st.Path, err)
}
return resolvedPath, nil
}
func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
tmpl, err := template.New(st.Path).Parse(st.Content)
if err != nil {
return fmt.Errorf("error generating %s template: %v", st.Path, err)
}
resolvedPath, err := st.ResolvePath()
if err != nil {
return err
}
var buffer bytes.Buffer
err = tmpl.Execute(&buffer, args)
if err != nil {
return fmt.Errorf("error generating %s: %v", st.Path, err)
}
fileMode := os.FileMode(0644)
if st.FileMode != 0 {
fileMode = st.FileMode
}
err = ioutil.WriteFile(resolvedPath, buffer.Bytes(), fileMode)
if err != nil {
return fmt.Errorf("error writing %s: %v", resolvedPath, err)
}
return nil
}
func (st *ServiceTemplate) Remove() error {
resolvedPath, err := st.ResolvePath()
if err != nil {
return err
}
err = os.Remove(resolvedPath)
if err != nil {
return fmt.Errorf("error deleting %s: %v", resolvedPath, err)
}
return nil
}
func runCommand(command string, args ...string) error {
cmd := exec.Command(command, args...)
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("error getting stderr pipe: %v", err)
}
err = cmd.Start()
if err != nil {
return fmt.Errorf("error starting %s: %v", command, err)
}
commandErr, _ := ioutil.ReadAll(stderr)
if len(commandErr) > 0 {
return fmt.Errorf("%s error: %s", command, commandErr)
}
err = cmd.Wait()
if err != nil {
return fmt.Errorf("%s returned with error: %v", command, err)
}
return nil
}

View File

@ -0,0 +1,32 @@
package main
import (
"fmt"
"regexp"
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
)
// Restrict key names to characters allowed in an HTTP header name.
// Restrict key values to printable characters (what is recognised as data in an HTTP header value).
var tagRegexp = regexp.MustCompile("^([a-zA-Z0-9!#$%&'*+\\-.^_`|~]+)=([[:print:]]+)$")
func NewTagFromCLI(compoundTag string) (tunnelpogs.Tag, bool) {
matches := tagRegexp.FindStringSubmatch(compoundTag)
if len(matches) == 0 {
return tunnelpogs.Tag{}, false
}
return tunnelpogs.Tag{Name: matches[1], Value: matches[2]}, true
}
func NewTagSliceFromCLI(tags []string) ([]tunnelpogs.Tag, error) {
var tagSlice []tunnelpogs.Tag
for _, compoundTag := range tags {
if tag, ok := NewTagFromCLI(compoundTag); ok {
tagSlice = append(tagSlice, tag)
} else {
return nil, fmt.Errorf("Cannot parse tag value %s", compoundTag)
}
}
return tagSlice, nil
}

View File

@ -0,0 +1,46 @@
package main
import (
"testing"
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
"github.com/stretchr/testify/assert"
)
func TestSingleTag(t *testing.T) {
testCases := []struct {
Input string
Output tunnelpogs.Tag
Fail bool
}{
{Input: "x=y", Output: tunnelpogs.Tag{Name: "x", Value: "y"}},
{Input: "More-Complex=Tag Values", Output: tunnelpogs.Tag{Name: "More-Complex", Value: "Tag Values"}},
{Input: "First=Equals=Wins", Output: tunnelpogs.Tag{Name: "First", Value: "Equals=Wins"}},
{Input: "x=", Fail: true},
{Input: "=y", Fail: true},
{Input: "=", Fail: true},
{Input: "No spaces allowed=in key names", Fail: true},
{Input: "omg\nwtf=bbq", Fail: true},
}
for i, testCase := range testCases {
tag, ok := NewTagFromCLI(testCase.Input)
assert.Equalf(t, !testCase.Fail, ok, "mismatched success for test case %d", i)
assert.Equalf(t, testCase.Output, tag, "mismatched output for test case %d", i)
}
}
func TestTagSlice(t *testing.T) {
tagSlice, err := NewTagSliceFromCLI([]string{"a=b", "c=d", "e=f"})
assert.NoError(t, err)
assert.Len(t, tagSlice, 3)
assert.Equal(t, "a", tagSlice[0].Name)
assert.Equal(t, "b", tagSlice[0].Value)
assert.Equal(t, "c", tagSlice[1].Name)
assert.Equal(t, "d", tagSlice[1].Value)
assert.Equal(t, "e", tagSlice[2].Name)
assert.Equal(t, "f", tagSlice[2].Value)
tagSlice, err = NewTagSliceFromCLI([]string{"a=b", "=", "e=f"})
assert.Error(t, err)
}

View File

@ -0,0 +1,41 @@
package main
import "github.com/equinox-io/equinox"
const appID = "app_cwbQae3Tpea"
var publicKey = []byte(`
-----BEGIN ECDSA PUBLIC KEY-----
MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE4OWZocTVZ8Do/L6ScLdkV+9A0IYMHoOf
dsCmJ/QZ6aw0w9qkkwEpne1Lmo6+0pGexZzFZOH6w5amShn+RXt7qkSid9iWlzGq
EKx0BZogHSor9Wy5VztdFaAaVbsJiCbO
-----END ECDSA PUBLIC KEY-----
`)
type ReleaseInfo struct {
Updated bool
Version string
Error error
}
func checkForUpdates() ReleaseInfo {
var opts equinox.Options
if err := opts.SetPublicKeyPEM(publicKey); err != nil {
return ReleaseInfo{Error: err}
}
resp, err := equinox.Check(appID, opts)
switch {
case err == equinox.NotAvailableErr:
return ReleaseInfo{}
case err != nil:
return ReleaseInfo{Error: err}
}
err = resp.Apply()
if err != nil {
return ReleaseInfo{Error: err}
}
return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion}
}

View File

@ -0,0 +1,145 @@
// +build windows
package main
// Copypasta from the example files:
// https://github.com/golang/sys/blob/master/windows/svc/example
import (
"fmt"
"os"
log "github.com/Sirupsen/logrus"
cli "gopkg.in/urfave/cli.v2"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/eventlog"
"golang.org/x/sys/windows/svc/mgr"
)
const (
windowsServiceName = "CloudflareWarp"
windowsServiceDescription = "Cloudflare Warp agent"
)
func runApp(app *cli.App) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Cloudflare Warp Windows service",
Subcommands: []*cli.Command{
&cli.Command{
Name: "install",
Usage: "Install Cloudflare Warp as a Windows service",
Action: installWindowsService,
},
&cli.Command{
Name: "uninstall",
Usage: "Uninstall the Cloudflare Warp service",
Action: uninstallWindowsService,
},
},
})
isIntSess, err := svc.IsAnInteractiveSession()
if err != nil {
log.Fatalf("failed to determine if we are running in an interactive session: %v", err)
}
if isIntSess {
app.Run(os.Args)
return
}
elog, err := eventlog.Open(windowsServiceName)
if err != nil {
return
}
defer elog.Close()
elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName))
err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog})
if err != nil {
elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
return
}
elog.Info(1, fmt.Sprintf("%s service stopped", windowsServiceName))
}
type windowsService struct {
app *cli.App
elog *eventlog.Log
}
func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, errno uint32) {
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
changes <- svc.Status{State: svc.StartPending}
go s.app.Run(args)
changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
loop:
for {
select {
case c := <-r:
switch c.Cmd {
case svc.Interrogate:
changes <- c.CurrentStatus
case svc.Stop, svc.Shutdown:
break loop
default:
s.elog.Error(1, fmt.Sprintf("unexpected control request #%d", c))
}
}
}
close(shutdownC)
changes <- svc.Status{State: svc.StopPending}
return
}
func installWindowsService(c *cli.Context) error {
exepath, err := os.Executable()
if err != nil {
return err
}
m, err := mgr.Connect()
if err != nil {
return err
}
defer m.Disconnect()
s, err := m.OpenService(windowsServiceName)
if err == nil {
s.Close()
return fmt.Errorf("service %s already exists", windowsServiceName)
}
s, err = m.CreateService(windowsServiceName, exepath, mgr.Config{DisplayName: windowsServiceDescription}, "is", "auto-started")
if err != nil {
return err
}
defer s.Close()
err = eventlog.InstallAsEventCreate(windowsServiceName, eventlog.Error|eventlog.Warning|eventlog.Info)
if err != nil {
s.Delete()
return fmt.Errorf("SetupEventLogSource() failed: %s", err)
}
return nil
}
func uninstallWindowsService(c *cli.Context) error {
m, err := mgr.Connect()
if err != nil {
return err
}
defer m.Disconnect()
s, err := m.OpenService(windowsServiceName)
if err != nil {
return fmt.Errorf("service %s is not installed", windowsServiceName)
}
defer s.Close()
err = s.Delete()
if err != nil {
return err
}
err = eventlog.Remove(windowsServiceName)
if err != nil {
return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
}
return nil
}

213
h2mux/activestreammap.go Normal file
View File

@ -0,0 +1,213 @@
package h2mux
import (
"sync"
"golang.org/x/net/http2"
)
// activeStreamMap is used to moderate access to active streams between the read and write
// threads, and deny access to new peer streams while shutting down.
type activeStreamMap struct {
sync.RWMutex
// streams tracks open streams.
streams map[uint32]*MuxedStream
// streamsEmpty is a chan that should be closed when no more streams are open.
streamsEmpty chan struct{}
// nextStreamID is the next ID to use on our side of the connection.
// This is odd for clients, even for servers.
nextStreamID uint32
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
maxPeerStreamID uint32
// ignoreNewStreams is true when the connection is being shut down. New streams
// cannot be registered.
ignoreNewStreams bool
}
type FlowControlMetrics struct {
AverageReceiveWindowSize, AverageSendWindowSize float64
MinReceiveWindowSize, MaxReceiveWindowSize uint32
MinSendWindowSize, MaxSendWindowSize uint32
}
func newActiveStreamMap(useClientStreamNumbers bool) *activeStreamMap {
m := &activeStreamMap{
streams: make(map[uint32]*MuxedStream),
streamsEmpty: make(chan struct{}),
nextStreamID: 1,
}
// Client initiated stream uses odd stream ID, server initiated stream uses even stream ID
if !useClientStreamNumbers {
m.nextStreamID = 2
}
return m
}
// Len returns the number of active streams.
func (m *activeStreamMap) Len() int {
m.RLock()
defer m.RUnlock()
return len(m.streams)
}
func (m *activeStreamMap) Get(streamID uint32) (*MuxedStream, bool) {
m.RLock()
defer m.RUnlock()
stream, ok := m.streams[streamID]
return stream, ok
}
// Set returns true if the stream was assigned successfully. If a stream
// already existed with that ID or we are shutting down, return false.
func (m *activeStreamMap) Set(newStream *MuxedStream) bool {
m.Lock()
defer m.Unlock()
if _, ok := m.streams[newStream.streamID]; ok {
return false
}
if m.ignoreNewStreams {
return false
}
m.streams[newStream.streamID] = newStream
return true
}
// Delete stops tracking the stream. It should be called only after it is closed and resetted.
func (m *activeStreamMap) Delete(streamID uint32) {
m.Lock()
defer m.Unlock()
delete(m.streams, streamID)
if len(m.streams) == 0 && m.streamsEmpty != nil {
close(m.streamsEmpty)
m.streamsEmpty = nil
}
}
// Shutdown blocks new streams from being created. It returns a channel that receives an event
// once the last stream has closed, or nil if a shutdown is in progress.
func (m *activeStreamMap) Shutdown() <-chan struct{} {
m.Lock()
defer m.Unlock()
if m.ignoreNewStreams {
// already shutting down
return nil
}
m.ignoreNewStreams = true
done := make(chan struct{})
if len(m.streams) == 0 {
// nothing to shut down
close(done)
return done
}
m.streamsEmpty = done
return done
}
// AcquireLocalID acquires a new stream ID for a stream you're opening.
func (m *activeStreamMap) AcquireLocalID() uint32 {
m.Lock()
defer m.Unlock()
x := m.nextStreamID
m.nextStreamID += 2
return x
}
// ObservePeerID observes the ID of a stream opened by the peer. It returns true if we should accept
// the new stream, or false to reject it. The ErrCode gives the reason why.
func (m *activeStreamMap) AcquirePeerID(streamID uint32) (bool, http2.ErrCode) {
m.Lock()
defer m.Unlock()
switch {
case m.ignoreNewStreams:
return false, http2.ErrCodeStreamClosed
case streamID > m.maxPeerStreamID:
m.maxPeerStreamID = streamID
return true, http2.ErrCodeNo
default:
return false, http2.ErrCodeStreamClosed
}
}
// IsPeerStreamID is true if the stream ID belongs to the peer.
func (m *activeStreamMap) IsPeerStreamID(streamID uint32) bool {
m.RLock()
defer m.RUnlock()
return (streamID % 2) != (m.nextStreamID % 2)
}
// IsLocalStreamID is true if it is a stream we have opened, even if it is now closed.
func (m *activeStreamMap) IsLocalStreamID(streamID uint32) bool {
m.RLock()
defer m.RUnlock()
return (streamID%2) == (m.nextStreamID%2) && streamID < m.nextStreamID
}
// LastPeerStreamID returns the most recently opened peer stream ID.
func (m *activeStreamMap) LastPeerStreamID() uint32 {
m.RLock()
defer m.RUnlock()
return m.maxPeerStreamID
}
// LastLocalStreamID returns the most recently opened local stream ID.
func (m *activeStreamMap) LastLocalStreamID() uint32 {
m.RLock()
defer m.RUnlock()
if m.nextStreamID > 1 {
return m.nextStreamID - 2
}
return 0
}
// Abort closes every active stream and prevents new ones being created. This should be used to
// return errors in pending read/writes when the underlying connection goes away.
func (m *activeStreamMap) Abort() {
m.Lock()
defer m.Unlock()
for _, stream := range m.streams {
stream.Close()
}
m.ignoreNewStreams = true
}
func (m *activeStreamMap) Metrics() *FlowControlMetrics {
m.Lock()
defer m.Unlock()
var averageReceiveWindowSize, averageSendWindowSize float64
var minReceiveWindowSize, maxReceiveWindowSize, minSendWindowSize, maxSendWindowSize uint32
i := 0
// The first variable in the range expression for map is the key, not index.
for _, stream := range m.streams {
// iterative mean: a(t+1) = a(t) + (a(t)-x)/(t+1)
windows := stream.FlowControlWindow()
averageReceiveWindowSize += (float64(windows.receiveWindow) - averageReceiveWindowSize) / float64(i+1)
averageSendWindowSize += (float64(windows.sendWindow) - averageSendWindowSize) / float64(i+1)
if i == 0 {
maxReceiveWindowSize = windows.receiveWindow
minReceiveWindowSize = windows.receiveWindow
maxSendWindowSize = windows.sendWindow
minSendWindowSize = windows.sendWindow
} else {
if windows.receiveWindow > maxReceiveWindowSize {
maxReceiveWindowSize = windows.receiveWindow
} else if windows.receiveWindow < minReceiveWindowSize {
minReceiveWindowSize = windows.receiveWindow
}
if windows.sendWindow > maxSendWindowSize {
maxSendWindowSize = windows.sendWindow
} else if windows.sendWindow < minSendWindowSize {
minSendWindowSize = windows.sendWindow
}
}
i++
}
return &FlowControlMetrics{
MinReceiveWindowSize: minReceiveWindowSize,
MaxReceiveWindowSize: maxReceiveWindowSize,
AverageReceiveWindowSize: averageReceiveWindowSize,
MinSendWindowSize: minSendWindowSize,
MaxSendWindowSize: maxSendWindowSize,
AverageSendWindowSize: averageSendWindowSize,
}
}

View File

@ -0,0 +1,28 @@
package h2mux
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestMetrics(t *testing.T) {
streamMap := newActiveStreamMap(false)
for i := 1; i <= 7; i++ {
stream := new(MuxedStream)
*stream = MuxedStream{
streamID: defaultWindowSize * uint32(i),
receiveWindow: defaultWindowSize * uint32(i),
sendWindow: defaultWindowSize * uint32(i),
}
assert.True(t, streamMap.Set(stream))
}
metrics := streamMap.Metrics()
assert.Equal(t, float64(defaultWindowSize*4), metrics.AverageReceiveWindowSize)
assert.Equal(t, float64(defaultWindowSize*4), metrics.AverageSendWindowSize)
assert.Equal(t, defaultWindowSize, metrics.MinReceiveWindowSize)
assert.Equal(t, defaultWindowSize, metrics.MinSendWindowSize)
assert.Equal(t, defaultWindowSize*7, metrics.MaxReceiveWindowSize)
assert.Equal(t, defaultWindowSize*7, metrics.MaxSendWindowSize)
}

25
h2mux/booleanfuse.go Normal file
View File

@ -0,0 +1,25 @@
package h2mux
import "sync/atomic"
// BooleanFuse is a data structure that can be set once to a particular value using Fuse(value).
// Subsequent calls to Fuse() will have no effect.
type BooleanFuse struct {
value int32
}
// Value gets the value
func (f *BooleanFuse) Value() bool {
// 0: unset
// 1: set true
// 2: set false
return atomic.LoadInt32(&f.value) == 1
}
func (f *BooleanFuse) Fuse(result bool) {
newValue := int32(2)
if result {
newValue = 1
}
atomic.CompareAndSwapInt32(&f.value, 0, newValue)
}

61
h2mux/error.go Normal file
View File

@ -0,0 +1,61 @@
package h2mux
import (
"fmt"
"golang.org/x/net/http2"
)
var (
ErrHandshakeTimeout = MuxerHandshakeError{"1000 handshake timeout"}
ErrBadHandshakeNotSettings = MuxerHandshakeError{"1001 unexpected response"}
ErrBadHandshakeUnexpectedAck = MuxerHandshakeError{"1002 unexpected response"}
ErrBadHandshakeNoMagic = MuxerHandshakeError{"1003 unexpected response"}
ErrBadHandshakeWrongMagic = MuxerHandshakeError{"1004 connected to endpoint of wrong type"}
ErrBadHandshakeNotSettingsAck = MuxerHandshakeError{"1005 unexpected response"}
ErrBadHandshakeUnexpectedSettings = MuxerHandshakeError{"1006 unexpected response"}
ErrUnexpectedFrameType = MuxerProtocolError{"2001 unexpected frame type", http2.ErrCodeProtocol}
ErrUnknownStream = MuxerProtocolError{"2002 unknown stream", http2.ErrCodeProtocol}
ErrInvalidStream = MuxerProtocolError{"2003 invalid stream", http2.ErrCodeProtocol}
ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"}
ErrConnectionClosed = MuxerApplicationError{"3001 connection closed"}
ErrConnectionDropped = MuxerApplicationError{"3002 connection dropped"}
ErrClosedStream = MuxerStreamError{"4000 stream closed", http2.ErrCodeStreamClosed}
)
type MuxerHandshakeError struct {
cause string
}
func (e MuxerHandshakeError) Error() string {
return fmt.Sprintf("Handshake error: %s", e.cause)
}
type MuxerProtocolError struct {
cause string
h2code http2.ErrCode
}
func (e MuxerProtocolError) Error() string {
return fmt.Sprintf("Protocol error: %s", e.cause)
}
type MuxerApplicationError struct {
cause string
}
func (e MuxerApplicationError) Error() string {
return fmt.Sprintf("Application error: %s", e.cause)
}
type MuxerStreamError struct {
cause string
h2code http2.ErrCode
}
func (e MuxerStreamError) Error() string {
return fmt.Sprintf("Stream error: %s", e.cause)
}

335
h2mux/h2mux.go Normal file
View File

@ -0,0 +1,335 @@
package h2mux
import (
"io"
"strings"
"sync"
"time"
log "github.com/Sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
const (
defaultFrameSize uint32 = 1 << 14 // Minimum frame size in http2 spec
defaultWindowSize uint32 = 65535
maxWindowSize uint32 = (1 << 31) - 1 // 2^31-1 = 2147483647, max window size specified in http2 spec
defaultTimeout time.Duration = 5 * time.Second
defaultRetries uint64 = 5
SettingMuxerMagic http2.SettingID = 0x42db
MuxerMagicOrigin uint32 = 0xa2e43c8b
MuxerMagicEdge uint32 = 0x1088ebf9
)
type MuxedStreamHandler interface {
ServeStream(*MuxedStream) error
}
type MuxedStreamFunc func(stream *MuxedStream) error
func (f MuxedStreamFunc) ServeStream(stream *MuxedStream) error {
return f(stream)
}
type MuxerConfig struct {
Timeout time.Duration
Handler MuxedStreamHandler
IsClient bool
// Name is used to identify this muxer instance when logging.
Name string
// The minimum time this connection can be idle before sending a heartbeat.
HeartbeatInterval time.Duration
// The minimum number of heartbeats to send before terminating the connection.
MaxHeartbeats uint64
}
type Muxer struct {
// f is used to read and write HTTP2 frames on the wire.
f *http2.Framer
// config is the MuxerConfig given in Handshake.
config MuxerConfig
// w, r are references to the underlying connection used.
w io.WriteCloser
r io.ReadCloser
// muxReader is the read process.
muxReader *MuxReader
// muxWriter is the write process.
muxWriter *MuxWriter
// newStreamChan is used to create new streams on the writer thread.
// The writer will assign the next available stream ID.
newStreamChan chan MuxedStreamRequest
// abortChan is used to abort the writer event loop.
abortChan chan struct{}
// abortOnce is used to ensure abortChan is closed once only.
abortOnce sync.Once
// readyList is used to signal writable streams.
readyList *ReadyList
// streams tracks currently-open streams.
streams *activeStreamMap
// explicitShutdown records whether the Muxer is closing because Shutdown was called, or due to another
// error.
explicitShutdown BooleanFuse
}
type Header struct {
Name, Value string
}
// Handshake establishes a muxed connection with the peer.
// After the handshake completes, it is possible to open and accept streams.
func Handshake(
w io.WriteCloser,
r io.ReadCloser,
config MuxerConfig,
) (*Muxer, error) {
// Initialise connection state fields
m := &Muxer{
f: http2.NewFramer(w, r), // A framer that writes to w and reads from r
config: config,
w: w,
r: r,
newStreamChan: make(chan MuxedStreamRequest),
abortChan: make(chan struct{}),
readyList: NewReadyList(),
streams: newActiveStreamMap(config.IsClient),
}
m.f.ReadMetaHeaders = hpack.NewDecoder(4096, func(hpack.HeaderField) {})
if config.Timeout == 0 {
config.Timeout = defaultTimeout
}
// Initialise the settings to identify this connection and confirm the other end is sane.
handshakeSetting := http2.Setting{ID: SettingMuxerMagic, Val: MuxerMagicEdge}
expectedMagic := MuxerMagicOrigin
if config.IsClient {
handshakeSetting.Val = MuxerMagicOrigin
expectedMagic = MuxerMagicEdge
}
errChan := make(chan error, 2)
// Simultaneously send our settings and verify the peer's settings.
go func() { errChan <- m.f.WriteSettings(handshakeSetting) }()
go func() { errChan <- m.readPeerSettings(expectedMagic) }()
err := joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
if err != nil {
return nil, err
}
// Confirm sanity by ACKing the frame and expecting an ACK for our frame.
// Not strictly necessary, but let's pretend to be H2-like.
go func() { errChan <- m.f.WriteSettingsAck() }()
go func() { errChan <- m.readPeerSettingsAck() }()
err = joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
if err != nil {
return nil, err
}
// set up reader/writer pair ready for serve
streamErrors := NewStreamErrorMap()
goAwayChan := make(chan http2.ErrCode, 1)
pingTimestamp := NewPingTimestamp()
connActive := NewSignal()
idleDuration := config.HeartbeatInterval
// Sanity check to enusre idelDuration is sane
if idleDuration == 0 || idleDuration < defaultTimeout {
idleDuration = defaultTimeout
log.Warn("Minimum idle time has been adjusted to ", defaultTimeout)
}
maxRetries := config.MaxHeartbeats
if maxRetries == 0 {
maxRetries = defaultRetries
log.Warn("Minimum number of unacked heartbeats to send before closing the connection has been adjusted to ", maxRetries)
}
m.muxReader = &MuxReader{
f: m.f,
handler: m.config.Handler,
streams: m.streams,
readyList: m.readyList,
streamErrors: streamErrors,
goAwayChan: goAwayChan,
abortChan: m.abortChan,
pingTimestamp: pingTimestamp,
connActive: connActive,
initialStreamWindow: defaultWindowSize,
streamWindowMax: maxWindowSize,
r: m.r,
}
m.muxWriter = &MuxWriter{
f: m.f,
streams: m.streams,
streamErrors: streamErrors,
readyStreamChan: m.readyList.ReadyChannel(),
newStreamChan: m.newStreamChan,
goAwayChan: goAwayChan,
abortChan: m.abortChan,
pingTimestamp: pingTimestamp,
idleTimer: NewIdleTimer(idleDuration, maxRetries),
connActiveChan: connActive.WaitChannel(),
maxFrameSize: defaultFrameSize,
}
m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
return m, nil
}
func (m *Muxer) readPeerSettings(magic uint32) error {
frame, err := m.f.ReadFrame()
if err != nil {
return err
}
settingsFrame, ok := frame.(*http2.SettingsFrame)
if !ok {
return ErrBadHandshakeNotSettings
}
if settingsFrame.Header().Flags != 0 {
return ErrBadHandshakeUnexpectedAck
}
peerMagic, ok := settingsFrame.Value(SettingMuxerMagic)
if !ok {
return ErrBadHandshakeNoMagic
}
if magic != peerMagic {
return ErrBadHandshakeWrongMagic
}
return nil
}
func (m *Muxer) readPeerSettingsAck() error {
frame, err := m.f.ReadFrame()
if err != nil {
return err
}
settingsFrame, ok := frame.(*http2.SettingsFrame)
if !ok {
return ErrBadHandshakeNotSettingsAck
}
if settingsFrame.Header().Flags != http2.FlagSettingsAck {
return ErrBadHandshakeUnexpectedSettings
}
return nil
}
func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time.Duration, timeoutError error) error {
for i := 0; i < receiveCount; i++ {
select {
case err := <-errChan:
if err != nil {
return err
}
case <-time.After(timeout):
return timeoutError
}
}
return nil
}
func (m *Muxer) Serve() error {
logger := log.WithField("name", m.config.Name)
errChan := make(chan error)
go func() {
errChan <- m.muxReader.run(logger)
m.explicitShutdown.Fuse(false)
m.r.Close()
m.abort()
}()
go func() {
errChan <- m.muxWriter.run(logger)
m.explicitShutdown.Fuse(false)
m.w.Close()
m.abort()
}()
err := <-errChan
go func() {
// discard error as other handler closes
<-errChan
close(errChan)
}()
if isUnexpectedTunnelError(err, m.explicitShutdown.Value()) {
return err
}
return nil
}
func (m *Muxer) Shutdown() {
m.explicitShutdown.Fuse(true)
m.muxReader.Shutdown()
}
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
// The set of expected errors change depending on whether we initiated shutdown or not.
func isUnexpectedTunnelError(err error, expectedShutdown bool) bool {
if err == nil {
return false
}
if !expectedShutdown {
return true
}
return !isConnectionClosedError(err)
}
func isConnectionClosedError(err error) bool {
if err == io.EOF {
return true
}
if err == io.ErrClosedPipe {
return true
}
if err.Error() == "tls: use of closed connection" {
return true
}
if strings.HasSuffix(err.Error(), "use of closed network connection") {
return true
}
return false
}
// OpenStream opens a new data stream with the given headers.
// Called by proxy server and tunnel
func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, error) {
stream := &MuxedStream{
responseHeadersReceived: make(chan struct{}),
readBuffer: NewSharedBuffer(),
receiveWindow: defaultWindowSize,
receiveWindowCurrentMax: defaultWindowSize, // Initial window size limit. exponentially increase it when receiveWindow is exhausted
receiveWindowMax: maxWindowSize,
sendWindow: defaultWindowSize,
readyList: m.readyList,
writeHeaders: headers,
}
select {
// Will be received by mux writer
case m.newStreamChan <- MuxedStreamRequest{stream: stream, body: body}:
case <-m.abortChan:
return nil, ErrConnectionClosed
}
select {
case <-stream.responseHeadersReceived:
return stream, nil
case <-m.abortChan:
return nil, ErrConnectionClosed
}
}
// Return the estimated round-trip time.
func (m *Muxer) RTT() RTTMeasurement {
return m.muxReader.RTT()
}
// Return min/max/average of send/receive window for all streams on this connection
func (m *Muxer) FlowControlMetrics() *FlowControlMetrics {
return m.muxReader.FlowControlMetrics()
}
func (m *Muxer) abort() {
m.abortOnce.Do(func() {
close(m.abortChan)
m.streams.Abort()
})
}
// Return how many retries/ticks since the connection was last marked active
func (m *Muxer) TimerRetries() uint64 {
return m.muxWriter.idleTimer.RetryCount()
}

646
h2mux/h2mux_test.go Normal file
View File

@ -0,0 +1,646 @@
package h2mux
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"math/rand"
"os"
"strconv"
"sync"
"testing"
"time"
log "github.com/Sirupsen/logrus"
)
func TestMain(m *testing.M) {
if os.Getenv("VERBOSE") == "1" {
log.SetLevel(log.DebugLevel)
}
os.Exit(m.Run())
}
type DefaultMuxerPair struct {
OriginMuxConfig MuxerConfig
OriginMux *Muxer
OriginWriter *io.PipeWriter
OriginReader *io.PipeReader
EdgeMuxConfig MuxerConfig
EdgeMux *Muxer
EdgeWriter *io.PipeWriter
EdgeReader *io.PipeReader
doneC chan struct{}
}
func NewDefaultMuxerPair() *DefaultMuxerPair {
originReader, edgeWriter := io.Pipe()
edgeReader, originWriter := io.Pipe()
return &DefaultMuxerPair{
OriginMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: true, Name: "origin"},
OriginWriter: originWriter,
OriginReader: originReader,
EdgeMuxConfig: MuxerConfig{Timeout: time.Second, IsClient: false, Name: "edge"},
EdgeWriter: edgeWriter,
EdgeReader: edgeReader,
doneC: make(chan struct{}),
}
}
func (p *DefaultMuxerPair) Handshake(t *testing.T) {
edgeErrC := make(chan error)
originErrC := make(chan error)
go func() {
var err error
p.EdgeMux, err = Handshake(p.EdgeWriter, p.EdgeReader, p.EdgeMuxConfig)
edgeErrC <- err
}()
go func() {
var err error
p.OriginMux, err = Handshake(p.OriginWriter, p.OriginReader, p.OriginMuxConfig)
originErrC <- err
}()
select {
case err := <-edgeErrC:
if err != nil {
t.Fatalf("edge handshake failure: %s", err)
}
case <-time.After(time.Second * 5):
t.Fatalf("edge handshake timeout")
}
select {
case err := <-originErrC:
if err != nil {
t.Fatalf("origin handshake failure: %s", err)
}
case <-time.After(time.Second * 5):
t.Fatalf("origin handshake timeout")
}
}
func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) {
p.Handshake(t)
var wg sync.WaitGroup
wg.Add(2)
go func() {
err := p.EdgeMux.Serve()
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
t.Errorf("error in edge muxer Serve(): %s", err)
}
p.OriginMux.Shutdown()
wg.Done()
}()
go func() {
err := p.OriginMux.Serve()
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
t.Errorf("error in origin muxer Serve(): %s", err)
}
p.EdgeMux.Shutdown()
wg.Done()
}()
go func() {
// notify when both muxes have stopped serving
wg.Wait()
close(p.doneC)
}()
}
func (p *DefaultMuxerPair) Wait(t *testing.T) {
select {
case <-p.doneC:
return
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for shutdown")
}
}
func TestHandshake(t *testing.T) {
muxPair := NewDefaultMuxerPair()
muxPair.Handshake(t)
AssertIfPipeReadable(t, muxPair.OriginReader)
AssertIfPipeReadable(t, muxPair.EdgeReader)
}
func TestSingleStream(t *testing.T) {
closeC := make(chan struct{})
muxPair := NewDefaultMuxerPair()
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
defer close(closeC)
if len(stream.Headers) != 1 {
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
}
if stream.Headers[0].Name != "test-header" {
t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
}
if stream.Headers[0].Value != "headerValue" {
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
}
stream.WriteHeaders([]Header{
Header{Name: "response-header", Value: "responseValue"},
})
buf := []byte("Hello world")
stream.Write(buf)
// after this receive, the edge closed the stream
<-closeC
n, err := stream.Read(buf)
if n > 0 {
t.Fatalf("read %d bytes after EOF", n)
}
if err != io.EOF {
t.Fatalf("expected EOF, got %s", err)
}
return nil
})
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "test-header", Value: "headerValue"}},
nil,
)
if err != nil {
t.Fatalf("error in OpenStream: %s", err)
}
if len(stream.Headers) != 1 {
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
}
if stream.Headers[0].Name != "response-header" {
t.Fatalf("expected header name %s, got %s", "response-header", stream.Headers[0].Name)
}
if stream.Headers[0].Value != "responseValue" {
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
}
responseBody := make([]byte, 11)
n, err := stream.Read(responseBody)
if err != nil {
t.Fatalf("error from (*MuxedStream).Read: %s", err)
}
if n != len(responseBody) {
t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n)
}
if string(responseBody) != "Hello world" {
t.Fatalf("expected response body %s, got %s", "Hello world", responseBody)
}
stream.Close()
closeC <- struct{}{}
n, err = stream.Write([]byte("aaaaa"))
if n > 0 {
t.Fatalf("wrote %d bytes after EOF", n)
}
if err != io.EOF {
t.Fatalf("expected EOF, got %s", err)
}
<-closeC
}
func TestSingleStreamLargeResponseBody(t *testing.T) {
muxPair := NewDefaultMuxerPair()
bodySize := 1 << 24
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
if len(stream.Headers) != 1 {
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
}
if stream.Headers[0].Name != "test-header" {
t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
}
if stream.Headers[0].Value != "headerValue" {
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
}
stream.WriteHeaders([]Header{
Header{Name: "response-header", Value: "responseValue"},
})
payload := make([]byte, bodySize)
for i := range payload {
payload[i] = byte(i % 256)
}
n, err := stream.Write(payload)
if err != nil {
t.Fatalf("origin write error: %s", err)
}
if n != len(payload) {
t.Fatalf("origin short write: %d/%d bytes", n, len(payload))
}
return nil
})
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "test-header", Value: "headerValue"}},
nil,
)
if err != nil {
t.Fatalf("error in OpenStream: %s", err)
}
if len(stream.Headers) != 1 {
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
}
if stream.Headers[0].Name != "response-header" {
t.Fatalf("expected header name %s, got %s", "response-header", stream.Headers[0].Name)
}
if stream.Headers[0].Value != "responseValue" {
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
}
responseBody := make([]byte, bodySize)
n, err := stream.Read(responseBody)
if err != nil {
t.Fatalf("error from (*MuxedStream).Read: %s", err)
}
if n != len(responseBody) {
t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n)
}
}
func TestMultipleStreams(t *testing.T) {
muxPair := NewDefaultMuxerPair()
maxStreams := 64
errorsC := make(chan error, maxStreams)
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
if len(stream.Headers) != 1 {
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
}
if stream.Headers[0].Name != "client-token" {
t.Fatalf("expected header name %s, got %s", "client-token", stream.Headers[0].Name)
}
log.Debugf("Got request for stream %s", stream.Headers[0].Value)
stream.WriteHeaders([]Header{
Header{Name: "response-token", Value: stream.Headers[0].Value},
})
log.Debugf("Wrote headers for stream %s", stream.Headers[0].Value)
stream.Write([]byte("OK"))
log.Debugf("Wrote body for stream %s", stream.Headers[0].Value)
return nil
})
muxPair.HandshakeAndServe(t)
var wg sync.WaitGroup
wg.Add(maxStreams)
for i := 0; i < maxStreams; i++ {
go func(tokenId int) {
defer wg.Done()
tokenString := fmt.Sprintf("%d", tokenId)
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "client-token", Value: tokenString}},
nil,
)
log.Debugf("Got headers for stream %d", tokenId)
if err != nil {
errorsC <- err
return
}
if len(stream.Headers) != 1 {
errorsC <- fmt.Errorf("stream %d has error: expected %d headers, got %d", stream.streamID, 1, len(stream.Headers))
return
}
if stream.Headers[0].Name != "response-token" {
errorsC <- fmt.Errorf("stream %d has error: expected header name %s, got %s", stream.streamID, "response-token", stream.Headers[0].Name)
return
}
if stream.Headers[0].Value != tokenString {
errorsC <- fmt.Errorf("stream %d has error: expected header value %s, got %s", stream.streamID, tokenString, stream.Headers[0].Value)
return
}
responseBody := make([]byte, 2)
n, err := stream.Read(responseBody)
if err != nil {
errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err)
return
}
if n != len(responseBody) {
errorsC <- fmt.Errorf("stream %d has error: expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n)
return
}
if string(responseBody) != "OK" {
errorsC <- fmt.Errorf("stream %d has error: expected response body %s, got %s", stream.streamID, "OK", responseBody)
return
}
}(i)
}
wg.Wait()
close(errorsC)
testFail := false
for err := range errorsC {
testFail = true
log.Error(err)
}
if testFail {
t.Fatalf("TestMultipleStreamsFlowControl failed")
}
}
func TestMultipleStreamsFlowControl(t *testing.T) {
maxStreams := 32
errorsC := make(chan error, maxStreams)
responseSizes := make([]int32, maxStreams)
for i := 0; i < maxStreams; i++ {
responseSizes[i] = rand.Int31n(int32(defaultWindowSize << 4))
}
muxPair := NewDefaultMuxerPair()
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
if len(stream.Headers) != 1 {
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
}
if stream.Headers[0].Name != "test-header" {
t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
}
if stream.Headers[0].Value != "headerValue" {
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
}
stream.WriteHeaders([]Header{
Header{Name: "response-header", Value: "responseValue"},
})
payload := make([]byte, responseSizes[(stream.streamID-2)/2])
for i := range payload {
payload[i] = byte(i % 256)
}
n, err := stream.Write(payload)
if err != nil {
t.Fatalf("origin write error: %s", err)
}
if n != len(payload) {
t.Fatalf("origin short write: %d/%d bytes", n, len(payload))
}
return nil
})
muxPair.HandshakeAndServe(t)
var wg sync.WaitGroup
wg.Add(maxStreams)
for i := 0; i < maxStreams; i++ {
go func(tokenId int) {
defer wg.Done()
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "test-header", Value: "headerValue"}},
nil,
)
if err != nil {
errorsC <- fmt.Errorf("stream %d error in OpenStream: %s", stream.streamID, err)
return
}
if len(stream.Headers) != 1 {
errorsC <- fmt.Errorf("stream %d expected %d headers, got %d", stream.streamID, 1, len(stream.Headers))
return
}
if stream.Headers[0].Name != "response-header" {
errorsC <- fmt.Errorf("stream %d expected header name %s, got %s", stream.streamID, "response-header", stream.Headers[0].Name)
return
}
if stream.Headers[0].Value != "responseValue" {
errorsC <- fmt.Errorf("stream %d expected header value %s, got %s", stream.streamID, "responseValue", stream.Headers[0].Value)
return
}
responseBody := make([]byte, responseSizes[(stream.streamID-2)/2])
n, err := stream.Read(responseBody)
if err != nil {
errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err)
return
}
if n != len(responseBody) {
errorsC <- fmt.Errorf("stream %d expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n)
return
}
}(i)
}
wg.Wait()
close(errorsC)
testFail := false
for err := range errorsC {
testFail = true
log.Error(err)
}
if testFail {
t.Fatalf("TestMultipleStreamsFlowControl failed")
}
}
func TestGracefulShutdown(t *testing.T) {
sendC := make(chan struct{})
responseBuf := bytes.Repeat([]byte("Hello world"), 65536)
muxPair := NewDefaultMuxerPair()
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
stream.WriteHeaders([]Header{
Header{Name: "response-header", Value: "responseValue"},
})
<-sendC
log.Debugf("Writing %d bytes", len(responseBuf))
stream.Write(responseBuf)
stream.CloseWrite()
log.Debugf("Wrote %d bytes", len(responseBuf))
// Reading from the stream will block until the edge closes its end of the stream.
// Otherwise, we'll close the whole connection before receiving the 'stream closed'
// message from the edge.
// Graceful shutdown works if you omit this, it just gives spurious errors for now -
// TODO ignore errors when writing 'stream closed' and we're shutting down.
stream.Read([]byte{0})
log.Debugf("Handler ends")
return nil
})
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "test-header", Value: "headerValue"}},
nil,
)
// Start graceful shutdown of the edge mux - this should also close the origin mux when done
muxPair.EdgeMux.Shutdown()
close(sendC)
if err != nil {
t.Fatalf("error in OpenStream: %s", err)
}
responseBody := make([]byte, len(responseBuf))
log.Debugf("Waiting for %d bytes", len(responseBuf))
n, err := stream.Read(responseBody)
if err != nil {
t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err)
}
if n != len(responseBody) {
t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n)
}
if !bytes.Equal(responseBuf, responseBody) {
t.Fatalf("response body mismatch")
}
stream.Close()
muxPair.Wait(t)
}
func TestUnexpectedShutdown(t *testing.T) {
sendC := make(chan struct{})
handlerFinishC := make(chan struct{})
responseBuf := bytes.Repeat([]byte("Hello world"), 65536)
muxPair := NewDefaultMuxerPair()
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
defer close(handlerFinishC)
stream.WriteHeaders([]Header{
Header{Name: "response-header", Value: "responseValue"},
})
<-sendC
n, err := stream.Read([]byte{0})
if err != io.EOF {
t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err)
}
if n != 0 {
t.Fatalf("expected empty read, got %d bytes", n)
}
// Write comes after read, because write buffers data before it is flushed. It wouldn't know about EOF
// until some time later. Calling read first forces it to know about EOF now.
_, err = stream.Write(responseBuf)
if err != io.EOF {
t.Fatalf("unexpected error from (*MuxedStream).Write: %s", err)
}
return nil
})
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "test-header", Value: "headerValue"}},
nil,
)
// Close the underlying connection before telling the origin to write.
muxPair.EdgeReader.Close()
close(sendC)
if err != nil {
t.Fatalf("error in OpenStream: %s", err)
}
responseBody := make([]byte, len(responseBuf))
n, err := stream.Read(responseBody)
if err != io.EOF {
t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err)
}
if n != 0 {
t.Fatalf("expected response body to have %d bytes, got %d", 0, n)
}
// The write ordering requirement explained in the origin handler applies here too.
_, err = stream.Write(responseBuf)
if err != io.EOF {
t.Fatalf("unexpected error from (*MuxedStream).Write: %s", err)
}
<-handlerFinishC
}
func EchoHandler(stream *MuxedStream) error {
var buf bytes.Buffer
fmt.Fprintf(&buf, "Hello, world!\n\n# REQUEST HEADERS:\n\n")
for _, header := range stream.Headers {
fmt.Fprintf(&buf, "[%s] = %s\n", header.Name, header.Value)
}
stream.WriteHeaders([]Header{
{Name: ":status", Value: "200"},
{Name: "server", Value: "Echo-server/1.0"},
{Name: "date", Value: time.Now().Format(time.RFC850)},
{Name: "content-type", Value: "text/html; charset=utf-8"},
{Name: "content-length", Value: strconv.Itoa(buf.Len())},
})
buf.WriteTo(stream)
return nil
}
func TestOpenAfterDisconnect(t *testing.T) {
for i := 0; i < 3; i++ {
muxPair := NewDefaultMuxerPair()
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(EchoHandler)
muxPair.HandshakeAndServe(t)
switch i {
case 0:
// Close both directions of the connection to cause EOF on both peers.
muxPair.OriginReader.Close()
muxPair.OriginWriter.Close()
case 1:
// Close origin reader (edge writer) to cause EOF on origin only.
muxPair.OriginReader.Close()
case 2:
// Close origin writer (edge reader) to cause EOF on edge only.
muxPair.OriginWriter.Close()
}
_, err := muxPair.EdgeMux.OpenStream(
[]Header{Header{Name: "test-header", Value: "headerValue"}},
nil,
)
if err != ErrConnectionClosed {
t.Fatalf("unexpected error in OpenStream: %s", err)
}
}
}
func TestHPACK(t *testing.T) {
muxPair := NewDefaultMuxerPair()
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(EchoHandler)
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{
{Name: ":method", Value: "RPC"},
{Name: ":scheme", Value: "capnp"},
{Name: ":path", Value: "*"},
},
nil,
)
if err != nil {
t.Fatalf("error in OpenStream: %s", err)
}
stream.Close()
for i := 0; i < 3; i++ {
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{
{Name: ":method", Value: "GET"},
{Name: ":scheme", Value: "https"},
{Name: ":authority", Value: "tunnel.otterlyadorable.co.uk"},
{Name: ":path", Value: "/get"},
{Name: "accept-encoding", Value: "gzip"},
{Name: "cf-ray", Value: "378948953f044408-SFO-DOG"},
{Name: "cf-visitor", Value: "{\"scheme\":\"https\"}"},
{Name: "cf-connecting-ip", Value: "2400:cb00:0025:010d:0000:0000:0000:0001"},
{Name: "x-forwarded-for", Value: "2400:cb00:0025:010d:0000:0000:0000:0001"},
{Name: "x-forwarded-proto", Value: "https"},
{Name: "accept-language", Value: "en-gb"},
{Name: "referer", Value: "https://tunnel.otterlyadorable.co.uk/"},
{Name: "cookie", Value: "__cfduid=d4555095065f92daedc059490771967d81493032162"},
{Name: "connection", Value: "Keep-Alive"},
{Name: "cf-ipcountry", Value: "US"},
{Name: "accept", Value: "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
{Name: "user-agent", Value: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_5) AppleWebKit/603.2.4 (KHTML, like Gecko) Version/10.1.1 Safari/603.2.4"},
},
nil,
)
if err != nil {
t.Fatalf("error in OpenStream: %s", err)
}
if len(stream.Headers) == 0 {
t.Fatal("response has no headers")
}
if stream.Headers[0].Name != ":status" {
t.Fatalf("first header should be status, found %s instead", stream.Headers[0].Name)
}
if stream.Headers[0].Value != "200" {
t.Fatalf("expected status 200, got %s", stream.Headers[0].Value)
}
ioutil.ReadAll(stream)
stream.Close()
}
}
func AssertIfPipeReadable(t *testing.T, pipe *io.PipeReader) {
errC := make(chan error)
go func() {
b := []byte{0}
n, err := pipe.Read(b)
if n > 0 {
t.Fatalf("read pipe was not empty")
}
errC <- err
}()
select {
case err := <-errC:
if err != nil {
t.Fatalf("read error: %s", err)
}
case <-time.After(100 * time.Millisecond):
// nothing to read
pipe.Close()
<-errC
}
}

81
h2mux/idletimer.go Normal file
View File

@ -0,0 +1,81 @@
package h2mux
import (
"math/rand"
"sync"
"time"
)
// IdleTimer is a type of Timer designed for managing heartbeats on an idle connection.
// The timer ticks on an interval with added jitter to avoid accidental synchronisation
// between two endpoints. It tracks the number of retries/ticks since the connection was
// last marked active.
//
// The methods of IdleTimer must not be called while a goroutine is reading from C.
type IdleTimer struct {
// The channel on which ticks are delivered.
C <-chan time.Time
// A timer used to measure idle connection time. Reset after sending data.
idleTimer *time.Timer
// The maximum length of time a connection is idle before sending a ping.
idleDuration time.Duration
// A pseudorandom source used to add jitter to the idle duration.
randomSource *rand.Rand
// The maximum number of retries allowed.
maxRetries uint64
// The number of retries since the connection was last marked active.
retries uint64
// A lock to prevent race condition while checking retries
stateLock sync.RWMutex
}
func NewIdleTimer(idleDuration time.Duration, maxRetries uint64) *IdleTimer {
t := &IdleTimer{
idleTimer: time.NewTimer(idleDuration),
idleDuration: idleDuration,
randomSource: rand.New(rand.NewSource(time.Now().Unix())),
maxRetries: maxRetries,
}
t.C = t.idleTimer.C
return t
}
// Retry should be called when retrying the idle timeout. If the maximum number of retries
// has been met, returns false.
// After calling this function and sending a heartbeat, call ResetTimer. Since sending the
// heartbeat could be a blocking operation, we resetting the timer after the write completes
// to avoid it expiring during the write.
func (t *IdleTimer) Retry() bool {
t.stateLock.Lock()
defer t.stateLock.Unlock()
if t.retries >= t.maxRetries {
return false
}
t.retries++
return true
}
func (t *IdleTimer) RetryCount() uint64 {
t.stateLock.RLock()
defer t.stateLock.RUnlock()
return t.retries
}
// MarkActive resets the idle connection timer and suppresses any outstanding idle events.
func (t *IdleTimer) MarkActive() {
if !t.idleTimer.Stop() {
// eat the timer event to prevent spurious pings
<-t.idleTimer.C
}
t.stateLock.Lock()
t.retries = 0
t.stateLock.Unlock()
t.ResetTimer()
}
// Reset the idle timer according to the configured duration, with some added jitter.
func (t *IdleTimer) ResetTimer() {
jitter := time.Duration(t.randomSource.Int63n(int64(t.idleDuration)))
t.idleTimer.Reset(t.idleDuration + jitter)
}

31
h2mux/idletimer_test.go Normal file
View File

@ -0,0 +1,31 @@
package h2mux
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestRetry(t *testing.T) {
timer := NewIdleTimer(time.Second, 2)
assert.Equal(t, uint64(0), timer.RetryCount())
ok := timer.Retry()
assert.True(t, ok)
assert.Equal(t, uint64(1), timer.RetryCount())
ok = timer.Retry()
assert.True(t, ok)
assert.Equal(t, uint64(2), timer.RetryCount())
ok = timer.Retry()
assert.False(t, ok)
}
func TestMarkActive(t *testing.T) {
timer := NewIdleTimer(time.Second, 2)
assert.Equal(t, uint64(0), timer.RetryCount())
ok := timer.Retry()
assert.True(t, ok)
assert.Equal(t, uint64(1), timer.RetryCount())
timer.MarkActive()
assert.Equal(t, uint64(0), timer.RetryCount())
}

250
h2mux/muxedstream.go Normal file
View File

@ -0,0 +1,250 @@
package h2mux
import (
"bytes"
"io"
"sync"
)
type MuxedStream struct {
Headers []Header
streamID uint32
responseHeadersReceived chan struct{}
readBuffer *SharedBuffer
receiveWindow uint32
// current window size limit. Exponentially increase it when it's exhausted
receiveWindowCurrentMax uint32
// limit set in http2 spec. 2^31-1
receiveWindowMax uint32
// nonzero if a WINDOW_UPDATE frame for a stream needs to be sent
windowUpdate uint32
writeLock sync.Mutex
// The zero value for Buffer is an empty buffer ready to use.
writeBuffer bytes.Buffer
sendWindow uint32
readyList *ReadyList
headersSent bool
writeHeaders []Header
// true if the write end of this stream has been closed
writeEOF bool
// true if we have sent EOF to the peer
sentEOF bool
// true if the peer sent us an EOF
receivedEOF bool
}
type flowControlWindow struct {
receiveWindow, sendWindow uint32
}
func (s *MuxedStream) Read(p []byte) (n int, err error) {
return s.readBuffer.Read(p)
}
func (s *MuxedStream) Write(p []byte) (n int, err error) {
s.writeLock.Lock()
defer s.writeLock.Unlock()
if s.writeEOF {
return 0, io.EOF
}
n, err = s.writeBuffer.Write(p)
if n != len(p) || err != nil {
return n, err
}
s.writeNotify()
return n, nil
}
func (s *MuxedStream) Close() error {
// TUN-115: Close the write buffer before the read buffer.
// In the case of shutdown, read will not get new data, but the write buffer can still receive
// new data. Closing read before write allows application to race between a failed read and a
// successful write, even though this close should appear to be atomic.
// This can't happen the other way because reads may succeed after a failed write; if we read
// past EOF the application will block until we close the buffer.
err := s.CloseWrite()
if err != nil {
if s.CloseRead() == nil {
// don't bother the caller with errors if at least one close succeeded
return nil
}
return err
}
return s.CloseRead()
}
func (s *MuxedStream) CloseRead() error {
return s.readBuffer.Close()
}
func (s *MuxedStream) CloseWrite() error {
s.writeLock.Lock()
defer s.writeLock.Unlock()
if s.writeEOF {
return io.EOF
}
s.writeEOF = true
s.writeNotify()
return nil
}
func (s *MuxedStream) WriteHeaders(headers []Header) error {
s.writeLock.Lock()
defer s.writeLock.Unlock()
if s.writeHeaders != nil {
return ErrStreamHeadersSent
}
s.writeHeaders = headers
s.writeNotify()
return nil
}
func (s *MuxedStream) FlowControlWindow() *flowControlWindow {
s.writeLock.Lock()
defer s.writeLock.Unlock()
return &flowControlWindow{
receiveWindow: s.receiveWindow,
sendWindow: s.sendWindow,
}
}
// writeNotify must happen while holding writeLock.
func (s *MuxedStream) writeNotify() {
s.readyList.Signal(s.streamID)
}
// Call by muxreader when it gets a WindowUpdateFrame. This is an update of the peer's
// receive window (how much data we can send).
func (s *MuxedStream) replenishSendWindow(bytes uint32) {
s.writeLock.Lock()
s.sendWindow += bytes
s.writeNotify()
s.writeLock.Unlock()
}
// Call by muxreader when it receives a data frame
func (s *MuxedStream) consumeReceiveWindow(bytes uint32) bool {
s.writeLock.Lock()
defer s.writeLock.Unlock()
// received data size is greater than receive window/buffer
if s.receiveWindow < bytes {
return false
}
s.receiveWindow -= bytes
if s.receiveWindow < s.receiveWindowCurrentMax/2 {
// exhausting client send window (how much data client can send)
if s.receiveWindowCurrentMax < s.receiveWindowMax {
s.receiveWindowCurrentMax <<= 1
}
s.windowUpdate += s.receiveWindowCurrentMax - s.receiveWindow
s.writeNotify()
}
return true
}
// receiveEOF should be called when the peer indicates no more data will be sent.
// Returns true if the socket is now closed (i.e. the write side is already closed).
func (s *MuxedStream) receiveEOF() (closed bool) {
s.writeLock.Lock()
defer s.writeLock.Unlock()
s.receivedEOF = true
s.CloseRead()
return s.writeEOF && s.writeBuffer.Len() == 0
}
func (s *MuxedStream) gotReceiveEOF() bool {
s.writeLock.Lock()
defer s.writeLock.Unlock()
return s.receivedEOF
}
// MuxedStreamReader implements io.ReadCloser for the read end of the stream.
// This is useful for passing to functions that close the object after it is done reading,
// but you still want to be able to write data afterwards (e.g. http.Client).
type MuxedStreamReader struct {
*MuxedStream
}
func (s MuxedStreamReader) Read(p []byte) (n int, err error) {
return s.MuxedStream.Read(p)
}
func (s MuxedStreamReader) Close() error {
return s.MuxedStream.CloseRead()
}
// streamChunk represents a chunk of data to be written.
type streamChunk struct {
streamID uint32
// true if a HEADERS frame should be sent
sendHeaders bool
headers []Header
// nonzero if a WINDOW_UPDATE frame should be sent
windowUpdate uint32
// true if data frames should be sent
sendData bool
eof bool
buffer bytes.Buffer
}
// getChunk atomically extracts a chunk of data to be written by MuxWriter.
// The data returned will not exceed the send window for this stream.
func (s *MuxedStream) getChunk() *streamChunk {
s.writeLock.Lock()
defer s.writeLock.Unlock()
chunk := &streamChunk{
streamID: s.streamID,
sendHeaders: !s.headersSent,
headers: s.writeHeaders,
windowUpdate: s.windowUpdate,
sendData: !s.sentEOF,
eof: s.writeEOF && uint32(s.writeBuffer.Len()) <= s.sendWindow,
}
// Copies at most s.sendWindow bytes
//log.Infof("writeBuffer len %d stream %d", s.writeBuffer.Len(), s.streamID)
writeLen, _ := io.CopyN(&chunk.buffer, &s.writeBuffer, int64(s.sendWindow))
//log.Infof("writeLen %d stream %d", writeLen, s.streamID)
s.sendWindow -= uint32(writeLen)
s.receiveWindow += s.windowUpdate
s.windowUpdate = 0
s.headersSent = true
// if this chunk contains the end of the stream, close the stream now
if chunk.sendData && chunk.eof {
s.sentEOF = true
}
return chunk
}
func (c *streamChunk) sendHeadersFrame() bool {
return c.sendHeaders
}
func (c *streamChunk) sendWindowUpdateFrame() bool {
return c.windowUpdate > 0
}
func (c *streamChunk) sendDataFrame() bool {
return c.sendData
}
func (c *streamChunk) nextDataFrame(frameSize int) (payload []byte, endStream bool) {
payload = c.buffer.Next(frameSize)
if c.buffer.Len() == 0 {
// this is the last data frame in this chunk
c.sendData = false
if c.eof {
endStream = true
}
}
return
}

65
h2mux/muxedstream_test.go Normal file
View File

@ -0,0 +1,65 @@
package h2mux
import (
"testing"
"github.com/stretchr/testify/assert"
)
const testWindowSize uint32 = 65535
const testMaxWindowSize uint32 = testWindowSize << 2
// Only sending WINDOW_UPDATE frame, so sendWindow should never change
func TestFlowControlSingleStream(t *testing.T) {
stream := &MuxedStream{
responseHeadersReceived: make(chan struct{}),
readBuffer: NewSharedBuffer(),
receiveWindow: testWindowSize,
receiveWindowCurrentMax: testWindowSize,
receiveWindowMax: testMaxWindowSize,
sendWindow: testWindowSize,
readyList: NewReadyList(),
}
assert.True(t, stream.consumeReceiveWindow(testWindowSize/2))
dataSent := testWindowSize / 2
assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow)
assert.Equal(t, testWindowSize, stream.receiveWindowCurrentMax)
assert.Equal(t, uint32(0), stream.windowUpdate)
tempWindowUpdate := stream.windowUpdate
streamChunk := stream.getChunk()
assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate)
assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow)
assert.Equal(t, uint32(0), stream.windowUpdate)
assert.Equal(t, testWindowSize, stream.sendWindow)
assert.True(t, stream.consumeReceiveWindow(2))
dataSent += 2
assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow)
assert.Equal(t, testWindowSize<<1, stream.receiveWindowCurrentMax)
assert.Equal(t, (testWindowSize<<1)-stream.receiveWindow, stream.windowUpdate)
tempWindowUpdate = stream.windowUpdate
streamChunk = stream.getChunk()
assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate)
assert.Equal(t, testWindowSize<<1, stream.receiveWindow)
assert.Equal(t, uint32(0), stream.windowUpdate)
assert.Equal(t, testWindowSize, stream.sendWindow)
assert.True(t, stream.consumeReceiveWindow(testWindowSize+10))
dataSent = testWindowSize + 10
assert.Equal(t, (testWindowSize<<1)-dataSent, stream.receiveWindow)
assert.Equal(t, testWindowSize<<2, stream.receiveWindowCurrentMax)
assert.Equal(t, (testWindowSize<<2)-stream.receiveWindow, stream.windowUpdate)
tempWindowUpdate = stream.windowUpdate
streamChunk = stream.getChunk()
assert.Equal(t, tempWindowUpdate, streamChunk.windowUpdate)
assert.Equal(t, testWindowSize<<2, stream.receiveWindow)
assert.Equal(t, uint32(0), stream.windowUpdate)
assert.Equal(t, testWindowSize, stream.sendWindow)
assert.False(t, stream.consumeReceiveWindow(testMaxWindowSize+1))
assert.Equal(t, testWindowSize<<2, stream.receiveWindow)
assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax)
}

326
h2mux/muxreader.go Normal file
View File

@ -0,0 +1,326 @@
package h2mux
import (
"encoding/binary"
"io"
"sync"
"time"
log "github.com/Sirupsen/logrus"
"golang.org/x/net/http2"
)
type MuxReader struct {
// f is used to read HTTP2 frames.
f *http2.Framer
// handler provides a callback to receive new streams. if nil, new streams cannot be accepted.
handler MuxedStreamHandler
// streams tracks currently-open streams.
streams *activeStreamMap
// readyList is used to signal writable streams.
readyList *ReadyList
// streamErrors lets us report stream errors to the MuxWriter.
streamErrors *StreamErrorMap
// goAwayChan is used to tell the writer to send a GOAWAY message.
goAwayChan chan<- http2.ErrCode
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
abortChan <-chan struct{}
// pingTimestamp is an atomic value containing the latest received ping timestamp.
pingTimestamp *PingTimestamp
// connActive is used to signal to the writer that something happened on the connection.
// This is used to clear idle timeout disconnection deadlines.
connActive Signal
// The initial value for the send and receive window of a new stream.
initialStreamWindow uint32
// The max value for the send window of a stream.
streamWindowMax uint32
// windowMetrics keeps track of min/max/average of send/receive windows for all streams
flowControlMetrics *FlowControlMetrics
metricsMutex sync.Mutex
// r is a reference to the underlying connection used when shutting down.
r io.Closer
// rttMeasurement measures RTT based on ping timestamps.
rttMeasurement RTTMeasurement
rttMutex sync.Mutex
}
func (r *MuxReader) Shutdown() {
done := r.streams.Shutdown()
if done == nil {
return
}
r.sendGoAway(http2.ErrCodeNo)
go func() {
// close reader side when last stream ends; this will cause the writer to abort
<-done
r.r.Close()
}()
}
func (r *MuxReader) RTT() RTTMeasurement {
r.rttMutex.Lock()
defer r.rttMutex.Unlock()
return r.rttMeasurement
}
func (r *MuxReader) FlowControlMetrics() *FlowControlMetrics {
r.metricsMutex.Lock()
defer r.metricsMutex.Unlock()
if r.flowControlMetrics != nil {
return r.flowControlMetrics
}
// No metrics available yet
return &FlowControlMetrics{}
}
func (r *MuxReader) run(parentLogger *log.Entry) error {
logger := parentLogger.WithFields(log.Fields{
"subsystem": "mux",
"dir": "read",
})
defer logger.Debug("event loop finished")
for {
frame, err := r.f.ReadFrame()
if err != nil {
switch e := err.(type) {
case http2.StreamError:
logger.WithError(err).Warn("stream error")
r.streamError(e.StreamID, e.Code)
case http2.ConnectionError:
logger.WithError(err).Warn("connection error")
return r.connectionError(err)
default:
if isConnectionClosedError(err) {
if r.streams.Len() == 0 {
logger.Debug("shutting down")
return nil
}
logger.Warn("connection closed unexpectedly")
return err
} else {
logger.WithError(err).Warn("frame read error")
return r.connectionError(err)
}
}
}
r.connActive.Signal()
logger.WithField("data", frame).Debug("read frame")
switch f := frame.(type) {
case *http2.DataFrame:
err = r.receiveFrameData(f, logger)
case *http2.MetaHeadersFrame:
err = r.receiveHeaderData(f)
case *http2.RSTStreamFrame:
streamID := f.Header().StreamID
if streamID == 0 {
return ErrInvalidStream
}
r.streams.Delete(streamID)
case *http2.PingFrame:
r.receivePingData(f)
case *http2.GoAwayFrame:
err = r.receiveGoAway(f)
case *http2.WindowUpdateFrame:
err = r.updateStreamWindow(f)
default:
err = ErrUnexpectedFrameType
}
if err != nil {
logger.WithField("data", frame).WithError(err).Debug("frame error")
return r.connectionError(err)
}
}
}
func (r *MuxReader) newMuxedStream(streamID uint32) *MuxedStream {
return &MuxedStream{
streamID: streamID,
readBuffer: NewSharedBuffer(),
receiveWindow: r.initialStreamWindow,
receiveWindowCurrentMax: r.initialStreamWindow,
receiveWindowMax: r.streamWindowMax,
sendWindow: r.initialStreamWindow,
readyList: r.readyList,
}
}
// getStreamForFrame returns a stream if valid, or an error describing why the stream could not be returned.
func (r *MuxReader) getStreamForFrame(frame http2.Frame) (*MuxedStream, error) {
sid := frame.Header().StreamID
if sid == 0 {
return nil, ErrUnexpectedFrameType
}
if stream, ok := r.streams.Get(sid); ok {
return stream, nil
}
if r.streams.IsLocalStreamID(sid) {
// no stream available, but no error
return nil, ErrClosedStream
}
if sid < r.streams.LastPeerStreamID() {
// no stream available, stream closed error
return nil, ErrClosedStream
}
return nil, ErrUnknownStream
}
func (r *MuxReader) defaultStreamErrorHandler(err error, header http2.FrameHeader) error {
if header.Flags.Has(http2.FlagHeadersEndStream) {
return nil
} else if err == ErrUnknownStream || err == ErrClosedStream {
return r.streamError(header.StreamID, http2.ErrCodeStreamClosed)
} else {
return err
}
}
// Receives header frames from a stream. A non-nil error is a connection error.
func (r *MuxReader) receiveHeaderData(frame *http2.MetaHeadersFrame) error {
var stream *MuxedStream
sid := frame.Header().StreamID
if sid == 0 {
return ErrUnexpectedFrameType
}
newStream := r.streams.IsPeerStreamID(sid)
if newStream {
// header request
// TODO support trailers (if stream exists)
ok, err := r.streams.AcquirePeerID(sid)
if !ok {
// ignore new streams while shutting down
return r.streamError(sid, err)
}
stream = r.newMuxedStream(sid)
// Set stream. Returns false if a stream already existed with that ID or we are shutting down, return false.
if !r.streams.Set(stream) {
// got HEADERS frame for an existing stream
// TODO support trailers
return r.streamError(sid, http2.ErrCodeInternal)
}
} else {
// header response
var err error
if stream, err = r.getStreamForFrame(frame); err != nil {
return r.defaultStreamErrorHandler(err, frame.Header())
}
}
headers := make([]Header, len(frame.Fields))
for i, header := range frame.Fields {
headers[i].Name = header.Name
headers[i].Value = header.Value
}
stream.Headers = headers
if frame.Header().Flags.Has(http2.FlagHeadersEndStream) {
stream.receiveEOF()
return nil
}
if newStream {
go r.handleStream(stream)
} else {
close(stream.responseHeadersReceived)
}
return nil
}
func (r *MuxReader) handleStream(stream *MuxedStream) {
defer stream.Close()
r.handler.ServeStream(stream)
}
// Receives a data frame from a stream. A non-nil error is a connection error.
func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.Entry) error {
logger := parentLogger.WithField("stream", frame.Header().StreamID)
stream, err := r.getStreamForFrame(frame)
if err != nil {
return r.defaultStreamErrorHandler(err, frame.Header())
}
data := frame.Data()
if len(data) > 0 {
_, err = stream.readBuffer.Write(data)
if err != nil {
return r.streamError(stream.streamID, http2.ErrCodeInternal)
}
}
if frame.Header().Flags.Has(http2.FlagDataEndStream) {
if stream.receiveEOF() {
r.streams.Delete(stream.streamID)
logger.Debug("stream closed")
} else {
logger.Debug("shutdown receive side")
}
return nil
}
if !stream.consumeReceiveWindow(uint32(len(data))) {
return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
}
return nil
}
// Receive a PING from the peer. Update RTT and send/receive window metrics if it's an ACK.
func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
ts := int64(binary.LittleEndian.Uint64(frame.Data[:]))
if !frame.IsAck() {
r.pingTimestamp.Set(ts)
return
}
r.rttMutex.Lock()
r.rttMeasurement.Update(time.Unix(0, ts))
r.rttMutex.Unlock()
r.flowControlMetrics = r.streams.Metrics()
}
// Receive a GOAWAY from the peer. Gracefully shut down our connection.
func (r *MuxReader) receiveGoAway(frame *http2.GoAwayFrame) error {
r.Shutdown()
// Close all streams above the last processed stream
lastStream := r.streams.LastLocalStreamID()
for i := frame.LastStreamID + 2; i <= lastStream; i++ {
if stream, ok := r.streams.Get(i); ok {
stream.Close()
}
}
return nil
}
// Receives header frames from a stream. A non-nil error is a connection error.
func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
stream, err := r.getStreamForFrame(frame)
if err != nil && err != ErrUnknownStream && err != ErrClosedStream {
return err
}
if stream == nil {
// ignore window updates on closed streams
return nil
}
stream.replenishSendWindow(frame.Increment)
return nil
}
// Raise a stream processing error, closing the stream. Runs on the write thread.
func (r *MuxReader) streamError(streamID uint32, e http2.ErrCode) error {
r.streamErrors.RaiseError(streamID, e)
return nil
}
func (r *MuxReader) connectionError(err error) error {
http2Code := http2.ErrCodeInternal
switch e := err.(type) {
case http2.ConnectionError:
http2Code = http2.ErrCode(e)
case MuxerProtocolError:
http2Code = e.h2code
}
log.Warnf("Connection error %v", http2Code)
r.sendGoAway(http2Code)
return err
}
// Instruct the writer to send a GOAWAY message if possible. This may fail in
// the case where an existing GOAWAY message is in flight or the writer event
// loop already ended.
func (r *MuxReader) sendGoAway(errCode http2.ErrCode) {
select {
case r.goAwayChan <- errCode:
default:
}
}

238
h2mux/muxwriter.go Normal file
View File

@ -0,0 +1,238 @@
package h2mux
import (
"bytes"
"encoding/binary"
"io"
"time"
log "github.com/Sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
type MuxWriter struct {
// f is used to write HTTP2 frames.
f *http2.Framer
// streams tracks currently-open streams.
streams *activeStreamMap
// streamErrors receives stream errors raised by the MuxReader.
streamErrors *StreamErrorMap
// readyStreamChan is used to multiplex writable streams onto the single connection.
// When a stream becomes writable its ID is sent on this channel.
readyStreamChan <-chan uint32
// newStreamChan is used to create new streams with a given set of headers.
newStreamChan <-chan MuxedStreamRequest
// goAwayChan is used to send a single GOAWAY message to the peer. The element received
// is the HTTP/2 error code to send.
goAwayChan <-chan http2.ErrCode
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
abortChan <-chan struct{}
// pingTimestamp is an atomic value containing the latest received ping timestamp.
pingTimestamp *PingTimestamp
// A timer used to measure idle connection time. Reset after sending data.
idleTimer *IdleTimer
// connActiveChan receives a signal that the connection received some (read) activity.
connActiveChan <-chan struct{}
// Maximum size of all frames that can be sent on this connection.
maxFrameSize uint32
// headerEncoder is the stateful header encoder for this connection
headerEncoder *hpack.Encoder
// headerBuffer is the temporary buffer used by headerEncoder.
headerBuffer bytes.Buffer
}
type MuxedStreamRequest struct {
stream *MuxedStream
body io.Reader
}
func (r *MuxedStreamRequest) flushBody() {
io.Copy(r.stream, r.body)
r.stream.CloseWrite()
}
func tsToPingData(ts int64) [8]byte {
pingData := [8]byte{}
binary.LittleEndian.PutUint64(pingData[:], uint64(ts))
return pingData
}
func (w *MuxWriter) run(parentLogger *log.Entry) error {
logger := parentLogger.WithFields(log.Fields{
"subsystem": "mux",
"dir": "write",
})
defer logger.Debug("event loop finished")
for {
select {
case <-w.abortChan:
logger.Debug("aborting writer thread")
return nil
case errCode := <-w.goAwayChan:
logger.Debug("sending GOAWAY code ", errCode)
err := w.f.WriteGoAway(w.streams.LastPeerStreamID(), errCode, []byte{})
if err != nil {
return err
}
w.idleTimer.MarkActive()
case <-w.pingTimestamp.GetUpdateChan():
logger.Debug("sending PING ACK")
err := w.f.WritePing(true, tsToPingData(w.pingTimestamp.Get()))
if err != nil {
return err
}
w.idleTimer.MarkActive()
case <-w.idleTimer.C:
if !w.idleTimer.Retry() {
return ErrConnectionDropped
}
logger.Debug("sending PING")
err := w.f.WritePing(false, tsToPingData(time.Now().UnixNano()))
if err != nil {
return err
}
w.idleTimer.ResetTimer()
case <-w.connActiveChan:
w.idleTimer.MarkActive()
case <-w.streamErrors.GetSignalChan():
for streamID, errCode := range w.streamErrors.GetErrors() {
logger.WithField("stream", streamID).WithField("code", errCode).Debug("resetting stream")
err := w.f.WriteRSTStream(streamID, errCode)
if err != nil {
return err
}
}
w.idleTimer.MarkActive()
case streamRequest := <-w.newStreamChan:
streamID := w.streams.AcquireLocalID()
streamRequest.stream.streamID = streamID
if !w.streams.Set(streamRequest.stream) {
// Race between OpenStream and Shutdown, and Shutdown won. Let Shutdown (and the eventual abort) take
// care of this stream. Ideally we'd pass the error directly to the stream object somehow so the
// caller can be unblocked sooner, but the value of that optimisation is minimal for most of the
// reasons why you'd call Shutdown anyway.
continue
}
if streamRequest.body != nil {
go streamRequest.flushBody()
}
streamLogger := logger.WithField("stream", streamID)
err := w.writeStreamData(streamRequest.stream, streamLogger)
if err != nil {
return err
}
w.idleTimer.MarkActive()
case streamID := <-w.readyStreamChan:
streamLogger := logger.WithField("stream", streamID)
stream, ok := w.streams.Get(streamID)
if !ok {
continue
}
err := w.writeStreamData(stream, streamLogger)
if err != nil {
return err
}
w.idleTimer.MarkActive()
}
}
}
func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error {
logger.Debug("writable")
chunk := stream.getChunk()
if chunk.sendHeadersFrame() {
err := w.writeHeaders(chunk.streamID, chunk.headers)
if err != nil {
logger.WithError(err).Warn("error writing headers")
return err
}
logger.Debug("output headers")
}
if chunk.sendWindowUpdateFrame() {
// Send a WINDOW_UPDATE frame to update our receive window.
// If the Stream ID is zero, the window update applies to the connection as a whole
// A WINDOW_UPDATE in a specific stream applies to the connection-level flow control as well.
err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
if err != nil {
logger.WithError(err).Warn("error writing window update")
return err
}
logger.Debugf("increment receive window by %d", chunk.windowUpdate)
}
for chunk.sendDataFrame() {
payload, sentEOF := chunk.nextDataFrame(int(w.maxFrameSize))
err := w.f.WriteData(chunk.streamID, sentEOF, payload)
if err != nil {
logger.WithError(err).Warn("error writing data")
return err
}
logger.WithField("len", len(payload)).Debug("output data")
if sentEOF {
if stream.readBuffer.Closed() {
// transition into closed state
if !stream.gotReceiveEOF() {
// the peer may send data that we no longer want to receive. Force them into the
// closed state.
logger.Debug("resetting stream")
w.f.WriteRSTStream(chunk.streamID, http2.ErrCodeNo)
} else {
// Half-open stream transitioned into closed
logger.Debug("closing stream")
}
w.streams.Delete(chunk.streamID)
} else {
logger.Debug("closing stream write side")
}
}
}
return nil
}
func (w *MuxWriter) encodeHeaders(headers []Header) ([]byte, error) {
w.headerBuffer.Reset()
for _, header := range headers {
err := w.headerEncoder.WriteField(hpack.HeaderField{
Name: header.Name,
Value: header.Value,
})
if err != nil {
return nil, err
}
}
return w.headerBuffer.Bytes(), nil
}
// writeHeaders writes a block of encoded headers, splitting it into multiple frames if necessary.
func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error {
encodedHeaders, err := w.encodeHeaders(headers)
if err != nil {
return err
}
blockSize := int(w.maxFrameSize)
continuation := false
endHeaders := len(encodedHeaders) == 0
for !endHeaders && err == nil {
blockFragment := encodedHeaders
if len(encodedHeaders) > blockSize {
blockFragment = blockFragment[:blockSize]
encodedHeaders = encodedHeaders[blockSize:]
} else {
endHeaders = true
}
if continuation {
err = w.f.WriteContinuation(streamID, endHeaders, blockFragment)
} else {
err = w.f.WriteHeaders(http2.HeadersFrameParam{
StreamID: streamID,
EndHeaders: endHeaders,
BlockFragment: blockFragment,
})
}
}
return err
}

140
h2mux/readylist.go Normal file
View File

@ -0,0 +1,140 @@
package h2mux
// ReadyList multiplexes several event signals onto a single channel.
type ReadyList struct {
signalC chan uint32
waitC chan uint32
}
func NewReadyList() *ReadyList {
rl := &ReadyList{
signalC: make(chan uint32),
waitC: make(chan uint32),
}
go rl.run()
return rl
}
// ID is the stream ID
func (r *ReadyList) Signal(ID uint32) {
r.signalC <- ID
}
func (r *ReadyList) ReadyChannel() <-chan uint32 {
return r.waitC
}
func (r *ReadyList) Close() {
close(r.signalC)
}
func (r *ReadyList) run() {
defer close(r.waitC)
var queue readyDescriptorQueue
var firstReady *readyDescriptor
activeDescriptors := newReadyDescriptorMap()
for {
if firstReady == nil {
// Wait for first ready descriptor
i, ok := <-r.signalC
if !ok {
// closed
return
}
firstReady = activeDescriptors.SetIfMissing(i)
}
select {
case r.waitC <- firstReady.ID:
activeDescriptors.Delete(firstReady.ID)
firstReady = queue.Dequeue()
case i, ok := <-r.signalC:
if !ok {
// closed
return
}
newReady := activeDescriptors.SetIfMissing(i)
if newReady != nil {
// key doesn't exist
queue.Enqueue(newReady)
}
}
}
}
type readyDescriptor struct {
ID uint32
Next *readyDescriptor
}
// readyDescriptorQueue is a queue of readyDescriptors in the form of a singly-linked list.
// The nil readyDescriptorQueue is an empty queue ready for use.
type readyDescriptorQueue struct {
Head *readyDescriptor
Tail *readyDescriptor
}
func (q *readyDescriptorQueue) Empty() bool {
return q.Head == nil
}
func (q *readyDescriptorQueue) Enqueue(x *readyDescriptor) {
if x.Next != nil {
panic("enqueued already queued item")
}
if q.Empty() {
q.Head = x
q.Tail = x
} else {
q.Tail.Next = x
q.Tail = x
}
}
// Dequeue returns the first readyDescriptor in the queue, or nil if empty.
func (q *readyDescriptorQueue) Dequeue() *readyDescriptor {
if q.Empty() {
return nil
}
x := q.Head
q.Head = x.Next
x.Next = nil
return x
}
// readyDescriptorQueue is a map of readyDescriptors keyed by ID.
// It maintains a free list of deleted ready descriptors.
type readyDescriptorMap struct {
descriptors map[uint32]*readyDescriptor
free []*readyDescriptor
}
func newReadyDescriptorMap() *readyDescriptorMap {
return &readyDescriptorMap{descriptors: make(map[uint32]*readyDescriptor)}
}
// create or reuse a readyDescriptor if the stream is not in the queue.
// This avoid stream starvation caused by a single high-bandwidth stream monopolising the writer goroutine
func (m *readyDescriptorMap) SetIfMissing(key uint32) *readyDescriptor {
if _, ok := m.descriptors[key]; ok {
return nil
}
var newDescriptor *readyDescriptor
if len(m.free) > 0 {
// reuse deleted ready descriptors
newDescriptor = m.free[len(m.free)-1]
m.free = m.free[:len(m.free)-1]
} else {
newDescriptor = &readyDescriptor{}
}
newDescriptor.ID = key
m.descriptors[key] = newDescriptor
return newDescriptor
}
func (m *readyDescriptorMap) Delete(key uint32) {
if descriptor, ok := m.descriptors[key]; ok {
m.free = append(m.free, descriptor)
delete(m.descriptors, key)
}
}

115
h2mux/readylist_test.go Normal file
View File

@ -0,0 +1,115 @@
package h2mux
import (
"testing"
"time"
)
func TestReadyList(t *testing.T) {
rl := NewReadyList()
c := rl.ReadyChannel()
// helper functions
assertEmpty := func() {
select {
case <-c:
t.Fatalf("Spurious wakeup")
default:
}
}
receiveWithTimeout := func() uint32 {
select {
case i := <-c:
return i
case <-time.After(100 * time.Millisecond):
t.Fatalf("Timeout")
return 0
}
}
// no signals, receive should fail
assertEmpty()
rl.Signal(0)
if receiveWithTimeout() != 0 {
t.Fatalf("Received wrong ID of signalled event")
}
// no new signals, receive should fail
assertEmpty()
// Signals should not block;
// Duplicate unhandled signals should not cause multiple wakeups
signalled := [5]bool{}
for i := range signalled {
rl.Signal(uint32(i))
rl.Signal(uint32(i))
}
// All signals should be received once (in any order)
for range signalled {
i := receiveWithTimeout()
if signalled[i] {
t.Fatalf("Received signal %d more than once", i)
}
signalled[i] = true
}
}
func TestReadyDescriptorQueue(t *testing.T) {
var queue readyDescriptorQueue
items := [4]readyDescriptor{}
for i := range items {
items[i].ID = uint32(i)
}
if !queue.Empty() {
t.Fatalf("nil queue should be empty")
}
queue.Enqueue(&items[3])
queue.Enqueue(&items[1])
queue.Enqueue(&items[0])
queue.Enqueue(&items[2])
if queue.Empty() {
t.Fatalf("Empty should be false after enqueue")
}
i := queue.Dequeue().ID
if i != 3 {
t.Fatalf("item 3 should have been dequeued, got %d instead", i)
}
i = queue.Dequeue().ID
if i != 1 {
t.Fatalf("item 1 should have been dequeued, got %d instead", i)
}
i = queue.Dequeue().ID
if i != 0 {
t.Fatalf("item 0 should have been dequeued, got %d instead", i)
}
i = queue.Dequeue().ID
if i != 2 {
t.Fatalf("item 2 should have been dequeued, got %d instead", i)
}
if !queue.Empty() {
t.Fatal("queue should be empty after dequeuing all items")
}
if queue.Dequeue() != nil {
t.Fatal("dequeue on empty queue should return nil")
}
}
func TestReadyDescriptorMap(t *testing.T) {
m := newReadyDescriptorMap()
m.Delete(42)
// (delete of missing key should be a noop)
x := m.SetIfMissing(42)
if x == nil {
t.Fatal("SetIfMissing for new key returned nil")
}
if m.SetIfMissing(42) != nil {
t.Fatal("SetIfMissing for existing key returned non-nil")
}
// this delete has effect
m.Delete(42)
// the next set should reuse the old object
y := m.SetIfMissing(666)
if y == nil {
t.Fatal("SetIfMissing for new key returned nil")
}
if x != y {
t.Fatal("SetIfMissing didn't reuse freed object")
}
}

53
h2mux/rtt.go Normal file
View File

@ -0,0 +1,53 @@
package h2mux
import (
"sync/atomic"
"time"
)
// PingTimestamp is an atomic interface around ping timestamping and signalling.
type PingTimestamp struct {
ts int64
signal Signal
}
func NewPingTimestamp() *PingTimestamp {
return &PingTimestamp{signal: NewSignal()}
}
func (pt *PingTimestamp) Set(v int64) {
if atomic.SwapInt64(&pt.ts, v) != 0 {
pt.signal.Signal()
}
}
func (pt *PingTimestamp) Get() int64 {
return atomic.SwapInt64(&pt.ts, 0)
}
func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} {
return pt.signal.WaitChannel()
}
// RTTMeasurement encapsulates a continuous round trip time measurement.
type RTTMeasurement struct {
Current, Min, Max time.Duration
lastMeasurementTime time.Time
}
// Update updates the computed values with a new measurement.
// outgoingTime is the time that the probe was sent.
// We assume that time.Now() is the time we received that probe.
func (r *RTTMeasurement) Update(outgoingTime time.Time) {
if !r.lastMeasurementTime.Before(outgoingTime) {
return
}
r.lastMeasurementTime = outgoingTime
r.Current = time.Since(outgoingTime)
if r.Max < r.Current {
r.Max = r.Current
}
if r.Min > r.Current {
r.Min = r.Current
}
}

64
h2mux/shared_buffer.go Normal file
View File

@ -0,0 +1,64 @@
package h2mux
import (
"bytes"
"io"
"sync"
)
type SharedBuffer struct {
cond *sync.Cond
buffer bytes.Buffer
eof bool
}
func NewSharedBuffer() *SharedBuffer {
return &SharedBuffer{
cond: sync.NewCond(&sync.Mutex{}),
}
}
func (s *SharedBuffer) Read(p []byte) (n int, err error) {
totalRead := 0
s.cond.L.Lock()
for totalRead < len(p) {
n, err = s.buffer.Read(p[totalRead:])
totalRead += n
if err == io.EOF {
if s.eof {
break
}
err = nil
s.cond.Wait()
}
}
s.cond.L.Unlock()
return totalRead, err
}
func (s *SharedBuffer) Write(p []byte) (n int, err error) {
s.cond.L.Lock()
defer s.cond.L.Unlock()
if s.eof {
return 0, io.EOF
}
n, err = s.buffer.Write(p)
s.cond.Signal()
return
}
func (s *SharedBuffer) Close() error {
s.cond.L.Lock()
defer s.cond.L.Unlock()
if !s.eof {
s.eof = true
s.cond.Signal()
}
return nil
}
func (s *SharedBuffer) Closed() bool {
s.cond.L.Lock()
defer s.cond.L.Unlock()
return s.eof
}

120
h2mux/shared_buffer_test.go Normal file
View File

@ -0,0 +1,120 @@
package h2mux
import (
"bytes"
"io"
"sync"
"testing"
"time"
)
func AssertIOReturnIsGood(t *testing.T, expected int) func(int, error) {
return func(actual int, err error) {
if expected != actual {
t.Fatalf("Expected %d bytes, got %d", expected, actual)
}
if err != nil {
t.Fatalf("Unexpected error %s", err)
}
}
}
func TestSharedBuffer(t *testing.T) {
b := NewSharedBuffer()
testData := []byte("Hello world")
AssertIOReturnIsGood(t, len(testData))(b.Write(testData))
bytesRead := make([]byte, len(testData))
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
}
func TestSharedBufferBlockingRead(t *testing.T) {
b := NewSharedBuffer()
testData := []byte("Hello world")
result := make(chan []byte)
go func() {
bytesRead := make([]byte, len(testData))
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
result <- bytesRead
}()
select {
case <-result:
t.Fatalf("read returned early")
default:
}
AssertIOReturnIsGood(t, 5)(b.Write(testData[:5]))
select {
case <-result:
t.Fatalf("read returned early")
default:
}
AssertIOReturnIsGood(t, len(testData)-5)(b.Write(testData[5:]))
select {
case r := <-result:
if string(r) != string(testData) {
t.Fatalf("expected read to return %s, got %s", testData, r)
}
case <-time.After(time.Second):
t.Fatalf("read timed out")
}
}
// This is quite slow under the race detector
func TestSharedBufferConcurrentReadWrite(t *testing.T) {
b := NewSharedBuffer()
var expectedResult, actualResult bytes.Buffer
var wg sync.WaitGroup
wg.Add(2)
go func() {
block := make([]byte, 256)
for i := range block {
block[i] = byte(i)
}
for blockSize := 1; blockSize <= 256; blockSize++ {
for i := 0; i < 256; i++ {
expectedResult.Write(block[:blockSize])
n, err := b.Write(block[:blockSize])
if n != blockSize || err != nil {
t.Fatalf("write error: %d %s", n, err)
}
}
}
wg.Done()
}()
go func() {
block := make([]byte, 256)
// Change block sizes in opposition to the write thread, to test blocking for new data.
for blockSize := 256; blockSize > 0; blockSize-- {
for i := 0; i < 256; i++ {
n, err := b.Read(block[:blockSize])
if n != blockSize || err != nil {
t.Fatalf("read error: %d %s", n, err)
}
actualResult.Write(block[:blockSize])
}
}
wg.Done()
}()
wg.Wait()
if bytes.Compare(expectedResult.Bytes(), actualResult.Bytes()) != 0 {
t.Fatal("Result diverged")
}
}
func TestSharedBufferClose(t *testing.T) {
b := NewSharedBuffer()
testData := []byte("Hello world")
AssertIOReturnIsGood(t, len(testData))(b.Write(testData))
err := b.Close()
if err != nil {
t.Fatalf("unexpected error from Close: %s", err)
}
bytesRead := make([]byte, len(testData))
AssertIOReturnIsGood(t, len(testData))(b.Read(bytesRead))
n, err := b.Read(bytesRead)
if n != 0 {
t.Fatalf("extra bytes received: %d", n)
}
if err != io.EOF {
t.Fatalf("expected EOF, got %s", err)
}
}

34
h2mux/signal.go Normal file
View File

@ -0,0 +1,34 @@
package h2mux
// Signal describes an event that can be waited on for at least one signal.
// Signalling the event while it is in the signalled state is a noop.
// When the waiter wakes up, the signal is set to unsignalled.
// It is a way for any number of writers to inform a reader (without blocking)
// that an event has happened.
type Signal struct {
c chan struct{}
}
// NewSignal creates a new Signal.
func NewSignal() Signal {
return Signal{c: make(chan struct{}, 1)}
}
// Signal signals the event.
func (s Signal) Signal() {
// This channel is buffered, so the nonblocking send will always succeed if the buffer is empty.
select {
case s.c <- struct{}{}:
default:
}
}
// Wait for the event to be signalled.
func (s Signal) Wait() {
<-s.c
}
// WaitChannel returns a channel that is readable after Signal is called.
func (s Signal) WaitChannel() <-chan struct{} {
return s.c
}

47
h2mux/streamerrormap.go Normal file
View File

@ -0,0 +1,47 @@
package h2mux
import (
"sync"
"golang.org/x/net/http2"
)
// StreamErrorMap is used to track stream errors. This is a separate structure to ActiveStreamMap because
// errors can be raised against non-existent or closed streams.
type StreamErrorMap struct {
sync.RWMutex
// errors tracks per-stream errors
errors map[uint32]http2.ErrCode
// hasError is signaled whenever an error is raised.
hasError Signal
}
// NewStreamErrorMap creates a new StreamErrorMap.
func NewStreamErrorMap() *StreamErrorMap {
return &StreamErrorMap{
errors: make(map[uint32]http2.ErrCode),
hasError: NewSignal(),
}
}
// RaiseError raises a stream error.
func (s *StreamErrorMap) RaiseError(streamID uint32, err http2.ErrCode) {
s.Lock()
s.errors[streamID] = err
s.Unlock()
s.hasError.Signal()
}
// GetSignalChan returns a channel that is signalled when an error is raised.
func (s *StreamErrorMap) GetSignalChan() <-chan struct{} {
return s.hasError.WaitChannel()
}
// GetErrors retrieves all errors currently raised. This resets the currently-tracked errors.
func (s *StreamErrorMap) GetErrors() map[uint32]http2.ErrCode {
s.Lock()
errors := s.errors
s.errors = make(map[uint32]http2.ErrCode)
s.Unlock()
return errors
}

48
metrics/metrics.go Normal file
View File

@ -0,0 +1,48 @@
package metrics
import (
"net"
"net/http"
_ "net/http/pprof"
"runtime"
"time"
"golang.org/x/net/context"
log "github.com/Sirupsen/logrus"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
func ServeMetrics(l net.Listener, shutdownC <-chan struct{}) error {
server := &http.Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
}
go func() {
<-shutdownC
server.Shutdown(context.Background())
}()
http.Handle("/metrics", promhttp.Handler())
log.WithField("addr", l.Addr()).Info("Starting metrics server")
err := server.Serve(l)
if err == http.ErrServerClosed {
log.Info("Metrics server stopped")
return nil
}
log.WithError(err).Error("Metrics server quit with error")
return err
}
func RegisterBuildInfo(buildTime string, version string) {
buildInfo := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
// Don't namespace build_info, since we want it to be consistent across all Cloudflare services
Name: "build_info",
Help: "Build and version information",
},
[]string{"goversion", "revision", "version"},
)
prometheus.MustRegister(buildInfo)
buildInfo.WithLabelValues(runtime.Version(), buildTime, version).Set(1)
}

70
origin/backoffhandler.go Normal file
View File

@ -0,0 +1,70 @@
package origin
import (
"time"
"golang.org/x/net/context"
)
// Redeclare time functions so they can be overridden in tests.
var (
timeNow = time.Now
timeAfter = time.After
)
// BackoffHandler manages exponential backoff and limits the maximum number of retries.
// The base time period is 1 second, doubling with each retry.
// After initial success, a grace period can be set to reset the backoff timer if
// a connection is maintained successfully for a long enough period. The base grace period
// is 2 seconds, doubling with each retry.
type BackoffHandler struct {
// MaxRetries sets the maximum number of retries to perform. The default value
// of 0 disables retry completely.
MaxRetries uint
retries uint
resetDeadline time.Time
}
func (b BackoffHandler) GetBackoffDuration(ctx context.Context) (time.Duration, bool) {
// Follows the same logic as Backoff, but without mutating the receiver.
// This select has to happen first to reflect the actual behaviour of the Backoff function.
select {
case <-ctx.Done():
return time.Duration(0), false
default:
}
if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) {
// b.retries would be set to 0 at this point
return time.Second, true
}
if b.retries >= b.MaxRetries {
return time.Duration(0), false
}
return time.Duration(time.Second * 1 << b.retries), true
}
// Backoff is used to wait according to exponential backoff. Returns false if the
// maximum number of retries have been used or if the underlying context has been cancelled.
func (b *BackoffHandler) Backoff(ctx context.Context) bool {
if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) {
b.retries = 0
b.resetDeadline = time.Time{}
}
if b.retries >= b.MaxRetries {
return false
}
select {
case <-timeAfter(time.Duration(time.Second * 1 << b.retries)):
b.retries++
return true
case <-ctx.Done():
return false
}
}
// Sets a grace period within which the the backoff timer is maintained. After the grace
// period expires, the number of retries & backoff duration is reset.
func (b *BackoffHandler) SetGracePeriod() {
b.resetDeadline = timeNow().Add(time.Duration(time.Second * 2 << b.retries))
}

View File

@ -0,0 +1,114 @@
package origin
import (
"testing"
"time"
"golang.org/x/net/context"
)
func immediateTimeAfter(time.Duration) <-chan time.Time {
c := make(chan time.Time, 1)
c <- time.Now()
return c
}
func TestBackoffRetries(t *testing.T) {
// make backoff return immediately
timeAfter = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 3}
if !backoff.Backoff(ctx) {
t.Fatalf("backoff failed immediately")
}
if !backoff.Backoff(ctx) {
t.Fatalf("backoff failed after 1 retry")
}
if !backoff.Backoff(ctx) {
t.Fatalf("backoff failed after 2 retry")
}
if backoff.Backoff(ctx) {
t.Fatalf("backoff allowed after 3 (max) retries")
}
}
func TestBackoffCancel(t *testing.T) {
// prevent backoff from returning normally
timeAfter = func(time.Duration) <-chan time.Time { return make(chan time.Time) }
ctx, cancelFunc := context.WithCancel(context.Background())
backoff := BackoffHandler{MaxRetries: 3}
cancelFunc()
if backoff.Backoff(ctx) {
t.Fatalf("backoff allowed after cancel")
}
if _, ok := backoff.GetBackoffDuration(ctx); ok {
t.Fatalf("backoff allowed after cancel")
}
}
func TestBackoffGracePeriod(t *testing.T) {
currentTime := time.Now()
// make timeNow return whatever we like
timeNow = func() time.Time { return currentTime }
// make backoff return immediately
timeAfter = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 1}
if !backoff.Backoff(ctx) {
t.Fatalf("backoff failed immediately")
}
// the next call to Backoff would fail unless it's after the grace period
backoff.SetGracePeriod()
// advance time to after the grace period (~4 seconds) and see what happens
currentTime = currentTime.Add(time.Second * 5)
if !backoff.Backoff(ctx) {
t.Fatalf("backoff failed after the grace period expired")
}
// confirm we ignore grace period after backoff
if backoff.Backoff(ctx) {
t.Fatalf("backoff allowed after 1 (max) retry")
}
}
func TestGetBackoffDurationRetries(t *testing.T) {
// make backoff return immediately
timeAfter = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 3}
if _, ok := backoff.GetBackoffDuration(ctx); !ok {
t.Fatalf("backoff failed immediately")
}
backoff.Backoff(ctx) // noop
if _, ok := backoff.GetBackoffDuration(ctx); !ok {
t.Fatalf("backoff failed after 1 retry")
}
backoff.Backoff(ctx) // noop
if _, ok := backoff.GetBackoffDuration(ctx); !ok {
t.Fatalf("backoff failed after 2 retry")
}
backoff.Backoff(ctx) // noop
if _, ok := backoff.GetBackoffDuration(ctx); ok {
t.Fatalf("backoff allowed after 3 (max) retries")
}
if backoff.Backoff(ctx) {
t.Fatalf("backoff allowed after 3 (max) retries")
}
}
func TestGetBackoffDuration(t *testing.T) {
// make backoff return immediately
timeAfter = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 3}
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second {
t.Fatalf("backoff didn't return 1 second on first retry")
}
backoff.Backoff(ctx) // noop
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*2 {
t.Fatalf("backoff didn't return 2 seconds on second retry")
}
backoff.Backoff(ctx) // noop
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*4 {
t.Fatalf("backoff didn't return 4 seconds on third retry")
}
}

360
origin/tunnel.go Normal file
View File

@ -0,0 +1,360 @@
package origin
import (
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/url"
"runtime"
"strings"
"time"
"golang.org/x/net/context"
"github.com/cloudflare/cloudflare-warp/h2mux"
"github.com/cloudflare/cloudflare-warp/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflare-warp/tunnelrpc/pogs"
"github.com/cloudflare/cloudflare-warp/validation"
log "github.com/Sirupsen/logrus"
raven "github.com/getsentry/raven-go"
"github.com/pkg/errors"
rpc "zombiezen.com/go/capnproto2/rpc"
)
const (
dialTimeout = 15 * time.Second
TagHeaderNamePrefix = "Cf-Warp-Tag-"
)
type TunnelConfig struct {
EdgeAddr string
OriginUrl string
Hostname string
APIKey string
APIEmail string
APICAKey string
TlsConfig *tls.Config
Retries uint
HeartbeatInterval time.Duration
MaxHeartbeats uint64
ClientID string
ReportedVersion string
LBPool string
Tags []tunnelpogs.Tag
AccessInternalIP bool
ConnectedSignal h2mux.Signal
}
type dialError struct {
cause error
}
func (e dialError) Error() string {
return e.cause.Error()
}
type printableRegisterTunnelError struct {
cause error
permanent bool
}
func (e printableRegisterTunnelError) Error() string {
return e.cause.Error()
}
func (c *TunnelConfig) RegistrationOptions() *tunnelpogs.RegistrationOptions {
policy := tunnelrpc.ExistingTunnelPolicy_disconnect
if c.LBPool != "" {
policy = tunnelrpc.ExistingTunnelPolicy_balance
}
return &tunnelpogs.RegistrationOptions{
ClientID: c.ClientID,
Version: c.ReportedVersion,
OS: fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH),
ExistingTunnelPolicy: policy,
PoolID: c.LBPool,
Tags: c.Tags,
ExposeInternalHostname: c.AccessInternalIP,
}
}
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}) error {
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-shutdownC
cancel()
}()
backoff := BackoffHandler{MaxRetries: config.Retries}
for {
err, recoverable := ServeTunnel(ctx, config, &backoff)
if recoverable {
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
log.Infof("Retrying in %s seconds", duration)
backoff.Backoff(ctx)
continue
}
}
return err
}
}
func ServeTunnel(
ctx context.Context,
config *TunnelConfig,
backoff *BackoffHandler,
) (err error, recoverable bool) {
// Returns error from parsing the origin URL or handshake errors
handler, err := NewTunnelHandler(ctx, config)
if err != nil {
errLog := log.WithError(err)
switch err.(type) {
case dialError:
errLog.Error("Unable to dial edge")
case h2mux.MuxerHandshakeError:
errLog.Error("Handshake failed with edge server")
default:
errLog.Error("Tunnel creation failure")
return err, false
}
return err, true
}
serveCtx, serveCancel := context.WithCancel(ctx)
registerErrC := make(chan error, 1)
go func() {
err := RegisterTunnel(serveCtx, handler.muxer, config)
if err == nil {
config.ConnectedSignal.Signal()
backoff.SetGracePeriod()
} else {
serveCancel()
}
registerErrC <- err
}()
go func() {
<-serveCtx.Done()
handler.muxer.Shutdown()
}()
err = handler.muxer.Serve()
serveCancel()
registerErr := <-registerErrC
if err != nil {
log.WithError(err).Error("Tunnel error")
return err, true
}
if registerErr != nil {
raven.CaptureError(registerErr, nil)
// Don't retry on errors like entitlement failure or version too old
if e, ok := registerErr.(printableRegisterTunnelError); ok {
log.WithError(e).Error("Cannot register")
if e.permanent {
return nil, false
}
return e.cause, true
}
log.Error("Cannot register")
return err, true
}
return nil, false
}
func IsRPCStreamResponse(headers []h2mux.Header) bool {
if len(headers) != 1 {
return false
}
if headers[0].Name != ":status" || headers[0].Value != "200" {
return false
}
return true
}
func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig) error {
logger := log.WithField("subsystem", "rpc")
logger.Debug("initiating RPC stream")
stream, err := muxer.OpenStream([]h2mux.Header{
{Name: ":method", Value: "RPC"},
{Name: ":scheme", Value: "capnp"},
{Name: ":path", Value: "*"},
}, nil)
if err != nil {
// RPC stream open error
raven.CaptureError(err, nil)
return err
}
if !IsRPCStreamResponse(stream.Headers) {
// stream response error
raven.CaptureError(err, nil)
return err
}
conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
)
defer conn.Close()
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
// Request server info without blocking tunnel registration; must use capnp library directly.
tsClient := tunnelrpc.TunnelServer{Client: ts.Client}
serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil
})
registration, err := ts.RegisterTunnel(
ctx,
&tunnelpogs.Authentication{Key: config.APIKey, Email: config.APIEmail, OriginCAKey: config.APICAKey},
config.Hostname,
config.RegistrationOptions(),
)
LogServerInfo(logger, serverInfoPromise.Result())
if err != nil {
// RegisterTunnel RPC failure
return err
}
for _, logLine := range registration.LogLines {
logger.Info(logLine)
}
if registration.Err != "" {
return printableRegisterTunnelError{
cause: fmt.Errorf("Server error: %s", registration.Err),
permanent: registration.PermanentFailure,
}
}
for _, url := range registration.Urls {
log.Infof("Registered at %s", url)
}
for _, logLine := range registration.LogLines {
log.Infof(logLine)
}
return nil
}
func LogServerInfo(logger *log.Entry, promise tunnelrpc.ServerInfo_Promise) {
serverInfoMessage, err := promise.Struct()
if err != nil {
logger.WithError(err).Warn("Failed to retrieve server information")
return
}
serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage)
if err != nil {
logger.WithError(err).Warn("Failed to retrieve server information")
return
}
log.Infof("Connected to %s", serverInfo.LocationName)
}
type TunnelHandler struct {
originUrl string
muxer *h2mux.Muxer
httpClient *http.Client
tags []tunnelpogs.Tag
}
var dialer = net.Dialer{DualStack: true}
func NewTunnelHandler(ctx context.Context, config *TunnelConfig) (*TunnelHandler, error) {
url, err := validation.ValidateUrl(config.OriginUrl)
if err != nil {
return nil, fmt.Errorf("Unable to parse origin url %#v", url)
}
h := &TunnelHandler{
originUrl: url,
httpClient: &http.Client{Timeout: time.Minute},
tags: config.Tags,
}
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
// TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one)
plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", config.EdgeAddr)
dialCancel()
if err != nil {
return nil, dialError{cause: err}
}
edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig)
edgeConn.SetDeadline(time.Now().Add(dialTimeout))
err = edgeConn.Handshake()
if err != nil {
return nil, dialError{cause: err}
}
// clear the deadline on the conn; h2mux has its own timeouts
edgeConn.SetDeadline(time.Time{})
// Establish a muxed connection with the edge
// Client mux handshake with agent server
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
Timeout: 5 * time.Second,
Handler: h,
IsClient: true,
HeartbeatInterval: config.HeartbeatInterval,
MaxHeartbeats: config.MaxHeartbeats,
})
if err != nil {
return h, errors.New("TLS handshake error")
}
return h, err
}
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
for _, header := range h2 {
switch header.Name {
case ":method":
h1.Method = header.Value
case ":scheme":
case ":authority":
// Otherwise the host header will be based on the origin URL
h1.Host = header.Value
case ":path":
u, err := url.Parse(header.Value)
if err != nil {
return fmt.Errorf("unparseable path")
}
resolved := h1.URL.ResolveReference(u)
// prevent escaping base URL
if !strings.HasPrefix(resolved.String(), h1.URL.String()) {
return fmt.Errorf("invalid path")
}
h1.URL = resolved
default:
h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value)
}
}
return nil
}
func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}}
for headerName, headerValues := range h1.Header {
for _, headerValue := range headerValues {
h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue})
}
}
return
}
func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
for _, tag := range h.tags {
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
}
}
func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
if err != nil {
log.WithError(err).Panic("Unexpected error from http.NewRequest")
}
err = H2RequestHeadersToH1Request(stream.Headers, req)
if err != nil {
log.WithError(err).Error("invalid request received")
}
h.AppendTagHeaders(req)
response, err := h.httpClient.Do(req)
if err != nil {
log.WithError(err).Error("HTTP request error")
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
stream.Write([]byte("502 Bad Gateway"))
} else {
defer response.Body.Close()
stream.WriteHeaders(H1ResponseToH2Response(response))
io.Copy(stream, response.Body)
}
return nil
}

62
tlsconfig/tlsconfig.go Normal file
View File

@ -0,0 +1,62 @@
// Package tlsconfig provides convenience functions for configuring TLS connections from the
// command line.
package tlsconfig
import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
log "github.com/Sirupsen/logrus"
cli "gopkg.in/urfave/cli.v2"
)
// CLIFlags names the flags used to configure TLS for a command or subsystem.
// The nil value for a field means the flag is ignored.
type CLIFlags struct {
Cert string
Key string
ClientCert string
RootCA string
}
// GetConfig returns a TLS configuration according to the flags defined in f and
// set by the user.
func (f CLIFlags) GetConfig(c *cli.Context) *tls.Config {
config := &tls.Config{}
if c.IsSet(f.Cert) && c.IsSet(f.Key) {
cert, err := tls.LoadX509KeyPair(c.String(f.Cert), c.String(f.Key))
if err != nil {
log.WithError(err).Fatal("Error parsing X509 key pair")
}
config.Certificates = []tls.Certificate{cert}
config.BuildNameToCertificate()
}
if c.IsSet(f.ClientCert) {
// set of root certificate authorities that servers use if required to verify a client certificate
// by the policy in ClientAuth
config.ClientCAs = LoadCert(c.String(f.ClientCert))
// server's policy for TLS Client Authentication. Default is no client cert
config.ClientAuth = tls.RequireAndVerifyClientCert
}
// set of root certificate authorities that clients use when verifying server certificates
if c.IsSet(f.RootCA) {
config.RootCAs = LoadCert(c.String(f.RootCA))
}
return config
}
// LoadCert creates a CertPool containing all certificates in a PEM-format file.
func LoadCert(certPath string) *x509.CertPool {
caCert, err := ioutil.ReadFile(certPath)
if err != nil {
log.WithError(err).Fatalf("Error reading certificate %s", certPath)
}
ca := x509.NewCertPool()
if !ca.AppendCertsFromPEM(caCert) {
log.WithError(err).Fatalf("Error parsing certificate %s", certPath)
}
return ca
}

15
tunnelrpc/go.capnp Normal file
View File

@ -0,0 +1,15 @@
# Generate go.capnp.out with:
# capnp compile -o- go.capnp > go.capnp.out
# Must run inside this directory to preserve paths.
@0xd12a1c51fedd6c88;
annotation package(file) :Text;
annotation import(file) :Text;
annotation doc(struct, field, enum) :Text;
annotation tag(enumerant) :Text;
annotation notag(enumerant) :Void;
annotation customtype(field) :Text;
annotation name(struct, field, union, enum, enumerant, interface, method, param, annotation, const, group) :Text;
$package("capnp");

26
tunnelrpc/log.go Normal file
View File

@ -0,0 +1,26 @@
package tunnelrpc
//go:generate capnp compile -ogo -I./tunnelrpc/ tunnelrpc.capnp
import (
log "github.com/Sirupsen/logrus"
"golang.org/x/net/context"
"zombiezen.com/go/capnproto2/rpc"
)
// ConnLogger wraps a logrus *log.Entry for a connection.
type ConnLogger struct {
Entry *log.Entry
}
func (c ConnLogger) Infof(ctx context.Context, format string, args ...interface{}) {
c.Entry.Infof(format, args...)
}
func (c ConnLogger) Errorf(ctx context.Context, format string, args ...interface{}) {
c.Entry.Errorf(format, args...)
}
func ConnLog(log *log.Entry) rpc.ConnOption {
return rpc.ConnLog(ConnLogger{log})
}

45
tunnelrpc/logtransport.go Normal file
View File

@ -0,0 +1,45 @@
// Package logtransport provides a transport that logs all of its messages.
package tunnelrpc
import (
"bytes"
log "github.com/Sirupsen/logrus"
"golang.org/x/net/context"
"zombiezen.com/go/capnproto2/encoding/text"
"zombiezen.com/go/capnproto2/rpc"
rpccapnp "zombiezen.com/go/capnproto2/std/capnp/rpc"
)
type transport struct {
rpc.Transport
l *log.Entry
}
// New creates a new logger that proxies messages to and from t and
// logs them to l. If l is nil, then the log package's default
// logger is used.
func NewTransportLogger(l *log.Entry, t rpc.Transport) rpc.Transport {
return &transport{Transport: t, l: l}
}
func (t *transport) SendMessage(ctx context.Context, msg rpccapnp.Message) error {
t.l.Debugf("tx %s", formatMsg(msg))
return t.Transport.SendMessage(ctx, msg)
}
func (t *transport) RecvMessage(ctx context.Context) (rpccapnp.Message, error) {
msg, err := t.Transport.RecvMessage(ctx)
if err != nil {
t.l.WithError(err).Debug("rx error")
return msg, err
}
t.l.Debugf("rx %s", formatMsg(msg))
return msg, nil
}
func formatMsg(m rpccapnp.Message) string {
var buf bytes.Buffer
text.NewEncoder(&buf).Encode(0x91b79f1f808db032, m.Struct)
return buf.String()
}

194
tunnelrpc/pogs/tunnelrpc.go Normal file
View File

@ -0,0 +1,194 @@
package pogs
import (
"github.com/cloudflare/cloudflare-warp/tunnelrpc"
"golang.org/x/net/context"
"zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/pogs"
"zombiezen.com/go/capnproto2/rpc"
"zombiezen.com/go/capnproto2/server"
)
type Authentication struct {
Key string
Email string
OriginCAKey string
}
func MarshalAuthentication(s tunnelrpc.Authentication, p *Authentication) error {
return pogs.Insert(tunnelrpc.Authentication_TypeID, s.Struct, p)
}
func UnmarshalAuthentication(s tunnelrpc.Authentication) (*Authentication, error) {
p := new(Authentication)
err := pogs.Extract(p, tunnelrpc.Authentication_TypeID, s.Struct)
return p, err
}
type TunnelRegistration struct {
Err string
Urls []string
LogLines []string
PermanentFailure bool
}
func MarshalTunnelRegistration(s tunnelrpc.TunnelRegistration, p *TunnelRegistration) error {
return pogs.Insert(tunnelrpc.TunnelRegistration_TypeID, s.Struct, p)
}
func UnmarshalTunnelRegistration(s tunnelrpc.TunnelRegistration) (*TunnelRegistration, error) {
p := new(TunnelRegistration)
err := pogs.Extract(p, tunnelrpc.TunnelRegistration_TypeID, s.Struct)
return p, err
}
type RegistrationOptions struct {
ClientID string `capnp:"clientId"`
Version string
OS string `capnp:"os"`
ExistingTunnelPolicy tunnelrpc.ExistingTunnelPolicy
PoolID string `capnp:"poolId"`
ExposeInternalHostname bool
Tags []Tag
}
func MarshalRegistrationOptions(s tunnelrpc.RegistrationOptions, p *RegistrationOptions) error {
return pogs.Insert(tunnelrpc.RegistrationOptions_TypeID, s.Struct, p)
}
func UnmarshalRegistrationOptions(s tunnelrpc.RegistrationOptions) (*RegistrationOptions, error) {
p := new(RegistrationOptions)
err := pogs.Extract(p, tunnelrpc.RegistrationOptions_TypeID, s.Struct)
return p, err
}
type Tag struct {
Name string
Value string
}
type ServerInfo struct {
LocationName string
}
func MarshalServerInfo(s tunnelrpc.ServerInfo, p *ServerInfo) error {
return pogs.Insert(tunnelrpc.ServerInfo_TypeID, s.Struct, p)
}
func UnmarshalServerInfo(s tunnelrpc.ServerInfo) (*ServerInfo, error) {
p := new(ServerInfo)
err := pogs.Extract(p, tunnelrpc.ServerInfo_TypeID, s.Struct)
return p, err
}
type TunnelServer interface {
RegisterTunnel(ctx context.Context, auth *Authentication, hostname string, options *RegistrationOptions) (*TunnelRegistration, error)
GetServerInfo(ctx context.Context) (*ServerInfo, error)
}
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {
return tunnelrpc.TunnelServer_ServerToClient(TunnelServer_PogsImpl{s})
}
type TunnelServer_PogsImpl struct {
impl TunnelServer
}
func (i TunnelServer_PogsImpl) RegisterTunnel(p tunnelrpc.TunnelServer_registerTunnel) error {
authentication, err := p.Params.Auth()
if err != nil {
return err
}
pogsAuthentication, err := UnmarshalAuthentication(authentication)
if err != nil {
return err
}
hostname, err := p.Params.Hostname()
if err != nil {
return err
}
options, err := p.Params.Options()
if err != nil {
return err
}
pogsOptions, err := UnmarshalRegistrationOptions(options)
if err != nil {
return err
}
server.Ack(p.Options)
registration, err := i.impl.RegisterTunnel(p.Ctx, pogsAuthentication, hostname, pogsOptions)
if err != nil {
return err
}
result, err := p.Results.NewResult()
if err != nil {
return err
}
return MarshalTunnelRegistration(result, registration)
}
func (i TunnelServer_PogsImpl) GetServerInfo(p tunnelrpc.TunnelServer_getServerInfo) error {
server.Ack(p.Options)
serverInfo, err := i.impl.GetServerInfo(p.Ctx)
if err != nil {
return err
}
result, err := p.Results.NewResult()
if err != nil {
return err
}
return MarshalServerInfo(result, serverInfo)
}
type TunnelServer_PogsClient struct {
Client capnp.Client
Conn *rpc.Conn
}
func (c TunnelServer_PogsClient) Close() error {
return c.Conn.Close()
}
func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, auth *Authentication, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.RegisterTunnel(ctx, func(p tunnelrpc.TunnelServer_registerTunnel_Params) error {
authentication, err := p.NewAuth()
if err != nil {
return err
}
err = MarshalAuthentication(authentication, auth)
if err != nil {
return err
}
err = p.SetHostname(hostname)
if err != nil {
return err
}
registrationOptions, err := p.NewOptions()
if err != nil {
return err
}
err = MarshalRegistrationOptions(registrationOptions, options)
if err != nil {
return err
}
return nil
})
retval, err := promise.Result().Struct()
if err != nil {
return nil, err
}
return UnmarshalTunnelRegistration(retval)
}
func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.GetServerInfo(ctx, func(p tunnelrpc.TunnelServer_getServerInfo_Params) error {
return nil
})
retval, err := promise.Result().Struct()
if err != nil {
return nil, err
}
return UnmarshalServerInfo(retval)
}

56
tunnelrpc/tunnelrpc.capnp Normal file
View File

@ -0,0 +1,56 @@
using Go = import "go.capnp";
@0xdb8274f9144abc7e;
$Go.package("tunnelrpc");
$Go.import("github.com/cloudflare/cloudflare-warp/tunnelrpc");
struct Authentication {
key @0 :Text;
email @1 :Text;
originCAKey @2 :Text;
}
struct TunnelRegistration {
err @0 :Text;
# A list of URLs that the tunnel is accessible from.
urls @1 :List(Text);
# Used to inform the client of actions taken.
logLines @2 :List(Text);
# In case of error, whether the client should attempt to reconnect.
permanentFailure @3 :Bool;
}
struct RegistrationOptions {
# The tunnel client's unique identifier, used to verify a reconnection.
clientId @0 :Text;
# Information about the running binary.
version @1 :Text;
os @2 :Text;
# What to do with existing tunnels for the given hostname.
existingTunnelPolicy @3 :ExistingTunnelPolicy;
# If using the balancing policy, identifies the LB pool to use.
poolId @4 :Text;
# Prevents the tunnel from being accessed at <subdomain>.cftunnel.com
exposeInternalHostname @5 :Bool;
# Client-defined tags to associate with the tunnel
tags @6 :List(Tag);
}
struct Tag {
name @0 :Text;
value @1 :Text;
}
enum ExistingTunnelPolicy {
ignore @0;
disconnect @1;
balance @2;
}
struct ServerInfo {
locationName @0 :Text;
}
interface TunnelServer {
registerTunnel @0 (auth :Authentication, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
getServerInfo @1 () -> (result :ServerInfo);
}

1145
tunnelrpc/tunnelrpc.capnp.go Normal file

File diff suppressed because it is too large Load Diff

136
validation/validation.go Normal file
View File

@ -0,0 +1,136 @@
package validation
import (
"fmt"
"net"
"net/url"
"strings"
"golang.org/x/net/idna"
)
const defaultScheme = "http"
var supportedProtocol = [2]string{"http", "https"}
func ValidateHostname(hostname string) (string, error) {
if hostname == "" {
return "", fmt.Errorf("Hostname should not be empty")
}
// users gives url(contains schema) not just hostname
if strings.Contains(hostname, ":") || strings.Contains(hostname, "%3A") {
unescapeHostname, err := url.PathUnescape(hostname)
if err != nil {
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid escape characters %s", hostname, unescapeHostname)
}
hostnameToURL, err := url.Parse(unescapeHostname)
if err != nil {
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid format %s", hostname, hostnameToURL)
}
asciiHostname, err := idna.ToASCII(hostnameToURL.Hostname())
if err != nil {
return "", fmt.Errorf("Hostname(actually a URL) %s has invalid ASCII encdoing %s", hostname, asciiHostname)
}
return asciiHostname, nil
}
asciiHostname, err := idna.ToASCII(hostname)
if err != nil {
return "", fmt.Errorf("Hostname %s has invalid ASCII encdoing %s", hostname, asciiHostname)
}
hostnameToURL, err := url.Parse(asciiHostname)
if err != nil {
return "", fmt.Errorf("Hostname %s is not valid", hostnameToURL)
}
return hostnameToURL.RequestURI(), nil
}
func ValidateUrl(originUrl string) (string, error) {
if originUrl == "" {
return "", fmt.Errorf("Url should not be empty")
}
if net.ParseIP(originUrl) != nil {
return validateIP("", originUrl, "")
} else if strings.HasPrefix(originUrl, "[") && strings.HasSuffix(originUrl, "]") {
// ParseIP doesn't recoginze [::1]
return validateIP("", originUrl[1:len(originUrl)-1], "")
}
host, port, err := net.SplitHostPort(originUrl)
// user might pass in an ip address like 127.0.0.1
if err == nil && net.ParseIP(host) != nil {
return validateIP("", host, port)
}
unescapedUrl, err := url.PathUnescape(originUrl)
if err != nil {
return "", fmt.Errorf("URL %s has invalid escape characters %s", originUrl, unescapedUrl)
}
parsedUrl, err := url.Parse(unescapedUrl)
if err != nil {
return "", fmt.Errorf("URL %s has invalid format", originUrl)
}
// if the url is in the form of host:port, IsAbs() will think host is the schema
var hostname string
hasScheme := parsedUrl.IsAbs() && parsedUrl.Host != ""
if hasScheme {
err := validateScheme(parsedUrl.Scheme)
if err != nil {
return "", err
}
// The earlier check for ip address will miss the case http://[::1]
// and http://[::1]:8080
if net.ParseIP(parsedUrl.Hostname()) != nil {
return validateIP(parsedUrl.Scheme, parsedUrl.Hostname(), parsedUrl.Port())
}
hostname, err = ValidateHostname(parsedUrl.Hostname())
if err != nil {
return "", fmt.Errorf("URL %s has invalid format", originUrl)
}
if parsedUrl.Port() != "" {
return fmt.Sprintf("%s://%s", parsedUrl.Scheme, net.JoinHostPort(hostname, parsedUrl.Port())), nil
}
return fmt.Sprintf("%s://%s", parsedUrl.Scheme, hostname), nil
} else {
if host == "" {
hostname, err = ValidateHostname(originUrl)
if err != nil {
return "", fmt.Errorf("URL no %s has invalid format", originUrl)
}
return fmt.Sprintf("%s://%s", defaultScheme, hostname), nil
} else {
hostname, err = ValidateHostname(host)
if err != nil {
return "", fmt.Errorf("URL %s has invalid format", originUrl)
}
return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil
}
}
}
func validateScheme(scheme string) error {
for _, protocol := range supportedProtocol {
if scheme == protocol {
return nil
}
}
return fmt.Errorf("Currently Cloudflare-Warp does not support %s protocol.", scheme)
}
func validateIP(scheme, host, port string) (string, error) {
if scheme == "" {
scheme = defaultScheme
}
if port != "" {
return fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(host, port)), nil
} else if strings.Contains(host, ":") {
// IPv6
return fmt.Sprintf("%s://[%s]", scheme, host), nil
}
return fmt.Sprintf("%s://%s", scheme, host), nil
}

View File

@ -0,0 +1,136 @@
package validation
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestValidateHostname(t *testing.T) {
var inputHostname string
hostname, err := ValidateHostname(inputHostname)
assert.Equal(t, err, fmt.Errorf("Hostname should not be empty"))
assert.Empty(t, hostname)
inputHostname = "hello.example.com"
hostname, err = ValidateHostname(inputHostname)
assert.Nil(t, err)
assert.Equal(t, "hello.example.com", hostname)
inputHostname = "http://hello.example.com"
hostname, err = ValidateHostname(inputHostname)
assert.Nil(t, err)
assert.Equal(t, "hello.example.com", hostname)
inputHostname = "bücher.example.com"
hostname, err = ValidateHostname(inputHostname)
assert.Nil(t, err)
assert.Equal(t, "xn--bcher-kva.example.com", hostname)
inputHostname = "http://bücher.example.com"
hostname, err = ValidateHostname(inputHostname)
assert.Nil(t, err)
assert.Equal(t, "xn--bcher-kva.example.com", hostname)
inputHostname = "http%3A%2F%2Fhello.example.com"
hostname, err = ValidateHostname(inputHostname)
assert.Nil(t, err)
assert.Equal(t, "hello.example.com", hostname)
}
func TestValidateUrl(t *testing.T) {
validUrl, err := ValidateUrl("")
assert.Equal(t, fmt.Errorf("Url should not be empty"), err)
assert.Empty(t, validUrl)
validUrl, err = ValidateUrl("https://localhost:8080")
assert.Nil(t, err)
assert.Equal(t, "https://localhost:8080", validUrl)
validUrl, err = ValidateUrl("localhost:8080")
assert.Nil(t, err)
assert.Equal(t, "http://localhost:8080", validUrl)
validUrl, err = ValidateUrl("http://localhost")
assert.Nil(t, err)
assert.Equal(t, "http://localhost", validUrl)
validUrl, err = ValidateUrl("http://127.0.0.1:8080")
assert.Nil(t, err)
assert.Equal(t, "http://127.0.0.1:8080", validUrl)
validUrl, err = ValidateUrl("127.0.0.1:8080")
assert.Nil(t, err)
assert.Equal(t, "http://127.0.0.1:8080", validUrl)
validUrl, err = ValidateUrl("127.0.0.1")
assert.Nil(t, err)
assert.Equal(t, "http://127.0.0.1", validUrl)
validUrl, err = ValidateUrl("https://127.0.0.1:8080")
assert.Nil(t, err)
assert.Equal(t, "https://127.0.0.1:8080", validUrl)
validUrl, err = ValidateUrl("[::1]:8080")
assert.Nil(t, err)
assert.Equal(t, "http://[::1]:8080", validUrl)
validUrl, err = ValidateUrl("http://[::1]")
assert.Nil(t, err)
assert.Equal(t, "http://[::1]", validUrl)
validUrl, err = ValidateUrl("http://[::1]:8080")
assert.Nil(t, err)
assert.Equal(t, "http://[::1]:8080", validUrl)
validUrl, err = ValidateUrl("[::1]")
assert.Nil(t, err)
assert.Equal(t, "http://[::1]", validUrl)
validUrl, err = ValidateUrl("https://example.com")
assert.Nil(t, err)
assert.Equal(t, "https://example.com", validUrl)
validUrl, err = ValidateUrl("example.com")
assert.Nil(t, err)
assert.Equal(t, "http://example.com", validUrl)
validUrl, err = ValidateUrl("http://hello.example.com")
assert.Nil(t, err)
assert.Equal(t, "http://hello.example.com", validUrl)
validUrl, err = ValidateUrl("hello.example.com")
assert.Nil(t, err)
assert.Equal(t, "http://hello.example.com", validUrl)
validUrl, err = ValidateUrl("hello.example.com:8080")
assert.Nil(t, err)
assert.Equal(t, "http://hello.example.com:8080", validUrl)
validUrl, err = ValidateUrl("https://hello.example.com:8080")
assert.Nil(t, err)
assert.Equal(t, "https://hello.example.com:8080", validUrl)
validUrl, err = ValidateUrl("https://bücher.example.com")
assert.Nil(t, err)
assert.Equal(t, "https://xn--bcher-kva.example.com", validUrl)
validUrl, err = ValidateUrl("bücher.example.com")
assert.Nil(t, err)
assert.Equal(t, "http://xn--bcher-kva.example.com", validUrl)
validUrl, err = ValidateUrl("https%3A%2F%2Fhello.example.com")
assert.Nil(t, err)
assert.Equal(t, "https://hello.example.com", validUrl)
validUrl, err = ValidateUrl("ftp://alex:12345@hello.example.com:8080/robot.txt")
assert.Equal(t, "Currently Cloudflare-Warp does not support ftp protocol.", err.Error())
assert.Empty(t, validUrl)
validUrl, err = ValidateUrl("https://alex:12345@hello.example.com:8080")
assert.Nil(t, err)
assert.Equal(t, "https://hello.example.com:8080", validUrl)
}