diff --git a/connection/http2.go b/connection/http2.go index 8e3e8f99..a2f00e2c 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -190,6 +190,13 @@ func (rp *http2RespWriter) Read(p []byte) (n int, err error) { } func (rp *http2RespWriter) Write(p []byte) (n int, err error) { + defer func() { + // Implementer of OriginClient should make sure it doesn't write to the connection after Proxy returns + // Register a recover routine just in case. + if r := recover(); r != nil { + println("Recover from http2 response writer panic, error", r) + } + }() n, err = rp.w.Write(p) if err == nil && rp.shouldFlush { rp.flusher.Flush() diff --git a/ingress/origin_service.go b/ingress/origin_service.go index 0993d202..f35e9bb1 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -318,3 +318,20 @@ func newHTTPTransport(service OriginService, cfg OriginRequestConfig) (*http.Tra return &httpTransport, nil } + +// MockOriginService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior. +type MockOriginService struct { + Transport http.RoundTripper +} + +func (mos MockOriginService) RoundTrip(req *http.Request) (*http.Response, error) { + return mos.Transport.RoundTrip(req) +} + +func (mos MockOriginService) String() string { + return "MockOriginService" +} + +func (mos MockOriginService) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error { + return nil +} diff --git a/origin/proxy.go b/origin/proxy.go index fe3ddc94..ddefdbf4 100644 --- a/origin/proxy.go +++ b/origin/proxy.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "io" + "net" "net/http" "strconv" "strings" @@ -124,19 +125,31 @@ func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request, } serveCtx, cancel := context.WithCancel(req.Context()) - defer cancel() + connClosedChan := make(chan struct{}) go func() { + // serveCtx is done if req is cancelled, or streamWebsocket returns <-serveCtx.Done() conn.Close() + close(connClosedChan) }() - err = w.WriteRespHeaders(resp) - if err != nil { - return nil, errors.Wrap(err, "Error writing response header") - } + // Copy to/from stream to the undelying connection. Use the underlying // connection because cloudflared doesn't operate on the message themselves - websocket.Stream(conn.UnderlyingConn(), w) - return resp, nil + err = c.streamWebsocket(w, conn.UnderlyingConn(), resp) + cancel() + + // We need to make sure conn is closed before returning, otherwise we might write to conn after Proxy returns + <-connClosedChan + return resp, err +} + +func (c *client) streamWebsocket(w connection.ResponseWriter, conn net.Conn, resp *http.Response) error { + err := w.WriteRespHeaders(resp) + if err != nil { + return errors.Wrap(err, "Error writing websocket response header") + } + websocket.Stream(conn, w) + return nil } func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) { diff --git a/origin/proxy_test.go b/origin/proxy_test.go index 7a85f6d2..d20be70a 100644 --- a/origin/proxy_test.go +++ b/origin/proxy_test.go @@ -49,6 +49,7 @@ func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error { func (w *mockHTTPRespWriter) WriteErrorResponse() { w.WriteHeader(http.StatusBadGateway) + w.Write([]byte("http response error")) } func (w *mockHTTPRespWriter) Read(data []byte) (int, error) { @@ -315,3 +316,37 @@ func (ma mockAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusCreated) w.Write([]byte("Created")) } + +type errorOriginTransport struct{} + +func (errorOriginTransport) RoundTrip(*http.Request) (*http.Response, error) { + return nil, fmt.Errorf("Proxy error") +} + +func TestProxyError(t *testing.T) { + ingress := ingress.Ingress{ + Rules: []ingress.Rule{ + { + Hostname: "*", + Path: nil, + Service: ingress.MockOriginService{ + Transport: errorOriginTransport{}, + }, + }, + }, + } + + logger, err := logger.New() + require.NoError(t, err) + + client := NewClient(ingress, testTags, logger) + + respWriter := newMockHTTPRespWriter() + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) + assert.NoError(t, err) + + err = client.Proxy(respWriter, req, false) + assert.Error(t, err) + assert.Equal(t, http.StatusBadGateway, respWriter.Code) + assert.Equal(t, "http response error", respWriter.Body.String()) +}