diff --git a/connection/http2.go b/connection/http2.go index c2ed6b17..44346371 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -7,6 +7,7 @@ import ( "math" "net" "net/http" + "runtime/debug" "strings" "sync" @@ -100,7 +101,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { connType := determineHTTP2Type(r) handleMissingRequestParts(connType, r) - respWriter, err := newHTTP2RespWriter(r, w, connType) + respWriter, err := NewHTTP2RespWriter(r, w, connType) if err != nil { c.observer.log.Error().Msg(err.Error()) return @@ -159,7 +160,7 @@ type http2RespWriter struct { shouldFlush bool } -func newHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type) (*http2RespWriter, error) { +func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type) (*http2RespWriter, error) { flusher, isFlusher := w.(http.Flusher) if !isFlusher { respWriter := &http2RespWriter{ @@ -231,7 +232,7 @@ func (rp *http2RespWriter) Write(p []byte) (n int, err error) { // Implementer of OriginClient should make sure it doesn't write to the connection after Proxy returns // Register a recover routine just in case. if r := recover(); r != nil { - println("Recover from http2 response writer panic, error", r) + println(fmt.Sprintf("Recover from http2 response writer panic, error %s", debug.Stack())) } }() n, err = rp.w.Write(p) diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index c97d42fd..9588ce36 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -48,7 +48,12 @@ type tcpOverWSConnection struct { } func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { - wc.streamHandler(websocket.NewConn(ctx, tunnelConn, log), wc.conn, log) + wsCtx, cancel := context.WithCancel(ctx) + wsConn := websocket.NewConn(wsCtx, tunnelConn, log) + wc.streamHandler(wsConn, wc.conn, log) + cancel() + // Makes sure wsConn stops sending ping before terminating the stream + wsConn.WaitForShutdown() } func (wc *tcpOverWSConnection) Close() { @@ -63,7 +68,12 @@ type socksProxyOverWSConnection struct { } func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { - socks.StreamNetHandler(websocket.NewConn(ctx, tunnelConn, log), sp.accessPolicy, log) + wsCtx, cancel := context.WithCancel(ctx) + wsConn := websocket.NewConn(wsCtx, tunnelConn, log) + socks.StreamNetHandler(wsConn, sp.accessPolicy, log) + cancel() + // Makes sure wsConn stops sending ping before terminating the stream + wsConn.WaitForShutdown() } func (sp *socksProxyOverWSConnection) Close() { diff --git a/ingress/origin_connection_test.go b/ingress/origin_connection_test.go index 040662a0..7dfff6f0 100644 --- a/ingress/origin_connection_test.go +++ b/ingress/origin_connection_test.go @@ -19,6 +19,7 @@ import ( "golang.org/x/net/proxy" "golang.org/x/sync/errgroup" + "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/socks" "github.com/cloudflare/cloudflared/websocket" @@ -189,6 +190,53 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) { } } +func TestWsConnReturnsBeforeStreamReturns(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + eyeballConn, err := connection.NewHTTP2RespWriter(r, w, connection.TypeWebsocket) + assert.NoError(t, err) + + cfdConn, originConn := net.Pipe() + tcpOverWSConn := tcpOverWSConnection{ + conn: cfdConn, + streamHandler: DefaultStreamHandler, + } + go func() { + time.Sleep(time.Millisecond * 10) + // Simulate losing connection to origin + originConn.Close() + }() + ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond) + tcpOverWSConn.Stream(ctx, eyeballConn, testLogger) + }) + server := httptest.NewServer(handler) + defer server.Close() + client := server.Client() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + errGroup, ctx := errgroup.WithContext(ctx) + for i := 0; i < 50; i++ { + eyeballConn, edgeConn := net.Pipe() + req, err := http.NewRequestWithContext(ctx, http.MethodConnect, server.URL, edgeConn) + assert.NoError(t, err) + + resp, err := client.Transport.RoundTrip(req) + assert.NoError(t, err) + assert.Equal(t, resp.StatusCode, http.StatusOK) + + errGroup.Go(func() error { + for { + if err := wsutil.WriteClientBinary(eyeballConn, testMessage); err != nil { + return nil + } + } + }) + } + + assert.NoError(t, errGroup.Wait()) +} + type wsEyeball struct { conn net.Conn } diff --git a/websocket/connection.go b/websocket/connection.go index 40a68fd3..79665902 100644 --- a/websocket/connection.go +++ b/websocket/connection.go @@ -18,12 +18,16 @@ const ( writeWait = 10 * time.Second // Time allowed to read the next pong message from the peer. - pongWait = 60 * time.Second + defaultPongWait = 60 * time.Second // Send pings to peer with this period. Must be less than pongWait. - pingPeriod = (pongWait * 9) / 10 + defaultPingPeriod = (defaultPongWait * 9) / 10 + + PingPeriodContextKey = PingPeriodContext("pingPeriod") ) +type PingPeriodContext string + // GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter // This is still used by access carrier type GorillaConn struct { @@ -77,7 +81,7 @@ func (c *GorillaConn) SetDeadline(t time.Time) error { // pinger simulates the websocket connection to keep it alive func (c *GorillaConn) pinger(ctx context.Context) { - ticker := time.NewTicker(pingPeriod) + ticker := time.NewTicker(defaultPingPeriod) defer ticker.Stop() for { select { @@ -94,12 +98,15 @@ func (c *GorillaConn) pinger(ctx context.Context) { type Conn struct { rw io.ReadWriter log *zerolog.Logger + // closed is a channel to indicate if Conn has been fully terminated + shutdownC chan struct{} } func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn { c := &Conn{ - rw: rw, - log: log, + rw: rw, + log: log, + shutdownC: make(chan struct{}), } go c.pinger(ctx) return c @@ -123,23 +130,39 @@ func (c *Conn) Write(p []byte) (int, error) { } func (c *Conn) pinger(ctx context.Context) { + defer close(c.shutdownC) pongMessge := wsutil.Message{ OpCode: gobwas.OpPong, Payload: []byte{}, } - ticker := time.NewTicker(pingPeriod) + + ticker := time.NewTicker(c.pingPeriod(ctx)) defer ticker.Stop() for { select { case <-ticker.C: if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil { - c.log.Err(err).Msgf("failed to write ping message") + c.log.Debug().Err(err).Msgf("failed to write ping message") } if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil { - c.log.Err(err).Msgf("failed to write pong message") + c.log.Debug().Err(err).Msgf("failed to write pong message") } case <-ctx.Done(): return } } } + +func (c *Conn) pingPeriod(ctx context.Context) time.Duration { + if val := ctx.Value(PingPeriodContextKey); val != nil { + if period, ok := val.(time.Duration); ok { + return period + } + } + return defaultPingPeriod +} + +// Close waits for pinger to terminate +func (c *Conn) WaitForShutdown() { + <-c.shutdownC +}