TUN-3489: Add unit tests to cover proxy logic in connection package of cloudflared

This commit is contained in:
cthuang 2020-10-27 22:27:15 +00:00
parent 5974fb4cfd
commit d5769519b2
9 changed files with 754 additions and 92 deletions

View File

@ -44,7 +44,7 @@ type OriginClient interface {
type ResponseWriter interface { type ResponseWriter interface {
WriteRespHeaders(*http.Response) error WriteRespHeaders(*http.Response) error
WriteErrorResponse(error) WriteErrorResponse()
io.ReadWriter io.ReadWriter
} }

View File

@ -0,0 +1,113 @@
package connection
import (
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
"github.com/cloudflare/cloudflared/logger"
"github.com/gobwas/ws/wsutil"
)
const (
largeFileSize = 2 * 1024 * 1024
)
var (
testConfig = &Config{
OriginClient: &mockOriginClient{},
GracePeriod: time.Millisecond * 100,
}
testLogger, _ = logger.New()
testOriginURL = &url.URL{
Scheme: "https",
Host: "connectiontest.argotunnel.com",
}
testTunnelEventChan = make(chan ui.TunnelEvent)
testObserver = &Observer{
testLogger,
m,
testTunnelEventChan,
}
testLargeResp = make([]byte, largeFileSize)
)
type testRequest struct {
name string
endpoint string
expectedStatus int
expectedBody []byte
isProxyError bool
}
type mockOriginClient struct {
}
func (moc *mockOriginClient) Proxy(w ResponseWriter, r *http.Request, isWebsocket bool) error {
if isWebsocket {
return wsEndpoint(w, r)
}
switch r.URL.Path {
case "/ok":
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
case "/large_file":
originRespEndpoint(w, http.StatusOK, testLargeResp)
case "/400":
originRespEndpoint(w, http.StatusBadRequest, []byte(http.StatusText(http.StatusBadRequest)))
case "/500":
originRespEndpoint(w, http.StatusInternalServerError, []byte(http.StatusText(http.StatusInternalServerError)))
case "/error":
return fmt.Errorf("Failed to proxy to origin")
default:
originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
}
return nil
}
type nowriter struct {
io.Reader
}
func (nowriter) Write(p []byte) (int, error) {
return 0, fmt.Errorf("Writer not implemented")
}
func wsEndpoint(w ResponseWriter, r *http.Request) error {
resp := &http.Response{
StatusCode: http.StatusSwitchingProtocols,
}
w.WriteRespHeaders(resp)
clientReader := nowriter{r.Body}
go func() {
for {
data, err := wsutil.ReadClientText(clientReader)
if err != nil {
return
}
if err := wsutil.WriteServerText(w, data); err != nil {
return
}
}
}()
<-r.Context().Done()
return nil
}
func originRespEndpoint(w ResponseWriter, status int, data []byte) {
resp := &http.Response{
StatusCode: status,
}
w.WriteRespHeaders(resp)
w.Write(data)
}
type mockConnectedFuse struct{}
func (mcf mockConnectedFuse) Connected() {}
func (mcf mockConnectedFuse) IsConnected() bool {
return true
}

View File

@ -88,9 +88,9 @@ func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *Nam
return err return err
} }
rpcClient := newRegistrationRPCClient(ctx, stream, h.observer) rpcClient := newRegistrationRPCClient(ctx, stream, h.observer)
defer rpcClient.close() defer rpcClient.Close()
if err = registerConnection(serveCtx, rpcClient, namedTunnel, connOptions, h.connIndex, h.observer); err != nil { if err = rpcClient.RegisterConnection(serveCtx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
return err return err
} }
connectedFuse.Connected() connectedFuse.Connected()
@ -177,11 +177,16 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
req, reqErr := h.newRequest(stream) req, reqErr := h.newRequest(stream)
if reqErr != nil { if reqErr != nil {
respWriter.WriteErrorResponse(reqErr) respWriter.WriteErrorResponse()
return reqErr return reqErr
} }
return h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req)) err := h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
if err != nil {
respWriter.WriteErrorResponse()
return err
}
return nil
} }
func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) { func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
@ -206,7 +211,7 @@ func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
return rp.WriteHeaders(headers) return rp.WriteHeaders(headers)
} }
func (rp *h2muxRespWriter) WriteErrorResponse(err error) { func (rp *h2muxRespWriter) WriteErrorResponse() {
rp.WriteHeaders([]h2mux.Header{ rp.WriteHeaders([]h2mux.Header{
{Name: ":status", Value: "502"}, {Name: ":status", Value: "502"},
{Name: responseMetaHeaderField, Value: responseMetaHeaderCfd}, {Name: responseMetaHeaderField, Value: responseMetaHeaderCfd},

242
connection/h2mux_test.go Normal file
View File

@ -0,0 +1,242 @@
package connection
import (
"context"
"fmt"
"io"
"net"
"net/http"
"strconv"
"sync"
"testing"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/gobwas/ws/wsutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
testMuxerConfig = &MuxerConfig{
HeartbeatInterval: time.Second * 5,
MaxHeartbeats: 5,
CompressionSetting: 0,
MetricsUpdateFreq: time.Second * 5,
}
)
func newH2MuxConnection(ctx context.Context, t require.TestingT) (*h2muxConnection, *h2mux.Muxer) {
edgeConn, originConn := net.Pipe()
edgeMuxChan := make(chan *h2mux.Muxer)
go func() {
edgeMuxConfig := h2mux.MuxerConfig{
Logger: testObserver,
}
edgeMux, err := h2mux.Handshake(edgeConn, edgeConn, edgeMuxConfig, h2mux.ActiveStreams)
require.NoError(t, err)
edgeMuxChan <- edgeMux
}()
var connIndex = uint8(0)
h2muxConn, err, _ := NewH2muxConnection(ctx, testConfig, testMuxerConfig, originConn, connIndex, testObserver)
require.NoError(t, err)
return h2muxConn, <-edgeMuxChan
}
func TestServeStreamHTTP(t *testing.T) {
tests := []testRequest{
{
name: "ok",
endpoint: "/ok",
expectedStatus: http.StatusOK,
expectedBody: []byte(http.StatusText(http.StatusOK)),
},
{
name: "large_file",
endpoint: "/large_file",
expectedStatus: http.StatusOK,
expectedBody: testLargeResp,
},
{
name: "Bad request",
endpoint: "/400",
expectedStatus: http.StatusBadRequest,
expectedBody: []byte(http.StatusText(http.StatusBadRequest)),
},
{
name: "Internal server error",
endpoint: "/500",
expectedStatus: http.StatusInternalServerError,
expectedBody: []byte(http.StatusText(http.StatusInternalServerError)),
},
{
name: "Proxy error",
endpoint: "/error",
expectedStatus: http.StatusBadGateway,
expectedBody: nil,
isProxyError: true,
},
}
ctx, cancel := context.WithCancel(context.Background())
h2muxConn, edgeMux := newH2MuxConnection(ctx, t)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
edgeMux.Serve(ctx)
}()
go func() {
defer wg.Done()
err := h2muxConn.serveMuxer(ctx)
require.Error(t, err)
}()
for _, test := range tests {
headers := []h2mux.Header{
{
Name: ":path",
Value: test.endpoint,
},
}
stream, err := edgeMux.OpenStream(ctx, headers, nil)
require.NoError(t, err)
require.True(t, hasHeader(stream, ":status", strconv.Itoa(test.expectedStatus)))
if test.isProxyError {
assert.True(t, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderCfd))
} else {
assert.True(t, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderOrigin))
body := make([]byte, len(test.expectedBody))
_, err = stream.Read(body)
require.NoError(t, err)
require.Equal(t, test.expectedBody, body)
}
}
cancel()
wg.Wait()
}
func TestServeStreamWS(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
h2muxConn, edgeMux := newH2MuxConnection(ctx, t)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
edgeMux.Serve(ctx)
}()
go func() {
defer wg.Done()
err := h2muxConn.serveMuxer(ctx)
require.Error(t, err)
}()
headers := []h2mux.Header{
{
Name: ":path",
Value: "/ws",
},
{
Name: "connection",
Value: "upgrade",
},
{
Name: "upgrade",
Value: "websocket",
},
}
readPipe, writePipe := io.Pipe()
stream, err := edgeMux.OpenStream(ctx, headers, readPipe)
require.NoError(t, err)
require.True(t, hasHeader(stream, ":status", strconv.Itoa(http.StatusSwitchingProtocols)))
assert.True(t, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderOrigin))
data := []byte("test websocket")
err = wsutil.WriteClientText(writePipe, data)
require.NoError(t, err)
respBody, err := wsutil.ReadServerText(stream)
require.NoError(t, err)
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
cancel()
wg.Wait()
}
func hasHeader(stream *h2mux.MuxedStream, name, val string) bool {
for _, header := range stream.Headers {
if header.Name == name && header.Value == val {
return true
}
}
return false
}
func benchmarkServeStreamHTTPSimple(b *testing.B, test testRequest) {
ctx, cancel := context.WithCancel(context.Background())
h2muxConn, edgeMux := newH2MuxConnection(ctx, b)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
edgeMux.Serve(ctx)
}()
go func() {
defer wg.Done()
err := h2muxConn.serveMuxer(ctx)
require.Error(b, err)
}()
headers := []h2mux.Header{
{
Name: ":path",
Value: test.endpoint,
},
}
body := make([]byte, len(test.expectedBody))
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StartTimer()
stream, openstreamErr := edgeMux.OpenStream(ctx, headers, nil)
_, readBodyErr := stream.Read(body)
b.StopTimer()
require.NoError(b, openstreamErr)
assert.True(b, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderOrigin))
require.True(b, hasHeader(stream, ":status", strconv.Itoa(http.StatusOK)))
require.NoError(b, readBodyErr)
require.Equal(b, test.expectedBody, body)
}
cancel()
wg.Wait()
}
func BenchmarkServeStreamHTTPSimple(b *testing.B) {
test := testRequest{
name: "ok",
endpoint: "/ok",
expectedStatus: http.StatusOK,
expectedBody: []byte(http.StatusText(http.StatusOK)),
}
benchmarkServeStreamHTTPSimple(b, test)
}
func BenchmarkServeStreamHTTPLargeFile(b *testing.B) {
test := testRequest{
name: "large_file",
endpoint: "/large_file",
expectedStatus: http.StatusOK,
expectedBody: testLargeResp,
}
benchmarkServeStreamHTTPSimple(b, test)
}

View File

@ -11,6 +11,7 @@ import (
"sync" "sync"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"golang.org/x/net/http2" "golang.org/x/net/http2"
@ -26,7 +27,7 @@ var (
errNotFlusher = errors.New("ResponseWriter doesn't implement http.Flusher") errNotFlusher = errors.New("ResponseWriter doesn't implement http.Flusher")
) )
type HTTP2Connection struct { type http2Connection struct {
conn net.Conn conn net.Conn
server *http2.Server server *http2.Server
config *Config config *Config
@ -36,6 +37,8 @@ type HTTP2Connection struct {
connIndexStr string connIndexStr string
connIndex uint8 connIndex uint8
wg *sync.WaitGroup wg *sync.WaitGroup
// newRPCClientFunc allows us to mock RPCs during testing
newRPCClientFunc func(context.Context, io.ReadWriteCloser, logger.Service) NamedTunnelRPCClient
connectedFuse ConnectedFuse connectedFuse ConnectedFuse
} }
@ -47,8 +50,8 @@ func NewHTTP2Connection(
observer *Observer, observer *Observer,
connIndex uint8, connIndex uint8,
connectedFuse ConnectedFuse, connectedFuse ConnectedFuse,
) *HTTP2Connection { ) *http2Connection {
return &HTTP2Connection{ return &http2Connection{
conn: conn, conn: conn,
server: &http2.Server{ server: &http2.Server{
MaxConcurrentStreams: math.MaxUint32, MaxConcurrentStreams: math.MaxUint32,
@ -60,11 +63,12 @@ func NewHTTP2Connection(
connIndexStr: uint8ToString(connIndex), connIndexStr: uint8ToString(connIndex),
connIndex: connIndex, connIndex: connIndex,
wg: &sync.WaitGroup{}, wg: &sync.WaitGroup{},
newRPCClientFunc: newRegistrationRPCClient,
connectedFuse: connectedFuse, connectedFuse: connectedFuse,
} }
} }
func (c *HTTP2Connection) Serve(ctx context.Context) { func (c *http2Connection) Serve(ctx context.Context) {
go func() { go func() {
<-ctx.Done() <-ctx.Done()
c.close() c.close()
@ -75,7 +79,7 @@ func (c *HTTP2Connection) Serve(ctx context.Context) {
}) })
} }
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.wg.Add(1) c.wg.Add(1)
defer c.wg.Done() defer c.wg.Done()
@ -86,65 +90,42 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
flusher, isFlusher := w.(http.Flusher) flusher, isFlusher := w.(http.Flusher)
if !isFlusher { if !isFlusher {
c.observer.Errorf("%T doesn't implement http.Flusher", w) c.observer.Errorf("%T doesn't implement http.Flusher", w)
respWriter.WriteErrorResponse(errNotFlusher) respWriter.WriteErrorResponse()
return return
} }
respWriter.flusher = flusher respWriter.flusher = flusher
var err error
if isControlStreamUpgrade(r) { if isControlStreamUpgrade(r) {
respWriter.shouldFlush = true respWriter.shouldFlush = true
err := c.serveControlStream(r.Context(), respWriter) err = c.serveControlStream(r.Context(), respWriter)
if err != nil {
respWriter.WriteErrorResponse(err)
}
} else if isWebsocketUpgrade(r) { } else if isWebsocketUpgrade(r) {
respWriter.shouldFlush = true respWriter.shouldFlush = true
stripWebsocketUpgradeHeader(r) stripWebsocketUpgradeHeader(r)
c.config.OriginClient.Proxy(respWriter, r, true) err = c.config.OriginClient.Proxy(respWriter, r, true)
} else { } else {
c.config.OriginClient.Proxy(respWriter, r, false) err = c.config.OriginClient.Proxy(respWriter, r, false)
}
if err != nil {
respWriter.WriteErrorResponse()
} }
} }
func (c *HTTP2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error { func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
rpcClient := newRegistrationRPCClient(ctx, respWriter, c.observer) rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer)
defer rpcClient.close() defer rpcClient.Close()
if err := registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil { if err := rpcClient.RegisterConnection(ctx, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
return err return err
} }
c.connectedFuse.Connected() c.connectedFuse.Connected()
<-ctx.Done() <-ctx.Done()
c.gracefulShutdown(ctx, rpcClient) rpcClient.GracefulShutdown(ctx, c.config.GracePeriod)
return nil return nil
} }
func (c *HTTP2Connection) registerConnection( func (c *http2Connection) close() {
ctx context.Context,
rpcClient tunnelpogs.RegistrationServer_PogsClient,
) error {
connDetail, err := rpcClient.RegisterConnection(
ctx,
c.namedTunnel.Auth,
c.namedTunnel.ID,
c.connIndex,
c.connOptions,
)
if err != nil {
c.observer.Errorf("Cannot register connection, err: %v", err)
return err
}
c.observer.Infof("Connection %s registered with %s using ID %s", c.connIndexStr, connDetail.Location, connDetail.UUID)
return nil
}
func (c *HTTP2Connection) gracefulShutdown(ctx context.Context, rpcClient *registrationServerClient) {
ctx, cancel := context.WithTimeout(ctx, c.config.GracePeriod)
defer cancel()
rpcClient.client.UnregisterConnection(ctx)
}
func (c *HTTP2Connection) close() {
// Wait for all serve HTTP handlers to return // Wait for all serve HTTP handlers to return
c.wg.Wait() c.wg.Wait()
c.conn.Close() c.conn.Close()
@ -195,7 +176,7 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
return nil return nil
} }
func (rp *http2RespWriter) WriteErrorResponse(err error) { func (rp *http2RespWriter) WriteErrorResponse() {
rp.setResponseMetaHeader(responseMetaHeaderCfd) rp.setResponseMetaHeader(responseMetaHeaderCfd)
rp.w.WriteHeader(http.StatusBadGateway) rp.w.WriteHeader(http.StatusBadGateway)
} }

