TUN-3817: Adds tests for websocket based streaming regression
This commit is contained in:
parent
6681d179dc
commit
a6c2348127
|
@ -5,6 +5,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
gws "github.com/gorilla/websocket"
|
gws "github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
@ -15,6 +16,7 @@ type OriginConnection interface {
|
||||||
// Stream should generally be implemented as a bidirectional io.Copy.
|
// Stream should generally be implemented as a bidirectional io.Copy.
|
||||||
Stream(tunnelConn io.ReadWriter)
|
Stream(tunnelConn io.ReadWriter)
|
||||||
Close()
|
Close()
|
||||||
|
Type() connection.Type
|
||||||
}
|
}
|
||||||
|
|
||||||
type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn)
|
type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn)
|
||||||
|
@ -57,6 +59,10 @@ func (tc *tcpConnection) Close() {
|
||||||
tc.conn.Close()
|
tc.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (*tcpConnection) Type() connection.Type {
|
||||||
|
return connection.TypeTCP
|
||||||
|
}
|
||||||
|
|
||||||
// wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets.
|
// wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets.
|
||||||
// TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does.
|
// TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does.
|
||||||
type wsConnection struct {
|
type wsConnection struct {
|
||||||
|
@ -73,6 +79,10 @@ func (wsc *wsConnection) Close() {
|
||||||
wsc.wsConn.Close()
|
wsc.wsConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (wsc *wsConnection) Type() connection.Type {
|
||||||
|
return connection.TypeWebsocket
|
||||||
|
}
|
||||||
|
|
||||||
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) {
|
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) {
|
||||||
d := &gws.Dialer{
|
d := &gws.Dialer{
|
||||||
TLSClientConfig: transport.TLSClientConfig,
|
TLSClientConfig: transport.TLSClientConfig,
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -39,6 +40,12 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
return o.transport.RoundTrip(req)
|
return o.transport.RoundTrip(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, error) {
|
||||||
|
req.URL.Host = o.url.Host
|
||||||
|
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
|
||||||
|
return newWSConnection(o.transport, req)
|
||||||
|
}
|
||||||
|
|
||||||
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
// Rewrite the request URL so that it goes to the Hello World server.
|
// Rewrite the request URL so that it goes to the Hello World server.
|
||||||
req.URL.Host = o.server.Addr().String()
|
req.URL.Host = o.server.Addr().String()
|
||||||
|
|
|
@ -52,6 +52,9 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
cfRay := findCfRayHeader(req)
|
cfRay := findCfRayHeader(req)
|
||||||
lbProbe := isLBProbeRequest(req)
|
lbProbe := isLBProbeRequest(req)
|
||||||
|
|
||||||
|
serveCtx, cancel := context.WithCancel(req.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
p.appendTagHeaders(req)
|
p.appendTagHeaders(req)
|
||||||
if sourceConnectionType == connection.TypeTCP {
|
if sourceConnectionType == connection.TypeTCP {
|
||||||
if p.warpRouting == nil {
|
if p.warpRouting == nil {
|
||||||
|
@ -59,7 +62,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
p.log.Error().Msg(err.Error())
|
p.log.Error().Msg(err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
resp, err := p.handleProxyConn(w, req, nil, p.warpRouting.Proxy)
|
resp, err := p.proxyConnection(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.logRequestError(err, cfRay, ingress.ServiceWarpRouting)
|
p.logRequestError(err, cfRay, ingress.ServiceWarpRouting)
|
||||||
w.WriteErrorResponse()
|
w.WriteErrorResponse()
|
||||||
|
@ -83,12 +86,6 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
respHeader := http.Header{}
|
|
||||||
if sourceConnectionType == connection.TypeWebsocket {
|
|
||||||
go websocket.NewConn(w, p.log).Pinger(req.Context())
|
|
||||||
respHeader = websocket.NewResponseHeader(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
|
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
|
||||||
req.Header.Set("Host", hostHeader)
|
req.Header.Set("Host", hostHeader)
|
||||||
req.Host = hostHeader
|
req.Host = hostHeader
|
||||||
|
@ -99,7 +96,8 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service)
|
p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service)
|
||||||
return fmt.Errorf("Not a connection-oriented service")
|
return fmt.Errorf("Not a connection-oriented service")
|
||||||
}
|
}
|
||||||
resp, err := p.handleProxyConn(w, req, respHeader, connectionProxy)
|
|
||||||
|
resp, err := p.proxyConnection(serveCtx, w, req, sourceConnectionType, connectionProxy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.logErrorAndWriteResponse(w, err, cfRay, ruleNum)
|
p.logErrorAndWriteResponse(w, err, cfRay, ruleNum)
|
||||||
return err
|
return err
|
||||||
|
@ -109,31 +107,6 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *proxy) handleProxyConn(
|
|
||||||
w connection.ResponseWriter,
|
|
||||||
req *http.Request,
|
|
||||||
respHeader http.Header,
|
|
||||||
connectionProxy ingress.StreamBasedOriginProxy) (*http.Response, error) {
|
|
||||||
connClosedChan := make(chan struct{})
|
|
||||||
err := p.proxyConnection(connClosedChan, w, req, connectionProxy)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
status := http.StatusSwitchingProtocols
|
|
||||||
resp := &http.Response{
|
|
||||||
Status: http.StatusText(status),
|
|
||||||
StatusCode: status,
|
|
||||||
Header: respHeader,
|
|
||||||
ContentLength: -1,
|
|
||||||
}
|
|
||||||
w.WriteRespHeaders(http.StatusSwitchingProtocols, nil)
|
|
||||||
|
|
||||||
<-connClosedChan
|
|
||||||
return resp, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *proxy) logErrorAndWriteResponse(w connection.ResponseWriter, err error, cfRay string, ruleNum int) {
|
func (p *proxy) logErrorAndWriteResponse(w connection.ResponseWriter, err error, cfRay string, ruleNum int) {
|
||||||
p.logRequestError(err, cfRay, ruleNum)
|
p.logRequestError(err, cfRay, ruleNum)
|
||||||
w.WriteErrorResponse()
|
w.WriteErrorResponse()
|
||||||
|
@ -186,27 +159,51 @@ func (p *proxy) proxyHTTP(w connection.ResponseWriter, req *http.Request, rule *
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *proxy) proxyConnection(connClosedChan chan struct{},
|
func (p *proxy) proxyConnection(
|
||||||
conn io.ReadWriter, req *http.Request, connectionProxy ingress.StreamBasedOriginProxy) error {
|
serveCtx context.Context,
|
||||||
|
w connection.ResponseWriter,
|
||||||
|
req *http.Request,
|
||||||
|
sourceConnectionType connection.Type,
|
||||||
|
connectionProxy ingress.StreamBasedOriginProxy,
|
||||||
|
) (*http.Response, error) {
|
||||||
originConn, err := connectionProxy.EstablishConnection(req)
|
originConn, err := connectionProxy.EstablishConnection(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
serveCtx, cancel := context.WithCancel(req.Context())
|
var eyeballConn io.ReadWriter = w
|
||||||
|
respHeader := http.Header{}
|
||||||
|
if sourceConnectionType == connection.TypeWebsocket {
|
||||||
|
wsReadWriter := websocket.NewConn(serveCtx, w, p.log)
|
||||||
|
// If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames
|
||||||
|
if originConn.Type() != sourceConnectionType {
|
||||||
|
eyeballConn = wsReadWriter
|
||||||
|
}
|
||||||
|
respHeader = websocket.NewResponseHeader(req)
|
||||||
|
}
|
||||||
|
status := http.StatusSwitchingProtocols
|
||||||
|
resp := &http.Response{
|
||||||
|
Status: http.StatusText(status),
|
||||||
|
StatusCode: status,
|
||||||
|
Header: respHeader,
|
||||||
|
ContentLength: -1,
|
||||||
|
}
|
||||||
|
w.WriteRespHeaders(http.StatusSwitchingProtocols, respHeader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Error writing response header")
|
||||||
|
}
|
||||||
|
|
||||||
|
streamCtx, cancel := context.WithCancel(serveCtx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// serveCtx is done if req is cancelled, or streamWebsocket returns
|
// streamCtx is done if req is cancelled or if Stream returns
|
||||||
<-serveCtx.Done()
|
<-streamCtx.Done()
|
||||||
originConn.Close()
|
originConn.Close()
|
||||||
close(connClosedChan)
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
originConn.Stream(eyeballConn)
|
||||||
originConn.Stream(conn)
|
return resp, nil
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
|
func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
|
||||||
|
|
|
@ -17,11 +17,10 @@ import (
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||||
"github.com/cloudflare/cloudflared/connection"
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
|
||||||
"github.com/cloudflare/cloudflared/hello"
|
"github.com/cloudflare/cloudflared/hello"
|
||||||
"github.com/cloudflare/cloudflared/ingress"
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
gorillaWS "github.com/gorilla/websocket"
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
"github.com/gobwas/ws/wsutil"
|
"github.com/gobwas/ws/wsutil"
|
||||||
|
@ -354,112 +353,347 @@ func TestProxyError(t *testing.T) {
|
||||||
assert.Equal(t, "http response error", respWriter.Body.String())
|
assert.Equal(t, "http response error", respWriter.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyBastionMode(t *testing.T) {
|
type replayer struct {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
sync.RWMutex
|
||||||
flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError)
|
writeDone chan struct{}
|
||||||
flagSet.Bool("bastion", true, "")
|
rw *bytes.Buffer
|
||||||
|
|
||||||
cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil)
|
|
||||||
err := cliCtx.Set(config.BastionFlag, "true")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
allowURLFromArgs := false
|
|
||||||
ingressRule, err := ingress.NewSingleOrigin(cliCtx, allowURLFromArgs)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
errC := make(chan error)
|
|
||||||
|
|
||||||
log := logger.Create(nil)
|
|
||||||
|
|
||||||
ingressRule.StartOrigins(&wg, log, ctx.Done(), errC)
|
|
||||||
|
|
||||||
proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, log)
|
|
||||||
|
|
||||||
t.Run("testBastionWebsocket", testBastionWebsocket(proxy))
|
|
||||||
cancel()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testBastionWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
func newReplayer(buffer *bytes.Buffer) {
|
||||||
return func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
readPipe, _ := io.Pipe()
|
|
||||||
respWriter := newMockWSRespWriter(readPipe)
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
}
|
||||||
msgFromConn := []byte("data from websocket proxy")
|
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
func (r *replayer) Read(p []byte) (int, error) {
|
||||||
require.NoError(t, err)
|
r.RLock()
|
||||||
wg.Add(1)
|
defer r.RUnlock()
|
||||||
go func() {
|
return r.rw.Read(p)
|
||||||
defer wg.Done()
|
}
|
||||||
defer ln.Close()
|
|
||||||
conn, err := ln.Accept()
|
func (r *replayer) Write(p []byte) (int, error) {
|
||||||
|
r.Lock()
|
||||||
|
defer r.Unlock()
|
||||||
|
n, err := r.rw.Write(p)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *replayer) String() string {
|
||||||
|
r.Lock()
|
||||||
|
defer r.Unlock()
|
||||||
|
return r.rw.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *replayer) Bytes() []byte {
|
||||||
|
r.Lock()
|
||||||
|
defer r.Unlock()
|
||||||
|
return r.rw.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConnections tests every possible permutation of connection protocols
|
||||||
|
// proxied by cloudflared.
|
||||||
|
//
|
||||||
|
// WS - WS : When a websocket based ingress is configured on the origin and
|
||||||
|
// the eyeball is also a websocket client streaming data.
|
||||||
|
// TCP - TCP : When teamnet is enabled and an http or tcp service is running
|
||||||
|
// on the origin.
|
||||||
|
// TCP - WS: When teamnet is enabled and a websocket based service is running
|
||||||
|
// on the origin.
|
||||||
|
// WS - TCP: When a tcp based ingress is configured on the origin and the
|
||||||
|
// eyeball sends tcp packets wrapped in websockets. (E.g: cloudflared access).
|
||||||
|
func TestConnections(t *testing.T) {
|
||||||
|
logger := logger.Create(nil)
|
||||||
|
replayer := &replayer{rw: &bytes.Buffer{}}
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
name string
|
||||||
|
skip bool
|
||||||
|
ingressServicePrefix string
|
||||||
|
|
||||||
|
originService func(*testing.T, net.Listener)
|
||||||
|
eyeballService connection.ResponseWriter
|
||||||
|
connectionType connection.Type
|
||||||
|
wantMessage []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ws-ws proxy",
|
||||||
|
ingressServicePrefix: "ws://",
|
||||||
|
originService: runEchoWSService,
|
||||||
|
eyeballService: newWSRespWriter([]byte("test1"), replayer),
|
||||||
|
connectionType: connection.TypeWebsocket,
|
||||||
|
wantMessage: []byte("test1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp-tcp proxy",
|
||||||
|
ingressServicePrefix: "tcp://",
|
||||||
|
originService: runEchoTCPService,
|
||||||
|
eyeballService: newTCPRespWriter(
|
||||||
|
[]byte(`test2`),
|
||||||
|
replayer,
|
||||||
|
),
|
||||||
|
connectionType: connection.TypeTCP,
|
||||||
|
wantMessage: []byte("echo-test2"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp-ws proxy",
|
||||||
|
ingressServicePrefix: "ws://",
|
||||||
|
originService: runEchoWSService,
|
||||||
|
eyeballService: newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")),
|
||||||
|
connectionType: connection.TypeTCP,
|
||||||
|
wantMessage: []byte("test3"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ws-tcp proxy",
|
||||||
|
ingressServicePrefix: "tcp://",
|
||||||
|
originService: runEchoTCPService,
|
||||||
|
eyeballService: newWSRespWriter([]byte("test4"), replayer),
|
||||||
|
connectionType: connection.TypeWebsocket,
|
||||||
|
wantMessage: []byte("echo-test4"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
if test.skip {
|
||||||
|
t.Skip("todo: skipping a failing test. THis should be fixed before merge")
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
wsConn := websocket.NewConn(conn, nil)
|
test.originService(t, ln)
|
||||||
wsConn.Write(msgFromConn)
|
ingressRule := createSingleIngressConfig(t, test.ingressServicePrefix+ln.Addr().String())
|
||||||
}()
|
var wg sync.WaitGroup
|
||||||
|
errC := make(chan error)
|
||||||
|
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC)
|
||||||
|
proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger)
|
||||||
|
req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Header.Set("Cf-Cloudflared-Proxy-Src", "non-blank-value")
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://dummy", nil)
|
if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok {
|
||||||
req.Header.Set(h2mux.CFJumpDestinationHeader, ln.Addr().String())
|
go func() {
|
||||||
|
resp := pipedWS.roundtrip(test.ingressServicePrefix + ln.Addr().String())
|
||||||
wg.Add(1)
|
replayer.Write(resp)
|
||||||
go func() {
|
}()
|
||||||
defer wg.Done()
|
}
|
||||||
err = proxy.Proxy(respWriter, req, connection.TypeWebsocket)
|
err = proxy.Proxy(test.eyeballService, req, test.connectionType)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
|
cancel()
|
||||||
}()
|
assert.Equal(t, test.wantMessage, replayer.Bytes())
|
||||||
|
replayer.rw.Reset()
|
||||||
// ReadServerText reads next data message from rw, considering that caller represents proxy side.
|
})
|
||||||
returnedMsg, err := wsutil.ReadServerText(respWriter.respBody())
|
|
||||||
if err != io.EOF {
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, msgFromConn, returnedMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
cancel()
|
|
||||||
wg.Wait()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTCPStream(t *testing.T) {
|
type pipedWSWriter struct {
|
||||||
logger := logger.Create(nil)
|
dialer gorillaWS.Dialer
|
||||||
|
wsConn net.Conn
|
||||||
|
pipedConn net.Conn
|
||||||
|
respWriter connection.ResponseWriter
|
||||||
|
messageToWrite []byte
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
func newPipedWSWriter(rw *mockTCPRespWriter, messageToWrite []byte) *pipedWSWriter {
|
||||||
|
conn1, conn2 := net.Pipe()
|
||||||
|
dialer := gorillaWS.Dialer{
|
||||||
|
NetDial: func(network, addr string) (net.Conn, error) {
|
||||||
|
return conn2, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
rw.pr = conn1
|
||||||
|
rw.w = conn1
|
||||||
|
return &pipedWSWriter{
|
||||||
|
dialer: dialer,
|
||||||
|
pipedConn: conn1,
|
||||||
|
wsConn: conn2,
|
||||||
|
messageToWrite: messageToWrite,
|
||||||
|
respWriter: rw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipedWSWriter) roundtrip(addr string) []byte {
|
||||||
|
header := http.Header{}
|
||||||
|
conn, resp, err := p.dialer.Dial(addr, header)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||||
|
panic(fmt.Errorf("resp returned status code: %d", resp.StatusCode))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.WriteMessage(gorillaWS.TextMessage, p.messageToWrite)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, data, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipedWSWriter) Read(data []byte) (int, error) {
|
||||||
|
return p.pipedConn.Read(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipedWSWriter) Write(data []byte) (int, error) {
|
||||||
|
return p.pipedConn.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipedWSWriter) WriteErrorResponse() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipedWSWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type wsRespWriter struct {
|
||||||
|
w io.Writer
|
||||||
|
pr *io.PipeReader
|
||||||
|
pw *io.PipeWriter
|
||||||
|
code int
|
||||||
|
}
|
||||||
|
|
||||||
|
// newWSRespWriter uses wsutil.WriteClientText to generate websocket frames.
|
||||||
|
// and wsutil.ReadClientText to translate frames from server to byte data.
|
||||||
|
// In essence, this acts as a wsClient.
|
||||||
|
func newWSRespWriter(data []byte, w io.Writer) *wsRespWriter {
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
go wsutil.WriteClientBinary(pw, data)
|
||||||
|
return &wsRespWriter{
|
||||||
|
w: w,
|
||||||
|
pr: pr,
|
||||||
|
pw: pw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read is read by ingress.Stream and serves as the input from the client.
|
||||||
|
func (w *wsRespWriter) Read(p []byte) (int, error) {
|
||||||
|
return w.pr.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is written to by ingress.Stream and serves as the output to the client.
|
||||||
|
func (w *wsRespWriter) Write(p []byte) (int, error) {
|
||||||
|
defer w.pw.Close()
|
||||||
|
returnedMsg, err := wsutil.ReadServerBinary(bytes.NewBuffer(p))
|
||||||
|
if err != nil {
|
||||||
|
// The data was not returned by a websocket connecton.
|
||||||
|
if err != io.ErrUnexpectedEOF {
|
||||||
|
return w.w.Write(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return w.w.Write(returnedMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
|
w.code = status
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wsRespWriter) WriteErrorResponse() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func runEchoTCPService(t *testing.T, l net.Listener) {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := l.Accept()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
size, err := conn.Read(buf)
|
||||||
|
if err == io.EOF {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data := []byte("echo-")
|
||||||
|
data = append(data, buf[:size]...)
|
||||||
|
_, err = conn.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func runEchoWSService(t *testing.T, l net.Listener) {
|
||||||
|
var upgrader = gorillaWS.Upgrader{
|
||||||
|
ReadBufferSize: 10,
|
||||||
|
WriteBufferSize: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
var ws = func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
for {
|
||||||
|
messageType, p, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.WriteMessage(messageType, p); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
server := http.Server{
|
||||||
|
Handler: http.HandlerFunc(ws),
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
err := server.Serve(l)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSingleIngressConfig(t *testing.T, service string) ingress.Ingress {
|
||||||
ingressConfig := &config.Configuration{
|
ingressConfig := &config.Configuration{
|
||||||
Ingress: []config.UnvalidatedIngressRule{
|
Ingress: []config.UnvalidatedIngressRule{
|
||||||
{
|
{
|
||||||
Hostname: "*",
|
Hostname: "*",
|
||||||
Service: "bastion",
|
Service: service,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
ingressRule, err := ingress.ParseIngress(ingressConfig)
|
ingressRule, err := ingress.ParseIngress(ingressConfig)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
return ingressRule
|
||||||
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
type tcpWrappedWs struct {
|
||||||
errC := make(chan error)
|
|
||||||
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC)
|
|
||||||
|
|
||||||
proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger)
|
|
||||||
|
|
||||||
t.Run("testTCPStream", testTCPStreamProxy(proxy))
|
|
||||||
cancel()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockTCPRespWriter struct {
|
type mockTCPRespWriter struct {
|
||||||
w io.Writer
|
w io.Writer
|
||||||
|
pr io.Reader
|
||||||
|
pw *io.PipeWriter
|
||||||
code int
|
code int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newTCPRespWriter(data []byte, w io.Writer) *mockTCPRespWriter {
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
go pw.Write(data)
|
||||||
|
return &mockTCPRespWriter{
|
||||||
|
w: w,
|
||||||
|
pr: pr,
|
||||||
|
pw: pw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockTCPRespWriter) Read(p []byte) (n int, err error) {
|
func (m *mockTCPRespWriter) Read(p []byte) (n int, err error) {
|
||||||
return len(p), nil
|
return m.pr.Read(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
|
func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
|
||||||
|
defer m.pw.Close()
|
||||||
return m.w.Write(p)
|
return m.w.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -470,44 +704,3 @@ func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) err
|
||||||
m.code = status
|
m.code = status
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func testTCPStreamProxy(proxy connection.OriginProxy) func(t *testing.T) {
|
|
||||||
return func(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
readPipe, writePipe := io.Pipe()
|
|
||||||
respWriter := &mockTCPRespWriter{
|
|
||||||
w: writePipe,
|
|
||||||
}
|
|
||||||
msgFromConn := []byte("data from tcp proxy")
|
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
go func() {
|
|
||||||
defer ln.Close()
|
|
||||||
conn, err := ln.Accept()
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer conn.Close()
|
|
||||||
_, err = conn.Write(msgFromConn)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://dummy", nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
req.Header.Set("Cf-Cloudflared-Proxy-Src", "non-blank-value")
|
|
||||||
req.Host = ln.Addr().String()
|
|
||||||
err = proxy.Proxy(respWriter, req, connection.TypeTCP)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
require.Equal(t, http.StatusSwitchingProtocols, respWriter.code)
|
|
||||||
|
|
||||||
returnedMsg := make([]byte, len(msgFromConn))
|
|
||||||
|
|
||||||
_, err = readPipe.Read(returnedMsg)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, msgFromConn, returnedMsg)
|
|
||||||
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -36,7 +36,6 @@ func (c *GorillaConn) Read(p []byte) (int, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return copy(p, message), nil
|
return copy(p, message), nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -71,11 +70,13 @@ type Conn struct {
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(rw io.ReadWriter, log *zerolog.Logger) *Conn {
|
func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
|
||||||
return &Conn{
|
c := &Conn{
|
||||||
rw: rw,
|
rw: rw,
|
||||||
log: log,
|
log: log,
|
||||||
}
|
}
|
||||||
|
go c.pinger(ctx)
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read will read messages from the websocket connection
|
// Read will read messages from the websocket connection
|
||||||
|
@ -92,11 +93,10 @@ func (c *Conn) Write(p []byte) (int, error) {
|
||||||
if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
|
if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return len(p), nil
|
return len(p), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Pinger(ctx context.Context) {
|
func (c *Conn) pinger(ctx context.Context) {
|
||||||
pongMessge := wsutil.Message{
|
pongMessge := wsutil.Message{
|
||||||
OpCode: gobwas.OpPong,
|
OpCode: gobwas.OpPong,
|
||||||
Payload: []byte{},
|
Payload: []byte{},
|
||||||
|
|
Loading…
Reference in New Issue