diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index 94cf76ca..4212ae79 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -90,16 +90,16 @@ func (wsc *wsConnection) Type() connection.Type { return connection.TypeWebsocket } -func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) { +func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, *http.Response, error) { d := &gws.Dialer{ TLSClientConfig: transport.TLSClientConfig, } wsConn, resp, err := websocket.ClientConnect(r, d) if err != nil { - return nil, err + return nil, nil, err } return &wsConnection{ wsConn, resp, - }, nil + }, resp, nil } diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index c5a9dff3..affdc2c3 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -21,7 +21,7 @@ type HTTPOriginProxy interface { // StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level. type StreamBasedOriginProxy interface { - EstablishConnection(r *http.Request) (OriginConnection, error) + EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) } func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { @@ -29,8 +29,8 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { } // TODO: TUN-3636: establish connection to origins over UDS -func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, error) { - return nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections") +func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { + return nil, nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections") } func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { @@ -40,7 +40,7 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { return o.transport.RoundTrip(req) } -func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, error) { +func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) { req.URL.Host = o.url.Host req.URL.Scheme = websocket.ChangeRequestScheme(o.url) return newWSConnection(o.transport, req) @@ -53,7 +53,7 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) { return o.transport.RoundTrip(req) } -func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, error) { +func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) { req.URL.Host = o.server.Addr().String() req.URL.Scheme = "wss" return newWSConnection(o.transport, req) @@ -63,12 +63,13 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { return o.resp, nil } -func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, error) { +func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { dest, err := o.destination(r) if err != nil { - return nil, err + return nil, nil, err } - return o.client.connect(r, dest) + conn, err := o.client.connect(r, dest) + return conn, nil, err } // getRequestHost returns the host of the http.Request. @@ -102,8 +103,10 @@ func removePath(dest string) string { return strings.SplitN(dest, "/", 2)[0] } -func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, error) { - return o.client.connect(r, o.dest) +func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { + conn, err := o.client.connect(r, o.dest) + return conn, nil, err + } type tcpClient struct { diff --git a/origin/proxy.go b/origin/proxy.go index e2316f41..2edaaf07 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -166,20 +166,22 @@ func (p *proxy) proxyConnection( sourceConnectionType connection.Type, connectionProxy ingress.StreamBasedOriginProxy, ) (*http.Response, error) { - originConn, err := connectionProxy.EstablishConnection(req) + originConn, connectionResp, err := connectionProxy.EstablishConnection(req) if err != nil { return nil, err } var eyeballConn io.ReadWriter = w respHeader := http.Header{} + if connectionResp != nil { + respHeader = connectionResp.Header + } if sourceConnectionType == connection.TypeWebsocket { wsReadWriter := websocket.NewConn(serveCtx, w, p.log) // If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames if originConn.Type() != sourceConnectionType { eyeballConn = wsReadWriter } - respHeader = websocket.NewResponseHeader(req) } status := http.StatusSwitchingProtocols resp := &http.Response{ diff --git a/origin/proxy_test.go b/origin/proxy_test.go index 44b09d5f..04ca922e 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -411,7 +411,9 @@ func TestConnections(t *testing.T) { originService func(*testing.T, net.Listener) eyeballService connection.ResponseWriter connectionType connection.Type + requestHeaders http.Header wantMessage []byte + wantHeaders http.Header }{ { name: "ws-ws proxy", @@ -419,7 +421,16 @@ func TestConnections(t *testing.T) { originService: runEchoWSService, eyeballService: newWSRespWriter([]byte("test1"), replayer), connectionType: connection.TypeWebsocket, - wantMessage: []byte("test1"), + requestHeaders: map[string][]string{ + "Test-Cloudflared-Echo": []string{"Echo"}, + }, + wantMessage: []byte("echo-test1"), + wantHeaders: map[string][]string{ + "Connection": []string{"Upgrade"}, + "Sec-Websocket-Accept": []string{"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="}, + "Upgrade": []string{"websocket"}, + "Test-Cloudflared-Echo": []string{"Echo"}, + }, }, { name: "tcp-tcp proxy", @@ -430,15 +441,25 @@ func TestConnections(t *testing.T) { replayer, ), connectionType: connection.TypeTCP, - wantMessage: []byte("echo-test2"), + requestHeaders: map[string][]string{ + "Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"}, + }, + wantMessage: []byte("echo-test2"), + wantHeaders: http.Header{}, }, { name: "tcp-ws proxy", ingressServicePrefix: "ws://", originService: runEchoWSService, eyeballService: newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")), - connectionType: connection.TypeTCP, - wantMessage: []byte("test3"), + requestHeaders: map[string][]string{ + "Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"}, + }, + connectionType: connection.TypeTCP, + wantMessage: []byte("echo-test3"), + // We expect no headers here because they are sent back via + // the stream. + wantHeaders: http.Header{}, }, { name: "ws-tcp proxy", @@ -447,14 +468,12 @@ func TestConnections(t *testing.T) { eyeballService: newWSRespWriter([]byte("test4"), replayer), connectionType: connection.TypeWebsocket, wantMessage: []byte("echo-test4"), + wantHeaders: http.Header{}, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if test.skip { - t.Skip("todo: skipping a failing test. THis should be fixed before merge") - } ctx, cancel := context.WithCancel(context.Background()) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -466,7 +485,11 @@ func TestConnections(t *testing.T) { proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger) req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil) require.NoError(t, err) - req.Header.Set("Cf-Cloudflared-Proxy-Src", "non-blank-value") + reqHeaders := make(http.Header) + for k, vs := range test.requestHeaders { + reqHeaders[k] = vs + } + req.Header = reqHeaders if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok { go func() { @@ -474,21 +497,29 @@ func TestConnections(t *testing.T) { replayer.Write(resp) }() } + err = proxy.Proxy(test.eyeballService, req, test.connectionType) require.NoError(t, err) cancel() assert.Equal(t, test.wantMessage, replayer.Bytes()) + respPrinter := test.eyeballService.(responsePrinter) + assert.Equal(t, test.wantHeaders, respPrinter.printRespHeaders()) replayer.rw.Reset() }) } } +type responsePrinter interface { + printRespHeaders() http.Header +} + type pipedWSWriter struct { dialer gorillaWS.Dialer wsConn net.Conn pipedConn net.Conn respWriter connection.ResponseWriter + respHeaders http.Header messageToWrite []byte } @@ -547,14 +578,21 @@ func (p *pipedWSWriter) WriteErrorResponse() { } func (p *pipedWSWriter) WriteRespHeaders(status int, header http.Header) error { + p.respHeaders = header return nil } +// printRespHeaders is a test function to read respHeaders +func (p *pipedWSWriter) printRespHeaders() http.Header { + return p.respHeaders +} + type wsRespWriter struct { - w io.Writer - pr *io.PipeReader - pw *io.PipeWriter - code int + w io.Writer + pr *io.PipeReader + pw *io.PipeWriter + respHeaders http.Header + code int } // newWSRespWriter uses wsutil.WriteClientText to generate websocket frames. @@ -589,6 +627,7 @@ func (w *wsRespWriter) Write(p []byte) (int, error) { } func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error { + w.respHeaders = header w.code = status return nil } @@ -596,6 +635,11 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error { func (w *wsRespWriter) WriteErrorResponse() { } +// printRespHeaders is a test function to read respHeaders +func (w *wsRespWriter) printRespHeaders() http.Header { + return w.respHeaders +} + func runEchoTCPService(t *testing.T, l net.Listener) { go func() { for { @@ -628,7 +672,13 @@ func runEchoWSService(t *testing.T, l net.Listener) { } var ws = func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) + header := make(http.Header) + for k, vs := range r.Header { + if k == "Test-Cloudflared-Echo" { + header[k] = vs + } + } + conn, err := upgrader.Upgrade(w, r, header) require.NoError(t, err) defer conn.Close() @@ -637,8 +687,9 @@ func runEchoWSService(t *testing.T, l net.Listener) { if err != nil { return } - - if err := conn.WriteMessage(messageType, p); err != nil { + data := []byte("echo-") + data = append(data, p...) + if err := conn.WriteMessage(messageType, data); err != nil { return } } @@ -672,10 +723,11 @@ type tcpWrappedWs struct { } type mockTCPRespWriter struct { - w io.Writer - pr io.Reader - pw *io.PipeWriter - code int + w io.Writer + pr io.Reader + pw *io.PipeWriter + respHeaders http.Header + code int } func newTCPRespWriter(data []byte, w io.Writer) *mockTCPRespWriter { @@ -701,6 +753,12 @@ func (m *mockTCPRespWriter) WriteErrorResponse() { } func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error { + m.respHeaders = header m.code = status return nil } + +// printRespHeaders is a test function to read respHeaders +func (m *mockTCPRespWriter) printRespHeaders() http.Header { + return m.respHeaders +}