package stream import ( "fmt" "io" "sync" "testing" "time" "github.com/rs/zerolog" "github.com/stretchr/testify/require" ) func TestPipeBidirectionalFinishBothSides(t *testing.T) { fun := func(upstream, downstream *mockedStream) { downstream.closeReader() upstream.closeReader() } testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, false) } func TestPipeBidirectionalFinishOneSideTimeout(t *testing.T) { fun := func(upstream, downstream *mockedStream) { downstream.closeReader() } testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, true) } func TestPipeBidirectionalClosingWriteBothSidesAlsoExists(t *testing.T) { fun := func(upstream, downstream *mockedStream) { downstream.CloseWrite() upstream.CloseWrite() downstream.writeToReader("abc") upstream.writeToReader("abc") } testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, false) } func TestPipeBidirectionalClosingWriteSingleSideAlsoExists(t *testing.T) { fun := func(upstream, downstream *mockedStream) { downstream.CloseWrite() downstream.writeToReader("abc") upstream.writeToReader("abc") } testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, true) } func testPipeBidirectionalUnblocking(t *testing.T, afterFun func(*mockedStream, *mockedStream), timeout time.Duration, expectTimeout bool) { logger := zerolog.Nop() downstream := newMockedStream() upstream := newMockedStream() resultCh := make(chan error) go func() { resultCh <- PipeBidirectional(downstream, upstream, timeout, &logger) }() afterFun(upstream, downstream) select { case err := <-resultCh: if expectTimeout { require.NotNil(t, err) } else { require.Nil(t, err) } case <-time.After(timeout * 2): require.Fail(t, "test timeout") } } func newMockedStream() *mockedStream { return &mockedStream{ readCh: make(chan *string), writeCh: make(chan struct{}), } } type mockedStream struct { readCh chan *string writeCh chan struct{} writeCloseOnce sync.Once } func (m *mockedStream) Read(p []byte) (n int, err error) { result := <-m.readCh if result == nil { return 0, io.EOF } return len(*result), nil } func (m *mockedStream) Write(p []byte) (n int, err error) { <-m.writeCh return 0, fmt.Errorf("closed") } func (m *mockedStream) CloseWrite() error { m.writeCloseOnce.Do(func() { close(m.writeCh) }) return nil } func (m *mockedStream) closeReader() { close(m.readCh) } func (m *mockedStream) writeToReader(content string) { m.readCh <- &content }