TUN-1358: Close readyList after Muxer.Serve() has stopped running

This commit is contained in:
Nick Vollmar 2019-01-16 07:57:30 -06:00
parent 62b1ab8c98
commit 5bf6dd8f85
8 changed files with 255 additions and 157 deletions

View File

@ -71,7 +71,7 @@ type Muxer struct {
// muxWriter is the write process. // muxWriter is the write process.
muxWriter *MuxWriter muxWriter *MuxWriter
// muxMetricsUpdater is the process to update metrics // muxMetricsUpdater is the process to update metrics
muxMetricsUpdater *muxMetricsUpdater muxMetricsUpdater muxMetricsUpdater
// newStreamChan is used to create new streams on the writer thread. // newStreamChan is used to create new streams on the writer thread.
// The writer will assign the next available stream ID. // The writer will assign the next available stream ID.
newStreamChan chan MuxedStreamRequest newStreamChan chan MuxedStreamRequest
@ -163,11 +163,6 @@ func Handshake(
// set up reader/writer pair ready for serve // set up reader/writer pair ready for serve
streamErrors := NewStreamErrorMap() streamErrors := NewStreamErrorMap()
goAwayChan := make(chan http2.ErrCode, 1) goAwayChan := make(chan http2.ErrCode, 1)
updateRTTChan := make(chan *roundTripMeasurement, 1)
updateReceiveWindowChan := make(chan uint32, 1)
updateSendWindowChan := make(chan uint32, 1)
updateInBoundBytesChan := make(chan uint64)
updateOutBoundBytesChan := make(chan uint64)
inBoundCounter := NewAtomicCounter(0) inBoundCounter := NewAtomicCounter(0)
outBoundCounter := NewAtomicCounter(0) outBoundCounter := NewAtomicCounter(0)
pingTimestamp := NewPingTimestamp() pingTimestamp := NewPingTimestamp()
@ -184,6 +179,14 @@ func Handshake(
config.Logger.Warn("Minimum number of unacked heartbeats to send before closing the connection has been adjusted to ", maxRetries) config.Logger.Warn("Minimum number of unacked heartbeats to send before closing the connection has been adjusted to ", maxRetries)
} }
compBytesBefore, compBytesAfter := NewAtomicCounter(0), NewAtomicCounter(0)
m.muxMetricsUpdater = newMuxMetricsUpdater(
m.abortChan,
compBytesBefore,
compBytesAfter,
)
m.explicitShutdown = NewBooleanFuse() m.explicitShutdown = NewBooleanFuse()
m.muxReader = &MuxReader{ m.muxReader = &MuxReader{
f: m.f, f: m.f,
@ -198,45 +201,27 @@ func Handshake(
initialStreamWindow: m.config.DefaultWindowSize, initialStreamWindow: m.config.DefaultWindowSize,
streamWindowMax: m.config.MaxWindowSize, streamWindowMax: m.config.MaxWindowSize,
streamWriteBufferMaxLen: m.config.StreamWriteBufferMaxLen, streamWriteBufferMaxLen: m.config.StreamWriteBufferMaxLen,
r: m.r, r: m.r,
updateRTTChan: updateRTTChan, metricsUpdater: m.muxMetricsUpdater,
updateReceiveWindowChan: updateReceiveWindowChan, bytesRead: inBoundCounter,
updateSendWindowChan: updateSendWindowChan,
bytesRead: inBoundCounter,
updateInBoundBytesChan: updateInBoundBytesChan,
} }
m.muxWriter = &MuxWriter{ m.muxWriter = &MuxWriter{
f: m.f, f: m.f,
streams: m.streams, streams: m.streams,
streamErrors: streamErrors, streamErrors: streamErrors,
readyStreamChan: m.readyList.ReadyChannel(), readyStreamChan: m.readyList.ReadyChannel(),
newStreamChan: m.newStreamChan, newStreamChan: m.newStreamChan,
goAwayChan: goAwayChan, goAwayChan: goAwayChan,
abortChan: m.abortChan, abortChan: m.abortChan,
pingTimestamp: pingTimestamp, pingTimestamp: pingTimestamp,
idleTimer: NewIdleTimer(idleDuration, maxRetries), idleTimer: NewIdleTimer(idleDuration, maxRetries),
connActiveChan: connActive.WaitChannel(), connActiveChan: connActive.WaitChannel(),
maxFrameSize: defaultFrameSize, maxFrameSize: defaultFrameSize,
updateReceiveWindowChan: updateReceiveWindowChan, metricsUpdater: m.muxMetricsUpdater,
updateSendWindowChan: updateSendWindowChan, bytesWrote: outBoundCounter,
bytesWrote: outBoundCounter,
updateOutBoundBytesChan: updateOutBoundBytesChan,
} }
m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer) m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
compBytesBefore, compBytesAfter := NewAtomicCounter(0), NewAtomicCounter(0)
m.muxMetricsUpdater = newMuxMetricsUpdater(
updateRTTChan,
updateReceiveWindowChan,
updateSendWindowChan,
updateInBoundBytesChan,
updateOutBoundBytesChan,
m.abortChan,
compBytesBefore,
compBytesAfter,
)
if m.compressionQuality.dictSize > 0 && m.compressionQuality.nDicts > 0 { if m.compressionQuality.dictSize > 0 && m.compressionQuality.nDicts > 0 {
nd, sz := m.compressionQuality.nDicts, m.compressionQuality.dictSize nd, sz := m.compressionQuality.nDicts, m.compressionQuality.dictSize
writeDicts, dictChan := newH2WriteDictionaries( writeDicts, dictChan := newH2WriteDictionaries(
@ -322,6 +307,12 @@ func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time.
return nil return nil
} }
// Serve runs the event loops that comprise h2mux:
// - MuxReader.run()
// - MuxWriter.run()
// - muxMetricsUpdater.run()
// In the normal case, Shutdown() is called concurrently with Serve() to stop
// these loops.
func (m *Muxer) Serve(ctx context.Context) error { func (m *Muxer) Serve(ctx context.Context) error {
errGroup, _ := errgroup.WithContext(ctx) errGroup, _ := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
@ -352,6 +343,7 @@ func (m *Muxer) Serve(ctx context.Context) error {
return nil return nil
} }
// Shutdown is called to initiate the "happy path" of muxer termination.
func (m *Muxer) Shutdown() { func (m *Muxer) Shutdown() {
m.explicitShutdown.Fuse(true) m.explicitShutdown.Fuse(true)
m.muxReader.Shutdown() m.muxReader.Shutdown()
@ -418,12 +410,13 @@ func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, erro
} }
func (m *Muxer) Metrics() *MuxerMetrics { func (m *Muxer) Metrics() *MuxerMetrics {
return m.muxMetricsUpdater.Metrics() return m.muxMetricsUpdater.metrics()
} }
func (m *Muxer) abort() { func (m *Muxer) abort() {
m.abortOnce.Do(func() { m.abortOnce.Do(func() {
close(m.abortChan) close(m.abortChan)
m.readyList.Close()
m.streams.Abort() m.streams.Abort()
}) })
} }

View File

@ -829,7 +829,8 @@ func TestMultipleStreamsWithDictionaries(t *testing.T) {
t.Fatalf("TestMultipleStreams failed") t.Fatalf("TestMultipleStreams failed")
} }
if q > CompressionNone && muxPair.OriginMux.muxMetricsUpdater.compBytesBefore.Value() <= 10*muxPair.OriginMux.muxMetricsUpdater.compBytesAfter.Value() { originMuxMetrics := muxPair.OriginMux.Metrics()
if q > CompressionNone && originMuxMetrics.CompBytesBefore.Value() <= 10*originMuxMetrics.CompBytesAfter.Value() {
t.Fatalf("Cross-stream compression is expected to give a better compression ratio") t.Fatalf("Cross-stream compression is expected to give a better compression ratio")
} }
} }
@ -927,7 +928,8 @@ func TestSampleSiteWithDictionaries(t *testing.T) {
} }
wg.Wait() wg.Wait()
if q > CompressionNone && muxPair.OriginMux.muxMetricsUpdater.compBytesBefore.Value() <= 10*muxPair.OriginMux.muxMetricsUpdater.compBytesAfter.Value() { originMuxMetrics := muxPair.OriginMux.Metrics()
if q > CompressionNone && originMuxMetrics.CompBytesBefore.Value() <= 10*originMuxMetrics.CompBytesAfter.Value() {
t.Fatalf("Cross-stream compression is expected to give a better compression ratio") t.Fatalf("Cross-stream compression is expected to give a better compression ratio")
} }
} }
@ -960,7 +962,8 @@ func TestLongSiteWithDictionaries(t *testing.T) {
} }
wg.Wait() wg.Wait()
if q > CompressionNone && muxPair.OriginMux.muxMetricsUpdater.compBytesBefore.Value() <= 100*muxPair.OriginMux.muxMetricsUpdater.compBytesAfter.Value() { originMuxMetrics := muxPair.OriginMux.Metrics()
if q > CompressionNone && originMuxMetrics.CompBytesBefore.Value() <= 100*originMuxMetrics.CompBytesAfter.Value() {
t.Fatalf("Cross-stream compression is expected to give a better compression ratio") t.Fatalf("Cross-stream compression is expected to give a better compression ratio")
} }
} }

