TUN-5141: Make sure websocket pinger returns before streaming returns

This commit is contained in:
cthuang 2021-09-22 17:33:05 +01:00
parent f985ed567f
commit 6238fd9022
4 changed files with 95 additions and 13 deletions

View File

@ -7,6 +7,7 @@ import (
"math" "math"
"net" "net"
"net/http" "net/http"
"runtime/debug"
"strings" "strings"
"sync" "sync"
@ -100,7 +101,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
connType := determineHTTP2Type(r) connType := determineHTTP2Type(r)
handleMissingRequestParts(connType, r) handleMissingRequestParts(connType, r)
respWriter, err := newHTTP2RespWriter(r, w, connType) respWriter, err := NewHTTP2RespWriter(r, w, connType)
if err != nil { if err != nil {
c.observer.log.Error().Msg(err.Error()) c.observer.log.Error().Msg(err.Error())
return return
@ -159,7 +160,7 @@ type http2RespWriter struct {
shouldFlush bool shouldFlush bool
} }
func newHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type) (*http2RespWriter, error) { func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type) (*http2RespWriter, error) {
flusher, isFlusher := w.(http.Flusher) flusher, isFlusher := w.(http.Flusher)
if !isFlusher { if !isFlusher {
respWriter := &http2RespWriter{ respWriter := &http2RespWriter{
@ -231,7 +232,7 @@ func (rp *http2RespWriter) Write(p []byte) (n int, err error) {
// Implementer of OriginClient should make sure it doesn't write to the connection after Proxy returns // Implementer of OriginClient should make sure it doesn't write to the connection after Proxy returns
// Register a recover routine just in case. // Register a recover routine just in case.
if r := recover(); r != nil { if r := recover(); r != nil {
println("Recover from http2 response writer panic, error", r) println(fmt.Sprintf("Recover from http2 response writer panic, error %s", debug.Stack()))
} }
}() }()
n, err = rp.w.Write(p) n, err = rp.w.Write(p)

View File

@ -48,7 +48,12 @@ type tcpOverWSConnection struct {
} }
func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
wc.streamHandler(websocket.NewConn(ctx, tunnelConn, log), wc.conn, log) wsCtx, cancel := context.WithCancel(ctx)
wsConn := websocket.NewConn(wsCtx, tunnelConn, log)
wc.streamHandler(wsConn, wc.conn, log)
cancel()
// Makes sure wsConn stops sending ping before terminating the stream
wsConn.WaitForShutdown()
} }
func (wc *tcpOverWSConnection) Close() { func (wc *tcpOverWSConnection) Close() {
@ -63,7 +68,12 @@ type socksProxyOverWSConnection struct {
} }
func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
socks.StreamNetHandler(websocket.NewConn(ctx, tunnelConn, log), sp.accessPolicy, log) wsCtx, cancel := context.WithCancel(ctx)
wsConn := websocket.NewConn(wsCtx, tunnelConn, log)
socks.StreamNetHandler(wsConn, sp.accessPolicy, log)
cancel()
// Makes sure wsConn stops sending ping before terminating the stream
wsConn.WaitForShutdown()
} }
func (sp *socksProxyOverWSConnection) Close() { func (sp *socksProxyOverWSConnection) Close() {

View File

@ -19,6 +19,7 @@ import (
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/socks" "github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
@ -189,6 +190,53 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
} }
} }
func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
eyeballConn, err := connection.NewHTTP2RespWriter(r, w, connection.TypeWebsocket)
assert.NoError(t, err)
cfdConn, originConn := net.Pipe()
tcpOverWSConn := tcpOverWSConnection{
conn: cfdConn,
streamHandler: DefaultStreamHandler,
}
go func() {
time.Sleep(time.Millisecond * 10)
// Simulate losing connection to origin
originConn.Close()
}()
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
tcpOverWSConn.Stream(ctx, eyeballConn, testLogger)
})
server := httptest.NewServer(handler)
defer server.Close()
client := server.Client()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
errGroup, ctx := errgroup.WithContext(ctx)
for i := 0; i < 50; i++ {
eyeballConn, edgeConn := net.Pipe()
req, err := http.NewRequestWithContext(ctx, http.MethodConnect, server.URL, edgeConn)
assert.NoError(t, err)
resp, err := client.Transport.RoundTrip(req)
assert.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
errGroup.Go(func() error {
for {
if err := wsutil.WriteClientBinary(eyeballConn, testMessage); err != nil {
return nil
}
}
})
}
assert.NoError(t, errGroup.Wait())
}
type wsEyeball struct { type wsEyeball struct {
conn net.Conn conn net.Conn
} }

View File

@ -18,12 +18,16 @@ const (
writeWait = 10 * time.Second writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer. // Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second defaultPongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait. // Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10 defaultPingPeriod = (defaultPongWait * 9) / 10
PingPeriodContextKey = PingPeriodContext("pingPeriod")
) )
type PingPeriodContext string
// GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter // GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter
// This is still used by access carrier // This is still used by access carrier
type GorillaConn struct { type GorillaConn struct {
@ -77,7 +81,7 @@ func (c *GorillaConn) SetDeadline(t time.Time) error {
// pinger simulates the websocket connection to keep it alive // pinger simulates the websocket connection to keep it alive
func (c *GorillaConn) pinger(ctx context.Context) { func (c *GorillaConn) pinger(ctx context.Context) {
ticker := time.NewTicker(pingPeriod) ticker := time.NewTicker(defaultPingPeriod)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
@ -94,12 +98,15 @@ func (c *GorillaConn) pinger(ctx context.Context) {
type Conn struct { type Conn struct {
rw io.ReadWriter rw io.ReadWriter
log *zerolog.Logger log *zerolog.Logger
// closed is a channel to indicate if Conn has been fully terminated
shutdownC chan struct{}
} }
func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn { func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
c := &Conn{ c := &Conn{
rw: rw, rw: rw,
log: log, log: log,
shutdownC: make(chan struct{}),
} }
go c.pinger(ctx) go c.pinger(ctx)
return c return c
@ -123,23 +130,39 @@ func (c *Conn) Write(p []byte) (int, error) {
} }
func (c *Conn) pinger(ctx context.Context) { func (c *Conn) pinger(ctx context.Context) {
defer close(c.shutdownC)
pongMessge := wsutil.Message{ pongMessge := wsutil.Message{
OpCode: gobwas.OpPong, OpCode: gobwas.OpPong,
Payload: []byte{}, Payload: []byte{},
} }
ticker := time.NewTicker(pingPeriod)
ticker := time.NewTicker(c.pingPeriod(ctx))
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil { if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
c.log.Err(err).Msgf("failed to write ping message") c.log.Debug().Err(err).Msgf("failed to write ping message")
} }
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil { if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
c.log.Err(err).Msgf("failed to write pong message") c.log.Debug().Err(err).Msgf("failed to write pong message")
} }
case <-ctx.Done(): case <-ctx.Done():
return return
} }
} }
} }
func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
if val := ctx.Value(PingPeriodContextKey); val != nil {
if period, ok := val.(time.Duration); ok {
return period
}
}
return defaultPingPeriod
}
// Close waits for pinger to terminate
func (c *Conn) WaitForShutdown() {
<-c.shutdownC
}