TUN-3853: Respond with ws headers from the origin service rather than generating our own
This commit is contained in:
		
							parent
							
								
									9c298e4851
								
							
						
					
					
						commit
						ed57ee64e8
					
				|  | @ -90,16 +90,16 @@ func (wsc *wsConnection) Type() connection.Type { | ||||||
| 	return connection.TypeWebsocket | 	return connection.TypeWebsocket | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) { | func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, *http.Response, error) { | ||||||
| 	d := &gws.Dialer{ | 	d := &gws.Dialer{ | ||||||
| 		TLSClientConfig: transport.TLSClientConfig, | 		TLSClientConfig: transport.TLSClientConfig, | ||||||
| 	} | 	} | ||||||
| 	wsConn, resp, err := websocket.ClientConnect(r, d) | 	wsConn, resp, err := websocket.ClientConnect(r, d) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
| 	return &wsConnection{ | 	return &wsConnection{ | ||||||
| 		wsConn, | 		wsConn, | ||||||
| 		resp, | 		resp, | ||||||
| 	}, nil | 	}, resp, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -21,7 +21,7 @@ type HTTPOriginProxy interface { | ||||||
| 
 | 
 | ||||||
| // StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
 | // StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
 | ||||||
| type StreamBasedOriginProxy interface { | type StreamBasedOriginProxy interface { | ||||||
| 	EstablishConnection(r *http.Request) (OriginConnection, error) | 	EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { | func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { | ||||||
|  | @ -29,8 +29,8 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // TODO: TUN-3636: establish connection to origins over UDS
 | // TODO: TUN-3636: establish connection to origins over UDS
 | ||||||
| func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, error) { | func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { | ||||||
| 	return nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections") | 	return nil, nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { | func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { | ||||||
|  | @ -40,7 +40,7 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { | ||||||
| 	return o.transport.RoundTrip(req) | 	return o.transport.RoundTrip(req) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, error) { | func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) { | ||||||
| 	req.URL.Host = o.url.Host | 	req.URL.Host = o.url.Host | ||||||
| 	req.URL.Scheme = websocket.ChangeRequestScheme(o.url) | 	req.URL.Scheme = websocket.ChangeRequestScheme(o.url) | ||||||
| 	return newWSConnection(o.transport, req) | 	return newWSConnection(o.transport, req) | ||||||
|  | @ -53,7 +53,7 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) { | ||||||
| 	return o.transport.RoundTrip(req) | 	return o.transport.RoundTrip(req) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, error) { | func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) { | ||||||
| 	req.URL.Host = o.server.Addr().String() | 	req.URL.Host = o.server.Addr().String() | ||||||
| 	req.URL.Scheme = "wss" | 	req.URL.Scheme = "wss" | ||||||
| 	return newWSConnection(o.transport, req) | 	return newWSConnection(o.transport, req) | ||||||
|  | @ -63,12 +63,13 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { | ||||||
| 	return o.resp, nil | 	return o.resp, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, error) { | func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { | ||||||
| 	dest, err := o.destination(r) | 	dest, err := o.destination(r) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
| 	return o.client.connect(r, dest) | 	conn, err := o.client.connect(r, dest) | ||||||
|  | 	return conn, nil, err | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // getRequestHost returns the host of the http.Request.
 | // getRequestHost returns the host of the http.Request.
 | ||||||
|  | @ -102,8 +103,10 @@ func removePath(dest string) string { | ||||||
| 	return strings.SplitN(dest, "/", 2)[0] | 	return strings.SplitN(dest, "/", 2)[0] | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, error) { | func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) { | ||||||
| 	return o.client.connect(r, o.dest) | 	conn, err := o.client.connect(r, o.dest) | ||||||
|  | 	return conn, nil, err | ||||||
|  | 
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type tcpClient struct { | type tcpClient struct { | ||||||
|  |  | ||||||
|  | @ -166,20 +166,22 @@ func (p *proxy) proxyConnection( | ||||||
| 	sourceConnectionType connection.Type, | 	sourceConnectionType connection.Type, | ||||||
| 	connectionProxy ingress.StreamBasedOriginProxy, | 	connectionProxy ingress.StreamBasedOriginProxy, | ||||||
| ) (*http.Response, error) { | ) (*http.Response, error) { | ||||||
| 	originConn, err := connectionProxy.EstablishConnection(req) | 	originConn, connectionResp, err := connectionProxy.EstablishConnection(req) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var eyeballConn io.ReadWriter = w | 	var eyeballConn io.ReadWriter = w | ||||||
| 	respHeader := http.Header{} | 	respHeader := http.Header{} | ||||||
|  | 	if connectionResp != nil { | ||||||
|  | 		respHeader = connectionResp.Header | ||||||
|  | 	} | ||||||
| 	if sourceConnectionType == connection.TypeWebsocket { | 	if sourceConnectionType == connection.TypeWebsocket { | ||||||
| 		wsReadWriter := websocket.NewConn(serveCtx, w, p.log) | 		wsReadWriter := websocket.NewConn(serveCtx, w, p.log) | ||||||
| 		// If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames
 | 		// If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames
 | ||||||
| 		if originConn.Type() != sourceConnectionType { | 		if originConn.Type() != sourceConnectionType { | ||||||
| 			eyeballConn = wsReadWriter | 			eyeballConn = wsReadWriter | ||||||
| 		} | 		} | ||||||
| 		respHeader = websocket.NewResponseHeader(req) |  | ||||||
| 	} | 	} | ||||||
| 	status := http.StatusSwitchingProtocols | 	status := http.StatusSwitchingProtocols | ||||||
| 	resp := &http.Response{ | 	resp := &http.Response{ | ||||||
|  |  | ||||||
|  | @ -411,7 +411,9 @@ func TestConnections(t *testing.T) { | ||||||
| 		originService  func(*testing.T, net.Listener) | 		originService  func(*testing.T, net.Listener) | ||||||
| 		eyeballService connection.ResponseWriter | 		eyeballService connection.ResponseWriter | ||||||
| 		connectionType connection.Type | 		connectionType connection.Type | ||||||
|  | 		requestHeaders http.Header | ||||||
| 		wantMessage    []byte | 		wantMessage    []byte | ||||||
|  | 		wantHeaders    http.Header | ||||||
| 	}{ | 	}{ | ||||||
| 		{ | 		{ | ||||||
| 			name:                 "ws-ws proxy", | 			name:                 "ws-ws proxy", | ||||||
|  | @ -419,7 +421,16 @@ func TestConnections(t *testing.T) { | ||||||
| 			originService:        runEchoWSService, | 			originService:        runEchoWSService, | ||||||
| 			eyeballService:       newWSRespWriter([]byte("test1"), replayer), | 			eyeballService:       newWSRespWriter([]byte("test1"), replayer), | ||||||
| 			connectionType:       connection.TypeWebsocket, | 			connectionType:       connection.TypeWebsocket, | ||||||
| 			wantMessage:          []byte("test1"), | 			requestHeaders: map[string][]string{ | ||||||
|  | 				"Test-Cloudflared-Echo": []string{"Echo"}, | ||||||
|  | 			}, | ||||||
|  | 			wantMessage: []byte("echo-test1"), | ||||||
|  | 			wantHeaders: map[string][]string{ | ||||||
|  | 				"Connection":            []string{"Upgrade"}, | ||||||
|  | 				"Sec-Websocket-Accept":  []string{"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="}, | ||||||
|  | 				"Upgrade":               []string{"websocket"}, | ||||||
|  | 				"Test-Cloudflared-Echo": []string{"Echo"}, | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:                 "tcp-tcp proxy", | 			name:                 "tcp-tcp proxy", | ||||||
|  | @ -430,15 +441,25 @@ func TestConnections(t *testing.T) { | ||||||
| 				replayer, | 				replayer, | ||||||
| 			), | 			), | ||||||
| 			connectionType: connection.TypeTCP, | 			connectionType: connection.TypeTCP, | ||||||
|  | 			requestHeaders: map[string][]string{ | ||||||
|  | 				"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"}, | ||||||
|  | 			}, | ||||||
| 			wantMessage: []byte("echo-test2"), | 			wantMessage: []byte("echo-test2"), | ||||||
|  | 			wantHeaders: http.Header{}, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:                 "tcp-ws proxy", | 			name:                 "tcp-ws proxy", | ||||||
| 			ingressServicePrefix: "ws://", | 			ingressServicePrefix: "ws://", | ||||||
| 			originService:        runEchoWSService, | 			originService:        runEchoWSService, | ||||||
| 			eyeballService:       newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")), | 			eyeballService:       newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")), | ||||||
|  | 			requestHeaders: map[string][]string{ | ||||||
|  | 				"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"}, | ||||||
|  | 			}, | ||||||
| 			connectionType: connection.TypeTCP, | 			connectionType: connection.TypeTCP, | ||||||
| 			wantMessage:          []byte("test3"), | 			wantMessage:    []byte("echo-test3"), | ||||||
|  | 			// We expect no headers here because they are sent back via
 | ||||||
|  | 			// the stream.
 | ||||||
|  | 			wantHeaders: http.Header{}, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:                 "ws-tcp proxy", | 			name:                 "ws-tcp proxy", | ||||||
|  | @ -447,14 +468,12 @@ func TestConnections(t *testing.T) { | ||||||
| 			eyeballService:       newWSRespWriter([]byte("test4"), replayer), | 			eyeballService:       newWSRespWriter([]byte("test4"), replayer), | ||||||
| 			connectionType:       connection.TypeWebsocket, | 			connectionType:       connection.TypeWebsocket, | ||||||
| 			wantMessage:          []byte("echo-test4"), | 			wantMessage:          []byte("echo-test4"), | ||||||
|  | 			wantHeaders:          http.Header{}, | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, test := range tests { | 	for _, test := range tests { | ||||||
| 		t.Run(test.name, func(t *testing.T) { | 		t.Run(test.name, func(t *testing.T) { | ||||||
| 			if test.skip { |  | ||||||
| 				t.Skip("todo: skipping a failing test. THis should be fixed before merge") |  | ||||||
| 			} |  | ||||||
| 			ctx, cancel := context.WithCancel(context.Background()) | 			ctx, cancel := context.WithCancel(context.Background()) | ||||||
| 			ln, err := net.Listen("tcp", "127.0.0.1:0") | 			ln, err := net.Listen("tcp", "127.0.0.1:0") | ||||||
| 			require.NoError(t, err) | 			require.NoError(t, err) | ||||||
|  | @ -466,7 +485,11 @@ func TestConnections(t *testing.T) { | ||||||
| 			proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger) | 			proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger) | ||||||
| 			req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil) | 			req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil) | ||||||
| 			require.NoError(t, err) | 			require.NoError(t, err) | ||||||
| 			req.Header.Set("Cf-Cloudflared-Proxy-Src", "non-blank-value") | 			reqHeaders := make(http.Header) | ||||||
|  | 			for k, vs := range test.requestHeaders { | ||||||
|  | 				reqHeaders[k] = vs | ||||||
|  | 			} | ||||||
|  | 			req.Header = reqHeaders | ||||||
| 
 | 
 | ||||||
| 			if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok { | 			if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok { | ||||||
| 				go func() { | 				go func() { | ||||||
|  | @ -474,21 +497,29 @@ func TestConnections(t *testing.T) { | ||||||
| 					replayer.Write(resp) | 					replayer.Write(resp) | ||||||
| 				}() | 				}() | ||||||
| 			} | 			} | ||||||
|  | 
 | ||||||
| 			err = proxy.Proxy(test.eyeballService, req, test.connectionType) | 			err = proxy.Proxy(test.eyeballService, req, test.connectionType) | ||||||
| 			require.NoError(t, err) | 			require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 			cancel() | 			cancel() | ||||||
| 			assert.Equal(t, test.wantMessage, replayer.Bytes()) | 			assert.Equal(t, test.wantMessage, replayer.Bytes()) | ||||||
|  | 			respPrinter := test.eyeballService.(responsePrinter) | ||||||
|  | 			assert.Equal(t, test.wantHeaders, respPrinter.printRespHeaders()) | ||||||
| 			replayer.rw.Reset() | 			replayer.rw.Reset() | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | type responsePrinter interface { | ||||||
|  | 	printRespHeaders() http.Header | ||||||
|  | } | ||||||
|  | 
 | ||||||
| type pipedWSWriter struct { | type pipedWSWriter struct { | ||||||
| 	dialer         gorillaWS.Dialer | 	dialer         gorillaWS.Dialer | ||||||
| 	wsConn         net.Conn | 	wsConn         net.Conn | ||||||
| 	pipedConn      net.Conn | 	pipedConn      net.Conn | ||||||
| 	respWriter     connection.ResponseWriter | 	respWriter     connection.ResponseWriter | ||||||
|  | 	respHeaders    http.Header | ||||||
| 	messageToWrite []byte | 	messageToWrite []byte | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -547,13 +578,20 @@ func (p *pipedWSWriter) WriteErrorResponse() { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *pipedWSWriter) WriteRespHeaders(status int, header http.Header) error { | func (p *pipedWSWriter) WriteRespHeaders(status int, header http.Header) error { | ||||||
|  | 	p.respHeaders = header | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // printRespHeaders is a test function to read respHeaders
 | ||||||
|  | func (p *pipedWSWriter) printRespHeaders() http.Header { | ||||||
|  | 	return p.respHeaders | ||||||
|  | } | ||||||
|  | 
 | ||||||
| type wsRespWriter struct { | type wsRespWriter struct { | ||||||
| 	w           io.Writer | 	w           io.Writer | ||||||
| 	pr          *io.PipeReader | 	pr          *io.PipeReader | ||||||
| 	pw          *io.PipeWriter | 	pw          *io.PipeWriter | ||||||
|  | 	respHeaders http.Header | ||||||
| 	code        int | 	code        int | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -589,6 +627,7 @@ func (w *wsRespWriter) Write(p []byte) (int, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error { | func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error { | ||||||
|  | 	w.respHeaders = header | ||||||
| 	w.code = status | 	w.code = status | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  | @ -596,6 +635,11 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error { | ||||||
| func (w *wsRespWriter) WriteErrorResponse() { | func (w *wsRespWriter) WriteErrorResponse() { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // printRespHeaders is a test function to read respHeaders
 | ||||||
|  | func (w *wsRespWriter) printRespHeaders() http.Header { | ||||||
|  | 	return w.respHeaders | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func runEchoTCPService(t *testing.T, l net.Listener) { | func runEchoTCPService(t *testing.T, l net.Listener) { | ||||||
| 	go func() { | 	go func() { | ||||||
| 		for { | 		for { | ||||||
|  | @ -628,7 +672,13 @@ func runEchoWSService(t *testing.T, l net.Listener) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var ws = func(w http.ResponseWriter, r *http.Request) { | 	var ws = func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		conn, err := upgrader.Upgrade(w, r, nil) | 		header := make(http.Header) | ||||||
|  | 		for k, vs := range r.Header { | ||||||
|  | 			if k == "Test-Cloudflared-Echo" { | ||||||
|  | 				header[k] = vs | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		conn, err := upgrader.Upgrade(w, r, header) | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
| 		defer conn.Close() | 		defer conn.Close() | ||||||
| 
 | 
 | ||||||
|  | @ -637,8 +687,9 @@ func runEchoWSService(t *testing.T, l net.Listener) { | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 
 | 			data := []byte("echo-") | ||||||
| 			if err := conn.WriteMessage(messageType, p); err != nil { | 			data = append(data, p...) | ||||||
|  | 			if err := conn.WriteMessage(messageType, data); err != nil { | ||||||
| 				return | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | @ -675,6 +726,7 @@ type mockTCPRespWriter struct { | ||||||
| 	w           io.Writer | 	w           io.Writer | ||||||
| 	pr          io.Reader | 	pr          io.Reader | ||||||
| 	pw          *io.PipeWriter | 	pw          *io.PipeWriter | ||||||
|  | 	respHeaders http.Header | ||||||
| 	code        int | 	code        int | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -701,6 +753,12 @@ func (m *mockTCPRespWriter) WriteErrorResponse() { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error { | func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error { | ||||||
|  | 	m.respHeaders = header | ||||||
| 	m.code = status | 	m.code = status | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // printRespHeaders is a test function to read respHeaders
 | ||||||
|  | func (m *mockTCPRespWriter) printRespHeaders() http.Header { | ||||||
|  | 	return m.respHeaders | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue