TUN-3489: Add unit tests to cover proxy logic in connection package of cloudflared
This commit is contained in:
parent
5974fb4cfd
commit
d5769519b2
|
@ -44,7 +44,7 @@ type OriginClient interface {
|
||||||
|
|
||||||
type ResponseWriter interface {
|
type ResponseWriter interface {
|
||||||
WriteRespHeaders(*http.Response) error
|
WriteRespHeaders(*http.Response) error
|
||||||
WriteErrorResponse(error)
|
WriteErrorResponse()
|
||||||
io.ReadWriter
|
io.ReadWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,113 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
"github.com/gobwas/ws/wsutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
largeFileSize = 2 * 1024 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testConfig = &Config{
|
||||||
|
OriginClient: &mockOriginClient{},
|
||||||
|
GracePeriod: time.Millisecond * 100,
|
||||||
|
}
|
||||||
|
testLogger, _ = logger.New()
|
||||||
|
testOriginURL = &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "connectiontest.argotunnel.com",
|
||||||
|
}
|
||||||
|
testTunnelEventChan = make(chan ui.TunnelEvent)
|
||||||
|
testObserver = &Observer{
|
||||||
|
testLogger,
|
||||||
|
m,
|
||||||
|
testTunnelEventChan,
|
||||||
|
}
|
||||||
|
testLargeResp = make([]byte, largeFileSize)
|
||||||
|
)
|
||||||
|
|
||||||
|
type testRequest struct {
|
||||||
|
name string
|
||||||
|
endpoint string
|
||||||
|
expectedStatus int
|
||||||
|
expectedBody []byte
|
||||||
|
isProxyError bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockOriginClient struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (moc *mockOriginClient) Proxy(w ResponseWriter, r *http.Request, isWebsocket bool) error {
|
||||||
|
if isWebsocket {
|
||||||
|
return wsEndpoint(w, r)
|
||||||
|
}
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/ok":
|
||||||
|
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
|
||||||
|
case "/large_file":
|
||||||
|
originRespEndpoint(w, http.StatusOK, testLargeResp)
|
||||||
|
case "/400":
|
||||||
|
originRespEndpoint(w, http.StatusBadRequest, []byte(http.StatusText(http.StatusBadRequest)))
|
||||||
|
case "/500":
|
||||||
|
originRespEndpoint(w, http.StatusInternalServerError, []byte(http.StatusText(http.StatusInternalServerError)))
|
||||||
|
case "/error":
|
||||||
|
return fmt.Errorf("Failed to proxy to origin")
|
||||||
|
default:
|
||||||
|
originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type nowriter struct {
|
||||||
|
io.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nowriter) Write(p []byte) (int, error) {
|
||||||
|
return 0, fmt.Errorf("Writer not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func wsEndpoint(w ResponseWriter, r *http.Request) error {
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusSwitchingProtocols,
|
||||||
|
}
|
||||||
|
w.WriteRespHeaders(resp)
|
||||||
|
clientReader := nowriter{r.Body}
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
data, err := wsutil.ReadClientText(clientReader)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := wsutil.WriteServerText(w, data); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
<-r.Context().Done()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func originRespEndpoint(w ResponseWriter, status int, data []byte) {
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: status,
|
||||||
|
}
|
||||||
|
w.WriteRespHeaders(resp)
|
||||||
|
w.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockConnectedFuse struct{}
|
||||||
|
|
||||||
|
func (mcf mockConnectedFuse) Connected() {}
|
||||||
|
|
||||||
|
func (mcf mockConnectedFuse) IsConnected() bool {
|
||||||
|
return true
|
||||||
|
}
|
|
@ -88,9 +88,9 @@ func (h *h2muxConnection) ServeNamedTunnel(ctx context.Context, namedTunnel *Nam
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
rpcClient := newRegistrationRPCClient(ctx, stream, h.observer)
|
rpcClient := newRegistrationRPCClient(ctx, stream, h.observer)
|
||||||
defer rpcClient.close()
|
defer rpcClient.Close()
|
||||||
|
|
||||||
if err = registerConnection(serveCtx, rpcClient, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
|
if err = rpcClient.RegisterConnection(serveCtx, namedTunnel, connOptions, h.connIndex, h.observer); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
connectedFuse.Connected()
|
connectedFuse.Connected()
|
||||||
|
@ -177,11 +177,16 @@ func (h *h2muxConnection) ServeStream(stream *h2mux.MuxedStream) error {
|
||||||
|
|
||||||
req, reqErr := h.newRequest(stream)
|
req, reqErr := h.newRequest(stream)
|
||||||
if reqErr != nil {
|
if reqErr != nil {
|
||||||
respWriter.WriteErrorResponse(reqErr)
|
respWriter.WriteErrorResponse()
|
||||||
return reqErr
|
return reqErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
|
err := h.config.OriginClient.Proxy(respWriter, req, websocket.IsWebSocketUpgrade(req))
|
||||||
|
if err != nil {
|
||||||
|
respWriter.WriteErrorResponse()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
|
func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
|
||||||
|
@ -206,7 +211,7 @@ func (rp *h2muxRespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||||
return rp.WriteHeaders(headers)
|
return rp.WriteHeaders(headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *h2muxRespWriter) WriteErrorResponse(err error) {
|
func (rp *h2muxRespWriter) WriteErrorResponse() {
|
||||||
rp.WriteHeaders([]h2mux.Header{
|
rp.WriteHeaders([]h2mux.Header{
|
||||||
{Name: ":status", Value: "502"},
|
{Name: ":status", Value: "502"},
|
||||||
{Name: responseMetaHeaderField, Value: responseMetaHeaderCfd},
|
{Name: responseMetaHeaderField, Value: responseMetaHeaderCfd},
|
||||||
|
|
|
@ -0,0 +1,242 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/gobwas/ws/wsutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testMuxerConfig = &MuxerConfig{
|
||||||
|
HeartbeatInterval: time.Second * 5,
|
||||||
|
MaxHeartbeats: 5,
|
||||||
|
CompressionSetting: 0,
|
||||||
|
MetricsUpdateFreq: time.Second * 5,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func newH2MuxConnection(ctx context.Context, t require.TestingT) (*h2muxConnection, *h2mux.Muxer) {
|
||||||
|
edgeConn, originConn := net.Pipe()
|
||||||
|
edgeMuxChan := make(chan *h2mux.Muxer)
|
||||||
|
go func() {
|
||||||
|
edgeMuxConfig := h2mux.MuxerConfig{
|
||||||
|
Logger: testObserver,
|
||||||
|
}
|
||||||
|
edgeMux, err := h2mux.Handshake(edgeConn, edgeConn, edgeMuxConfig, h2mux.ActiveStreams)
|
||||||
|
require.NoError(t, err)
|
||||||
|
edgeMuxChan <- edgeMux
|
||||||
|
}()
|
||||||
|
var connIndex = uint8(0)
|
||||||
|
h2muxConn, err, _ := NewH2muxConnection(ctx, testConfig, testMuxerConfig, originConn, connIndex, testObserver)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return h2muxConn, <-edgeMuxChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeStreamHTTP(t *testing.T) {
|
||||||
|
tests := []testRequest{
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
endpoint: "/ok",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: []byte(http.StatusText(http.StatusOK)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large_file",
|
||||||
|
endpoint: "/large_file",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: testLargeResp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bad request",
|
||||||
|
endpoint: "/400",
|
||||||
|
expectedStatus: http.StatusBadRequest,
|
||||||
|
expectedBody: []byte(http.StatusText(http.StatusBadRequest)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Internal server error",
|
||||||
|
endpoint: "/500",
|
||||||
|
expectedStatus: http.StatusInternalServerError,
|
||||||
|
expectedBody: []byte(http.StatusText(http.StatusInternalServerError)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Proxy error",
|
||||||
|
endpoint: "/error",
|
||||||
|
expectedStatus: http.StatusBadGateway,
|
||||||
|
expectedBody: nil,
|
||||||
|
isProxyError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
h2muxConn, edgeMux := newH2MuxConnection(ctx, t)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
edgeMux.Serve(ctx)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err := h2muxConn.serveMuxer(ctx)
|
||||||
|
require.Error(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
headers := []h2mux.Header{
|
||||||
|
{
|
||||||
|
Name: ":path",
|
||||||
|
Value: test.endpoint,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
stream, err := edgeMux.OpenStream(ctx, headers, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, hasHeader(stream, ":status", strconv.Itoa(test.expectedStatus)))
|
||||||
|
|
||||||
|
if test.isProxyError {
|
||||||
|
assert.True(t, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderCfd))
|
||||||
|
} else {
|
||||||
|
assert.True(t, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderOrigin))
|
||||||
|
body := make([]byte, len(test.expectedBody))
|
||||||
|
_, err = stream.Read(body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, test.expectedBody, body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeStreamWS(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
h2muxConn, edgeMux := newH2MuxConnection(ctx, t)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
edgeMux.Serve(ctx)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err := h2muxConn.serveMuxer(ctx)
|
||||||
|
require.Error(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
headers := []h2mux.Header{
|
||||||
|
{
|
||||||
|
Name: ":path",
|
||||||
|
Value: "/ws",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "connection",
|
||||||
|
Value: "upgrade",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "upgrade",
|
||||||
|
Value: "websocket",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
readPipe, writePipe := io.Pipe()
|
||||||
|
stream, err := edgeMux.OpenStream(ctx, headers, readPipe)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.True(t, hasHeader(stream, ":status", strconv.Itoa(http.StatusSwitchingProtocols)))
|
||||||
|
assert.True(t, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderOrigin))
|
||||||
|
|
||||||
|
data := []byte("test websocket")
|
||||||
|
err = wsutil.WriteClientText(writePipe, data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
respBody, err := wsutil.ReadServerText(stream)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasHeader(stream *h2mux.MuxedStream, name, val string) bool {
|
||||||
|
for _, header := range stream.Headers {
|
||||||
|
if header.Name == name && header.Value == val {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkServeStreamHTTPSimple(b *testing.B, test testRequest) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
h2muxConn, edgeMux := newH2MuxConnection(ctx, b)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
edgeMux.Serve(ctx)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err := h2muxConn.serveMuxer(ctx)
|
||||||
|
require.Error(b, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
headers := []h2mux.Header{
|
||||||
|
{
|
||||||
|
Name: ":path",
|
||||||
|
Value: test.endpoint,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
body := make([]byte, len(test.expectedBody))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
b.StartTimer()
|
||||||
|
stream, openstreamErr := edgeMux.OpenStream(ctx, headers, nil)
|
||||||
|
_, readBodyErr := stream.Read(body)
|
||||||
|
b.StopTimer()
|
||||||
|
|
||||||
|
require.NoError(b, openstreamErr)
|
||||||
|
assert.True(b, hasHeader(stream, responseMetaHeaderField, responseMetaHeaderOrigin))
|
||||||
|
require.True(b, hasHeader(stream, ":status", strconv.Itoa(http.StatusOK)))
|
||||||
|
require.NoError(b, readBodyErr)
|
||||||
|
require.Equal(b, test.expectedBody, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkServeStreamHTTPSimple(b *testing.B) {
|
||||||
|
test := testRequest{
|
||||||
|
name: "ok",
|
||||||
|
endpoint: "/ok",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: []byte(http.StatusText(http.StatusOK)),
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkServeStreamHTTPSimple(b, test)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkServeStreamHTTPLargeFile(b *testing.B) {
|
||||||
|
test := testRequest{
|
||||||
|
name: "large_file",
|
||||||
|
endpoint: "/large_file",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: testLargeResp,
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkServeStreamHTTPSimple(b, test)
|
||||||
|
}
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
|
@ -26,7 +27,7 @@ var (
|
||||||
errNotFlusher = errors.New("ResponseWriter doesn't implement http.Flusher")
|
errNotFlusher = errors.New("ResponseWriter doesn't implement http.Flusher")
|
||||||
)
|
)
|
||||||
|
|
||||||
type HTTP2Connection struct {
|
type http2Connection struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
server *http2.Server
|
server *http2.Server
|
||||||
config *Config
|
config *Config
|
||||||
|
@ -36,6 +37,8 @@ type HTTP2Connection struct {
|
||||||
connIndexStr string
|
connIndexStr string
|
||||||
connIndex uint8
|
connIndex uint8
|
||||||
wg *sync.WaitGroup
|
wg *sync.WaitGroup
|
||||||
|
// newRPCClientFunc allows us to mock RPCs during testing
|
||||||
|
newRPCClientFunc func(context.Context, io.ReadWriteCloser, logger.Service) NamedTunnelRPCClient
|
||||||
connectedFuse ConnectedFuse
|
connectedFuse ConnectedFuse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,8 +50,8 @@ func NewHTTP2Connection(
|
||||||
observer *Observer,
|
observer *Observer,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
connectedFuse ConnectedFuse,
|
connectedFuse ConnectedFuse,
|
||||||
) *HTTP2Connection {
|
) *http2Connection {
|
||||||
return &HTTP2Connection{
|
return &http2Connection{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
server: &http2.Server{
|
server: &http2.Server{
|
||||||
MaxConcurrentStreams: math.MaxUint32,
|
MaxConcurrentStreams: math.MaxUint32,
|
||||||
|
@ -60,11 +63,12 @@ func NewHTTP2Connection(
|
||||||
connIndexStr: uint8ToString(connIndex),
|
connIndexStr: uint8ToString(connIndex),
|
||||||
connIndex: connIndex,
|
connIndex: connIndex,
|
||||||
wg: &sync.WaitGroup{},
|
wg: &sync.WaitGroup{},
|
||||||
|
newRPCClientFunc: newRegistrationRPCClient,
|
||||||
connectedFuse: connectedFuse,
|
connectedFuse: connectedFuse,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HTTP2Connection) Serve(ctx context.Context) {
|
func (c *http2Connection) Serve(ctx context.Context) {
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
c.close()
|
c.close()
|
||||||
|
@ -75,7 +79,7 @@ func (c *HTTP2Connection) Serve(ctx context.Context) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
c.wg.Add(1)
|
c.wg.Add(1)
|
||||||
defer c.wg.Done()
|
defer c.wg.Done()
|
||||||
|
|
||||||
|
@ -86,65 +90,42 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
flusher, isFlusher := w.(http.Flusher)
|
flusher, isFlusher := w.(http.Flusher)
|
||||||
if !isFlusher {
|
if !isFlusher {
|
||||||
c.observer.Errorf("%T doesn't implement http.Flusher", w)
|
c.observer.Errorf("%T doesn't implement http.Flusher", w)
|
||||||
respWriter.WriteErrorResponse(errNotFlusher)
|
respWriter.WriteErrorResponse()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
respWriter.flusher = flusher
|
respWriter.flusher = flusher
|
||||||
|
var err error
|
||||||
if isControlStreamUpgrade(r) {
|
if isControlStreamUpgrade(r) {
|
||||||
respWriter.shouldFlush = true
|
respWriter.shouldFlush = true
|
||||||
err := c.serveControlStream(r.Context(), respWriter)
|
err = c.serveControlStream(r.Context(), respWriter)
|
||||||
if err != nil {
|
|
||||||
respWriter.WriteErrorResponse(err)
|
|
||||||
}
|
|
||||||
} else if isWebsocketUpgrade(r) {
|
} else if isWebsocketUpgrade(r) {
|
||||||
respWriter.shouldFlush = true
|
respWriter.shouldFlush = true
|
||||||
stripWebsocketUpgradeHeader(r)
|
stripWebsocketUpgradeHeader(r)
|
||||||
c.config.OriginClient.Proxy(respWriter, r, true)
|
err = c.config.OriginClient.Proxy(respWriter, r, true)
|
||||||
} else {
|
} else {
|
||||||
c.config.OriginClient.Proxy(respWriter, r, false)
|
err = c.config.OriginClient.Proxy(respWriter, r, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
respWriter.WriteErrorResponse()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HTTP2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
|
func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
|
||||||
rpcClient := newRegistrationRPCClient(ctx, respWriter, c.observer)
|
rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer)
|
||||||
defer rpcClient.close()
|
defer rpcClient.Close()
|
||||||
|
|
||||||
if err := registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
|
if err := rpcClient.RegisterConnection(ctx, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.connectedFuse.Connected()
|
c.connectedFuse.Connected()
|
||||||
|
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
c.gracefulShutdown(ctx, rpcClient)
|
rpcClient.GracefulShutdown(ctx, c.config.GracePeriod)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *HTTP2Connection) registerConnection(
|
func (c *http2Connection) close() {
|
||||||
ctx context.Context,
|
|
||||||
rpcClient tunnelpogs.RegistrationServer_PogsClient,
|
|
||||||
) error {
|
|
||||||
connDetail, err := rpcClient.RegisterConnection(
|
|
||||||
ctx,
|
|
||||||
c.namedTunnel.Auth,
|
|
||||||
c.namedTunnel.ID,
|
|
||||||
c.connIndex,
|
|
||||||
c.connOptions,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
c.observer.Errorf("Cannot register connection, err: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.observer.Infof("Connection %s registered with %s using ID %s", c.connIndexStr, connDetail.Location, connDetail.UUID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HTTP2Connection) gracefulShutdown(ctx context.Context, rpcClient *registrationServerClient) {
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, c.config.GracePeriod)
|
|
||||||
defer cancel()
|
|
||||||
rpcClient.client.UnregisterConnection(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HTTP2Connection) close() {
|
|
||||||
// Wait for all serve HTTP handlers to return
|
// Wait for all serve HTTP handlers to return
|
||||||
c.wg.Wait()
|
c.wg.Wait()
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
|
@ -195,7 +176,7 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *http2RespWriter) WriteErrorResponse(err error) {
|
func (rp *http2RespWriter) WriteErrorResponse() {
|
||||||
rp.setResponseMetaHeader(responseMetaHeaderCfd)
|
rp.setResponseMetaHeader(responseMetaHeaderCfd)
|
||||||
rp.w.WriteHeader(http.StatusBadGateway)
|
rp.w.WriteHeader(http.StatusBadGateway)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,303 @@
|
||||||
|
package connection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||||
|
"github.com/gobwas/ws/wsutil"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testTransport = http2.Transport{}
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestHTTP2Connection() (*http2Connection, net.Conn) {
|
||||||
|
edgeConn, originConn := net.Pipe()
|
||||||
|
var connIndex = uint8(0)
|
||||||
|
return NewHTTP2Connection(
|
||||||
|
originConn,
|
||||||
|
testConfig,
|
||||||
|
&NamedTunnelConfig{},
|
||||||
|
&pogs.ConnectionOptions{},
|
||||||
|
testObserver,
|
||||||
|
connIndex,
|
||||||
|
mockConnectedFuse{},
|
||||||
|
), edgeConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeHTTP(t *testing.T) {
|
||||||
|
tests := []testRequest{
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
endpoint: "ok",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: []byte(http.StatusText(http.StatusOK)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large_file",
|
||||||
|
endpoint: "large_file",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: testLargeResp,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bad request",
|
||||||
|
endpoint: "400",
|
||||||
|
expectedStatus: http.StatusBadRequest,
|
||||||
|
expectedBody: []byte(http.StatusText(http.StatusBadRequest)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Internal server error",
|
||||||
|
endpoint: "500",
|
||||||
|
expectedStatus: http.StatusInternalServerError,
|
||||||
|
expectedBody: []byte(http.StatusText(http.StatusInternalServerError)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Proxy error",
|
||||||
|
endpoint: "error",
|
||||||
|
expectedStatus: http.StatusBadGateway,
|
||||||
|
expectedBody: nil,
|
||||||
|
isProxyError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
http2Conn, edgeConn := newTestHTTP2Connection()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
http2Conn.Serve(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, test.expectedStatus, resp.StatusCode)
|
||||||
|
if test.expectedBody != nil {
|
||||||
|
respBody, err := ioutil.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, test.expectedBody, respBody)
|
||||||
|
}
|
||||||
|
if test.isProxyError {
|
||||||
|
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(responseMetaHeaderField))
|
||||||
|
} else {
|
||||||
|
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(responseMetaHeaderField))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockNamedTunnelRPCClient struct {
|
||||||
|
registered chan struct{}
|
||||||
|
unregistered chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc mockNamedTunnelRPCClient) RegisterConnection(
|
||||||
|
c context.Context,
|
||||||
|
config *NamedTunnelConfig,
|
||||||
|
options *tunnelpogs.ConnectionOptions,
|
||||||
|
connIndex uint8,
|
||||||
|
observer *Observer,
|
||||||
|
) error {
|
||||||
|
close(mc.registered)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc mockNamedTunnelRPCClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
|
||||||
|
close(mc.unregistered)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mockNamedTunnelRPCClient) Close() {}
|
||||||
|
|
||||||
|
type mockRPCClientFactory struct {
|
||||||
|
registered chan struct{}
|
||||||
|
unregistered chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, logger.Service) NamedTunnelRPCClient {
|
||||||
|
return mockNamedTunnelRPCClient{
|
||||||
|
registered: mf.registered,
|
||||||
|
unregistered: mf.unregistered,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type wsRespWriter struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
readPipe *io.PipeReader
|
||||||
|
writePipe *io.PipeWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWSRespWriter() *wsRespWriter {
|
||||||
|
readPipe, writePipe := io.Pipe()
|
||||||
|
return &wsRespWriter{
|
||||||
|
httptest.NewRecorder(),
|
||||||
|
readPipe,
|
||||||
|
writePipe,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wsRespWriter) RespBody() io.ReadWriter {
|
||||||
|
return nowriter{w.readPipe}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wsRespWriter) Write(data []byte) (n int, err error) {
|
||||||
|
return w.writePipe.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeWS(t *testing.T) {
|
||||||
|
http2Conn, _ := newTestHTTP2Connection()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
http2Conn.Serve(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
respWriter := newWSRespWriter()
|
||||||
|
readPipe, writePipe := io.Pipe()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Header.Set(internalUpgradeHeader, websocketUpgrade)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
http2Conn.ServeHTTP(respWriter, req)
|
||||||
|
}()
|
||||||
|
|
||||||
|
data := []byte("test websocket")
|
||||||
|
err = wsutil.WriteClientText(writePipe, data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
respBody, err := wsutil.ReadServerText(respWriter.RespBody())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
resp := respWriter.Result()
|
||||||
|
// http2RespWriter should rewrite status 101 to 200
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(responseMetaHeaderField))
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServeControlStream(t *testing.T) {
|
||||||
|
http2Conn, edgeConn := newTestHTTP2Connection()
|
||||||
|
|
||||||
|
rpcClientFactory := mockRPCClientFactory{
|
||||||
|
registered: make(chan struct{}),
|
||||||
|
unregistered: make(chan struct{}),
|
||||||
|
}
|
||||||
|
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
http2Conn.Serve(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
|
||||||
|
|
||||||
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
edgeHTTP2Conn.RoundTrip(req)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-rpcClientFactory.registered
|
||||||
|
cancel()
|
||||||
|
<-rpcClientFactory.unregistered
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkServeHTTP(b *testing.B, test testRequest) {
|
||||||
|
http2Conn, edgeConn := newTestHTTP2Connection()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
http2Conn.Serve(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||||
|
require.NoError(b, err)
|
||||||
|
|
||||||
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||||
|
require.NoError(b, err)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
b.StartTimer()
|
||||||
|
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
||||||
|
b.StopTimer()
|
||||||
|
require.NoError(b, err)
|
||||||
|
require.Equal(b, test.expectedStatus, resp.StatusCode)
|
||||||
|
if test.expectedBody != nil {
|
||||||
|
respBody, err := ioutil.ReadAll(resp.Body)
|
||||||
|
require.NoError(b, err)
|
||||||
|
require.Equal(b, test.expectedBody, respBody)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
func BenchmarkServeHTTPSimple(b *testing.B) {
|
||||||
|
test := testRequest{
|
||||||
|
name: "ok",
|
||||||
|
endpoint: "ok",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: []byte(http.StatusText(http.StatusOK)),
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkServeHTTP(b, test)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkServeHTTPLargeFile(b *testing.B) {
|
||||||
|
test := testRequest{
|
||||||
|
name: "large_file",
|
||||||
|
endpoint: "large_file",
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedBody: testLargeResp,
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmarkServeHTTP(b, test)
|
||||||
|
}
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||||
|
@ -49,6 +50,18 @@ func (tsc *tunnelServerClient) Close() {
|
||||||
tsc.transport.Close()
|
tsc.transport.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NamedTunnelRPCClient interface {
|
||||||
|
RegisterConnection(
|
||||||
|
c context.Context,
|
||||||
|
config *NamedTunnelConfig,
|
||||||
|
options *tunnelpogs.ConnectionOptions,
|
||||||
|
connIndex uint8,
|
||||||
|
observer *Observer,
|
||||||
|
) error
|
||||||
|
GracefulShutdown(ctx context.Context, gracePeriod time.Duration)
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
type registrationServerClient struct {
|
type registrationServerClient struct {
|
||||||
client tunnelpogs.RegistrationServer_PogsClient
|
client tunnelpogs.RegistrationServer_PogsClient
|
||||||
transport rpc.Transport
|
transport rpc.Transport
|
||||||
|
@ -58,7 +71,7 @@ func newRegistrationRPCClient(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
stream io.ReadWriteCloser,
|
stream io.ReadWriteCloser,
|
||||||
logger logger.Service,
|
logger logger.Service,
|
||||||
) *registrationServerClient {
|
) NamedTunnelRPCClient {
|
||||||
transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream))
|
transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream))
|
||||||
conn := rpc.NewConn(
|
conn := rpc.NewConn(
|
||||||
transport,
|
transport,
|
||||||
|
@ -70,31 +83,14 @@ func newRegistrationRPCClient(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rsc *registrationServerClient) close() {
|
func (rsc *registrationServerClient) RegisterConnection(
|
||||||
// Closing the client will also close the connection
|
|
||||||
rsc.client.Close()
|
|
||||||
// Closing the transport also closes the stream
|
|
||||||
rsc.transport.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
type rpcName string
|
|
||||||
|
|
||||||
const (
|
|
||||||
register rpcName = "register"
|
|
||||||
reconnect rpcName = "reconnect"
|
|
||||||
unregister rpcName = "unregister"
|
|
||||||
authenticate rpcName = " authenticate"
|
|
||||||
)
|
|
||||||
|
|
||||||
func registerConnection(
|
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
rpcClient *registrationServerClient,
|
|
||||||
config *NamedTunnelConfig,
|
config *NamedTunnelConfig,
|
||||||
options *tunnelpogs.ConnectionOptions,
|
options *tunnelpogs.ConnectionOptions,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
observer *Observer,
|
observer *Observer,
|
||||||
) error {
|
) error {
|
||||||
conn, err := rpcClient.client.RegisterConnection(
|
conn, err := rsc.client.RegisterConnection(
|
||||||
ctx,
|
ctx,
|
||||||
config.Auth,
|
config.Auth,
|
||||||
config.ID,
|
config.ID,
|
||||||
|
@ -118,6 +114,28 @@ func registerConnection(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rsc *registrationServerClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, gracePeriod)
|
||||||
|
defer cancel()
|
||||||
|
rsc.client.UnregisterConnection(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rsc *registrationServerClient) Close() {
|
||||||
|
// Closing the client will also close the connection
|
||||||
|
rsc.client.Close()
|
||||||
|
// Closing the transport also closes the stream
|
||||||
|
rsc.transport.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpcName string
|
||||||
|
|
||||||
|
const (
|
||||||
|
register rpcName = "register"
|
||||||
|
reconnect rpcName = "reconnect"
|
||||||
|
unregister rpcName = "unregister"
|
||||||
|
authenticate rpcName = " authenticate"
|
||||||
|
)
|
||||||
|
|
||||||
func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
|
func (h *h2muxConnection) registerTunnel(ctx context.Context, credentialSetter CredentialManager, classicTunnel *ClassicTunnelConfig, registrationOptions *tunnelpogs.RegistrationOptions) error {
|
||||||
h.observer.sendRegisteringEvent()
|
h.observer.sendRegisteringEvent()
|
||||||
|
|
||||||
|
@ -264,9 +282,9 @@ func (h *h2muxConnection) unregister(isNamedTunnel bool) {
|
||||||
|
|
||||||
if isNamedTunnel {
|
if isNamedTunnel {
|
||||||
rpcClient := newRegistrationRPCClient(unregisterCtx, stream, h.observer)
|
rpcClient := newRegistrationRPCClient(unregisterCtx, stream, h.observer)
|
||||||
defer rpcClient.close()
|
defer rpcClient.Close()
|
||||||
|
|
||||||
rpcClient.client.UnregisterConnection(unregisterCtx)
|
rpcClient.GracefulShutdown(unregisterCtx, h.config.GracePeriod)
|
||||||
} else {
|
} else {
|
||||||
rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer)
|
rpcClient := NewTunnelServerClient(unregisterCtx, stream, h.observer)
|
||||||
defer rpcClient.Close()
|
defer rpcClient.Close()
|
||||||
|
|
|
@ -60,7 +60,7 @@ func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsock
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.logRequestError(err, cfRay, ruleNum)
|
c.logRequestError(err, cfRay, ruleNum)
|
||||||
w.WriteErrorResponse(err)
|
w.WriteErrorResponse()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
|
c.logOriginResponse(resp, cfRay, lbProbe, ruleNum)
|
||||||
|
|
|
@ -47,7 +47,7 @@ func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *mockHTTPRespWriter) WriteErrorResponse(err error) {
|
func (w *mockHTTPRespWriter) WriteErrorResponse() {
|
||||||
w.WriteHeader(http.StatusBadGateway)
|
w.WriteHeader(http.StatusBadGateway)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue