TUN-2505: Terminate stream on receipt of RST_STREAM; MuxedStream.CloseWrite() should terminate the MuxedStream.Write() loop

This commit is contained in:
Nick Vollmar 2019-11-05 17:24:00 -06:00
parent 3a9a0a0d75
commit e14ec1a1fb
2 changed files with 29 additions and 8 deletions

View File

@ -137,9 +137,10 @@ func (s *MuxedStream) Write(p []byte) (int, error) {
// If the buffer is full, block till there is more room. // If the buffer is full, block till there is more room.
// Use a loop to recheck the buffer size after the lock is reacquired. // Use a loop to recheck the buffer size after the lock is reacquired.
for s.writeBufferMaxLen <= s.writeBuffer.Len() { for s.writeBufferMaxLen <= s.writeBuffer.Len() {
s.writeLock.Unlock() s.awaitWriteBufferHasSpace()
<-s.writeBufferHasSpace if s.writeEOF {
s.writeLock.Lock() return totalWritten, io.EOF
}
} }
amountToWrite := len(p) - totalWritten amountToWrite := len(p) - totalWritten
spaceAvailable := s.writeBufferMaxLen - s.writeBuffer.Len() spaceAvailable := s.writeBufferMaxLen - s.writeBuffer.Len()
@ -188,6 +189,9 @@ func (s *MuxedStream) CloseWrite() error {
if c, ok := s.writeBuffer.(io.Closer); ok { if c, ok := s.writeBuffer.(io.Closer); ok {
c.Close() c.Close()
} }
// Allow MuxedStream::Write() to terminate its loop with err=io.EOF, if needed
s.notifyWriteBufferHasSpace()
// We need to send something over the wire, even if it's an END_STREAM with no data
s.writeNotify() s.writeNotify()
return nil return nil
} }
@ -238,6 +242,23 @@ func (s *MuxedStream) TunnelHostname() TunnelHostname {
return s.tunnelHostname return s.tunnelHostname
} }
// Block until a value is sent on writeBufferHasSpace.
// Must be called while holding writeLock
func (s *MuxedStream) awaitWriteBufferHasSpace() {
s.writeLock.Unlock()
<-s.writeBufferHasSpace
s.writeLock.Lock()
}
// Send a value on writeBufferHasSpace without blocking.
// Must be called while holding writeLock
func (s *MuxedStream) notifyWriteBufferHasSpace() {
select {
case s.writeBufferHasSpace <- struct{}{}:
default:
}
}
func (s *MuxedStream) getReceiveWindow() uint32 { func (s *MuxedStream) getReceiveWindow() uint32 {
s.writeLock.Lock() s.writeLock.Lock()
defer s.writeLock.Unlock() defer s.writeLock.Unlock()
@ -361,12 +382,9 @@ func (s *MuxedStream) getChunk() *streamChunk {
writeLen, _ := io.CopyN(&chunk.buffer, s.writeBuffer, int64(s.sendWindow)) writeLen, _ := io.CopyN(&chunk.buffer, s.writeBuffer, int64(s.sendWindow))
s.sendWindow -= uint32(writeLen) s.sendWindow -= uint32(writeLen)
// Non-blocking channel send. This will allow MuxedStream::Write() to continue, if needed // Allow MuxedStream::Write() to continue, if needed
if s.writeBuffer.Len() < s.writeBufferMaxLen { if s.writeBuffer.Len() < s.writeBufferMaxLen {
select { s.notifyWriteBufferHasSpace()
case s.writeBufferHasSpace <- struct{}{}:
default:
}
} }
// When we write the chunk, we'll write the WINDOW_UPDATE frame if needed // When we write the chunk, we'll write the WINDOW_UPDATE frame if needed

View File

@ -125,6 +125,9 @@ func (r *MuxReader) run(parentLogger *log.Entry) error {
if streamID == 0 { if streamID == 0 {
return ErrInvalidStream return ErrInvalidStream
} }
if stream, ok := r.streams.Get(streamID); ok {
stream.Close()
}
r.streams.Delete(streamID) r.streams.Delete(streamID)
case *http2.PingFrame: case *http2.PingFrame:
r.receivePingData(f) r.receivePingData(f)