diff --git a/streamhandler/request.go b/streamhandler/request.go index 6d2004bd..3b61ac26 100644 --- a/streamhandler/request.go +++ b/streamhandler/request.go @@ -23,8 +23,8 @@ func IsLBProbeRequest(req *http.Request) bool { return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix) } -func CreateRequest(stream *h2mux.MuxedStream, originAddr string) (*http.Request, error) { - req, err := http.NewRequest(http.MethodGet, originAddr, h2mux.MuxedStreamReader{MuxedStream: stream}) +func createRequest(stream *h2mux.MuxedStream, url string) (*http.Request, error) { + req, err := http.NewRequest(http.MethodGet, url, h2mux.MuxedStreamReader{MuxedStream: stream}) if err != nil { return nil, errors.Wrap(err, "unexpected error from http.NewRequest") } diff --git a/streamhandler/stream_handler.go b/streamhandler/stream_handler.go index 955ccca9..73d5d474 100644 --- a/streamhandler/stream_handler.go +++ b/streamhandler/stream_handler.go @@ -4,15 +4,39 @@ import ( "context" "fmt" "net/http" + "strconv" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/tunnelhostnamemapper" "github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "zombiezen.com/go/capnproto2/rpc" ) +const ( + statusPseudoHeader = ":status" +) + +type httpErrorStatus struct { + status string + text []byte +} + +var ( + statusBadRequest = newHTTPErrorStatus(http.StatusBadRequest) + statusNotFound = newHTTPErrorStatus(http.StatusNotFound) + statusBadGateway = newHTTPErrorStatus(http.StatusBadGateway) +) + +func newHTTPErrorStatus(status int) *httpErrorStatus { + return &httpErrorStatus{ + status: strconv.Itoa(status), + text: []byte(http.StatusText(status)), + } +} + // StreamHandler handles new stream opened by the edge. The streams can be used to proxy requests or make RPC. type StreamHandler struct { // newConfigChan is a send-only channel to notify Supervisor of a new ClientConfig @@ -82,7 +106,11 @@ func (s *StreamHandler) ServeStream(stream *h2mux.MuxedStream) error { if stream.IsRPCStream() { return s.serveRPC(stream) } - return s.serveRequest(stream) + if err := s.serveRequest(stream); err != nil { + s.logger.Error(err) + return err + } + return nil } func (s *StreamHandler) serveRPC(stream *h2mux.MuxedStream) error { @@ -100,21 +128,20 @@ func (s *StreamHandler) serveRPC(stream *h2mux.MuxedStream) error { func (s *StreamHandler) serveRequest(stream *h2mux.MuxedStream) error { tunnelHostname := stream.TunnelHostname() if !tunnelHostname.IsSet() { - err := fmt.Errorf("stream doesn't have tunnelHostname") - s.logger.Error(err) - return err + s.writeErrorStatus(stream, statusBadRequest) + return fmt.Errorf("stream doesn't have tunnelHostname") } originService, ok := s.tunnelHostnameMapper.Get(tunnelHostname) if !ok { - err := fmt.Errorf("cannot map tunnel hostname %s to origin", tunnelHostname) - s.logger.Error(err) - return err + s.writeErrorStatus(stream, statusNotFound) + return fmt.Errorf("cannot map tunnel hostname %s to origin", tunnelHostname) } - req, err := CreateRequest(stream, originService.OriginAddr()) + req, err := createRequest(stream, originService.OriginAddr()) if err != nil { - return err + s.writeErrorStatus(stream, statusBadRequest) + return errors.Wrap(err, "cannot create request") } logger := s.requestLogger(req, tunnelHostname) @@ -122,8 +149,8 @@ func (s *StreamHandler) serveRequest(stream *h2mux.MuxedStream) error { resp, err := originService.Proxy(stream, req) if err != nil { - logger.WithError(err).Error("Request error") - return err + s.writeErrorStatus(stream, statusBadGateway) + return errors.Wrap(err, "cannot proxy request") } logger.WithField("status", resp.Status).Debugf("Response Headers %+v", resp.Header) @@ -144,3 +171,13 @@ func (s *StreamHandler) requestLogger(req *http.Request, tunnelHostname h2mux.Tu } return logger } + +func (s *StreamHandler) writeErrorStatus(stream *h2mux.MuxedStream, status *httpErrorStatus) { + stream.WriteHeaders([]h2mux.Header{ + { + Name: statusPseudoHeader, + Value: status.status, + }, + }) + stream.Write(status.text) +} diff --git a/streamhandler/stream_handler_test.go b/streamhandler/stream_handler_test.go new file mode 100644 index 00000000..7e777709 --- /dev/null +++ b/streamhandler/stream_handler_test.go @@ -0,0 +1,230 @@ +package streamhandler + +import ( + "context" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "sync" + "testing" + "time" + + "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" +) + +const ( + testOpenStreamTimeout = time.Millisecond * 5000 + testHandshakeTimeout = time.Millisecond * 1000 +) + +var ( + testTunnelHostname = h2mux.TunnelHostname("123.cftunnel.com") + baseHeaders = []h2mux.Header{ + {Name: ":method", Value: "GET"}, + {Name: ":scheme", Value: "http"}, + {Name: ":authority", Value: "example.com"}, + {Name: ":path", Value: "/"}, + } + tunnelHostnameHeader = h2mux.Header{ + Name: h2mux.CloudflaredProxyTunnelHostnameHeader, + Value: testTunnelHostname.String(), + } +) + +func TestServeRequest(t *testing.T) { + configChan := make(chan *pogs.ClientConfig) + useConfigResultChan := make(chan *pogs.UseConfigurationResult) + streamHandler := NewStreamHandler(configChan, useConfigResultChan, logrus.New()) + + message := []byte("Hello cloudflared") + httpServer := httptest.NewServer(&mockHTTPHandler{message}) + url, err := url.Parse(httpServer.URL) + assert.NoError(t, err) + + reverseProxyConfigs := []*pogs.ReverseProxyConfig{ + { + TunnelHostname: testTunnelHostname, + Origin: &pogs.HTTPOriginConfig{ + URL: &pogs.HTTPURL{ + URL: url, + }, + }, + }, + } + streamHandler.UpdateConfig(reverseProxyConfigs) + + muxPair := NewDefaultMuxerPair(t, streamHandler) + muxPair.Serve(t) + + ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout) + defer cancel() + + headers := append(baseHeaders, tunnelHostnameHeader) + stream, err := muxPair.EdgeMux.OpenStream(ctx, headers, nil) + assert.NoError(t, err) + assertStatusHeader(t, http.StatusOK, stream.Headers) + assertRespBody(t, message, stream) +} + +func TestServeBadRequest(t *testing.T) { + configChan := make(chan *pogs.ClientConfig) + useConfigResultChan := make(chan *pogs.UseConfigurationResult) + streamHandler := NewStreamHandler(configChan, useConfigResultChan, logrus.New()) + + muxPair := NewDefaultMuxerPair(t, streamHandler) + muxPair.Serve(t) + + ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout) + defer cancel() + + // No tunnel hostname header, expect to get 400 Bad Request + stream, err := muxPair.EdgeMux.OpenStream(ctx, baseHeaders, nil) + assert.NoError(t, err) + assertStatusHeader(t, http.StatusBadRequest, stream.Headers) + assertRespBody(t, statusBadRequest.text, stream) + + // No mapping for the tunnel hostname, expect to get 404 Not Found + headers := append(baseHeaders, tunnelHostnameHeader) + stream, err = muxPair.EdgeMux.OpenStream(ctx, headers, nil) + assert.NoError(t, err) + assertStatusHeader(t, http.StatusNotFound, stream.Headers) + assertRespBody(t, statusNotFound.text, stream) + + // Nothing listening on empty url, so proxy would fail. Expect to get 502 Bad Gateway + reverseProxyConfigs := []*pogs.ReverseProxyConfig{ + { + TunnelHostname: testTunnelHostname, + Origin: &pogs.HTTPOriginConfig{ + URL: &pogs.HTTPURL{ + URL: &url.URL{}, + }, + }, + }, + } + streamHandler.UpdateConfig(reverseProxyConfigs) + stream, err = muxPair.EdgeMux.OpenStream(ctx, headers, nil) + assert.NoError(t, err) + assertStatusHeader(t, http.StatusBadGateway, stream.Headers) + assertRespBody(t, statusBadGateway.text, stream) + + // Invalid content-length, wouldn't not be able to create a request. Expect to get 400 Bad Request + headers = append(headers, h2mux.Header{ + Name: "content-length", + Value: "x", + }) + stream, err = muxPair.EdgeMux.OpenStream(ctx, headers, nil) + assert.NoError(t, err) + assertStatusHeader(t, http.StatusBadRequest, stream.Headers) + assertRespBody(t, statusBadRequest.text, stream) +} + +func assertStatusHeader(t *testing.T, expectedStatus int, headers []h2mux.Header) { + assert.Equal(t, statusPseudoHeader, headers[0].Name) + assert.Equal(t, strconv.Itoa(expectedStatus), headers[0].Value) +} + +func assertRespBody(t *testing.T, expectedRespBody []byte, stream *h2mux.MuxedStream) { + respBody := make([]byte, len(expectedRespBody)) + _, err := stream.Read(respBody) + assert.NoError(t, err) + assert.Equal(t, expectedRespBody, respBody) +} + +type DefaultMuxerPair struct { + OriginMuxConfig h2mux.MuxerConfig + OriginMux *h2mux.Muxer + OriginConn net.Conn + EdgeMuxConfig h2mux.MuxerConfig + EdgeMux *h2mux.Muxer + EdgeConn net.Conn + doneC chan struct{} +} + +func NewDefaultMuxerPair(t assert.TestingT, h h2mux.MuxedStreamHandler) *DefaultMuxerPair { + origin, edge := net.Pipe() + p := &DefaultMuxerPair{ + OriginMuxConfig: h2mux.MuxerConfig{ + Timeout: testHandshakeTimeout, + Handler: h, + IsClient: true, + Name: "origin", + Logger: logrus.NewEntry(logrus.New()), + DefaultWindowSize: (1 << 8) - 1, + MaxWindowSize: (1 << 15) - 1, + StreamWriteBufferMaxLen: 1024, + }, + OriginConn: origin, + EdgeMuxConfig: h2mux.MuxerConfig{ + Timeout: testHandshakeTimeout, + IsClient: false, + Name: "edge", + Logger: logrus.NewEntry(logrus.New()), + DefaultWindowSize: (1 << 8) - 1, + MaxWindowSize: (1 << 15) - 1, + StreamWriteBufferMaxLen: 1024, + }, + EdgeConn: edge, + doneC: make(chan struct{}), + } + assert.NoError(t, p.Handshake()) + return p +} + +func (p *DefaultMuxerPair) Handshake() error { + ctx, cancel := context.WithTimeout(context.Background(), testHandshakeTimeout) + defer cancel() + errGroup, _ := errgroup.WithContext(ctx) + errGroup.Go(func() (err error) { + p.EdgeMux, err = h2mux.Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig) + return errors.Wrap(err, "edge handshake failure") + }) + errGroup.Go(func() (err error) { + p.OriginMux, err = h2mux.Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig) + return errors.Wrap(err, "origin handshake failure") + }) + + return errGroup.Wait() +} + +func (p *DefaultMuxerPair) Serve(t assert.TestingT) { + ctx := context.Background() + var wg sync.WaitGroup + wg.Add(2) + go func() { + err := p.EdgeMux.Serve(ctx) + if err != nil && err != io.EOF && err != io.ErrClosedPipe { + t.Errorf("error in edge muxer Serve(): %s", err) + } + p.OriginMux.Shutdown() + wg.Done() + }() + go func() { + err := p.OriginMux.Serve(ctx) + if err != nil && err != io.EOF && err != io.ErrClosedPipe { + t.Errorf("error in origin muxer Serve(): %s", err) + } + p.EdgeMux.Shutdown() + wg.Done() + }() + go func() { + // notify when both muxes have stopped serving + wg.Wait() + close(p.doneC) + }() +} + +type mockHTTPHandler struct { + message []byte +} + +func (mth *mockHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Write(mth.message) +}