package h2mux

import (
	"sync"

	"github.com/prometheus/client_golang/prometheus"
	"golang.org/x/net/http2"
)

// activeStreamMap is used to moderate access to active streams between the read and write
// threads, and deny access to new peer streams while shutting down.
type activeStreamMap struct {
	sync.RWMutex
	// streams tracks open streams.
	streams map[uint32]*MuxedStream
	// nextStreamID is the next ID to use on our side of the connection.
	// This is odd for clients, even for servers.
	nextStreamID uint32
	// maxPeerStreamID is the ID of the most recent stream opened by the peer.
	maxPeerStreamID uint32
	// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
	activeStreams prometheus.Gauge

	// ignoreNewStreams is true when the connection is being shut down. New streams
	// cannot be registered.
	ignoreNewStreams bool
	// streamsEmpty is a chan that will be closed when no more streams are open.
	streamsEmptyChan chan struct{}
	closeOnce        sync.Once
}

func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap {
	m := &activeStreamMap{
		streams:          make(map[uint32]*MuxedStream),
		streamsEmptyChan: make(chan struct{}),
		nextStreamID:     1,
		activeStreams:    activeStreams,
	}
	// Client initiated stream uses odd stream ID, server initiated stream uses even stream ID
	if !useClientStreamNumbers {
		m.nextStreamID = 2
	}
	return m
}

// This function should be called while `m` is locked.
func (m *activeStreamMap) notifyStreamsEmpty() {
	m.closeOnce.Do(func() {
		close(m.streamsEmptyChan)
	})
}

// Len returns the number of active streams.
func (m *activeStreamMap) Len() int {
	m.RLock()
	defer m.RUnlock()
	return len(m.streams)
}

func (m *activeStreamMap) Get(streamID uint32) (*MuxedStream, bool) {
	m.RLock()
	defer m.RUnlock()
	stream, ok := m.streams[streamID]
	return stream, ok
}

// Set returns true if the stream was assigned successfully. If a stream
// already existed with that ID or we are shutting down, return false.
func (m *activeStreamMap) Set(newStream *MuxedStream) bool {
	m.Lock()
	defer m.Unlock()
	if _, ok := m.streams[newStream.streamID]; ok {
		return false
	}
	if m.ignoreNewStreams {
		return false
	}
	m.streams[newStream.streamID] = newStream
	m.activeStreams.Inc()
	return true
}

// Delete stops tracking the stream. It should be called only after it is closed and resetted.
func (m *activeStreamMap) Delete(streamID uint32) {
	m.Lock()
	defer m.Unlock()
	if _, ok := m.streams[streamID]; ok {
		delete(m.streams, streamID)
		m.activeStreams.Dec()
	}

	// shutting down, and now the map is empty
	if m.ignoreNewStreams && len(m.streams) == 0 {
		m.notifyStreamsEmpty()
	}
}

// Shutdown blocks new streams from being created.
// It returns `done`, a channel that is closed once the last stream has closed
// and `progress`, whether a shutdown was already in progress
func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bool) {
	m.Lock()
	defer m.Unlock()
	if m.ignoreNewStreams {
		// already shutting down
		return m.streamsEmptyChan, true
	}
	m.ignoreNewStreams = true
	if len(m.streams) == 0 {
		// there are no streams to wait for
		m.notifyStreamsEmpty()
	}
	return m.streamsEmptyChan, false
}

// AcquireLocalID acquires a new stream ID for a stream you're opening.
func (m *activeStreamMap) AcquireLocalID() uint32 {
	m.Lock()
	defer m.Unlock()
	x := m.nextStreamID
	m.nextStreamID += 2
	return x
}

// ObservePeerID observes the ID of a stream opened by the peer. It returns true if we should accept
// the new stream, or false to reject it. The ErrCode gives the reason why.
func (m *activeStreamMap) AcquirePeerID(streamID uint32) (bool, http2.ErrCode) {
	m.Lock()
	defer m.Unlock()
	switch {
	case m.ignoreNewStreams:
		return false, http2.ErrCodeStreamClosed
	case streamID > m.maxPeerStreamID:
		m.maxPeerStreamID = streamID
		return true, http2.ErrCodeNo
	default:
		return false, http2.ErrCodeStreamClosed
	}
}

// IsPeerStreamID is true if the stream ID belongs to the peer.
func (m *activeStreamMap) IsPeerStreamID(streamID uint32) bool {
	m.RLock()
	defer m.RUnlock()
	return (streamID % 2) != (m.nextStreamID % 2)
}

// IsLocalStreamID is true if it is a stream we have opened, even if it is now closed.
func (m *activeStreamMap) IsLocalStreamID(streamID uint32) bool {
	m.RLock()
	defer m.RUnlock()
	return (streamID%2) == (m.nextStreamID%2) && streamID < m.nextStreamID
}

// LastPeerStreamID returns the most recently opened peer stream ID.
func (m *activeStreamMap) LastPeerStreamID() uint32 {
	m.RLock()
	defer m.RUnlock()
	return m.maxPeerStreamID
}

// LastLocalStreamID returns the most recently opened local stream ID.
func (m *activeStreamMap) LastLocalStreamID() uint32 {
	m.RLock()
	defer m.RUnlock()
	if m.nextStreamID > 1 {
		return m.nextStreamID - 2
	}
	return 0
}

// Abort closes every active stream and prevents new ones being created. This should be used to
// return errors in pending read/writes when the underlying connection goes away.
func (m *activeStreamMap) Abort() {
	m.Lock()
	defer m.Unlock()
	for _, stream := range m.streams {
		stream.Close()
	}
	m.ignoreNewStreams = true
	m.notifyStreamsEmpty()
}