TUN-3838: ResponseWriter no longer reads and origin error tests

This commit is contained in:
Sudarsan Reddy 2021-02-11 14:36:42 +00:00 committed by Nuno Diegues
parent ab4dda5427
commit e20c4f8752
3 changed files with 300 additions and 209 deletions

View File

@ -93,8 +93,7 @@ type OriginProxy interface {
type ResponseWriter interface {
WriteRespHeaders(status int, header http.Header) error
WriteErrorResponse()
io.ReadWriter
io.Writer
}
type ConnectedFuse interface {

View File

@ -177,11 +177,28 @@ func (p *proxy) proxyStreamRequest(
originConn.Close()
}()
originConn.Stream(serveCtx, w, p.log)
eyeballStream := &bidirectionalStream{
writer: w,
reader: req.Body,
}
originConn.Stream(serveCtx, eyeballStream, p.log)
p.logOriginResponse(resp, fields)
return nil
}
type bidirectionalStream struct {
reader io.Reader
writer io.Writer
}
func (wr *bidirectionalStream) Read(p []byte) (n int, err error) {
return wr.reader.Read(p)
}
func (wr *bidirectionalStream) Write(p []byte) (n int, err error) {
return wr.writer.Write(p)
}
func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
reader := bufio.NewReader(respBody)
for {

View File

@ -52,11 +52,6 @@ func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) er
return nil
}
func (w *mockHTTPRespWriter) WriteErrorResponse() {
w.WriteHeader(http.StatusBadGateway)
_, _ = w.Write([]byte("http response error"))
}
func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
}
@ -140,14 +135,14 @@ func TestProxySingleOrigin(t *testing.T) {
func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
return func(t *testing.T) {
respWriter := newMockHTTPRespWriter()
responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
require.NoError(t, err)
err = proxy.Proxy(respWriter, req, connection.TypeHTTP)
err = proxy.Proxy(responseWriter, req, connection.TypeHTTP)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, respWriter.Code)
assert.Equal(t, http.StatusOK, responseWriter.Code)
}
}
@ -155,19 +150,18 @@ func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *test
return func(t *testing.T) {
// WSRoute is a websocket echo handler
ctx, cancel := context.WithCancel(context.Background())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), nil)
readPipe, writePipe := io.Pipe()
respWriter := newMockWSRespWriter(readPipe)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), readPipe)
responseWriter := newMockWSRespWriter(readPipe)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err = proxy.Proxy(respWriter, req, connection.TypeWebsocket)
err = proxy.Proxy(responseWriter, req, connection.TypeWebsocket)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code)
require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code)
}()
msg := []byte("test websocket")
@ -175,14 +169,14 @@ func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *test
require.NoError(t, err)
// ReadServerText reads next data message from rw, considering that caller represents proxy side.
returnedMsg, err := wsutil.ReadServerText(respWriter.respBody())
returnedMsg, err := wsutil.ReadServerText(responseWriter.respBody())
require.NoError(t, err)
require.Equal(t, msg, returnedMsg)
err = wsutil.WriteClientBinary(writePipe, msg)
require.NoError(t, err)
returnedMsg, err = wsutil.ReadServerBinary(respWriter.respBody())
returnedMsg, err = wsutil.ReadServerBinary(responseWriter.respBody())
require.NoError(t, err)
require.Equal(t, msg, returnedMsg)
@ -197,7 +191,7 @@ func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T)
pushCount = 50
pushFreq = time.Millisecond * 10
)
respWriter := newMockSSERespWriter()
responseWriter := newMockSSERespWriter()
ctx, cancel := context.WithCancel(context.Background())
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s?freq=%s", hello.SSERoute, pushFreq), nil)
require.NoError(t, err)
@ -206,18 +200,18 @@ func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T)
wg.Add(1)
go func() {
defer wg.Done()
err = proxy.Proxy(respWriter, req, connection.TypeHTTP)
err = proxy.Proxy(responseWriter, req, connection.TypeHTTP)
require.NoError(t, err)
require.Equal(t, http.StatusOK, respWriter.Code)
require.Equal(t, http.StatusOK, responseWriter.Code)
}()
for i := 0; i < pushCount; i++ {
line := respWriter.ReadBytes()
line := responseWriter.ReadBytes()
expect := fmt.Sprintf("%d\n", i)
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
line = respWriter.ReadBytes()
line = responseWriter.ReadBytes()
require.Equal(t, []byte("\n"), line, fmt.Sprintf("Expect to read '\n', got %v", line))
}
@ -295,18 +289,18 @@ func TestProxyMultipleOrigins(t *testing.T) {
}
for _, test := range tests {
respWriter := newMockHTTPRespWriter()
responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, test.url, nil)
require.NoError(t, err)
err = proxy.Proxy(respWriter, req, connection.TypeHTTP)
err = proxy.Proxy(responseWriter, req, connection.TypeHTTP)
require.NoError(t, err)
assert.Equal(t, test.expectedStatus, respWriter.Code)
assert.Equal(t, test.expectedStatus, responseWriter.Code)
if test.expectedBody != nil {
assert.Equal(t, test.expectedBody, respWriter.Body.Bytes())
assert.Equal(t, test.expectedBody, responseWriter.Body.Bytes())
} else {
assert.Equal(t, 0, respWriter.Body.Len())
assert.Equal(t, 0, responseWriter.Body.Len())
}
}
cancel()
@ -343,11 +337,11 @@ func TestProxyError(t *testing.T) {
proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log)
respWriter := newMockHTTPRespWriter()
responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
assert.NoError(t, err)
assert.Error(t, proxy.Proxy(respWriter, req, connection.TypeHTTP))
assert.Error(t, proxy.Proxy(responseWriter, req, connection.TypeHTTP))
}
type replayer struct {
@ -399,82 +393,171 @@ func (r *replayer) Bytes() []byte {
func TestConnections(t *testing.T) {
logger := logger.Create(nil)
replayer := &replayer{rw: &bytes.Buffer{}}
type args struct {
ingressServiceScheme string
originService func(*testing.T, net.Listener)
eyeballResponseWriter connection.ResponseWriter
eyeballRequestBody io.ReadCloser
// Can be set to nil to show warp routing is not enabled.
warpRoutingService *ingress.WarpRoutingService
// eyeball connection type.
connectionType connection.Type
//requestheaders to be sent in the call to proxy.Proxy
requestHeaders http.Header
}
type want struct {
message []byte
headers http.Header
err bool
}
var tests = []struct {
name string
skip bool
ingressServicePrefix string
originService func(*testing.T, net.Listener)
eyeballService connection.ResponseWriter
connectionType connection.Type
requestHeaders http.Header
wantMessage []byte
wantHeaders http.Header
args args
want want
}{
{
name: "ws-ws proxy",
ingressServicePrefix: "ws://",
args: args{
ingressServiceScheme: "ws://",
originService: runEchoWSService,
eyeballService: newWSRespWriter([]byte("test1"), replayer),
eyeballResponseWriter: newWSRespWriter(replayer),
eyeballRequestBody: newWSRequestBody([]byte("test1")),
connectionType: connection.TypeWebsocket,
requestHeaders: http.Header{
requestHeaders: map[string][]string{
// Example key from https://tools.ietf.org/html/rfc6455#section-1.2
"Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="},
"Test-Cloudflared-Echo": {"Echo"},
},
wantMessage: []byte("echo-test1"),
wantHeaders: http.Header{
},
want: want{
message: []byte("echo-test1"),
headers: map[string][]string{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
"Test-Cloudflared-Echo": {"Echo"},
},
},
},
{
name: "tcp-tcp proxy",
ingressServicePrefix: "tcp://",
args: args{
ingressServiceScheme: "tcp://",
originService: runEchoTCPService,
eyeballService: newTCPRespWriter(
[]byte(`test2`),
replayer,
),
eyeballResponseWriter: newTCPRespWriter(replayer),
eyeballRequestBody: newTCPRequestBody([]byte("test2")),
warpRoutingService: ingress.NewWarpRoutingService(),
connectionType: connection.TypeTCP,
requestHeaders: http.Header{
requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
},
wantMessage: []byte("echo-test2"),
},
want: want{
message: []byte("echo-test2"),
},
},
{
name: "tcp-ws proxy",
ingressServicePrefix: "ws://",
args: args{
ingressServiceScheme: "ws://",
originService: runEchoWSService,
eyeballService: newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")),
requestHeaders: http.Header{
//eyeballResponseWriter gets set after roundtrip dial.
eyeballRequestBody: newPipedWSRequestBody([]byte("test3")),
warpRoutingService: ingress.NewWarpRoutingService(),
requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
},
connectionType: connection.TypeTCP,
wantMessage: []byte("echo-test3"),
},
want: want{
message: []byte("echo-test3"),
// We expect no headers here because they are sent back via
// the stream.
},
},
{
name: "ws-tcp proxy",
ingressServicePrefix: "tcp://",
args: args{
ingressServiceScheme: "tcp://",
originService: runEchoTCPService,
eyeballService: newWSRespWriter([]byte("test4"), replayer),
eyeballResponseWriter: newWSRespWriter(replayer),
eyeballRequestBody: newWSRequestBody([]byte("test4")),
connectionType: connection.TypeWebsocket,
requestHeaders: http.Header{
requestHeaders: map[string][]string{
// Example key from https://tools.ietf.org/html/rfc6455#section-1.2
"Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="},
},
wantMessage: []byte("echo-test4"),
wantHeaders: http.Header{
},
want: want{
message: []byte("echo-test4"),
headers: map[string][]string{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
},
},
},
{
name: "tcp-tcp proxy without warpRoutingService enabled",
args: args{
ingressServiceScheme: "tcp://",
originService: runEchoTCPService,
eyeballResponseWriter: newTCPRespWriter(replayer),
eyeballRequestBody: newTCPRequestBody([]byte("test2")),
connectionType: connection.TypeTCP,
requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
},
},
want: want{
message: []byte{},
err: true,
},
},
{
name: "ws-ws proxy when origin is different",
args: args{
ingressServiceScheme: "ws://",
originService: runEchoWSService,
eyeballResponseWriter: newWSRespWriter(replayer),
eyeballRequestBody: newWSRequestBody([]byte("test1")),
connectionType: connection.TypeWebsocket,
requestHeaders: map[string][]string{
// Example key from https://tools.ietf.org/html/rfc6455#section-1.2
"Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="},
"Origin": {"Different origin"},
},
},
want: want{
message: []byte{},
err: true,
},
},
{
name: "tcp-* proxy when origin service has already closed the connection/ is no longer running",
args: args{
ingressServiceScheme: "tcp://",
originService: func(t *testing.T, ln net.Listener) {
// closing the listener created by the test.
ln.Close()
},
eyeballResponseWriter: newTCPRespWriter(replayer),
eyeballRequestBody: newTCPRequestBody([]byte("test2")),
connectionType: connection.TypeTCP,
requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
},
},
want: want{
message: []byte{},
err: true,
},
},
}
for _, test := range tests {
@ -483,69 +566,99 @@ func TestConnections(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
// Starts origin service
test.originService(t, ln)
test.args.originService(t, ln)
ingressRule := createSingleIngressConfig(t, test.ingressServicePrefix+ln.Addr().String())
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
var wg sync.WaitGroup
errC := make(chan error)
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC)
proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger)
proxy := NewOriginProxy(ingressRule, test.args.warpRoutingService, testTags, logger)
req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil)
req, err := http.NewRequest(
http.MethodGet,
test.args.ingressServiceScheme+ln.Addr().String(),
test.args.eyeballRequestBody,
)
require.NoError(t, err)
req.Header = test.requestHeaders
if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok {
req.Header = test.args.requestHeaders
respWriter := test.args.eyeballResponseWriter
if pipedReqBody, ok := test.args.eyeballRequestBody.(*pipedRequestBody); ok {
respWriter = newTCPRespWriter(pipedReqBody.pipedConn)
go func() {
resp := pipedWS.roundtrip(test.ingressServicePrefix + ln.Addr().String())
resp := pipedReqBody.roundtrip(test.args.ingressServiceScheme + ln.Addr().String())
replayer.Write(resp)
}()
}
err = proxy.Proxy(test.eyeballService, req, test.connectionType)
require.NoError(t, err)
err = proxy.Proxy(respWriter, req, test.args.connectionType)
cancel()
assert.Equal(t, test.wantMessage, replayer.Bytes())
respPrinter := test.eyeballService.(responsePrinter)
assert.Equal(t, test.wantHeaders, respPrinter.printRespHeaders())
assert.Equal(t, test.want.err, err != nil)
assert.Equal(t, test.want.message, replayer.Bytes())
respPrinter := respWriter.(responsePrinter)
assert.Equal(t, test.want.headers, respPrinter.headers())
replayer.rw.Reset()
})
}
}
type responsePrinter interface {
printRespHeaders() http.Header
type requestBody struct {
pw *io.PipeWriter
pr *io.PipeReader
}
type pipedWSWriter struct {
func newWSRequestBody(data []byte) *requestBody {
pr, pw := io.Pipe()
go wsutil.WriteClientBinary(pw, data)
return &requestBody{
pr: pr,
pw: pw,
}
}
func newTCPRequestBody(data []byte) *requestBody {
pr, pw := io.Pipe()
go pw.Write(data)
return &requestBody{
pr: pr,
pw: pw,
}
}
func (r *requestBody) Read(p []byte) (n int, err error) {
return r.pr.Read(p)
}
func (r *requestBody) Close() error {
r.pw.Close()
r.pr.Close()
return nil
}
type pipedRequestBody struct {
dialer gorillaWS.Dialer
wsConn net.Conn
pipedConn net.Conn
respWriter connection.ResponseWriter
respHeaders http.Header
wsConn net.Conn
messageToWrite []byte
}
func newPipedWSWriter(rw *mockTCPRespWriter, messageToWrite []byte) *pipedWSWriter {
func newPipedWSRequestBody(data []byte) *pipedRequestBody {
conn1, conn2 := net.Pipe()
dialer := gorillaWS.Dialer{
NetDial: func(network, addr string) (net.Conn, error) {
return conn2, nil
},
}
rw.pr = conn1
rw.w = conn1
return &pipedWSWriter{
return &pipedRequestBody{
dialer: dialer,
pipedConn: conn1,
wsConn: conn2,
messageToWrite: messageToWrite,
respWriter: rw,
messageToWrite: data,
}
}
func (p *pipedWSWriter) roundtrip(addr string) []byte {
func (p *pipedRequestBody) roundtrip(addr string) []byte {
header := http.Header{}
conn, resp, err := p.dialer.Dial(addr, header)
if err != nil {
@ -570,56 +683,35 @@ func (p *pipedWSWriter) roundtrip(addr string) []byte {
return data
}
func (p *pipedWSWriter) Read(data []byte) (int, error) {
func (p *pipedRequestBody) Read(data []byte) (n int, err error) {
return p.pipedConn.Read(data)
}
func (p *pipedWSWriter) Write(data []byte) (int, error) {
return p.pipedConn.Write(data)
}
func (p *pipedWSWriter) WriteErrorResponse() {
}
func (p *pipedWSWriter) WriteRespHeaders(status int, header http.Header) error {
p.respHeaders = header
func (p *pipedRequestBody) Close() error {
return nil
}
// printRespHeaders is a test function to read respHeaders
func (p *pipedWSWriter) printRespHeaders() http.Header {
return p.respHeaders
type responsePrinter interface {
headers() http.Header
}
type wsRespWriter struct {
w io.Writer
pr *io.PipeReader
pw *io.PipeWriter
respHeaders http.Header
responseHeaders http.Header
code int
}
// newWSRespWriter uses wsutil.WriteClientText to generate websocket frames.
// and wsutil.ReadClientText to translate frames from server to byte data.
// In essence, this acts as a wsClient.
func newWSRespWriter(data []byte, w io.Writer) *wsRespWriter {
pr, pw := io.Pipe()
go wsutil.WriteClientBinary(pw, data)
func newWSRespWriter(w io.Writer) *wsRespWriter {
return &wsRespWriter{
w: w,
pr: pr,
pw: pw,
}
}
// Read is read by ingress.Stream and serves as the input from the client.
func (w *wsRespWriter) Read(p []byte) (int, error) {
return w.pr.Read(p)
}
// Write is written to by ingress.Stream and serves as the output to the client.
func (w *wsRespWriter) Write(p []byte) (int, error) {
defer w.pw.Close()
returnedMsg, err := wsutil.ReadServerBinary(bytes.NewBuffer(p))
if err != nil {
// The data was not returned by a websocket connecton.
@ -631,17 +723,55 @@ func (w *wsRespWriter) Write(p []byte) (int, error) {
}
func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
w.respHeaders = header
w.responseHeaders = header
w.code = status
return nil
}
func (w *wsRespWriter) WriteErrorResponse() {
// respHeaders is a test function to read respHeaders
func (w *wsRespWriter) headers() http.Header {
return w.responseHeaders
}
// printRespHeaders is a test function to read respHeaders
func (w *wsRespWriter) printRespHeaders() http.Header {
return w.respHeaders
type mockTCPRespWriter struct {
w io.Writer
responseHeaders http.Header
code int
}
func newTCPRespWriter(w io.Writer) *mockTCPRespWriter {
return &mockTCPRespWriter{
w: w,
}
}
func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
return m.w.Write(p)
}
func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
m.responseHeaders = header
m.code = status
return nil
}
// respHeaders is a test function to read respHeaders
func (m *mockTCPRespWriter) headers() http.Header {
return m.responseHeaders
}
func createSingleIngressConfig(t *testing.T, service string) ingress.Ingress {
ingressConfig := &config.Configuration{
Ingress: []config.UnvalidatedIngressRule{
{
Hostname: "*",
Service: service,
},
},
}
ingressRule, err := ingress.ParseIngress(ingressConfig)
require.NoError(t, err)
return ingressRule
}
func runEchoTCPService(t *testing.T, l net.Listener) {
@ -662,8 +792,8 @@ func runEchoTCPService(t *testing.T, l net.Listener) {
_, err = conn.Write(data)
if err != nil {
t.Log(err)
return
}
return
}
}
}()
@ -683,7 +813,10 @@ func runEchoWSService(t *testing.T, l net.Listener) {
}
}
conn, err := upgrader.Upgrade(w, r, header)
require.NoError(t, err)
if err != nil {
t.Log(err)
return
}
defer conn.Close()
for {
@ -708,61 +841,3 @@ func runEchoWSService(t *testing.T, l net.Listener) {
require.NoError(t, err)
}()
}
func createSingleIngressConfig(t *testing.T, service string) ingress.Ingress {
ingressConfig := &config.Configuration{
Ingress: []config.UnvalidatedIngressRule{
{
Hostname: "*",
Service: service,
},
},
}
ingressRule, err := ingress.ParseIngress(ingressConfig)
require.NoError(t, err)
return ingressRule
}
type tcpWrappedWs struct {
}
type mockTCPRespWriter struct {
w io.Writer
pr io.Reader
pw *io.PipeWriter
respHeaders http.Header
code int
}
func newTCPRespWriter(data []byte, w io.Writer) *mockTCPRespWriter {
pr, pw := io.Pipe()
go pw.Write(data)
return &mockTCPRespWriter{
w: w,
pr: pr,
pw: pw,
}
}
func (m *mockTCPRespWriter) Read(p []byte) (n int, err error) {
return m.pr.Read(p)
}
func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
defer m.pw.Close()
return m.w.Write(p)
}
func (m *mockTCPRespWriter) WriteErrorResponse() {
}
func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
m.respHeaders = header
m.code = status
return nil
}
// printRespHeaders is a test function to read respHeaders
func (m *mockTCPRespWriter) printRespHeaders() http.Header {
return m.respHeaders
}