TUN-3853: Respond with ws headers from the origin service rather than generating our own

This commit is contained in:
Sudarsan Reddy 2021-02-04 18:03:34 +00:00 committed by Nuno Diegues
parent 9c298e4851
commit ed57ee64e8
4 changed files with 97 additions and 34 deletions

View File

@ -90,16 +90,16 @@ func (wsc *wsConnection) Type() connection.Type {
return connection.TypeWebsocket
}
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) {
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, *http.Response, error) {
d := &gws.Dialer{
TLSClientConfig: transport.TLSClientConfig,
}
wsConn, resp, err := websocket.ClientConnect(r, d)
if err != nil {
return nil, err
return nil, nil, err
}
return &wsConnection{
wsConn,
resp,
}, nil
}, resp, nil
}

View File

@ -21,7 +21,7 @@ type HTTPOriginProxy interface {
// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
type StreamBasedOriginProxy interface {
EstablishConnection(r *http.Request) (OriginConnection, error)
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
}
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
@ -29,8 +29,8 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
}
// TODO: TUN-3636: establish connection to origins over UDS
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, error) {
return nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
return nil, nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
}
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
@ -40,7 +40,7 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req)
}
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, error) {
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req.URL.Host = o.url.Host
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
return newWSConnection(o.transport, req)
@ -53,7 +53,7 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req)
}
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, error) {
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req.URL.Host = o.server.Addr().String()
req.URL.Scheme = "wss"
return newWSConnection(o.transport, req)
@ -63,12 +63,13 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return o.resp, nil
}
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, error) {
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
dest, err := o.destination(r)
if err != nil {
return nil, err
return nil, nil, err
}
return o.client.connect(r, dest)
conn, err := o.client.connect(r, dest)
return conn, nil, err
}
// getRequestHost returns the host of the http.Request.
@ -102,8 +103,10 @@ func removePath(dest string) string {
return strings.SplitN(dest, "/", 2)[0]
}
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, error) {
return o.client.connect(r, o.dest)
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
conn, err := o.client.connect(r, o.dest)
return conn, nil, err
}
type tcpClient struct {

View File

@ -166,20 +166,22 @@ func (p *proxy) proxyConnection(
sourceConnectionType connection.Type,
connectionProxy ingress.StreamBasedOriginProxy,
) (*http.Response, error) {
originConn, err := connectionProxy.EstablishConnection(req)
originConn, connectionResp, err := connectionProxy.EstablishConnection(req)
if err != nil {
return nil, err
}
var eyeballConn io.ReadWriter = w
respHeader := http.Header{}
if connectionResp != nil {
respHeader = connectionResp.Header
}
if sourceConnectionType == connection.TypeWebsocket {
wsReadWriter := websocket.NewConn(serveCtx, w, p.log)
// If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames
if originConn.Type() != sourceConnectionType {
eyeballConn = wsReadWriter
}
respHeader = websocket.NewResponseHeader(req)
}
status := http.StatusSwitchingProtocols
resp := &http.Response{

View File

@ -411,7 +411,9 @@ func TestConnections(t *testing.T) {
originService func(*testing.T, net.Listener)
eyeballService connection.ResponseWriter
connectionType connection.Type
requestHeaders http.Header
wantMessage []byte
wantHeaders http.Header
}{
{
name: "ws-ws proxy",
@ -419,7 +421,16 @@ func TestConnections(t *testing.T) {
originService: runEchoWSService,
eyeballService: newWSRespWriter([]byte("test1"), replayer),
connectionType: connection.TypeWebsocket,
wantMessage: []byte("test1"),
requestHeaders: map[string][]string{
"Test-Cloudflared-Echo": []string{"Echo"},
},
wantMessage: []byte("echo-test1"),
wantHeaders: map[string][]string{
"Connection": []string{"Upgrade"},
"Sec-Websocket-Accept": []string{"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
"Upgrade": []string{"websocket"},
"Test-Cloudflared-Echo": []string{"Echo"},
},
},
{
name: "tcp-tcp proxy",
@ -430,15 +441,25 @@ func TestConnections(t *testing.T) {
replayer,
),
connectionType: connection.TypeTCP,
wantMessage: []byte("echo-test2"),
requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"},
},
wantMessage: []byte("echo-test2"),
wantHeaders: http.Header{},
},
{
name: "tcp-ws proxy",
ingressServicePrefix: "ws://",
originService: runEchoWSService,
eyeballService: newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")),
connectionType: connection.TypeTCP,
wantMessage: []byte("test3"),
requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"},
},
connectionType: connection.TypeTCP,
wantMessage: []byte("echo-test3"),
// We expect no headers here because they are sent back via
// the stream.
wantHeaders: http.Header{},
},
{
name: "ws-tcp proxy",
@ -447,14 +468,12 @@ func TestConnections(t *testing.T) {
eyeballService: newWSRespWriter([]byte("test4"), replayer),
connectionType: connection.TypeWebsocket,
wantMessage: []byte("echo-test4"),
wantHeaders: http.Header{},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.skip {
t.Skip("todo: skipping a failing test. THis should be fixed before merge")
}
ctx, cancel := context.WithCancel(context.Background())
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@ -466,7 +485,11 @@ func TestConnections(t *testing.T) {
proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger)
req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil)
require.NoError(t, err)
req.Header.Set("Cf-Cloudflared-Proxy-Src", "non-blank-value")
reqHeaders := make(http.Header)
for k, vs := range test.requestHeaders {
reqHeaders[k] = vs
}
req.Header = reqHeaders
if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok {
go func() {
@ -474,21 +497,29 @@ func TestConnections(t *testing.T) {
replayer.Write(resp)
}()
}
err = proxy.Proxy(test.eyeballService, req, test.connectionType)
require.NoError(t, err)
cancel()
assert.Equal(t, test.wantMessage, replayer.Bytes())
respPrinter := test.eyeballService.(responsePrinter)
assert.Equal(t, test.wantHeaders, respPrinter.printRespHeaders())
replayer.rw.Reset()
})
}
}
type responsePrinter interface {
printRespHeaders() http.Header
}
type pipedWSWriter struct {
dialer gorillaWS.Dialer
wsConn net.Conn
pipedConn net.Conn
respWriter connection.ResponseWriter
respHeaders http.Header
messageToWrite []byte
}
@ -547,14 +578,21 @@ func (p *pipedWSWriter) WriteErrorResponse() {
}
func (p *pipedWSWriter) WriteRespHeaders(status int, header http.Header) error {
p.respHeaders = header
return nil
}
// printRespHeaders is a test function to read respHeaders
func (p *pipedWSWriter) printRespHeaders() http.Header {
return p.respHeaders
}
type wsRespWriter struct {
w io.Writer
pr *io.PipeReader
pw *io.PipeWriter
code int
w io.Writer
pr *io.PipeReader
pw *io.PipeWriter
respHeaders http.Header
code int
}
// newWSRespWriter uses wsutil.WriteClientText to generate websocket frames.
@ -589,6 +627,7 @@ func (w *wsRespWriter) Write(p []byte) (int, error) {
}
func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
w.respHeaders = header
w.code = status
return nil
}
@ -596,6 +635,11 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
func (w *wsRespWriter) WriteErrorResponse() {
}
// printRespHeaders is a test function to read respHeaders
func (w *wsRespWriter) printRespHeaders() http.Header {
return w.respHeaders
}
func runEchoTCPService(t *testing.T, l net.Listener) {
go func() {
for {
@ -628,7 +672,13 @@ func runEchoWSService(t *testing.T, l net.Listener) {
}
var ws = func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
header := make(http.Header)
for k, vs := range r.Header {
if k == "Test-Cloudflared-Echo" {
header[k] = vs
}
}
conn, err := upgrader.Upgrade(w, r, header)
require.NoError(t, err)
defer conn.Close()
@ -637,8 +687,9 @@ func runEchoWSService(t *testing.T, l net.Listener) {
if err != nil {
return
}
if err := conn.WriteMessage(messageType, p); err != nil {
data := []byte("echo-")
data = append(data, p...)
if err := conn.WriteMessage(messageType, data); err != nil {
return
}
}
@ -672,10 +723,11 @@ type tcpWrappedWs struct {
}
type mockTCPRespWriter struct {
w io.Writer
pr io.Reader
pw *io.PipeWriter
code int
w io.Writer
pr io.Reader
pw *io.PipeWriter
respHeaders http.Header
code int
}
func newTCPRespWriter(data []byte, w io.Writer) *mockTCPRespWriter {
@ -701,6 +753,12 @@ func (m *mockTCPRespWriter) WriteErrorResponse() {
}
func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
m.respHeaders = header
m.code = status
return nil
}
// printRespHeaders is a test function to read respHeaders
func (m *mockTCPRespWriter) printRespHeaders() http.Header {
return m.respHeaders
}