cloudflared-mirror/streamhandler/stream_handler_test.go

260 lines
7.8 KiB
Go

package streamhandler
import (
"context"
"io"
"net"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
)
const (
testOpenStreamTimeout = time.Millisecond * 5000
testHandshakeTimeout = time.Millisecond * 1000
)
var (
testTunnelHostname = h2mux.TunnelHostname("123.cftunnel.com")
baseHeaders = []h2mux.Header{
{Name: ":method", Value: "GET"},
{Name: ":scheme", Value: "http"},
{Name: ":authority", Value: "example.com"},
{Name: ":path", Value: "/"},
// Regular headers must always come after the pseudoheaders
{Name: h2mux.RequestUserHeadersField, Value: ""},
}
tunnelHostnameHeader = h2mux.Header{Name: h2mux.CloudflaredProxyTunnelHostnameHeader, Value: testTunnelHostname.String()}
)
func TestServeRequest(t *testing.T) {
configChan := make(chan *pogs.ClientConfig)
useConfigResultChan := make(chan *pogs.UseConfigurationResult)
streamHandler := NewStreamHandler(configChan, useConfigResultChan, logrus.New())
message := []byte("Hello cloudflared")
httpServer := httptest.NewServer(&mockHTTPHandler{message})
reverseProxyConfigs := []*pogs.ReverseProxyConfig{
{
TunnelHostname: testTunnelHostname,
OriginConfig: &pogs.HTTPOriginConfig{
URLString: httpServer.URL,
},
},
}
streamHandler.UpdateConfig(reverseProxyConfigs)
muxPair := NewDefaultMuxerPair(t, streamHandler)
muxPair.Serve(t)
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
defer cancel()
headers := append(baseHeaders, tunnelHostnameHeader)
stream, err := muxPair.EdgeMux.OpenStream(ctx, headers, nil)
assert.NoError(t, err)
assertStatusHeader(t, http.StatusOK, stream.Headers)
assertRespBody(t, message, stream)
}
func createStreamHandler() *StreamHandler {
configChan := make(chan *pogs.ClientConfig)
useConfigResultChan := make(chan *pogs.UseConfigurationResult)
return NewStreamHandler(configChan, useConfigResultChan, logrus.New())
}
func createRequestMuxPair(t *testing.T, streamHandler *StreamHandler) *DefaultMuxerPair {
muxPair := NewDefaultMuxerPair(t, streamHandler)
muxPair.Serve(t)
return muxPair
}
func TestServeStatusBadRequest(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
defer cancel()
// No tunnel hostname header, expect to get 400 Bad Request
stream, err := createRequestMuxPair(t, createStreamHandler()).EdgeMux.OpenStream(ctx, baseHeaders, nil)
assert.NoError(t, err)
assertStatusHeader(t, http.StatusBadRequest, stream.Headers)
assertRespBody(t, statusBadRequest.text, stream)
}
func TestServeInvalidContentLength(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
defer cancel()
// Invalid content-length, wouldn't be able to create a request
// Expect to get 400 Bad Request
headers := append(baseHeaders, tunnelHostnameHeader)
headers = append(headers, h2mux.Header{
Name: "content-length",
Value: "x",
})
streamHandler := createStreamHandler()
streamHandler.UpdateConfig([]*pogs.ReverseProxyConfig{
{
TunnelHostname: testTunnelHostname,
OriginConfig: &pogs.HTTPOriginConfig{
URLString: "",
},
},
})
mux := createRequestMuxPair(t, streamHandler).EdgeMux
stream, err := mux.OpenStream(ctx, headers, nil)
assert.NoError(t, err)
assertStatusHeader(t, http.StatusBadRequest, stream.Headers)
assertRespBody(t, statusBadRequest.text, stream)
}
func TestServeStatusNotFound(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
defer cancel()
// No mapping for the tunnel hostname, expect to get 404 Not Found
headers := append(baseHeaders, tunnelHostnameHeader)
stream, err := createRequestMuxPair(t, createStreamHandler()).EdgeMux.OpenStream(ctx, headers, nil)
assert.NoError(t, err)
assertStatusHeader(t, http.StatusNotFound, stream.Headers)
assertRespBody(t, statusNotFound.text, stream)
}
func TestServeStatusBadGateway(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
defer cancel()
// Nothing listening on empty url, so proxy would fail. Expect to get 502 Bad Gateway
reverseProxyConfigs := []*pogs.ReverseProxyConfig{
{
TunnelHostname: testTunnelHostname,
OriginConfig: &pogs.HTTPOriginConfig{
URLString: "",
},
},
}
streamHandler := createStreamHandler()
streamHandler.UpdateConfig(reverseProxyConfigs)
headers := append(baseHeaders, tunnelHostnameHeader)
stream, err := createRequestMuxPair(t, streamHandler).EdgeMux.OpenStream(ctx, headers, nil)
assert.NoError(t, err)
assertStatusHeader(t, http.StatusBadGateway, stream.Headers)
assertRespBody(t, statusBadGateway.text, stream)
}
func assertStatusHeader(t *testing.T, expectedStatus int, headers []h2mux.Header) {
assert.Equal(t, statusPseudoHeader, headers[0].Name)
assert.Equal(t, strconv.Itoa(expectedStatus), headers[0].Value)
}
func assertRespBody(t *testing.T, expectedRespBody []byte, stream *h2mux.MuxedStream) {
respBody := make([]byte, len(expectedRespBody))
_, err := stream.Read(respBody)
assert.NoError(t, err)
assert.Equal(t, expectedRespBody, respBody)
}
type DefaultMuxerPair struct {
OriginMuxConfig h2mux.MuxerConfig
OriginMux *h2mux.Muxer
OriginConn net.Conn
EdgeMuxConfig h2mux.MuxerConfig
EdgeMux *h2mux.Muxer
EdgeConn net.Conn
doneC chan struct{}
}
func NewDefaultMuxerPair(t *testing.T, h h2mux.MuxedStreamHandler) *DefaultMuxerPair {
origin, edge := net.Pipe()
p := &DefaultMuxerPair{
OriginMuxConfig: h2mux.MuxerConfig{
Timeout: testHandshakeTimeout,
Handler: h,
IsClient: true,
Name: "origin",
Logger: logrus.NewEntry(logrus.New()),
DefaultWindowSize: (1 << 8) - 1,
MaxWindowSize: (1 << 15) - 1,
StreamWriteBufferMaxLen: 1024,
},
OriginConn: origin,
EdgeMuxConfig: h2mux.MuxerConfig{
Timeout: testHandshakeTimeout,
IsClient: false,
Name: "edge",
Logger: logrus.NewEntry(logrus.New()),
DefaultWindowSize: (1 << 8) - 1,
MaxWindowSize: (1 << 15) - 1,
StreamWriteBufferMaxLen: 1024,
},
EdgeConn: edge,
doneC: make(chan struct{}),
}
assert.NoError(t, p.Handshake(t.Name()))
return p
}
func (p *DefaultMuxerPair) Handshake(testName string) error {
ctx, cancel := context.WithTimeout(context.Background(), testHandshakeTimeout)
defer cancel()
errGroup, _ := errgroup.WithContext(ctx)
errGroup.Go(func() (err error) {
p.EdgeMux, err = h2mux.Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig, h2mux.NewActiveStreamsMetrics(testName, "edge"))
return errors.Wrap(err, "edge handshake failure")
})
errGroup.Go(func() (err error) {
p.OriginMux, err = h2mux.Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig, h2mux.NewActiveStreamsMetrics(testName, "origin"))
return errors.Wrap(err, "origin handshake failure")
})
return errGroup.Wait()
}
func (p *DefaultMuxerPair) Serve(t assert.TestingT) {
ctx := context.Background()
var wg sync.WaitGroup
wg.Add(2)
go func() {
err := p.EdgeMux.Serve(ctx)
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
t.Errorf("error in edge muxer Serve(): %s", err)
}
p.OriginMux.Shutdown()
wg.Done()
}()
go func() {
err := p.OriginMux.Serve(ctx)
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
t.Errorf("error in origin muxer Serve(): %s", err)
}
p.EdgeMux.Shutdown()
wg.Done()
}()
go func() {
// notify when both muxes have stopped serving
wg.Wait()
close(p.doneC)
}()
}
type mockHTTPHandler struct {
message []byte
}
func (mth *mockHTTPHandler) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write(mth.message)
}