View File

@ -17,7 +17,24 @@ const (
updateFreq = time.Second updateFreq = time.Second
) )
type muxMetricsUpdater struct { type muxMetricsUpdater interface {
// metrics returns the latest metrics
metrics() *MuxerMetrics
// run is a blocking call to start the event loop
run(logger *log.Entry) error
// updateRTTChan is called by muxReader to report new RTT measurements
updateRTT(rtt *roundTripMeasurement)
//updateReceiveWindowChan is called by muxReader and muxWriter when receiveWindow size is updated
updateReceiveWindow(receiveWindow uint32)
//updateSendWindowChan is called by muxReader and muxWriter when sendWindow size is updated
updateSendWindow(sendWindow uint32)
// updateInBoundBytesChan is called periodicallyby muxReader to report bytesRead
updateInBoundBytes(inBoundBytes uint64)
// updateOutBoundBytesChan is called periodically by muxWriter to report bytesWrote
updateOutBoundBytes(outBoundBytes uint64)
}
type muxMetricsUpdaterImpl struct {
// rttData keeps record of rtt, rttMin, rttMax and last measured time // rttData keeps record of rtt, rttMin, rttMax and last measured time
rttData *rttData rttData *rttData
// receiveWindowData keeps record of receive window measurement // receiveWindowData keeps record of receive window measurement
@ -28,16 +45,16 @@ type muxMetricsUpdater struct {
inBoundRate *rate inBoundRate *rate
// outBoundRate is outgoing bytes/sec // outBoundRate is outgoing bytes/sec
outBoundRate *rate outBoundRate *rate
// updateRTTChan is the channel to receive new RTT measurement from muxReader // updateRTTChan is the channel to receive new RTT measurement
updateRTTChan <-chan *roundTripMeasurement updateRTTChan chan *roundTripMeasurement
//updateReceiveWindowChan is the channel to receive updated receiveWindow size from muxReader and muxWriter //updateReceiveWindowChan is the channel to receive updated receiveWindow size
updateReceiveWindowChan <-chan uint32 updateReceiveWindowChan chan uint32
//updateSendWindowChan is the channel to receive updated sendWindow size from muxReader and muxWriter //updateSendWindowChan is the channel to receive updated sendWindow size
updateSendWindowChan <-chan uint32 updateSendWindowChan chan uint32
// updateInBoundBytesChan us the channel to receive bytesRead from muxReader // updateInBoundBytesChan us the channel to receive bytesRead
updateInBoundBytesChan <-chan uint64 updateInBoundBytesChan chan uint64
// updateOutBoundBytesChan us the channel to receive bytesWrote from muxWriter // updateOutBoundBytesChan us the channel to receive bytesWrote
updateOutBoundBytesChan <-chan uint64 updateOutBoundBytesChan chan uint64
// shutdownC is to signal the muxerMetricsUpdater to shutdown // shutdownC is to signal the muxerMetricsUpdater to shutdown
abortChan <-chan struct{} abortChan <-chan struct{}
@ -84,15 +101,16 @@ type rate struct {
} }
func newMuxMetricsUpdater( func newMuxMetricsUpdater(
updateRTTChan <-chan *roundTripMeasurement,
updateReceiveWindowChan <-chan uint32,
updateSendWindowChan <-chan uint32,
updateInBoundBytesChan <-chan uint64,
updateOutBoundBytesChan <-chan uint64,
abortChan <-chan struct{}, abortChan <-chan struct{},
compBytesBefore, compBytesAfter *AtomicCounter, compBytesBefore, compBytesAfter *AtomicCounter,
) *muxMetricsUpdater { ) muxMetricsUpdater {
return &muxMetricsUpdater{ updateRTTChan := make(chan *roundTripMeasurement, 1)
updateReceiveWindowChan := make(chan uint32, 1)
updateSendWindowChan := make(chan uint32, 1)
updateInBoundBytesChan := make(chan uint64)
updateOutBoundBytesChan := make(chan uint64)
return &muxMetricsUpdaterImpl{
rttData: newRTTData(), rttData: newRTTData(),
receiveWindowData: newFlowControlData(), receiveWindowData: newFlowControlData(),
sendWindowData: newFlowControlData(), sendWindowData: newFlowControlData(),
@ -109,7 +127,7 @@ func newMuxMetricsUpdater(
} }
} }
func (updater *muxMetricsUpdater) Metrics() *MuxerMetrics { func (updater *muxMetricsUpdaterImpl) metrics() *MuxerMetrics {
m := &MuxerMetrics{} m := &MuxerMetrics{}
m.RTT, m.RTTMin, m.RTTMax = updater.rttData.metrics() m.RTT, m.RTTMin, m.RTTMax = updater.rttData.metrics()
m.ReceiveWindowAve, m.ReceiveWindowMin, m.ReceiveWindowMax = updater.receiveWindowData.metrics() m.ReceiveWindowAve, m.ReceiveWindowMin, m.ReceiveWindowMax = updater.receiveWindowData.metrics()
@ -120,7 +138,7 @@ func (updater *muxMetricsUpdater) Metrics() *MuxerMetrics {
return m return m
} }
func (updater *muxMetricsUpdater) run(parentLogger *log.Entry) error { func (updater *muxMetricsUpdaterImpl) run(parentLogger *log.Entry) error {
logger := parentLogger.WithFields(log.Fields{ logger := parentLogger.WithFields(log.Fields{
"subsystem": "mux", "subsystem": "mux",
"dir": "metrics", "dir": "metrics",
@ -152,6 +170,43 @@ func (updater *muxMetricsUpdater) run(parentLogger *log.Entry) error {
} }
} }
func (updater *muxMetricsUpdaterImpl) updateRTT(rtt *roundTripMeasurement) {
select {
case updater.updateRTTChan <- rtt:
case <-updater.abortChan:
}
}
func (updater *muxMetricsUpdaterImpl) updateReceiveWindow(receiveWindow uint32) {
select {
case updater.updateReceiveWindowChan <- receiveWindow:
case <-updater.abortChan:
}
}
func (updater *muxMetricsUpdaterImpl) updateSendWindow(sendWindow uint32) {
select {
case updater.updateSendWindowChan <- sendWindow:
case <-updater.abortChan:
}
}
func (updater *muxMetricsUpdaterImpl) updateInBoundBytes(inBoundBytes uint64) {
select {
case updater.updateInBoundBytesChan <- inBoundBytes:
case <-updater.abortChan:
}
}
func (updater *muxMetricsUpdaterImpl) updateOutBoundBytes(outBoundBytes uint64) {
select {
case updater.updateOutBoundBytesChan <- outBoundBytes:
case <-updater.abortChan:
}
}
func newRTTData() *rttData { func newRTTData() *rttData {
return &rttData{} return &rttData{}
} }

View File

@ -86,24 +86,11 @@ func TestFlowControlDataUpdate(t *testing.T) {
} }
func TestMuxMetricsUpdater(t *testing.T) { func TestMuxMetricsUpdater(t *testing.T) {
t.Skip("Race condition") t.Skip("Inherently racy test due to muxMetricsUpdaterImpl.run()")
updateRTTChan := make(chan *roundTripMeasurement)
updateReceiveWindowChan := make(chan uint32)
updateSendWindowChan := make(chan uint32)
updateInBoundBytesChan := make(chan uint64)
updateOutBoundBytesChan := make(chan uint64)
abortChan := make(chan struct{})
errChan := make(chan error) errChan := make(chan error)
abortChan := make(chan struct{})
compBefore, compAfter := NewAtomicCounter(0), NewAtomicCounter(0) compBefore, compAfter := NewAtomicCounter(0), NewAtomicCounter(0)
m := newMuxMetricsUpdater(updateRTTChan, m := newMuxMetricsUpdater(abortChan, compBefore, compAfter)
updateReceiveWindowChan,
updateSendWindowChan,
updateInBoundBytesChan,
updateOutBoundBytesChan,
abortChan,
compBefore,
compAfter,
)
logger := log.NewEntry(log.New()) logger := log.NewEntry(log.New())
go func() { go func() {
@ -116,42 +103,44 @@ func TestMuxMetricsUpdater(t *testing.T) {
// mock muxReader // mock muxReader
readerStart := time.Now() readerStart := time.Now()
rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart} rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart}
updateRTTChan <- rm m.updateRTT(rm)
go func() { go func() {
defer wg.Done() defer wg.Done()
// Becareful if dataPoints is not divisibile by 4 assert.Equal(t, 0, dataPoints%4,
"dataPoints is not divisible by 4; this test should be adjusted accordingly")
readerSend := readerStart.Add(time.Millisecond) readerSend := readerStart.Add(time.Millisecond)
for i := 1; i <= dataPoints/4; i++ { for i := 1; i <= dataPoints/4; i++ {
readerReceive := readerSend.Add(time.Duration(i) * time.Millisecond) readerReceive := readerSend.Add(time.Duration(i) * time.Millisecond)
rm := &roundTripMeasurement{receiveTime: readerReceive, sendTime: readerSend} rm := &roundTripMeasurement{receiveTime: readerReceive, sendTime: readerSend}
updateRTTChan <- rm m.updateRTT(rm)
readerSend = readerReceive.Add(time.Millisecond) readerSend = readerReceive.Add(time.Millisecond)
m.updateReceiveWindow(uint32(i))
m.updateSendWindow(uint32(i))
updateReceiveWindowChan <- uint32(i) m.updateInBoundBytes(uint64(i))
updateSendWindowChan <- uint32(i)
updateInBoundBytesChan <- uint64(i)
} }
}() }()
// mock muxWriter // mock muxWriter
go func() { go func() {
defer wg.Done() defer wg.Done()
assert.Equal(t, 0, dataPoints%4,
"dataPoints is not divisible by 4; this test should be adjusted accordingly")
for j := dataPoints/4 + 1; j <= dataPoints/2; j++ { for j := dataPoints/4 + 1; j <= dataPoints/2; j++ {
updateReceiveWindowChan <- uint32(j) m.updateReceiveWindow(uint32(j))
updateSendWindowChan <- uint32(j) m.updateSendWindow(uint32(j))
// should always be disgard since the send time is before readerSend // should always be disgarded since the send time is before readerSend
rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart.Add(-time.Duration(j*dataPoints) * time.Millisecond)} rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart.Add(-time.Duration(j*dataPoints) * time.Millisecond)}
updateRTTChan <- rm m.updateRTT(rm)
updateOutBoundBytesChan <- uint64(j) m.updateOutBoundBytes(uint64(j))
} }
}() }()
wg.Wait() wg.Wait()
metrics := m.Metrics() metrics := m.metrics()
points := dataPoints / 2 points := dataPoints / 2
assert.Equal(t, time.Millisecond, metrics.RTTMin) assert.Equal(t, time.Millisecond, metrics.RTTMin)
assert.Equal(t, time.Duration(dataPoints/4)*time.Millisecond, metrics.RTTMax) assert.Equal(t, time.Duration(dataPoints/4)*time.Millisecond, metrics.RTTMax)

View File

@ -39,16 +39,10 @@ type MuxReader struct {
streamWriteBufferMaxLen int streamWriteBufferMaxLen int
// r is a reference to the underlying connection used when shutting down. // r is a reference to the underlying connection used when shutting down.
r io.Closer r io.Closer
// updateRTTChan is the channel to send new RTT measurement to muxerMetricsUpdater // metricsUpdater is used to report metrics
updateRTTChan chan<- *roundTripMeasurement metricsUpdater muxMetricsUpdater
// updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater // bytesRead is the amount of bytes read from data frames since the last time we called metricsUpdater.updateInBoundBytes()
updateReceiveWindowChan chan<- uint32
// updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater
updateSendWindowChan chan<- uint32
// bytesRead is the amount of bytes read from data frame since the last time we send bytes read to metrics
bytesRead *AtomicCounter bytesRead *AtomicCounter
// updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
updateInBoundBytesChan chan<- uint64
// dictionaries holds the h2 cross-stream compression dictionaries // dictionaries holds the h2 cross-stream compression dictionaries
dictionaries h2Dictionaries dictionaries h2Dictionaries
} }
@ -81,7 +75,7 @@ func (r *MuxReader) run(parentLogger *log.Entry) error {
case <-r.abortChan: case <-r.abortChan:
return return
case <-tickC: case <-tickC:
r.updateInBoundBytesChan <- r.bytesRead.Count() r.metricsUpdater.updateInBoundBytes(r.bytesRead.Count())
} }
} }
}() }()
@ -289,7 +283,7 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E
if !stream.consumeReceiveWindow(uint32(len(data))) { if !stream.consumeReceiveWindow(uint32(len(data))) {
return r.streamError(stream.streamID, http2.ErrCodeFlowControl) return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
} }
r.updateReceiveWindowChan <- stream.getReceiveWindow() r.metricsUpdater.updateReceiveWindow(stream.getReceiveWindow())
return nil return nil
} }
@ -301,13 +295,13 @@ func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
return return
} }
// Update updates the computed values with a new measurement. // Update the computed RTT aggregations with a new measurement.
// outgoingTime is the time that the probe was sent. // `ts` is the time that the probe was sent.
// We assume that time.Now() is the time we received that probe. // We assume that `time.Now()` is the time we received that probe.
r.updateRTTChan <- &roundTripMeasurement{ r.metricsUpdater.updateRTT(&roundTripMeasurement{
receiveTime: time.Now(), receiveTime: time.Now(),
sendTime: time.Unix(0, ts), sendTime: time.Unix(0, ts),
} })
} }
// Receive a GOAWAY from the peer. Gracefully shut down our connection. // Receive a GOAWAY from the peer. Gracefully shut down our connection.
@ -468,7 +462,7 @@ func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
return nil return nil
} }
stream.replenishSendWindow(frame.Increment) stream.replenishSendWindow(frame.Increment)
r.updateSendWindowChan <- stream.getSendWindow() r.metricsUpdater.updateSendWindow(stream.getSendWindow())
return nil return nil
} }

