TUN-2608: h2mux.Muxer.Shutdown always returns a non-nil channel

This commit is contained in:
Nick Vollmar 2019-12-03 15:01:28 -06:00
parent bbf31377c2
commit b499c0fdba
6 changed files with 188 additions and 29 deletions

View File

@ -13,26 +13,28 @@ type activeStreamMap struct {
sync.RWMutex
// streams tracks open streams.
streams map[uint32]*MuxedStream
// streamsEmpty is a chan that should be closed when no more streams are open.
streamsEmpty chan struct{}
// nextStreamID is the next ID to use on our side of the connection.
// This is odd for clients, even for servers.
nextStreamID uint32
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
maxPeerStreamID uint32
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
activeStreams prometheus.Gauge
// ignoreNewStreams is true when the connection is being shut down. New streams
// cannot be registered.
ignoreNewStreams bool
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
activeStreams prometheus.Gauge
// streamsEmpty is a chan that will be closed when no more streams are open.
streamsEmptyChan chan struct{}
closeOnce sync.Once
}
func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap {
m := &activeStreamMap{
streams: make(map[uint32]*MuxedStream),
streamsEmpty: make(chan struct{}),
nextStreamID: 1,
activeStreams: activeStreams,
streams: make(map[uint32]*MuxedStream),
streamsEmptyChan: make(chan struct{}),
nextStreamID: 1,
activeStreams: activeStreams,
}
// Client initiated stream uses odd stream ID, server initiated stream uses even stream ID
if !useClientStreamNumbers {
@ -41,6 +43,12 @@ func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Ga
return m
}
func (m *activeStreamMap) notifyStreamsEmpty() {
m.closeOnce.Do(func() {
close(m.streamsEmptyChan)
})
}
// Len returns the number of active streams.
func (m *activeStreamMap) Len() int {
m.RLock()
@ -79,30 +87,27 @@ func (m *activeStreamMap) Delete(streamID uint32) {
delete(m.streams, streamID)
m.activeStreams.Dec()
}
if len(m.streams) == 0 && m.streamsEmpty != nil {
close(m.streamsEmpty)
m.streamsEmpty = nil
if len(m.streams) == 0 {
m.notifyStreamsEmpty()
}
}
// Shutdown blocks new streams from being created. It returns a channel that receives an event
// once the last stream has closed, or nil if a shutdown is in progress.
func (m *activeStreamMap) Shutdown() <-chan struct{} {
// Shutdown blocks new streams from being created.
// It returns `done`, a channel that is closed once the last stream has closed
// and `progress`, whether a shutdown was already in progress
func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bool) {
m.Lock()
defer m.Unlock()
if m.ignoreNewStreams {
// already shutting down
return nil
return m.streamsEmptyChan, true
}
m.ignoreNewStreams = true
done := make(chan struct{})
if len(m.streams) == 0 {
// nothing to shut down
close(done)
return done
m.notifyStreamsEmpty()
}
m.streamsEmpty = done
return done
return m.streamsEmptyChan, false
}
// AcquireLocalID acquires a new stream ID for a stream you're opening.
@ -170,4 +175,5 @@ func (m *activeStreamMap) Abort() {
stream.Close()
}
m.ignoreNewStreams = true
m.notifyStreamsEmpty()
}

View File

@ -0,0 +1,134 @@
package h2mux
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestShutdown(t *testing.T) {
const numStreams = 1000
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
// Add all the streams
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
stream := &MuxedStream{streamID: uint32(streamID)}
ok := m.Set(stream)
assert.True(t, ok)
}(i)
}
wg.Wait()
}
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
shutdownChan, alreadyInProgress := m.Shutdown()
select {
case <-shutdownChan:
assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed")
default:
}
assert.False(t, alreadyInProgress)
shutdownChan2, alreadyInProgress2 := m.Shutdown()
assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel")
assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'")
// Delete all the streams
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
m.Delete(uint32(streamID))
}(i)
}
wg.Wait()
}
assert.Equal(t, 0, m.Len(), "All the streams should have been deleted")
select {
case <-shutdownChan:
default:
assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed")
}
}
type noopBuffer struct {
isClosed bool
}
func (t *noopBuffer) Read(p []byte) (n int, err error) { return len(p), nil }
func (t *noopBuffer) Write(p []byte) (n int, err error) { return len(p), nil }
func (t *noopBuffer) Reset() {}
func (t *noopBuffer) Len() int { return 0 }
func (t *noopBuffer) Close() error { t.isClosed = true; return nil }
func (t *noopBuffer) Closed() bool { return t.isClosed }
type noopReadyList struct{}
func (_ *noopReadyList) Signal(streamID uint32) {}
func TestAbort(t *testing.T) {
const numStreams = 1000
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
var openedStreams sync.Map
// Add all the streams
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
stream := &MuxedStream{
streamID: uint32(streamID),
readBuffer: &noopBuffer{},
writeBuffer: &noopBuffer{},
readyList: &noopReadyList{},
}
ok := m.Set(stream)
assert.True(t, ok)
openedStreams.Store(stream.streamID, stream)
}(i)
}
wg.Wait()
}
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
shutdownChan, alreadyInProgress := m.Shutdown()
select {
case <-shutdownChan:
assert.Fail(t, "before Abort(), shutdownChan shouldn't be closed")
default:
}
assert.False(t, alreadyInProgress)
m.Abort()
assert.Equal(t, numStreams, m.Len(), "Abort() shouldn't delete any streams")
openedStreams.Range(func(key interface{}, value interface{}) bool {
stream := value.(*MuxedStream)
readBuffer := stream.readBuffer.(*noopBuffer)
writeBuffer := stream.writeBuffer.(*noopBuffer)
return assert.True(t, readBuffer.isClosed && writeBuffer.isClosed, "Abort() should have closed all the streams")
})
select {
case <-shutdownChan:
default:
assert.Fail(t, "after Abort(), shutdownChan should have been closed")
}
// multiple aborts shouldn't cause any issues
m.Abort()
m.Abort()
m.Abort()
}

