TUN-7545: Add support for full bidirectionally streaming with close signal propagation
This commit is contained in:
parent
1b9f55a002
commit
286addc102
|
@ -15,6 +15,36 @@ import (
|
|||
"github.com/cloudflare/cloudflared/cfio"
|
||||
)
|
||||
|
||||
type Stream interface {
|
||||
Reader
|
||||
WriterCloser
|
||||
}
|
||||
|
||||
type Reader interface {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
type WriterCloser interface {
|
||||
io.Writer
|
||||
WriteCloser
|
||||
}
|
||||
|
||||
type WriteCloser interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
type nopCloseWriterAdapter struct {
|
||||
io.ReadWriter
|
||||
}
|
||||
|
||||
func NopCloseWriterAdapter(stream io.ReadWriter) *nopCloseWriterAdapter {
|
||||
return &nopCloseWriterAdapter{stream}
|
||||
}
|
||||
|
||||
func (n *nopCloseWriterAdapter) CloseWrite() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type bidirectionalStreamStatus struct {
|
||||
doneChan chan struct{}
|
||||
anyDone uint32
|
||||
|
@ -32,8 +62,24 @@ func (s *bidirectionalStreamStatus) markUniStreamDone() {
|
|||
s.doneChan <- struct{}{}
|
||||
}
|
||||
|
||||
func (s *bidirectionalStreamStatus) waitAnyDone() {
|
||||
func (s *bidirectionalStreamStatus) wait(maxWaitForSecondStream time.Duration) error {
|
||||
<-s.doneChan
|
||||
|
||||
// Only wait for second stream to finish if maxWait is greater than zero
|
||||
if maxWaitForSecondStream > 0 {
|
||||
|
||||
timer := time.NewTimer(maxWaitForSecondStream)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
return fmt.Errorf("timeout waiting for second stream to finish")
|
||||
case <-s.doneChan:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func (s *bidirectionalStreamStatus) isAnyDone() bool {
|
||||
return atomic.LoadUint32(&s.anyDone) > 0
|
||||
|
@ -41,16 +87,28 @@ func (s *bidirectionalStreamStatus) isAnyDone() bool {
|
|||
|
||||
// Pipe copies copy data to & from provided io.ReadWriters.
|
||||
func Pipe(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) {
|
||||
status := newBiStreamStatus()
|
||||
|
||||
go unidirectionalStream(tunnelConn, originConn, "origin->tunnel", status, log)
|
||||
go unidirectionalStream(originConn, tunnelConn, "tunnel->origin", status, log)
|
||||
|
||||
// If one side is done, we are done.
|
||||
status.waitAnyDone()
|
||||
PipeBidirectional(NopCloseWriterAdapter(tunnelConn), NopCloseWriterAdapter(originConn), 0, log)
|
||||
}
|
||||
|
||||
func unidirectionalStream(dst io.Writer, src io.Reader, dir string, status *bidirectionalStreamStatus, log *zerolog.Logger) {
|
||||
// PipeBidirectional copies data two BidirectionStreams. It is a special case of Pipe where it receives a concept that allows for Read and Write side to be closed independently.
|
||||
// The main difference is that when piping data from a reader to a writer, if EOF is read, then this implementation propagates the EOF signal to the destination/writer by closing the write side of the
|
||||
// Bidirectional Stream.
|
||||
// Finally, depending on once EOF is ready from one of the provided streams, the other direction of streaming data will have a configured time period to also finish, otherwise,
|
||||
// the method will return immediately with a timeout error. It is however, the responsability of the caller to close the associated streams in both ends in order to free all the resources/go-routines.
|
||||
func PipeBidirectional(downstream, upstream Stream, maxWaitForSecondStream time.Duration, log *zerolog.Logger) error {
|
||||
status := newBiStreamStatus()
|
||||
|
||||
go unidirectionalStream(downstream, upstream, "upstream->downstream", status, log)
|
||||
go unidirectionalStream(upstream, downstream, "downstream->upstream", status, log)
|
||||
|
||||
if err := status.wait(maxWaitForSecondStream); err != nil {
|
||||
return errors.Wrap(err, "unable to wait for both streams while proxying")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func unidirectionalStream(dst WriterCloser, src Reader, dir string, status *bidirectionalStreamStatus, log *zerolog.Logger) {
|
||||
defer func() {
|
||||
// The bidirectional streaming spawns 2 goroutines to stream each direction.
|
||||
// If any ends, the callstack returns, meaning the Tunnel request/stream (depending on http2 vs quic) will
|
||||
|
@ -71,6 +129,8 @@ func unidirectionalStream(dst io.Writer, src io.Reader, dir string, status *bidi
|
|||
}
|
||||
}()
|
||||
|
||||
defer dst.CloseWrite()
|
||||
|
||||
_, err := copyData(dst, src, dir)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("%s copy: %v", dir, err)
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
package stream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPipeBidirectionalFinishBothSides(t *testing.T) {
|
||||
fun := func(upstream, downstream *mockedStream) {
|
||||
downstream.closeReader()
|
||||
upstream.closeReader()
|
||||
}
|
||||
|
||||
testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, false)
|
||||
}
|
||||
|
||||
func TestPipeBidirectionalFinishOneSideTimeout(t *testing.T) {
|
||||
fun := func(upstream, downstream *mockedStream) {
|
||||
downstream.closeReader()
|
||||
}
|
||||
|
||||
testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, true)
|
||||
}
|
||||
|
||||
func TestPipeBidirectionalClosingWriteBothSidesAlsoExists(t *testing.T) {
|
||||
fun := func(upstream, downstream *mockedStream) {
|
||||
downstream.CloseWrite()
|
||||
upstream.CloseWrite()
|
||||
|
||||
downstream.writeToReader("abc")
|
||||
upstream.writeToReader("abc")
|
||||
}
|
||||
|
||||
testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, false)
|
||||
}
|
||||
|
||||
func TestPipeBidirectionalClosingWriteSingleSideAlsoExists(t *testing.T) {
|
||||
fun := func(upstream, downstream *mockedStream) {
|
||||
downstream.CloseWrite()
|
||||
|
||||
downstream.writeToReader("abc")
|
||||
upstream.writeToReader("abc")
|
||||
}
|
||||
|
||||
testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, true)
|
||||
}
|
||||
|
||||
func testPipeBidirectionalUnblocking(t *testing.T, afterFun func(*mockedStream, *mockedStream), timeout time.Duration, expectTimeout bool) {
|
||||
logger := zerolog.Nop()
|
||||
|
||||
downstream := newMockedStream()
|
||||
upstream := newMockedStream()
|
||||
|
||||
resultCh := make(chan error)
|
||||
go func() {
|
||||
resultCh <- PipeBidirectional(downstream, upstream, timeout, &logger)
|
||||
}()
|
||||
|
||||
afterFun(upstream, downstream)
|
||||
|
||||
select {
|
||||
case err := <-resultCh:
|
||||
if expectTimeout {
|
||||
require.NotNil(t, err)
|
||||
} else {
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
case <-time.After(timeout * 2):
|
||||
require.Fail(t, "test timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func newMockedStream() *mockedStream {
|
||||
return &mockedStream{
|
||||
readCh: make(chan *string),
|
||||
writeCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
type mockedStream struct {
|
||||
readCh chan *string
|
||||
writeCh chan struct{}
|
||||
|
||||
writeCloseOnce sync.Once
|
||||
}
|
||||
|
||||
func (m *mockedStream) Read(p []byte) (n int, err error) {
|
||||
result := <-m.readCh
|
||||
if result == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
return len(*result), nil
|
||||
}
|
||||
|
||||
func (m *mockedStream) Write(p []byte) (n int, err error) {
|
||||
<-m.writeCh
|
||||
|
||||
return 0, fmt.Errorf("closed")
|
||||
}
|
||||
|
||||
func (m *mockedStream) CloseWrite() error {
|
||||
m.writeCloseOnce.Do(func() {
|
||||
close(m.writeCh)
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockedStream) closeReader() {
|
||||
close(m.readCh)
|
||||
}
|
||||
func (m *mockedStream) writeToReader(content string) {
|
||||
m.readCh <- &content
|
||||
}
|
Loading…
Reference in New Issue