AUTH-1070: added SSH/protocol forwarding
This commit is contained in:
parent
41916365b6
commit
fa92441415
|
@ -0,0 +1,131 @@
|
|||
//Package carrier provides a WebSocket proxy to carry or proxy a connection
|
||||
//from the local client to the edge. See it as a wrapper around any protocol
|
||||
//that it packages up in a WebSocket connection to the edge.
|
||||
package carrier
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/access"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// StdinoutStream is empty struct for wrapping stdin/stdout
|
||||
// into a single ReadWriter
|
||||
type StdinoutStream struct {
|
||||
}
|
||||
|
||||
// Read will read from Stdin
|
||||
func (c *StdinoutStream) Read(p []byte) (int, error) {
|
||||
return os.Stdin.Read(p)
|
||||
|
||||
}
|
||||
|
||||
// Write will write to Stdout
|
||||
func (c *StdinoutStream) Write(p []byte) (int, error) {
|
||||
return os.Stdout.Write(p)
|
||||
}
|
||||
|
||||
// StartClient will copy the data from stdin/stdout over a WebSocket connection
|
||||
// to the edge (originURL)
|
||||
func StartClient(logger *logrus.Logger, originURL string, stream io.ReadWriter) error {
|
||||
return serveStream(logger, originURL, stream)
|
||||
}
|
||||
|
||||
// StartServer will setup a server on a specified port and copy data over a WebSocket connection
|
||||
// to the edge (originURL)
|
||||
func StartServer(logger *logrus.Logger, address, originURL string, shutdownC <-chan struct{}) error {
|
||||
listener, err := net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("failed to start forwarding server")
|
||||
return err
|
||||
}
|
||||
defer listener.Close()
|
||||
for {
|
||||
select {
|
||||
case <-shutdownC:
|
||||
return nil
|
||||
default:
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go serveConnection(logger, conn, originURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// serveConnection handles connections for the StartServer call
|
||||
func serveConnection(logger *logrus.Logger, c net.Conn, originURL string) {
|
||||
defer c.Close()
|
||||
serveStream(logger, originURL, c)
|
||||
}
|
||||
|
||||
// serveStream will serve the data over the WebSocket stream
|
||||
func serveStream(logger *logrus.Logger, originURL string, conn io.ReadWriter) error {
|
||||
wsConn, err := createWebsocketStream(originURL)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("failed to create websocket stream")
|
||||
return err
|
||||
}
|
||||
defer wsConn.Close()
|
||||
|
||||
websocket.Stream(wsConn, conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createWebsocketStream will create a WebSocket connection to stream data over
|
||||
// It also handles redirects from Access and will present that flow if
|
||||
// the token is not present on the request
|
||||
func createWebsocketStream(originURL string) (*websocket.Conn, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, originURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wsConn, resp, err := websocket.ClientConnect(req, nil)
|
||||
if err != nil && resp != nil && resp.StatusCode > 300 {
|
||||
location, err := resp.Location()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !strings.Contains(location.String(), "cdn-cgi/access/login") {
|
||||
return nil, errors.New("not an Access redirect")
|
||||
}
|
||||
req, err := buildAccessRequest(originURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wsConn, _, err = websocket.ClientConnect(req, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &websocket.Conn{Conn: wsConn}, nil
|
||||
}
|
||||
|
||||
// buildAccessRequest builds an HTTP request with the Access token set
|
||||
func buildAccessRequest(originURL string) (*http.Request, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, originURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token, err := access.FetchToken(req.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("cf-access-token", token)
|
||||
|
||||
return req, nil
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
package carrier
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
ws "github.com/gorilla/websocket"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
// example in Sec-Websocket-Key in rfc6455
|
||||
testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ=="
|
||||
)
|
||||
|
||||
type testStreamer struct {
|
||||
buf *bytes.Buffer
|
||||
l sync.RWMutex
|
||||
}
|
||||
|
||||
func newTestStream() *testStreamer {
|
||||
return &testStreamer{buf: new(bytes.Buffer)}
|
||||
}
|
||||
|
||||
func (s *testStreamer) Read(p []byte) (int, error) {
|
||||
s.l.RLock()
|
||||
defer s.l.RUnlock()
|
||||
return s.buf.Read(p)
|
||||
|
||||
}
|
||||
|
||||
func (s *testStreamer) Write(p []byte) (int, error) {
|
||||
s.l.Lock()
|
||||
defer s.l.Unlock()
|
||||
return s.buf.Write(p)
|
||||
}
|
||||
|
||||
func TestStartClient(t *testing.T) {
|
||||
message := "Good morning Austin! Time for another sunny day in the great state of Texas."
|
||||
logger := logrus.New()
|
||||
ts := newTestWebSocketServer()
|
||||
defer ts.Close()
|
||||
|
||||
buf := newTestStream()
|
||||
err := StartClient(logger, "http://"+ts.Listener.Addr().String(), buf)
|
||||
assert.NoError(t, err)
|
||||
buf.Write([]byte(message))
|
||||
|
||||
readBuffer := make([]byte, len(message))
|
||||
buf.Read(readBuffer)
|
||||
assert.Equal(t, message, string(readBuffer))
|
||||
}
|
||||
|
||||
func TestStartServer(t *testing.T) {
|
||||
listenerAddress := "localhost:1117"
|
||||
message := "Good morning Austin! Time for another sunny day in the great state of Texas."
|
||||
logger := logrus.New()
|
||||
shutdownC := make(chan struct{})
|
||||
ts := newTestWebSocketServer()
|
||||
defer ts.Close()
|
||||
|
||||
go func() {
|
||||
StartServer(logger, listenerAddress, "http://"+ts.Listener.Addr().String(), shutdownC)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", listenerAddress)
|
||||
assert.NoError(t, err)
|
||||
conn.Write([]byte(message))
|
||||
|
||||
readBuffer := make([]byte, len(message))
|
||||
conn.Read(readBuffer)
|
||||
assert.Equal(t, string(readBuffer), message)
|
||||
}
|
||||
|
||||
func newTestWebSocketServer() *httptest.Server {
|
||||
upgrader := ws.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, _ := upgrader.Upgrade(w, r, nil)
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if err := conn.WriteMessage(mt, []byte(message)); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
|
||||
req, err := http.NewRequest("GET", url, stream)
|
||||
if err != nil {
|
||||
t.Fatalf("testRequestHeader error")
|
||||
}
|
||||
|
||||
req.Header.Add("Connection", "Upgrade")
|
||||
req.Header.Add("Upgrade", "WebSocket")
|
||||
req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
|
||||
req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
|
||||
req.Header.Add("Sec-Websocket-Version", "13")
|
||||
req.Header.Add("User-Agent", "curl/7.59.0")
|
||||
|
||||
return req
|
||||
}
|
|
@ -86,10 +86,12 @@ func login(c *cli.Context) error {
|
|||
logger.Errorf("Please provide the url of the Access application\n")
|
||||
return err
|
||||
}
|
||||
if _, err := fetchToken(c, appURL); err != nil {
|
||||
token, err := FetchToken(appURL)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to fetch token: %s\n", err)
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(os.Stdout, "Successfully fetched your token:\n\n%s\n\n", string(token))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -115,7 +117,7 @@ func curl(c *cli.Context) error {
|
|||
logger.Warn("You don't have an Access token set. Please run access token <access application> to fetch one.")
|
||||
return shell.Run("curl", cmdArgs...)
|
||||
}
|
||||
token, err = fetchToken(c, appURL)
|
||||
token, err = FetchToken(appURL)
|
||||
if err != nil {
|
||||
logger.Error("Failed to refresh token: ", err)
|
||||
return err
|
||||
|
|
|
@ -15,15 +15,13 @@ import (
|
|||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
||||
var logger = log.CreateLogger()
|
||||
|
||||
// fetchToken will either load a stored token or generate a new one
|
||||
func fetchToken(c *cli.Context, appURL *url.URL) (string, error) {
|
||||
// FetchToken will either load a stored token or generate a new one
|
||||
func FetchToken(appURL *url.URL) (string, error) {
|
||||
if token, err := getTokenIfExists(appURL); token != "" && err == nil {
|
||||
fmt.Fprintf(os.Stdout, "You have an existing token:\n\n%s\n\n", token)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
|
@ -36,12 +34,11 @@ func fetchToken(c *cli.Context, appURL *url.URL) (string, error) {
|
|||
// we want to send to the transfer service. the key is token and the value
|
||||
// is blank (basically just the id generated in the transfer service)
|
||||
const resourceName, key, value = "token", "token", ""
|
||||
token, err := transfer.Run(c, appURL, resourceName, key, value, path, true)
|
||||
token, err := transfer.Run(appURL, resourceName, key, value, path, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stdout, "Successfully fetched your token:\n\n%s\n\n", string(token))
|
||||
return string(token), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@ import (
|
|||
"github.com/cloudflare/cloudflared/cmd/cloudflared/encrypter"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/shell"
|
||||
"github.com/cloudflare/cloudflared/log"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -32,7 +31,7 @@ var logger = log.CreateLogger()
|
|||
// The "dance" we refer to is building a HTTP request, opening that in a browser waiting for
|
||||
// the user to complete an action, while it long polls in the background waiting for an
|
||||
// action to be completed to download the resource.
|
||||
func Run(c *cli.Context, transferURL *url.URL, resourceName, key, value, path string, shouldEncrypt bool) ([]byte, error) {
|
||||
func Run(transferURL *url.URL, resourceName, key, value, path string, shouldEncrypt bool) ([]byte, error) {
|
||||
encrypterClient, err := encrypter.New("cloudflared_priv.pem", "cloudflared_pub.pem")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -49,16 +48,10 @@ func Run(c *cli.Context, transferURL *url.URL, resourceName, key, value, path st
|
|||
fmt.Fprintf(os.Stdout, "A browser window should have opened at the following URL:\n\n%s\n\nIf the browser failed to open, open it yourself and visit the URL above.\n", requestURL)
|
||||
}
|
||||
|
||||
// for local debugging
|
||||
baseURL := baseStoreURL
|
||||
if c.IsSet("url") {
|
||||
baseURL = c.String("url")
|
||||
}
|
||||
|
||||
var resourceData []byte
|
||||
|
||||
if shouldEncrypt {
|
||||
buf, key, err := transferRequest(baseURL + filepath.Join("transfer", encrypterClient.PublicKey()))
|
||||
buf, key, err := transferRequest(baseStoreURL + filepath.Join("transfer", encrypterClient.PublicKey()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -74,7 +67,7 @@ func Run(c *cli.Context, transferURL *url.URL, resourceName, key, value, path st
|
|||
|
||||
resourceData = decrypted
|
||||
} else {
|
||||
buf, _, err := transferRequest(baseURL + filepath.Join(encrypterClient.PublicKey()))
|
||||
buf, _, err := transferRequest(baseStoreURL + filepath.Join(encrypterClient.PublicKey()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
package tunnel
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
|
||||
"github.com/cloudflare/cloudflared/carrier"
|
||||
"github.com/cloudflare/cloudflared/validation"
|
||||
"github.com/pkg/errors"
|
||||
cli "gopkg.in/urfave/cli.v2"
|
||||
)
|
||||
|
||||
// ssh will start a WS proxy server for server mode
|
||||
// or copy from stdin/stdout for client mode
|
||||
// useful for proxying other protocols (like ssh) over websockets
|
||||
// (which you can put Access in front of)
|
||||
func ssh(c *cli.Context) error {
|
||||
hostname, err := validation.ValidateHostname(c.String("hostname"))
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Invalid hostname")
|
||||
return errors.Wrap(err, "invalid hostname")
|
||||
}
|
||||
|
||||
if c.NArg() > 0 || c.IsSet("url") {
|
||||
localForwarder, err := validateUrl(c)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Error validating origin URL")
|
||||
return errors.Wrap(err, "error validating origin URL")
|
||||
}
|
||||
forwarder, err := url.Parse(localForwarder)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Error validating origin URL")
|
||||
return errors.Wrap(err, "error validating origin URL")
|
||||
}
|
||||
return carrier.StartServer(logger, forwarder.Host, "https://"+hostname, shutdownC)
|
||||
}
|
||||
|
||||
return carrier.StartClient(logger, "https://"+hostname, &carrier.StdinoutStream{})
|
||||
}
|
|
@ -3,6 +3,7 @@ package tunnel
|
|||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"runtime/trace"
|
||||
"sync"
|
||||
|
@ -19,6 +20,7 @@ import (
|
|||
"github.com/cloudflare/cloudflared/metrics"
|
||||
"github.com/cloudflare/cloudflared/origin"
|
||||
"github.com/cloudflare/cloudflared/tunneldns"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
"github.com/coreos/go-systemd/daemon"
|
||||
"github.com/facebookgo/grace/gracenet"
|
||||
"github.com/pkg/errors"
|
||||
|
@ -137,6 +139,21 @@ func Commands() []*cli.Command {
|
|||
},
|
||||
Hidden: true,
|
||||
},
|
||||
{
|
||||
Name: "ssh",
|
||||
Action: ssh,
|
||||
Usage: `ssh -o ProxyCommand="cloudflared tunnel ssh --hostname %h" ssh.warptunnels.org`,
|
||||
ArgsUsage: "[origin-url]",
|
||||
Description: `The ssh subcommand wraps sends data over a WebSocket proxy to the Cloudflare edge.`,
|
||||
Flags: []cli.Flag{
|
||||
&cli.StringFlag{
|
||||
Name: "hostname",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "url",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var subcommands []*cli.Command
|
||||
|
@ -308,6 +325,20 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
|
|||
c.Set("url", "https://"+helloListener.Addr().String())
|
||||
}
|
||||
|
||||
if c.IsSet("ws-proxy-server") {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:")
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Cannot start Websocket Proxy Server")
|
||||
return errors.Wrap(err, "Cannot start Websocket Proxy Server")
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
errC <- websocket.StartProxyServer(logger, listener, c.String("remote"), shutdownC)
|
||||
}()
|
||||
c.Set("url", "http://"+listener.Addr().String())
|
||||
}
|
||||
|
||||
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, protoLogger)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -447,6 +478,13 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
EnvVars: []string{"TUNNEL_URL"},
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "remote",
|
||||
Value: "localhost:22",
|
||||
Usage: "Connect to the local server over tcp at `remote`.",
|
||||
EnvVars: []string{"TUNNEL_REMOTE"},
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "hostname",
|
||||
Usage: "Set a hostname on a Cloudflare zone to route traffic through this tunnel.",
|
||||
|
@ -549,6 +587,13 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "ws-proxy-server",
|
||||
Value: false,
|
||||
Usage: "Run WS proxy Server",
|
||||
EnvVars: []string{"TUNNEL_WS_PROXY"},
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "pidfile",
|
||||
Usage: "Write the application's PID to this file after first successful connection.",
|
||||
|
|
|
@ -33,7 +33,7 @@ func login(c *cli.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
_, err = transfer.Run(c, loginURL, "cert", "callback", callbackStoreURL, path, false)
|
||||
_, err = transfer.Run(loginURL, "cert", "callback", callbackStoreURL, path, false)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to write the certificate due to the following error:\n%v\n\nYour browser will download the certificate instead. You will have to manually\ncopy it to the following path:\n\n%s\n", err, path)
|
||||
return err
|
||||
|
|
|
@ -9,11 +9,24 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var stripWebsocketHeaders = []string {
|
||||
const (
|
||||
// Time allowed to write a message to the peer.
|
||||
writeWait = 10 * time.Second
|
||||
|
||||
// Time allowed to read the next pong message from the peer.
|
||||
pongWait = 60 * time.Second
|
||||
|
||||
// Send pings to peer with this period. Must be less than pongWait.
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
)
|
||||
|
||||
var stripWebsocketHeaders = []string{
|
||||
"Upgrade",
|
||||
"Connection",
|
||||
"Sec-Websocket-Key",
|
||||
|
@ -21,6 +34,32 @@ var stripWebsocketHeaders = []string {
|
|||
"Sec-Websocket-Extensions",
|
||||
}
|
||||
|
||||
// Conn is a wrapper around the standard gorilla websocket
|
||||
// but implements a ReadWriter
|
||||
type Conn struct {
|
||||
*websocket.Conn
|
||||
}
|
||||
|
||||
// Read will read messages from the websocket connection
|
||||
func (c *Conn) Read(p []byte) (int, error) {
|
||||
_, message, err := c.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return copy(p, message), nil
|
||||
|
||||
}
|
||||
|
||||
// Write will write messages to the websocket connection
|
||||
func (c *Conn) Write(p []byte) (int, error) {
|
||||
if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
|
||||
func IsWebSocketUpgrade(req *http.Request) bool {
|
||||
return websocket.IsWebSocketUpgrade(req)
|
||||
|
@ -36,7 +75,7 @@ func ClientConnect(req *http.Request, tlsClientConfig *tls.Config) (*websocket.C
|
|||
d := &websocket.Dialer{TLSClientConfig: tlsClientConfig}
|
||||
conn, response, err := d.Dial(req.URL.String(), wsHeaders)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, response, err
|
||||
}
|
||||
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
|
||||
return conn, response, err
|
||||
|
@ -74,16 +113,58 @@ func Stream(conn, backendConn io.ReadWriter) {
|
|||
<-proxyDone
|
||||
}
|
||||
|
||||
// StartProxyServer will start a websocket server that will decode
|
||||
// the websocket data and write the resulting data to the provided
|
||||
// address
|
||||
func StartProxyServer(logger *logrus.Logger, listener net.Listener, remote string, shutdownC <-chan struct{}) error {
|
||||
upgrader := websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
|
||||
httpServer := &http.Server{Addr: listener.Addr().String(), Handler: nil}
|
||||
go func() {
|
||||
<-shutdownC
|
||||
httpServer.Close()
|
||||
}()
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
stream, err := net.Dial("tcp", remote)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Cannot connect to remote.")
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("failed to upgrade")
|
||||
return
|
||||
}
|
||||
conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
|
||||
done := make(chan struct{})
|
||||
go pinger(logger, conn, done)
|
||||
defer func() {
|
||||
<-done
|
||||
conn.Close()
|
||||
}()
|
||||
Stream(&Conn{conn}, stream)
|
||||
})
|
||||
|
||||
return httpServer.Serve(listener)
|
||||
}
|
||||
|
||||
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
|
||||
// Sec-WebSocket-Version and Sec-Websocket-Extensions headers.
|
||||
// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194.
|
||||
func websocketHeaders(req *http.Request) http.Header {
|
||||
wsHeaders := make(http.Header)
|
||||
for key, val := range req.Header {
|
||||
wsHeaders[key] = val
|
||||
wsHeaders[key] = val
|
||||
}
|
||||
// Assume the header keys are in canonical format.
|
||||
for _, header := range stripWebsocketHeaders {
|
||||
for _, header := range stripWebsocketHeaders {
|
||||
wsHeaders.Del(header)
|
||||
}
|
||||
return wsHeaders
|
||||
|
@ -115,3 +196,19 @@ func changeRequestScheme(req *http.Request) string {
|
|||
return req.URL.Scheme
|
||||
}
|
||||
}
|
||||
|
||||
// pinger simulates the websocket connection to keep it alive
|
||||
func pinger(logger *logrus.Logger, ws *websocket.Conn, done chan struct{}) {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
|
||||
logger.WithError(err).Debug("failed to send ping message")
|
||||
}
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,100 +1,138 @@
|
|||
package websocket
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"testing"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
"golang.org/x/net/websocket"
|
||||
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
)
|
||||
|
||||
const (
|
||||
// example in Sec-Websocket-Key in rfc6455
|
||||
testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ=="
|
||||
// example Sec-Websocket-Accept in rfc6455
|
||||
testSecWebsocketAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
||||
// example in Sec-Websocket-Key in rfc6455
|
||||
testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ=="
|
||||
// example Sec-Websocket-Accept in rfc6455
|
||||
testSecWebsocketAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
||||
)
|
||||
|
||||
func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
|
||||
req, err := http.NewRequest("GET", url, stream)
|
||||
if err != nil {
|
||||
t.Fatalf("testRequestHeader error")
|
||||
}
|
||||
req, err := http.NewRequest("GET", url, stream)
|
||||
if err != nil {
|
||||
t.Fatalf("testRequestHeader error")
|
||||
}
|
||||
|
||||
req.Header.Add("Connection", "Upgrade")
|
||||
req.Header.Add("Upgrade", "WebSocket")
|
||||
req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
|
||||
req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
|
||||
req.Header.Add("Sec-Websocket-Version", "13")
|
||||
req.Header.Add("User-Agent", "curl/7.59.0")
|
||||
req.Header.Add("Connection", "Upgrade")
|
||||
req.Header.Add("Upgrade", "WebSocket")
|
||||
req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
|
||||
req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
|
||||
req.Header.Add("Sec-Websocket-Version", "13")
|
||||
req.Header.Add("User-Agent", "curl/7.59.0")
|
||||
|
||||
return req
|
||||
return req
|
||||
}
|
||||
|
||||
func websocketClientTLSConfig(t *testing.T) *tls.Config {
|
||||
certPool, err := tlsconfig.LoadOriginCertPool(nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, certPool)
|
||||
return &tls.Config{RootCAs: certPool}
|
||||
certPool, err := tlsconfig.LoadOriginCertPool(nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, certPool)
|
||||
return &tls.Config{RootCAs: certPool}
|
||||
}
|
||||
|
||||
func TestWebsocketHeaders(t *testing.T) {
|
||||
req := testRequest(t, "http://example.com", nil)
|
||||
wsHeaders := websocketHeaders(req)
|
||||
for _, header := range stripWebsocketHeaders {
|
||||
assert.Empty(t, wsHeaders[header])
|
||||
}
|
||||
assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
|
||||
req := testRequest(t, "http://example.com", nil)
|
||||
wsHeaders := websocketHeaders(req)
|
||||
for _, header := range stripWebsocketHeaders {
|
||||
assert.Empty(t, wsHeaders[header])
|
||||
}
|
||||
assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
|
||||
}
|
||||
|
||||
func TestGenerateAcceptKey(t *testing.T) {
|
||||
req := testRequest(t, "http://example.com", nil)
|
||||
assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(req))
|
||||
req := testRequest(t, "http://example.com", nil)
|
||||
assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(req))
|
||||
}
|
||||
|
||||
func TestServe(t *testing.T) {
|
||||
logger := logrus.New()
|
||||
shutdownC := make(chan struct{})
|
||||
errC := make(chan error)
|
||||
listener, err := hello.CreateTLSListener("localhost:1111")
|
||||
assert.NoError(t, err)
|
||||
defer listener.Close()
|
||||
logger := logrus.New()
|
||||
shutdownC := make(chan struct{})
|
||||
errC := make(chan error)
|
||||
listener, err := hello.CreateTLSListener("localhost:1111")
|
||||
assert.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
go func() {
|
||||
errC <- hello.StartHelloWorldServer(logger, listener, shutdownC)
|
||||
}()
|
||||
go func() {
|
||||
errC <- hello.StartHelloWorldServer(logger, listener, shutdownC)
|
||||
}()
|
||||
|
||||
req := testRequest(t, "https://localhost:1111/ws", nil)
|
||||
req := testRequest(t, "https://localhost:1111/ws", nil)
|
||||
|
||||
tlsConfig := websocketClientTLSConfig(t)
|
||||
assert.NotNil(t, tlsConfig)
|
||||
conn, resp, err := ClientConnect(req, tlsConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
|
||||
tlsConfig := websocketClientTLSConfig(t)
|
||||
assert.NotNil(t, tlsConfig)
|
||||
conn, resp, err := ClientConnect(req, tlsConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
messageSize := rand.Int() % 2048 + 1
|
||||
clientMessage := make([]byte, messageSize)
|
||||
// rand.Read always returns len(clientMessage) and a nil error
|
||||
rand.Read(clientMessage)
|
||||
err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
|
||||
assert.NoError(t, err)
|
||||
for i := 0; i < 1000; i++ {
|
||||
messageSize := rand.Int()%2048 + 1
|
||||
clientMessage := make([]byte, messageSize)
|
||||
// rand.Read always returns len(clientMessage) and a nil error
|
||||
rand.Read(clientMessage)
|
||||
err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
|
||||
assert.NoError(t, err)
|
||||
|
||||
messageType, message, err := conn.ReadMessage()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, websocket.BinaryFrame, messageType)
|
||||
assert.Equal(t, clientMessage, message)
|
||||
}
|
||||
messageType, message, err := conn.ReadMessage()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, websocket.BinaryFrame, messageType)
|
||||
assert.Equal(t, clientMessage, message)
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
close(shutdownC)
|
||||
<-errC
|
||||
conn.Close()
|
||||
close(shutdownC)
|
||||
<-errC
|
||||
}
|
||||
|
||||
// func TestStartProxyServer(t *testing.T) {
|
||||
// var wg sync.WaitGroup
|
||||
// remoteAddress := "localhost:1113"
|
||||
// listenerAddress := "localhost:1112"
|
||||
// message := "Good morning Austin! Time for another sunny day in the great state of Texas."
|
||||
// logger := logrus.New()
|
||||
// shutdownC := make(chan struct{})
|
||||
|
||||
// listener, err := net.Listen("tcp", listenerAddress)
|
||||
// assert.NoError(t, err)
|
||||
// defer listener.Close()
|
||||
|
||||
// remoteListener, err := net.Listen("tcp", remoteAddress)
|
||||
// assert.NoError(t, err)
|
||||
// defer remoteListener.Close()
|
||||
|
||||
// wg.Add(1)
|
||||
// go func() {
|
||||
// defer wg.Done()
|
||||
// conn, err := remoteListener.Accept()
|
||||
// assert.NoError(t, err)
|
||||
// buf := make([]byte, len(message))
|
||||
// conn.Read(buf)
|
||||
// assert.Equal(t, string(buf), message)
|
||||
// }()
|
||||
|
||||
// go func() {
|
||||
// StartProxyServer(logger, listener, remoteAddress, shutdownC)
|
||||
// }()
|
||||
|
||||
// req := testRequest(t, fmt.Sprintf("http://%s/", listenerAddress), nil)
|
||||
// conn, _, err := ClientConnect(req, nil)
|
||||
// assert.NoError(t, err)
|
||||
// err = conn.WriteMessage(1, []byte(message))
|
||||
// assert.NoError(t, err)
|
||||
// wg.Wait()
|
||||
// }
|
||||
|
|
Loading…
Reference in New Issue