diff --git a/connection/connection.go b/connection/connection.go index 41fd03ab..d1b1081e 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -44,7 +44,7 @@ type OriginClient interface { type ResponseWriter interface { WriteRespHeaders(*http.Response) error - WriteErrorResponse(error) + WriteErrorResponse() io.ReadWriter } diff --git a/connection/connection_test.go b/connection/connection_test.go new file mode 100644 index 00000000..f15dacfb --- /dev/null +++ b/connection/connection_test.go @@ -0,0 +1,113 @@ +package connection + +import ( + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/cloudflare/cloudflared/cmd/cloudflared/ui" + "github.com/cloudflare/cloudflared/logger" + "github.com/gobwas/ws/wsutil" +) + +const ( + largeFileSize = 2 * 1024 * 1024 +) + +var ( + testConfig = &Config{ + OriginClient: &mockOriginClient{}, + GracePeriod: time.Millisecond * 100, + } + testLogger, _ = logger.New() + testOriginURL = &url.URL{ + Scheme: "https", + Host: "connectiontest.argotunnel.com", + } + testTunnelEventChan = make(chan ui.TunnelEvent) + testObserver = &Observer{ + testLogger, + m, + testTunnelEventChan, + } + testLargeResp = make([]byte, largeFileSize) +) + +type testRequest struct { + name string + endpoint string + expectedStatus int + expectedBody []byte + isProxyError bool +} + +type mockOriginClient struct { +} + +func (moc *mockOriginClient) Proxy(w ResponseWriter, r *http.Request, isWebsocket bool) error { + if isWebsocket { + return wsEndpoint(w, r) + } + switch r.URL.Path { + case "/ok": + originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK))) + case "/large_file": + originRespEndpoint(w, http.StatusOK, testLargeResp) + case "/400": + originRespEndpoint(w, http.StatusBadRequest, []byte(http.StatusText(http.StatusBadRequest))) + case "/500": + originRespEndpoint(w, http.StatusInternalServerError, []byte(http.StatusText(http.StatusInternalServerError))) + case "/error": + return fmt.Errorf("Failed to proxy to origin") + default: + originRespEndpoint(w, http.StatusNotFound, []byte("page not found")) + } + return nil +} + +type nowriter struct { + io.Reader +} + +func (nowriter) Write(p []byte) (int, error) { + return 0, fmt.Errorf("Writer not implemented") +} + +func wsEndpoint(w ResponseWriter, r *http.Request) error { + resp := &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + } + w.WriteRespHeaders(resp) + clientReader := nowriter{r.Body} + go func() { + for { + data, err := wsutil.ReadClientText(clientReader) + if err != nil { + return + } + if err := wsutil.WriteServerText(w, data); err != nil { + return + } + } + }() + <-r.Context().Done() + return nil +} + +func originRespEndpoint(w ResponseWriter, status int, data []byte) { + resp := &http.Response{ + StatusCode: status, + } + w.WriteRespHeaders(resp) + w.Write(data) +} + +type mockConnectedFuse struct{} + +func (mcf mockConnectedFuse) Connected() {} + +func (mcf mockConnectedFuse) IsConnected() bool { + return true +} diff --git a/connection/h2mux.go b/connection/h2mux.go index a85a44c8..4f397b68 100644 --- a/connection/h2mux.go +++ b/connection/h2mux.go @@ -88,9 +88,9 @@ func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *Nam return err } rpcClient := newRegistrationRPCClient(ctx, stream, h.observer) - defer rpcClient.close() + defer rpcClient.Close() - if err = registerConnection(serveCtx, rpcClient, namedTunnel, connOptions, h.connIndex, h.observer); err != nil { + if err = rpcClient.RegisterConnection(serveCtx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil { return err } connectedFuse.Connected() @@ -177,11 +177,16 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error { req, reqErr := h.newRequest(stream) if reqErr != nil { - respWriter.WriteErrorResponse(reqErr) + respWriter.WriteErrorResponse() return reqErr } - return h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req)) + err := h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req)) + if err != nil { + respWriter.WriteErrorResponse() + return err + } + return nil } func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) { @@ -206,7 +211,7 @@ func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error { return rp.WriteHeaders(headers) } -func (rp *h2muxRespWriter) WriteErrorResponse(err error) { +func (rp *h2muxRespWriter) WriteErrorResponse() { rp.WriteHeaders([]h2mux.Header{ {Name: ":status", Value: "502"}, {Name: responseMetaHeaderField, Value: responseMetaHeaderCfd}, diff --git a/connection/h2mux_test.go b/connection/h2mux_test.go new file mode 100644 index 00000000..4f1019e0 --- /dev/null +++ b/connection/h2mux_test.go @@ -0,0 +1,242 @@ +package connection + +import ( + "context" + "fmt" + "io" + "net" + "net/http" + "strconv" + "sync" + "testing" + "time" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/gobwas/ws/wsutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + testMuxerConfig = &MuxerConfig{ + HeartbeatInterval: time.Second * 5, + MaxHeartbeats: 5, + CompressionSetting: 0, + MetricsUpdateFreq: time.Second * 5, + } +) + +func newH2MuxConnection(ctx context.Context, t require.TestingT) (*h2muxConnection, *h2mux.Muxer) { + edgeConn, originConn := net.Pipe() + edgeMuxChan := make(chan *h2mux.Muxer) + go func() { + edgeMuxConfig := h2mux.MuxerConfig{ + Logger: testObserver, + } + edgeMux, err := h2mux.Handshake(edgeConn, edgeConn, edgeMuxConfig, h2mux.ActiveStreams) + require.NoError(t, err) + edgeMuxChan <- edgeMux + }() + var connIndex = uint8(0) + h2muxConn, err, _ := NewH2muxConnection(ctx, testConfig, testMuxerConfig, originConn, connIndex, testObserver) + require.NoError(t, err) + return h2muxConn, <-edgeMuxChan +} + +func TestServeStreamHTTP(t *testing.T) { + tests := []testRequest{ + { + name: "ok", + endpoint: "/ok", + expectedStatus: http.StatusOK, + expectedBody: []byte(http.StatusText(http.StatusOK)), + }, + { + name: "large_file", + endpoint: "/large_file", + expectedStatus: http.StatusOK, + expectedBody: testLargeResp, + }, + { + name: "Bad request", + endpoint: "/400", + expectedStatus: http.StatusBadRequest, + expectedBody: []byte(http.StatusText(http.StatusBadRequest)), + }, + { + name: "Internal server error", + endpoint: "/500", + expectedStatus: http.StatusInternalServerError, + expectedBody: []byte(http.StatusText(http.StatusInternalServerError)), + }, + { + name: "Proxy error", + endpoint: "/error", + expectedStatus: http.StatusBadGateway, + expectedBody: nil, + isProxyError: true, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + h2muxConn, edgeMux := newH2MuxConnection(ctx, t) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + edgeMux.Serve(ctx) + }() + go func() { + defer wg.Done() + err := h2muxConn.serveMuxer(ctx) + require.Error(t, err) + }() + + for _, test := range tests { + headers := []h2mux.Header{ + { + Name: ":path", + Value: test.endpoint, + }, + } + stream, err := edgeMux.OpenStream(ctx, headers, nil) + require.NoError(t, err) + require.True(t, hasHeader(stream, ":status", strconv.Itoa(test.expectedStatus))) + + if test.isProxyError { + assert.True(t, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderCfd)) + } else { + assert.True(t, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderOrigin)) + body := make([]byte, len(test.expectedBody)) + _, err = stream.Read(body) + require.NoError(t, err) + require.Equal(t, test.expectedBody, body) + } + } + cancel() + wg.Wait() +} + +func TestServeStreamWS(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + h2muxConn, edgeMux := newH2MuxConnection(ctx, t) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + edgeMux.Serve(ctx) + }() + go func() { + defer wg.Done() + err := h2muxConn.serveMuxer(ctx) + require.Error(t, err) + }() + + headers := []h2mux.Header{ + { + Name: ":path", + Value: "/ws", + }, + { + Name: "connection", + Value: "upgrade", + }, + { + Name: "upgrade", + Value: "websocket", + }, + } + + readPipe, writePipe := io.Pipe() + stream, err := edgeMux.OpenStream(ctx, headers, readPipe) + require.NoError(t, err) + + require.True(t, hasHeader(stream, ":status", strconv.Itoa(http.StatusSwitchingProtocols))) + assert.True(t, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderOrigin)) + + data := []byte("test websocket") + err = wsutil.WriteClientText(writePipe, data) + require.NoError(t, err) + + respBody, err := wsutil.ReadServerText(stream) + require.NoError(t, err) + require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody))) + + cancel() + wg.Wait() +} + +func hasHeader(stream *h2mux.MuxedStream, name, val string) bool { + for _, header := range stream.Headers { + if header.Name == name && header.Value == val { + return true + } + } + return false +} + +func benchmarkServeStreamHTTPSimple(b *testing.B, test testRequest) { + ctx, cancel := context.WithCancel(context.Background()) + h2muxConn, edgeMux := newH2MuxConnection(ctx, b) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + edgeMux.Serve(ctx) + }() + go func() { + defer wg.Done() + err := h2muxConn.serveMuxer(ctx) + require.Error(b, err) + }() + + headers := []h2mux.Header{ + { + Name: ":path", + Value: test.endpoint, + }, + } + + body := make([]byte, len(test.expectedBody)) + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StartTimer() + stream, openstreamErr := edgeMux.OpenStream(ctx, headers, nil) + _, readBodyErr := stream.Read(body) + b.StopTimer() + + require.NoError(b, openstreamErr) + assert.True(b, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderOrigin)) + require.True(b, hasHeader(stream, ":status", strconv.Itoa(http.StatusOK))) + require.NoError(b, readBodyErr) + require.Equal(b, test.expectedBody, body) + } + + cancel() + wg.Wait() +} + +func BenchmarkServeStreamHTTPSimple(b *testing.B) { + test := testRequest{ + name: "ok", + endpoint: "/ok", + expectedStatus: http.StatusOK, + expectedBody: []byte(http.StatusText(http.StatusOK)), + } + + benchmarkServeStreamHTTPSimple(b, test) +} + +func BenchmarkServeStreamHTTPLargeFile(b *testing.B) { + test := testRequest{ + name: "large_file", + endpoint: "/large_file", + expectedStatus: http.StatusOK, + expectedBody: testLargeResp, + } + + benchmarkServeStreamHTTPSimple(b, test) +} diff --git a/connection/http2.go b/connection/http2.go index 7145caab..8e3e8f99 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -11,6 +11,7 @@ import ( "sync" "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/logger" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" "golang.org/x/net/http2" @@ -26,17 +27,19 @@ var ( errNotFlusher = errors.New("ResponseWriter doesn't implement http.Flusher") ) -type HTTP2Connection struct { - conn net.Conn - server *http2.Server - config *Config - namedTunnel *NamedTunnelConfig - connOptions *tunnelpogs.ConnectionOptions - observer *Observer - connIndexStr string - connIndex uint8 - wg *sync.WaitGroup - connectedFuse ConnectedFuse +type http2Connection struct { + conn net.Conn + server *http2.Server + config *Config + namedTunnel *NamedTunnelConfig + connOptions *tunnelpogs.ConnectionOptions + observer *Observer + connIndexStr string + connIndex uint8 + wg *sync.WaitGroup + // newRPCClientFunc allows us to mock RPCs during testing + newRPCClientFunc func(context.Context, io.ReadWriteCloser, logger.Service) NamedTunnelRPCClient + connectedFuse ConnectedFuse } func NewHTTP2Connection( @@ -47,24 +50,25 @@ func NewHTTP2Connection( observer *Observer, connIndex uint8, connectedFuse ConnectedFuse, -) *HTTP2Connection { - return &HTTP2Connection{ +) *http2Connection { + return &http2Connection{ conn: conn, server: &http2.Server{ MaxConcurrentStreams: math.MaxUint32, }, - config: config, - namedTunnel: namedTunnelConfig, - connOptions: connOptions, - observer: observer, - connIndexStr: uint8ToString(connIndex), - connIndex: connIndex, - wg: &sync.WaitGroup{}, - connectedFuse: connectedFuse, + config: config, + namedTunnel: namedTunnelConfig, + connOptions: connOptions, + observer: observer, + connIndexStr: uint8ToString(connIndex), + connIndex: connIndex, + wg: &sync.WaitGroup{}, + newRPCClientFunc: newRegistrationRPCClient, + connectedFuse: connectedFuse, } } -func (c *HTTP2Connection) Serve(ctx context.Context) { +func (c *http2Connection) Serve(ctx context.Context) { go func() { <-ctx.Done() c.close() @@ -75,7 +79,7 @@ func (c *HTTP2Connection) Serve(ctx context.Context) { }) } -func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { c.wg.Add(1) defer c.wg.Done() @@ -86,65 +90,42 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { flusher, isFlusher := w.(http.Flusher) if !isFlusher { c.observer.Errorf("%T doesn't implement http.Flusher", w) - respWriter.WriteErrorResponse(errNotFlusher) + respWriter.WriteErrorResponse() return } respWriter.flusher = flusher + var err error if isControlStreamUpgrade(r) { respWriter.shouldFlush = true - err := c.serveControlStream(r.Context(), respWriter) - if err != nil { - respWriter.WriteErrorResponse(err) - } + err = c.serveControlStream(r.Context(), respWriter) } else if isWebsocketUpgrade(r) { respWriter.shouldFlush = true stripWebsocketUpgradeHeader(r) - c.config.OriginClient.Proxy(respWriter, r, true) + err = c.config.OriginClient.Proxy(respWriter, r, true) } else { - c.config.OriginClient.Proxy(respWriter, r, false) + err = c.config.OriginClient.Proxy(respWriter, r, false) + } + + if err != nil { + respWriter.WriteErrorResponse() } } -func (c *HTTP2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error { - rpcClient := newRegistrationRPCClient(ctx, respWriter, c.observer) - defer rpcClient.close() +func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error { + rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer) + defer rpcClient.Close() - if err := registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil { + if err := rpcClient.RegisterConnection(ctx, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil { return err } c.connectedFuse.Connected() <-ctx.Done() - c.gracefulShutdown(ctx, rpcClient) + rpcClient.GracefulShutdown(ctx, c.config.GracePeriod) return nil } -func (c *HTTP2Connection) registerConnection( - ctx context.Context, - rpcClient tunnelpogs.RegistrationServer_PogsClient, -) error { - connDetail, err := rpcClient.RegisterConnection( - ctx, - c.namedTunnel.Auth, - c.namedTunnel.ID, - c.connIndex, - c.connOptions, - ) - if err != nil { - c.observer.Errorf("Cannot register connection, err: %v", err) - return err - } - c.observer.Infof("Connection %s registered with %s using ID %s", c.connIndexStr, connDetail.Location, connDetail.UUID) - return nil -} - -func (c *HTTP2Connection) gracefulShutdown(ctx context.Context, rpcClient *registrationServerClient) { - ctx, cancel := context.WithTimeout(ctx, c.config.GracePeriod) - defer cancel() - rpcClient.client.UnregisterConnection(ctx) -} - -func (c *HTTP2Connection) close() { +func (c *http2Connection) close() { // Wait for all serve HTTP handlers to return c.wg.Wait() c.conn.Close() @@ -195,7 +176,7 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error { return nil } -func (rp *http2RespWriter) WriteErrorResponse(err error) { +func (rp *http2RespWriter) WriteErrorResponse() { rp.setResponseMetaHeader(responseMetaHeaderCfd) rp.w.WriteHeader(http.StatusBadGateway) } diff --git a/connection/http2_test.go b/connection/http2_test.go new file mode 100644 index 00000000..06bced61 --- /dev/null +++ b/connection/http2_test.go @@ -0,0 +1,303 @@ +package connection + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/cloudflare/cloudflared/logger" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/gobwas/ws/wsutil" + "github.com/stretchr/testify/require" + "golang.org/x/net/http2" +) + +var ( + testTransport = http2.Transport{} +) + +func newTestHTTP2Connection() (*http2Connection, net.Conn) { + edgeConn, originConn := net.Pipe() + var connIndex = uint8(0) + return NewHTTP2Connection( + originConn, + testConfig, + &NamedTunnelConfig{}, + &pogs.ConnectionOptions{}, + testObserver, + connIndex, + mockConnectedFuse{}, + ), edgeConn +} + +func TestServeHTTP(t *testing.T) { + tests := []testRequest{ + { + name: "ok", + endpoint: "ok", + expectedStatus: http.StatusOK, + expectedBody: []byte(http.StatusText(http.StatusOK)), + }, + { + name: "large_file", + endpoint: "large_file", + expectedStatus: http.StatusOK, + expectedBody: testLargeResp, + }, + { + name: "Bad request", + endpoint: "400", + expectedStatus: http.StatusBadRequest, + expectedBody: []byte(http.StatusText(http.StatusBadRequest)), + }, + { + name: "Internal server error", + endpoint: "500", + expectedStatus: http.StatusInternalServerError, + expectedBody: []byte(http.StatusText(http.StatusInternalServerError)), + }, + { + name: "Proxy error", + endpoint: "error", + expectedStatus: http.StatusBadGateway, + expectedBody: nil, + isProxyError: true, + }, + } + + http2Conn, edgeConn := newTestHTTP2Connection() + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + http2Conn.Serve(ctx) + }() + + edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn) + require.NoError(t, err) + + for _, test := range tests { + endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + require.NoError(t, err) + + resp, err := edgeHTTP2Conn.RoundTrip(req) + require.NoError(t, err) + require.Equal(t, test.expectedStatus, resp.StatusCode) + if test.expectedBody != nil { + respBody, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, test.expectedBody, respBody) + } + if test.isProxyError { + require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(responseMetaHeaderField)) + } else { + require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(responseMetaHeaderField)) + } + } + cancel() + wg.Wait() +} + +type mockNamedTunnelRPCClient struct { + registered chan struct{} + unregistered chan struct{} +} + +func (mc mockNamedTunnelRPCClient) RegisterConnection( + c context.Context, + config *NamedTunnelConfig, + options *tunnelpogs.ConnectionOptions, + connIndex uint8, + observer *Observer, +) error { + close(mc.registered) + return nil +} + +func (mc mockNamedTunnelRPCClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) { + close(mc.unregistered) +} + +func (mockNamedTunnelRPCClient) Close() {} + +type mockRPCClientFactory struct { + registered chan struct{} + unregistered chan struct{} +} + +func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, logger.Service) NamedTunnelRPCClient { + return mockNamedTunnelRPCClient{ + registered: mf.registered, + unregistered: mf.unregistered, + } +} + +type wsRespWriter struct { + *httptest.ResponseRecorder + readPipe *io.PipeReader + writePipe *io.PipeWriter +} + +func newWSRespWriter() *wsRespWriter { + readPipe, writePipe := io.Pipe() + return &wsRespWriter{ + httptest.NewRecorder(), + readPipe, + writePipe, + } +} + +func (w *wsRespWriter) RespBody() io.ReadWriter { + return nowriter{w.readPipe} +} + +func (w *wsRespWriter) Write(data []byte) (n int, err error) { + return w.writePipe.Write(data) +} + +func TestServeWS(t *testing.T) { + http2Conn, _ := newTestHTTP2Connection() + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + http2Conn.Serve(ctx) + }() + + respWriter := newWSRespWriter() + readPipe, writePipe := io.Pipe() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe) + require.NoError(t, err) + req.Header.Set(internalUpgradeHeader, websocketUpgrade) + + wg.Add(1) + go func() { + defer wg.Done() + http2Conn.ServeHTTP(respWriter, req) + }() + + data := []byte("test websocket") + err = wsutil.WriteClientText(writePipe, data) + require.NoError(t, err) + + respBody, err := wsutil.ReadServerText(respWriter.RespBody()) + require.NoError(t, err) + require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody))) + + cancel() + resp := respWriter.Result() + // http2RespWriter should rewrite status 101 to 200 + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(responseMetaHeaderField)) + + wg.Wait() +} + +func TestServeControlStream(t *testing.T) { + http2Conn, edgeConn := newTestHTTP2Connection() + + rpcClientFactory := mockRPCClientFactory{ + registered: make(chan struct{}), + unregistered: make(chan struct{}), + } + http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + http2Conn.Serve(ctx) + }() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil) + require.NoError(t, err) + req.Header.Set(internalUpgradeHeader, controlStreamUpgrade) + + edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn) + require.NoError(t, err) + + wg.Add(1) + go func() { + defer wg.Done() + edgeHTTP2Conn.RoundTrip(req) + }() + + <-rpcClientFactory.registered + cancel() + <-rpcClientFactory.unregistered + + wg.Wait() +} + +func benchmarkServeHTTP(b *testing.B, test testRequest) { + http2Conn, edgeConn := newTestHTTP2Connection() + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + http2Conn.Serve(ctx) + }() + + endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + require.NoError(b, err) + + edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn) + require.NoError(b, err) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StartTimer() + resp, err := edgeHTTP2Conn.RoundTrip(req) + b.StopTimer() + require.NoError(b, err) + require.Equal(b, test.expectedStatus, resp.StatusCode) + if test.expectedBody != nil { + respBody, err := ioutil.ReadAll(resp.Body) + require.NoError(b, err) + require.Equal(b, test.expectedBody, respBody) + } + resp.Body.Close() + } + + cancel() + wg.Wait() +} +func BenchmarkServeHTTPSimple(b *testing.B) { + test := testRequest{ + name: "ok", + endpoint: "ok", + expectedStatus: http.StatusOK, + expectedBody: []byte(http.StatusText(http.StatusOK)), + } + + benchmarkServeHTTP(b, test) +} + +func BenchmarkServeHTTPLargeFile(b *testing.B) { + test := testRequest{ + name: "large_file", + endpoint: "large_file", + expectedStatus: http.StatusOK, + expectedBody: testLargeResp, + } + + benchmarkServeHTTP(b, test) +} diff --git a/connection/rpc.go b/connection/rpc.go index 9da7630d..fd756fd5 100644 --- a/connection/rpc.go +++ b/connection/rpc.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "time" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/tunnelrpc" @@ -49,6 +50,18 @@ func (tsc *tunnelServerClient) Close() { tsc.transport.Close() } +type NamedTunnelRPCClient interface { + RegisterConnection( + c context.Context, + config *NamedTunnelConfig, + options *tunnelpogs.ConnectionOptions, + connIndex uint8, + observer *Observer, + ) error + GracefulShutdown(ctx context.Context, gracePeriod time.Duration) + Close() +} + type registrationServerClient struct { client tunnelpogs.RegistrationServer_PogsClient transport rpc.Transport @@ -58,7 +71,7 @@ func newRegistrationRPCClient( ctx context.Context, stream io.ReadWriteCloser, logger logger.Service, -) *registrationServerClient { +) NamedTunnelRPCClient { transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)) conn := rpc.NewConn( transport, @@ -70,31 +83,14 @@ func newRegistrationRPCClient( } } -func (rsc *registrationServerClient) close() { - // Closing the client will also close the connection - rsc.client.Close() - // Closing the transport also closes the stream - rsc.transport.Close() -} - -type rpcName string - -const ( - register rpcName = "register" - reconnect rpcName = "reconnect" - unregister rpcName = "unregister" - authenticate rpcName = " authenticate" -) - -func registerConnection( +func (rsc *registrationServerClient) RegisterConnection( ctx context.Context, - rpcClient *registrationServerClient, config *NamedTunnelConfig, options *tunnelpogs.ConnectionOptions, connIndex uint8, observer *Observer, ) error { - conn, err := rpcClient.client.RegisterConnection( + conn, err := rsc.client.RegisterConnection( ctx, config.Auth, config.ID, @@ -118,6 +114,28 @@ func registerConnection( return nil } +func (rsc *registrationServerClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) { + ctx, cancel := context.WithTimeout(ctx, gracePeriod) + defer cancel() + rsc.client.UnregisterConnection(ctx) +} + +func (rsc *registrationServerClient) Close() { + // Closing the client will also close the connection + rsc.client.Close() + // Closing the transport also closes the stream + rsc.transport.Close() +} + +type rpcName string + +const ( + register rpcName = "register" + reconnect rpcName = "reconnect" + unregister rpcName = "unregister" + authenticate rpcName = " authenticate" +) + func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error { h.observer.sendRegisteringEvent() @@ -264,9 +282,9 @@ func (h *h2muxConnection) unregister(isNamedTunnel bool) { if isNamedTunnel { rpcClient := newRegistrationRPCClient(unregisterCtx, stream, h.observer) - defer rpcClient.close() + defer rpcClient.Close() - rpcClient.client.UnregisterConnection(unregisterCtx) + rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod) } else { rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer) defer rpcClient.Close() diff --git a/origin/proxy.go b/origin/proxy.go index 589d601a..fe3ddc94 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -60,7 +60,7 @@ func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsock } if err != nil { c.logRequestError(err, cfRay, ruleNum) - w.WriteErrorResponse(err) + w.WriteErrorResponse() return err } c.logOriginResponse(resp, cfRay, lbProbe, ruleNum) diff --git a/origin/proxy_test.go b/origin/proxy_test.go index 09b02ea0..7a85f6d2 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -47,7 +47,7 @@ func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error { return nil } -func (w *mockHTTPRespWriter) WriteErrorResponse(err error) { +func (w *mockHTTPRespWriter) WriteErrorResponse() { w.WriteHeader(http.StatusBadGateway) }