TUN-2608: h2mux.Muxer.Shutdown always returns a non-nil channel

This commit is contained in:
Nick Vollmar 2019-12-03 15:01:28 -06:00
parent bbf31377c2
commit b499c0fdba
6 changed files with 188 additions and 29 deletions

View File

@ -13,24 +13,26 @@ type activeStreamMap struct {
sync.RWMutex sync.RWMutex
// streams tracks open streams. // streams tracks open streams.
streams map[uint32]*MuxedStream streams map[uint32]*MuxedStream
// streamsEmpty is a chan that should be closed when no more streams are open.
streamsEmpty chan struct{}
// nextStreamID is the next ID to use on our side of the connection. // nextStreamID is the next ID to use on our side of the connection.
// This is odd for clients, even for servers. // This is odd for clients, even for servers.
nextStreamID uint32 nextStreamID uint32
// maxPeerStreamID is the ID of the most recent stream opened by the peer. // maxPeerStreamID is the ID of the most recent stream opened by the peer.
maxPeerStreamID uint32 maxPeerStreamID uint32
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
activeStreams prometheus.Gauge
// ignoreNewStreams is true when the connection is being shut down. New streams // ignoreNewStreams is true when the connection is being shut down. New streams
// cannot be registered. // cannot be registered.
ignoreNewStreams bool ignoreNewStreams bool
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams // streamsEmpty is a chan that will be closed when no more streams are open.
activeStreams prometheus.Gauge streamsEmptyChan chan struct{}
closeOnce sync.Once
} }
func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap { func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap {
m := &activeStreamMap{ m := &activeStreamMap{
streams: make(map[uint32]*MuxedStream), streams: make(map[uint32]*MuxedStream),
streamsEmpty: make(chan struct{}), streamsEmptyChan: make(chan struct{}),
nextStreamID: 1, nextStreamID: 1,
activeStreams: activeStreams, activeStreams: activeStreams,
} }
@ -41,6 +43,12 @@ func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Ga
return m return m
} }
func (m *activeStreamMap) notifyStreamsEmpty() {
m.closeOnce.Do(func() {
close(m.streamsEmptyChan)
})
}
// Len returns the number of active streams. // Len returns the number of active streams.
func (m *activeStreamMap) Len() int { func (m *activeStreamMap) Len() int {
m.RLock() m.RLock()
@ -79,30 +87,27 @@ func (m *activeStreamMap) Delete(streamID uint32) {
delete(m.streams, streamID) delete(m.streams, streamID)
m.activeStreams.Dec() m.activeStreams.Dec()
} }
if len(m.streams) == 0 && m.streamsEmpty != nil { if len(m.streams) == 0 {
close(m.streamsEmpty) m.notifyStreamsEmpty()
m.streamsEmpty = nil
} }
} }
// Shutdown blocks new streams from being created. It returns a channel that receives an event // Shutdown blocks new streams from being created.
// once the last stream has closed, or nil if a shutdown is in progress. // It returns `done`, a channel that is closed once the last stream has closed
func (m *activeStreamMap) Shutdown() <-chan struct{} { // and `progress`, whether a shutdown was already in progress
func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bool) {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
if m.ignoreNewStreams { if m.ignoreNewStreams {
// already shutting down // already shutting down
return nil return m.streamsEmptyChan, true
} }
m.ignoreNewStreams = true m.ignoreNewStreams = true
done := make(chan struct{})
if len(m.streams) == 0 { if len(m.streams) == 0 {
// nothing to shut down // nothing to shut down
close(done) m.notifyStreamsEmpty()
return done
} }
m.streamsEmpty = done return m.streamsEmptyChan, false
return done
} }
// AcquireLocalID acquires a new stream ID for a stream you're opening. // AcquireLocalID acquires a new stream ID for a stream you're opening.
@ -170,4 +175,5 @@ func (m *activeStreamMap) Abort() {
stream.Close() stream.Close()
} }
m.ignoreNewStreams = true m.ignoreNewStreams = true
m.notifyStreamsEmpty()
} }

View File

@ -0,0 +1,134 @@
package h2mux
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestShutdown(t *testing.T) {
const numStreams = 1000
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
// Add all the streams
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
stream := &MuxedStream{streamID: uint32(streamID)}
ok := m.Set(stream)
assert.True(t, ok)
}(i)
}
wg.Wait()
}
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
shutdownChan, alreadyInProgress := m.Shutdown()
select {
case <-shutdownChan:
assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed")
default:
}
assert.False(t, alreadyInProgress)
shutdownChan2, alreadyInProgress2 := m.Shutdown()
assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel")
assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'")
// Delete all the streams
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
m.Delete(uint32(streamID))
}(i)
}
wg.Wait()
}
assert.Equal(t, 0, m.Len(), "All the streams should have been deleted")
select {
case <-shutdownChan:
default:
assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed")
}
}
type noopBuffer struct {
isClosed bool
}
func (t *noopBuffer) Read(p []byte) (n int, err error) { return len(p), nil }
func (t *noopBuffer) Write(p []byte) (n int, err error) { return len(p), nil }
func (t *noopBuffer) Reset() {}
func (t *noopBuffer) Len() int { return 0 }
func (t *noopBuffer) Close() error { t.isClosed = true; return nil }
func (t *noopBuffer) Closed() bool { return t.isClosed }
type noopReadyList struct{}
func (_ *noopReadyList) Signal(streamID uint32) {}
func TestAbort(t *testing.T) {
const numStreams = 1000
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
var openedStreams sync.Map
// Add all the streams
{
var wg sync.WaitGroup
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(streamID int) {
defer wg.Done()
stream := &MuxedStream{
streamID: uint32(streamID),
readBuffer: &noopBuffer{},
writeBuffer: &noopBuffer{},
readyList: &noopReadyList{},
}
ok := m.Set(stream)
assert.True(t, ok)
openedStreams.Store(stream.streamID, stream)
}(i)
}
wg.Wait()
}
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
shutdownChan, alreadyInProgress := m.Shutdown()
select {
case <-shutdownChan:
assert.Fail(t, "before Abort(), shutdownChan shouldn't be closed")
default:
}
assert.False(t, alreadyInProgress)
m.Abort()
assert.Equal(t, numStreams, m.Len(), "Abort() shouldn't delete any streams")
openedStreams.Range(func(key interface{}, value interface{}) bool {
stream := value.(*MuxedStream)
readBuffer := stream.readBuffer.(*noopBuffer)
writeBuffer := stream.writeBuffer.(*noopBuffer)
return assert.True(t, readBuffer.isClosed && writeBuffer.isClosed, "Abort() should have closed all the streams")
})
select {
case <-shutdownChan:
default:
assert.Fail(t, "after Abort(), shutdownChan should have been closed")
}
// multiple aborts shouldn't cause any issues
m.Abort()
m.Abort()
m.Abort()
}

View File

@ -353,9 +353,11 @@ func (m *Muxer) Serve(ctx context.Context) error {
} }
// Shutdown is called to initiate the "happy path" of muxer termination. // Shutdown is called to initiate the "happy path" of muxer termination.
func (m *Muxer) Shutdown() { // It blocks new streams from being created.
// It returns a channel that is closed when the last stream has been closed.
func (m *Muxer) Shutdown() <-chan struct{} {
m.explicitShutdown.Fuse(true) m.explicitShutdown.Fuse(true)
m.muxReader.Shutdown() return m.muxReader.Shutdown()
} }
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel. // IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.

View File

