TUN-2608: h2mux.Muxer.Shutdown always returns a non-nil channel
This commit is contained in:
parent
bbf31377c2
commit
b499c0fdba
|
@ -13,24 +13,26 @@ type activeStreamMap struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
// streams tracks open streams.
|
// streams tracks open streams.
|
||||||
streams map[uint32]*MuxedStream
|
streams map[uint32]*MuxedStream
|
||||||
// streamsEmpty is a chan that should be closed when no more streams are open.
|
|
||||||
streamsEmpty chan struct{}
|
|
||||||
// nextStreamID is the next ID to use on our side of the connection.
|
// nextStreamID is the next ID to use on our side of the connection.
|
||||||
// This is odd for clients, even for servers.
|
// This is odd for clients, even for servers.
|
||||||
nextStreamID uint32
|
nextStreamID uint32
|
||||||
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
|
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
|
||||||
maxPeerStreamID uint32
|
maxPeerStreamID uint32
|
||||||
|
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
|
||||||
|
activeStreams prometheus.Gauge
|
||||||
|
|
||||||
// ignoreNewStreams is true when the connection is being shut down. New streams
|
// ignoreNewStreams is true when the connection is being shut down. New streams
|
||||||
// cannot be registered.
|
// cannot be registered.
|
||||||
ignoreNewStreams bool
|
ignoreNewStreams bool
|
||||||
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
|
// streamsEmpty is a chan that will be closed when no more streams are open.
|
||||||
activeStreams prometheus.Gauge
|
streamsEmptyChan chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap {
|
func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap {
|
||||||
m := &activeStreamMap{
|
m := &activeStreamMap{
|
||||||
streams: make(map[uint32]*MuxedStream),
|
streams: make(map[uint32]*MuxedStream),
|
||||||
streamsEmpty: make(chan struct{}),
|
streamsEmptyChan: make(chan struct{}),
|
||||||
nextStreamID: 1,
|
nextStreamID: 1,
|
||||||
activeStreams: activeStreams,
|
activeStreams: activeStreams,
|
||||||
}
|
}
|
||||||
|
@ -41,6 +43,12 @@ func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Ga
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *activeStreamMap) notifyStreamsEmpty() {
|
||||||
|
m.closeOnce.Do(func() {
|
||||||
|
close(m.streamsEmptyChan)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Len returns the number of active streams.
|
// Len returns the number of active streams.
|
||||||
func (m *activeStreamMap) Len() int {
|
func (m *activeStreamMap) Len() int {
|
||||||
m.RLock()
|
m.RLock()
|
||||||
|
@ -79,30 +87,27 @@ func (m *activeStreamMap) Delete(streamID uint32) {
|
||||||
delete(m.streams, streamID)
|
delete(m.streams, streamID)
|
||||||
m.activeStreams.Dec()
|
m.activeStreams.Dec()
|
||||||
}
|
}
|
||||||
if len(m.streams) == 0 && m.streamsEmpty != nil {
|
if len(m.streams) == 0 {
|
||||||
close(m.streamsEmpty)
|
m.notifyStreamsEmpty()
|
||||||
m.streamsEmpty = nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown blocks new streams from being created. It returns a channel that receives an event
|
// Shutdown blocks new streams from being created.
|
||||||
// once the last stream has closed, or nil if a shutdown is in progress.
|
// It returns `done`, a channel that is closed once the last stream has closed
|
||||||
func (m *activeStreamMap) Shutdown() <-chan struct{} {
|
// and `progress`, whether a shutdown was already in progress
|
||||||
|
func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bool) {
|
||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
if m.ignoreNewStreams {
|
if m.ignoreNewStreams {
|
||||||
// already shutting down
|
// already shutting down
|
||||||
return nil
|
return m.streamsEmptyChan, true
|
||||||
}
|
}
|
||||||
m.ignoreNewStreams = true
|
m.ignoreNewStreams = true
|
||||||
done := make(chan struct{})
|
|
||||||
if len(m.streams) == 0 {
|
if len(m.streams) == 0 {
|
||||||
// nothing to shut down
|
// nothing to shut down
|
||||||
close(done)
|
m.notifyStreamsEmpty()
|
||||||
return done
|
|
||||||
}
|
}
|
||||||
m.streamsEmpty = done
|
return m.streamsEmptyChan, false
|
||||||
return done
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcquireLocalID acquires a new stream ID for a stream you're opening.
|
// AcquireLocalID acquires a new stream ID for a stream you're opening.
|
||||||
|
@ -170,4 +175,5 @@ func (m *activeStreamMap) Abort() {
|
||||||
stream.Close()
|
stream.Close()
|
||||||
}
|
}
|
||||||
m.ignoreNewStreams = true
|
m.ignoreNewStreams = true
|
||||||
|
m.notifyStreamsEmpty()
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,134 @@
|
||||||
|
package h2mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestShutdown(t *testing.T) {
|
||||||
|
const numStreams = 1000
|
||||||
|
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
|
||||||
|
|
||||||
|
// Add all the streams
|
||||||
|
{
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numStreams)
|
||||||
|
for i := 0; i < numStreams; i++ {
|
||||||
|
go func(streamID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
stream := &MuxedStream{streamID: uint32(streamID)}
|
||||||
|
ok := m.Set(stream)
|
||||||
|
assert.True(t, ok)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
|
||||||
|
|
||||||
|
shutdownChan, alreadyInProgress := m.Shutdown()
|
||||||
|
select {
|
||||||
|
case <-shutdownChan:
|
||||||
|
assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
assert.False(t, alreadyInProgress)
|
||||||
|
|
||||||
|
shutdownChan2, alreadyInProgress2 := m.Shutdown()
|
||||||
|
assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel")
|
||||||
|
assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'")
|
||||||
|
|
||||||
|
// Delete all the streams
|
||||||
|
{
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numStreams)
|
||||||
|
for i := 0; i < numStreams; i++ {
|
||||||
|
go func(streamID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
m.Delete(uint32(streamID))
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
assert.Equal(t, 0, m.Len(), "All the streams should have been deleted")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-shutdownChan:
|
||||||
|
default:
|
||||||
|
assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type noopBuffer struct {
|
||||||
|
isClosed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *noopBuffer) Read(p []byte) (n int, err error) { return len(p), nil }
|
||||||
|
func (t *noopBuffer) Write(p []byte) (n int, err error) { return len(p), nil }
|
||||||
|
func (t *noopBuffer) Reset() {}
|
||||||
|
func (t *noopBuffer) Len() int { return 0 }
|
||||||
|
func (t *noopBuffer) Close() error { t.isClosed = true; return nil }
|
||||||
|
func (t *noopBuffer) Closed() bool { return t.isClosed }
|
||||||
|
|
||||||
|
type noopReadyList struct{}
|
||||||
|
|
||||||
|
func (_ *noopReadyList) Signal(streamID uint32) {}
|
||||||
|
|
||||||
|
func TestAbort(t *testing.T) {
|
||||||
|
const numStreams = 1000
|
||||||
|
m := newActiveStreamMap(true, NewActiveStreamsMetrics("test", t.Name()))
|
||||||
|
|
||||||
|
var openedStreams sync.Map
|
||||||
|
|
||||||
|
// Add all the streams
|
||||||
|
{
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numStreams)
|
||||||
|
for i := 0; i < numStreams; i++ {
|
||||||
|
go func(streamID int) {
|
||||||
|
defer wg.Done()
|
||||||
|
stream := &MuxedStream{
|
||||||
|
streamID: uint32(streamID),
|
||||||
|
readBuffer: &noopBuffer{},
|
||||||
|
writeBuffer: &noopBuffer{},
|
||||||
|
readyList: &noopReadyList{},
|
||||||
|
}
|
||||||
|
ok := m.Set(stream)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
openedStreams.Store(stream.streamID, stream)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
|
||||||
|
|
||||||
|
shutdownChan, alreadyInProgress := m.Shutdown()
|
||||||
|
select {
|
||||||
|
case <-shutdownChan:
|
||||||
|
assert.Fail(t, "before Abort(), shutdownChan shouldn't be closed")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
assert.False(t, alreadyInProgress)
|
||||||
|
|
||||||
|
m.Abort()
|
||||||
|
assert.Equal(t, numStreams, m.Len(), "Abort() shouldn't delete any streams")
|
||||||
|
openedStreams.Range(func(key interface{}, value interface{}) bool {
|
||||||
|
stream := value.(*MuxedStream)
|
||||||
|
readBuffer := stream.readBuffer.(*noopBuffer)
|
||||||
|
writeBuffer := stream.writeBuffer.(*noopBuffer)
|
||||||
|
return assert.True(t, readBuffer.isClosed && writeBuffer.isClosed, "Abort() should have closed all the streams")
|
||||||
|
})
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-shutdownChan:
|
||||||
|
default:
|
||||||
|
assert.Fail(t, "after Abort(), shutdownChan should have been closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// multiple aborts shouldn't cause any issues
|
||||||
|
m.Abort()
|
||||||
|
m.Abort()
|
||||||
|
m.Abort()
|
||||||
|
}
|
|
@ -353,9 +353,11 @@ func (m *Muxer) Serve(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown is called to initiate the "happy path" of muxer termination.
|
// Shutdown is called to initiate the "happy path" of muxer termination.
|
||||||
func (m *Muxer) Shutdown() {
|
// It blocks new streams from being created.
|
||||||
|
// It returns a channel that is closed when the last stream has been closed.
|
||||||
|
func (m *Muxer) Shutdown() <-chan struct{} {
|
||||||
m.explicitShutdown.Fuse(true)
|
m.explicitShutdown.Fuse(true)
|
||||||
m.muxReader.Shutdown()
|
return m.muxReader.Shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
|
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
|
||||||
|
|
|
@ -55,6 +55,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
|
||||||
DefaultWindowSize: (1 << 8) - 1,
|
DefaultWindowSize: (1 << 8) - 1,
|
||||||
MaxWindowSize: (1 << 15) - 1,
|
MaxWindowSize: (1 << 15) - 1,
|
||||||
StreamWriteBufferMaxLen: 1024,
|
StreamWriteBufferMaxLen: 1024,
|
||||||
|
HeartbeatInterval: defaultTimeout,
|
||||||
|
MaxHeartbeats: defaultRetries,
|
||||||
},
|
},
|
||||||
OriginConn: origin,
|
OriginConn: origin,
|
||||||
EdgeMuxConfig: MuxerConfig{
|
EdgeMuxConfig: MuxerConfig{
|
||||||
|
@ -65,6 +67,8 @@ func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc)
|
||||||
DefaultWindowSize: (1 << 8) - 1,
|
DefaultWindowSize: (1 << 8) - 1,
|
||||||
MaxWindowSize: (1 << 15) - 1,
|
MaxWindowSize: (1 << 15) - 1,
|
||||||
StreamWriteBufferMaxLen: 1024,
|
StreamWriteBufferMaxLen: 1024,
|
||||||
|
HeartbeatInterval: defaultTimeout,
|
||||||
|
MaxHeartbeats: defaultRetries,
|
||||||
},
|
},
|
||||||
EdgeConn: edge,
|
EdgeConn: edge,
|
||||||
doneC: make(chan struct{}),
|
doneC: make(chan struct{}),
|
||||||
|
@ -83,6 +87,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
|
||||||
Name: "origin",
|
Name: "origin",
|
||||||
CompressionQuality: quality,
|
CompressionQuality: quality,
|
||||||
Logger: log.NewEntry(log.New()),
|
Logger: log.NewEntry(log.New()),
|
||||||
|
HeartbeatInterval: defaultTimeout,
|
||||||
|
MaxHeartbeats: defaultRetries,
|
||||||
},
|
},
|
||||||
OriginConn: origin,
|
OriginConn: origin,
|
||||||
EdgeMuxConfig: MuxerConfig{
|
EdgeMuxConfig: MuxerConfig{
|
||||||
|
@ -91,6 +97,8 @@ func NewCompressedMuxerPair(t assert.TestingT, testName string, quality Compress
|
||||||
Name: "edge",
|
Name: "edge",
|
||||||
CompressionQuality: quality,
|
CompressionQuality: quality,
|
||||||
Logger: log.NewEntry(log.New()),
|
Logger: log.NewEntry(log.New()),
|
||||||
|
HeartbeatInterval: defaultTimeout,
|
||||||
|
MaxHeartbeats: defaultRetries,
|
||||||
},
|
},
|
||||||
EdgeConn: edge,
|
EdgeConn: edge,
|
||||||
doneC: make(chan struct{}),
|
doneC: make(chan struct{}),
|
||||||
|
|
|
@ -17,6 +17,12 @@ type ReadWriteClosedCloser interface {
|
||||||
Closed() bool
|
Closed() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MuxedStreamDataSignaller is a write-only *ReadyList
|
||||||
|
type MuxedStreamDataSignaller interface {
|
||||||
|
// Non-blocking: call this when data is ready to be sent for the given stream ID.
|
||||||
|
Signal(ID uint32)
|
||||||
|
}
|
||||||
|
|
||||||
// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data.
|
// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data.
|
||||||
type MuxedStream struct {
|
type MuxedStream struct {
|
||||||
streamID uint32
|
streamID uint32
|
||||||
|
@ -55,8 +61,8 @@ type MuxedStream struct {
|
||||||
// This is the amount of bytes that are in the peer's receive window
|
// This is the amount of bytes that are in the peer's receive window
|
||||||
// (how much data we can send from this stream).
|
// (how much data we can send from this stream).
|
||||||
sendWindow uint32
|
sendWindow uint32
|
||||||
// Reference to the muxer's readyList; signal this for stream data to be sent.
|
// The muxer's readyList
|
||||||
readyList *ReadyList
|
readyList MuxedStreamDataSignaller
|
||||||
// The headers that should be sent, and a flag so we only send them once.
|
// The headers that should be sent, and a flag so we only send them once.
|
||||||
headersSent bool
|
headersSent bool
|
||||||
writeHeaders []Header
|
writeHeaders []Header
|
||||||
|
@ -88,7 +94,7 @@ func (th TunnelHostname) IsSet() bool {
|
||||||
return th != ""
|
return th != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStream(config MuxerConfig, writeHeaders []Header, readyList *ReadyList, dictionaries h2Dictionaries) *MuxedStream {
|
func NewStream(config MuxerConfig, writeHeaders []Header, readyList MuxedStreamDataSignaller, dictionaries h2Dictionaries) *MuxedStream {
|
||||||
return &MuxedStream{
|
return &MuxedStream{
|
||||||
responseHeadersReceived: make(chan struct{}),
|
responseHeadersReceived: make(chan struct{}),
|
||||||
readBuffer: NewSharedBuffer(),
|
readBuffer: NewSharedBuffer(),
|
||||||
|
|
|
@ -51,10 +51,12 @@ type MuxReader struct {
|
||||||
dictionaries h2Dictionaries
|
dictionaries h2Dictionaries
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MuxReader) Shutdown() {
|
// Shutdown blocks new streams from being created.
|
||||||
done := r.streams.Shutdown()
|
// It returns a channel that is closed once the last stream has closed.
|
||||||
if done == nil {
|
func (r *MuxReader) Shutdown() <-chan struct{} {
|
||||||
return
|
done, alreadyInProgress := r.streams.Shutdown()
|
||||||
|
if alreadyInProgress {
|
||||||
|
return done
|
||||||
}
|
}
|
||||||
r.sendGoAway(http2.ErrCodeNo)
|
r.sendGoAway(http2.ErrCodeNo)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -62,6 +64,7 @@ func (r *MuxReader) Shutdown() {
|
||||||
<-done
|
<-done
|
||||||
r.r.Close()
|
r.r.Close()
|
||||||
}()
|
}()
|
||||||
|
return done
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MuxReader) run(parentLogger *log.Entry) error {
|
func (r *MuxReader) run(parentLogger *log.Entry) error {
|
||||||
|
|
Loading…
Reference in New Issue