303
connection/http2_test.go Normal file
View File

@ -0,0 +1,303 @@
package connection
import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/gobwas/ws/wsutil"
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"
)
var (
testTransport = http2.Transport{}
)
func newTestHTTP2Connection() (*http2Connection, net.Conn) {
edgeConn, originConn := net.Pipe()
var connIndex = uint8(0)
return NewHTTP2Connection(
originConn,
testConfig,
&NamedTunnelConfig{},
&pogs.ConnectionOptions{},
testObserver,
connIndex,
mockConnectedFuse{},
), edgeConn
}
func TestServeHTTP(t *testing.T) {
tests := []testRequest{
{
name: "ok",
endpoint: "ok",
expectedStatus: http.StatusOK,
expectedBody: []byte(http.StatusText(http.StatusOK)),
},
{
name: "large_file",
endpoint: "large_file",
expectedStatus: http.StatusOK,
expectedBody: testLargeResp,
},
{
name: "Bad request",
endpoint: "400",
expectedStatus: http.StatusBadRequest,
expectedBody: []byte(http.StatusText(http.StatusBadRequest)),
},
{
name: "Internal server error",
endpoint: "500",
expectedStatus: http.StatusInternalServerError,
expectedBody: []byte(http.StatusText(http.StatusInternalServerError)),
},
{
name: "Proxy error",
endpoint: "error",
expectedStatus: http.StatusBadGateway,
expectedBody: nil,
isProxyError: true,
},
}
http2Conn, edgeConn := newTestHTTP2Connection()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
}()
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
for _, test := range tests {
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
require.NoError(t, err)
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, test.expectedStatus, resp.StatusCode)
if test.expectedBody != nil {
respBody, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, test.expectedBody, respBody)
}
if test.isProxyError {
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(responseMetaHeaderField))
} else {
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(responseMetaHeaderField))
}
}
cancel()
wg.Wait()
}
type mockNamedTunnelRPCClient struct {
registered chan struct{}
unregistered chan struct{}
}
func (mc mockNamedTunnelRPCClient) RegisterConnection(
c context.Context,
config *NamedTunnelConfig,
options *tunnelpogs.ConnectionOptions,
connIndex uint8,
observer *Observer,
) error {
close(mc.registered)
return nil
}
func (mc mockNamedTunnelRPCClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
close(mc.unregistered)
}
func (mockNamedTunnelRPCClient) Close() {}
type mockRPCClientFactory struct {
registered chan struct{}
unregistered chan struct{}
}
func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, logger.Service) NamedTunnelRPCClient {
return mockNamedTunnelRPCClient{
registered: mf.registered,
unregistered: mf.unregistered,
}
}
type wsRespWriter struct {
*httptest.ResponseRecorder
readPipe *io.PipeReader
writePipe *io.PipeWriter
}
func newWSRespWriter() *wsRespWriter {
readPipe, writePipe := io.Pipe()
return &wsRespWriter{
httptest.NewRecorder(),
readPipe,
writePipe,
}
}
func (w *wsRespWriter) RespBody() io.ReadWriter {
return nowriter{w.readPipe}
}
func (w *wsRespWriter) Write(data []byte) (n int, err error) {
return w.writePipe.Write(data)
}
func TestServeWS(t *testing.T) {
http2Conn, _ := newTestHTTP2Connection()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
}()
respWriter := newWSRespWriter()
readPipe, writePipe := io.Pipe()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
require.NoError(t, err)
req.Header.Set(internalUpgradeHeader, websocketUpgrade)
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.ServeHTTP(respWriter, req)
}()
data := []byte("test websocket")
err = wsutil.WriteClientText(writePipe, data)
require.NoError(t, err)
respBody, err := wsutil.ReadServerText(respWriter.RespBody())
require.NoError(t, err)
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
cancel()
resp := respWriter.Result()
// http2RespWriter should rewrite status 101 to 200
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(responseMetaHeaderField))
wg.Wait()
}
func TestServeControlStream(t *testing.T) {
http2Conn, edgeConn := newTestHTTP2Connection()
rpcClientFactory := mockRPCClientFactory{
registered: make(chan struct{}),
unregistered: make(chan struct{}),
}
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
require.NoError(t, err)
req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
wg.Add(1)
go func() {
defer wg.Done()
edgeHTTP2Conn.RoundTrip(req)
}()
<-rpcClientFactory.registered
cancel()
<-rpcClientFactory.unregistered
wg.Wait()
}
func benchmarkServeHTTP(b *testing.B, test testRequest) {
http2Conn, edgeConn := newTestHTTP2Connection()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
}()
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
require.NoError(b, err)
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(b, err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StartTimer()
resp, err := edgeHTTP2Conn.RoundTrip(req)
b.StopTimer()
require.NoError(b, err)
require.Equal(b, test.expectedStatus, resp.StatusCode)
if test.expectedBody != nil {
respBody, err := ioutil.ReadAll(resp.Body)
require.NoError(b, err)
require.Equal(b, test.expectedBody, respBody)
}
resp.Body.Close()
}
cancel()
wg.Wait()
}
func BenchmarkServeHTTPSimple(b *testing.B) {
test := testRequest{
name: "ok",
endpoint: "ok",
expectedStatus: http.StatusOK,
expectedBody: []byte(http.StatusText(http.StatusOK)),
}
benchmarkServeHTTP(b, test)
}
func BenchmarkServeHTTPLargeFile(b *testing.B) {
test := testRequest{
name: "large_file",
endpoint: "large_file",
expectedStatus: http.StatusOK,
expectedBody: testLargeResp,
}
benchmarkServeHTTP(b, test)
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"time"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
@ -49,6 +50,18 @@ func (tsc *tunnelServerClient) Close() {
tsc.transport.Close() tsc.transport.Close()
} }
type NamedTunnelRPCClient interface {
RegisterConnection(
c context.Context,
config *NamedTunnelConfig,
options *tunnelpogs.ConnectionOptions,
connIndex uint8,
observer *Observer,
) error
GracefulShutdown(ctx context.Context, gracePeriod time.Duration)
Close()
}
type registrationServerClient struct { type registrationServerClient struct {
client tunnelpogs.RegistrationServer_PogsClient client tunnelpogs.RegistrationServer_PogsClient
transport rpc.Transport transport rpc.Transport
@ -58,7 +71,7 @@ func newRegistrationRPCClient(
ctx context.Context, ctx context.Context,
stream io.ReadWriteCloser, stream io.ReadWriteCloser,
logger logger.Service, logger logger.Service,
) *registrationServerClient { ) NamedTunnelRPCClient {
transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream)) transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream))
conn := rpc.NewConn( conn := rpc.NewConn(
transport, transport,
@ -70,31 +83,14 @@ func newRegistrationRPCClient(
} }
} }
func (rsc *registrationServerClient) close() { func (rsc *registrationServerClient) RegisterConnection(
// Closing the client will also close the connection
rsc.client.Close()
// Closing the transport also closes the stream
rsc.transport.Close()
}
type rpcName string
const (
register rpcName = "register"
reconnect rpcName = "reconnect"
unregister rpcName = "unregister"
authenticate rpcName = " authenticate"
)
func registerConnection(
ctx context.Context, ctx context.Context,
rpcClient *registrationServerClient,
config *NamedTunnelConfig, config *NamedTunnelConfig,
options *tunnelpogs.ConnectionOptions, options *tunnelpogs.ConnectionOptions,
connIndex uint8, connIndex uint8,
observer *Observer, observer *Observer,
) error { ) error {
conn, err := rpcClient.client.RegisterConnection( conn, err := rsc.client.RegisterConnection(
ctx, ctx,
config.Auth, config.Auth,
config.ID, config.ID,
@ -118,6 +114,28 @@ func registerConnection(
return nil return nil
} }
func (rsc *registrationServerClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
ctx, cancel := context.WithTimeout(ctx, gracePeriod)
defer cancel()
rsc.client.UnregisterConnection(ctx)
}
func (rsc *registrationServerClient) Close() {
// Closing the client will also close the connection
rsc.client.Close()
// Closing the transport also closes the stream
rsc.transport.Close()
}
type rpcName string
const (
register rpcName = "register"
reconnect rpcName = "reconnect"
unregister rpcName = "unregister"
authenticate rpcName = " authenticate"
)
func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error { func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
h.observer.sendRegisteringEvent() h.observer.sendRegisteringEvent()
@ -264,9 +282,9 @@ func (h *h2muxConnection) unregister(isNamedTunnel bool) {
if isNamedTunnel { if isNamedTunnel {
rpcClient := newRegistrationRPCClient(unregisterCtx, stream, h.observer) rpcClient := newRegistrationRPCClient(unregisterCtx, stream, h.observer)
defer rpcClient.close() defer rpcClient.Close()
rpcClient.client.UnregisterConnection(unregisterCtx) rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod)
} else { } else {
rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer) rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer)
defer rpcClient.Close() defer rpcClient.Close()

View File

@ -60,7 +60,7 @@ func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsock
} }
if err != nil { if err != nil {
c.logRequestError(err, cfRay, ruleNum) c.logRequestError(err, cfRay, ruleNum)
w.WriteErrorResponse(err) w.WriteErrorResponse()
return err return err
} }
c.logOriginResponse(resp, cfRay, lbProbe, ruleNum) c.logOriginResponse(resp, cfRay, lbProbe, ruleNum)

View File

@ -47,7 +47,7 @@ func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error {
return nil return nil
} }
func (w *mockHTTPRespWriter) WriteErrorResponse(err error) { func (w *mockHTTPRespWriter) WriteErrorResponse() {
w.WriteHeader(http.StatusBadGateway) w.WriteHeader(http.StatusBadGateway)
} }