View File

@ -40,14 +40,11 @@ type MuxWriter struct {
headerEncoder *hpack.Encoder headerEncoder *hpack.Encoder
// headerBuffer is the temporary buffer used by headerEncoder. // headerBuffer is the temporary buffer used by headerEncoder.
headerBuffer bytes.Buffer headerBuffer bytes.Buffer
// updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater
updateReceiveWindowChan chan<- uint32 // metricsUpdater is used to report metrics
// updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater metricsUpdater muxMetricsUpdater
updateSendWindowChan chan<- uint32 // bytesWrote is the amount of bytes written to data frames since the last time we called metricsUpdater.updateOutBoundBytes()
// bytesWrote is the amount of bytes wrote to data frame since the last time we send bytes wrote to metrics
bytesWrote *AtomicCounter bytesWrote *AtomicCounter
// updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater
updateOutBoundBytesChan chan<- uint64
useDictChan <-chan useDictRequest useDictChan <-chan useDictRequest
} }
@ -83,7 +80,7 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error {
case <-w.abortChan: case <-w.abortChan:
return return
case <-tickC: case <-tickC:
w.updateOutBoundBytesChan <- w.bytesWrote.Count() w.metricsUpdater.updateOutBoundBytes(w.bytesWrote.Count())
} }
} }
}() }()
@ -172,8 +169,8 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error {
func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error { func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error {
logger.Debug("writable") logger.Debug("writable")
chunk := stream.getChunk() chunk := stream.getChunk()
w.updateReceiveWindowChan <- stream.getReceiveWindow() w.metricsUpdater.updateReceiveWindow(stream.getReceiveWindow())
w.updateSendWindowChan <- stream.getSendWindow() w.metricsUpdater.updateSendWindow(stream.getSendWindow())
if chunk.sendHeadersFrame() { if chunk.sendHeadersFrame() {
err := w.writeHeaders(chunk.streamID, chunk.headers) err := w.writeHeaders(chunk.streamID, chunk.headers)
if err != nil { if err != nil {

View File

@ -1,15 +1,23 @@
package h2mux package h2mux
import "sync"
// ReadyList multiplexes several event signals onto a single channel. // ReadyList multiplexes several event signals onto a single channel.
type ReadyList struct { type ReadyList struct {
// signalC is used to signal that a stream can be enqueued
signalC chan uint32 signalC chan uint32
waitC chan uint32 // waitC is used to signal the ID of the first ready descriptor
waitC chan uint32
// doneC is used to signal that run should terminate
doneC chan struct{}
closeOnce sync.Once
} }
func NewReadyList() *ReadyList { func NewReadyList() *ReadyList {
rl := &ReadyList{ rl := &ReadyList{
signalC: make(chan uint32), signalC: make(chan uint32),
waitC: make(chan uint32), waitC: make(chan uint32),
doneC: make(chan struct{}),
} }
go rl.run() go rl.run()
return rl return rl
@ -17,7 +25,11 @@ func NewReadyList() *ReadyList {
// ID is the stream ID // ID is the stream ID
func (r *ReadyList) Signal(ID uint32) { func (r *ReadyList) Signal(ID uint32) {
r.signalC <- ID select {
case r.signalC <- ID:
// ReadyList already closed
case <-r.doneC:
}
} }
func (r *ReadyList) ReadyChannel() <-chan uint32 { func (r *ReadyList) ReadyChannel() <-chan uint32 {
@ -25,7 +37,9 @@ func (r *ReadyList) ReadyChannel() <-chan uint32 {
} }
func (r *ReadyList) Close() { func (r *ReadyList) Close() {
close(r.signalC) r.closeOnce.Do(func() {
close(r.doneC)
})
} }
func (r *ReadyList) run() { func (r *ReadyList) run() {
@ -35,28 +49,25 @@ func (r *ReadyList) run() {
activeDescriptors := newReadyDescriptorMap() activeDescriptors := newReadyDescriptorMap()
for { for {
if firstReady == nil { if firstReady == nil {
// Wait for first ready descriptor select {
i, ok := <-r.signalC case i := <-r.signalC:
if !ok { firstReady = activeDescriptors.SetIfMissing(i)
// closed case <-r.doneC:
return return
} }
firstReady = activeDescriptors.SetIfMissing(i)
} }
select { select {
case r.waitC <- firstReady.ID: case r.waitC <- firstReady.ID:
activeDescriptors.Delete(firstReady.ID) activeDescriptors.Delete(firstReady.ID)
firstReady = queue.Dequeue() firstReady = queue.Dequeue()
case i, ok := <-r.signalC: case i := <-r.signalC:
if !ok {
// closed
return
}
newReady := activeDescriptors.SetIfMissing(i) newReady := activeDescriptors.SetIfMissing(i)
if newReady != nil { if newReady != nil {
// key doesn't exist // key doesn't exist
queue.Enqueue(newReady) queue.Enqueue(newReady)
} }
case <-r.doneC:
return
} }
} }
} }

View File

@ -3,36 +3,59 @@ package h2mux
import ( import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func TestReadyList(t *testing.T) { func assertEmpty(t *testing.T, rl *ReadyList) {
select {
case <-rl.ReadyChannel():
t.Fatal("Spurious wakeup")
default:
}
}
func assertClosed(t *testing.T, rl *ReadyList) {
select {
case _, ok := <-rl.ReadyChannel():
assert.False(t, ok, "ReadyChannel was not closed")
case <-time.After(100 * time.Millisecond):
t.Fatalf("Timeout")
}
}
func receiveWithTimeout(t *testing.T, rl *ReadyList) uint32 {
select {
case i := <-rl.ReadyChannel():
return i
case <-time.After(100 * time.Millisecond):
t.Fatalf("Timeout")
return 0
}
}
func TestReadyListEmpty(t *testing.T) {
rl := NewReadyList() rl := NewReadyList()
c := rl.ReadyChannel()
// helper functions
assertEmpty := func() {
select {
case <-c:
t.Fatalf("Spurious wakeup")
default:
}
}
receiveWithTimeout := func() uint32 {
select {
case i := <-c:
return i
case <-time.After(100 * time.Millisecond):
t.Fatalf("Timeout")
return 0
}
}
// no signals, receive should fail // no signals, receive should fail
assertEmpty() assertEmpty(t, rl)
}
func TestReadyListSignal(t *testing.T) {
rl := NewReadyList()
assertEmpty(t, rl)
rl.Signal(0) rl.Signal(0)
if receiveWithTimeout() != 0 { if receiveWithTimeout(t, rl) != 0 {
t.Fatalf("Received wrong ID of signalled event") t.Fatalf("Received wrong ID of signalled event")
} }
// no new signals, receive should fail
assertEmpty() assertEmpty(t, rl)
}
func TestReadyListMultipleSignals(t *testing.T) {
rl := NewReadyList()
assertEmpty(t, rl)
// Signals should not block; // Signals should not block;
// Duplicate unhandled signals should not cause multiple wakeups // Duplicate unhandled signals should not cause multiple wakeups
signalled := [5]bool{} signalled := [5]bool{}
@ -42,12 +65,45 @@ func TestReadyList(t *testing.T) {
} }
// All signals should be received once (in any order) // All signals should be received once (in any order)
for range signalled { for range signalled {
i := receiveWithTimeout() i := receiveWithTimeout(t, rl)
if signalled[i] { if signalled[i] {
t.Fatalf("Received signal %d more than once", i) t.Fatalf("Received signal %d more than once", i)
} }
signalled[i] = true signalled[i] = true
} }
for i := range signalled {
if !signalled[i] {
t.Fatalf("Never received signal %d", i)
}
}
assertEmpty(t, rl)
}
func TestReadyListClose(t *testing.T) {
rl := NewReadyList()
rl.Close()
// readyList.run() occurs in a separate goroutine,
// so there's no way to directly check that run() has terminated.
// Perform an indirect check: is the ready channel closed?
assertClosed(t, rl)
// a second rl.Close() shouldn't cause a panic
rl.Close()
// Signal shouldn't block after Close()
done := make(chan struct{})
go func() {
for i := 0; i < 5; i++ {
rl.Signal(uint32(i))
}
close(done)
}()
select {
case <-done:
case <-time.After(100 * time.Millisecond):
t.Fatal("Test timed out")
}
} }
func TestReadyDescriptorQueue(t *testing.T) { func TestReadyDescriptorQueue(t *testing.T) {