cloudflared-mirror/h2mux/muxwriter.go

285 lines
8.7 KiB
Go

package h2mux
import (
"bytes"
"encoding/binary"
"io"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
type MuxWriter struct {
// f is used to write HTTP2 frames.
f *http2.Framer
// streams tracks currently-open streams.
streams *activeStreamMap
// streamErrors receives stream errors raised by the MuxReader.
streamErrors *StreamErrorMap
// readyStreamChan is used to multiplex writable streams onto the single connection.
// When a stream becomes writable its ID is sent on this channel.
readyStreamChan <-chan uint32
// newStreamChan is used to create new streams with a given set of headers.
newStreamChan <-chan MuxedStreamRequest
// goAwayChan is used to send a single GOAWAY message to the peer. The element received
// is the HTTP/2 error code to send.
goAwayChan <-chan http2.ErrCode
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
abortChan <-chan struct{}
// pingTimestamp is an atomic value containing the latest received ping timestamp.
pingTimestamp *PingTimestamp
// A timer used to measure idle connection time. Reset after sending data.
idleTimer *IdleTimer
// connActiveChan receives a signal that the connection received some (read) activity.
connActiveChan <-chan struct{}
// Maximum size of all frames that can be sent on this connection.
maxFrameSize uint32
// headerEncoder is the stateful header encoder for this connection
headerEncoder *hpack.Encoder
// headerBuffer is the temporary buffer used by headerEncoder.
headerBuffer bytes.Buffer
// metricsUpdater is used to report metrics
metricsUpdater muxMetricsUpdater
// bytesWrote is the amount of bytes written to data frames since the last time we called metricsUpdater.updateOutBoundBytes()
bytesWrote *AtomicCounter
useDictChan <-chan useDictRequest
}
type MuxedStreamRequest struct {
stream *MuxedStream
body io.Reader
}
func (r *MuxedStreamRequest) flushBody() {
io.Copy(r.stream, r.body)
r.stream.CloseWrite()
}
func tsToPingData(ts int64) [8]byte {
pingData := [8]byte{}
binary.LittleEndian.PutUint64(pingData[:], uint64(ts))
return pingData
}
func (w *MuxWriter) run(parentLogger *log.Entry) error {
logger := parentLogger.WithFields(log.Fields{
"subsystem": "mux",
"dir": "write",
})
defer logger.Debug("event loop finished")
// routine to periodically communicate bytesWrote
go func() {
tickC := time.Tick(updateFreq)
for {
select {
case <-w.abortChan:
return
case <-tickC:
w.metricsUpdater.updateOutBoundBytes(w.bytesWrote.Count())
}
}
}()
for {
select {
case <-w.abortChan:
logger.Debug("aborting writer thread")
return nil
case errCode := <-w.goAwayChan:
logger.Debug("sending GOAWAY code ", errCode)
err := w.f.WriteGoAway(w.streams.LastPeerStreamID(), errCode, []byte{})
if err != nil {
return err
}
w.idleTimer.MarkActive()
case <-w.pingTimestamp.GetUpdateChan():
logger.Debug("sending PING ACK")
err := w.f.WritePing(true, tsToPingData(w.pingTimestamp.Get()))
if err != nil {
return err
}
w.idleTimer.MarkActive()
case <-w.idleTimer.C:
if !w.idleTimer.Retry() {
return ErrConnectionDropped
}
logger.Debug("sending PING")
err := w.f.WritePing(false, tsToPingData(time.Now().UnixNano()))
if err != nil {
return err
}
w.idleTimer.ResetTimer()
case <-w.connActiveChan:
w.idleTimer.MarkActive()
case <-w.streamErrors.GetSignalChan():
for streamID, errCode := range w.streamErrors.GetErrors() {
logger.WithField("stream", streamID).WithField("code", errCode).Debug("resetting stream")
err := w.f.WriteRSTStream(streamID, errCode)
if err != nil {
return err
}
}
w.idleTimer.MarkActive()
case streamRequest := <-w.newStreamChan:
streamID := w.streams.AcquireLocalID()
streamRequest.stream.streamID = streamID
if !w.streams.Set(streamRequest.stream) {
// Race between OpenStream and Shutdown, and Shutdown won. Let Shutdown (and the eventual abort) take
// care of this stream. Ideally we'd pass the error directly to the stream object somehow so the
// caller can be unblocked sooner, but the value of that optimisation is minimal for most of the
// reasons why you'd call Shutdown anyway.
continue
}
if streamRequest.body != nil {
go streamRequest.flushBody()
}
streamLogger := logger.WithField("stream", streamID)
err := w.writeStreamData(streamRequest.stream, streamLogger)
if err != nil {
return err
}
w.idleTimer.MarkActive()
case streamID := <-w.readyStreamChan:
streamLogger := logger.WithField("stream", streamID)
stream, ok := w.streams.Get(streamID)
if !ok {
continue
}
err := w.writeStreamData(stream, streamLogger)
if err != nil {
return err
}
w.idleTimer.MarkActive()
case useDict := <-w.useDictChan:
err := w.writeUseDictionary(useDict)
if err != nil {
logger.WithError(err).Warn("error writing use dictionary")
return err
}
w.idleTimer.MarkActive()
}
}
}
func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error {
logger.Debug("writable")
chunk := stream.getChunk()
w.metricsUpdater.updateReceiveWindow(stream.getReceiveWindow())
w.metricsUpdater.updateSendWindow(stream.getSendWindow())
if chunk.sendHeadersFrame() {
err := w.writeHeaders(chunk.streamID, chunk.headers)
if err != nil {
logger.WithError(err).Warn("error writing headers")
return err
}
logger.Debug("output headers")
}
if chunk.sendWindowUpdateFrame() {
// Send a WINDOW_UPDATE frame to update our receive window.
// If the Stream ID is zero, the window update applies to the connection as a whole
// RFC7540 section-6.9.1 "A receiver that receives a flow-controlled frame MUST
// always account for its contribution against the connection flow-control
// window, unless the receiver treats this as a connection error"
err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
if err != nil {
logger.WithError(err).Warn("error writing window update")
return err
}
logger.Debugf("increment receive window by %d", chunk.windowUpdate)
}
for chunk.sendDataFrame() {
payload, sentEOF := chunk.nextDataFrame(int(w.maxFrameSize))
err := w.f.WriteData(chunk.streamID, sentEOF, payload)
if err != nil {
logger.WithError(err).Warn("error writing data")
return err
}
// update the amount of data wrote
w.bytesWrote.IncrementBy(uint64(len(payload)))
logger.WithField("len", len(payload)).Debug("output data")
if sentEOF {
if stream.readBuffer.Closed() {
// transition into closed state
if !stream.gotReceiveEOF() {
// the peer may send data that we no longer want to receive. Force them into the
// closed state.
logger.Debug("resetting stream")
w.f.WriteRSTStream(chunk.streamID, http2.ErrCodeNo)
} else {
// Half-open stream transitioned into closed
logger.Debug("closing stream")
}
w.streams.Delete(chunk.streamID)
} else {
logger.Debug("closing stream write side")
}
}
}
return nil
}
func (w *MuxWriter) encodeHeaders(headers []Header) ([]byte, error) {
w.headerBuffer.Reset()
for _, header := range headers {
err := w.headerEncoder.WriteField(hpack.HeaderField{
Name: header.Name,
Value: header.Value,
})
if err != nil {
return nil, err
}
}
return w.headerBuffer.Bytes(), nil
}
// writeHeaders writes a block of encoded headers, splitting it into multiple frames if necessary.
func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error {
encodedHeaders, err := w.encodeHeaders(headers)
if err != nil {
return err
}
blockSize := int(w.maxFrameSize)
endHeaders := len(encodedHeaders) == 0
for !endHeaders && err == nil {
blockFragment := encodedHeaders
if len(encodedHeaders) > blockSize {
blockFragment = blockFragment[:blockSize]
encodedHeaders = encodedHeaders[blockSize:]
// Send CONTINUATION frame if the headers can't be fit into 1 frame
err = w.f.WriteContinuation(streamID, endHeaders, blockFragment)
} else {
endHeaders = true
err = w.f.WriteHeaders(http2.HeadersFrameParam{
StreamID: streamID,
EndHeaders: endHeaders,
BlockFragment: blockFragment,
})
}
}
return err
}
func (w *MuxWriter) writeUseDictionary(dictRequest useDictRequest) error {
err := w.f.WriteRawFrame(FrameUseDictionary, 0, dictRequest.streamID, []byte{byte(dictRequest.dictID)})
if err != nil {
return err
}
payload := make([]byte, 0, 64)
for _, set := range dictRequest.setDict {
payload = append(payload, byte(set.dictID))
payload = appendVarInt(payload, 7, uint64(set.dictSZ))
payload = append(payload, 0x80) // E = 1, D = 0, Truncate = 0
}
err = w.f.WriteRawFrame(FrameSetDictionary, FlagSetDictionaryAppend, dictRequest.streamID, payload)
return err
}