TUN-3838: ResponseWriter no longer reads and origin error tests
This commit is contained in:
		
							parent
							
								
									ab4dda5427
								
							
						
					
					
						commit
						e20c4f8752
					
				|  | @ -93,8 +93,7 @@ type OriginProxy interface { | ||||||
| 
 | 
 | ||||||
| type ResponseWriter interface { | type ResponseWriter interface { | ||||||
| 	WriteRespHeaders(status int, header http.Header) error | 	WriteRespHeaders(status int, header http.Header) error | ||||||
| 	WriteErrorResponse() | 	io.Writer | ||||||
| 	io.ReadWriter |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type ConnectedFuse interface { | type ConnectedFuse interface { | ||||||
|  |  | ||||||
|  | @ -177,11 +177,28 @@ func (p *proxy) proxyStreamRequest( | ||||||
| 		originConn.Close() | 		originConn.Close() | ||||||
| 	}() | 	}() | ||||||
| 
 | 
 | ||||||
| 	originConn.Stream(serveCtx, w, p.log) | 	eyeballStream := &bidirectionalStream{ | ||||||
|  | 		writer: w, | ||||||
|  | 		reader: req.Body, | ||||||
|  | 	} | ||||||
|  | 	originConn.Stream(serveCtx, eyeballStream, p.log) | ||||||
| 	p.logOriginResponse(resp, fields) | 	p.logOriginResponse(resp, fields) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | type bidirectionalStream struct { | ||||||
|  | 	reader io.Reader | ||||||
|  | 	writer io.Writer | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (wr *bidirectionalStream) Read(p []byte) (n int, err error) { | ||||||
|  | 	return wr.reader.Read(p) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (wr *bidirectionalStream) Write(p []byte) (n int, err error) { | ||||||
|  | 	return wr.writer.Write(p) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) { | func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) { | ||||||
| 	reader := bufio.NewReader(respBody) | 	reader := bufio.NewReader(respBody) | ||||||
| 	for { | 	for { | ||||||
|  |  | ||||||
|  | @ -52,11 +52,6 @@ func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) er | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (w *mockHTTPRespWriter) WriteErrorResponse() { |  | ||||||
| 	w.WriteHeader(http.StatusBadGateway) |  | ||||||
| 	_, _ = w.Write([]byte("http response error")) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (w *mockHTTPRespWriter) Read(data []byte) (int, error) { | func (w *mockHTTPRespWriter) Read(data []byte) (int, error) { | ||||||
| 	return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader") | 	return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader") | ||||||
| } | } | ||||||
|  | @ -140,14 +135,14 @@ func TestProxySingleOrigin(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) { | func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) { | ||||||
| 	return func(t *testing.T) { | 	return func(t *testing.T) { | ||||||
| 		respWriter := newMockHTTPRespWriter() | 		responseWriter := newMockHTTPRespWriter() | ||||||
| 		req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) | 		req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 		err = proxy.Proxy(respWriter, req, connection.TypeHTTP) | 		err = proxy.Proxy(responseWriter, req, connection.TypeHTTP) | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 		assert.Equal(t, http.StatusOK, respWriter.Code) | 		assert.Equal(t, http.StatusOK, responseWriter.Code) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -155,19 +150,18 @@ func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *test | ||||||
| 	return func(t *testing.T) { | 	return func(t *testing.T) { | ||||||
| 		// WSRoute is a websocket echo handler
 | 		// WSRoute is a websocket echo handler
 | ||||||
| 		ctx, cancel := context.WithCancel(context.Background()) | 		ctx, cancel := context.WithCancel(context.Background()) | ||||||
| 		req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), nil) |  | ||||||
| 
 |  | ||||||
| 		readPipe, writePipe := io.Pipe() | 		readPipe, writePipe := io.Pipe() | ||||||
| 		respWriter := newMockWSRespWriter(readPipe) | 		req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), readPipe) | ||||||
|  | 		responseWriter := newMockWSRespWriter(readPipe) | ||||||
| 
 | 
 | ||||||
| 		var wg sync.WaitGroup | 		var wg sync.WaitGroup | ||||||
| 		wg.Add(1) | 		wg.Add(1) | ||||||
| 		go func() { | 		go func() { | ||||||
| 			defer wg.Done() | 			defer wg.Done() | ||||||
| 			err = proxy.Proxy(respWriter, req, connection.TypeWebsocket) | 			err = proxy.Proxy(responseWriter, req, connection.TypeWebsocket) | ||||||
| 			require.NoError(t, err) | 			require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 			require.Equal(t, http.StatusSwitchingProtocols, respWriter.Code) | 			require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code) | ||||||
| 		}() | 		}() | ||||||
| 
 | 
 | ||||||
| 		msg := []byte("test websocket") | 		msg := []byte("test websocket") | ||||||
|  | @ -175,14 +169,14 @@ func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *test | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 		// ReadServerText reads next data message from rw, considering that caller represents proxy side.
 | 		// ReadServerText reads next data message from rw, considering that caller represents proxy side.
 | ||||||
| 		returnedMsg, err := wsutil.ReadServerText(respWriter.respBody()) | 		returnedMsg, err := wsutil.ReadServerText(responseWriter.respBody()) | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
| 		require.Equal(t, msg, returnedMsg) | 		require.Equal(t, msg, returnedMsg) | ||||||
| 
 | 
 | ||||||
| 		err = wsutil.WriteClientBinary(writePipe, msg) | 		err = wsutil.WriteClientBinary(writePipe, msg) | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 		returnedMsg, err = wsutil.ReadServerBinary(respWriter.respBody()) | 		returnedMsg, err = wsutil.ReadServerBinary(responseWriter.respBody()) | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
| 		require.Equal(t, msg, returnedMsg) | 		require.Equal(t, msg, returnedMsg) | ||||||
| 
 | 
 | ||||||
|  | @ -197,7 +191,7 @@ func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) | ||||||
| 			pushCount = 50 | 			pushCount = 50 | ||||||
| 			pushFreq  = time.Millisecond * 10 | 			pushFreq  = time.Millisecond * 10 | ||||||
| 		) | 		) | ||||||
| 		respWriter := newMockSSERespWriter() | 		responseWriter := newMockSSERespWriter() | ||||||
| 		ctx, cancel := context.WithCancel(context.Background()) | 		ctx, cancel := context.WithCancel(context.Background()) | ||||||
| 		req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s?freq=%s", hello.SSERoute, pushFreq), nil) | 		req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s?freq=%s", hello.SSERoute, pushFreq), nil) | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
|  | @ -206,18 +200,18 @@ func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) | ||||||
| 		wg.Add(1) | 		wg.Add(1) | ||||||
| 		go func() { | 		go func() { | ||||||
| 			defer wg.Done() | 			defer wg.Done() | ||||||
| 			err = proxy.Proxy(respWriter, req, connection.TypeHTTP) | 			err = proxy.Proxy(responseWriter, req, connection.TypeHTTP) | ||||||
| 			require.NoError(t, err) | 			require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 			require.Equal(t, http.StatusOK, respWriter.Code) | 			require.Equal(t, http.StatusOK, responseWriter.Code) | ||||||
| 		}() | 		}() | ||||||
| 
 | 
 | ||||||
| 		for i := 0; i < pushCount; i++ { | 		for i := 0; i < pushCount; i++ { | ||||||
| 			line := respWriter.ReadBytes() | 			line := responseWriter.ReadBytes() | ||||||
| 			expect := fmt.Sprintf("%d\n", i) | 			expect := fmt.Sprintf("%d\n", i) | ||||||
| 			require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line)) | 			require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line)) | ||||||
| 
 | 
 | ||||||
