TUN-5141: Make sure websocket pinger returns before streaming returns
This commit is contained in:
parent
f985ed567f
commit
6238fd9022
|
@ -7,6 +7,7 @@ import (
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
@ -100,7 +101,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
connType := determineHTTP2Type(r)
|
connType := determineHTTP2Type(r)
|
||||||
handleMissingRequestParts(connType, r)
|
handleMissingRequestParts(connType, r)
|
||||||
|
|
||||||
respWriter, err := newHTTP2RespWriter(r, w, connType)
|
respWriter, err := NewHTTP2RespWriter(r, w, connType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.observer.log.Error().Msg(err.Error())
|
c.observer.log.Error().Msg(err.Error())
|
||||||
return
|
return
|
||||||
|
@ -159,7 +160,7 @@ type http2RespWriter struct {
|
||||||
shouldFlush bool
|
shouldFlush bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type) (*http2RespWriter, error) {
|
func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type) (*http2RespWriter, error) {
|
||||||
flusher, isFlusher := w.(http.Flusher)
|
flusher, isFlusher := w.(http.Flusher)
|
||||||
if !isFlusher {
|
if !isFlusher {
|
||||||
respWriter := &http2RespWriter{
|
respWriter := &http2RespWriter{
|
||||||
|
@ -231,7 +232,7 @@ func (rp *http2RespWriter) Write(p []byte) (n int, err error) {
|
||||||
// Implementer of OriginClient should make sure it doesn't write to the connection after Proxy returns
|
// Implementer of OriginClient should make sure it doesn't write to the connection after Proxy returns
|
||||||
// Register a recover routine just in case.
|
// Register a recover routine just in case.
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
println("Recover from http2 response writer panic, error", r)
|
println(fmt.Sprintf("Recover from http2 response writer panic, error %s", debug.Stack()))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
n, err = rp.w.Write(p)
|
n, err = rp.w.Write(p)
|
||||||
|
|
|
@ -48,7 +48,12 @@ type tcpOverWSConnection struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||||
wc.streamHandler(websocket.NewConn(ctx, tunnelConn, log), wc.conn, log)
|
wsCtx, cancel := context.WithCancel(ctx)
|
||||||
|
wsConn := websocket.NewConn(wsCtx, tunnelConn, log)
|
||||||
|
wc.streamHandler(wsConn, wc.conn, log)
|
||||||
|
cancel()
|
||||||
|
// Makes sure wsConn stops sending ping before terminating the stream
|
||||||
|
wsConn.WaitForShutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wc *tcpOverWSConnection) Close() {
|
func (wc *tcpOverWSConnection) Close() {
|
||||||
|
@ -63,7 +68,12 @@ type socksProxyOverWSConnection struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||||
socks.StreamNetHandler(websocket.NewConn(ctx, tunnelConn, log), sp.accessPolicy, log)
|
wsCtx, cancel := context.WithCancel(ctx)
|
||||||
|
wsConn := websocket.NewConn(wsCtx, tunnelConn, log)
|
||||||
|
socks.StreamNetHandler(wsConn, sp.accessPolicy, log)
|
||||||
|
cancel()
|
||||||
|
// Makes sure wsConn stops sending ping before terminating the stream
|
||||||
|
wsConn.WaitForShutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sp *socksProxyOverWSConnection) Close() {
|
func (sp *socksProxyOverWSConnection) Close() {
|
||||||
|
|
|
@ -19,6 +19,7 @@ import (
|
||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/connection"
|
||||||
"github.com/cloudflare/cloudflared/logger"
|
"github.com/cloudflare/cloudflared/logger"
|
||||||
"github.com/cloudflare/cloudflared/socks"
|
"github.com/cloudflare/cloudflared/socks"
|
||||||
"github.com/cloudflare/cloudflared/websocket"
|
"github.com/cloudflare/cloudflared/websocket"
|
||||||
|
@ -189,6 +190,53 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
eyeballConn, err := connection.NewHTTP2RespWriter(r, w, connection.TypeWebsocket)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
cfdConn, originConn := net.Pipe()
|
||||||
|
tcpOverWSConn := tcpOverWSConnection{
|
||||||
|
conn: cfdConn,
|
||||||
|
streamHandler: DefaultStreamHandler,
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
// Simulate losing connection to origin
|
||||||
|
originConn.Close()
|
||||||
|
}()
|
||||||
|
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
|
||||||
|
tcpOverWSConn.Stream(ctx, eyeballConn, testLogger)
|
||||||
|
})
|
||||||
|
server := httptest.NewServer(handler)
|
||||||
|
defer server.Close()
|
||||||
|
client := server.Client()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
errGroup, ctx := errgroup.WithContext(ctx)
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
eyeballConn, edgeConn := net.Pipe()
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodConnect, server.URL, edgeConn)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
resp, err := client.Transport.RoundTrip(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
||||||
|
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
for {
|
||||||
|
if err := wsutil.WriteClientBinary(eyeballConn, testMessage); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, errGroup.Wait())
|
||||||
|
}
|
||||||
|
|
||||||
type wsEyeball struct {
|
type wsEyeball struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,12 +18,16 @@ const (
|
||||||
writeWait = 10 * time.Second
|
writeWait = 10 * time.Second
|
||||||
|
|
||||||
// Time allowed to read the next pong message from the peer.
|
// Time allowed to read the next pong message from the peer.
|
||||||
pongWait = 60 * time.Second
|
defaultPongWait = 60 * time.Second
|
||||||
|
|
||||||
// Send pings to peer with this period. Must be less than pongWait.
|
// Send pings to peer with this period. Must be less than pongWait.
|
||||||
pingPeriod = (pongWait * 9) / 10
|
defaultPingPeriod = (defaultPongWait * 9) / 10
|
||||||
|
|
||||||
|
PingPeriodContextKey = PingPeriodContext("pingPeriod")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type PingPeriodContext string
|
||||||
|
|
||||||
// GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter
|
// GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter
|
||||||
// This is still used by access carrier
|
// This is still used by access carrier
|
||||||
type GorillaConn struct {
|
type GorillaConn struct {
|
||||||
|
@ -77,7 +81,7 @@ func (c *GorillaConn) SetDeadline(t time.Time) error {
|
||||||
|
|
||||||
// pinger simulates the websocket connection to keep it alive
|
// pinger simulates the websocket connection to keep it alive
|
||||||
func (c *GorillaConn) pinger(ctx context.Context) {
|
func (c *GorillaConn) pinger(ctx context.Context) {
|
||||||
ticker := time.NewTicker(pingPeriod)
|
ticker := time.NewTicker(defaultPingPeriod)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
@ -94,12 +98,15 @@ func (c *GorillaConn) pinger(ctx context.Context) {
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
rw io.ReadWriter
|
rw io.ReadWriter
|
||||||
log *zerolog.Logger
|
log *zerolog.Logger
|
||||||
|
// closed is a channel to indicate if Conn has been fully terminated
|
||||||
|
shutdownC chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
|
func NewConn(ctx context.Context, rw io.ReadWriter, log *zerolog.Logger) *Conn {
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
rw: rw,
|
rw: rw,
|
||||||
log: log,
|
log: log,
|
||||||
|
shutdownC: make(chan struct{}),
|
||||||
}
|
}
|
||||||
go c.pinger(ctx)
|
go c.pinger(ctx)
|
||||||
return c
|
return c
|
||||||
|
@ -123,23 +130,39 @@ func (c *Conn) Write(p []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) pinger(ctx context.Context) {
|
func (c *Conn) pinger(ctx context.Context) {
|
||||||
|
defer close(c.shutdownC)
|
||||||
pongMessge := wsutil.Message{
|
pongMessge := wsutil.Message{
|
||||||
OpCode: gobwas.OpPong,
|
OpCode: gobwas.OpPong,
|
||||||
Payload: []byte{},
|
Payload: []byte{},
|
||||||
}
|
}
|
||||||
ticker := time.NewTicker(pingPeriod)
|
|
||||||
|
ticker := time.NewTicker(c.pingPeriod(ctx))
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
|
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
|
||||||
c.log.Err(err).Msgf("failed to write ping message")
|
c.log.Debug().Err(err).Msgf("failed to write ping message")
|
||||||
}
|
}
|
||||||
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
|
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
|
||||||
c.log.Err(err).Msgf("failed to write pong message")
|
c.log.Debug().Err(err).Msgf("failed to write pong message")
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) pingPeriod(ctx context.Context) time.Duration {
|
||||||
|
if val := ctx.Value(PingPeriodContextKey); val != nil {
|
||||||
|
if period, ok := val.(time.Duration); ok {
|
||||||
|
return period
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultPingPeriod
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close waits for pinger to terminate
|
||||||
|
func (c *Conn) WaitForShutdown() {
|
||||||
|
<-c.shutdownC
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue