diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go index 310b79dd..862dc15d 100644 --- a/h2mux/h2mux.go +++ b/h2mux/h2mux.go @@ -397,7 +397,6 @@ func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader return stream, nil } - func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) { stream := m.NewStream(RPCHeaders()) if err := m.MakeMuxedStreamRequest(ctx, MuxedStreamRequest{stream: stream, body: nil}); err != nil { @@ -425,6 +424,13 @@ func (m *Muxer) MakeMuxedStreamRequest(ctx context.Context, request MuxedStreamR } } +func (m *Muxer) CloseStreamRead(stream *MuxedStream) { + stream.CloseRead() + if stream.WriteClosed() { + m.streams.Delete(stream.streamID) + } +} + func (m *Muxer) AwaitResponseHeaders(ctx context.Context, stream *MuxedStream) error { select { case <-ctx.Done(): diff --git a/h2mux/muxedstream.go b/h2mux/muxedstream.go index 286c54ab..a1ac2fd7 100644 --- a/h2mux/muxedstream.go +++ b/h2mux/muxedstream.go @@ -192,6 +192,12 @@ func (s *MuxedStream) CloseWrite() error { return nil } +func (s *MuxedStream) WriteClosed() bool { + s.writeLock.Lock() + defer s.writeLock.Unlock() + return s.writeEOF +} + func (s *MuxedStream) WriteHeaders(headers []Header) error { s.writeLock.Lock() defer s.writeLock.Unlock() @@ -351,7 +357,6 @@ func (s *MuxedStream) getChunk() *streamChunk { sendData: !s.sentEOF, eof: s.writeEOF && uint32(s.writeBuffer.Len()) <= s.sendWindow, } - // Copy at most s.sendWindow bytes, adjust the sendWindow accordingly writeLen, _ := io.CopyN(&chunk.buffer, s.writeBuffer, int64(s.sendWindow)) s.sendWindow -= uint32(writeLen)