TUN-3853: Respond with ws headers from the origin service rather than generating our own
This commit is contained in:
parent
9c298e4851
commit
ed57ee64e8
|
@ -90,16 +90,16 @@ func (wsc *wsConnection) Type() connection.Type {
|
||||||
return connection.TypeWebsocket
|
return connection.TypeWebsocket
|
||||||
}
|
}
|
||||||
|
|
||||||
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) {
|
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, *http.Response, error) {
|
||||||
d := &gws.Dialer{
|
d := &gws.Dialer{
|
||||||
TLSClientConfig: transport.TLSClientConfig,
|
TLSClientConfig: transport.TLSClientConfig,
|
||||||
}
|
}
|
||||||
wsConn, resp, err := websocket.ClientConnect(r, d)
|
wsConn, resp, err := websocket.ClientConnect(r, d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return &wsConnection{
|
return &wsConnection{
|
||||||
wsConn,
|
wsConn,
|
||||||
resp,
|
resp,
|
||||||
}, nil
|
}, resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ type HTTPOriginProxy interface {
|
||||||
|
|
||||||
// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
|
// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
|
||||||
type StreamBasedOriginProxy interface {
|
type StreamBasedOriginProxy interface {
|
||||||
EstablishConnection(r *http.Request) (OriginConnection, error)
|
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
@ -29,8 +29,8 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: TUN-3636: establish connection to origins over UDS
|
// TODO: TUN-3636: establish connection to origins over UDS
|
||||||
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, error) {
|
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||||
return nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
|
return nil, nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
@ -40,7 +40,7 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
return o.transport.RoundTrip(req)
|
return o.transport.RoundTrip(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, error) {
|
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||||
req.URL.Host = o.url.Host
|
req.URL.Host = o.url.Host
|
||||||
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
|
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
|
||||||
return newWSConnection(o.transport, req)
|
return newWSConnection(o.transport, req)
|
||||||
|
@ -53,7 +53,7 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
return o.transport.RoundTrip(req)
|
return o.transport.RoundTrip(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, error) {
|
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||||
req.URL.Host = o.server.Addr().String()
|
req.URL.Host = o.server.Addr().String()
|
||||||
req.URL.Scheme = "wss"
|
req.URL.Scheme = "wss"
|
||||||
return newWSConnection(o.transport, req)
|
return newWSConnection(o.transport, req)
|
||||||
|
@ -63,12 +63,13 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||||
return o.resp, nil
|
return o.resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, error) {
|
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||||
dest, err := o.destination(r)
|
dest, err := o.destination(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return o.client.connect(r, dest)
|
conn, err := o.client.connect(r, dest)
|
||||||
|
return conn, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRequestHost returns the host of the http.Request.
|
// getRequestHost returns the host of the http.Request.
|
||||||
|
@ -102,8 +103,10 @@ func removePath(dest string) string {
|
||||||
return strings.SplitN(dest, "/", 2)[0]
|
return strings.SplitN(dest, "/", 2)[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, error) {
|
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||||
return o.client.connect(r, o.dest)
|
conn, err := o.client.connect(r, o.dest)
|
||||||
|
return conn, nil, err
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type tcpClient struct {
|
type tcpClient struct {
|
||||||
|
|
|
@ -166,20 +166,22 @@ func (p *proxy) proxyConnection(
|
||||||
sourceConnectionType connection.Type,
|
sourceConnectionType connection.Type,
|
||||||
connectionProxy ingress.StreamBasedOriginProxy,
|
connectionProxy ingress.StreamBasedOriginProxy,
|
||||||
) (*http.Response, error) {
|
) (*http.Response, error) {
|
||||||
originConn, err := connectionProxy.EstablishConnection(req)
|
originConn, connectionResp, err := connectionProxy.EstablishConnection(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var eyeballConn io.ReadWriter = w
|
var eyeballConn io.ReadWriter = w
|
||||||
respHeader := http.Header{}
|
respHeader := http.Header{}
|
||||||
|
if connectionResp != nil {
|
||||||
|
respHeader = connectionResp.Header
|
||||||
|
}
|
||||||
if sourceConnectionType == connection.TypeWebsocket {
|
if sourceConnectionType == connection.TypeWebsocket {
|
||||||
wsReadWriter := websocket.NewConn(serveCtx, w, p.log)
|
wsReadWriter := websocket.NewConn(serveCtx, w, p.log)
|
||||||
// If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames
|
// If cloudflared <-> origin is not websocket, we need to decode TCP data out of WS frames
|
||||||
if originConn.Type() != sourceConnectionType {
|
if originConn.Type() != sourceConnectionType {
|
||||||
eyeballConn = wsReadWriter
|
eyeballConn = wsReadWriter
|
||||||
}
|
}
|
||||||
respHeader = websocket.NewResponseHeader(req)
|
|
||||||
}
|
}
|
||||||
status := http.StatusSwitchingProtocols
|
status := http.StatusSwitchingProtocols
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
|
|
|
@ -411,7 +411,9 @@ func TestConnections(t *testing.T) {
|
||||||
originService func(*testing.T, net.Listener)
|
originService func(*testing.T, net.Listener)
|
||||||
eyeballService connection.ResponseWriter
|
eyeballService connection.ResponseWriter
|
||||||
connectionType connection.Type
|
connectionType connection.Type
|
||||||
|
requestHeaders http.Header
|
||||||
wantMessage []byte
|
wantMessage []byte
|
||||||
|
wantHeaders http.Header
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "ws-ws proxy",
|
name: "ws-ws proxy",
|
||||||
|
@ -419,7 +421,16 @@ func TestConnections(t *testing.T) {
|
||||||
originService: runEchoWSService,
|
originService: runEchoWSService,
|
||||||
eyeballService: newWSRespWriter([]byte("test1"), replayer),
|
eyeballService: newWSRespWriter([]byte("test1"), replayer),
|
||||||
connectionType: connection.TypeWebsocket,
|
connectionType: connection.TypeWebsocket,
|
||||||
wantMessage: []byte("test1"),
|
requestHeaders: map[string][]string{
|
||||||
|
"Test-Cloudflared-Echo": []string{"Echo"},
|
||||||
|
},
|
||||||
|
wantMessage: []byte("echo-test1"),
|
||||||
|
wantHeaders: map[string][]string{
|
||||||
|
"Connection": []string{"Upgrade"},
|
||||||
|
"Sec-Websocket-Accept": []string{"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
|
||||||
|
"Upgrade": []string{"websocket"},
|
||||||
|
"Test-Cloudflared-Echo": []string{"Echo"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "tcp-tcp proxy",
|
name: "tcp-tcp proxy",
|
||||||
|
@ -430,15 +441,25 @@ func TestConnections(t *testing.T) {
|
||||||
replayer,
|
replayer,
|
||||||
),
|
),
|
||||||
connectionType: connection.TypeTCP,
|
connectionType: connection.TypeTCP,
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"},
|
||||||
|
},
|
||||||
wantMessage: []byte("echo-test2"),
|
wantMessage: []byte("echo-test2"),
|
||||||
|
wantHeaders: http.Header{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "tcp-ws proxy",
|
name: "tcp-ws proxy",
|
||||||
ingressServicePrefix: "ws://",
|
ingressServicePrefix: "ws://",
|
||||||
originService: runEchoWSService,
|
originService: runEchoWSService,
|
||||||
eyeballService: newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")),
|
eyeballService: newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")),
|
||||||
|
requestHeaders: map[string][]string{
|
||||||
|
"Cf-Cloudflared-Proxy-Src": []string{"non-blank-value"},
|
||||||
|
},
|
||||||
connectionType: connection.TypeTCP,
|
connectionType: connection.TypeTCP,
|
||||||
wantMessage: []byte("test3"),
|
wantMessage: []byte("echo-test3"),
|
||||||
|
// We expect no headers here because they are sent back via
|
||||||
|
// the stream.
|
||||||
|
wantHeaders: http.Header{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "ws-tcp proxy",
|
name: "ws-tcp proxy",
|
||||||
|
@ -447,14 +468,12 @@ func TestConnections(t *testing.T) {
|
||||||
eyeballService: newWSRespWriter([]byte("test4"), replayer),
|
eyeballService: newWSRespWriter([]byte("test4"), replayer),
|
||||||
connectionType: connection.TypeWebsocket,
|
connectionType: connection.TypeWebsocket,
|
||||||
wantMessage: []byte("echo-test4"),
|
wantMessage: []byte("echo-test4"),
|
||||||
|
wantHeaders: http.Header{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
if test.skip {
|
|
||||||
t.Skip("todo: skipping a failing test. THis should be fixed before merge")
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -466,7 +485,11 @@ func TestConnections(t *testing.T) {
|
||||||
proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger)
|
proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger)
|
||||||
req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil)
|
req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
req.Header.Set("Cf-Cloudflared-Proxy-Src", "non-blank-value")
|
reqHeaders := make(http.Header)
|
||||||
|
for k, vs := range test.requestHeaders {
|
||||||
|
reqHeaders[k] = vs
|
||||||
|
}
|
||||||
|
req.Header = reqHeaders
|
||||||
|
|
||||||
if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok {
|
if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok {
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -474,21 +497,29 @@ func TestConnections(t *testing.T) {
|
||||||
replayer.Write(resp)
|
replayer.Write(resp)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = proxy.Proxy(test.eyeballService, req, test.connectionType)
|
err = proxy.Proxy(test.eyeballService, req, test.connectionType)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
assert.Equal(t, test.wantMessage, replayer.Bytes())
|
assert.Equal(t, test.wantMessage, replayer.Bytes())
|
||||||
|
respPrinter := test.eyeballService.(responsePrinter)
|
||||||
|
assert.Equal(t, test.wantHeaders, respPrinter.printRespHeaders())
|
||||||
replayer.rw.Reset()
|
replayer.rw.Reset()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type responsePrinter interface {
|
||||||
|
printRespHeaders() http.Header
|
||||||
|
}
|
||||||
|
|
||||||
type pipedWSWriter struct {
|
type pipedWSWriter struct {
|
||||||
dialer gorillaWS.Dialer
|
dialer gorillaWS.Dialer
|
||||||
wsConn net.Conn
|
wsConn net.Conn
|
||||||
pipedConn net.Conn
|
pipedConn net.Conn
|
||||||
respWriter connection.ResponseWriter
|
respWriter connection.ResponseWriter
|
||||||
|
respHeaders http.Header
|
||||||
messageToWrite []byte
|
messageToWrite []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -547,13 +578,20 @@ func (p *pipedWSWriter) WriteErrorResponse() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *pipedWSWriter) WriteRespHeaders(status int, header http.Header) error {
|
func (p *pipedWSWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
|
p.respHeaders = header
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// printRespHeaders is a test function to read respHeaders
|
||||||
|
func (p *pipedWSWriter) printRespHeaders() http.Header {
|
||||||
|
return p.respHeaders
|
||||||
|
}
|
||||||
|
|
||||||
type wsRespWriter struct {
|
type wsRespWriter struct {
|
||||||
w io.Writer
|
w io.Writer
|
||||||
pr *io.PipeReader
|
pr *io.PipeReader
|
||||||
pw *io.PipeWriter
|
pw *io.PipeWriter
|
||||||
|
respHeaders http.Header
|
||||||
code int
|
code int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -589,6 +627,7 @@ func (w *wsRespWriter) Write(p []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
|
w.respHeaders = header
|
||||||
w.code = status
|
w.code = status
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -596,6 +635,11 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
func (w *wsRespWriter) WriteErrorResponse() {
|
func (w *wsRespWriter) WriteErrorResponse() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// printRespHeaders is a test function to read respHeaders
|
||||||
|
func (w *wsRespWriter) printRespHeaders() http.Header {
|
||||||
|
return w.respHeaders
|
||||||
|
}
|
||||||
|
|
||||||
func runEchoTCPService(t *testing.T, l net.Listener) {
|
func runEchoTCPService(t *testing.T, l net.Listener) {
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
|
@ -628,7 +672,13 @@ func runEchoWSService(t *testing.T, l net.Listener) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var ws = func(w http.ResponseWriter, r *http.Request) {
|
var ws = func(w http.ResponseWriter, r *http.Request) {
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
header := make(http.Header)
|
||||||
|
for k, vs := range r.Header {
|
||||||
|
if k == "Test-Cloudflared-Echo" {
|
||||||
|
header[k] = vs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn, err := upgrader.Upgrade(w, r, header)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
@ -637,8 +687,9 @@ func runEchoWSService(t *testing.T, l net.Listener) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
data := []byte("echo-")
|
||||||
if err := conn.WriteMessage(messageType, p); err != nil {
|
data = append(data, p...)
|
||||||
|
if err := conn.WriteMessage(messageType, data); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -675,6 +726,7 @@ type mockTCPRespWriter struct {
|
||||||
w io.Writer
|
w io.Writer
|
||||||
pr io.Reader
|
pr io.Reader
|
||||||
pw *io.PipeWriter
|
pw *io.PipeWriter
|
||||||
|
respHeaders http.Header
|
||||||
code int
|
code int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -701,6 +753,12 @@ func (m *mockTCPRespWriter) WriteErrorResponse() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||||
|
m.respHeaders = header
|
||||||
m.code = status
|
m.code = status
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// printRespHeaders is a test function to read respHeaders
|
||||||
|
func (m *mockTCPRespWriter) printRespHeaders() http.Header {
|
||||||
|
return m.respHeaders
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue