cloudflared-mirror/h2mux/muxreader.go

502 lines
14 KiB
Go

package h2mux
import (
"bytes"
"encoding/binary"
"io"
"net/url"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
)
type MuxReader struct {
// f is used to read HTTP2 frames.
f *http2.Framer
// handler provides a callback to receive new streams. if nil, new streams cannot be accepted.
handler MuxedStreamHandler
// streams tracks currently-open streams.
streams *activeStreamMap
// readyList is used to signal writable streams.
readyList *ReadyList
// streamErrors lets us report stream errors to the MuxWriter.
streamErrors *StreamErrorMap
// goAwayChan is used to tell the writer to send a GOAWAY message.
goAwayChan chan<- http2.ErrCode
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
abortChan <-chan struct{}
// pingTimestamp is an atomic value containing the latest received ping timestamp.
pingTimestamp *PingTimestamp
// connActive is used to signal to the writer that something happened on the connection.
// This is used to clear idle timeout disconnection deadlines.
connActive Signal
// The initial value for the send and receive window of a new stream.
initialStreamWindow uint32
// The max value for the send window of a stream.
streamWindowMax uint32
// The max size for the write buffer of a stream
streamWriteBufferMaxLen int
// r is a reference to the underlying connection used when shutting down.
r io.Closer
// updateRTTChan is the channel to send new RTT measurement to muxerMetricsUpdater
updateRTTChan chan<- *roundTripMeasurement
// updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
updateReceiveWindowChan chan<- uint32
// updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
updateSendWindowChan chan<- uint32
// bytesRead is the amount of bytes read from data frame since the last time we send bytes read to metrics
bytesRead *AtomicCounter
// updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
updateInBoundBytesChan chan<- uint64
// dictionaries holds the h2 cross-stream compression dictionaries
dictionaries h2Dictionaries
}
func (r *MuxReader) Shutdown() {
done := r.streams.Shutdown()
if done == nil {
return
}
r.sendGoAway(http2.ErrCodeNo)
go func() {
// close reader side when last stream ends; this will cause the writer to abort
<-done
r.r.Close()
}()
}
func (r *MuxReader) run(parentLogger *log.Entry) error {
logger := parentLogger.WithFields(log.Fields{
"subsystem": "mux",
"dir": "read",
})
defer logger.Debug("event loop finished")
// routine to periodically update bytesRead
go func() {
tickC := time.Tick(updateFreq)
for {
select {
case <-r.abortChan:
return
case <-tickC:
r.updateInBoundBytesChan <- r.bytesRead.Count()
}
}
}()
for {
frame, err := r.f.ReadFrame()
if err != nil {
switch e := err.(type) {
case http2.StreamError:
logger.WithError(err).Warn("stream error")
r.streamError(e.StreamID, e.Code)
case http2.ConnectionError:
logger.WithError(err).Warn("connection error")
return r.connectionError(err)
default:
if isConnectionClosedError(err) {
if r.streams.Len() == 0 {
logger.Debug("shutting down")
return nil
}
logger.Warn("connection closed unexpectedly")
return err
} else {
logger.WithError(err).Warn("frame read error")
return r.connectionError(err)
}
}
}
r.connActive.Signal()
logger.WithField("data", frame).Debug("read frame")
switch f := frame.(type) {
case *http2.DataFrame:
err = r.receiveFrameData(f, logger)
case *http2.MetaHeadersFrame:
err = r.receiveHeaderData(f)
case *http2.RSTStreamFrame:
streamID := f.Header().StreamID
if streamID == 0 {
return ErrInvalidStream
}
r.streams.Delete(streamID)
case *http2.PingFrame:
r.receivePingData(f)
case *http2.GoAwayFrame:
err = r.receiveGoAway(f)
// The receiver of a flow-controlled frame sends a WINDOW_UPDATE frame as it
// consumes data and frees up space in flow-control windows
case *http2.WindowUpdateFrame:
err = r.updateStreamWindow(f)
case *http2.UnknownFrame:
switch f.Header().Type {
case FrameUseDictionary:
err = r.receiveUseDictionary(f)
case FrameSetDictionary:
err = r.receiveSetDictionary(f)
default:
err = ErrUnexpectedFrameType
}
default:
err = ErrUnexpectedFrameType
}
if err != nil {
logger.WithField("data", frame).WithError(err).Debug("frame error")
return r.connectionError(err)
}
}
}
func (r *MuxReader) newMuxedStream(streamID uint32) *MuxedStream {
return &MuxedStream{
streamID: streamID,
readBuffer: NewSharedBuffer(),
writeBuffer: &bytes.Buffer{},
writeBufferMaxLen: r.streamWriteBufferMaxLen,
writeBufferHasSpace: make(chan struct{}, 1),
receiveWindow: r.initialStreamWindow,
receiveWindowCurrentMax: r.initialStreamWindow,
receiveWindowMax: r.streamWindowMax,
sendWindow: r.initialStreamWindow,
readyList: r.readyList,
dictionaries: r.dictionaries,
}
}
// getStreamForFrame returns a stream if valid, or an error describing why the stream could not be returned.
func (r *MuxReader) getStreamForFrame(frame http2.Frame) (*MuxedStream, error) {
sid := frame.Header().StreamID
if sid == 0 {
return nil, ErrUnexpectedFrameType
}
if stream, ok := r.streams.Get(sid); ok {
return stream, nil
}
if r.streams.IsLocalStreamID(sid) {
// no stream available, but no error
return nil, ErrClosedStream
}
if sid < r.streams.LastPeerStreamID() {
// no stream available, stream closed error
return nil, ErrClosedStream
}
return nil, ErrUnknownStream
}
func (r *MuxReader) defaultStreamErrorHandler(err error, header http2.FrameHeader) error {
if header.Flags.Has(http2.FlagHeadersEndStream) {
return nil
} else if err == ErrUnknownStream || err == ErrClosedStream {
return r.streamError(header.StreamID, http2.ErrCodeStreamClosed)
} else {
return err
}
}
// Receives header frames from a stream. A non-nil error is a connection error.
func (r *MuxReader) receiveHeaderData(frame *http2.MetaHeadersFrame) error {
var stream *MuxedStream
sid := frame.Header().StreamID
if sid == 0 {
return ErrUnexpectedFrameType
}
newStream := r.streams.IsPeerStreamID(sid)
if newStream {
// header request
// TODO support trailers (if stream exists)
ok, err := r.streams.AcquirePeerID(sid)
if !ok {
// ignore new streams while shutting down
return r.streamError(sid, err)
}
stream = r.newMuxedStream(sid)
// Set stream. Returns false if a stream already existed with that ID or we are shutting down, return false.
if !r.streams.Set(stream) {
// got HEADERS frame for an existing stream
// TODO support trailers
return r.streamError(sid, http2.ErrCodeInternal)
}
} else {
// header response
var err error
if stream, err = r.getStreamForFrame(frame); err != nil {
return r.defaultStreamErrorHandler(err, frame.Header())
}
}
headers := make([]Header, 0, len(frame.Fields))
for _, header := range frame.Fields {
switch header.Name {
case ":method":
stream.method = header.Value
case ":path":
u, err := url.Parse(header.Value)
if err == nil {
stream.path = u.Path
}
case "accept-encoding":
// remove accept-encoding if dictionaries are enabled
if r.dictionaries.write != nil {
continue
}
}
headers = append(headers, Header{Name: header.Name, Value: header.Value})
}
stream.Headers = headers
if frame.Header().Flags.Has(http2.FlagHeadersEndStream) {
stream.receiveEOF()
return nil
}
if newStream {
go r.handleStream(stream)
} else {
close(stream.responseHeadersReceived)
}
return nil
}
func (r *MuxReader) handleStream(stream *MuxedStream) {
defer stream.Close()
r.handler.ServeStream(stream)
}
// Receives a data frame from a stream. A non-nil error is a connection error.
func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.Entry) error {
logger := parentLogger.WithField("stream", frame.Header().StreamID)
stream, err := r.getStreamForFrame(frame)
if err != nil {
return r.defaultStreamErrorHandler(err, frame.Header())
}
data := frame.Data()
if len(data) > 0 {
n, err := stream.readBuffer.Write(data)
if err != nil {
return r.streamError(stream.streamID, http2.ErrCodeInternal)
}
r.bytesRead.IncrementBy(uint64(n))
}
if frame.Header().Flags.Has(http2.FlagDataEndStream) {
if stream.receiveEOF() {
r.streams.Delete(stream.streamID)
logger.Debug("stream closed")
} else {
logger.Debug("shutdown receive side")
}
return nil
}
if !stream.consumeReceiveWindow(uint32(len(data))) {
return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
}
r.updateReceiveWindowChan <- stream.getReceiveWindow()
return nil
}
// Receive a PING from the peer. Update RTT and send/receive window metrics if it's an ACK.
func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
ts := int64(binary.LittleEndian.Uint64(frame.Data[:]))
if !frame.IsAck() {
r.pingTimestamp.Set(ts)
return
}
// Update updates the computed values with a new measurement.
// outgoingTime is the time that the probe was sent.
// We assume that time.Now() is the time we received that probe.
r.updateRTTChan <- &roundTripMeasurement{
receiveTime: time.Now(),
sendTime: time.Unix(0, ts),
}
}
// Receive a GOAWAY from the peer. Gracefully shut down our connection.
func (r *MuxReader) receiveGoAway(frame *http2.GoAwayFrame) error {
r.Shutdown()
// Close all streams above the last processed stream
lastStream := r.streams.LastLocalStreamID()
for i := frame.LastStreamID + 2; i <= lastStream; i++ {
if stream, ok := r.streams.Get(i); ok {
stream.Close()
}
}
return nil
}
// Receive a USE_DICTIONARY from the peer. Setup dictionary for stream.
func (r *MuxReader) receiveUseDictionary(frame *http2.UnknownFrame) error {
payload := frame.Payload()
streamID := frame.StreamID
// Check frame is formatted properly
if len(payload) != 1 {
return r.streamError(streamID, http2.ErrCodeProtocol)
}
stream, err := r.getStreamForFrame(frame)
if err != nil {
return err
}
if stream.receivedUseDict == true || stream.dictionaries.read == nil {
return r.streamError(streamID, http2.ErrCodeInternal)
}
stream.receivedUseDict = true
dictID := payload[0]
dictReader := stream.dictionaries.read.newReader(stream.readBuffer.(*SharedBuffer), dictID)
if dictReader == nil {
return r.streamError(streamID, http2.ErrCodeInternal)
}
stream.readBufferLock.Lock()
stream.readBuffer = dictReader
stream.readBufferLock.Unlock()
return nil
}
// Receive a SET_DICTIONARY from the peer. Update dictionaries accordingly.
func (r *MuxReader) receiveSetDictionary(frame *http2.UnknownFrame) (err error) {
payload := frame.Payload()
flags := frame.Flags
stream, err := r.getStreamForFrame(frame)
if err != nil && err != ErrClosedStream {
return err
}
reader, ok := stream.readBuffer.(*h2DictionaryReader)
if !ok {
return r.streamError(frame.StreamID, http2.ErrCodeProtocol)
}
// A SetDictionary frame consists of several
// Dictionary-Entries that specify how existing dictionaries
// are to be updated using the current stream data
// +---------------+---------------+
// | Dictionary-Entry (+) ...
// +---------------+---------------+
for {
// Each Dictionary-Entry is formatted as follows:
// +-------------------------------+
// | Dictionary-ID (8) |
// +---+---------------------------+
// | P | Size (7+) |
// +---+---------------------------+
// | E?| D?| Truncate? (6+) |
// +---+---------------------------+
// | Offset? (8+) |
// +-------------------------------+
var size, truncate, offset uint64
var p, e, d bool
// Parse a single Dictionary-Entry
if len(payload) < 2 { // Must have at least id and size
return MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
}
dictID := uint8(payload[0])
p = (uint8(payload[1]) >> 7) == 1
payload, size, err = http2ReadVarInt(7, payload[1:])
if err != nil {
return
}
if flags.Has(FlagSetDictionaryAppend) {
// Presence of FlagSetDictionaryAppend means we expect e, d and truncate
if len(payload) < 1 {
return MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
}
e = (uint8(payload[0]) >> 7) == 1
d = (uint8((payload[0])>>6) & 1) == 1
payload, truncate, err = http2ReadVarInt(6, payload)
if err != nil {
return
}
}
if flags.Has(FlagSetDictionaryOffset) {
// Presence of FlagSetDictionaryOffset means we expect offset
if len(payload) < 1 {
return MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
}
payload, offset, err = http2ReadVarInt(8, payload)
if err != nil {
return
}
}
setdict := setDictRequest{streamID: stream.streamID,
dictID: dictID,
dictSZ: size,
truncate: truncate,
offset: offset,
P: p,
E: e,
D: d}
// Find the right dictionary
dict, err := r.dictionaries.read.getDictByID(dictID)
if err != nil {
return err
}
// Register a dictionary update order for the dictionary and reader
updateEntry := &dictUpdate{reader: reader, dictionary: dict, s: setdict}
dict.queue = append(dict.queue, updateEntry)
reader.queue = append(reader.queue, updateEntry)
// End of frame
if len(payload) == 0 {
break
}
}
return nil
}
// Receives header frames from a stream. A non-nil error is a connection error.
func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
stream, err := r.getStreamForFrame(frame)
if err != nil && err != ErrUnknownStream && err != ErrClosedStream {
return err
}
if stream == nil {
// ignore window updates on closed streams
return nil
}
stream.replenishSendWindow(frame.Increment)
r.updateSendWindowChan <- stream.getSendWindow()
return nil
}
// Raise a stream processing error, closing the stream. Runs on the write thread.
func (r *MuxReader) streamError(streamID uint32, e http2.ErrCode) error {
r.streamErrors.RaiseError(streamID, e)
return nil
}
func (r *MuxReader) connectionError(err error) error {
http2Code := http2.ErrCodeInternal
switch e := err.(type) {
case http2.ConnectionError:
http2Code = http2.ErrCode(e)
case MuxerProtocolError:
http2Code = e.h2code
}
r.sendGoAway(http2Code)
return err
}
// Instruct the writer to send a GOAWAY message if possible. This may fail in
// the case where an existing GOAWAY message is in flight or the writer event
// loop already ended.
func (r *MuxReader) sendGoAway(errCode http2.ErrCode) {
select {
case r.goAwayChan <- errCode:
default:
}
}