View File

@ -353,9 +353,11 @@ func (m *Muxer) Serve(ctx context.Context) error {
}
// Shutdown is called to initiate the "happy path" of muxer termination.
func (m *Muxer) Shutdown() {
// It blocks new streams from being created.
// It returns a channel that is closed when the last stream has been closed.
func (m *Muxer) Shutdown() <-chan struct{} {
m.explicitShutdown.Fuse(true)
m.muxReader.Shutdown()
return m.muxReader.Shutdown()
}
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.

View File

@ -55,6 +55,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
DefaultWindowSize: (1 << 8) - 1,
MaxWindowSize: (1 << 15) - 1,
StreamWriteBufferMaxLen: 1024,
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
},
OriginConn: origin,
EdgeMuxConfig: MuxerConfig{
@ -65,6 +67,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
DefaultWindowSize: (1 << 8) - 1,
MaxWindowSize: (1 << 15) - 1,
StreamWriteBufferMaxLen: 1024,
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
},
EdgeConn: edge,
doneC: make(chan struct{}),
@ -83,6 +87,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
Name: "origin",
CompressionQuality: quality,
Logger: log.NewEntry(log.New()),
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
},
OriginConn: origin,
EdgeMuxConfig: MuxerConfig{
@ -91,6 +97,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
Name: "edge",
CompressionQuality: quality,
Logger: log.NewEntry(log.New()),
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
},
EdgeConn: edge,
doneC: make(chan struct{}),

View File

@ -17,6 +17,12 @@ type ReadWriteClosedCloser interface {
Closed() bool
}
// MuxedStreamDataSignaller is a write-only *ReadyList
type MuxedStreamDataSignaller interface {
// Non-blocking: call this when data is ready to be sent for the given stream ID.
Signal(ID uint32)
}
// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data.
type MuxedStream struct {
streamID uint32
@ -55,8 +61,8 @@ type MuxedStream struct {
// This is the amount of bytes that are in the peer's receive window
// (how much data we can send from this stream).
sendWindow uint32
// Reference to the muxer's readyList; signal this for stream data to be sent.
readyList *ReadyList
// The muxer's readyList
readyList MuxedStreamDataSignaller
// The headers that should be sent, and a flag so we only send them once.
headersSent bool
writeHeaders []Header
@ -88,7 +94,7 @@ func (th TunnelHostname) IsSet() bool {
return th != ""
}
func NewStream(config MuxerConfig, writeHeaders []Header, readyList *ReadyList, dictionaries h2Dictionaries) *MuxedStream {
func NewStream(config MuxerConfig, writeHeaders []Header, readyList MuxedStreamDataSignaller, dictionaries h2Dictionaries) *MuxedStream {
return &MuxedStream{
responseHeadersReceived: make(chan struct{}),
readBuffer: NewSharedBuffer(),

View File

@ -51,10 +51,12 @@ type MuxReader struct {
dictionaries h2Dictionaries
}
func (r *MuxReader) Shutdown() {
done := r.streams.Shutdown()
if done == nil {
return
// Shutdown blocks new streams from being created.
// It returns a channel that is closed once the last stream has closed.
func (r *MuxReader) Shutdown() <-chan struct{} {
done, alreadyInProgress := r.streams.Shutdown()
if alreadyInProgress {
return done
}
r.sendGoAway(http2.ErrCodeNo)
go func() {
@ -62,6 +64,7 @@ func (r *MuxReader) Shutdown() {
<-done
r.r.Close()
}()
return done
}
func (r *MuxReader) run(parentLogger *log.Entry) error {