package h2mux import ( "sync" "testing" "github.com/stretchr/testify/assert" ) func TestShutdown(t *testing.T) { const numStreams = 1000 m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name())) // Add all the streams { var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { go func(streamID int) { defer wg.Done() stream := &MuxedStream{streamID: uint32(streamID)} ok := m.Set(stream) assert.True(t, ok) }(i) } wg.Wait() } assert.Equal(t, numStreams, m.Len(), "All the streams should have been added") shutdownChan, alreadyInProgress := m.Shutdown() select { case <-shutdownChan: assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed") default: } assert.False(t, alreadyInProgress) shutdownChan2, alreadyInProgress2 := m.Shutdown() assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel") assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'") // Delete all the streams { var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { go func(streamID int) { defer wg.Done() m.Delete(uint32(streamID)) }(i) } wg.Wait() } assert.Equal(t, 0, m.Len(), "All the streams should have been deleted") select { case <-shutdownChan: default: assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed") } } func TestEmptyBeforeShutdown(t *testing.T) { const numStreams = 1000 m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name())) // Add all the streams { var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { go func(streamID int) { defer wg.Done() stream := &MuxedStream{streamID: uint32(streamID)} ok := m.Set(stream) assert.True(t, ok) }(i) } wg.Wait() } assert.Equal(t, numStreams, m.Len(), "All the streams should have been added") // Delete all the streams, bringing m to size 0 { var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { go func(streamID int) { defer wg.Done() m.Delete(uint32(streamID)) }(i) } wg.Wait() } assert.Equal(t, 0, m.Len(), "All the streams should have been deleted") // Add one stream back const soloStreamID = uint32(0) ok := m.Set(&MuxedStream{streamID: soloStreamID}) assert.True(t, ok) shutdownChan, alreadyInProgress := m.Shutdown() select { case <-shutdownChan: assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed") default: } assert.False(t, alreadyInProgress) shutdownChan2, alreadyInProgress2 := m.Shutdown() assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel") assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'") // Remove the remaining stream m.Delete(soloStreamID) select { case <-shutdownChan: default: assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed") } } type noopBuffer struct { isClosed bool } func (t *noopBuffer) Read(p []byte) (n int, err error) { return len(p), nil } func (t *noopBuffer) Write(p []byte) (n int, err error) { return len(p), nil } func (t *noopBuffer) Reset() {} func (t *noopBuffer) Len() int { return 0 } func (t *noopBuffer) Close() error { t.isClosed = true; return nil } func (t *noopBuffer) Closed() bool { return t.isClosed } type noopReadyList struct{} func (_ *noopReadyList) Signal(streamID uint32) {} func TestAbort(t *testing.T) { const numStreams = 1000 m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name())) var openedStreams sync.Map // Add all the streams { var wg sync.WaitGroup wg.Add(numStreams) for i := 0; i < numStreams; i++ { go func(streamID int) { defer wg.Done() stream := &MuxedStream{ streamID: uint32(streamID), readBuffer: &noopBuffer{}, writeBuffer: &noopBuffer{}, readyList: &noopReadyList{}, } ok := m.Set(stream) assert.True(t, ok) openedStreams.Store(stream.streamID, stream) }(i) } wg.Wait() } assert.Equal(t, numStreams, m.Len(), "All the streams should have been added") shutdownChan, alreadyInProgress := m.Shutdown() select { case <-shutdownChan: assert.Fail(t, "before Abort(), shutdownChan shouldn't be closed") default: } assert.False(t, alreadyInProgress) m.Abort() assert.Equal(t, numStreams, m.Len(), "Abort() shouldn't delete any streams") openedStreams.Range(func(key interface{}, value interface{}) bool { stream := value.(*MuxedStream) readBuffer := stream.readBuffer.(*noopBuffer) writeBuffer := stream.writeBuffer.(*noopBuffer) return assert.True(t, readBuffer.isClosed && writeBuffer.isClosed, "Abort() should have closed all the streams") }) select { case <-shutdownChan: default: assert.Fail(t, "after Abort(), shutdownChan should have been closed") } // multiple aborts shouldn't cause any issues m.Abort() m.Abort() m.Abort() }