diff --git a/ingress/origin_connection.go b/ingress/origin_connection.go index bdfb19b6..2e8b946e 100644 --- a/ingress/origin_connection.go +++ b/ingress/origin_connection.go @@ -78,16 +78,3 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io. func (sp *socksProxyOverWSConnection) Close() { } - -// wsProxyConnection represents a bidirectional stream for a websocket connection to the origin -type wsProxyConnection struct { - rwc io.ReadWriteCloser -} - -func (conn *wsProxyConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { - websocket.Stream(tunnelConn, conn.rwc, log) -} - -func (conn *wsProxyConnection) Close() { - conn.rwc.Close() -} diff --git a/websocket/websocket.go b/websocket/websocket.go index b94b4f54..67c8916b 100644 --- a/websocket/websocket.go +++ b/websocket/websocket.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "sync/atomic" "time" "github.com/gorilla/websocket" @@ -28,28 +29,64 @@ func NewResponseHeader(req *http.Request) http.Header { return header } +type bidirectionalStreamStatus struct { + doneChan chan struct{} + anyDone uint32 +} + +func newBiStreamStatus() *bidirectionalStreamStatus { + return &bidirectionalStreamStatus{ + doneChan: make(chan struct{}, 2), + anyDone: 0, + } +} + +func (s *bidirectionalStreamStatus) markUniStreamDone() { + atomic.StoreUint32(&s.anyDone, 1) + s.doneChan <- struct{}{} +} + +func (s *bidirectionalStreamStatus) waitAnyDone() { + <-s.doneChan +} +func (s *bidirectionalStreamStatus) isAnyDone() bool { + return atomic.LoadUint32(&s.anyDone) > 0 +} + // Stream copies copy data to & from provided io.ReadWriters. func Stream(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) { - proxyDone := make(chan struct{}, 2) + status := newBiStreamStatus() - go func() { - _, err := copyData(tunnelConn, originConn, "origin->tunnel") - if err != nil { - log.Debug().Msgf("origin to tunnel copy: %v", err) - } - proxyDone <- struct{}{} - }() - - go func() { - _, err := copyData(originConn, tunnelConn, "tunnel->origin") - if err != nil { - log.Debug().Msgf("tunnel to origin copy: %v", err) - } - proxyDone <- struct{}{} - }() + go unidirectionalStream(tunnelConn, originConn, "origin->tunnel", status, log) + go unidirectionalStream(originConn, tunnelConn, "tunnel->origin", status, log) // If one side is done, we are done. - <-proxyDone + status.waitAnyDone() +} + +func unidirectionalStream(dst io.Writer, src io.Reader, dir string, status *bidirectionalStreamStatus, log *zerolog.Logger) { + defer func() { + // The bidirectional streaming spawns 2 goroutines to stream each direction. + // If any ends, the callstack returns, meaning the Tunnel request/stream (depending on http2 vs quic) will + // close. In such case, if the other direction did not stop (due to application level stopping, e.g., if a + // server/origin listens forever until closure), it may read/write from the underlying ReadWriter (backed by + // the Edge<->cloudflared transport) in an unexpected state. + + if status.isAnyDone() { + // Because of this, we set this recover() logic, which kicks-in *only* if any stream is known to have + // exited. In such case, we stop a possible panic from propagating upstream. + if r := recover(); r != nil { + // We handle such unexpected errors only when we detect that one side of the streaming is done. + log.Debug().Msgf("Handled gracefully error %v in Streaming for %s", r, dir) + } + } + }() + + _, err := copyData(dst, src, dir) + if err != nil { + log.Debug().Msgf("%s copy: %v", dir, err) + } + status.markUniStreamDone() } // when set to true, enables logging of content copied to/from origin and tunnel