TUN-3490: Make sure OriginClient implementation doesn't write after Proxy return
This commit is contained in:
parent
d5769519b2
commit
543169c893
|
@ -190,6 +190,13 @@ func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *http2RespWriter) Write(p []byte) (n int, err error) {
|
func (rp *http2RespWriter) Write(p []byte) (n int, err error) {
|
||||||
|
defer func() {
|
||||||
|
// Implementer of OriginClient should make sure it doesn't write to the connection after Proxy returns
|
||||||
|
// Register a recover routine just in case.
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
println("Recover from http2 response writer panic, error", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
n, err = rp.w.Write(p)
|
n, err = rp.w.Write(p)
|
||||||
if err == nil && rp.shouldFlush {
|
if err == nil && rp.shouldFlush {
|
||||||
rp.flusher.Flush()
|
rp.flusher.Flush()
|
||||||
|
|
|
@ -318,3 +318,20 @@ func newHTTPTransport(service OriginService, cfg OriginRequestConfig) (*http.Tra
|
||||||
|
|
||||||
return &httpTransport, nil
|
return &httpTransport, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MockOriginService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior.
|
||||||
|
type MockOriginService struct {
|
||||||
|
Transport http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mos MockOriginService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return mos.Transport.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mos MockOriginService) String() string {
|
||||||
|
return "MockOriginService"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mos MockOriginService) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -124,19 +125,31 @@ func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request,
|
||||||
}
|
}
|
||||||
|
|
||||||
serveCtx, cancel := context.WithCancel(req.Context())
|
serveCtx, cancel := context.WithCancel(req.Context())
|
||||||
defer cancel()
|
connClosedChan := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
|
// serveCtx is done if req is cancelled, or streamWebsocket returns
|
||||||
<-serveCtx.Done()
|
<-serveCtx.Done()
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
close(connClosedChan)
|
||||||
}()
|
}()
|
||||||
err = w.WriteRespHeaders(resp)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "Error writing response header")
|
|
||||||
}
|
|
||||||
// Copy to/from stream to the undelying connection. Use the underlying
|
// Copy to/from stream to the undelying connection. Use the underlying
|
||||||
// connection because cloudflared doesn't operate on the message themselves
|
// connection because cloudflared doesn't operate on the message themselves
|
||||||
websocket.Stream(conn.UnderlyingConn(), w)
|
err = c.streamWebsocket(w, conn.UnderlyingConn(), resp)
|
||||||
return resp, nil
|
cancel()
|
||||||
|
|
||||||
|
// We need to make sure conn is closed before returning, otherwise we might write to conn after Proxy returns
|
||||||
|
<-connClosedChan
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) streamWebsocket(w connection.ResponseWriter, conn net.Conn, resp *http.Response) error {
|
||||||
|
err := w.WriteRespHeaders(resp)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Error writing websocket response header")
|
||||||
|
}
|
||||||
|
websocket.Stream(conn, w)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
|
func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
|
||||||
|
|
|
@ -49,6 +49,7 @@ func (w *mockHTTPRespWriter) WriteRespHeaders(resp *http.Response) error {
|
||||||
|
|
||||||
func (w *mockHTTPRespWriter) WriteErrorResponse() {
|
func (w *mockHTTPRespWriter) WriteErrorResponse() {
|
||||||
w.WriteHeader(http.StatusBadGateway)
|
w.WriteHeader(http.StatusBadGateway)
|
||||||
|
w.Write([]byte("http response error"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
|
func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
|
||||||
|
@ -315,3 +316,37 @@ func (ma mockAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusCreated)
|
w.WriteHeader(http.StatusCreated)
|
||||||
w.Write([]byte("Created"))
|
w.Write([]byte("Created"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type errorOriginTransport struct{}
|
||||||
|
|
||||||
|
func (errorOriginTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||||
|
return nil, fmt.Errorf("Proxy error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyError(t *testing.T) {
|
||||||
|
ingress := ingress.Ingress{
|
||||||
|
Rules: []ingress.Rule{
|
||||||
|
{
|
||||||
|
Hostname: "*",
|
||||||
|
Path: nil,
|
||||||
|
Service: ingress.MockOriginService{
|
||||||
|
Transport: errorOriginTransport{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
logger, err := logger.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
client := NewClient(ingress, testTags, logger)
|
||||||
|
|
||||||
|
respWriter := newMockHTTPRespWriter()
|
||||||
|
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = client.Proxy(respWriter, req, false)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, http.StatusBadGateway, respWriter.Code)
|
||||||
|
assert.Equal(t, "http response error", respWriter.Body.String())
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue