TUN-2608: h2mux.Muxer.Shutdown always returns a non-nil channel
This commit is contained in:
parent
bbf31377c2
commit
b499c0fdba
|
@ -13,26 +13,28 @@ type activeStreamMap struct {
|
|||
sync.RWMutex
|
||||
// streams tracks open streams.
|
||||
streams map[uint32]*MuxedStream
|
||||
// streamsEmpty is a chan that should be closed when no more streams are open.
|
||||
streamsEmpty chan struct{}
|
||||
// nextStreamID is the next ID to use on our side of the connection.
|
||||
// This is odd for clients, even for servers.
|
||||
nextStreamID uint32
|
||||
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
|
||||
maxPeerStreamID uint32
|
||||
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
|
||||
activeStreams prometheus.Gauge
|
||||
|
||||
// ignoreNewStreams is true when the connection is being shut down. New streams
|
||||
// cannot be registered.
|
||||
ignoreNewStreams bool
|
||||
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
|
||||
activeStreams prometheus.Gauge
|
||||
// streamsEmpty is a chan that will be closed when no more streams are open.
|
||||
streamsEmptyChan chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap {
|
||||
m := &activeStreamMap{
|
||||
streams: make(map[uint32]*MuxedStream),
|
||||
streamsEmpty: make(chan struct{}),
|
||||
nextStreamID: 1,
|
||||
activeStreams: activeStreams,
|
||||
streams: make(map[uint32]*MuxedStream),
|
||||
streamsEmptyChan: make(chan struct{}),
|
||||
nextStreamID: 1,
|
||||
activeStreams: activeStreams,
|
||||
}
|
||||
// Client initiated stream uses odd stream ID, server initiated stream uses even stream ID
|
||||
if !useClientStreamNumbers {
|
||||
|
@ -41,6 +43,12 @@ func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Ga
|
|||
return m
|
||||
}
|
||||
|
||||
func (m *activeStreamMap) notifyStreamsEmpty() {
|
||||
m.closeOnce.Do(func() {
|
||||
close(m.streamsEmptyChan)
|
||||
})
|
||||
}
|
||||
|
||||
// Len returns the number of active streams.
|
||||
func (m *activeStreamMap) Len() int {
|
||||
m.RLock()
|
||||
|
@ -79,30 +87,27 @@ func (m *activeStreamMap) Delete(streamID uint32) {
|
|||
delete(m.streams, streamID)
|
||||
m.activeStreams.Dec()
|
||||
}
|
||||
if len(m.streams) == 0 && m.streamsEmpty != nil {
|
||||
close(m.streamsEmpty)
|
||||
m.streamsEmpty = nil
|
||||
if len(m.streams) == 0 {
|
||||
m.notifyStreamsEmpty()
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown blocks new streams from being created. It returns a channel that receives an event
|
||||
// once the last stream has closed, or nil if a shutdown is in progress.
|
||||
func (m *activeStreamMap) Shutdown() <-chan struct{} {
|
||||
// Shutdown blocks new streams from being created.
|
||||
// It returns `done`, a channel that is closed once the last stream has closed
|
||||
// and `progress`, whether a shutdown was already in progress
|
||||
func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bool) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
if m.ignoreNewStreams {
|
||||
// already shutting down
|
||||
return nil
|
||||
return m.streamsEmptyChan, true
|
||||
}
|
||||
m.ignoreNewStreams = true
|
||||
done := make(chan struct{})
|
||||
if len(m.streams) == 0 {
|
||||
// nothing to shut down
|
||||
close(done)
|
||||
return done
|
||||
m.notifyStreamsEmpty()
|
||||
}
|
||||
m.streamsEmpty = done
|
||||
return done
|
||||
return m.streamsEmptyChan, false
|
||||
}
|
||||
|
||||
// AcquireLocalID acquires a new stream ID for a stream you're opening.
|
||||
|
@ -170,4 +175,5 @@ func (m *activeStreamMap) Abort() {
|
|||
stream.Close()
|
||||
}
|
||||
m.ignoreNewStreams = true
|
||||
m.notifyStreamsEmpty()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
const numStreams = 1000
|
||||
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
|
||||
|
||||
// Add all the streams
|
||||
{
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func(streamID int) {
|
||||
defer wg.Done()
|
||||
stream := &MuxedStream{streamID: uint32(streamID)}
|
||||
ok := m.Set(stream)
|
||||
assert.True(t, ok)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
|
||||
|
||||
shutdownChan, alreadyInProgress := m.Shutdown()
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed")
|
||||
default:
|
||||
}
|
||||
assert.False(t, alreadyInProgress)
|
||||
|
||||
shutdownChan2, alreadyInProgress2 := m.Shutdown()
|
||||
assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel")
|
||||
assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'")
|
||||
|
||||
// Delete all the streams
|
||||
{
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func(streamID int) {
|
||||
defer wg.Done()
|
||||
m.Delete(uint32(streamID))
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
assert.Equal(t, 0, m.Len(), "All the streams should have been deleted")
|
||||
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
default:
|
||||
assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed")
|
||||
}
|
||||
}
|
||||
|
||||
type noopBuffer struct {
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func (t *noopBuffer) Read(p []byte) (n int, err error) { return len(p), nil }
|
||||
func (t *noopBuffer) Write(p []byte) (n int, err error) { return len(p), nil }
|
||||
func (t *noopBuffer) Reset() {}
|
||||
func (t *noopBuffer) Len() int { return 0 }
|
||||
func (t *noopBuffer) Close() error { t.isClosed = true; return nil }
|
||||
func (t *noopBuffer) Closed() bool { return t.isClosed }
|
||||
|
||||
type noopReadyList struct{}
|
||||
|
||||
func (_ *noopReadyList) Signal(streamID uint32) {}
|
||||
|
||||
func TestAbort(t *testing.T) {
|
||||
const numStreams = 1000
|
||||
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
|
||||
|
||||
var openedStreams sync.Map
|
||||
|
||||
// Add all the streams
|
||||
{
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func(streamID int) {
|
||||
defer wg.Done()
|
||||
stream := &MuxedStream{
|
||||
streamID: uint32(streamID),
|
||||
readBuffer: &noopBuffer{},
|
||||
writeBuffer: &noopBuffer{},
|
||||
readyList: &noopReadyList{},
|
||||
}
|
||||
ok := m.Set(stream)
|
||||
assert.True(t, ok)
|
||||
|
||||
openedStreams.Store(stream.streamID, stream)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
|
||||
|
||||
shutdownChan, alreadyInProgress := m.Shutdown()
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
assert.Fail(t, "before Abort(), shutdownChan shouldn't be closed")
|
||||
default:
|
||||
}
|
||||
assert.False(t, alreadyInProgress)
|
||||
|
||||
m.Abort()
|
||||
assert.Equal(t, numStreams, m.Len(), "Abort() shouldn't delete any streams")
|
||||
openedStreams.Range(func(key interface{}, value interface{}) bool {
|
||||
stream := value.(*MuxedStream)
|
||||
readBuffer := stream.readBuffer.(*noopBuffer)
|
||||
writeBuffer := stream.writeBuffer.(*noopBuffer)
|
||||
return assert.True(t, readBuffer.isClosed && writeBuffer.isClosed, "Abort() should have closed all the streams")
|
||||
})
|
||||
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
default:
|
||||
assert.Fail(t, "after Abort(), shutdownChan should have been closed")
|
||||
}
|
||||
|
||||
// multiple aborts shouldn't cause any issues
|
||||
m.Abort()
|
||||
m.Abort()
|
||||
m.Abort()
|
||||
}
|
|
@ -353,9 +353,11 @@ func (m *Muxer) Serve(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// Shutdown is called to initiate the "happy path" of muxer termination.
|
||||
func (m *Muxer) Shutdown() {
|
||||
// It blocks new streams from being created.
|
||||
// It returns a channel that is closed when the last stream has been closed.
|
||||
func (m *Muxer) Shutdown() <-chan struct{} {
|
||||
m.explicitShutdown.Fuse(true)
|
||||
m.muxReader.Shutdown()
|
||||
return m.muxReader.Shutdown()
|
||||
}
|
||||
|
||||
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
|
||||
|
|
|
@ -55,6 +55,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
|
|||
DefaultWindowSize: (1 << 8) - 1,
|
||||
MaxWindowSize: (1 << 15) - 1,
|
||||
StreamWriteBufferMaxLen: 1024,
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
OriginConn: origin,
|
||||
EdgeMuxConfig: MuxerConfig{
|
||||
|
@ -65,6 +67,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
|
|||
DefaultWindowSize: (1 << 8) - 1,
|
||||
MaxWindowSize: (1 << 15) - 1,
|
||||
StreamWriteBufferMaxLen: 1024,
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
EdgeConn: edge,
|
||||
doneC: make(chan struct{}),
|
||||
|
@ -83,6 +87,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
|
|||
Name: "origin",
|
||||
CompressionQuality: quality,
|
||||
Logger: log.NewEntry(log.New()),
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
OriginConn: origin,
|
||||
EdgeMuxConfig: MuxerConfig{
|
||||
|
@ -91,6 +97,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
|
|||
Name: "edge",
|
||||
CompressionQuality: quality,
|
||||
Logger: log.NewEntry(log.New()),
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
EdgeConn: edge,
|
||||
doneC: make(chan struct{}),
|
||||
|
|
|
@ -17,6 +17,12 @@ type ReadWriteClosedCloser interface {
|
|||
Closed() bool
|
||||
}
|
||||
|
||||
// MuxedStreamDataSignaller is a write-only *ReadyList
|
||||
type MuxedStreamDataSignaller interface {
|
||||
// Non-blocking: call this when data is ready to be sent for the given stream ID.
|
||||
Signal(ID uint32)
|
||||
}
|
||||
|
||||
// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data.
|
||||
type MuxedStream struct {
|
||||
streamID uint32
|
||||
|
@ -55,8 +61,8 @@ type MuxedStream struct {
|
|||
// This is the amount of bytes that are in the peer's receive window
|
||||
// (how much data we can send from this stream).
|
||||
sendWindow uint32
|
||||
// Reference to the muxer's readyList; signal this for stream data to be sent.
|
||||
readyList *ReadyList
|
||||
// The muxer's readyList
|
||||
readyList MuxedStreamDataSignaller
|
||||
// The headers that should be sent, and a flag so we only send them once.
|
||||
headersSent bool
|
||||
writeHeaders []Header
|
||||
|
@ -88,7 +94,7 @@ func (th TunnelHostname) IsSet() bool {
|
|||
return th != ""
|
||||
}
|
||||
|
||||
func NewStream(config MuxerConfig, writeHeaders []Header, readyList *ReadyList, dictionaries h2Dictionaries) *MuxedStream {
|
||||
func NewStream(config MuxerConfig, writeHeaders []Header, readyList MuxedStreamDataSignaller, dictionaries h2Dictionaries) *MuxedStream {
|
||||
return &MuxedStream{
|
||||
responseHeadersReceived: make(chan struct{}),
|
||||
readBuffer: NewSharedBuffer(),
|
||||
|
|
|
@ -51,10 +51,12 @@ type MuxReader struct {
|
|||
dictionaries h2Dictionaries
|
||||
}
|
||||
|
||||
func (r *MuxReader) Shutdown() {
|
||||
done := r.streams.Shutdown()
|
||||
if done == nil {
|
||||
return
|
||||
// Shutdown blocks new streams from being created.
|
||||
// It returns a channel that is closed once the last stream has closed.
|
||||
func (r *MuxReader) Shutdown() <-chan struct{} {
|
||||
done, alreadyInProgress := r.streams.Shutdown()
|
||||
if alreadyInProgress {
|
||||
return done
|
||||
}
|
||||
r.sendGoAway(http2.ErrCodeNo)
|
||||
go func() {
|
||||
|
@ -62,6 +64,7 @@ func (r *MuxReader) Shutdown() {
|
|||
<-done
|
||||
r.r.Close()
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
func (r *MuxReader) run(parentLogger *log.Entry) error {
|
||||
|
|
Loading…
Reference in New Issue