package h2mux

import (
	"context"
	"io"
	"strings"
	"sync"
	"time"

	"github.com/prometheus/client_golang/prometheus"
	"golang.org/x/net/http2"
	"golang.org/x/net/http2/hpack"
	"golang.org/x/sync/errgroup"

	"github.com/cloudflare/cloudflared/logger"
)

const (
	defaultFrameSize         uint32        = 1 << 14       // Minimum frame size in http2 spec
	defaultWindowSize        uint32        = (1 << 16) - 1 // Minimum window size in http2 spec
	maxWindowSize            uint32        = (1 << 31) - 1 // 2^31-1 = 2147483647, max window size in http2 spec
	defaultTimeout           time.Duration = 5 * time.Second
	defaultRetries           uint64        = 5
	defaultWriteBufferMaxLen int           = 1024 * 1024 // 1mb
	writeBufferInitialSize   int           = 16 * 1024 // 16KB

	SettingMuxerMagic http2.SettingID = 0x42db
	MuxerMagicOrigin  uint32          = 0xa2e43c8b
	MuxerMagicEdge    uint32          = 0x1088ebf9
)

type MuxedStreamHandler interface {
	ServeStream(*MuxedStream) error
}

type MuxedStreamFunc func(stream *MuxedStream) error

func (f MuxedStreamFunc) ServeStream(stream *MuxedStream) error {
	return f(stream)
}

type MuxerConfig struct {
	Timeout  time.Duration
	Handler  MuxedStreamHandler
	IsClient bool
	// Name is used to identify this muxer instance when logging.
	Name string
	// The minimum time this connection can be idle before sending a heartbeat.
	HeartbeatInterval time.Duration
	// The minimum number of heartbeats to send before terminating the connection.
	MaxHeartbeats uint64
	// Logger to use
	Logger             logger.Service
	CompressionQuality CompressionSetting
	// Initial size for HTTP2 flow control windows
	DefaultWindowSize uint32
	// Largest allowable size for HTTP2 flow control windows
	MaxWindowSize uint32
	// Largest allowable capacity for the buffer of data to be sent
	StreamWriteBufferMaxLen int
}

type Muxer struct {
	// f is used to read and write HTTP2 frames on the wire.
	f *http2.Framer
	// config is the MuxerConfig given in Handshake.
	config MuxerConfig
	// w, r are references to the underlying connection used.
	w io.WriteCloser
	r io.ReadCloser
	// muxReader is the read process.
	muxReader *MuxReader
	// muxWriter is the write process.
	muxWriter *MuxWriter
	// muxMetricsUpdater is the process to update metrics
	muxMetricsUpdater muxMetricsUpdater
	// newStreamChan is used to create new streams on the writer thread.
	// The writer will assign the next available stream ID.
	newStreamChan chan MuxedStreamRequest
	// abortChan is used to abort the writer event loop.
	abortChan chan struct{}
	// abortOnce is used to ensure abortChan is closed once only.
	abortOnce sync.Once
	// readyList is used to signal writable streams.
	readyList *ReadyList
	// streams tracks currently-open streams.
	streams *activeStreamMap
	// explicitShutdown records whether the Muxer is closing because Shutdown was called, or due to another
	// error.
	explicitShutdown *BooleanFuse

	compressionQuality CompressionPreset
}

func RPCHeaders() []Header {
	return []Header{
		{Name: ":method", Value: "RPC"},
		{Name: ":scheme", Value: "capnp"},
		{Name: ":path", Value: "*"},
	}
}

// Handshake establishes a muxed connection with the peer.
// After the handshake completes, it is possible to open and accept streams.
func Handshake(
	w io.WriteCloser,
	r io.ReadCloser,
	config MuxerConfig,
	activeStreamsMetrics prometheus.Gauge,
) (*Muxer, error) {
	// Set default config values
	if config.Timeout == 0 {
		config.Timeout = defaultTimeout
	}
	if config.DefaultWindowSize == 0 {
		config.DefaultWindowSize = defaultWindowSize
	}
	if config.MaxWindowSize == 0 {
		config.MaxWindowSize = maxWindowSize
	}
	if config.StreamWriteBufferMaxLen == 0 {
		config.StreamWriteBufferMaxLen = defaultWriteBufferMaxLen
	}
	// Initialise connection state fields
	m := &Muxer{
		f:             http2.NewFramer(w, r), // A framer that writes to w and reads from r
		config:        config,
		w:             w,
		r:             r,
		newStreamChan: make(chan MuxedStreamRequest),
		abortChan:     make(chan struct{}),
		readyList:     NewReadyList(),
		streams:       newActiveStreamMap(config.IsClient, activeStreamsMetrics),
	}

	m.f.ReadMetaHeaders = hpack.NewDecoder(4096, func(hpack.HeaderField) {})
	// Initialise the settings to identify this connection and confirm the other end is sane.
	handshakeSetting := http2.Setting{ID: SettingMuxerMagic, Val: MuxerMagicEdge}
	compressionSetting := http2.Setting{ID: SettingCompression, Val: config.CompressionQuality.toH2Setting()}
	if CompressionIsSupported() {
		config.Logger.Debug("muxer: Compression is supported")
		m.compressionQuality = config.CompressionQuality.getPreset()
	} else {
		config.Logger.Debug("muxer: Compression is not supported")
		compressionSetting = http2.Setting{ID: SettingCompression, Val: 0}
	}

	expectedMagic := MuxerMagicOrigin
	if config.IsClient {
		handshakeSetting.Val = MuxerMagicOrigin
		expectedMagic = MuxerMagicEdge
	}
	errChan := make(chan error, 2)
	// Simultaneously send our settings and verify the peer's settings.
	go func() { errChan <- m.f.WriteSettings(handshakeSetting, compressionSetting) }()
	go func() { errChan <- m.readPeerSettings(expectedMagic) }()
	err := joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
	if err != nil {
		return nil, err
	}
	// Confirm sanity by ACKing the frame and expecting an ACK for our frame.
	// Not strictly necessary, but let's pretend to be H2-like.
	go func() { errChan <- m.f.WriteSettingsAck() }()
	go func() { errChan <- m.readPeerSettingsAck() }()
	err = joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
	if err != nil {
		return nil, err
	}

	// set up reader/writer pair ready for serve
	streamErrors := NewStreamErrorMap()
	goAwayChan := make(chan http2.ErrCode, 1)
	inBoundCounter := NewAtomicCounter(0)
	outBoundCounter := NewAtomicCounter(0)
	pingTimestamp := NewPingTimestamp()
	connActive := NewSignal()
	idleDuration := config.HeartbeatInterval
	// Sanity check to enusre idelDuration is sane
	if idleDuration == 0 || idleDuration < defaultTimeout {
		idleDuration = defaultTimeout
		config.Logger.Infof("muxer: Minimum idle time has been adjusted to %d", defaultTimeout)
	}
	maxRetries := config.MaxHeartbeats
	if maxRetries == 0 {
		maxRetries = defaultRetries
		config.Logger.Infof("muxer: Minimum number of unacked heartbeats to send before closing the connection has been adjusted to %d", maxRetries)
	}

	compBytesBefore, compBytesAfter := NewAtomicCounter(0), NewAtomicCounter(0)

	m.muxMetricsUpdater = newMuxMetricsUpdater(
		m.abortChan,
		compBytesBefore,
		compBytesAfter,
	)

	m.explicitShutdown = NewBooleanFuse()
	m.muxReader = &MuxReader{
		f:                       m.f,
		handler:                 m.config.Handler,
		streams:                 m.streams,
		readyList:               m.readyList,
		streamErrors:            streamErrors,
		goAwayChan:              goAwayChan,
		abortChan:               m.abortChan,
		pingTimestamp:           pingTimestamp,
		connActive:              connActive,
		initialStreamWindow:     m.config.DefaultWindowSize,
		streamWindowMax:         m.config.MaxWindowSize,
		streamWriteBufferMaxLen: m.config.StreamWriteBufferMaxLen,
		r:                       m.r,
		metricsUpdater:          m.muxMetricsUpdater,
		bytesRead:               inBoundCounter,
	}
	m.muxWriter = &MuxWriter{
		f:               m.f,
		streams:         m.streams,
		streamErrors:    streamErrors,
		readyStreamChan: m.readyList.ReadyChannel(),
		newStreamChan:   m.newStreamChan,
		goAwayChan:      goAwayChan,
		abortChan:       m.abortChan,
		pingTimestamp:   pingTimestamp,
		idleTimer:       NewIdleTimer(idleDuration, maxRetries),
		connActiveChan:  connActive.WaitChannel(),
		maxFrameSize:    defaultFrameSize,
		metricsUpdater:  m.muxMetricsUpdater,
		bytesWrote:      outBoundCounter,
	}
	m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)

	if m.compressionQuality.dictSize > 0 && m.compressionQuality.nDicts > 0 {
		nd, sz := m.compressionQuality.nDicts, m.compressionQuality.dictSize
		writeDicts, dictChan := newH2WriteDictionaries(
			nd,
			sz,
			m.compressionQuality.quality,
			compBytesBefore,
			compBytesAfter,
		)
		readDicts := newH2ReadDictionaries(nd, sz)
		m.muxReader.dictionaries = h2Dictionaries{read: &readDicts, write: writeDicts}
		m.muxWriter.useDictChan = dictChan
	}

	return m, nil
}

