326 lines
9.3 KiB
Go
326 lines
9.3 KiB
Go
package h2mux
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"io"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
type MuxReader struct {
|
|
// f is used to read HTTP2 frames.
|
|
f *http2.Framer
|
|
// handler provides a callback to receive new streams. if nil, new streams cannot be accepted.
|
|
handler MuxedStreamHandler
|
|
// streams tracks currently-open streams.
|
|
streams *activeStreamMap
|
|
// readyList is used to signal writable streams.
|
|
readyList *ReadyList
|
|
// streamErrors lets us report stream errors to the MuxWriter.
|
|
streamErrors *StreamErrorMap
|
|
// goAwayChan is used to tell the writer to send a GOAWAY message.
|
|
goAwayChan chan<- http2.ErrCode
|
|
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
|
|
abortChan <-chan struct{}
|
|
// pingTimestamp is an atomic value containing the latest received ping timestamp.
|
|
pingTimestamp *PingTimestamp
|
|
// connActive is used to signal to the writer that something happened on the connection.
|
|
// This is used to clear idle timeout disconnection deadlines.
|
|
connActive Signal
|
|
// The initial value for the send and receive window of a new stream.
|
|
initialStreamWindow uint32
|
|
// The max value for the send window of a stream.
|
|
streamWindowMax uint32
|
|
// windowMetrics keeps track of min/max/average of send/receive windows for all streams
|
|
flowControlMetrics *FlowControlMetrics
|
|
metricsMutex sync.Mutex
|
|
// r is a reference to the underlying connection used when shutting down.
|
|
r io.Closer
|
|
// rttMeasurement measures RTT based on ping timestamps.
|
|
rttMeasurement RTTMeasurement
|
|
rttMutex sync.Mutex
|
|
}
|
|
|
|
func (r *MuxReader) Shutdown() {
|
|
done := r.streams.Shutdown()
|
|
if done == nil {
|
|
return
|
|
}
|
|
r.sendGoAway(http2.ErrCodeNo)
|
|
go func() {
|
|
// close reader side when last stream ends; this will cause the writer to abort
|
|
<-done
|
|
r.r.Close()
|
|
}()
|
|
}
|
|
|
|
func (r *MuxReader) RTT() RTTMeasurement {
|
|
r.rttMutex.Lock()
|
|
defer r.rttMutex.Unlock()
|
|
return r.rttMeasurement
|
|
}
|
|
|
|
func (r *MuxReader) FlowControlMetrics() *FlowControlMetrics {
|
|
r.metricsMutex.Lock()
|
|
defer r.metricsMutex.Unlock()
|
|
if r.flowControlMetrics != nil {
|
|
return r.flowControlMetrics
|
|
}
|
|
// No metrics available yet
|
|
return &FlowControlMetrics{}
|
|
}
|
|
|
|
func (r *MuxReader) run(parentLogger *log.Entry) error {
|
|
logger := parentLogger.WithFields(log.Fields{
|
|
"subsystem": "mux",
|
|
"dir": "read",
|
|
})
|
|
defer logger.Debug("event loop finished")
|
|
for {
|
|
frame, err := r.f.ReadFrame()
|
|
if err != nil {
|
|
switch e := err.(type) {
|
|
case http2.StreamError:
|
|
logger.WithError(err).Warn("stream error")
|
|
r.streamError(e.StreamID, e.Code)
|
|
case http2.ConnectionError:
|
|
logger.WithError(err).Warn("connection error")
|
|
return r.connectionError(err)
|
|
default:
|
|
if isConnectionClosedError(err) {
|
|
if r.streams.Len() == 0 {
|
|
logger.Debug("shutting down")
|
|
return nil
|
|
}
|
|
logger.Warn("connection closed unexpectedly")
|
|
return err
|
|
} else {
|
|
logger.WithError(err).Warn("frame read error")
|
|
return r.connectionError(err)
|
|
}
|
|
}
|
|
}
|
|
r.connActive.Signal()
|
|
logger.WithField("data", frame).Debug("read frame")
|
|
switch f := frame.(type) {
|
|
case *http2.DataFrame:
|
|
err = r.receiveFrameData(f, logger)
|
|
case *http2.MetaHeadersFrame:
|
|
err = r.receiveHeaderData(f)
|
|
case *http2.RSTStreamFrame:
|
|
streamID := f.Header().StreamID
|
|
if streamID == 0 {
|
|
return ErrInvalidStream
|
|
}
|
|
r.streams.Delete(streamID)
|
|
case *http2.PingFrame:
|
|
r.receivePingData(f)
|
|
case *http2.GoAwayFrame:
|
|
err = r.receiveGoAway(f)
|
|
case *http2.WindowUpdateFrame:
|
|
err = r.updateStreamWindow(f)
|
|
default:
|
|
err = ErrUnexpectedFrameType
|
|
}
|
|
if err != nil {
|
|
logger.WithField("data", frame).WithError(err).Debug("frame error")
|
|
return r.connectionError(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *MuxReader) newMuxedStream(streamID uint32) *MuxedStream {
|
|
return &MuxedStream{
|
|
streamID: streamID,
|
|
readBuffer: NewSharedBuffer(),
|
|
receiveWindow: r.initialStreamWindow,
|
|
receiveWindowCurrentMax: r.initialStreamWindow,
|
|
receiveWindowMax: r.streamWindowMax,
|
|
sendWindow: r.initialStreamWindow,
|
|
readyList: r.readyList,
|
|
}
|
|
}
|
|
|
|
// getStreamForFrame returns a stream if valid, or an error describing why the stream could not be returned.
|
|
func (r *MuxReader) getStreamForFrame(frame http2.Frame) (*MuxedStream, error) {
|
|
sid := frame.Header().StreamID
|
|
if sid == 0 {
|
|
return nil, ErrUnexpectedFrameType
|
|
}
|
|
if stream, ok := r.streams.Get(sid); ok {
|
|
return stream, nil
|
|
}
|
|
if r.streams.IsLocalStreamID(sid) {
|
|
// no stream available, but no error
|
|
return nil, ErrClosedStream
|
|
}
|
|
if sid < r.streams.LastPeerStreamID() {
|
|
// no stream available, stream closed error
|
|
return nil, ErrClosedStream
|
|
}
|
|
return nil, ErrUnknownStream
|
|
}
|
|
|
|
func (r *MuxReader) defaultStreamErrorHandler(err error, header http2.FrameHeader) error {
|
|
if header.Flags.Has(http2.FlagHeadersEndStream) {
|
|
return nil
|
|
} else if err == ErrUnknownStream || err == ErrClosedStream {
|
|
return r.streamError(header.StreamID, http2.ErrCodeStreamClosed)
|
|
} else {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Receives header frames from a stream. A non-nil error is a connection error.
|
|
func (r *MuxReader) receiveHeaderData(frame *http2.MetaHeadersFrame) error {
|
|
var stream *MuxedStream
|
|
sid := frame.Header().StreamID
|
|
if sid == 0 {
|
|
return ErrUnexpectedFrameType
|
|
}
|
|
newStream := r.streams.IsPeerStreamID(sid)
|
|
if newStream {
|
|
// header request
|
|
// TODO support trailers (if stream exists)
|
|
ok, err := r.streams.AcquirePeerID(sid)
|
|
if !ok {
|
|
// ignore new streams while shutting down
|
|
return r.streamError(sid, err)
|
|
}
|
|
stream = r.newMuxedStream(sid)
|
|
// Set stream. Returns false if a stream already existed with that ID or we are shutting down, return false.
|
|
if !r.streams.Set(stream) {
|
|
// got HEADERS frame for an existing stream
|
|
// TODO support trailers
|
|
return r.streamError(sid, http2.ErrCodeInternal)
|
|
}
|
|
} else {
|
|
// header response
|
|
var err error
|
|
if stream, err = r.getStreamForFrame(frame); err != nil {
|
|
return r.defaultStreamErrorHandler(err, frame.Header())
|
|
}
|
|
}
|
|
headers := make([]Header, len(frame.Fields))
|
|
for i, header := range frame.Fields {
|
|
headers[i].Name = header.Name
|
|
headers[i].Value = header.Value
|
|
}
|
|
stream.Headers = headers
|
|
if frame.Header().Flags.Has(http2.FlagHeadersEndStream) {
|
|
stream.receiveEOF()
|
|
return nil
|
|
}
|
|
if newStream {
|
|
go r.handleStream(stream)
|
|
} else {
|
|
close(stream.responseHeadersReceived)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *MuxReader) handleStream(stream *MuxedStream) {
|
|
defer stream.Close()
|
|
r.handler.ServeStream(stream)
|
|
}
|
|
|
|
// Receives a data frame from a stream. A non-nil error is a connection error.
|
|
func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.Entry) error {
|
|
logger := parentLogger.WithField("stream", frame.Header().StreamID)
|
|
stream, err := r.getStreamForFrame(frame)
|
|
if err != nil {
|
|
return r.defaultStreamErrorHandler(err, frame.Header())
|
|
}
|
|
data := frame.Data()
|
|
if len(data) > 0 {
|
|
_, err = stream.readBuffer.Write(data)
|
|
if err != nil {
|
|
return r.streamError(stream.streamID, http2.ErrCodeInternal)
|
|
}
|
|
}
|
|
if frame.Header().Flags.Has(http2.FlagDataEndStream) {
|
|
if stream.receiveEOF() {
|
|
r.streams.Delete(stream.streamID)
|
|
logger.Debug("stream closed")
|
|
} else {
|
|
logger.Debug("shutdown receive side")
|
|
}
|
|
return nil
|
|
}
|
|
if !stream.consumeReceiveWindow(uint32(len(data))) {
|
|
return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Receive a PING from the peer. Update RTT and send/receive window metrics if it's an ACK.
|
|
func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
|
|
ts := int64(binary.LittleEndian.Uint64(frame.Data[:]))
|
|
if !frame.IsAck() {
|
|
r.pingTimestamp.Set(ts)
|
|
return
|
|
}
|
|
r.rttMutex.Lock()
|
|
r.rttMeasurement.Update(time.Unix(0, ts))
|
|
r.rttMutex.Unlock()
|
|
r.flowControlMetrics = r.streams.Metrics()
|
|
}
|
|
|
|
// Receive a GOAWAY from the peer. Gracefully shut down our connection.
|
|
func (r *MuxReader) receiveGoAway(frame *http2.GoAwayFrame) error {
|
|
r.Shutdown()
|
|
// Close all streams above the last processed stream
|
|
lastStream := r.streams.LastLocalStreamID()
|
|
for i := frame.LastStreamID + 2; i <= lastStream; i++ {
|
|
if stream, ok := r.streams.Get(i); ok {
|
|
stream.Close()
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Receives header frames from a stream. A non-nil error is a connection error.
|
|
func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
|
|
stream, err := r.getStreamForFrame(frame)
|
|
if err != nil && err != ErrUnknownStream && err != ErrClosedStream {
|
|
return err
|
|
}
|
|
if stream == nil {
|
|
// ignore window updates on closed streams
|
|
return nil
|
|
}
|
|
stream.replenishSendWindow(frame.Increment)
|
|
return nil
|
|
}
|
|
|
|
// Raise a stream processing error, closing the stream. Runs on the write thread.
|
|
func (r *MuxReader) streamError(streamID uint32, e http2.ErrCode) error {
|
|
r.streamErrors.RaiseError(streamID, e)
|
|
return nil
|
|
}
|
|
|
|
func (r *MuxReader) connectionError(err error) error {
|
|
http2Code := http2.ErrCodeInternal
|
|
switch e := err.(type) {
|
|
case http2.ConnectionError:
|
|
http2Code = http2.ErrCode(e)
|
|
case MuxerProtocolError:
|
|
http2Code = e.h2code
|
|
}
|
|
r.sendGoAway(http2Code)
|
|
return err
|
|
}
|
|
|
|
// Instruct the writer to send a GOAWAY message if possible. This may fail in
|
|
// the case where an existing GOAWAY message is in flight or the writer event
|
|
// loop already ended.
|
|
func (r *MuxReader) sendGoAway(errCode http2.ErrCode) {
|
|
select {
|
|
case r.goAwayChan <- errCode:
|
|
default:
|
|
}
|
|
}
|