package connection import ( "context" "fmt" "io" "net" "net/http" "strconv" "sync" "testing" "time" "github.com/gobwas/ws/wsutil" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/cloudflare/cloudflared/h2mux" ) var ( testMuxerConfig = &MuxerConfig{ HeartbeatInterval: time.Second * 5, MaxHeartbeats: 5, CompressionSetting: 0, MetricsUpdateFreq: time.Second * 5, } ) func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) { edgeConn, originConn := net.Pipe() edgeMuxChan := make(chan *h2mux.Muxer) go func() { edgeMuxConfig := h2mux.MuxerConfig{ Log: &log, Handler: h2mux.MuxedStreamFunc(func(stream *h2mux.MuxedStream) error { // we only expect RPC traffic in client->edge direction, provide minimal support for mocking require.True(t, stream.IsRPCStream()) return stream.WriteHeaders([]h2mux.Header{ {Name: ":status", Value: "200"}, }) }), } edgeMux, err := h2mux.Handshake(edgeConn, edgeConn, edgeMuxConfig, h2mux.ActiveStreams) require.NoError(t, err) edgeMuxChan <- edgeMux }() var connIndex = uint8(0) testObserver := NewObserver(&log, &log, false) h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver, nil) require.NoError(t, err) return h2muxConn, <-edgeMuxChan } func TestServeStreamHTTP(t *testing.T) { tests := []testRequest{ { name: "ok", endpoint: "/ok", expectedStatus: http.StatusOK, expectedBody: []byte(http.StatusText(http.StatusOK)), }, { name: "large_file", endpoint: "/large_file", expectedStatus: http.StatusOK, expectedBody: testLargeResp, }, { name: "Bad request", endpoint: "/400", expectedStatus: http.StatusBadRequest, expectedBody: []byte(http.StatusText(http.StatusBadRequest)), }, { name: "Internal server error", endpoint: "/500", expectedStatus: http.StatusInternalServerError, expectedBody: []byte(http.StatusText(http.StatusInternalServerError)), }, { name: "Proxy error", endpoint: "/error", expectedStatus: http.StatusBadGateway, expectedBody: nil, isProxyError: true, }, } ctx, cancel := context.WithCancel(context.Background()) h2muxConn, edgeMux := newH2MuxConnection(t) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() _ = edgeMux.Serve(ctx) }() go func() { defer wg.Done() err := h2muxConn.serveMuxer(ctx) require.Error(t, err) }() for _, test := range tests { headers := []h2mux.Header{ { Name: ":path", Value: test.endpoint, }, } stream, err := edgeMux.OpenStream(ctx, headers, nil) require.NoError(t, err) require.True(t, hasHeader(stream, ":status", strconv.Itoa(test.expectedStatus))) if test.isProxyError { assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderCfd)) } else { assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin)) body := make([]byte, len(test.expectedBody)) _, err = stream.Read(body) require.NoError(t, err) require.Equal(t, test.expectedBody, body) } } cancel() wg.Wait() } func TestServeStreamWS(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) h2muxConn, edgeMux := newH2MuxConnection(t) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() edgeMux.Serve(ctx) }() go func() { defer wg.Done() err := h2muxConn.serveMuxer(ctx) require.Error(t, err) }() headers := []h2mux.Header{ { Name: ":path", Value: "/ws", }, { Name: "connection", Value: "upgrade", }, { Name: "upgrade", Value: "websocket", }, } readPipe, writePipe := io.Pipe() stream, err := edgeMux.OpenStream(ctx, headers, readPipe) require.NoError(t, err) require.True(t, hasHeader(stream, ":status", strconv.Itoa(http.StatusSwitchingProtocols))) assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin)) data := []byte("test websocket") err = wsutil.WriteClientText(writePipe, data) require.NoError(t, err) respBody, err := wsutil.ReadServerText(stream) require.NoError(t, err) require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody))) cancel() wg.Wait() } func TestGracefulShutdownH2Mux(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() h2muxConn, edgeMux := newH2MuxConnection(t) shutdownC := make(chan struct{}) unregisteredC := make(chan struct{}) h2muxConn.gracefulShutdownC = shutdownC h2muxConn.newRPCClientFunc = func(_ context.Context, _ io.ReadWriteCloser, _ *zerolog.Logger) NamedTunnelRPCClient { return &mockNamedTunnelRPCClient{ registered: nil, unregistered: unregisteredC, } } var wg sync.WaitGroup wg.Add(3) go func() { defer wg.Done() _ = edgeMux.Serve(ctx) }() go func() { defer wg.Done() _ = h2muxConn.serveMuxer(ctx) }() go func() { defer wg.Done() h2muxConn.controlLoop(ctx, &mockConnectedFuse{}, true) }() time.Sleep(100 * time.Millisecond) close(shutdownC) select { case <-unregisteredC: break // ok case <-time.Tick(time.Second): assert.Fail(t, "timed out waiting for control loop to unregister") } cancel() wg.Wait() assert.True(t, h2muxConn.stoppedGracefully) assert.Nil(t, h2muxConn.gracefulShutdownC) } func hasHeader(stream *h2mux.MuxedStream, name, val string) bool { for _, header := range stream.Headers { if header.Name == name && header.Value == val { return true } } return false } func benchmarkServeStreamHTTPSimple(b *testing.B, test testRequest) { ctx, cancel := context.WithCancel(context.Background()) h2muxConn, edgeMux := newH2MuxConnection(b) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() edgeMux.Serve(ctx) }() go func() { defer wg.Done() err := h2muxConn.serveMuxer(ctx) require.Error(b, err) }() headers := []h2mux.Header{ { Name: ":path", Value: test.endpoint, }, } body := make([]byte, len(test.expectedBody)) b.ResetTimer() for i := 0; i < b.N; i++ { b.StartTimer() stream, openstreamErr := edgeMux.OpenStream(ctx, headers, nil) _, readBodyErr := stream.Read(body) b.StopTimer() require.NoError(b, openstreamErr) assert.True(b, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin)) require.True(b, hasHeader(stream, ":status", strconv.Itoa(http.StatusOK))) require.NoError(b, readBodyErr) require.Equal(b, test.expectedBody, body) } cancel() wg.Wait() } func BenchmarkServeStreamHTTPSimple(b *testing.B) { test := testRequest{ name: "ok", endpoint: "/ok", expectedStatus: http.StatusOK, expectedBody: []byte(http.StatusText(http.StatusOK)), } benchmarkServeStreamHTTPSimple(b, test) } func BenchmarkServeStreamHTTPLargeFile(b *testing.B) { test := testRequest{ name: "large_file", endpoint: "/large_file", expectedStatus: http.StatusOK, expectedBody: testLargeResp, } benchmarkServeStreamHTTPSimple(b, test) }