diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go index 055348f0..1341c127 100644 --- a/h2mux/h2mux.go +++ b/h2mux/h2mux.go @@ -71,7 +71,7 @@ type Muxer struct { // muxWriter is the write process. muxWriter *MuxWriter // muxMetricsUpdater is the process to update metrics - muxMetricsUpdater *muxMetricsUpdater + muxMetricsUpdater muxMetricsUpdater // newStreamChan is used to create new streams on the writer thread. // The writer will assign the next available stream ID. newStreamChan chan MuxedStreamRequest @@ -163,11 +163,6 @@ func Handshake( // set up reader/writer pair ready for serve streamErrors := NewStreamErrorMap() goAwayChan := make(chan http2.ErrCode, 1) - updateRTTChan := make(chan *roundTripMeasurement, 1) - updateReceiveWindowChan := make(chan uint32, 1) - updateSendWindowChan := make(chan uint32, 1) - updateInBoundBytesChan := make(chan uint64) - updateOutBoundBytesChan := make(chan uint64) inBoundCounter := NewAtomicCounter(0) outBoundCounter := NewAtomicCounter(0) pingTimestamp := NewPingTimestamp() @@ -184,6 +179,14 @@ func Handshake( config.Logger.Warn("Minimum number of unacked heartbeats to send before closing the connection has been adjusted to ", maxRetries) } + compBytesBefore, compBytesAfter := NewAtomicCounter(0), NewAtomicCounter(0) + + m.muxMetricsUpdater = newMuxMetricsUpdater( + m.abortChan, + compBytesBefore, + compBytesAfter, + ) + m.explicitShutdown = NewBooleanFuse() m.muxReader = &MuxReader{ f: m.f, @@ -198,45 +201,27 @@ func Handshake( initialStreamWindow: m.config.DefaultWindowSize, streamWindowMax: m.config.MaxWindowSize, streamWriteBufferMaxLen: m.config.StreamWriteBufferMaxLen, - r: m.r, - updateRTTChan: updateRTTChan, - updateReceiveWindowChan: updateReceiveWindowChan, - updateSendWindowChan: updateSendWindowChan, - bytesRead: inBoundCounter, - updateInBoundBytesChan: updateInBoundBytesChan, + r: m.r, + metricsUpdater: m.muxMetricsUpdater, + bytesRead: inBoundCounter, } m.muxWriter = &MuxWriter{ - f: m.f, - streams: m.streams, - streamErrors: streamErrors, - readyStreamChan: m.readyList.ReadyChannel(), - newStreamChan: m.newStreamChan, - goAwayChan: goAwayChan, - abortChan: m.abortChan, - pingTimestamp: pingTimestamp, - idleTimer: NewIdleTimer(idleDuration, maxRetries), - connActiveChan: connActive.WaitChannel(), - maxFrameSize: defaultFrameSize, - updateReceiveWindowChan: updateReceiveWindowChan, - updateSendWindowChan: updateSendWindowChan, - bytesWrote: outBoundCounter, - updateOutBoundBytesChan: updateOutBoundBytesChan, + f: m.f, + streams: m.streams, + streamErrors: streamErrors, + readyStreamChan: m.readyList.ReadyChannel(), + newStreamChan: m.newStreamChan, + goAwayChan: goAwayChan, + abortChan: m.abortChan, + pingTimestamp: pingTimestamp, + idleTimer: NewIdleTimer(idleDuration, maxRetries), + connActiveChan: connActive.WaitChannel(), + maxFrameSize: defaultFrameSize, + metricsUpdater: m.muxMetricsUpdater, + bytesWrote: outBoundCounter, } m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer) - compBytesBefore, compBytesAfter := NewAtomicCounter(0), NewAtomicCounter(0) - - m.muxMetricsUpdater = newMuxMetricsUpdater( - updateRTTChan, - updateReceiveWindowChan, - updateSendWindowChan, - updateInBoundBytesChan, - updateOutBoundBytesChan, - m.abortChan, - compBytesBefore, - compBytesAfter, - ) - if m.compressionQuality.dictSize > 0 && m.compressionQuality.nDicts > 0 { nd, sz := m.compressionQuality.nDicts, m.compressionQuality.dictSize writeDicts, dictChan := newH2WriteDictionaries( @@ -322,6 +307,12 @@ func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time. return nil } +// Serve runs the event loops that comprise h2mux: +// - MuxReader.run() +// - MuxWriter.run() +// - muxMetricsUpdater.run() +// In the normal case, Shutdown() is called concurrently with Serve() to stop +// these loops. func (m *Muxer) Serve(ctx context.Context) error { errGroup, _ := errgroup.WithContext(ctx) errGroup.Go(func() error { @@ -352,6 +343,7 @@ func (m *Muxer) Serve(ctx context.Context) error { return nil } +// Shutdown is called to initiate the "happy path" of muxer termination. func (m *Muxer) Shutdown() { m.explicitShutdown.Fuse(true) m.muxReader.Shutdown() @@ -418,12 +410,13 @@ func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, erro } func (m *Muxer) Metrics() *MuxerMetrics { - return m.muxMetricsUpdater.Metrics() + return m.muxMetricsUpdater.metrics() } func (m *Muxer) abort() { m.abortOnce.Do(func() { close(m.abortChan) + m.readyList.Close() m.streams.Abort() }) } diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go index 6c5ac73a..241f8a0e 100644 --- a/h2mux/h2mux_test.go +++ b/h2mux/h2mux_test.go @@ -829,7 +829,8 @@ func TestMultipleStreamsWithDictionaries(t *testing.T) { t.Fatalf("TestMultipleStreams failed") } - if q > CompressionNone && muxPair.OriginMux.muxMetricsUpdater.compBytesBefore.Value() <= 10*muxPair.OriginMux.muxMetricsUpdater.compBytesAfter.Value() { + originMuxMetrics := muxPair.OriginMux.Metrics() + if q > CompressionNone && originMuxMetrics.CompBytesBefore.Value() <= 10*originMuxMetrics.CompBytesAfter.Value() { t.Fatalf("Cross-stream compression is expected to give a better compression ratio") } } @@ -927,7 +928,8 @@ func TestSampleSiteWithDictionaries(t *testing.T) { } wg.Wait() - if q > CompressionNone && muxPair.OriginMux.muxMetricsUpdater.compBytesBefore.Value() <= 10*muxPair.OriginMux.muxMetricsUpdater.compBytesAfter.Value() { + originMuxMetrics := muxPair.OriginMux.Metrics() + if q > CompressionNone && originMuxMetrics.CompBytesBefore.Value() <= 10*originMuxMetrics.CompBytesAfter.Value() { t.Fatalf("Cross-stream compression is expected to give a better compression ratio") } } @@ -960,7 +962,8 @@ func TestLongSiteWithDictionaries(t *testing.T) { } wg.Wait() - if q > CompressionNone && muxPair.OriginMux.muxMetricsUpdater.compBytesBefore.Value() <= 100*muxPair.OriginMux.muxMetricsUpdater.compBytesAfter.Value() { + originMuxMetrics := muxPair.OriginMux.Metrics() + if q > CompressionNone && originMuxMetrics.CompBytesBefore.Value() <= 100*originMuxMetrics.CompBytesAfter.Value() { t.Fatalf("Cross-stream compression is expected to give a better compression ratio") } } diff --git a/h2mux/muxmetrics.go b/h2mux/muxmetrics.go index cb241bf6..feea307e 100644 --- a/h2mux/muxmetrics.go +++ b/h2mux/muxmetrics.go @@ -17,7 +17,24 @@ const ( updateFreq = time.Second ) -type muxMetricsUpdater struct { +type muxMetricsUpdater interface { + // metrics returns the latest metrics + metrics() *MuxerMetrics + // run is a blocking call to start the event loop + run(logger *log.Entry) error + // updateRTTChan is called by muxReader to report new RTT measurements + updateRTT(rtt *roundTripMeasurement) + //updateReceiveWindowChan is called by muxReader and muxWriter when receiveWindow size is updated + updateReceiveWindow(receiveWindow uint32) + //updateSendWindowChan is called by muxReader and muxWriter when sendWindow size is updated + updateSendWindow(sendWindow uint32) + // updateInBoundBytesChan is called periodicallyby muxReader to report bytesRead + updateInBoundBytes(inBoundBytes uint64) + // updateOutBoundBytesChan is called periodically by muxWriter to report bytesWrote + updateOutBoundBytes(outBoundBytes uint64) +} + +type muxMetricsUpdaterImpl struct { // rttData keeps record of rtt, rttMin, rttMax and last measured time rttData *rttData // receiveWindowData keeps record of receive window measurement @@ -28,16 +45,16 @@ type muxMetricsUpdater struct { inBoundRate *rate // outBoundRate is outgoing bytes/sec outBoundRate *rate - // updateRTTChan is the channel to receive new RTT measurement from muxReader - updateRTTChan <-chan *roundTripMeasurement - //updateReceiveWindowChan is the channel to receive updated receiveWindow size from muxReader and muxWriter - updateReceiveWindowChan <-chan uint32 - //updateSendWindowChan is the channel to receive updated sendWindow size from muxReader and muxWriter - updateSendWindowChan <-chan uint32 - // updateInBoundBytesChan us the channel to receive bytesRead from muxReader - updateInBoundBytesChan <-chan uint64 - // updateOutBoundBytesChan us the channel to receive bytesWrote from muxWriter - updateOutBoundBytesChan <-chan uint64 + // updateRTTChan is the channel to receive new RTT measurement + updateRTTChan chan *roundTripMeasurement + //updateReceiveWindowChan is the channel to receive updated receiveWindow size + updateReceiveWindowChan chan uint32 + //updateSendWindowChan is the channel to receive updated sendWindow size + updateSendWindowChan chan uint32 + // updateInBoundBytesChan us the channel to receive bytesRead + updateInBoundBytesChan chan uint64 + // updateOutBoundBytesChan us the channel to receive bytesWrote + updateOutBoundBytesChan chan uint64 // shutdownC is to signal the muxerMetricsUpdater to shutdown abortChan <-chan struct{} @@ -84,15 +101,16 @@ type rate struct { } func newMuxMetricsUpdater( - updateRTTChan <-chan *roundTripMeasurement, - updateReceiveWindowChan <-chan uint32, - updateSendWindowChan <-chan uint32, - updateInBoundBytesChan <-chan uint64, - updateOutBoundBytesChan <-chan uint64, abortChan <-chan struct{}, compBytesBefore, compBytesAfter *AtomicCounter, -) *muxMetricsUpdater { - return &muxMetricsUpdater{ +) muxMetricsUpdater { + updateRTTChan := make(chan *roundTripMeasurement, 1) + updateReceiveWindowChan := make(chan uint32, 1) + updateSendWindowChan := make(chan uint32, 1) + updateInBoundBytesChan := make(chan uint64) + updateOutBoundBytesChan := make(chan uint64) + + return &muxMetricsUpdaterImpl{ rttData: newRTTData(), receiveWindowData: newFlowControlData(), sendWindowData: newFlowControlData(), @@ -109,7 +127,7 @@ func newMuxMetricsUpdater( } } -func (updater *muxMetricsUpdater) Metrics() *MuxerMetrics { +func (updater *muxMetricsUpdaterImpl) metrics() *MuxerMetrics { m := &MuxerMetrics{} m.RTT, m.RTTMin, m.RTTMax = updater.rttData.metrics() m.ReceiveWindowAve, m.ReceiveWindowMin, m.ReceiveWindowMax = updater.receiveWindowData.metrics() @@ -120,7 +138,7 @@ func (updater *muxMetricsUpdater) Metrics() *MuxerMetrics { return m } -func (updater *muxMetricsUpdater) run(parentLogger *log.Entry) error { +func (updater *muxMetricsUpdaterImpl) run(parentLogger *log.Entry) error { logger := parentLogger.WithFields(log.Fields{ "subsystem": "mux", "dir": "metrics", @@ -152,6 +170,43 @@ func (updater *muxMetricsUpdater) run(parentLogger *log.Entry) error { } } +func (updater *muxMetricsUpdaterImpl) updateRTT(rtt *roundTripMeasurement) { + select { + case updater.updateRTTChan <- rtt: + case <-updater.abortChan: + } + +} + +func (updater *muxMetricsUpdaterImpl) updateReceiveWindow(receiveWindow uint32) { + select { + case updater.updateReceiveWindowChan <- receiveWindow: + case <-updater.abortChan: + } +} + +func (updater *muxMetricsUpdaterImpl) updateSendWindow(sendWindow uint32) { + select { + case updater.updateSendWindowChan <- sendWindow: + case <-updater.abortChan: + } +} + +func (updater *muxMetricsUpdaterImpl) updateInBoundBytes(inBoundBytes uint64) { + select { + case updater.updateInBoundBytesChan <- inBoundBytes: + case <-updater.abortChan: + } + +} + +func (updater *muxMetricsUpdaterImpl) updateOutBoundBytes(outBoundBytes uint64) { + select { + case updater.updateOutBoundBytesChan <- outBoundBytes: + case <-updater.abortChan: + } +} + func newRTTData() *rttData { return &rttData{} } diff --git a/h2mux/muxmetrics_test.go b/h2mux/muxmetrics_test.go index 7a9d4792..f74e5dda 100644 --- a/h2mux/muxmetrics_test.go +++ b/h2mux/muxmetrics_test.go @@ -86,24 +86,11 @@ func TestFlowControlDataUpdate(t *testing.T) { } func TestMuxMetricsUpdater(t *testing.T) { - t.Skip("Race condition") - updateRTTChan := make(chan *roundTripMeasurement) - updateReceiveWindowChan := make(chan uint32) - updateSendWindowChan := make(chan uint32) - updateInBoundBytesChan := make(chan uint64) - updateOutBoundBytesChan := make(chan uint64) - abortChan := make(chan struct{}) + t.Skip("Inherently racy test due to muxMetricsUpdaterImpl.run()") errChan := make(chan error) + abortChan := make(chan struct{}) compBefore, compAfter := NewAtomicCounter(0), NewAtomicCounter(0) - m := newMuxMetricsUpdater(updateRTTChan, - updateReceiveWindowChan, - updateSendWindowChan, - updateInBoundBytesChan, - updateOutBoundBytesChan, - abortChan, - compBefore, - compAfter, - ) + m := newMuxMetricsUpdater(abortChan, compBefore, compAfter) logger := log.NewEntry(log.New()) go func() { @@ -116,42 +103,44 @@ func TestMuxMetricsUpdater(t *testing.T) { // mock muxReader readerStart := time.Now() rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart} - updateRTTChan <- rm + m.updateRTT(rm) go func() { defer wg.Done() - // Becareful if dataPoints is not divisibile by 4 + assert.Equal(t, 0, dataPoints%4, + "dataPoints is not divisible by 4; this test should be adjusted accordingly") readerSend := readerStart.Add(time.Millisecond) for i := 1; i <= dataPoints/4; i++ { readerReceive := readerSend.Add(time.Duration(i) * time.Millisecond) rm := &roundTripMeasurement{receiveTime: readerReceive, sendTime: readerSend} - updateRTTChan <- rm + m.updateRTT(rm) readerSend = readerReceive.Add(time.Millisecond) + m.updateReceiveWindow(uint32(i)) + m.updateSendWindow(uint32(i)) - updateReceiveWindowChan <- uint32(i) - updateSendWindowChan <- uint32(i) - - updateInBoundBytesChan <- uint64(i) + m.updateInBoundBytes(uint64(i)) } }() // mock muxWriter go func() { defer wg.Done() + assert.Equal(t, 0, dataPoints%4, + "dataPoints is not divisible by 4; this test should be adjusted accordingly") for j := dataPoints/4 + 1; j <= dataPoints/2; j++ { - updateReceiveWindowChan <- uint32(j) - updateSendWindowChan <- uint32(j) + m.updateReceiveWindow(uint32(j)) + m.updateSendWindow(uint32(j)) - // should always be disgard since the send time is before readerSend + // should always be disgarded since the send time is before readerSend rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart.Add(-time.Duration(j*dataPoints) * time.Millisecond)} - updateRTTChan <- rm + m.updateRTT(rm) - updateOutBoundBytesChan <- uint64(j) + m.updateOutBoundBytes(uint64(j)) } }() wg.Wait() - metrics := m.Metrics() + metrics := m.metrics() points := dataPoints / 2 assert.Equal(t, time.Millisecond, metrics.RTTMin) assert.Equal(t, time.Duration(dataPoints/4)*time.Millisecond, metrics.RTTMax) diff --git a/h2mux/muxreader.go b/h2mux/muxreader.go index f80c1ab7..3bdc8216 100644 --- a/h2mux/muxreader.go +++ b/h2mux/muxreader.go @@ -39,16 +39,10 @@ type MuxReader struct { streamWriteBufferMaxLen int // r is a reference to the underlying connection used when shutting down. r io.Closer - // updateRTTChan is the channel to send new RTT measurement to muxerMetricsUpdater - updateRTTChan chan<- *roundTripMeasurement - // updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater - updateReceiveWindowChan chan<- uint32 - // updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater - updateSendWindowChan chan<- uint32 - // bytesRead is the amount of bytes read from data frame since the last time we send bytes read to metrics + // metricsUpdater is used to report metrics + metricsUpdater muxMetricsUpdater + // bytesRead is the amount of bytes read from data frames since the last time we called metricsUpdater.updateInBoundBytes() bytesRead *AtomicCounter - // updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater - updateInBoundBytesChan chan<- uint64 // dictionaries holds the h2 cross-stream compression dictionaries dictionaries h2Dictionaries } @@ -81,7 +75,7 @@ func (r *MuxReader) run(parentLogger *log.Entry) error { case <-r.abortChan: return case <-tickC: - r.updateInBoundBytesChan <- r.bytesRead.Count() + r.metricsUpdater.updateInBoundBytes(r.bytesRead.Count()) } } }() @@ -289,7 +283,7 @@ func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, parentLogger *log.E if !stream.consumeReceiveWindow(uint32(len(data))) { return r.streamError(stream.streamID, http2.ErrCodeFlowControl) } - r.updateReceiveWindowChan <- stream.getReceiveWindow() + r.metricsUpdater.updateReceiveWindow(stream.getReceiveWindow()) return nil } @@ -301,13 +295,13 @@ func (r *MuxReader) receivePingData(frame *http2.PingFrame) { return } - // Update updates the computed values with a new measurement. - // outgoingTime is the time that the probe was sent. - // We assume that time.Now() is the time we received that probe. - r.updateRTTChan <- &roundTripMeasurement{ + // Update the computed RTT aggregations with a new measurement. + // `ts` is the time that the probe was sent. + // We assume that `time.Now()` is the time we received that probe. + r.metricsUpdater.updateRTT(&roundTripMeasurement{ receiveTime: time.Now(), sendTime: time.Unix(0, ts), - } + }) } // Receive a GOAWAY from the peer. Gracefully shut down our connection. @@ -468,7 +462,7 @@ func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error { return nil } stream.replenishSendWindow(frame.Increment) - r.updateSendWindowChan <- stream.getSendWindow() + r.metricsUpdater.updateSendWindow(stream.getSendWindow()) return nil } diff --git a/h2mux/muxwriter.go b/h2mux/muxwriter.go index 03b83860..b0769356 100644 --- a/h2mux/muxwriter.go +++ b/h2mux/muxwriter.go @@ -40,14 +40,11 @@ type MuxWriter struct { headerEncoder *hpack.Encoder // headerBuffer is the temporary buffer used by headerEncoder. headerBuffer bytes.Buffer - // updateReceiveWindowChan is the channel to update receiveWindow size to muxerMetricsUpdater - updateReceiveWindowChan chan<- uint32 - // updateSendWindowChan is the channel to update sendWindow size to muxerMetricsUpdater - updateSendWindowChan chan<- uint32 - // bytesWrote is the amount of bytes wrote to data frame since the last time we send bytes wrote to metrics + + // metricsUpdater is used to report metrics + metricsUpdater muxMetricsUpdater + // bytesWrote is the amount of bytes written to data frames since the last time we called metricsUpdater.updateOutBoundBytes() bytesWrote *AtomicCounter - // updateOutBoundBytesChan is the channel to send bytesWrote to muxerMetricsUpdater - updateOutBoundBytesChan chan<- uint64 useDictChan <-chan useDictRequest } @@ -83,7 +80,7 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error { case <-w.abortChan: return case <-tickC: - w.updateOutBoundBytesChan <- w.bytesWrote.Count() + w.metricsUpdater.updateOutBoundBytes(w.bytesWrote.Count()) } } }() @@ -172,8 +169,8 @@ func (w *MuxWriter) run(parentLogger *log.Entry) error { func (w *MuxWriter) writeStreamData(stream *MuxedStream, logger *log.Entry) error { logger.Debug("writable") chunk := stream.getChunk() - w.updateReceiveWindowChan <- stream.getReceiveWindow() - w.updateSendWindowChan <- stream.getSendWindow() + w.metricsUpdater.updateReceiveWindow(stream.getReceiveWindow()) + w.metricsUpdater.updateSendWindow(stream.getSendWindow()) if chunk.sendHeadersFrame() { err := w.writeHeaders(chunk.streamID, chunk.headers) if err != nil { diff --git a/h2mux/readylist.go b/h2mux/readylist.go index 5215e464..d1a18c6d 100644 --- a/h2mux/readylist.go +++ b/h2mux/readylist.go @@ -1,15 +1,23 @@ package h2mux +import "sync" + // ReadyList multiplexes several event signals onto a single channel. type ReadyList struct { + // signalC is used to signal that a stream can be enqueued signalC chan uint32 - waitC chan uint32 + // waitC is used to signal the ID of the first ready descriptor + waitC chan uint32 + // doneC is used to signal that run should terminate + doneC chan struct{} + closeOnce sync.Once } func NewReadyList() *ReadyList { rl := &ReadyList{ signalC: make(chan uint32), waitC: make(chan uint32), + doneC: make(chan struct{}), } go rl.run() return rl @@ -17,7 +25,11 @@ func NewReadyList() *ReadyList { // ID is the stream ID func (r *ReadyList) Signal(ID uint32) { - r.signalC <- ID + select { + case r.signalC <- ID: + // ReadyList already closed + case <-r.doneC: + } } func (r *ReadyList) ReadyChannel() <-chan uint32 { @@ -25,7 +37,9 @@ func (r *ReadyList) ReadyChannel() <-chan uint32 { } func (r *ReadyList) Close() { - close(r.signalC) + r.closeOnce.Do(func() { + close(r.doneC) + }) } func (r *ReadyList) run() { @@ -35,28 +49,25 @@ func (r *ReadyList) run() { activeDescriptors := newReadyDescriptorMap() for { if firstReady == nil { - // Wait for first ready descriptor - i, ok := <-r.signalC - if !ok { - // closed + select { + case i := <-r.signalC: + firstReady = activeDescriptors.SetIfMissing(i) + case <-r.doneC: return } - firstReady = activeDescriptors.SetIfMissing(i) } select { case r.waitC <- firstReady.ID: activeDescriptors.Delete(firstReady.ID) firstReady = queue.Dequeue() - case i, ok := <-r.signalC: - if !ok { - // closed - return - } + case i := <-r.signalC: newReady := activeDescriptors.SetIfMissing(i) if newReady != nil { // key doesn't exist queue.Enqueue(newReady) } + case <-r.doneC: + return } } } diff --git a/h2mux/readylist_test.go b/h2mux/readylist_test.go index 1bf5f0bf..6ee9cfbf 100644 --- a/h2mux/readylist_test.go +++ b/h2mux/readylist_test.go @@ -3,36 +3,59 @@ package h2mux import ( "testing" "time" + + "github.com/stretchr/testify/assert" ) -func TestReadyList(t *testing.T) { +func assertEmpty(t *testing.T, rl *ReadyList) { + select { + case <-rl.ReadyChannel(): + t.Fatal("Spurious wakeup") + default: + } +} + +func assertClosed(t *testing.T, rl *ReadyList) { + select { + case _, ok := <-rl.ReadyChannel(): + assert.False(t, ok, "ReadyChannel was not closed") + case <-time.After(100 * time.Millisecond): + t.Fatalf("Timeout") + } +} + +func receiveWithTimeout(t *testing.T, rl *ReadyList) uint32 { + select { + case i := <-rl.ReadyChannel(): + return i + case <-time.After(100 * time.Millisecond): + t.Fatalf("Timeout") + return 0 + } +} + +func TestReadyListEmpty(t *testing.T) { rl := NewReadyList() - c := rl.ReadyChannel() - // helper functions - assertEmpty := func() { - select { - case <-c: - t.Fatalf("Spurious wakeup") - default: - } - } - receiveWithTimeout := func() uint32 { - select { - case i := <-c: - return i - case <-time.After(100 * time.Millisecond): - t.Fatalf("Timeout") - return 0 - } - } + // no signals, receive should fail - assertEmpty() + assertEmpty(t, rl) +} +func TestReadyListSignal(t *testing.T) { + rl := NewReadyList() + assertEmpty(t, rl) + rl.Signal(0) - if receiveWithTimeout() != 0 { + if receiveWithTimeout(t, rl) != 0 { t.Fatalf("Received wrong ID of signalled event") } - // no new signals, receive should fail - assertEmpty() + + assertEmpty(t, rl) +} + +func TestReadyListMultipleSignals(t *testing.T) { + rl := NewReadyList() + assertEmpty(t, rl) + // Signals should not block; // Duplicate unhandled signals should not cause multiple wakeups signalled := [5]bool{} @@ -42,12 +65,45 @@ func TestReadyList(t *testing.T) { } // All signals should be received once (in any order) for range signalled { - i := receiveWithTimeout() + i := receiveWithTimeout(t, rl) if signalled[i] { t.Fatalf("Received signal %d more than once", i) } signalled[i] = true } + for i := range signalled { + if !signalled[i] { + t.Fatalf("Never received signal %d", i) + } + } + assertEmpty(t, rl) +} + +func TestReadyListClose(t *testing.T) { + rl := NewReadyList() + rl.Close() + + // readyList.run() occurs in a separate goroutine, + // so there's no way to directly check that run() has terminated. + // Perform an indirect check: is the ready channel closed? + assertClosed(t, rl) + + // a second rl.Close() shouldn't cause a panic + rl.Close() + + // Signal shouldn't block after Close() + done := make(chan struct{}) + go func() { + for i := 0; i < 5; i++ { + rl.Signal(uint32(i)) + } + close(done) + }() + select { + case <-done: + case <-time.After(100 * time.Millisecond): + t.Fatal("Test timed out") + } } func TestReadyDescriptorQueue(t *testing.T) {