diff --git a/h2mux/error.go b/h2mux/error.go index 467e99da..71581acc 100644 --- a/h2mux/error.go +++ b/h2mux/error.go @@ -20,11 +20,12 @@ var ( ErrUnknownStream = MuxerProtocolError{"2002 unknown stream", http2.ErrCodeProtocol} ErrInvalidStream = MuxerProtocolError{"2003 invalid stream", http2.ErrCodeProtocol} - ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"} - ErrConnectionClosed = MuxerApplicationError{"3001 connection closed"} - ErrConnectionDropped = MuxerApplicationError{"3002 connection dropped"} - ErrOpenStreamTimeout = MuxerApplicationError{"3003 open stream timeout"} - ErrResponseHeadersTimeout = MuxerApplicationError{"3004 timeout waiting for initial response headers"} + ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"} + ErrStreamRequestConnectionClosed = MuxerApplicationError{"3001 connection closed while opening stream"} + ErrConnectionDropped = MuxerApplicationError{"3002 connection dropped"} + ErrStreamRequestTimeout = MuxerApplicationError{"3003 open stream timeout"} + ErrResponseHeadersTimeout = MuxerApplicationError{"3004 timeout waiting for initial response headers"} + ErrResponseHeadersConnectionClosed = MuxerApplicationError{"3005 connection closed while waiting for initial response headers"} ErrClosedStream = MuxerStreamError{"4000 stream closed", http2.ErrCodeStreamClosed} ) diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go index 6e0905c2..03040324 100644 --- a/h2mux/h2mux.go +++ b/h2mux/h2mux.go @@ -1,7 +1,6 @@ package h2mux import ( - "bytes" "context" "io" "strings" @@ -388,72 +387,52 @@ func isConnectionClosedError(err error) bool { // OpenStream opens a new data stream with the given headers. // Called by proxy server and tunnel func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader) (*MuxedStream, error) { - stream := &MuxedStream{ - responseHeadersReceived: make(chan struct{}), - readBuffer: NewSharedBuffer(), - writeBuffer: &bytes.Buffer{}, - writeBufferMaxLen: m.config.StreamWriteBufferMaxLen, - writeBufferHasSpace: make(chan struct{}, 1), - receiveWindow: m.config.DefaultWindowSize, - receiveWindowCurrentMax: m.config.DefaultWindowSize, - receiveWindowMax: m.config.MaxWindowSize, - sendWindow: m.config.DefaultWindowSize, - readyList: m.readyList, - writeHeaders: headers, - dictionaries: m.muxReader.dictionaries, + stream := m.NewStream(headers) + if err := m.MakeMuxedStreamRequest(ctx, MuxedStreamRequest{stream, body}); err != nil { + return nil, err } + if err := m.AwaitResponseHeaders(ctx, stream); err != nil { + return nil, err + } + return stream, nil +} + +func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) { + stream := m.NewStream(RPCHeaders()) + if err := m.MakeMuxedStreamRequest(ctx, MuxedStreamRequest{stream: stream, body: nil}); err != nil { + return nil, err + } + if err := m.AwaitResponseHeaders(ctx, stream); err != nil { + return nil, err + } + return stream, nil +} + +func (m *Muxer) NewStream(headers []Header) *MuxedStream { + return NewStream(m.config, headers, m.readyList, m.muxReader.dictionaries) +} + +func (m *Muxer) MakeMuxedStreamRequest(ctx context.Context, request MuxedStreamRequest) error { select { + case <-ctx.Done(): + return ErrStreamRequestTimeout + case <-m.abortChan: + return ErrStreamRequestConnectionClosed // Will be received by mux writer - case <-ctx.Done(): - return nil, ErrOpenStreamTimeout - case <-m.abortChan: - return nil, ErrConnectionClosed - case m.newStreamChan <- MuxedStreamRequest{stream: stream, body: body}: - } - - select { - case <-ctx.Done(): - return nil, ErrResponseHeadersTimeout - case <-m.abortChan: - return nil, ErrConnectionClosed - case <-stream.responseHeadersReceived: - return stream, nil + case m.newStreamChan <- request: + return nil } } -func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) { - stream := &MuxedStream{ - responseHeadersReceived: make(chan struct{}), - readBuffer: NewSharedBuffer(), - writeBuffer: &bytes.Buffer{}, - writeBufferMaxLen: m.config.StreamWriteBufferMaxLen, - writeBufferHasSpace: make(chan struct{}, 1), - receiveWindow: m.config.DefaultWindowSize, - receiveWindowCurrentMax: m.config.DefaultWindowSize, - receiveWindowMax: m.config.MaxWindowSize, - sendWindow: m.config.DefaultWindowSize, - readyList: m.readyList, - writeHeaders: RPCHeaders(), - dictionaries: m.muxReader.dictionaries, - } - - select { - // Will be received by mux writer - case <-ctx.Done(): - return nil, ErrOpenStreamTimeout - case <-m.abortChan: - return nil, ErrConnectionClosed - case m.newStreamChan <- MuxedStreamRequest{stream: stream, body: nil}: - } - +func (m *Muxer) AwaitResponseHeaders(ctx context.Context, stream *MuxedStream) error { select { case <-ctx.Done(): - return nil, ErrResponseHeadersTimeout + return ErrResponseHeadersTimeout case <-m.abortChan: - return nil, ErrConnectionClosed + return ErrResponseHeadersConnectionClosed case <-stream.responseHeadersReceived: - return stream, nil + return nil } } diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go index 89aba250..0efa75e1 100644 --- a/h2mux/h2mux_test.go +++ b/h2mux/h2mux_test.go @@ -584,8 +584,8 @@ func TestOpenAfterDisconnect(t *testing.T) { []Header{{Name: "test-header", Value: "headerValue"}}, nil, ) - if err != ErrConnectionClosed { - t.Fatalf("unexpected error in OpenStream: %s", err) + if err != ErrStreamRequestConnectionClosed && err != ErrResponseHeadersConnectionClosed { + t.Fatalf("case %v: unexpected error in OpenStream: %v", i, err) } } } diff --git a/h2mux/muxedstream.go b/h2mux/muxedstream.go index 44d6f1e2..286c54ab 100644 --- a/h2mux/muxedstream.go +++ b/h2mux/muxedstream.go @@ -88,6 +88,23 @@ func (th TunnelHostname) IsSet() bool { return th != "" } +func NewStream(config MuxerConfig, writeHeaders []Header, readyList *ReadyList, dictionaries h2Dictionaries) *MuxedStream { + return &MuxedStream{ + responseHeadersReceived: make(chan struct{}), + readBuffer: NewSharedBuffer(), + writeBuffer: &bytes.Buffer{}, + writeBufferMaxLen: config.StreamWriteBufferMaxLen, + writeBufferHasSpace: make(chan struct{}, 1), + receiveWindow: config.DefaultWindowSize, + receiveWindowCurrentMax: config.DefaultWindowSize, + receiveWindowMax: config.MaxWindowSize, + sendWindow: config.DefaultWindowSize, + readyList: readyList, + writeHeaders: writeHeaders, + dictionaries: dictionaries, + } +} + func (s *MuxedStream) Read(p []byte) (n int, err error) { var readBuffer ReadWriteClosedCloser if s.dictionaries.read != nil {