func (m *Muxer) readPeerSettings(magic uint32) error {
	frame, err := m.f.ReadFrame()
	if err != nil {
		return err
	}
	settingsFrame, ok := frame.(*http2.SettingsFrame)
	if !ok {
		return ErrBadHandshakeNotSettings
	}
	if settingsFrame.Header().Flags != 0 {
		return ErrBadHandshakeUnexpectedAck
	}
	peerMagic, ok := settingsFrame.Value(SettingMuxerMagic)
	if !ok {
		return ErrBadHandshakeNoMagic
	}
	if magic != peerMagic {
		return ErrBadHandshakeWrongMagic
	}
	peerCompression, ok := settingsFrame.Value(SettingCompression)
	if !ok {
		m.compressionQuality = compressionPresets[CompressionNone]
		return nil
	}
	ver, fmt, sz, nd := parseCompressionSettingVal(peerCompression)
	if ver != compressionVersion || fmt != compressionFormat || sz == 0 || nd == 0 {
		m.compressionQuality = compressionPresets[CompressionNone]
		return nil
	}
	// Values used for compression are the mimimum between the two peers
	if sz < m.compressionQuality.dictSize {
		m.compressionQuality.dictSize = sz
	}
	if nd < m.compressionQuality.nDicts {
		m.compressionQuality.nDicts = nd
	}
	return nil
}

func (m *Muxer) readPeerSettingsAck() error {
	frame, err := m.f.ReadFrame()
	if err != nil {
		return err
	}
	settingsFrame, ok := frame.(*http2.SettingsFrame)
	if !ok {
		return ErrBadHandshakeNotSettingsAck
	}
	if settingsFrame.Header().Flags != http2.FlagSettingsAck {
		return ErrBadHandshakeUnexpectedSettings
	}
	return nil
}

func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time.Duration, timeoutError error) error {
	for i := 0; i < receiveCount; i++ {
		select {
		case err := <-errChan:
			if err != nil {
				return err
			}
		case <-time.After(timeout):
			return timeoutError
		}
	}
	return nil
}