| 			line = respWriter.ReadBytes() | 			line = responseWriter.ReadBytes() | ||||||
| 			require.Equal(t, []byte("\n"), line, fmt.Sprintf("Expect to read '\n', got %v", line)) | 			require.Equal(t, []byte("\n"), line, fmt.Sprintf("Expect to read '\n', got %v", line)) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -295,18 +289,18 @@ func TestProxyMultipleOrigins(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, test := range tests { | 	for _, test := range tests { | ||||||
| 		respWriter := newMockHTTPRespWriter() | 		responseWriter := newMockHTTPRespWriter() | ||||||
| 		req, err := http.NewRequest(http.MethodGet, test.url, nil) | 		req, err := http.NewRequest(http.MethodGet, test.url, nil) | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 		err = proxy.Proxy(respWriter, req, connection.TypeHTTP) | 		err = proxy.Proxy(responseWriter, req, connection.TypeHTTP) | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 		assert.Equal(t, test.expectedStatus, respWriter.Code) | 		assert.Equal(t, test.expectedStatus, responseWriter.Code) | ||||||
| 		if test.expectedBody != nil { | 		if test.expectedBody != nil { | ||||||
| 			assert.Equal(t, test.expectedBody, respWriter.Body.Bytes()) | 			assert.Equal(t, test.expectedBody, responseWriter.Body.Bytes()) | ||||||
| 		} else { | 		} else { | ||||||
| 			assert.Equal(t, 0, respWriter.Body.Len()) | 			assert.Equal(t, 0, responseWriter.Body.Len()) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	cancel() | 	cancel() | ||||||
|  | @ -343,11 +337,11 @@ func TestProxyError(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log) | 	proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log) | ||||||
| 
 | 
 | ||||||
| 	respWriter := newMockHTTPRespWriter() | 	responseWriter := newMockHTTPRespWriter() | ||||||
| 	req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) | 	req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) | ||||||
| 	assert.NoError(t, err) | 	assert.NoError(t, err) | ||||||
| 
 | 
 | ||||||
| 	assert.Error(t, proxy.Proxy(respWriter, req, connection.TypeHTTP)) | 	assert.Error(t, proxy.Proxy(responseWriter, req, connection.TypeHTTP)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type replayer struct { | type replayer struct { | ||||||
|  | @ -399,82 +393,171 @@ func (r *replayer) Bytes() []byte { | ||||||
| func TestConnections(t *testing.T) { | func TestConnections(t *testing.T) { | ||||||
| 	logger := logger.Create(nil) | 	logger := logger.Create(nil) | ||||||
| 	replayer := &replayer{rw: &bytes.Buffer{}} | 	replayer := &replayer{rw: &bytes.Buffer{}} | ||||||
|  | 	type args struct { | ||||||
|  | 		ingressServiceScheme  string | ||||||
|  | 		originService         func(*testing.T, net.Listener) | ||||||
|  | 		eyeballResponseWriter connection.ResponseWriter | ||||||
|  | 		eyeballRequestBody    io.ReadCloser | ||||||
|  | 
 | ||||||
|  | 		// Can be set to nil to show warp routing is not enabled.
 | ||||||
|  | 		warpRoutingService *ingress.WarpRoutingService | ||||||
|  | 
 | ||||||
|  | 		// eyeball connection type.
 | ||||||
|  | 		connectionType connection.Type | ||||||
|  | 
 | ||||||
|  | 		//requestheaders to be sent in the call to proxy.Proxy
 | ||||||
|  | 		requestHeaders http.Header | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	type want struct { | ||||||
|  | 		message []byte | ||||||
|  | 		headers http.Header | ||||||
|  | 		err     bool | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| 	var tests = []struct { | 	var tests = []struct { | ||||||
| 		name string | 		name string | ||||||
| 		skip                 bool | 		args args | ||||||
| 		ingressServicePrefix string | 		want want | ||||||
| 
 |  | ||||||
| 		originService  func(*testing.T, net.Listener) |  | ||||||
| 		eyeballService connection.ResponseWriter |  | ||||||
| 		connectionType connection.Type |  | ||||||
| 		requestHeaders http.Header |  | ||||||
| 		wantMessage    []byte |  | ||||||
| 		wantHeaders    http.Header |  | ||||||
| 	}{ | 	}{ | ||||||
| 		{ | 		{ | ||||||
| 			name: "ws-ws proxy", | 			name: "ws-ws proxy", | ||||||
| 			ingressServicePrefix: "ws://", | 			args: args{ | ||||||
|  | 				ingressServiceScheme:  "ws://", | ||||||
| 				originService:         runEchoWSService, | 				originService:         runEchoWSService, | ||||||
| 			eyeballService:       newWSRespWriter([]byte("test1"), replayer), | 				eyeballResponseWriter: newWSRespWriter(replayer), | ||||||
|  | 				eyeballRequestBody:    newWSRequestBody([]byte("test1")), | ||||||
| 				connectionType:        connection.TypeWebsocket, | 				connectionType:        connection.TypeWebsocket, | ||||||
| 			requestHeaders: http.Header{ | 				requestHeaders: map[string][]string{ | ||||||
| 					// Example key from https://tools.ietf.org/html/rfc6455#section-1.2
 | 					// Example key from https://tools.ietf.org/html/rfc6455#section-1.2
 | ||||||
| 					"Sec-Websocket-Key":     {"dGhlIHNhbXBsZSBub25jZQ=="}, | 					"Sec-Websocket-Key":     {"dGhlIHNhbXBsZSBub25jZQ=="}, | ||||||
| 					"Test-Cloudflared-Echo": {"Echo"}, | 					"Test-Cloudflared-Echo": {"Echo"}, | ||||||
| 				}, | 				}, | ||||||
| 			wantMessage: []byte("echo-test1"), | 			}, | ||||||
| 			wantHeaders: http.Header{ | 			want: want{ | ||||||
|  | 				message: []byte("echo-test1"), | ||||||
|  | 				headers: map[string][]string{ | ||||||
| 					"Connection":            {"Upgrade"}, | 					"Connection":            {"Upgrade"}, | ||||||
| 					"Sec-Websocket-Accept":  {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, | 					"Sec-Websocket-Accept":  {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, | ||||||
| 					"Upgrade":               {"websocket"}, | 					"Upgrade":               {"websocket"}, | ||||||
| 					"Test-Cloudflared-Echo": {"Echo"}, | 					"Test-Cloudflared-Echo": {"Echo"}, | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
|  | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name: "tcp-tcp proxy", | 			name: "tcp-tcp proxy", | ||||||
| 			ingressServicePrefix: "tcp://", | 			args: args{ | ||||||
|  | 				ingressServiceScheme:  "tcp://", | ||||||
| 				originService:         runEchoTCPService, | 				originService:         runEchoTCPService, | ||||||
| 			eyeballService: newTCPRespWriter( | 				eyeballResponseWriter: newTCPRespWriter(replayer), | ||||||
| 				[]byte(`test2`), | 				eyeballRequestBody:    newTCPRequestBody([]byte("test2")), | ||||||
| 				replayer, | 				warpRoutingService:    ingress.NewWarpRoutingService(), | ||||||
| 			), |  | ||||||
| 				connectionType:        connection.TypeTCP, | 				connectionType:        connection.TypeTCP, | ||||||
| 			requestHeaders: http.Header{ | 				requestHeaders: map[string][]string{ | ||||||
| 					"Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, | 					"Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, | ||||||
| 				}, | 				}, | ||||||
| 			wantMessage: []byte("echo-test2"), | 			}, | ||||||
|  | 			want: want{ | ||||||
|  | 				message: []byte("echo-test2"), | ||||||
|  | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name: "tcp-ws proxy", | 			name: "tcp-ws proxy", | ||||||
| 			ingressServicePrefix: "ws://", | 			args: args{ | ||||||
|  | 				ingressServiceScheme: "ws://", | ||||||
| 				originService:        runEchoWSService, | 				originService:        runEchoWSService, | ||||||
| 			eyeballService:       newPipedWSWriter(&mockTCPRespWriter{}, []byte("test3")), | 				//eyeballResponseWriter gets set after roundtrip dial.
 | ||||||
| 			requestHeaders: http.Header{ | 				eyeballRequestBody: newPipedWSRequestBody([]byte("test3")), | ||||||
|  | 				warpRoutingService: ingress.NewWarpRoutingService(), | ||||||
|  | 				requestHeaders: map[string][]string{ | ||||||
| 					"Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, | 					"Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, | ||||||
| 				}, | 				}, | ||||||
| 				connectionType: connection.TypeTCP, | 				connectionType: connection.TypeTCP, | ||||||
| 			wantMessage:    []byte("echo-test3"), | 			}, | ||||||
|  | 			want: want{ | ||||||
|  | 				message: []byte("echo-test3"), | ||||||
| 				// We expect no headers here because they are sent back via
 | 				// We expect no headers here because they are sent back via
 | ||||||
| 				// the stream.
 | 				// the stream.
 | ||||||
| 			}, | 			}, | ||||||
|  | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name: "ws-tcp proxy", | 			name: "ws-tcp proxy", | ||||||
| 			ingressServicePrefix: "tcp://", | 			args: args{ | ||||||
|  | 				ingressServiceScheme:  "tcp://", | ||||||
| 				originService:         runEchoTCPService, | 				originService:         runEchoTCPService, | ||||||
| 			eyeballService:       newWSRespWriter([]byte("test4"), replayer), | 				eyeballResponseWriter: newWSRespWriter(replayer), | ||||||
|  | 				eyeballRequestBody:    newWSRequestBody([]byte("test4")), | ||||||
| 				connectionType:        connection.TypeWebsocket, | 				connectionType:        connection.TypeWebsocket, | ||||||
| 			requestHeaders: http.Header{ | 				requestHeaders: map[string][]string{ | ||||||
| 					// Example key from https://tools.ietf.org/html/rfc6455#section-1.2
 | 					// Example key from https://tools.ietf.org/html/rfc6455#section-1.2
 | ||||||
| 					"Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="}, | 					"Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="}, | ||||||
| 				}, | 				}, | ||||||
| 			wantMessage: []byte("echo-test4"), | 			}, | ||||||
| 			wantHeaders: http.Header{ | 			want: want{ | ||||||
|  | 				message: []byte("echo-test4"), | ||||||
|  | 				headers: map[string][]string{ | ||||||
| 					"Connection":           {"Upgrade"}, | 					"Connection":           {"Upgrade"}, | ||||||
| 					"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, | 					"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="}, | ||||||
| 					"Upgrade":              {"websocket"}, | 					"Upgrade":              {"websocket"}, | ||||||
| 				}, | 				}, | ||||||
| 			}, | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "tcp-tcp proxy without warpRoutingService enabled", | ||||||
|  | 			args: args{ | ||||||
|  | 				ingressServiceScheme:  "tcp://", | ||||||
|  | 				originService:         runEchoTCPService, | ||||||
|  | 				eyeballResponseWriter: newTCPRespWriter(replayer), | ||||||
|  | 				eyeballRequestBody:    newTCPRequestBody([]byte("test2")), | ||||||
|  | 				connectionType:        connection.TypeTCP, | ||||||
|  | 				requestHeaders: map[string][]string{ | ||||||
|  | 					"Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			want: want{ | ||||||
|  | 				message: []byte{}, | ||||||
|  | 				err:     true, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "ws-ws proxy when origin is different", | ||||||
|  | 			args: args{ | ||||||
|  | 				ingressServiceScheme:  "ws://", | ||||||
|  | 				originService:         runEchoWSService, | ||||||
|  | 				eyeballResponseWriter: newWSRespWriter(replayer), | ||||||
|  | 				eyeballRequestBody:    newWSRequestBody([]byte("test1")), | ||||||
|  | 				connectionType:        connection.TypeWebsocket, | ||||||
|  | 				requestHeaders: map[string][]string{ | ||||||
|  | 					// Example key from https://tools.ietf.org/html/rfc6455#section-1.2
 | ||||||
|  | 					"Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="}, | ||||||
|  | 					"Origin":            {"Different origin"}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			want: want{ | ||||||
|  | 				message: []byte{}, | ||||||
|  | 				err:     true, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "tcp-* proxy when origin service has already closed the connection/ is no longer running", | ||||||
|  | 			args: args{ | ||||||
|  | 				ingressServiceScheme: "tcp://", | ||||||
|  | 				originService: func(t *testing.T, ln net.Listener) { | ||||||
|  | 					// closing the listener created by the test.
 | ||||||
|  | 					ln.Close() | ||||||
|  | 				}, | ||||||
|  | 				eyeballResponseWriter: newTCPRespWriter(replayer), | ||||||
|  | 				eyeballRequestBody:    newTCPRequestBody([]byte("test2")), | ||||||
|  | 				connectionType:        connection.TypeTCP, | ||||||
|  | 				requestHeaders: map[string][]string{ | ||||||
|  | 					"Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			want: want{ | ||||||
|  | 				message: []byte{}, | ||||||
|  | 				err:     true, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, test := range tests { | 	for _, test := range tests { | ||||||
|  | @ -483,69 +566,99 @@ func TestConnections(t *testing.T) { | ||||||
| 			ln, err := net.Listen("tcp", "127.0.0.1:0") | 			ln, err := net.Listen("tcp", "127.0.0.1:0") | ||||||
| 			require.NoError(t, err) | 			require.NoError(t, err) | ||||||
| 			// Starts origin service
 | 			// Starts origin service
 | ||||||
| 			test.originService(t, ln) | 			test.args.originService(t, ln) | ||||||
| 
 | 
 | ||||||
| 			ingressRule := createSingleIngressConfig(t, test.ingressServicePrefix+ln.Addr().String()) | 			ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String()) | ||||||
| 			var wg sync.WaitGroup | 			var wg sync.WaitGroup | ||||||
| 			errC := make(chan error) | 			errC := make(chan error) | ||||||
| 			ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) | 			ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC) | ||||||
| 			proxy := NewOriginProxy(ingressRule, ingress.NewWarpRoutingService(), testTags, logger) | 			proxy := NewOriginProxy(ingressRule, test.args.warpRoutingService, testTags, logger) | ||||||
| 
 | 
 | ||||||
| 			req, err := http.NewRequest(http.MethodGet, test.ingressServicePrefix+ln.Addr().String(), nil) | 			req, err := http.NewRequest( | ||||||
|  | 				http.MethodGet, | ||||||
|  | 				test.args.ingressServiceScheme+ln.Addr().String(), | ||||||
|  | 				test.args.eyeballRequestBody, | ||||||
|  | 			) | ||||||
| 			require.NoError(t, err) | 			require.NoError(t, err) | ||||||
| 			req.Header = test.requestHeaders |  | ||||||
| 
 | 
 | ||||||
| 			if pipedWS, ok := test.eyeballService.(*pipedWSWriter); ok { | 			req.Header = test.args.requestHeaders | ||||||
|  | 			respWriter := test.args.eyeballResponseWriter | ||||||
|  | 
 | ||||||
|  | 			if pipedReqBody, ok := test.args.eyeballRequestBody.(*pipedRequestBody); ok { | ||||||
|  | 				respWriter = newTCPRespWriter(pipedReqBody.pipedConn) | ||||||
| 				go func() { | 				go func() { | ||||||
| 					resp := pipedWS.roundtrip(test.ingressServicePrefix + ln.Addr().String()) | 					resp := pipedReqBody.roundtrip(test.args.ingressServiceScheme + ln.Addr().String()) | ||||||
| 					replayer.Write(resp) | 					replayer.Write(resp) | ||||||
| 				}() | 				}() | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			err = proxy.Proxy(test.eyeballService, req, test.connectionType) | 			err = proxy.Proxy(respWriter, req, test.args.connectionType) | ||||||
| 			require.NoError(t, err) |  | ||||||
| 
 | 
 | ||||||
| 			cancel() | 			cancel() | ||||||
| 			assert.Equal(t, test.wantMessage, replayer.Bytes()) | 			assert.Equal(t, test.want.err, err != nil) | ||||||
| 			respPrinter := test.eyeballService.(responsePrinter) | 			assert.Equal(t, test.want.message, replayer.Bytes()) | ||||||
| 			assert.Equal(t, test.wantHeaders, respPrinter.printRespHeaders()) | 			respPrinter := respWriter.(responsePrinter) | ||||||
|  | 			assert.Equal(t, test.want.headers, respPrinter.headers()) | ||||||
| 			replayer.rw.Reset() | 			replayer.rw.Reset() | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type responsePrinter interface { | type requestBody struct { | ||||||
| 	printRespHeaders() http.Header | 	pw *io.PipeWriter | ||||||
|  | 	pr *io.PipeReader | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type pipedWSWriter struct { | func newWSRequestBody(data []byte) *requestBody { | ||||||
|  | 	pr, pw := io.Pipe() | ||||||
|  | 	go wsutil.WriteClientBinary(pw, data) | ||||||
|  | 	return &requestBody{ | ||||||
|  | 		pr: pr, | ||||||
|  | 		pw: pw, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | func newTCPRequestBody(data []byte) *requestBody { | ||||||
|  | 	pr, pw := io.Pipe() | ||||||
|  | 	go pw.Write(data) | ||||||
|  | 	return &requestBody{ | ||||||
|  | 		pr: pr, | ||||||
|  | 		pw: pw, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (r *requestBody) Read(p []byte) (n int, err error) { | ||||||
|  | 	return r.pr.Read(p) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (r *requestBody) Close() error { | ||||||
|  | 	r.pw.Close() | ||||||
|  | 	r.pr.Close() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type pipedRequestBody struct { | ||||||
| 	dialer         gorillaWS.Dialer | 	dialer         gorillaWS.Dialer | ||||||
| 	wsConn         net.Conn |  | ||||||
| 	pipedConn      net.Conn | 	pipedConn      net.Conn | ||||||
| 	respWriter     connection.ResponseWriter | 	wsConn         net.Conn | ||||||
| 	respHeaders    http.Header |  | ||||||
| 	messageToWrite []byte | 	messageToWrite []byte | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func newPipedWSWriter(rw *mockTCPRespWriter, messageToWrite []byte) *pipedWSWriter { | func newPipedWSRequestBody(data []byte) *pipedRequestBody { | ||||||
| 	conn1, conn2 := net.Pipe() | 	conn1, conn2 := net.Pipe() | ||||||
| 	dialer := gorillaWS.Dialer{ | 	dialer := gorillaWS.Dialer{ | ||||||
| 		NetDial: func(network, addr string) (net.Conn, error) { | 		NetDial: func(network, addr string) (net.Conn, error) { | ||||||
| 			return conn2, nil | 			return conn2, nil | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	rw.pr = conn1 | 	return &pipedRequestBody{ | ||||||
| 	rw.w = conn1 |  | ||||||
| 	return &pipedWSWriter{ |  | ||||||
| 		dialer:         dialer, | 		dialer:         dialer, | ||||||
| 		pipedConn:      conn1, | 		pipedConn:      conn1, | ||||||
| 		wsConn:         conn2, | 		wsConn:         conn2, | ||||||
| 		messageToWrite: messageToWrite, | 		messageToWrite: data, | ||||||
| 		respWriter:     rw, |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *pipedWSWriter) roundtrip(addr string) []byte { | func (p *pipedRequestBody) roundtrip(addr string) []byte { | ||||||
| 	header := http.Header{} | 	header := http.Header{} | ||||||
| 	conn, resp, err := p.dialer.Dial(addr, header) | 	conn, resp, err := p.dialer.Dial(addr, header) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -570,56 +683,35 @@ func (p *pipedWSWriter) roundtrip(addr string) []byte { | ||||||
| 	return data | 	return data | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *pipedWSWriter) Read(data []byte) (int, error) { | func (p *pipedRequestBody) Read(data []byte) (n int, err error) { | ||||||
| 	return p.pipedConn.Read(data) | 	return p.pipedConn.Read(data) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *pipedWSWriter) Write(data []byte) (int, error) { | func (p *pipedRequestBody) Close() error { | ||||||
| 	return p.pipedConn.Write(data) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *pipedWSWriter) WriteErrorResponse() { |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *pipedWSWriter) WriteRespHeaders(status int, header http.Header) error { |  | ||||||
| 	p.respHeaders = header |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // printRespHeaders is a test function to read respHeaders
 | type responsePrinter interface { | ||||||
| func (p *pipedWSWriter) printRespHeaders() http.Header { | 	headers() http.Header | ||||||
| 	return p.respHeaders |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type wsRespWriter struct { | type wsRespWriter struct { | ||||||
| 	w               io.Writer | 	w               io.Writer | ||||||
| 	pr          *io.PipeReader | 	responseHeaders http.Header | ||||||
| 	pw          *io.PipeWriter |  | ||||||
| 	respHeaders http.Header |  | ||||||
| 	code            int | 	code            int | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // newWSRespWriter uses wsutil.WriteClientText to generate websocket frames.
 | // newWSRespWriter uses wsutil.WriteClientText to generate websocket frames.
 | ||||||
| // and wsutil.ReadClientText to translate frames from server to byte data.
 | // and wsutil.ReadClientText to translate frames from server to byte data.
 | ||||||
| // In essence, this acts as a wsClient.
 | // In essence, this acts as a wsClient.
 | ||||||
| func newWSRespWriter(data []byte, w io.Writer) *wsRespWriter { | func newWSRespWriter(w io.Writer) *wsRespWriter { | ||||||
| 	pr, pw := io.Pipe() |  | ||||||
| 	go wsutil.WriteClientBinary(pw, data) |  | ||||||
| 	return &wsRespWriter{ | 	return &wsRespWriter{ | ||||||
| 		w: w, | 		w: w, | ||||||
| 		pr: pr, |  | ||||||
| 		pw: pw, |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Read is read by ingress.Stream and serves as the input from the client.
 |  | ||||||
| func (w *wsRespWriter) Read(p []byte) (int, error) { |  | ||||||
| 	return w.pr.Read(p) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Write is written to by ingress.Stream and serves as the output to the client.
 | // Write is written to by ingress.Stream and serves as the output to the client.
 | ||||||
| func (w *wsRespWriter) Write(p []byte) (int, error) { | func (w *wsRespWriter) Write(p []byte) (int, error) { | ||||||
| 	defer w.pw.Close() |  | ||||||
| 	returnedMsg, err := wsutil.ReadServerBinary(bytes.NewBuffer(p)) | 	returnedMsg, err := wsutil.ReadServerBinary(bytes.NewBuffer(p)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		// The data was not returned by a websocket connecton.
 | 		// The data was not returned by a websocket connecton.
 | ||||||
|  | @ -631,17 +723,55 @@ func (w *wsRespWriter) Write(p []byte) (int, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error { | func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error { | ||||||
| 	w.respHeaders = header | 	w.responseHeaders = header | ||||||
| 	w.code = status | 	w.code = status | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (w *wsRespWriter) WriteErrorResponse() { | // respHeaders is a test function to read respHeaders
 | ||||||
|  | func (w *wsRespWriter) headers() http.Header { | ||||||
|  | 	return w.responseHeaders | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // printRespHeaders is a test function to read respHeaders
 | type mockTCPRespWriter struct { | ||||||
| func (w *wsRespWriter) printRespHeaders() http.Header { | 	w               io.Writer | ||||||
| 	return w.respHeaders | 	responseHeaders http.Header | ||||||
|  | 	code            int | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func newTCPRespWriter(w io.Writer) *mockTCPRespWriter { | ||||||
|  | 	return &mockTCPRespWriter{ | ||||||
|  | 		w: w, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) { | ||||||
|  | 	return m.w.Write(p) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error { | ||||||
|  | 	m.responseHeaders = header | ||||||
|  | 	m.code = status | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // respHeaders is a test function to read respHeaders
 | ||||||
|  | func (m *mockTCPRespWriter) headers() http.Header { | ||||||
|  | 	return m.responseHeaders | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func createSingleIngressConfig(t *testing.T, service string) ingress.Ingress { | ||||||
|  | 	ingressConfig := &config.Configuration{ | ||||||
|  | 		Ingress: []config.UnvalidatedIngressRule{ | ||||||
|  | 			{ | ||||||
|  | 				Hostname: "*", | ||||||
|  | 				Service:  service, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	ingressRule, err := ingress.ParseIngress(ingressConfig) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	return ingressRule | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func runEchoTCPService(t *testing.T, l net.Listener) { | func runEchoTCPService(t *testing.T, l net.Listener) { | ||||||
|  | @ -662,8 +792,8 @@ func runEchoTCPService(t *testing.T, l net.Listener) { | ||||||
| 				_, err = conn.Write(data) | 				_, err = conn.Write(data) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					t.Log(err) | 					t.Log(err) | ||||||
| 					return |  | ||||||
| 				} | 				} | ||||||
|  | 				return | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
|  | @ -683,7 +813,10 @@ func runEchoWSService(t *testing.T, l net.Listener) { | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		conn, err := upgrader.Upgrade(w, r, header) | 		conn, err := upgrader.Upgrade(w, r, header) | ||||||
| 		require.NoError(t, err) | 		if err != nil { | ||||||
|  | 			t.Log(err) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
| 		defer conn.Close() | 		defer conn.Close() | ||||||
| 
 | 
 | ||||||
| 		for { | 		for { | ||||||
|  | @ -708,61 +841,3 @@ func runEchoWSService(t *testing.T, l net.Listener) { | ||||||
| 		require.NoError(t, err) | 		require.NoError(t, err) | ||||||
| 	}() | 	}() | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func createSingleIngressConfig(t *testing.T, service string) ingress.Ingress { |  | ||||||
| 	ingressConfig := &config.Configuration{ |  | ||||||
| 		Ingress: []config.UnvalidatedIngressRule{ |  | ||||||
| 			{ |  | ||||||
| 				Hostname: "*", |  | ||||||
| 				Service:  service, |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
| 	ingressRule, err := ingress.ParseIngress(ingressConfig) |  | ||||||
| 	require.NoError(t, err) |  | ||||||
| 	return ingressRule |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type tcpWrappedWs struct { |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type mockTCPRespWriter struct { |  | ||||||
| 	w           io.Writer |  | ||||||
| 	pr          io.Reader |  | ||||||
| 	pw          *io.PipeWriter |  | ||||||
| 	respHeaders http.Header |  | ||||||
| 	code        int |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func newTCPRespWriter(data []byte, w io.Writer) *mockTCPRespWriter { |  | ||||||
| 	pr, pw := io.Pipe() |  | ||||||
| 	go pw.Write(data) |  | ||||||
| 	return &mockTCPRespWriter{ |  | ||||||
| 		w:  w, |  | ||||||
| 		pr: pr, |  | ||||||
| 		pw: pw, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m *mockTCPRespWriter) Read(p []byte) (n int, err error) { |  | ||||||
| 	return m.pr.Read(p) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) { |  | ||||||
| 	defer m.pw.Close() |  | ||||||
| 	return m.w.Write(p) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m *mockTCPRespWriter) WriteErrorResponse() { |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error { |  | ||||||
| 	m.respHeaders = header |  | ||||||
| 	m.code = status |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // printRespHeaders is a test function to read respHeaders
 |  | ||||||
| func (m *mockTCPRespWriter) printRespHeaders() http.Header { |  | ||||||
| 	return m.respHeaders |  | ||||||
| } |  | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue