498 lines
14 KiB
Go
498 lines
14 KiB
Go
|
package h2mux
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"encoding/binary"
|
||
|
"io"
|
||
|
"net/url"
|
||
|
"time"
|
||
|
|
||
|
log "github.com/sirupsen/logrus"
|
||
|
"golang.org/x/net/http2"
|
||
|
)
|
||
|
|
||
|
type MuxReader struct {
|
||
|
// f is used to read HTTP2 frames.
|
||
|
f *http2.Framer
|
||
|
// handler provides a callback to receive new streams. if nil, new streams cannot be accepted.
|
||
|
handler MuxedStreamHandler
|
||
|
// streams tracks currently-open streams.
|
||
|
streams *activeStreamMap
|
||
|
// readyList is used to signal writable streams.
|
||
|
readyList *ReadyList
|
||
|
// streamErrors lets us report stream errors to the MuxWriter.
|
||
|
streamErrors *StreamErrorMap
|
||
|
// goAwayChan is used to tell the writer to send a GOAWAY message.
|
||
|
goAwayChan chan<- http2.ErrCode
|
||
|
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
|
||
|
abortChan <-chan struct{}
|
||
|
// pingTimestamp is an atomic value containing the latest received ping timestamp.
|
||
|
pingTimestamp *PingTimestamp
|
||
|
// connActive is used to signal to the writer that something happened on the connection.
|
||
|
// This is used to clear idle timeout disconnection deadlines.
|
||
|
connActive Signal
|
||
|
// The initial value for the send and receive window of a new stream.
|
||
|
initialStreamWindow uint32
|
||
|
// The max value for the send window of a stream.
|
||
|
streamWindowMax uint32
|
||
|
// r is a reference to the underlying connection used when shutting down.
|
||
|
r io.Closer
|
||
|
// updateRTTChan is the channel to send new RTT measurement to muxerMetricsUpdater
|
||
|
updateRTTChan chan<- *roundTripMeasurement
|
||
|
// updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
|
||
|
updateReceiveWindowChan chan<- uint32
|
||
|
// updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
|
||
|
updateSendWindowChan chan<- uint32
|
||
|
// bytesRead is the amount of bytes read from data frame since the last time we send bytes read to metrics
|
||
|
bytesRead *AtomicCounter
|
||
|
// updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
|
||
|
updateInBoundBytesChan chan<- uint64
|
||
|
// dictionaries holds the h2 cross-stream compression dictionaries
|
||
|
dictionaries h2Dictionaries
|
||
|
}
|
||
|
|
||
|
func (r *MuxReader) Shutdown() {
|
||
|
done := r.streams.Shutdown()
|
||
|
if done == nil {
|
||
|
return
|
||
|
}
|
||
|
r.sendGoAway(http2.ErrCodeNo)
|
||
|
go func() {
|
||
|
// close reader side when last stream ends; this will cause the writer to abort
|
||
|
<-done
|
||
|
r.r.Close()
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
func (r *MuxReader) run(parentLogger *log.Entry) error {
|
||
|
logger := parentLogger.WithFields(log.Fields{
|
||
|
"subsystem": "mux",
|
||
|
"dir": "read",
|
||
|
})
|
||
|
defer logger.Debug("event loop finished")
|
||
|
|
||
|
// routine to periodically update bytesRead
|
||
|
go func() {
|
||
|
tickC := time.Tick(updateFreq)
|
||
|
for {
|
||
|
select {
|
||
|
case <-r.abortChan:
|
||
|
return
|
||
|
case <-tickC:
|
||
|
r.updateInBoundBytesChan <- r.bytesRead.Count()
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
for {
|
||
|
frame, err := r.f.ReadFrame()
|
||
|
if err != nil {
|
||
|
switch e := err.(type) {
|
||
|
case http2.StreamError:
|
||
|
logger.WithError(err).Warn("stream error")
|
||
|
r.streamError(e.StreamID, e.Code)
|
||
|
case http2.ConnectionError:
|
||
|
logger.WithError(err).Warn("connection error")
|
||
|
return r.connectionError(err)
|
||
|
default:
|
||
|
if isConnectionClosedError(err) {
|
||
|
if r.streams.Len() == 0 {
|
||
|
logger.Debug("shutting down")
|
||
|
return nil
|
||
|
}
|
||
|
logger.Warn("connection closed unexpectedly")
|
||
|
return err
|
||
|
} else {
|
||
|
logger.WithError(err).Warn("frame read error")
|
||
|
return r.connectionError(err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
r.connActive.Signal()
|
||
|
logger.WithField("data", frame).Debug("read frame")
|
||
|
switch f := frame.(type) {
|
||
|
case *http2.DataFrame:
|
||
|
err = r.receiveFrameData(f, logger)
|
||
|
case *http2.MetaHeadersFrame:
|
||
|
err = r.receiveHeaderData(f)
|
||
|
case *http2.RSTStreamFrame:
|
||
|
streamID := f.Header().StreamID
|
||
|
if streamID == 0 {
|
||
|
return ErrInvalidStream
|
||
|
}
|
||
|
r.streams.Delete(streamID)
|
||
|
case *http2.PingFrame:
|
||
|
r.receivePingData(f)
|
||
|
case *http2.GoAwayFrame:
|
||
|
err = r.receiveGoAway(f)
|
||
|
// The receiver of a flow-controlled frame sends a WINDOW_UPDATE frame as it
|
||
|
// consumes data and frees up space in flow-control windows
|
||
|
case *http2.WindowUpdateFrame:
|
||
|
err = r.updateStreamWindow(f)
|
||
|
case *http2.UnknownFrame:
|
||
|
switch f.Header().Type {
|
||
|
case FrameUseDictionary:
|
||
|
err = r.receiveUseDictionary(f)
|
||
|
case FrameSetDictionary:
|
||
|
err = r.receiveSetDictionary(f)
|
||
|
default:
|
||
|
err = ErrUnexpectedFrameType
|
||
|
}
|
||
|
default:
|
||
|
err = ErrUnexpectedFrameType
|
||
|
}
|
||
|
if err != nil {
|
||
|
logger.WithField("data", frame).WithError(err).Debug("frame error")
|
||
|
return r.connectionError(err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (r *MuxReader) newMuxedStream(streamID uint32) *MuxedStream {
|
||
|
return &MuxedStream{
|
||
|
streamID: streamID,
|
||
|
readBuffer: NewSharedBuffer(),
|
||
|
writeBuffer: &bytes.Buffer{},
|
||
|
receiveWindow: r.initialStreamWindow,
|
||
|
receiveWindowCurrentMax: r.initialStreamWindow,
|
||
|
receiveWindowMax: r.streamWindowMax,
|
||
|
sendWindow: r.initialStreamWindow,
|
||
|
readyList: r.readyList,
|
||
|
dictionaries: r.dictionaries,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// getStreamForFrame returns a stream if valid, or an error describing why the stream could not be returned.
|
||
|
func (r *MuxReader) getStreamForFrame(frame http2.Frame) (*MuxedStream, error) {
|
||
|
sid := frame.Header().StreamID
|
||
|
if sid == 0 {
|
||
|
return nil, ErrUnexpectedFrameType
|
||
|
}
|
||
|
if stream, ok := r.streams.Get(sid); ok {
|
||
|
return stream, nil
|
||
|
}
|
||
|
if r.streams.IsLocalStreamID(sid) {
|
||
|
// no stream available, but no error
|
||
|
return nil, ErrClosedStream
|
||
|
}
|
||
|
if sid < r.streams.LastPeerStreamID() {
|
||
|
// no stream available, stream closed error
|
||
|
return nil, ErrClosedStream
|
||
|
}
|
||
|
return nil, ErrUnknownStream
|
||
|
}
|
||
|
|
||
|
func (r *MuxReader) defaultStreamErrorHandler(err error, header http2.FrameHeader) error {
|
||
|
if header.Flags.Has(http2.FlagHeadersEndStream) {
|
||
|
return nil
|
||
|
} else if err == ErrUnknownStream || err == ErrClosedStream {
|
||
|
return r.streamError(header.StreamID, http2.ErrCodeStreamClosed)
|
||
|
} else {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Receives header frames from a stream. A non-nil error is a connection error.
|
||
|
func (r *MuxReader) receiveHeaderData(frame *http2.MetaHeadersFrame) error {
|
||
|
var stream *MuxedStream
|
||
|
sid := frame.Header().StreamID
|
||
|
if sid == 0 {
|
||
|
return ErrUnexpectedFrameType
|
||
|
}
|
||
|
newStream := r.streams.IsPeerStreamID(sid)
|
||
|
if newStream {
|
||
|
// header request
|
||
|
// TODO support trailers (if stream exists)
|
||
|
ok, err := r.streams.AcquirePeerID(sid)
|
||
|
if !ok {
|
||
|
// ignore new streams while shutting down
|
||
|
return r.streamError(sid, err)
|
||
|
}
|
||
|
stream = r.newMuxedStream(sid)
|
||
|
// Set stream. Returns false if a stream already existed with that ID or we are shutting down, return false.
|
||
|
if !r.streams.Set(stream) {
|
||
|
// got HEADERS frame for an existing stream
|
||
|
// TODO support trailers
|
||
|
return r.streamError(sid, http2.ErrCodeInternal)
|
||
|
}
|
||
|
} else {
|
||
|
// header response
|
||
|
var err error
|
||
|
if stream, err = r.getStreamForFrame(frame); err != nil {
|
||
|
return r.defaultStreamErrorHandler(err, frame.Header())
|
||
|
}
|
||
|
}
|
||
|
headers := make([]Header, 0, len(frame.Fields))
|
||
|
for _, header := range frame.Fields {
|
||
|
switch header.Name {
|
||
|
case ":method":
|
||
|
stream.method = header.Value
|
||
|
case ":path":
|
||
|
u, err := url.Parse(header.Value)
|
||
|
if err == nil {
|
||
|
stream.path = u.Path
|
||
|
}
|
||
|
case "accept-encoding":
|
||
|
// remove accept-encoding if dictionaries are enabled
|
||
|
if r.dictionaries.write != nil {
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
headers = append(headers, Header{Name: header.Name, Value: header.Value})
|
||
|
}
|
||
|
stream.Headers = headers
|
||
|
if frame.Header().Flags.Has(http2.FlagHeadersEndStream) {
|
||
|
stream.receiveEOF()
|
||
|
return nil
|
||
|
}
|
||
|
if newStream {
|
||
|
go r.handleStream(stream)
|
||
|
} else {
|
||
|
close(stream.responseHeadersReceived)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (r *MuxReader) handleStream(stream *MuxedStream) {
|
||
|
defer stream.Close()
|
||
|
r.handler.ServeStream(stream)
|
||
|
}
|
||
|
|
||
|
// Receives a data frame from a stream. A non-nil error is a connection error.
|
||
|
func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.Entry) error {
|
||
|
logger := parentLogger.WithField("stream", frame.Header().StreamID)
|
||
|
stream, err := r.getStreamForFrame(frame)
|
||
|
if err != nil {
|
||
|
return r.defaultStreamErrorHandler(err, frame.Header())
|
||
|
}
|
||
|
data := frame.Data()
|
||
|
if len(data) > 0 {
|
||
|
n, err := stream.readBuffer.Write(data)
|
||
|
if err != nil {
|
||
|
return r.streamError(stream.streamID, http2.ErrCodeInternal)
|
||
|
}
|
||
|
r.bytesRead.IncrementBy(uint64(n))
|
||
|
}
|
||
|
if frame.Header().Flags.Has(http2.FlagDataEndStream) {
|
||
|
if stream.receiveEOF() {
|
||
|
r.streams.Delete(stream.streamID)
|
||
|
logger.Debug("stream closed")
|
||
|
} else {
|
||
|
logger.Debug("shutdown receive side")
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
if !stream.consumeReceiveWindow(uint32(len(data))) {
|
||
|
return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
|
||
|
}
|
||
|
r.updateReceiveWindowChan <- stream.getReceiveWindow()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Receive a PING from the peer. Update RTT and send/receive window metrics if it's an ACK.
|
||
|
func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
|
||
|
ts := int64(binary.LittleEndian.Uint64(frame.Data[:]))
|
||
|
if !frame.IsAck() {
|
||
|
r.pingTimestamp.Set(ts)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Update updates the computed values with a new measurement.
|
||
|
// outgoingTime is the time that the probe was sent.
|
||
|
// We assume that time.Now() is the time we received that probe.
|
||
|
r.updateRTTChan <- &roundTripMeasurement{
|
||
|
receiveTime: time.Now(),
|
||
|
sendTime: time.Unix(0, ts),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Receive a GOAWAY from the peer. Gracefully shut down our connection.
|
||
|
func (r *MuxReader) receiveGoAway(frame *http2.GoAwayFrame) error {
|
||
|
r.Shutdown()
|
||
|
// Close all streams above the last processed stream
|
||
|
lastStream := r.streams.LastLocalStreamID()
|
||
|
for i := frame.LastStreamID + 2; i <= lastStream; i++ {
|
||
|
if stream, ok := r.streams.Get(i); ok {
|
||
|
stream.Close()
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Receive a USE_DICTIONARY from the peer. Setup dictionary for stream.
|
||
|
func (r *MuxReader) receiveUseDictionary(frame *http2.UnknownFrame) error {
|
||
|
payload := frame.Payload()
|
||
|
streamID := frame.StreamID
|
||
|
|
||
|
// Check frame is formatted properly
|
||
|
if len(payload) != 1 {
|
||
|
return r.streamError(streamID, http2.ErrCodeProtocol)
|
||
|
}
|
||
|
|
||
|
stream, err := r.getStreamForFrame(frame)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if stream.receivedUseDict == true || stream.dictionaries.read == nil {
|
||
|
return r.streamError(streamID, http2.ErrCodeInternal)
|
||
|
}
|
||
|
|
||
|
stream.receivedUseDict = true
|
||
|
dictID := payload[0]
|
||
|
|
||
|
dictReader := stream.dictionaries.read.newReader(stream.readBuffer.(*SharedBuffer), dictID)
|
||
|
if dictReader == nil {
|
||
|
return r.streamError(streamID, http2.ErrCodeInternal)
|
||
|
}
|
||
|
|
||
|
stream.readBufferLock.Lock()
|
||
|
stream.readBuffer = dictReader
|
||
|
stream.readBufferLock.Unlock()
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Receive a SET_DICTIONARY from the peer. Update dictionaries accordingly.
|
||
|
func (r *MuxReader) receiveSetDictionary(frame *http2.UnknownFrame) (err error) {
|
||
|
|
||
|
payload := frame.Payload()
|
||
|
flags := frame.Flags
|
||
|
|
||
|
stream, err := r.getStreamForFrame(frame)
|
||
|
if err != nil && err != ErrClosedStream {
|
||
|
return err
|
||
|
}
|
||
|
reader, ok := stream.readBuffer.(*h2DictionaryReader)
|
||
|
if !ok {
|
||
|
return r.streamError(frame.StreamID, http2.ErrCodeProtocol)
|
||
|
}
|
||
|
|
||
|
// A SetDictionary frame consists of several
|
||
|
// Dictionary-Entries that specify how existing dictionaries
|
||
|
// are to be updated using the current stream data
|
||
|
// +---------------+---------------+
|
||
|
// | Dictionary-Entry (+) ...
|
||
|
// +---------------+---------------+
|
||
|
|
||
|
for {
|
||
|
// Each Dictionary-Entry is formatted as follows:
|
||
|
// +-------------------------------+
|
||
|
// | Dictionary-ID (8) |
|
||
|
// +---+---------------------------+
|
||
|
// | P | Size (7+) |
|
||
|
// +---+---------------------------+
|
||
|
// | E?| D?| Truncate? (6+) |
|
||
|
// +---+---------------------------+
|
||
|
// | Offset? (8+) |
|
||
|
// +-------------------------------+
|
||
|
|
||
|
var size, truncate, offset uint64
|
||
|
var p, e, d bool
|
||
|
|
||
|
// Parse a single Dictionary-Entry
|
||
|
if len(payload) < 2 { // Must have at least id and size
|
||
|
return MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
|
||
|
}
|
||
|
|
||
|
dictID := uint8(payload[0])
|
||
|
p = (uint8(payload[1]) >> 7) == 1
|
||
|
payload, size, err = http2ReadVarInt(7, payload[1:])
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if flags.Has(FlagSetDictionaryAppend) {
|
||
|
// Presence of FlagSetDictionaryAppend means we expect e, d and truncate
|
||
|
if len(payload) < 1 {
|
||
|
return MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
|
||
|
}
|
||
|
e = (uint8(payload[0]) >> 7) == 1
|
||
|
d = (uint8((payload[0])>>6) & 1) == 1
|
||
|
payload, truncate, err = http2ReadVarInt(6, payload)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if flags.Has(FlagSetDictionaryOffset) {
|
||
|
// Presence of FlagSetDictionaryOffset means we expect offset
|
||
|
if len(payload) < 1 {
|
||
|
return MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
|
||
|
}
|
||
|
payload, offset, err = http2ReadVarInt(8, payload)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
setdict := setDictRequest{streamID: stream.streamID,
|
||
|
dictID: dictID,
|
||
|
dictSZ: size,
|
||
|
truncate: truncate,
|
||
|
offset: offset,
|
||
|
P: p,
|
||
|
E: e,
|
||
|
D: d}
|
||
|
|
||
|
// Find the right dictionary
|
||
|
dict, err := r.dictionaries.read.getDictByID(dictID)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Register a dictionary update order for the dictionary and reader
|
||
|
updateEntry := &dictUpdate{reader: reader, dictionary: dict, s: setdict}
|
||
|
dict.queue = append(dict.queue, updateEntry)
|
||
|
reader.queue = append(reader.queue, updateEntry)
|
||
|
// End of frame
|
||
|
if len(payload) == 0 {
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Receives header frames from a stream. A non-nil error is a connection error.
|
||
|
func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
|
||
|
stream, err := r.getStreamForFrame(frame)
|
||
|
if err != nil && err != ErrUnknownStream && err != ErrClosedStream {
|
||
|
return err
|
||
|
}
|
||
|
if stream == nil {
|
||
|
// ignore window updates on closed streams
|
||
|
return nil
|
||
|
}
|
||
|
stream.replenishSendWindow(frame.Increment)
|
||
|
r.updateSendWindowChan <- stream.getSendWindow()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Raise a stream processing error, closing the stream. Runs on the write thread.
|
||
|
func (r *MuxReader) streamError(streamID uint32, e http2.ErrCode) error {
|
||
|
r.streamErrors.RaiseError(streamID, e)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (r *MuxReader) connectionError(err error) error {
|
||
|
http2Code := http2.ErrCodeInternal
|
||
|
switch e := err.(type) {
|
||
|
case http2.ConnectionError:
|
||
|
http2Code = http2.ErrCode(e)
|
||
|
case MuxerProtocolError:
|
||
|
http2Code = e.h2code
|
||
|
}
|
||
|
r.sendGoAway(http2Code)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Instruct the writer to send a GOAWAY message if possible. This may fail in
|
||
|
// the case where an existing GOAWAY message is in flight or the writer event
|
||
|
// loop already ended.
|
||
|
func (r *MuxReader) sendGoAway(errCode http2.ErrCode) {
|
||
|
select {
|
||
|
case r.goAwayChan <- errCode:
|
||
|
default:
|
||
|
}
|
||
|
}
|