// Serve runs the event loops that comprise h2mux:
// - MuxReader.run()
// - MuxWriter.run()
// - muxMetricsUpdater.run()
// In the normal case, Shutdown() is called concurrently with Serve() to stop
// these loops.
func (m *Muxer) Serve(ctx context.Context) error {
	errGroup, _ := errgroup.WithContext(ctx)
	errGroup.Go(func() error {
		ch := make(chan error)
		go func() {
			err := m.muxReader.run(m.config.Logger)
			m.explicitShutdown.Fuse(false)
			m.r.Close()
			m.abort()
			// don't block if parent goroutine quit early
			select {
			case ch <- err:
			default:
			}
		}()
		select {
		case err := <-ch:
			return err
		case <-ctx.Done():
			return ctx.Err()
		}
	})

	errGroup.Go(func() error {
		ch := make(chan error)
		go func() {
			err := m.muxWriter.run(m.config.Logger)
			m.explicitShutdown.Fuse(false)
			m.w.Close()
			m.abort()
			// don't block if parent goroutine quit early
			select {
			case ch <- err:
			default:
			}
		}()
		select {
		case err := <-ch:
			return err
		case <-ctx.Done():
			return ctx.Err()
		}
	})

	errGroup.Go(func() error {
		ch := make(chan error)
		go func() {
			err := m.muxMetricsUpdater.run(m.config.Logger)
			// don't block if parent goroutine quit early
			select {
			case ch <- err:
			default:
			}
		}()
		select {
		case err := <-ch:
			return err
		case <-ctx.Done():
			return ctx.Err()
		}
	})

	err := errGroup.Wait()
	if isUnexpectedTunnelError(err, m.explicitShutdown.Value()) {
		return err
	}
	return nil
}

// Shutdown is called to initiate the "happy path" of muxer termination.
// It blocks new streams from being created.
// It returns a channel that is closed when the last stream has been closed.
func (m *Muxer) Shutdown() <-chan struct{} {
	m.explicitShutdown.Fuse(true)
	return m.muxReader.Shutdown()
}

// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
// The set of expected errors change depending on whether we initiated shutdown or not.
func isUnexpectedTunnelError(err error, expectedShutdown bool) bool {
	if err == nil {
		return false
	}
	if !expectedShutdown {
		return true
	}
	return !isConnectionClosedError(err)
}

func isConnectionClosedError(err error) bool {
	if err == io.EOF {
		return true
	}
	if err == io.ErrClosedPipe {
		return true
	}
	if err.Error() == "tls: use of closed connection" {
		return true
	}
	if strings.HasSuffix(err.Error(), "use of closed network connection") {
		return true
	}
	return false
}

// OpenStream opens a new data stream with the given headers.
// Called by proxy server and tunnel
func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader) (*MuxedStream, error) {
	stream := m.NewStream(headers)
	if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, body)); err != nil {
		return nil, err
	}
	if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
		return nil, err
	}
	return stream, nil
}

func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) {
	stream := m.NewStream(RPCHeaders())
	if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, nil)); err != nil {
		return nil, err
	}
	if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
		return nil, err
	}
	return stream, nil
}

func (m *Muxer) NewStream(headers []Header) *MuxedStream {
	return NewStream(m.config, headers, m.readyList, m.muxReader.dictionaries)
}

func (m *Muxer) MakeMuxedStreamRequest(ctx context.Context, request MuxedStreamRequest) error {
	select {
	case <-ctx.Done():
		return ErrStreamRequestTimeout
	case <-m.abortChan:
		return ErrStreamRequestConnectionClosed
	// Will be received by mux writer
	case m.newStreamChan <- request:
		return nil
	}
}

func (m *Muxer) CloseStreamRead(stream *MuxedStream) {
	stream.CloseRead()
	if stream.WriteClosed() {
		m.streams.Delete(stream.streamID)
	}
}

func (m *Muxer) AwaitResponseHeaders(ctx context.Context, stream *MuxedStream) error {
	select {
	case <-ctx.Done():
		return ErrResponseHeadersTimeout
	case <-m.abortChan:
		return ErrResponseHeadersConnectionClosed
	case <-stream.responseHeadersReceived:
		return nil
	}
}

func (m *Muxer) Metrics() *MuxerMetrics {
	return m.muxMetricsUpdater.metrics()
}

func (m *Muxer) abort() {
	m.abortOnce.Do(func() {
		close(m.abortChan)
		m.readyList.Close()
		m.streams.Abort()
	})
}

// Return how many retries/ticks since the connection was last marked active
func (m *Muxer) TimerRetries() uint64 {
	return m.muxWriter.idleTimer.RetryCount()
}