@ -55,6 +55,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
DefaultWindowSize: (1 << 8) - 1, DefaultWindowSize: (1 << 8) - 1,
MaxWindowSize: (1 << 15) - 1, MaxWindowSize: (1 << 15) - 1,
StreamWriteBufferMaxLen: 1024, StreamWriteBufferMaxLen: 1024,
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
}, },
OriginConn: origin, OriginConn: origin,
EdgeMuxConfig: MuxerConfig{ EdgeMuxConfig: MuxerConfig{
@ -65,6 +67,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
DefaultWindowSize: (1 << 8) - 1, DefaultWindowSize: (1 << 8) - 1,
MaxWindowSize: (1 << 15) - 1, MaxWindowSize: (1 << 15) - 1,
StreamWriteBufferMaxLen: 1024, StreamWriteBufferMaxLen: 1024,
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
}, },
EdgeConn: edge, EdgeConn: edge,
doneC: make(chan struct{}), doneC: make(chan struct{}),
@ -83,6 +87,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
Name: "origin", Name: "origin",
CompressionQuality: quality, CompressionQuality: quality,
Logger: log.NewEntry(log.New()), Logger: log.NewEntry(log.New()),
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
}, },
OriginConn: origin, OriginConn: origin,
EdgeMuxConfig: MuxerConfig{ EdgeMuxConfig: MuxerConfig{
@ -91,6 +97,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
Name: "edge", Name: "edge",
CompressionQuality: quality, CompressionQuality: quality,
Logger: log.NewEntry(log.New()), Logger: log.NewEntry(log.New()),
HeartbeatInterval: defaultTimeout,
MaxHeartbeats: defaultRetries,
}, },
EdgeConn: edge, EdgeConn: edge,
doneC: make(chan struct{}), doneC: make(chan struct{}),

View File

@ -17,6 +17,12 @@ type ReadWriteClosedCloser interface {
Closed() bool Closed() bool
} }
// MuxedStreamDataSignaller is a write-only *ReadyList
type MuxedStreamDataSignaller interface {
// Non-blocking: call this when data is ready to be sent for the given stream ID.
Signal(ID uint32)
}
// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data. // MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data.
type MuxedStream struct { type MuxedStream struct {
streamID uint32 streamID uint32
@ -55,8 +61,8 @@ type MuxedStream struct {
// This is the amount of bytes that are in the peer's receive window // This is the amount of bytes that are in the peer's receive window
// (how much data we can send from this stream). // (how much data we can send from this stream).
sendWindow uint32 sendWindow uint32
// Reference to the muxer's readyList; signal this for stream data to be sent. // The muxer's readyList
readyList *ReadyList readyList MuxedStreamDataSignaller
// The headers that should be sent, and a flag so we only send them once. // The headers that should be sent, and a flag so we only send them once.
headersSent bool headersSent bool
writeHeaders []Header writeHeaders []Header
@ -88,7 +94,7 @@ func (th TunnelHostname) IsSet() bool {
return th != "" return th != ""
} }
func NewStream(config MuxerConfig, writeHeaders []Header, readyList *ReadyList, dictionaries h2Dictionaries) *MuxedStream { func NewStream(config MuxerConfig, writeHeaders []Header, readyList MuxedStreamDataSignaller, dictionaries h2Dictionaries) *MuxedStream {
return &MuxedStream{ return &MuxedStream{
responseHeadersReceived: make(chan struct{}), responseHeadersReceived: make(chan struct{}),
readBuffer: NewSharedBuffer(), readBuffer: NewSharedBuffer(),

View File

@ -51,10 +51,12 @@ type MuxReader struct {
dictionaries h2Dictionaries dictionaries h2Dictionaries
} }
func (r *MuxReader) Shutdown() { // Shutdown blocks new streams from being created.
done := r.streams.Shutdown() // It returns a channel that is closed once the last stream has closed.
if done == nil { func (r *MuxReader) Shutdown() <-chan struct{} {
return done, alreadyInProgress := r.streams.Shutdown()
if alreadyInProgress {
return done
} }
r.sendGoAway(http2.ErrCodeNo) r.sendGoAway(http2.ErrCodeNo)
go func() { go func() {
@ -62,6 +64,7 @@ func (r *MuxReader) Shutdown() {
<-done <-done
r.r.Close() r.r.Close()
}() }()
return done
} }
func (r *MuxReader) run(parentLogger *log.Entry) error { func (r *MuxReader) run(parentLogger *log.Entry) error {