package h2mux import ( "bytes" "encoding/binary" "io" "time" log "github.com/sirupsen/logrus" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" ) type MuxWriter struct { // f is used to write HTTP2 frames. f *http2.Framer // streams tracks currently-open streams. streams *activeStreamMap // streamErrors receives stream errors raised by the MuxReader. streamErrors *StreamErrorMap // readyStreamChan is used to multiplex writable streams onto the single connection. // When a stream becomes writable its ID is sent on this channel. readyStreamChan <-chan uint32 // newStreamChan is used to create new streams with a given set of headers. newStreamChan <-chan MuxedStreamRequest // goAwayChan is used to send a single GOAWAY message to the peer. The element received // is the HTTP/2 error code to send. goAwayChan <-chan http2.ErrCode // abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop. abortChan <-chan struct{} // pingTimestamp is an atomic value containing the latest received ping timestamp. pingTimestamp *PingTimestamp // A timer used to measure idle connection time. Reset after sending data. idleTimer *IdleTimer // connActiveChan receives a signal that the connection received some (read) activity. connActiveChan <-chan struct{} // Maximum size of all frames that can be sent on this connection. maxFrameSize uint32 // headerEncoder is the stateful header encoder for this connection headerEncoder *hpack.Encoder // headerBuffer is the temporary buffer used by headerEncoder. headerBuffer bytes.Buffer // metricsUpdater is used to report metrics metricsUpdater muxMetricsUpdater // bytesWrote is the amount of bytes written to data frames since the last time we called metricsUpdater.updateOutBoundBytes() bytesWrote *AtomicCounter useDictChan <-chan useDictRequest } type MuxedStreamRequest struct { stream *MuxedStream body io.Reader } func (r *MuxedStreamRequest) flushBody() { io.Copy(r.stream, r.body) r.stream.CloseWrite() } func tsToPingData(ts int64) [8]byte { pingData := [8]byte{} binary.LittleEndian.PutUint64(pingData[:], uint64(ts)) return pingData } func (w *MuxWriter) run(parentLogger *log.Entry) error { logger := parentLogger.WithFields(log.Fields{ "subsystem": "mux", "dir": "write", }) defer logger.Debug("event loop finished") // routine to periodically communicate bytesWrote go func() { tickC := time.Tick(updateFreq) for { select { case <-w.abortChan: return case <-tickC: w.metricsUpdater.updateOutBoundBytes(w.bytesWrote.Count()) } } }() for { select { case <-w.abortChan: logger.Debug("aborting writer thread") return nil case errCode := <-w.goAwayChan: logger.Debug("sending GOAWAY code ", errCode) err := w.f.WriteGoAway(w.streams.LastPeerStreamID(), errCode, []byte{}) if err != nil { return err } w.idleTimer.MarkActive() case <-w.pingTimestamp.GetUpdateChan(): logger.Debug("sending PING ACK") err := w.f.WritePing(true, tsToPingData(w.pingTimestamp.Get())) if err != nil { return err } w.idleTimer.MarkActive() case <-w.idleTimer.C: if !w.idleTimer.Retry() { return ErrConnectionDropped } logger.Debug("sending PING") err := w.f.WritePing(false, tsToPingData(time.Now().UnixNano())) if err != nil { return err } w.idleTimer.ResetTimer() case <-w.connActiveChan: w.idleTimer.MarkActive() case <-w.streamErrors.GetSignalChan(): for streamID, errCode := range w.streamErrors.GetErrors() { logger.WithField("stream", streamID).WithField("code", errCode).Debug("resetting stream") err := w.f.WriteRSTStream(streamID, errCode) if err != nil { return err } } w.idleTimer.MarkActive() case streamRequest := <-w.newStreamChan: streamID := w.streams.AcquireLocalID() streamRequest.stream.streamID = streamID if !w.streams.Set(streamRequest.stream) { // Race between OpenStream and Shutdown, and Shutdown won. Let Shutdown (and the eventual abort) take // care of this stream. Ideally we'd pass the error directly to the stream object somehow so the // caller can be unblocked sooner, but the value of that optimisation is minimal for most of the // reasons why you'd call Shutdown anyway. continue } if streamRequest.body != nil { go streamRequest.flushBody() } streamLogger := logger.WithField("stream", streamID) err := w.writeStreamData(streamRequest.stream, streamLogger) if err != nil { return err } w.idleTimer.MarkActive() case streamID := <-w.readyStreamChan: streamLogger := logger.WithField("stream", streamID) stream, ok := w.streams.Get(streamID) if !ok { continue } err := w.writeStreamData(stream, streamLogger) if err != nil { return err } w.idleTimer.MarkActive() case useDict := <-w.useDictChan: err := w.writeUseDictionary(useDict) if err != nil { logger.WithError(err).Warn("error writing use dictionary") return err } w.idleTimer.MarkActive() } } } func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error { logger.Debug("writable") chunk := stream.getChunk() w.metricsUpdater.updateReceiveWindow(stream.getReceiveWindow()) w.metricsUpdater.updateSendWindow(stream.getSendWindow()) if chunk.sendHeadersFrame() { err := w.writeHeaders(chunk.streamID, chunk.headers) if err != nil { logger.WithError(err).Warn("error writing headers") return err } logger.Debug("output headers") } if chunk.sendWindowUpdateFrame() { // Send a WINDOW_UPDATE frame to update our receive window. // If the Stream ID is zero, the window update applies to the connection as a whole // RFC7540 section-6.9.1 "A receiver that receives a flow-controlled frame MUST // always account for its contribution against the connection flow-control // window, unless the receiver treats this as a connection error" err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate) if err != nil { logger.WithError(err).Warn("error writing window update") return err } logger.Debugf("increment receive window by %d", chunk.windowUpdate) } for chunk.sendDataFrame() { payload, sentEOF := chunk.nextDataFrame(int(w.maxFrameSize)) err := w.f.WriteData(chunk.streamID, sentEOF, payload) if err != nil { logger.WithError(err).Warn("error writing data") return err } // update the amount of data wrote w.bytesWrote.IncrementBy(uint64(len(payload))) logger.WithField("len", len(payload)).Debug("output data") if sentEOF { if stream.readBuffer.Closed() { // transition into closed state if !stream.gotReceiveEOF() { // the peer may send data that we no longer want to receive. Force them into the // closed state. logger.Debug("resetting stream") w.f.WriteRSTStream(chunk.streamID, http2.ErrCodeNo) } else { // Half-open stream transitioned into closed logger.Debug("closing stream") } w.streams.Delete(chunk.streamID) } else { logger.Debug("closing stream write side") } } } return nil } func (w *MuxWriter) encodeHeaders(headers []Header) ([]byte, error) { w.headerBuffer.Reset() for _, header := range headers { err := w.headerEncoder.WriteField(hpack.HeaderField{ Name: header.Name, Value: header.Value, }) if err != nil { return nil, err } } return w.headerBuffer.Bytes(), nil } // writeHeaders writes a block of encoded headers, splitting it into multiple frames if necessary. func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error { encodedHeaders, err := w.encodeHeaders(headers) if err != nil { return err } blockSize := int(w.maxFrameSize) endHeaders := len(encodedHeaders) == 0 for !endHeaders && err == nil { blockFragment := encodedHeaders if len(encodedHeaders) > blockSize { blockFragment = blockFragment[:blockSize] encodedHeaders = encodedHeaders[blockSize:] // Send CONTINUATION frame if the headers can't be fit into 1 frame err = w.f.WriteContinuation(streamID, endHeaders, blockFragment) } else { endHeaders = true err = w.f.WriteHeaders(http2.HeadersFrameParam{ StreamID: streamID, EndHeaders: endHeaders, BlockFragment: blockFragment, }) } } return err } func (w *MuxWriter) writeUseDictionary(dictRequest useDictRequest) error { err := w.f.WriteRawFrame(FrameUseDictionary, 0, dictRequest.streamID, []byte{byte(dictRequest.dictID)}) if err != nil { return err } payload := make([]byte, 0, 64) for _, set := range dictRequest.setDict { payload = append(payload, byte(set.dictID)) payload = appendVarInt(payload, 7, uint64(set.dictSZ)) payload = append(payload, 0x80) // E = 1, D = 0, Truncate = 0 } err = w.f.WriteRawFrame(FrameSetDictionary, FlagSetDictionaryAppend, dictRequest.streamID, payload) return err }