cloudflared-mirror/h2mux/activestreammap.go

180 lines
5.0 KiB
Go

package h2mux
import (
"sync"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/net/http2"
)
// activeStreamMap is used to moderate access to active streams between the read and write
// threads, and deny access to new peer streams while shutting down.
type activeStreamMap struct {
sync.RWMutex
// streams tracks open streams.
streams map[uint32]*MuxedStream
// nextStreamID is the next ID to use on our side of the connection.
// This is odd for clients, even for servers.
nextStreamID uint32
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
maxPeerStreamID uint32
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
activeStreams prometheus.Gauge
// ignoreNewStreams is true when the connection is being shut down. New streams
// cannot be registered.
ignoreNewStreams bool
// streamsEmpty is a chan that will be closed when no more streams are open.
streamsEmptyChan chan struct{}
closeOnce sync.Once
}
func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap {
m := &activeStreamMap{
streams: make(map[uint32]*MuxedStream),
streamsEmptyChan: make(chan struct{}),
nextStreamID: 1,
activeStreams: activeStreams,
}
// Client initiated stream uses odd stream ID, server initiated stream uses even stream ID
if !useClientStreamNumbers {
m.nextStreamID = 2
}
return m
}
func (m *activeStreamMap) notifyStreamsEmpty() {
m.closeOnce.Do(func() {
close(m.streamsEmptyChan)
})
}
// Len returns the number of active streams.
func (m *activeStreamMap) Len() int {
m.RLock()
defer m.RUnlock()
return len(m.streams)
}
func (m *activeStreamMap) Get(streamID uint32) (*MuxedStream, bool) {
m.RLock()
defer m.RUnlock()
stream, ok := m.streams[streamID]
return stream, ok
}
// Set returns true if the stream was assigned successfully. If a stream
// already existed with that ID or we are shutting down, return false.
func (m *activeStreamMap) Set(newStream *MuxedStream) bool {
m.Lock()
defer m.Unlock()
if _, ok := m.streams[newStream.streamID]; ok {
return false
}
if m.ignoreNewStreams {
return false
}
m.streams[newStream.streamID] = newStream
m.activeStreams.Inc()
return true
}
// Delete stops tracking the stream. It should be called only after it is closed and resetted.
func (m *activeStreamMap) Delete(streamID uint32) {
m.Lock()
defer m.Unlock()
if _, ok := m.streams[streamID]; ok {
delete(m.streams, streamID)
m.activeStreams.Dec()
}
if len(m.streams) == 0 {
m.notifyStreamsEmpty()
}
}
// Shutdown blocks new streams from being created.
// It returns `done`, a channel that is closed once the last stream has closed
// and `progress`, whether a shutdown was already in progress
func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bool) {
m.Lock()
defer m.Unlock()
if m.ignoreNewStreams {
// already shutting down
return m.streamsEmptyChan, true
}
m.ignoreNewStreams = true
if len(m.streams) == 0 {
// nothing to shut down
m.notifyStreamsEmpty()
}
return m.streamsEmptyChan, false
}
// AcquireLocalID acquires a new stream ID for a stream you're opening.
func (m *activeStreamMap) AcquireLocalID() uint32 {
m.Lock()
defer m.Unlock()
x := m.nextStreamID
m.nextStreamID += 2
return x
}
// ObservePeerID observes the ID of a stream opened by the peer. It returns true if we should accept
// the new stream, or false to reject it. The ErrCode gives the reason why.
func (m *activeStreamMap) AcquirePeerID(streamID uint32) (bool, http2.ErrCode) {
m.Lock()
defer m.Unlock()
switch {
case m.ignoreNewStreams:
return false, http2.ErrCodeStreamClosed
case streamID > m.maxPeerStreamID:
m.maxPeerStreamID = streamID
return true, http2.ErrCodeNo
default:
return false, http2.ErrCodeStreamClosed
}
}
// IsPeerStreamID is true if the stream ID belongs to the peer.
func (m *activeStreamMap) IsPeerStreamID(streamID uint32) bool {
m.RLock()
defer m.RUnlock()
return (streamID % 2) != (m.nextStreamID % 2)
}
// IsLocalStreamID is true if it is a stream we have opened, even if it is now closed.
func (m *activeStreamMap) IsLocalStreamID(streamID uint32) bool {
m.RLock()
defer m.RUnlock()
return (streamID%2) == (m.nextStreamID%2) && streamID < m.nextStreamID
}
// LastPeerStreamID returns the most recently opened peer stream ID.
func (m *activeStreamMap) LastPeerStreamID() uint32 {
m.RLock()
defer m.RUnlock()
return m.maxPeerStreamID
}
// LastLocalStreamID returns the most recently opened local stream ID.
func (m *activeStreamMap) LastLocalStreamID() uint32 {
m.RLock()
defer m.RUnlock()
if m.nextStreamID > 1 {
return m.nextStreamID - 2
}
return 0
}
// Abort closes every active stream and prevents new ones being created. This should be used to
// return errors in pending read/writes when the underlying connection goes away.
func (m *activeStreamMap) Abort() {
m.Lock()
defer m.Unlock()
for _, stream := range m.streams {
stream.Close()
}
m.ignoreNewStreams = true
m.notifyStreamsEmpty()
}