diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go index 5f6b8b5b..97312047 100644 --- a/h2mux/h2mux_test.go +++ b/h2mux/h2mux_test.go @@ -15,10 +15,11 @@ import ( "testing" "time" - "github.com/cloudflare/cloudflared/logger" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "golang.org/x/sync/errgroup" + + "github.com/cloudflare/cloudflared/logger" ) const ( @@ -1032,3 +1033,108 @@ func openStreams(b *testing.B, muxPair *DefaultMuxerPair, n int) { } assert.NoError(b, errGroup.Wait()) } + +func BenchmarkSingleStreamLargeResponseBody(b *testing.B) { + const bodySize = 1 << 24 + + const writeBufferSize = 16 << 10 + const writeN = bodySize / writeBufferSize + payload := make([]byte, writeBufferSize) + for i := range payload { + payload[i] = byte(i % 256) + } + + const readBufferSize = 16 << 10 + const readN = bodySize / readBufferSize + responseBody := make([]byte, readBufferSize) + + f := MuxedStreamFunc(func(stream *MuxedStream) error { + if len(stream.Headers) != 1 { + b.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) + } + if stream.Headers[0].Name != "test-header" { + b.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name) + } + if stream.Headers[0].Value != "headerValue" { + b.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value) + } + stream.WriteHeaders([]Header{ + {Name: "response-header", Value: "responseValue"}, + }) + for i := 0; i < writeN; i++ { + n, err := stream.Write(payload) + if err != nil { + b.Fatalf("origin write error: %s", err) + } + if n != len(payload) { + b.Fatalf("origin short write: %d/%d bytes", n, len(payload)) + } + } + + return nil + }) + + name := fmt.Sprintf("%s_%d", b.Name(), rand.Int()) + origin, edge := net.Pipe() + + muxPair := &DefaultMuxerPair{ + OriginMuxConfig: MuxerConfig{ + Timeout: testHandshakeTimeout, + Handler: f, + IsClient: true, + Name: "origin", + Logger: logger.NewOutputWriter(logger.NewMockWriteManager()), + DefaultWindowSize: defaultWindowSize, + MaxWindowSize: maxWindowSize, + StreamWriteBufferMaxLen: defaultWriteBufferMaxLen, + HeartbeatInterval: defaultTimeout, + MaxHeartbeats: defaultRetries, + }, + OriginConn: origin, + EdgeMuxConfig: MuxerConfig{ + Timeout: testHandshakeTimeout, + IsClient: false, + Name: "edge", + Logger: logger.NewOutputWriter(logger.NewMockWriteManager()), + DefaultWindowSize: defaultWindowSize, + MaxWindowSize: maxWindowSize, + StreamWriteBufferMaxLen: defaultWriteBufferMaxLen, + HeartbeatInterval: defaultTimeout, + MaxHeartbeats: defaultRetries, + }, + EdgeConn: edge, + doneC: make(chan struct{}), + } + assert.NoError(b, muxPair.Handshake(name)) + muxPair.Serve(b) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + stream, err := muxPair.OpenEdgeMuxStream( + []Header{{Name: "test-header", Value: "headerValue"}}, + nil, + ) + if err != nil { + b.Fatalf("error in OpenStream: %s", err) + } + if len(stream.Headers) != 1 { + b.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) + } + if stream.Headers[0].Name != "response-header" { + b.Fatalf("expected header name %s, got %s", "response-header", stream.Headers[0].Name) + } + if stream.Headers[0].Value != "responseValue" { + b.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value) + } + + for k := 0; k < readN; k++ { + n, err := io.ReadFull(stream, responseBody) + if err != nil { + b.Fatalf("error from (*MuxedStream).Read: %s", err) + } + if n != len(responseBody) { + b.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n) + } + } + } +}