513 lines
15 KiB
Go
513 lines
15 KiB
Go
package h2mux
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"io"
|
|
"net/url"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
const (
|
|
CloudflaredProxyTunnelHostnameHeader = "cf-cloudflared-proxy-tunnel-hostname"
|
|
)
|
|
|
|
type MuxReader struct {
|
|
// f is used to read HTTP2 frames.
|
|
f *http2.Framer
|
|
// handler provides a callback to receive new streams. if nil, new streams cannot be accepted.
|
|
handler MuxedStreamHandler
|
|
// streams tracks currently-open streams.
|
|
streams *activeStreamMap
|
|
// readyList is used to signal writable streams.
|
|
readyList *ReadyList
|
|
// streamErrors lets us report stream errors to the MuxWriter.
|
|
streamErrors *StreamErrorMap
|
|
// goAwayChan is used to tell the writer to send a GOAWAY message.
|
|
goAwayChan chan<- http2.ErrCode
|
|
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
|
|
abortChan <-chan struct{}
|
|
// pingTimestamp is an atomic value containing the latest received ping timestamp.
|
|
pingTimestamp *PingTimestamp
|
|
// connActive is used to signal to the writer that something happened on the connection.
|
|
// This is used to clear idle timeout disconnection deadlines.
|
|
connActive Signal
|
|
// The initial value for the send and receive window of a new stream.
|
|
initialStreamWindow uint32
|
|
// The max value for the send window of a stream.
|
|
streamWindowMax uint32
|
|
// The max size for the write buffer of a stream
|
|
streamWriteBufferMaxLen int
|
|
// r is a reference to the underlying connection used when shutting down.
|
|
r io.Closer
|
|
// metricsUpdater is used to report metrics
|
|
metricsUpdater muxMetricsUpdater
|
|
// bytesRead is the amount of bytes read from data frames since the last time we called metricsUpdater.updateInBoundBytes()
|
|
bytesRead *AtomicCounter
|
|
// dictionaries holds the h2 cross-stream compression dictionaries
|
|
dictionaries h2Dictionaries
|
|
}
|
|
|
|
// Shutdown blocks new streams from being created.
|
|
// It returns a channel that is closed once the last stream has closed.
|
|
func (r *MuxReader) Shutdown() <-chan struct{} {
|
|
done, alreadyInProgress := r.streams.Shutdown()
|
|
if alreadyInProgress {
|
|
return done
|
|
}
|
|
r.sendGoAway(http2.ErrCodeNo)
|
|
go func() {
|
|
// close reader side when last stream ends; this will cause the writer to abort
|
|
<-done
|
|
r.r.Close()
|
|
}()
|
|
return done
|
|
}
|
|
|
|
func (r *MuxReader) run(parentLogger *log.Entry) error {
|
|
logger := parentLogger.WithFields(log.Fields{
|
|
"subsystem": "mux",
|
|
"dir": "read",
|
|
})
|
|
defer logger.Debug("event loop finished")
|
|
|
|
// routine to periodically update bytesRead
|
|
go func() {
|
|
tickC := time.Tick(updateFreq)
|
|
for {
|
|
select {
|
|
case <-r.abortChan:
|
|
return
|
|
case <-tickC:
|
|
r.metricsUpdater.updateInBoundBytes(r.bytesRead.Count())
|
|
}
|
|
}
|
|
}()
|
|
|
|
for {
|
|
frame, err := r.f.ReadFrame()
|
|
if err != nil {
|
|
errLogger := logger.WithError(err)
|
|
if errorDetail := r.f.ErrorDetail(); errorDetail != nil {
|
|
errLogger = errLogger.WithField("errorDetail", errorDetail)
|
|
}
|
|
switch e := err.(type) {
|
|
case http2.StreamError:
|
|
errLogger.Warn("stream error")
|
|
r.streamError(e.StreamID, e.Code)
|
|
case http2.ConnectionError:
|
|
errLogger.Warn("connection error")
|
|
return r.connectionError(err)
|
|
default:
|
|
if isConnectionClosedError(err) {
|
|
if r.streams.Len() == 0 {
|
|
// don't log the error here -- that would just be extra noise
|
|
logger.Debug("shutting down")
|
|
return nil
|
|
}
|
|
errLogger.Warn("connection closed unexpectedly")
|
|
return err
|
|
} else {
|
|
errLogger.Warn("frame read error")
|
|
return r.connectionError(err)
|
|
}
|
|
}
|
|
}
|
|
r.connActive.Signal()
|
|
logger.WithField("data", frame).Debug("read frame")
|
|
switch f := frame.(type) {
|
|
case *http2.DataFrame:
|
|
err = r.receiveFrameData(f, logger)
|
|
case *http2.MetaHeadersFrame:
|
|
err = r.receiveHeaderData(f)
|
|
case *http2.RSTStreamFrame:
|
|
streamID := f.Header().StreamID
|
|
if streamID == 0 {
|
|
return ErrInvalidStream
|
|
}
|
|
if stream, ok := r.streams.Get(streamID); ok {
|
|
stream.Close()
|
|
}
|
|
r.streams.Delete(streamID)
|
|
case *http2.PingFrame:
|
|
r.receivePingData(f)
|
|
case *http2.GoAwayFrame:
|
|
err = r.receiveGoAway(f)
|
|
// The receiver of a flow-controlled frame sends a WINDOW_UPDATE frame as it
|
|
// consumes data and frees up space in flow-control windows
|
|
case *http2.WindowUpdateFrame:
|
|
err = r.updateStreamWindow(f)
|
|
case *http2.UnknownFrame:
|
|
switch f.Header().Type {
|
|
case FrameUseDictionary:
|
|
err = r.receiveUseDictionary(f)
|
|
case FrameSetDictionary:
|
|
err = r.receiveSetDictionary(f)
|
|
default:
|
|
err = ErrUnexpectedFrameType
|
|
}
|
|
default:
|
|
err = ErrUnexpectedFrameType
|
|
}
|
|
if err != nil {
|
|
logger.WithField("data", frame).WithError(err).Debug("frame error")
|
|
return r.connectionError(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *MuxReader) newMuxedStream(streamID uint32) *MuxedStream {
|
|
return &MuxedStream{
|
|
streamID: streamID,
|
|
readBuffer: NewSharedBuffer(),
|
|
writeBuffer: &bytes.Buffer{},
|
|
writeBufferMaxLen: r.streamWriteBufferMaxLen,
|
|
writeBufferHasSpace: make(chan struct{}, 1),
|
|
receiveWindow: r.initialStreamWindow,
|
|
receiveWindowCurrentMax: r.initialStreamWindow,
|
|
receiveWindowMax: r.streamWindowMax,
|
|
sendWindow: r.initialStreamWindow,
|
|
readyList: r.readyList,
|
|
dictionaries: r.dictionaries,
|
|
}
|
|
}
|
|
|
|
// getStreamForFrame returns a stream if valid, or an error describing why the stream could not be returned.
|
|
func (r *MuxReader) getStreamForFrame(frame http2.Frame) (*MuxedStream, error) {
|
|
sid := frame.Header().StreamID
|
|
if sid == 0 {
|
|
return nil, ErrUnexpectedFrameType
|
|
}
|
|
if stream, ok := r.streams.Get(sid); ok {
|
|
return stream, nil
|
|
}
|
|
if r.streams.IsLocalStreamID(sid) {
|
|
// no stream available, but no error
|
|
return nil, ErrClosedStream
|
|
}
|
|
if sid < r.streams.LastPeerStreamID() {
|
|
// no stream available, stream closed error
|
|
return nil, ErrClosedStream
|
|
}
|
|
return nil, ErrUnknownStream
|
|
}
|
|
|
|
func (r *MuxReader) defaultStreamErrorHandler(err error, header http2.FrameHeader) error {
|
|
if header.Flags.Has(http2.FlagHeadersEndStream) {
|
|
return nil
|
|
} else if err == ErrUnknownStream || err == ErrClosedStream {
|
|
return r.streamError(header.StreamID, http2.ErrCodeStreamClosed)
|
|
} else {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Receives header frames from a stream. A non-nil error is a connection error.
|
|
func (r *MuxReader) receiveHeaderData(frame *http2.MetaHeadersFrame) error {
|
|
var stream *MuxedStream
|
|
sid := frame.Header().StreamID
|
|
if sid == 0 {
|
|
return ErrUnexpectedFrameType
|
|
}
|
|
newStream := r.streams.IsPeerStreamID(sid)
|
|
if newStream {
|
|
// header request
|
|
// TODO support trailers (if stream exists)
|
|
ok, err := r.streams.AcquirePeerID(sid)
|
|
if !ok {
|
|
// ignore new streams while shutting down
|
|
return r.streamError(sid, err)
|
|
}
|
|
stream = r.newMuxedStream(sid)
|
|
// Set stream. Returns false if a stream already existed with that ID or we are shutting down, return false.
|
|
if !r.streams.Set(stream) {
|
|
// got HEADERS frame for an existing stream
|
|
// TODO support trailers
|
|
return r.streamError(sid, http2.ErrCodeInternal)
|
|
}
|
|
} else {
|
|
// header response
|
|
var err error
|
|
if stream, err = r.getStreamForFrame(frame); err != nil {
|
|
return r.defaultStreamErrorHandler(err, frame.Header())
|
|
}
|
|
}
|
|
headers := make([]Header, 0, len(frame.Fields))
|
|
for _, header := range frame.Fields {
|
|
switch header.Name {
|
|
case ":method":
|
|
stream.method = header.Value
|
|
case ":path":
|
|
u, err := url.Parse(header.Value)
|
|
if err == nil {
|
|
stream.path = u.Path
|
|
}
|
|
case "accept-encoding":
|
|
// remove accept-encoding if dictionaries are enabled
|
|
if r.dictionaries.write != nil {
|
|
continue
|
|
}
|
|
case CloudflaredProxyTunnelHostnameHeader:
|
|
stream.tunnelHostname = TunnelHostname(header.Value)
|
|
}
|
|
headers = append(headers, Header{Name: header.Name, Value: header.Value})
|
|
}
|
|
stream.Headers = headers
|
|
if frame.Header().Flags.Has(http2.FlagHeadersEndStream) {
|
|
stream.receiveEOF()
|
|
return nil
|
|
}
|
|
if newStream {
|
|
go r.handleStream(stream)
|
|
} else {
|
|
close(stream.responseHeadersReceived)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *MuxReader) handleStream(stream *MuxedStream) {
|
|
defer stream.Close()
|
|
r.handler.ServeStream(stream)
|
|
}
|
|
|
|
// Receives a data frame from a stream. A non-nil error is a connection error.
|
|
func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.Entry) error {
|
|
logger := parentLogger.WithField("stream", frame.Header().StreamID)
|
|
stream, err := r.getStreamForFrame(frame)
|
|
if err != nil {
|
|
return r.defaultStreamErrorHandler(err, frame.Header())
|
|
}
|
|
data := frame.Data()
|
|
if len(data) > 0 {
|
|
n, err := stream.readBuffer.Write(data)
|
|
if err != nil {
|
|
return r.streamError(stream.streamID, http2.ErrCodeInternal)
|
|
}
|
|
r.bytesRead.IncrementBy(uint64(n))
|
|
}
|
|
if frame.Header().Flags.Has(http2.FlagDataEndStream) {
|
|
if stream.receiveEOF() {
|
|
r.streams.Delete(stream.streamID)
|
|
logger.Debug("stream closed")
|
|
} else {
|
|
logger.Debug("shutdown receive side")
|
|
}
|
|
return nil
|
|
}
|
|
if !stream.consumeReceiveWindow(uint32(len(data))) {
|
|
return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
|
|
}
|
|
r.metricsUpdater.updateReceiveWindow(stream.getReceiveWindow())
|
|
return nil
|
|
}
|
|
|
|
// Receive a PING from the peer. Update RTT and send/receive window metrics if it's an ACK.
|
|
func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
|
|
ts := int64(binary.LittleEndian.Uint64(frame.Data[:]))
|
|
if !frame.IsAck() {
|
|
r.pingTimestamp.Set(ts)
|
|
return
|
|
}
|
|
|
|
// Update the computed RTT aggregations with a new measurement.
|
|
// `ts` is the time that the probe was sent.
|
|
// We assume that `time.Now()` is the time we received that probe.
|
|
r.metricsUpdater.updateRTT(&roundTripMeasurement{
|
|
receiveTime: time.Now(),
|
|
sendTime: time.Unix(0, ts),
|
|
})
|
|
}
|
|
|
|
// Receive a GOAWAY from the peer. Gracefully shut down our connection.
|
|
func (r *MuxReader) receiveGoAway(frame *http2.GoAwayFrame) error {
|
|
r.Shutdown()
|
|
// Close all streams above the last processed stream
|
|
lastStream := r.streams.LastLocalStreamID()
|
|
for i := frame.LastStreamID + 2; i <= lastStream; i++ {
|
|
if stream, ok := r.streams.Get(i); ok {
|
|
stream.Close()
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Receive a USE_DICTIONARY from the peer. Setup dictionary for stream.
|
|
func (r *MuxReader) receiveUseDictionary(frame *http2.UnknownFrame) error {
|
|
payload := frame.Payload()
|
|
streamID := frame.StreamID
|
|
|
|
// Check frame is formatted properly
|
|
if len(payload) != 1 {
|
|
return r.streamError(streamID, http2.ErrCodeProtocol)
|
|
}
|
|
|
|
stream, err := r.getStreamForFrame(frame)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if stream.receivedUseDict == true || stream.dictionaries.read == nil {
|
|
return r.streamError(streamID, http2.ErrCodeInternal)
|
|
}
|
|
|
|
stream.receivedUseDict = true
|
|
dictID := payload[0]
|
|
|
|
dictReader := stream.dictionaries.read.newReader(stream.readBuffer.(*SharedBuffer), dictID)
|
|
if dictReader == nil {
|
|
return r.streamError(streamID, http2.ErrCodeInternal)
|
|
}
|
|
|
|
stream.readBufferLock.Lock()
|
|
stream.readBuffer = dictReader
|
|
stream.readBufferLock.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Receive a SET_DICTIONARY from the peer. Update dictionaries accordingly.
|
|
func (r *MuxReader) receiveSetDictionary(frame *http2.UnknownFrame) (err error) {
|
|
|
|
payload := frame.Payload()
|
|
flags := frame.Flags
|
|
|
|
stream, err := r.getStreamForFrame(frame)
|
|
if err != nil && err != ErrClosedStream {
|
|
return err
|
|
}
|
|
reader, ok := stream.readBuffer.(*h2DictionaryReader)
|
|
if !ok {
|
|
return r.streamError(frame.StreamID, http2.ErrCodeProtocol)
|
|
}
|
|
|
|
// A SetDictionary frame consists of several
|
|
// Dictionary-Entries that specify how existing dictionaries
|
|
// are to be updated using the current stream data
|
|
// +---------------+---------------+
|
|
// | Dictionary-Entry (+) ...
|
|
// +---------------+---------------+
|
|
|
|
for {
|
|
// Each Dictionary-Entry is formatted as follows:
|
|
// +-------------------------------+
|
|
// | Dictionary-ID (8) |
|
|
// +---+---------------------------+
|
|
// | P | Size (7+) |
|
|
// +---+---------------------------+
|
|
// | E?| D?| Truncate? (6+) |
|
|
// +---+---------------------------+
|
|
// | Offset? (8+) |
|
|
// +-------------------------------+
|
|
|
|
var size, truncate, offset uint64
|
|
var p, e, d bool
|
|
|
|
// Parse a single Dictionary-Entry
|
|
if len(payload) < 2 { // Must have at least id and size
|
|
return MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
|
|
}
|
|
|
|
dictID := uint8(payload[0])
|
|
p = (uint8(payload[1]) >> 7) == 1
|
|
payload, size, err = http2ReadVarInt(7, payload[1:])
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if flags.Has(FlagSetDictionaryAppend) {
|
|
// Presence of FlagSetDictionaryAppend means we expect e, d and truncate
|
|
if len(payload) < 1 {
|
|
return MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
|
|
}
|
|
e = (uint8(payload[0]) >> 7) == 1
|
|
d = (uint8((payload[0])>>6) & 1) == 1
|
|
payload, truncate, err = http2ReadVarInt(6, payload)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
if flags.Has(FlagSetDictionaryOffset) {
|
|
// Presence of FlagSetDictionaryOffset means we expect offset
|
|
if len(payload) < 1 {
|
|
return MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
|
|
}
|
|
payload, offset, err = http2ReadVarInt(8, payload)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
setdict := setDictRequest{streamID: stream.streamID,
|
|
dictID: dictID,
|
|
dictSZ: size,
|
|
truncate: truncate,
|
|
offset: offset,
|
|
P: p,
|
|
E: e,
|
|
D: d}
|
|
|
|
// Find the right dictionary
|
|
dict, err := r.dictionaries.read.getDictByID(dictID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Register a dictionary update order for the dictionary and reader
|
|
updateEntry := &dictUpdate{reader: reader, dictionary: dict, s: setdict}
|
|
dict.queue = append(dict.queue, updateEntry)
|
|
reader.queue = append(reader.queue, updateEntry)
|
|
// End of frame
|
|
if len(payload) == 0 {
|
|
break
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Receives header frames from a stream. A non-nil error is a connection error.
|
|
func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
|
|
stream, err := r.getStreamForFrame(frame)
|
|
if err != nil && err != ErrUnknownStream && err != ErrClosedStream {
|
|
return err
|
|
}
|
|
if stream == nil {
|
|
// ignore window updates on closed streams
|
|
return nil
|
|
}
|
|
stream.replenishSendWindow(frame.Increment)
|
|
r.metricsUpdater.updateSendWindow(stream.getSendWindow())
|
|
return nil
|
|
}
|
|
|
|
// Raise a stream processing error, closing the stream. Runs on the write thread.
|
|
func (r *MuxReader) streamError(streamID uint32, e http2.ErrCode) error {
|
|
r.streamErrors.RaiseError(streamID, e)
|
|
return nil
|
|
}
|
|
|
|
func (r *MuxReader) connectionError(err error) error {
|
|
http2Code := http2.ErrCodeInternal
|
|
switch e := err.(type) {
|
|
case http2.ConnectionError:
|
|
http2Code = http2.ErrCode(e)
|
|
case MuxerProtocolError:
|
|
http2Code = e.h2code
|
|
}
|
|
r.sendGoAway(http2Code)
|
|
return err
|
|
}
|
|
|
|
// Instruct the writer to send a GOAWAY message if possible. This may fail in
|
|
// the case where an existing GOAWAY message is in flight or the writer event
|
|
// loop already ended.
|
|
func (r *MuxReader) sendGoAway(errCode http2.ErrCode) {
|
|
select {
|
|
case r.goAwayChan <- errCode:
|
|
default:
|
|
}
|
|
}
|