From b499c0fdbad0b3cfb3035b4cacba423804909673 Mon Sep 17 00:00:00 2001 From: Nick Vollmar Date: Tue, 3 Dec 2019 15:01:28 -0600 Subject: [PATCH] TUN-2608: h2mux.Muxer.Shutdown always returns a non-nil channel --- h2mux/activestreammap.go | 46 +++++++----- h2mux/activestreammap_test.go | 134 ++++++++++++++++++++++++++++++++++ h2mux/h2mux.go | 6 +- h2mux/h2mux_test.go | 8 ++ h2mux/muxedstream.go | 12 ++- h2mux/muxreader.go | 11 ++- 6 files changed, 188 insertions(+), 29 deletions(-) create mode 100644 h2mux/activestreammap_test.go diff --git a/h2mux/activestreammap.go b/h2mux/activestreammap.go index a15bee89..1138bea4 100644 --- a/h2mux/activestreammap.go +++ b/h2mux/activestreammap.go @@ -13,26 +13,28 @@ type activeStreamMap struct { sync.RWMutex // streams tracks open streams. streams map[uint32]*MuxedStream - // streamsEmpty is a chan that should be closed when no more streams are open. - streamsEmpty chan struct{} // nextStreamID is the next ID to use on our side of the connection. // This is odd for clients, even for servers. nextStreamID uint32 // maxPeerStreamID is the ID of the most recent stream opened by the peer. maxPeerStreamID uint32 + // activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams + activeStreams prometheus.Gauge + // ignoreNewStreams is true when the connection is being shut down. New streams // cannot be registered. ignoreNewStreams bool - // activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams - activeStreams prometheus.Gauge + // streamsEmpty is a chan that will be closed when no more streams are open. + streamsEmptyChan chan struct{} + closeOnce sync.Once } func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap { m := &activeStreamMap{ - streams: make(map[uint32]*MuxedStream), - streamsEmpty: make(chan struct{}), - nextStreamID: 1, - activeStreams: activeStreams, + streams: make(map[uint32]*MuxedStream), + streamsEmptyChan: make(chan struct{}), + nextStreamID: 1, + activeStreams: activeStreams, } // Client initiated stream uses odd stream ID, server initiated stream uses even stream ID if !useClientStreamNumbers { @@ -41,6 +43,12 @@ func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Ga return m } +func (m *activeStreamMap) notifyStreamsEmpty() { + m.closeOnce.Do(func() { + close(m.streamsEmptyChan) + }) +} + // Len returns the number of active streams. func (m *activeStreamMap) Len() int { m.RLock() @@ -79,30 +87,27 @@ func (m *activeStreamMap) Delete(streamID uint32) { delete(m.streams, streamID) m.activeStreams.Dec() } - if len(m.streams) == 0 && m.streamsEmpty != nil { - close(m.streamsEmpty) - m.streamsEmpty = nil + if len(m.streams) == 0 { + m.notifyStreamsEmpty() } } -// Shutdown blocks new streams from being created. It returns a channel that receives an event -// once the last stream has closed, or nil if a shutdown is in progress. -func (m *activeStreamMap) Shutdown() <-chan struct{} { +// Shutdown blocks new streams from being created. +// It returns `done`, a channel that is closed once the last stream has closed +// and `progress`, whether a shutdown was already in progress +func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bool) { m.Lock() defer m.Unlock() if m.ignoreNewStreams { // already shutting down - return nil + return m.streamsEmptyChan, true } m.ignoreNewStreams = true - done := make(chan struct{}) if len(m.streams) == 0 { // nothing to shut down - close(done) - return done + m.notifyStreamsEmpty() } - m.streamsEmpty = done - return done + return m.streamsEmptyChan, false } // AcquireLocalID acquires a new stream ID for a stream you're opening. @@ -170,4 +175,5 @@ func (m *activeStreamMap) Abort() { stream.Close() } m.ignoreNewStreams = true + m.notifyStreamsEmpty() } diff --git a/h2mux/activestreammap_test.go b/h2mux/activestreammap_test.go new file mode 100644 index 00000000..5f7cd2cc --- /dev/null +++ b/h2mux/activestreammap_test.go @@ -0,0 +1,134 @@ +package h2mux + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestShutdown(t *testing.T) { + const numStreams = 1000 + m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name())) + + // Add all the streams + { + var wg sync.WaitGroup + wg.Add(numStreams) + for i := 0; i < numStreams; i++ { + go func(streamID int) { + defer wg.Done() + stream := &MuxedStream{streamID: uint32(streamID)} + ok := m.Set(stream) + assert.True(t, ok) + }(i) + } + wg.Wait() + } + assert.Equal(t, numStreams, m.Len(), "All the streams should have been added") + + shutdownChan, alreadyInProgress := m.Shutdown() + select { + case <-shutdownChan: + assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed") + default: + } + assert.False(t, alreadyInProgress) + + shutdownChan2, alreadyInProgress2 := m.Shutdown() + assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel") + assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'") + + // Delete all the streams + { + var wg sync.WaitGroup + wg.Add(numStreams) + for i := 0; i < numStreams; i++ { + go func(streamID int) { + defer wg.Done() + m.Delete(uint32(streamID)) + }(i) + } + wg.Wait() + } + assert.Equal(t, 0, m.Len(), "All the streams should have been deleted") + + select { + case <-shutdownChan: + default: + assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed") + } +} + +type noopBuffer struct { + isClosed bool +} + +func (t *noopBuffer) Read(p []byte) (n int, err error) { return len(p), nil } +func (t *noopBuffer) Write(p []byte) (n int, err error) { return len(p), nil } +func (t *noopBuffer) Reset() {} +func (t *noopBuffer) Len() int { return 0 } +func (t *noopBuffer) Close() error { t.isClosed = true; return nil } +func (t *noopBuffer) Closed() bool { return t.isClosed } + +type noopReadyList struct{} + +func (_ *noopReadyList) Signal(streamID uint32) {} + +func TestAbort(t *testing.T) { + const numStreams = 1000 + m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name())) + + var openedStreams sync.Map + + // Add all the streams + { + var wg sync.WaitGroup + wg.Add(numStreams) + for i := 0; i < numStreams; i++ { + go func(streamID int) { + defer wg.Done() + stream := &MuxedStream{ + streamID: uint32(streamID), + readBuffer: &noopBuffer{}, + writeBuffer: &noopBuffer{}, + readyList: &noopReadyList{}, + } + ok := m.Set(stream) + assert.True(t, ok) + + openedStreams.Store(stream.streamID, stream) + }(i) + } + wg.Wait() + } + assert.Equal(t, numStreams, m.Len(), "All the streams should have been added") + + shutdownChan, alreadyInProgress := m.Shutdown() + select { + case <-shutdownChan: + assert.Fail(t, "before Abort(), shutdownChan shouldn't be closed") + default: + } + assert.False(t, alreadyInProgress) + + m.Abort() + assert.Equal(t, numStreams, m.Len(), "Abort() shouldn't delete any streams") + openedStreams.Range(func(key interface{}, value interface{}) bool { + stream := value.(*MuxedStream) + readBuffer := stream.readBuffer.(*noopBuffer) + writeBuffer := stream.writeBuffer.(*noopBuffer) + return assert.True(t, readBuffer.isClosed && writeBuffer.isClosed, "Abort() should have closed all the streams") + }) + + select { + case <-shutdownChan: + default: + assert.Fail(t, "after Abort(), shutdownChan should have been closed") + } + + // multiple aborts shouldn't cause any issues + m.Abort() + m.Abort() + m.Abort() +} diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go index 59f722cc..8a3330d3 100644 --- a/h2mux/h2mux.go +++ b/h2mux/h2mux.go @@ -353,9 +353,11 @@ func (m *Muxer) Serve(ctx context.Context) error { } // Shutdown is called to initiate the "happy path" of muxer termination. -func (m *Muxer) Shutdown() { +// It blocks new streams from being created. +// It returns a channel that is closed when the last stream has been closed. +func (m *Muxer) Shutdown() <-chan struct{} { m.explicitShutdown.Fuse(true) - m.muxReader.Shutdown() + return m.muxReader.Shutdown() } // IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel. diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go index 9b9ce13c..b7995232 100644 --- a/h2mux/h2mux_test.go +++ b/h2mux/h2mux_test.go @@ -55,6 +55,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc) DefaultWindowSize: (1 << 8) - 1, MaxWindowSize: (1 << 15) - 1, StreamWriteBufferMaxLen: 1024, + HeartbeatInterval: defaultTimeout, + MaxHeartbeats: defaultRetries, }, OriginConn: origin, EdgeMuxConfig: MuxerConfig{ @@ -65,6 +67,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc) DefaultWindowSize: (1 << 8) - 1, MaxWindowSize: (1 << 15) - 1, StreamWriteBufferMaxLen: 1024, + HeartbeatInterval: defaultTimeout, + MaxHeartbeats: defaultRetries, }, EdgeConn: edge, doneC: make(chan struct{}), @@ -83,6 +87,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress Name: "origin", CompressionQuality: quality, Logger: log.NewEntry(log.New()), + HeartbeatInterval: defaultTimeout, + MaxHeartbeats: defaultRetries, }, OriginConn: origin, EdgeMuxConfig: MuxerConfig{ @@ -91,6 +97,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress Name: "edge", CompressionQuality: quality, Logger: log.NewEntry(log.New()), + HeartbeatInterval: defaultTimeout, + MaxHeartbeats: defaultRetries, }, EdgeConn: edge, doneC: make(chan struct{}), diff --git a/h2mux/muxedstream.go b/h2mux/muxedstream.go index 6bafa19d..a37270cc 100644 --- a/h2mux/muxedstream.go +++ b/h2mux/muxedstream.go @@ -17,6 +17,12 @@ type ReadWriteClosedCloser interface { Closed() bool } +// MuxedStreamDataSignaller is a write-only *ReadyList +type MuxedStreamDataSignaller interface { + // Non-blocking: call this when data is ready to be sent for the given stream ID. + Signal(ID uint32) +} + // MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data. type MuxedStream struct { streamID uint32 @@ -55,8 +61,8 @@ type MuxedStream struct { // This is the amount of bytes that are in the peer's receive window // (how much data we can send from this stream). sendWindow uint32 - // Reference to the muxer's readyList; signal this for stream data to be sent. - readyList *ReadyList + // The muxer's readyList + readyList MuxedStreamDataSignaller // The headers that should be sent, and a flag so we only send them once. headersSent bool writeHeaders []Header @@ -88,7 +94,7 @@ func (th TunnelHostname) IsSet() bool { return th != "" } -func NewStream(config MuxerConfig, writeHeaders []Header, readyList *ReadyList, dictionaries h2Dictionaries) *MuxedStream { +func NewStream(config MuxerConfig, writeHeaders []Header, readyList MuxedStreamDataSignaller, dictionaries h2Dictionaries) *MuxedStream { return &MuxedStream{ responseHeadersReceived: make(chan struct{}), readBuffer: NewSharedBuffer(), diff --git a/h2mux/muxreader.go b/h2mux/muxreader.go index 728c94c4..c9b4dff7 100644 --- a/h2mux/muxreader.go +++ b/h2mux/muxreader.go @@ -51,10 +51,12 @@ type MuxReader struct { dictionaries h2Dictionaries } -func (r *MuxReader) Shutdown() { - done := r.streams.Shutdown() - if done == nil { - return +// Shutdown blocks new streams from being created. +// It returns a channel that is closed once the last stream has closed. +func (r *MuxReader) Shutdown() <-chan struct{} { + done, alreadyInProgress := r.streams.Shutdown() + if alreadyInProgress { + return done } r.sendGoAway(http2.ErrCodeNo) go func() { @@ -62,6 +64,7 @@ func (r *MuxReader) Shutdown() { <-done r.r.Close() }() + return done } func (r *MuxReader) run(parentLogger *log.Entry) error {