cloudflared-mirror/connection/h2mux_test.go

302 lines
7.0 KiB
Go

package connection
import (
"context"
"fmt"
"io"
"net"
"net/http"
"strconv"
"sync"
"testing"
"time"
"github.com/gobwas/ws/wsutil"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/h2mux"
)
var (
testMuxerConfig = &MuxerConfig{
HeartbeatInterval: time.Second * 5,
MaxHeartbeats: 5,
CompressionSetting: 0,
MetricsUpdateFreq: time.Second * 5,
}
)
func newH2MuxConnection(t require.TestingT) (*h2muxConnection, *h2mux.Muxer) {
edgeConn, originConn := net.Pipe()
edgeMuxChan := make(chan *h2mux.Muxer)
go func() {
edgeMuxConfig := h2mux.MuxerConfig{
Log: &log,
Handler: h2mux.MuxedStreamFunc(func(stream *h2mux.MuxedStream) error {
// we only expect RPC traffic in client->edge direction, provide minimal support for mocking
require.True(t, stream.IsRPCStream())
return stream.WriteHeaders([]h2mux.Header{
{Name: ":status", Value: "200"},
})
}),
}
edgeMux, err := h2mux.Handshake(edgeConn, edgeConn, edgeMuxConfig, h2mux.ActiveStreams)
require.NoError(t, err)
edgeMuxChan <- edgeMux
}()
var connIndex = uint8(0)
testObserver := NewObserver(&log, &log, false)
h2muxConn, err, _ := NewH2muxConnection(testConfig, testMuxerConfig, originConn, connIndex, testObserver, nil)
require.NoError(t, err)
return h2muxConn, <-edgeMuxChan
}
func TestServeStreamHTTP(t *testing.T) {
tests := []testRequest{
{
name: "ok",
endpoint: "/ok",
expectedStatus: http.StatusOK,
expectedBody: []byte(http.StatusText(http.StatusOK)),
},
{
name: "large_file",
endpoint: "/large_file",
expectedStatus: http.StatusOK,
expectedBody: testLargeResp,
},
{
name: "Bad request",
endpoint: "/400",
expectedStatus: http.StatusBadRequest,
expectedBody: []byte(http.StatusText(http.StatusBadRequest)),
},
{
name: "Internal server error",
endpoint: "/500",
expectedStatus: http.StatusInternalServerError,
expectedBody: []byte(http.StatusText(http.StatusInternalServerError)),
},
{
name: "Proxy error",
endpoint: "/error",
expectedStatus: http.StatusBadGateway,
expectedBody: nil,
isProxyError: true,
},
}
ctx, cancel := context.WithCancel(context.Background())
h2muxConn, edgeMux := newH2MuxConnection(t)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_ = edgeMux.Serve(ctx)
}()
go func() {
defer wg.Done()
err := h2muxConn.serveMuxer(ctx)
require.Error(t, err)
}()
for _, test := range tests {
headers := []h2mux.Header{
{
Name: ":path",
Value: test.endpoint,
},
}
stream, err := edgeMux.OpenStream(ctx, headers, nil)
require.NoError(t, err)
require.True(t, hasHeader(stream, ":status", strconv.Itoa(test.expectedStatus)))
if test.isProxyError {
assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderCfd))
} else {
assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin))
body := make([]byte, len(test.expectedBody))
_, err = stream.Read(body)
require.NoError(t, err)
require.Equal(t, test.expectedBody, body)
}
}
cancel()
wg.Wait()
}
func TestServeStreamWS(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
h2muxConn, edgeMux := newH2MuxConnection(t)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
edgeMux.Serve(ctx)
}()
go func() {
defer wg.Done()
err := h2muxConn.serveMuxer(ctx)
require.Error(t, err)
}()
headers := []h2mux.Header{
{
Name: ":path",
Value: "/ws",
},
{
Name: "connection",
Value: "upgrade",
},
{
Name: "upgrade",
Value: "websocket",
},
}
readPipe, writePipe := io.Pipe()
stream, err := edgeMux.OpenStream(ctx, headers, readPipe)
require.NoError(t, err)
require.True(t, hasHeader(stream, ":status", strconv.Itoa(http.StatusSwitchingProtocols)))
assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin))
data := []byte("test websocket")
err = wsutil.WriteClientText(writePipe, data)
require.NoError(t, err)
respBody, err := wsutil.ReadServerText(stream)
require.NoError(t, err)
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
cancel()
wg.Wait()
}
func TestGracefulShutdownH2Mux(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
h2muxConn, edgeMux := newH2MuxConnection(t)
shutdownC := make(chan struct{})
unregisteredC := make(chan struct{})
h2muxConn.gracefulShutdownC = shutdownC
h2muxConn.newRPCClientFunc = func(_ context.Context, _ io.ReadWriteCloser, _ *zerolog.Logger) NamedTunnelRPCClient {
return &mockNamedTunnelRPCClient{
registered: nil,
unregistered: unregisteredC,
}
}
var wg sync.WaitGroup
wg.Add(3)
go func() {
defer wg.Done()
_ = edgeMux.Serve(ctx)
}()
go func() {
defer wg.Done()
_ = h2muxConn.serveMuxer(ctx)
}()
go func() {
defer wg.Done()
h2muxConn.controlLoop(ctx, &mockConnectedFuse{}, true)
}()
time.Sleep(100 * time.Millisecond)
close(shutdownC)
select {
case <-unregisteredC:
break // ok
case <-time.Tick(time.Second):
assert.Fail(t, "timed out waiting for control loop to unregister")
}
cancel()
wg.Wait()
assert.True(t, h2muxConn.stoppedGracefully)
assert.Nil(t, h2muxConn.gracefulShutdownC)
}
func hasHeader(stream *h2mux.MuxedStream, name, val string) bool {
for _, header := range stream.Headers {
if header.Name == name && header.Value == val {
return true
}
}
return false
}
func benchmarkServeStreamHTTPSimple(b *testing.B, test testRequest) {
ctx, cancel := context.WithCancel(context.Background())
h2muxConn, edgeMux := newH2MuxConnection(b)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
edgeMux.Serve(ctx)
}()
go func() {
defer wg.Done()
err := h2muxConn.serveMuxer(ctx)
require.Error(b, err)
}()
headers := []h2mux.Header{
{
Name: ":path",
Value: test.endpoint,
},
}
body := make([]byte, len(test.expectedBody))
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StartTimer()
stream, openstreamErr := edgeMux.OpenStream(ctx, headers, nil)
_, readBodyErr := stream.Read(body)
b.StopTimer()
require.NoError(b, openstreamErr)
assert.True(b, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin))
require.True(b, hasHeader(stream, ":status", strconv.Itoa(http.StatusOK)))
require.NoError(b, readBodyErr)
require.Equal(b, test.expectedBody, body)
}
cancel()
wg.Wait()
}
func BenchmarkServeStreamHTTPSimple(b *testing.B) {
test := testRequest{
name: "ok",
endpoint: "/ok",
expectedStatus: http.StatusOK,
expectedBody: []byte(http.StatusText(http.StatusOK)),
}
benchmarkServeStreamHTTPSimple(b, test)
}
func BenchmarkServeStreamHTTPLargeFile(b *testing.B) {
test := testRequest{
name: "large_file",
endpoint: "/large_file",
expectedStatus: http.StatusOK,
expectedBody: testLargeResp,
}
benchmarkServeStreamHTTPSimple(b, test)
}