From e14ec1a1fbba4ce65edddb7961643017e1014ab8 Mon Sep 17 00:00:00 2001 From: Nick Vollmar Date: Tue, 5 Nov 2019 17:24:00 -0600 Subject: [PATCH] TUN-2505: Terminate stream on receipt of RST_STREAM; MuxedStream.CloseWrite() should terminate the MuxedStream.Write() loop --- h2mux/muxedstream.go | 34 ++++++++++++++++++++++++++-------- h2mux/muxreader.go | 3 +++ 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/h2mux/muxedstream.go b/h2mux/muxedstream.go index a1ac2fd7..6bafa19d 100644 --- a/h2mux/muxedstream.go +++ b/h2mux/muxedstream.go @@ -137,9 +137,10 @@ func (s *MuxedStream) Write(p []byte) (int, error) { // If the buffer is full, block till there is more room. // Use a loop to recheck the buffer size after the lock is reacquired. for s.writeBufferMaxLen <= s.writeBuffer.Len() { - s.writeLock.Unlock() - <-s.writeBufferHasSpace - s.writeLock.Lock() + s.awaitWriteBufferHasSpace() + if s.writeEOF { + return totalWritten, io.EOF + } } amountToWrite := len(p) - totalWritten spaceAvailable := s.writeBufferMaxLen - s.writeBuffer.Len() @@ -188,6 +189,9 @@ func (s *MuxedStream) CloseWrite() error { if c, ok := s.writeBuffer.(io.Closer); ok { c.Close() } + // Allow MuxedStream::Write() to terminate its loop with err=io.EOF, if needed + s.notifyWriteBufferHasSpace() + // We need to send something over the wire, even if it's an END_STREAM with no data s.writeNotify() return nil } @@ -238,6 +242,23 @@ func (s *MuxedStream) TunnelHostname() TunnelHostname { return s.tunnelHostname } +// Block until a value is sent on writeBufferHasSpace. +// Must be called while holding writeLock +func (s *MuxedStream) awaitWriteBufferHasSpace() { + s.writeLock.Unlock() + <-s.writeBufferHasSpace + s.writeLock.Lock() +} + +// Send a value on writeBufferHasSpace without blocking. +// Must be called while holding writeLock +func (s *MuxedStream) notifyWriteBufferHasSpace() { + select { + case s.writeBufferHasSpace <- struct{}{}: + default: + } +} + func (s *MuxedStream) getReceiveWindow() uint32 { s.writeLock.Lock() defer s.writeLock.Unlock() @@ -361,12 +382,9 @@ func (s *MuxedStream) getChunk() *streamChunk { writeLen, _ := io.CopyN(&chunk.buffer, s.writeBuffer, int64(s.sendWindow)) s.sendWindow -= uint32(writeLen) - // Non-blocking channel send. This will allow MuxedStream::Write() to continue, if needed + // Allow MuxedStream::Write() to continue, if needed if s.writeBuffer.Len() < s.writeBufferMaxLen { - select { - case s.writeBufferHasSpace <- struct{}{}: - default: - } + s.notifyWriteBufferHasSpace() } // When we write the chunk, we'll write the WINDOW_UPDATE frame if needed diff --git a/h2mux/muxreader.go b/h2mux/muxreader.go index 6228d685..728c94c4 100644 --- a/h2mux/muxreader.go +++ b/h2mux/muxreader.go @@ -125,6 +125,9 @@ func (r *MuxReader) run(parentLogger *log.Entry) error { if streamID == 0 { return ErrInvalidStream } + if stream, ok := r.streams.Get(streamID); ok { + stream.Close() + } r.streams.Delete(streamID) case *http2.PingFrame: r.receivePingData(f)