TUN-3853: Respond with ws headers from the origin service rather than generating our own
This commit is contained in:
parent
9c298e4851
commit
ed57ee64e8
|
@ -90,16 +90,16 @@ func (wsc *wsConnection) Type() connection.Type {
|
|||
return connection.TypeWebsocket
|
||||
}
|
||||
|
||||
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) {
|
||||
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
d := &gws.Dialer{
|
||||
TLSClientConfig: transport.TLSClientConfig,
|
||||
}
|
||||
wsConn, resp, err := websocket.ClientConnect(r, d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
return &wsConnection{
|
||||
wsConn,
|
||||
resp,
|
||||
}, nil
|
||||
}, resp, nil
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ type HTTPOriginProxy interface {
|
|||
|
||||
// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
|
||||
type StreamBasedOriginProxy interface {
|
||||
EstablishConnection(r *http.Request) (OriginConnection, error)
|
||||
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
|
||||
}
|
||||
|
||||
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
|
@ -29,8 +29,8 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
}
|
||||
|
||||
// TODO: TUN-3636: establish connection to origins over UDS
|
||||
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, error) {
|
||||
return nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
|
||||
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
return nil, nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
|
||||
}
|
||||
|
||||
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
|
@ -40,7 +40,7 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, error) {
|
||||
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req.URL.Host = o.url.Host
|
||||
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
|
||||
return newWSConnection(o.transport, req)
|
||||
|
@ -53,7 +53,7 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, error) {
|
||||
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "wss"
|
||||
return newWSConnection(o.transport, req)
|
||||
|
@ -63,12 +63,13 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
|||
return o.resp, nil
|
||||
}
|
||||
|
||||
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, error) {
|
||||
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
dest, err := o.destination(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
return o.client.connect(r, dest)
|
||||
conn, err := o.client.connect(r, dest)
|
||||
return conn, nil, err
|
||||
}
|
||||
|
||||
// getRequestHost returns the host of the http.Request.
|
||||
|
@ -102,8 +103,10 @@ func removePath(dest string) string {
|
|||
return strings.SplitN(dest, "/", 2)[0]
|
||||
}
|
||||
|
||||
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, error) {
|
||||
return o.client.connect(r, o.dest)
|
||||
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
conn, err := o.client.connect(r, o.dest)
|
||||
return conn, nil, err
|
||||
|
||||
}
|
||||
|
||||
type tcpClient struct {
|
||||
|
|
|
@ -166,20 +166,22 @@ func (p *proxy) proxyConnection(
|
|||
sourceConnectionType connection.Type,
|
||||
connectionProxy ingress.StreamBasedOriginProxy,
|
||||
) (*http.Response, error) {
|
||||
originConn, err := connectionProxy.EstablishConnection(req)
|
||||
originConn, connectionResp, err := connectionProxy.EstablishConnection(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var eyeballConn io.ReadWriter = w
|
||||
respHeader := http.Header{}
|
||||
if connectionResp != nil {
|
||||
respHeader = connectionResp.Header
|
||||
}
|
||||
if sourceConnectionType == connection.TypeWebsocket {
|
||||
wsReadWriter := websocket.NewConn(serveCtx, w, p.log)
|
||||
// If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames
|
||||
if originConn.Type() != sourceConnectionType {
|
||||
eyeballConn = wsReadWriter
|
||||
}
|
||||
respHeader = websocket.NewResponseHeader(req)
|
||||
}
|
||||
status := http.StatusSwitchingProtocols
|
||||
resp := &http.Response{
|
||||
|
|
|
@ -411,7 +411,9 @@ func TestConnections(t *testing.T) {
|
|||
originService func(*testing.T, net.Listener)
|
||||
eyeballService connection.ResponseWriter
|
||||
connectionType connection.Type
|
||||
requestHeaders http.Header
|
||||
wantMessage []byte
|
||||
wantHeaders http.Header
|
||||
}{
|
||||
{
|
||||
name: "ws-ws proxy",
|
||||
|
@ -419,7 +421,16 @@ func TestConnections(t *testing.T) {
|
|||
originService: runEchoWSService,
|
||||
eyeballService: newWSRespWriter([]byte("test1"), replayer),
|
||||
connectionType: connection.TypeWebsocket,
|
||||
wantMessage: []byte("test1"),
|
||||
requestHeaders: map[string][]string{
|
||||
"Test-Cloudflared-Echo": []string{"Echo"},
|
||||
},
|
||||
wantMessage: []byte("echo-test1"),
|
||||
wantHeaders: map[string][]string{
|
||||
"Connection": []string{"Upgrade"},
|
||||
"Sec-Websocket-Accept": []string{"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
|
||||
"Upgrade": []string{"websocket"},
|
||||
"Test-Cloudflared-Echo": []string{"Echo"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tcp-tcp proxy",
|
||||
|
@ -430,15 +441,25 @@ func TestConnections(t *testing.T) {
|
|||
replayer,
|
||||
),
|
||||
connectionType: connection.TypeTCP,
|
||||
wantMessage: []byte("echo-test2"),
|
||||
requestHeaders: map[string][]string{
|
||||
"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"},
|
||||
},
|
||||
wantMessage: []byte("echo-test2"),
|
||||
wantHeaders: http.Header{},
|
||||
},
|
||||
{
|
||||
name: "tcp-ws proxy",
|
||||
ingressServicePrefix: "ws://",
|
||||
originService: runEchoWSService,
|
||||
eyeballService: newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")),
|
||||
connectionType: connection.TypeTCP,
|
||||
wantMessage: []byte("test3"),
|
||||
requestHeaders: map[string][]string{
|
||||
"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"},
|
||||
},
|
||||
connectionType: connection.TypeTCP,
|
||||
wantMessage: []byte("echo-test3"),
|
||||
// We expect no headers here because they are sent back via
|
||||
// the stream.
|
||||
wantHeaders: http.Header{},
|
||||
},
|
||||
{
|
||||
name: "ws-tcp proxy",
|
||||
|
@ -447,14 +468,12 @@ func TestConnections(t *testing.T) {
|
|||
eyeballService: newWSRespWriter([]byte("test4"), replayer),
|
||||
connectionType: connection.TypeWebsocket,
|
||||
wantMessage: []byte("echo-test4"),
|
||||
wantHeaders: http.Header{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if test.skip {
|
||||
t.Skip("todo: skipping a failing test. THis should be fixed before merge")
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
@ -466,7 +485,11 @@ func TestConnections(t *testing.T) {
|
|||
proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger)
|
||||
req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Cf-Cloudflared-Proxy-Src", "non-blank-value")
|
||||
reqHeaders := make(http.Header)
|
||||
for k, vs := range test.requestHeaders {
|
||||
reqHeaders[k] = vs
|
||||
}
|
||||
req.Header = reqHeaders
|
||||
|
||||
if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok {
|
||||
go func() {
|
||||
|
@ -474,21 +497,29 @@ func TestConnections(t *testing.T) {
|
|||
replayer.Write(resp)
|
||||
}()
|
||||
}
|
||||
|
||||
err = proxy.Proxy(test.eyeballService, req, test.connectionType)
|
||||
require.NoError(t, err)
|
||||
|
||||
cancel()
|
||||
assert.Equal(t, test.wantMessage, replayer.Bytes())
|
||||
respPrinter := test.eyeballService.(responsePrinter)
|
||||
assert.Equal(t, test.wantHeaders, respPrinter.printRespHeaders())
|
||||
replayer.rw.Reset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type responsePrinter interface {
|
||||
printRespHeaders() http.Header
|
||||
}
|
||||
|
||||
type pipedWSWriter struct {
|
||||
dialer gorillaWS.Dialer
|
||||
wsConn net.Conn
|
||||
pipedConn net.Conn
|
||||
respWriter connection.ResponseWriter
|
||||
respHeaders http.Header
|
||||
messageToWrite []byte
|
||||
}
|
||||
|
||||
|
@ -547,14 +578,21 @@ func (p *pipedWSWriter) WriteErrorResponse() {
|
|||
}
|
||||
|
||||
func (p *pipedWSWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||
p.respHeaders = header
|
||||
return nil
|
||||
}
|
||||
|
||||
// printRespHeaders is a test function to read respHeaders
|
||||
func (p *pipedWSWriter) printRespHeaders() http.Header {
|
||||
return p.respHeaders
|
||||
}
|
||||
|
||||
type wsRespWriter struct {
|
||||
w io.Writer
|
||||
pr *io.PipeReader
|
||||
pw *io.PipeWriter
|
||||
code int
|
||||
w io.Writer
|
||||
pr *io.PipeReader
|
||||
pw *io.PipeWriter
|
||||
respHeaders http.Header
|
||||
code int
|
||||
}
|
||||
|
||||
// newWSRespWriter uses wsutil.WriteClientText to generate websocket frames.
|
||||
|
@ -589,6 +627,7 @@ func (w *wsRespWriter) Write(p []byte) (int, error) {
|
|||
}
|
||||
|
||||
func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||
w.respHeaders = header
|
||||
w.code = status
|
||||
return nil
|
||||
}
|
||||
|
@ -596,6 +635,11 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
|||
func (w *wsRespWriter) WriteErrorResponse() {
|
||||
}
|
||||
|
||||
// printRespHeaders is a test function to read respHeaders
|
||||
func (w *wsRespWriter) printRespHeaders() http.Header {
|
||||
return w.respHeaders
|
||||
}
|
||||
|
||||
func runEchoTCPService(t *testing.T, l net.Listener) {
|
||||
go func() {
|
||||
for {
|
||||
|
@ -628,7 +672,13 @@ func runEchoWSService(t *testing.T, l net.Listener) {
|
|||
}
|
||||
|
||||
var ws = func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
header := make(http.Header)
|
||||
for k, vs := range r.Header {
|
||||
if k == "Test-Cloudflared-Echo" {
|
||||
header[k] = vs
|
||||
}
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, header)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -637,8 +687,9 @@ func runEchoWSService(t *testing.T, l net.Listener) {
|
|||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.WriteMessage(messageType, p); err != nil {
|
||||
data := []byte("echo-")
|
||||
data = append(data, p...)
|
||||
if err := conn.WriteMessage(messageType, data); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -672,10 +723,11 @@ type tcpWrappedWs struct {
|
|||
}
|
||||
|
||||
type mockTCPRespWriter struct {
|
||||
w io.Writer
|
||||
pr io.Reader
|
||||
pw *io.PipeWriter
|
||||
code int
|
||||
w io.Writer
|
||||
pr io.Reader
|
||||
pw *io.PipeWriter
|
||||
respHeaders http.Header
|
||||
code int
|
||||
}
|
||||
|
||||
func newTCPRespWriter(data []byte, w io.Writer) *mockTCPRespWriter {
|
||||
|
@ -701,6 +753,12 @@ func (m *mockTCPRespWriter) WriteErrorResponse() {
|
|||
}
|
||||
|
||||
func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||
m.respHeaders = header
|
||||
m.code = status
|
||||
return nil
|
||||
}
|
||||
|
||||
// printRespHeaders is a test function to read respHeaders
|
||||
func (m *mockTCPRespWriter) printRespHeaders() http.Header {
|
||||
return m.respHeaders
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue