diff --git a/h2mux/activestreammap.go b/h2mux/activestreammap.go index 1138bea4..15203423 100644 --- a/h2mux/activestreammap.go +++ b/h2mux/activestreammap.go @@ -43,6 +43,7 @@ func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Ga return m } +// This function should be called while `m` is locked. func (m *activeStreamMap) notifyStreamsEmpty() { m.closeOnce.Do(func() { close(m.streamsEmptyChan) @@ -87,7 +88,9 @@ func (m *activeStreamMap) Delete(streamID uint32) { delete(m.streams, streamID) m.activeStreams.Dec() } - if len(m.streams) == 0 { + + // shutting down, and now the map is empty + if m.ignoreNewStreams && len(m.streams) == 0 { m.notifyStreamsEmpty() } } @@ -104,7 +107,7 @@ func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bo } m.ignoreNewStreams = true if len(m.streams) == 0 { - // nothing to shut down + // there are no streams to wait for m.notifyStreamsEmpty() } return m.streamsEmptyChan, false diff --git a/h2mux/activestreammap_test.go b/h2mux/activestreammap_test.go index 5f7cd2cc..f961bcaf 100644 --- a/h2mux/activestreammap_test.go +++ b/h2mux/activestreammap_test.go @@ -60,6 +60,67 @@ func TestShutdown(t *testing.T) { } } +func TestEmptyBeforeShutdown(t *testing.T) { + const numStreams = 1000 + m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name())) + + // Add all the streams + { + var wg sync.WaitGroup + wg.Add(numStreams) + for i := 0; i < numStreams; i++ { + go func(streamID int) { + defer wg.Done() + stream := &MuxedStream{streamID: uint32(streamID)} + ok := m.Set(stream) + assert.True(t, ok) + }(i) + } + wg.Wait() + } + assert.Equal(t, numStreams, m.Len(), "All the streams should have been added") + + // Delete all the streams, bringing m to size 0 + { + var wg sync.WaitGroup + wg.Add(numStreams) + for i := 0; i < numStreams; i++ { + go func(streamID int) { + defer wg.Done() + m.Delete(uint32(streamID)) + }(i) + } + wg.Wait() + } + assert.Equal(t, 0, m.Len(), "All the streams should have been deleted") + + // Add one stream back + const soloStreamID = uint32(0) + ok := m.Set(&MuxedStream{streamID: soloStreamID}) + assert.True(t, ok) + + shutdownChan, alreadyInProgress := m.Shutdown() + select { + case <-shutdownChan: + assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed") + default: + } + assert.False(t, alreadyInProgress) + + shutdownChan2, alreadyInProgress2 := m.Shutdown() + assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel") + assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'") + + // Remove the remaining stream + m.Delete(soloStreamID) + + select { + case <-shutdownChan: + default: + assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed") + } +} + type noopBuffer struct { isClosed bool }