package h2mux import ( "bytes" "encoding/binary" "io" "time" "github.com/rs/zerolog" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" ) type MuxWriter struct { // f is used to write HTTP2 frames. f *http2.Framer // streams tracks currently-open streams. streams *activeStreamMap // streamErrors receives stream errors raised by the MuxReader. streamErrors *StreamErrorMap // readyStreamChan is used to multiplex writable streams onto the single connection. // When a stream becomes writable its ID is sent on this channel. readyStreamChan <-chan uint32 // newStreamChan is used to create new streams with a given set of headers. newStreamChan <-chan MuxedStreamRequest // goAwayChan is used to send a single GOAWAY message to the peer. The element received // is the HTTP/2 error code to send. goAwayChan <-chan http2.ErrCode // abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop. abortChan <-chan struct{} // pingTimestamp is an atomic value containing the latest received ping timestamp. pingTimestamp *PingTimestamp // A timer used to measure idle connection time. Reset after sending data. idleTimer *IdleTimer // connActiveChan receives a signal that the connection received some (read) activity. connActiveChan <-chan struct{} // Maximum size of all frames that can be sent on this connection. maxFrameSize uint32 // headerEncoder is the stateful header encoder for this connection headerEncoder *hpack.Encoder // headerBuffer is the temporary buffer used by headerEncoder. headerBuffer bytes.Buffer // metricsUpdater is used to report metrics metricsUpdater muxMetricsUpdater // bytesWrote is the amount of bytes written to data frames since the last time we called metricsUpdater.updateOutBoundBytes() bytesWrote *AtomicCounter useDictChan <-chan useDictRequest } type MuxedStreamRequest struct { stream *MuxedStream body io.Reader } func NewMuxedStreamRequest(stream *MuxedStream, body io.Reader) MuxedStreamRequest { return MuxedStreamRequest{ stream: stream, body: body, } } func (r *MuxedStreamRequest) flushBody() { io.Copy(r.stream, r.body) r.stream.CloseWrite() } func tsToPingData(ts int64) [8]byte { pingData := [8]byte{} binary.LittleEndian.PutUint64(pingData[:], uint64(ts)) return pingData } func (w *MuxWriter) run(log *zerolog.Logger) error { defer log.Debug().Msg("mux - write: event loop finished") // routine to periodically communicate bytesWrote go func() { ticker := time.NewTicker(updateFreq) defer ticker.Stop() for { select { case <-w.abortChan: return case <-ticker.C: w.metricsUpdater.updateOutBoundBytes(w.bytesWrote.Count()) } } }() for { select { case <-w.abortChan: log.Debug().Msg("mux - write: aborting writer thread") return nil case errCode := <-w.goAwayChan: log.Debug().Msgf("mux - write: sending GOAWAY code %v", errCode) err := w.f.WriteGoAway(w.streams.LastPeerStreamID(), errCode, []byte{}) if err != nil { return err } w.idleTimer.MarkActive() case <-w.pingTimestamp.GetUpdateChan(): log.Debug().Msg("mux - write: sending PING ACK") err := w.f.WritePing(true, tsToPingData(w.pingTimestamp.Get())) if err != nil { return err } w.idleTimer.MarkActive() case <-w.idleTimer.C: if !w.idleTimer.Retry() { return ErrConnectionDropped } log.Debug().Msg("mux - write: sending PING") err := w.f.WritePing(false, tsToPingData(time.Now().UnixNano())) if err != nil { return err } w.idleTimer.ResetTimer() case <-w.connActiveChan: w.idleTimer.MarkActive() case <-w.streamErrors.GetSignalChan(): for streamID, errCode := range w.streamErrors.GetErrors() { log.Debug().Msgf("mux - write: resetting stream with code: %v streamID: %d", errCode, streamID) err := w.f.WriteRSTStream(streamID, errCode) if err != nil { return err } } w.idleTimer.MarkActive() case streamRequest := <-w.newStreamChan: streamID := w.streams.AcquireLocalID() streamRequest.stream.streamID = streamID if !w.streams.Set(streamRequest.stream) { // Race between OpenStream and Shutdown, and Shutdown won. Let Shutdown (and the eventual abort) take // care of this stream. Ideally we'd pass the error directly to the stream object somehow so the // caller can be unblocked sooner, but the value of that optimisation is minimal for most of the // reasons why you'd call Shutdown anyway. continue } if streamRequest.body != nil { go streamRequest.flushBody() } err := w.writeStreamData(streamRequest.stream, log) if err != nil { return err } w.idleTimer.MarkActive() case streamID := <-w.readyStreamChan: stream, ok := w.streams.Get(streamID) if !ok { continue } err := w.writeStreamData(stream, log) if err != nil { return err } w.idleTimer.MarkActive() case useDict := <-w.useDictChan: err := w.writeUseDictionary(useDict) if err != nil { log.Error().Msgf("mux - write: error writing use dictionary: %s", err) return err } w.idleTimer.MarkActive() } } } func (w *MuxWriter) writeStreamData(stream *MuxedStream, log *zerolog.Logger) error { log.Debug().Msgf("mux - write: writable: streamID: %d", stream.streamID) chunk := stream.getChunk() w.metricsUpdater.updateReceiveWindow(stream.getReceiveWindow()) w.metricsUpdater.updateSendWindow(stream.getSendWindow()) if chunk.sendHeadersFrame() { err := w.writeHeaders(chunk.streamID, chunk.headers) if err != nil { log.Error().Msgf("mux - write: error writing headers: %s: streamID: %d", err, stream.streamID) return err } log.Debug().Msgf("mux - write: output headers: streamID: %d", stream.streamID) } if chunk.sendWindowUpdateFrame() { // Send a WINDOW_UPDATE frame to update our receive window. // If the Stream ID is zero, the window update applies to the connection as a whole // RFC7540 section-6.9.1 "A receiver that receives a flow-controlled frame MUST // always account for its contribution against the connection flow-control // window, unless the receiver treats this as a connection error" err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate) if err != nil { log.Error().Msgf("mux - write: error writing window update: %s: streamID: %d", err, stream.streamID) return err } log.Debug().Msgf("mux - write: increment receive window by %d streamID: %d", chunk.windowUpdate, stream.streamID) } for chunk.sendDataFrame() { payload, sentEOF := chunk.nextDataFrame(int(w.maxFrameSize)) err := w.f.WriteData(chunk.streamID, sentEOF, payload) if err != nil { log.Error().Msgf("mux - write: error writing data: %s: streamID: %d", err, stream.streamID) return err } // update the amount of data wrote w.bytesWrote.IncrementBy(uint64(len(payload))) log.Debug().Msgf("mux - write: output data: %d: streamID: %d", len(payload), stream.streamID) if sentEOF { if stream.readBuffer.Closed() { // transition into closed state if !stream.gotReceiveEOF() { // the peer may send data that we no longer want to receive. Force them into the // closed state. log.Debug().Msgf("mux - write: resetting stream: streamID: %d", stream.streamID) w.f.WriteRSTStream(chunk.streamID, http2.ErrCodeNo) } else { // Half-open stream transitioned into closed log.Debug().Msgf("mux - write: closing stream: streamID: %d", stream.streamID) } w.streams.Delete(chunk.streamID) } else { log.Debug().Msgf("mux - write: closing stream write side: streamID: %d", stream.streamID) } } } return nil } func (w *MuxWriter) encodeHeaders(headers []Header) ([]byte, error) { w.headerBuffer.Reset() for _, header := range headers { err := w.headerEncoder.WriteField(hpack.HeaderField{ Name: header.Name, Value: header.Value, }) if err != nil { return nil, err } } return w.headerBuffer.Bytes(), nil } // writeHeaders writes a block of encoded headers, splitting it into multiple frames if necessary. func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error { encodedHeaders, err := w.encodeHeaders(headers) if err != nil || len(encodedHeaders) == 0 { return err } blockSize := int(w.maxFrameSize) // CONTINUATION is unnecessary; the headers fit within the blockSize if len(encodedHeaders) < blockSize { return w.f.WriteHeaders(http2.HeadersFrameParam{ StreamID: streamID, EndHeaders: true, BlockFragment: encodedHeaders, }) } choppedHeaders := chopEncodedHeaders(encodedHeaders, blockSize) // len(choppedHeaders) is at least 2 if err := w.f.WriteHeaders(http2.HeadersFrameParam{StreamID: streamID, EndHeaders: false, BlockFragment: choppedHeaders[0]}); err != nil { return err } for i := 1; i < len(choppedHeaders)-1; i++ { if err := w.f.WriteContinuation(streamID, false, choppedHeaders[i]); err != nil { return err } } if err := w.f.WriteContinuation(streamID, true, choppedHeaders[len(choppedHeaders)-1]); err != nil { return err } return nil } // Partition a slice of bytes into `len(slice) / blockSize` slices of length `blockSize` func chopEncodedHeaders(headers []byte, chunkSize int) [][]byte { var divided [][]byte for i := 0; i < len(headers); i += chunkSize { end := i + chunkSize if end > len(headers) { end = len(headers) } divided = append(divided, headers[i:end]) } return divided } func (w *MuxWriter) writeUseDictionary(dictRequest useDictRequest) error { err := w.f.WriteRawFrame(FrameUseDictionary, 0, dictRequest.streamID, []byte{byte(dictRequest.dictID)}) if err != nil { return err } payload := make([]byte, 0, 64) for _, set := range dictRequest.setDict { payload = append(payload, byte(set.dictID)) payload = appendVarInt(payload, 7, uint64(set.dictSZ)) payload = append(payload, 0x80) // E = 1, D = 0, Truncate = 0 } err = w.f.WriteRawFrame(FrameSetDictionary, FlagSetDictionaryAppend, dictRequest.streamID, payload) return err }