diff --git a/stream/stream.go b/stream/stream.go index 3e5d8569..c9491ec6 100644 --- a/stream/stream.go +++ b/stream/stream.go @@ -15,6 +15,36 @@ import ( "github.com/cloudflare/cloudflared/cfio" ) +type Stream interface { + Reader + WriterCloser +} + +type Reader interface { + io.Reader +} + +type WriterCloser interface { + io.Writer + WriteCloser +} + +type WriteCloser interface { + CloseWrite() error +} + +type nopCloseWriterAdapter struct { + io.ReadWriter +} + +func NopCloseWriterAdapter(stream io.ReadWriter) *nopCloseWriterAdapter { + return &nopCloseWriterAdapter{stream} +} + +func (n *nopCloseWriterAdapter) CloseWrite() error { + return nil +} + type bidirectionalStreamStatus struct { doneChan chan struct{} anyDone uint32 @@ -32,8 +62,24 @@ func (s *bidirectionalStreamStatus) markUniStreamDone() { s.doneChan <- struct{}{} } -func (s *bidirectionalStreamStatus) waitAnyDone() { +func (s *bidirectionalStreamStatus) wait(maxWaitForSecondStream time.Duration) error { <-s.doneChan + + // Only wait for second stream to finish if maxWait is greater than zero + if maxWaitForSecondStream > 0 { + + timer := time.NewTimer(maxWaitForSecondStream) + defer timer.Stop() + + select { + case <-timer.C: + return fmt.Errorf("timeout waiting for second stream to finish") + case <-s.doneChan: + return nil + } + } + + return nil } func (s *bidirectionalStreamStatus) isAnyDone() bool { return atomic.LoadUint32(&s.anyDone) > 0 @@ -41,16 +87,28 @@ func (s *bidirectionalStreamStatus) isAnyDone() bool { // Pipe copies copy data to & from provided io.ReadWriters. func Pipe(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) { - status := newBiStreamStatus() - - go unidirectionalStream(tunnelConn, originConn, "origin->tunnel", status, log) - go unidirectionalStream(originConn, tunnelConn, "tunnel->origin", status, log) - - // If one side is done, we are done. - status.waitAnyDone() + PipeBidirectional(NopCloseWriterAdapter(tunnelConn), NopCloseWriterAdapter(originConn), 0, log) } -func unidirectionalStream(dst io.Writer, src io.Reader, dir string, status *bidirectionalStreamStatus, log *zerolog.Logger) { +// PipeBidirectional copies data two BidirectionStreams. It is a special case of Pipe where it receives a concept that allows for Read and Write side to be closed independently. +// The main difference is that when piping data from a reader to a writer, if EOF is read, then this implementation propagates the EOF signal to the destination/writer by closing the write side of the +// Bidirectional Stream. +// Finally, depending on once EOF is ready from one of the provided streams, the other direction of streaming data will have a configured time period to also finish, otherwise, +// the method will return immediately with a timeout error. It is however, the responsability of the caller to close the associated streams in both ends in order to free all the resources/go-routines. +func PipeBidirectional(downstream, upstream Stream, maxWaitForSecondStream time.Duration, log *zerolog.Logger) error { + status := newBiStreamStatus() + + go unidirectionalStream(downstream, upstream, "upstream->downstream", status, log) + go unidirectionalStream(upstream, downstream, "downstream->upstream", status, log) + + if err := status.wait(maxWaitForSecondStream); err != nil { + return errors.Wrap(err, "unable to wait for both streams while proxying") + } + + return nil +} + +func unidirectionalStream(dst WriterCloser, src Reader, dir string, status *bidirectionalStreamStatus, log *zerolog.Logger) { defer func() { // The bidirectional streaming spawns 2 goroutines to stream each direction. // If any ends, the callstack returns, meaning the Tunnel request/stream (depending on http2 vs quic) will @@ -71,6 +129,8 @@ func unidirectionalStream(dst io.Writer, src io.Reader, dir string, status *bidi } }() + defer dst.CloseWrite() + _, err := copyData(dst, src, dir) if err != nil { log.Debug().Msgf("%s copy: %v", dir, err) diff --git a/stream/stream_test.go b/stream/stream_test.go new file mode 100644 index 00000000..6db372b3 --- /dev/null +++ b/stream/stream_test.go @@ -0,0 +1,122 @@ +package stream + +import ( + "fmt" + "io" + "sync" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" +) + +func TestPipeBidirectionalFinishBothSides(t *testing.T) { + fun := func(upstream, downstream *mockedStream) { + downstream.closeReader() + upstream.closeReader() + } + + testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, false) +} + +func TestPipeBidirectionalFinishOneSideTimeout(t *testing.T) { + fun := func(upstream, downstream *mockedStream) { + downstream.closeReader() + } + + testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, true) +} + +func TestPipeBidirectionalClosingWriteBothSidesAlsoExists(t *testing.T) { + fun := func(upstream, downstream *mockedStream) { + downstream.CloseWrite() + upstream.CloseWrite() + + downstream.writeToReader("abc") + upstream.writeToReader("abc") + } + + testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, false) +} + +func TestPipeBidirectionalClosingWriteSingleSideAlsoExists(t *testing.T) { + fun := func(upstream, downstream *mockedStream) { + downstream.CloseWrite() + + downstream.writeToReader("abc") + upstream.writeToReader("abc") + } + + testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, true) +} + +func testPipeBidirectionalUnblocking(t *testing.T, afterFun func(*mockedStream, *mockedStream), timeout time.Duration, expectTimeout bool) { + logger := zerolog.Nop() + + downstream := newMockedStream() + upstream := newMockedStream() + + resultCh := make(chan error) + go func() { + resultCh <- PipeBidirectional(downstream, upstream, timeout, &logger) + }() + + afterFun(upstream, downstream) + + select { + case err := <-resultCh: + if expectTimeout { + require.NotNil(t, err) + } else { + require.Nil(t, err) + } + + case <-time.After(timeout * 2): + require.Fail(t, "test timeout") + } +} + +func newMockedStream() *mockedStream { + return &mockedStream{ + readCh: make(chan *string), + writeCh: make(chan struct{}), + } +} + +type mockedStream struct { + readCh chan *string + writeCh chan struct{} + + writeCloseOnce sync.Once +} + +func (m *mockedStream) Read(p []byte) (n int, err error) { + result := <-m.readCh + if result == nil { + return 0, io.EOF + } + + return len(*result), nil +} + +func (m *mockedStream) Write(p []byte) (n int, err error) { + <-m.writeCh + + return 0, fmt.Errorf("closed") +} + +func (m *mockedStream) CloseWrite() error { + m.writeCloseOnce.Do(func() { + close(m.writeCh) + }) + + return nil +} + +func (m *mockedStream) closeReader() { + close(m.readCh) +} +func (m *mockedStream) writeToReader(content string) { + m.readCh <- &content +}