TUN-1682: Add context to OpenStream to prevent it from blocking indefinitely.

This commit is contained in:
Chung-Ting Huang 2019-04-02 18:12:09 -05:00
parent 13d25a52a9
commit 2bef5dbe72
5 changed files with 276 additions and 213 deletions

View File

@ -18,6 +18,7 @@ import (
const ( const (
dialTimeout = 5 * time.Second dialTimeout = 5 * time.Second
openStreamTimeout = 30 * time.Second
) )
type dialError struct { type dialError struct {
@ -70,7 +71,9 @@ func (h *h2muxHandler) serve(ctx context.Context) error {
// Connect is used to establish connections with cloudflare's edge network // Connect is used to establish connections with cloudflare's edge network
func (h *h2muxHandler) connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters) (*tunnelpogs.ConnectResult, error) { func (h *h2muxHandler) connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters) (*tunnelpogs.ConnectResult, error) {
conn, err := h.newRPConn() openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
defer cancel()
conn, err := h.newRPConn(openStreamCtx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Failed to create new RPC connection") return nil, errors.Wrap(err, "Failed to create new RPC connection")
} }
@ -83,8 +86,8 @@ func (h *h2muxHandler) shutdown() {
h.muxer.Shutdown() h.muxer.Shutdown()
} }
func (h *h2muxHandler) newRPConn() (*rpc.Conn, error) { func (h *h2muxHandler) newRPConn(ctx context.Context) (*rpc.Conn, error) {
stream, err := h.muxer.OpenStream([]h2mux.Header{ stream, err := h.muxer.OpenStream(ctx, []h2mux.Header{
{Name: ":method", Value: "RPC"}, {Name: ":method", Value: "RPC"},
{Name: ":scheme", Value: "capnp"}, {Name: ":scheme", Value: "capnp"},
{Name: ":path", Value: "*"}, {Name: ":path", Value: "*"},

View File

@ -7,6 +7,7 @@ import (
) )
var ( var (
// HTTP2 error codes: https://http2.github.io/http2-spec/#ErrorCodes
ErrHandshakeTimeout = MuxerHandshakeError{"1000 handshake timeout"} ErrHandshakeTimeout = MuxerHandshakeError{"1000 handshake timeout"}
ErrBadHandshakeNotSettings = MuxerHandshakeError{"1001 unexpected response"} ErrBadHandshakeNotSettings = MuxerHandshakeError{"1001 unexpected response"}
ErrBadHandshakeUnexpectedAck = MuxerHandshakeError{"1002 unexpected response"} ErrBadHandshakeUnexpectedAck = MuxerHandshakeError{"1002 unexpected response"}
@ -22,6 +23,7 @@ var (
ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"} ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"}
ErrConnectionClosed = MuxerApplicationError{"3001 connection closed"} ErrConnectionClosed = MuxerApplicationError{"3001 connection closed"}
ErrConnectionDropped = MuxerApplicationError{"3002 connection dropped"} ErrConnectionDropped = MuxerApplicationError{"3002 connection dropped"}
ErrOpenStreamTimeout = MuxerApplicationError{"3003 open stream timeout"}
ErrClosedStream = MuxerStreamError{"4000 stream closed", http2.ErrCodeStreamClosed} ErrClosedStream = MuxerStreamError{"4000 stream closed", http2.ErrCodeStreamClosed}
) )

View File

@ -379,7 +379,7 @@ func isConnectionClosedError(err error) bool {
// OpenStream opens a new data stream with the given headers. // OpenStream opens a new data stream with the given headers.
// Called by proxy server and tunnel // Called by proxy server and tunnel
func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, error) { func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader) (*MuxedStream, error) {
stream := &MuxedStream{ stream := &MuxedStream{
responseHeadersReceived: make(chan struct{}), responseHeadersReceived: make(chan struct{}),
readBuffer: NewSharedBuffer(), readBuffer: NewSharedBuffer(),
@ -397,15 +397,20 @@ func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, erro
select { select {
// Will be received by mux writer // Will be received by mux writer
case m.newStreamChan <- MuxedStreamRequest{stream: stream, body: body}: case <-ctx.Done():
return nil, ErrOpenStreamTimeout
case <-m.abortChan: case <-m.abortChan:
return nil, ErrConnectionClosed return nil, ErrConnectionClosed
case m.newStreamChan <- MuxedStreamRequest{stream: stream, body: body}:
} }
select { select {
case <-ctx.Done():
return nil, ErrOpenStreamTimeout
case <-m.abortChan:
return nil, ErrConnectionClosed
case <-stream.responseHeadersReceived: case <-stream.responseHeadersReceived:
return stream, nil return stream, nil
case <-m.abortChan:
return nil, ErrConnectionClosed
} }
} }

View File

@ -15,7 +15,15 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
)
const (
testOpenStreamTimeout = time.Millisecond * 5000
testHandshakeTimeout = time.Millisecond * 1000
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -35,11 +43,12 @@ type DefaultMuxerPair struct {
doneC chan struct{} doneC chan struct{}
} }
func NewDefaultMuxerPair() *DefaultMuxerPair { func NewDefaultMuxerPair(t assert.TestingT, f MuxedStreamFunc) *DefaultMuxerPair {
origin, edge := net.Pipe() origin, edge := net.Pipe()
return &DefaultMuxerPair{ p := &DefaultMuxerPair{
OriginMuxConfig: MuxerConfig{ OriginMuxConfig: MuxerConfig{
Timeout: time.Second, Timeout: testHandshakeTimeout,
Handler: f,
IsClient: true, IsClient: true,
Name: "origin", Name: "origin",
Logger: log.NewEntry(log.New()), Logger: log.NewEntry(log.New()),
@ -49,7 +58,7 @@ func NewDefaultMuxerPair() *DefaultMuxerPair {
}, },
OriginConn: origin, OriginConn: origin,
EdgeMuxConfig: MuxerConfig{ EdgeMuxConfig: MuxerConfig{
Timeout: time.Second, Timeout: testHandshakeTimeout,
IsClient: false, IsClient: false,
Name: "edge", Name: "edge",
Logger: log.NewEntry(log.New()), Logger: log.NewEntry(log.New()),
@ -60,13 +69,16 @@ func NewDefaultMuxerPair() *DefaultMuxerPair {
EdgeConn: edge, EdgeConn: edge,
doneC: make(chan struct{}), doneC: make(chan struct{}),
} }
assert.NoError(t, p.Handshake())
return p
} }
func NewCompressedMuxerPair(quality CompressionSetting) *DefaultMuxerPair { func NewCompressedMuxerPair(t assert.TestingT, quality CompressionSetting, f MuxedStreamFunc) *DefaultMuxerPair {
origin, edge := net.Pipe() origin, edge := net.Pipe()
return &DefaultMuxerPair{ p := &DefaultMuxerPair{
OriginMuxConfig: MuxerConfig{ OriginMuxConfig: MuxerConfig{
Timeout: time.Second, Timeout: time.Second,
Handler: f,
IsClient: true, IsClient: true,
Name: "origin", Name: "origin",
CompressionQuality: quality, CompressionQuality: quality,
@ -83,44 +95,28 @@ func NewCompressedMuxerPair(quality CompressionSetting) *DefaultMuxerPair {
EdgeConn: edge, EdgeConn: edge,
doneC: make(chan struct{}), doneC: make(chan struct{}),
} }
assert.NoError(t, p.Handshake())
return p
} }
func (p *DefaultMuxerPair) Handshake(t *testing.T) { func (p *DefaultMuxerPair) Handshake() error {
edgeErrC := make(chan error) ctx, cancel := context.WithTimeout(context.Background(), testHandshakeTimeout)
originErrC := make(chan error) defer cancel()
go func() { errGroup, _ := errgroup.WithContext(ctx)
var err error errGroup.Go(func() (err error) {
p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig) p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig)
edgeErrC <- err return errors.Wrap(err, "edge handshake failure")
}() })
go func() { errGroup.Go(func() (err error) {
var err error
p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig) p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig)
originErrC <- err return errors.Wrap(err, "origin handshake failure")
}() })
select { return errGroup.Wait()
case err := <-edgeErrC:
if err != nil {
t.Fatalf("edge handshake failure: %s", err)
}
case <-time.After(time.Second * 5):
t.Fatalf("edge handshake timeout")
}
select {
case err := <-originErrC:
if err != nil {
t.Fatalf("origin handshake failure: %s", err)
}
case <-time.After(time.Second * 5):
t.Fatalf("origin handshake timeout")
}
} }
func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) { func (p *DefaultMuxerPair) Serve(t assert.TestingT) {
ctx := context.Background() ctx := context.Background()
p.Handshake(t)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
go func() { go func() {
@ -155,18 +151,23 @@ func (p *DefaultMuxerPair) Wait(t *testing.T) {
} }
} }
func (p *DefaultMuxerPair) OpenEdgeMuxStream(headers []Header, body io.Reader) (*MuxedStream, error) {
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
defer cancel()
return p.EdgeMux.OpenStream(ctx, headers, body)
}
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
muxPair := NewDefaultMuxerPair() f := func(stream *MuxedStream) error {
muxPair.Handshake(t) return nil
}
muxPair := NewDefaultMuxerPair(t, f)
AssertIfPipeReadable(t, muxPair.OriginConn) AssertIfPipeReadable(t, muxPair.OriginConn)
AssertIfPipeReadable(t, muxPair.EdgeConn) AssertIfPipeReadable(t, muxPair.EdgeConn)
} }
func TestSingleStream(t *testing.T) { func TestSingleStream(t *testing.T) {
closeC := make(chan struct{}) f := MuxedStreamFunc(func(stream *MuxedStream) error {
muxPair := NewDefaultMuxerPair()
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
defer close(closeC)
if len(stream.Headers) != 1 { if len(stream.Headers) != 1 {
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
} }
@ -181,8 +182,6 @@ func TestSingleStream(t *testing.T) {
}) })
buf := []byte("Hello world") buf := []byte("Hello world")
stream.Write(buf) stream.Write(buf)
// after this receive, the edge closed the stream
<-closeC
n, err := io.ReadFull(stream, buf) n, err := io.ReadFull(stream, buf)
if n > 0 { if n > 0 {
t.Fatalf("read %d bytes after EOF", n) t.Fatalf("read %d bytes after EOF", n)
@ -192,9 +191,10 @@ func TestSingleStream(t *testing.T) {
} }
return nil return nil
}) })
muxPair.HandshakeAndServe(t) muxPair := NewDefaultMuxerPair(t, f)
muxPair.Serve(t)
stream, err := muxPair.EdgeMux.OpenStream( stream, err := muxPair.OpenEdgeMuxStream(
[]Header{{Name: "test-header", Value: "headerValue"}}, []Header{{Name: "test-header", Value: "headerValue"}},
nil, nil,
) )
@ -222,7 +222,6 @@ func TestSingleStream(t *testing.T) {
t.Fatalf("expected response body %s, got %s", "Hello world", responseBody) t.Fatalf("expected response body %s, got %s", "Hello world", responseBody)
} }
stream.Close() stream.Close()
closeC <- struct{}{}
n, err = stream.Write([]byte("aaaaa")) n, err = stream.Write([]byte("aaaaa"))
if n > 0 { if n > 0 {
t.Fatalf("wrote %d bytes after EOF", n) t.Fatalf("wrote %d bytes after EOF", n)
@ -230,13 +229,11 @@ func TestSingleStream(t *testing.T) {
if err != io.EOF { if err != io.EOF {
t.Fatalf("expected EOF, got %s", err) t.Fatalf("expected EOF, got %s", err)
} }
<-closeC
} }
func TestSingleStreamLargeResponseBody(t *testing.T) { func TestSingleStreamLargeResponseBody(t *testing.T) {
muxPair := NewDefaultMuxerPair()
bodySize := 1 << 24 bodySize := 1 << 24
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { f := MuxedStreamFunc(func(stream *MuxedStream) error {
if len(stream.Headers) != 1 { if len(stream.Headers) != 1 {
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
} }
@ -265,9 +262,10 @@ func TestSingleStreamLargeResponseBody(t *testing.T) {
return nil return nil
}) })
muxPair.HandshakeAndServe(t) muxPair := NewDefaultMuxerPair(t, f)
muxPair.Serve(t)
stream, err := muxPair.EdgeMux.OpenStream( stream, err := muxPair.OpenEdgeMuxStream(
[]Header{{Name: "test-header", Value: "headerValue"}}, []Header{{Name: "test-header", Value: "headerValue"}},
nil, nil,
) )
@ -295,10 +293,7 @@ func TestSingleStreamLargeResponseBody(t *testing.T) {
} }
func TestMultipleStreams(t *testing.T) { func TestMultipleStreams(t *testing.T) {
muxPair := NewDefaultMuxerPair() f := MuxedStreamFunc(func(stream *MuxedStream) error {
maxStreams := 64
errorsC := make(chan error, maxStreams)
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error {
if len(stream.Headers) != 1 { if len(stream.Headers) != 1 {
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
} }
@ -314,15 +309,18 @@ func TestMultipleStreams(t *testing.T) {
log.Debugf("Wrote body for stream %s", stream.Headers[0].Value) log.Debugf("Wrote body for stream %s", stream.Headers[0].Value)
return nil return nil
}) })
muxPair.HandshakeAndServe(t) muxPair := NewDefaultMuxerPair(t, f)
muxPair.Serve(t)
maxStreams := 64
errorsC := make(chan error, maxStreams)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(maxStreams) wg.Add(maxStreams)
for i := 0; i < maxStreams; i++ { for i := 0; i < maxStreams; i++ {
go func(tokenId int) { go func(tokenId int) {
defer wg.Done() defer wg.Done()
tokenString := fmt.Sprintf("%d", tokenId) tokenString := fmt.Sprintf("%d", tokenId)
stream, err := muxPair.EdgeMux.OpenStream( stream, err := muxPair.OpenEdgeMuxStream(
[]Header{{Name: "client-token", Value: tokenString}}, []Header{{Name: "client-token", Value: tokenString}},
nil, nil,
) )
@ -373,13 +371,12 @@ func TestMultipleStreams(t *testing.T) {
func TestMultipleStreamsFlowControl(t *testing.T) { func TestMultipleStreamsFlowControl(t *testing.T) {
maxStreams := 32 maxStreams := 32
errorsC := make(chan error, maxStreams)
responseSizes := make([]int32, maxStreams) responseSizes := make([]int32, maxStreams)
for i := 0; i < maxStreams; i++ { for i := 0; i < maxStreams; i++ {
responseSizes[i] = rand.Int31n(int32(defaultWindowSize << 4)) responseSizes[i] = rand.Int31n(int32(defaultWindowSize << 4))
} }
muxPair := NewDefaultMuxerPair()
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { f := MuxedStreamFunc(func(stream *MuxedStream) error {
if len(stream.Headers) != 1 { if len(stream.Headers) != 1 {
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
} }
@ -405,63 +402,48 @@ func TestMultipleStreamsFlowControl(t *testing.T) {
} }
return nil return nil
}) })
muxPair.HandshakeAndServe(t) muxPair := NewDefaultMuxerPair(t, f)
muxPair.Serve(t)
var wg sync.WaitGroup errGroup, _ := errgroup.WithContext(context.Background())
wg.Add(maxStreams)
for i := 0; i < maxStreams; i++ { for i := 0; i < maxStreams; i++ {
go func(tokenId int) { errGroup.Go(func() error {
defer wg.Done() stream, err := muxPair.OpenEdgeMuxStream(
stream, err := muxPair.EdgeMux.OpenStream(
[]Header{{Name: "test-header", Value: "headerValue"}}, []Header{{Name: "test-header", Value: "headerValue"}},
nil, nil,
) )
if err != nil { if err != nil {
errorsC <- fmt.Errorf("stream %d error in OpenStream: %s", stream.streamID, err) return fmt.Errorf("error in OpenStream: %d %s", stream.streamID, err)
return
} }
if len(stream.Headers) != 1 { if len(stream.Headers) != 1 {
errorsC <- fmt.Errorf("stream %d expected %d headers, got %d", stream.streamID, 1, len(stream.Headers)) return fmt.Errorf("stream %d expected %d headers, got %d", stream.streamID, 1, len(stream.Headers))
return
} }
if stream.Headers[0].Name != "response-header" { if stream.Headers[0].Name != "response-header" {
errorsC <- fmt.Errorf("stream %d expected header name %s, got %s", stream.streamID, "response-header", stream.Headers[0].Name) return fmt.Errorf("stream %d expected header name %s, got %s", stream.streamID, "response-header", stream.Headers[0].Name)
return
} }
if stream.Headers[0].Value != "responseValue" { if stream.Headers[0].Value != "responseValue" {
errorsC <- fmt.Errorf("stream %d expected header value %s, got %s", stream.streamID, "responseValue", stream.Headers[0].Value) return fmt.Errorf("stream %d expected header value %s, got %s", stream.streamID, "responseValue", stream.Headers[0].Value)
return
} }
responseBody := make([]byte, responseSizes[(stream.streamID-2)/2]) responseBody := make([]byte, responseSizes[(stream.streamID-2)/2])
n, err := io.ReadFull(stream, responseBody) n, err := io.ReadFull(stream, responseBody)
if err != nil { if err != nil {
errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err) return fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err)
return
} }
if n != len(responseBody) { if n != len(responseBody) {
errorsC <- fmt.Errorf("stream %d expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n) return fmt.Errorf("stream %d expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n)
return
} }
}(i) return nil
} })
wg.Wait()
close(errorsC)
testFail := false
for err := range errorsC {
testFail = true
log.Error(err)
}
if testFail {
t.Fatalf("TestMultipleStreamsFlowControl failed")
} }
assert.NoError(t, errGroup.Wait())
} }
func TestGracefulShutdown(t *testing.T) { func TestGracefulShutdown(t *testing.T) {
sendC := make(chan struct{}) sendC := make(chan struct{})
responseBuf := bytes.Repeat([]byte("Hello world"), 65536) responseBuf := bytes.Repeat([]byte("Hello world"), 65536)
muxPair := NewDefaultMuxerPair()
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { f := MuxedStreamFunc(func(stream *MuxedStream) error {
stream.WriteHeaders([]Header{ stream.WriteHeaders([]Header{
{Name: "response-header", Value: "responseValue"}, {Name: "response-header", Value: "responseValue"},
}) })
@ -479,18 +461,19 @@ func TestGracefulShutdown(t *testing.T) {
log.Debugf("Handler ends") log.Debugf("Handler ends")
return nil return nil
}) })
muxPair.HandshakeAndServe(t) muxPair := NewDefaultMuxerPair(t, f)
muxPair.Serve(t)
stream, err := muxPair.EdgeMux.OpenStream( stream, err := muxPair.OpenEdgeMuxStream(
[]Header{{Name: "test-header", Value: "headerValue"}}, []Header{{Name: "test-header", Value: "headerValue"}},
nil, nil,
) )
// Start graceful shutdown of the edge mux - this should also close the origin mux when done
muxPair.EdgeMux.Shutdown()
close(sendC)
if err != nil { if err != nil {
t.Fatalf("error in OpenStream: %s", err) t.Fatalf("error in OpenStream: %s", err)
} }
// Start graceful shutdown of the edge mux - this should also close the origin mux when done
muxPair.EdgeMux.Shutdown()
close(sendC)
responseBody := make([]byte, len(responseBuf)) responseBody := make([]byte, len(responseBuf))
log.Debugf("Waiting for %d bytes", len(responseBuf)) log.Debugf("Waiting for %d bytes", len(responseBuf))
n, err := io.ReadFull(stream, responseBody) n, err := io.ReadFull(stream, responseBody)
@ -511,8 +494,8 @@ func TestUnexpectedShutdown(t *testing.T) {
sendC := make(chan struct{}) sendC := make(chan struct{})
handlerFinishC := make(chan struct{}) handlerFinishC := make(chan struct{})
responseBuf := bytes.Repeat([]byte("Hello world"), 65536) responseBuf := bytes.Repeat([]byte("Hello world"), 65536)
muxPair := NewDefaultMuxerPair()
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { f := MuxedStreamFunc(func(stream *MuxedStream) error {
defer close(handlerFinishC) defer close(handlerFinishC)
stream.WriteHeaders([]Header{ stream.WriteHeaders([]Header{
{Name: "response-header", Value: "responseValue"}, {Name: "response-header", Value: "responseValue"},
@ -533,9 +516,10 @@ func TestUnexpectedShutdown(t *testing.T) {
} }
return nil return nil
}) })
muxPair.HandshakeAndServe(t) muxPair := NewDefaultMuxerPair(t, f)
muxPair.Serve(t)
stream, err := muxPair.EdgeMux.OpenStream( stream, err := muxPair.OpenEdgeMuxStream(
[]Header{{Name: "test-header", Value: "headerValue"}}, []Header{{Name: "test-header", Value: "headerValue"}},
nil, nil,
) )
@ -580,9 +564,8 @@ func EchoHandler(stream *MuxedStream) error {
func TestOpenAfterDisconnect(t *testing.T) { func TestOpenAfterDisconnect(t *testing.T) {
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
muxPair := NewDefaultMuxerPair() muxPair := NewDefaultMuxerPair(t, EchoHandler)
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(EchoHandler) muxPair.Serve(t)
muxPair.HandshakeAndServe(t)
switch i { switch i {
case 0: case 0:
@ -597,7 +580,7 @@ func TestOpenAfterDisconnect(t *testing.T) {
muxPair.EdgeConn.Close() muxPair.EdgeConn.Close()
} }
_, err := muxPair.EdgeMux.OpenStream( _, err := muxPair.OpenEdgeMuxStream(
[]Header{{Name: "test-header", Value: "headerValue"}}, []Header{{Name: "test-header", Value: "headerValue"}},
nil, nil,
) )
@ -608,11 +591,10 @@ func TestOpenAfterDisconnect(t *testing.T) {
} }
func TestHPACK(t *testing.T) { func TestHPACK(t *testing.T) {
muxPair := NewDefaultMuxerPair() muxPair := NewDefaultMuxerPair(t, EchoHandler)
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(EchoHandler) muxPair.Serve(t)
muxPair.HandshakeAndServe(t)
stream, err := muxPair.EdgeMux.OpenStream( stream, err := muxPair.OpenEdgeMuxStream(
[]Header{ []Header{
{Name: ":method", Value: "RPC"}, {Name: ":method", Value: "RPC"},
{Name: ":scheme", Value: "capnp"}, {Name: ":scheme", Value: "capnp"},
@ -626,7 +608,7 @@ func TestHPACK(t *testing.T) {
stream.Close() stream.Close()
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
stream, err := muxPair.EdgeMux.OpenStream( stream, err := muxPair.OpenEdgeMuxStream(
[]Header{ []Header{
{Name: ":method", Value: "GET"}, {Name: ":method", Value: "GET"},
{Name: ":scheme", Value: "https"}, {Name: ":scheme", Value: "https"},
@ -688,8 +670,6 @@ func AssertIfPipeReadable(t *testing.T, pipe io.ReadCloser) {
func TestMultipleStreamsWithDictionaries(t *testing.T) { func TestMultipleStreamsWithDictionaries(t *testing.T) {
for q := CompressionNone; q <= CompressionMax; q++ { for q := CompressionNone; q <= CompressionMax; q++ {
muxPair := NewCompressedMuxerPair(q)
htmlBody := `<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.1//EN"` + htmlBody := `<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.1//EN"` +
`"http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd">` + `"http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd">` +
`<html xmlns="http://www.w3.org/1999/xhtml" xml:lang="en">` + `<html xmlns="http://www.w3.org/1999/xhtml" xml:lang="en">` +
@ -712,7 +692,7 @@ func TestMultipleStreamsWithDictionaries(t *testing.T) {
`</body>` + `</body>` +
`</html>` `</html>`
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { f := MuxedStreamFunc(func(stream *MuxedStream) error {
var contentType string var contentType string
var pathHeader Header var pathHeader Header
@ -744,8 +724,8 @@ func TestMultipleStreamsWithDictionaries(t *testing.T) {
return nil return nil
}) })
muxPair := NewCompressedMuxerPair(t, q, f)
muxPair.HandshakeAndServe(t) muxPair.Serve(t)
var wg sync.WaitGroup var wg sync.WaitGroup
@ -782,25 +762,26 @@ func TestMultipleStreamsWithDictionaries(t *testing.T) {
errorsC := make(chan error, len(paths)) errorsC := make(chan error, len(paths))
for i, s := range paths { for i, s := range paths {
go func(i int, path string) { go func(index int, path string) {
defer wg.Done() defer wg.Done()
stream, err := muxPair.EdgeMux.OpenStream( stream, err := muxPair.OpenEdgeMuxStream(
[]Header{ []Header{
{Name: ":method", Value: "GET"}, {Name: ":method", Value: "GET"},
{Name: ":scheme", Value: "https"}, {Name: ":scheme", Value: "https"},
{Name: ":authority", Value: "tunnel.otterlyadorable.co.uk"}, {Name: ":authority", Value: "tunnel.otterlyadorable.co.uk"},
{Name: ":path", Value: path}, {Name: ":path", Value: path},
{Name: "cf-ray", Value: "378948953f044408-SFO-DOG"}, {Name: "cf-ray", Value: "378948953f044408-SFO-DOG"},
{Name: "idx", Value: strconv.Itoa(i)}, {Name: "idx", Value: strconv.Itoa(index)},
{Name: "accept-encoding", Value: "gzip, br"}, {Name: "accept-encoding", Value: "gzip, br"},
}, },
nil, nil,
) )
if err != nil { if err != nil {
t.Fatalf("error in OpenStream: %s", err) errorsC <- fmt.Errorf("error in OpenStream: %v", err)
return
} }
expectBody := strings.Replace(htmlBody, "paragraph", path, 1) + strconv.Itoa(i) expectBody := strings.Replace(htmlBody, "paragraph", path, 1) + strconv.Itoa(index)
responseBody := make([]byte, len(expectBody)*2) responseBody := make([]byte, len(expectBody)*2)
n, err := stream.Read(responseBody) n, err := stream.Read(responseBody)
if err != nil { if err != nil {
@ -836,7 +817,8 @@ func TestMultipleStreamsWithDictionaries(t *testing.T) {
} }
} }
func sampleSiteHandler(stream *MuxedStream) error { func sampleSiteHandler(files map[string][]byte) MuxedStreamFunc {
return func(stream *MuxedStream) error {
var contentType string var contentType string
var pathHeader Header var pathHeader Header
@ -848,7 +830,7 @@ func sampleSiteHandler(stream *MuxedStream) error {
} }
if pathHeader.Name != ":path" { if pathHeader.Name != ":path" {
panic("Couldn't find :path header in test") return fmt.Errorf("Couldn't find :path header in test")
} }
if strings.Contains(pathHeader.Value, "html") { if strings.Contains(pathHeader.Value, "html") {
@ -864,14 +846,18 @@ func sampleSiteHandler(stream *MuxedStream) error {
Header{Name: "content-type", Value: contentType}, Header{Name: "content-type", Value: contentType},
}) })
log.Debugf("Wrote headers for stream %s", pathHeader.Value) log.Debugf("Wrote headers for stream %s", pathHeader.Value)
b, _ := ioutil.ReadFile("./sample" + pathHeader.Value) file, ok := files[pathHeader.Value]
stream.Write(b) if !ok {
return fmt.Errorf("%s content is not preloaded", pathHeader.Value)
}
stream.Write(file)
log.Debugf("Wrote body for stream %s", pathHeader.Value) log.Debugf("Wrote body for stream %s", pathHeader.Value)
return nil return nil
}
} }
func sampleSiteTest(t *testing.T, muxPair *DefaultMuxerPair, path string) { func sampleSiteTest(muxPair *DefaultMuxerPair, path string, files map[string][]byte) error {
stream, err := muxPair.EdgeMux.OpenStream( stream, err := muxPair.OpenEdgeMuxStream(
[]Header{ []Header{
{Name: ":method", Value: "GET"}, {Name: ":method", Value: "GET"},
{Name: ":scheme", Value: "https"}, {Name: ":scheme", Value: "https"},
@ -883,50 +869,75 @@ func sampleSiteTest(t *testing.T, muxPair *DefaultMuxerPair, path string) {
nil, nil,
) )
if err != nil { if err != nil {
t.Fatalf("error in OpenStream: %s", err) return fmt.Errorf("error in OpenStream: %v", err)
} }
expectBody, _ := ioutil.ReadFile("./sample" + path) file, ok := files[path]
responseBody := make([]byte, len(expectBody)) if !ok {
return fmt.Errorf("%s content is not preloaded", path)
}
responseBody := make([]byte, len(file))
n, err := io.ReadFull(stream, responseBody) n, err := io.ReadFull(stream, responseBody)
log.Debugf("Got body for stream %s", path)
if err != nil { if err != nil {
t.Fatalf("error from (*MuxedStream).Read: %s", err) return fmt.Errorf("error from (*MuxedStream).Read: %v", err)
} }
if n != len(expectBody) { if n != len(file) {
t.Fatalf("expected response body to have %d bytes, got %d", len(expectBody), n) return fmt.Errorf("expected response body to have %d bytes, got %d", len(file), n)
} }
if string(responseBody[:n]) != string(expectBody) { if string(responseBody[:n]) != string(file) {
t.Fatalf("expected response body %s, got %s", expectBody, responseBody[:n]) return fmt.Errorf("expected response body %s, got %s", file, responseBody[:n])
} }
return nil
}
func loadSampleFiles(paths []string) (map[string][]byte, error) {
files := make(map[string][]byte)
for _, path := range paths {
if _, ok := files[path]; !ok {
expectBody, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
files[path] = expectBody
}
}
return files, nil
} }
func TestSampleSiteWithDictionaries(t *testing.T) { func TestSampleSiteWithDictionaries(t *testing.T) {
paths := []string{
"./sample/index.html",
"./sample/index2.html",
"./sample/index1.html",
"./sample/ghost-url.min.js",
"./sample/jquery.fitvids.js",
"./sample/index1.html",
"./sample/index2.html",
"./sample/index.html",
}
files, err := loadSampleFiles(paths)
assert.NoError(t, err)
for q := CompressionNone; q <= CompressionMax; q++ { for q := CompressionNone; q <= CompressionMax; q++ {
muxPair := NewCompressedMuxerPair(q) muxPair := NewCompressedMuxerPair(t, q, sampleSiteHandler(files))
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(sampleSiteHandler) muxPair.Serve(t)
muxPair.HandshakeAndServe(t)
var wg sync.WaitGroup var wg sync.WaitGroup
errC := make(chan error, len(paths))
paths := []string{
"/index.html",
"/index2.html",
"/index1.html",
"/ghost-url.min.js",
"/jquery.fitvids.js",
"/index1.html",
"/index2.html",
"/index.html",
}
wg.Add(len(paths)) wg.Add(len(paths))
for _, s := range paths { for _, s := range paths {
go func(path string) { go func(path string) {
sampleSiteTest(t, muxPair, path) defer wg.Done()
wg.Done() errC <- sampleSiteTest(muxPair, path, files)
}(s) }(s)
} }
wg.Wait() wg.Wait()
close(errC)
for err := range errC {
assert.NoError(t, err)
}
originMuxMetrics := muxPair.OriginMux.Metrics() originMuxMetrics := muxPair.OriginMux.Metrics()
if q > CompressionNone && originMuxMetrics.CompBytesBefore.Value() <= 10*originMuxMetrics.CompBytesAfter.Value() { if q > CompressionNone && originMuxMetrics.CompBytesBefore.Value() <= 10*originMuxMetrics.CompBytesAfter.Value() {
@ -936,35 +947,74 @@ func TestSampleSiteWithDictionaries(t *testing.T) {
} }
func TestLongSiteWithDictionaries(t *testing.T) { func TestLongSiteWithDictionaries(t *testing.T) {
paths := []string{
"./sample/index.html",
"./sample/index1.html",
"./sample/index2.html",
"./sample/ghost-url.min.js",
"./sample/jquery.fitvids.js",
}
files, err := loadSampleFiles(paths)
assert.NoError(t, err)
for q := CompressionNone; q <= CompressionMedium; q++ { for q := CompressionNone; q <= CompressionMedium; q++ {
muxPair := NewCompressedMuxerPair(q) muxPair := NewCompressedMuxerPair(t, q, sampleSiteHandler(files))
muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(sampleSiteHandler) muxPair.Serve(t)
muxPair.HandshakeAndServe(t)
var wg sync.WaitGroup
rand.Seed(time.Now().Unix()) rand.Seed(time.Now().Unix())
paths := []string{ tstLen := 500
"/index.html", errGroup, _ := errgroup.WithContext(context.Background())
"/index1.html",
"/index2.html",
"/ghost-url.min.js",
"/jquery.fitvids.js"}
tstLen := 1000
wg.Add(tstLen)
for i := 0; i < tstLen; i++ { for i := 0; i < tstLen; i++ {
errGroup.Go(func() error {
path := paths[rand.Int()%len(paths)] path := paths[rand.Int()%len(paths)]
go func(path string) { return sampleSiteTest(muxPair, path, files)
sampleSiteTest(t, muxPair, path) })
wg.Done()
}(path)
} }
wg.Wait() assert.NoError(t, errGroup.Wait())
originMuxMetrics := muxPair.OriginMux.Metrics() originMuxMetrics := muxPair.OriginMux.Metrics()
if q > CompressionNone && originMuxMetrics.CompBytesBefore.Value() <= 100*originMuxMetrics.CompBytesAfter.Value() { if q > CompressionNone && originMuxMetrics.CompBytesBefore.Value() <= 10*originMuxMetrics.CompBytesAfter.Value() {
t.Fatalf("Cross-stream compression is expected to give a better compression ratio") t.Fatalf("Cross-stream compression is expected to give a better compression ratio")
} }
} }
} }
func BenchmarkOpenStream(b *testing.B) {
const streams = 5000
for i := 0; i < b.N; i++ {
b.StopTimer()
f := MuxedStreamFunc(func(stream *MuxedStream) error {
if len(stream.Headers) != 1 {
b.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
}
if stream.Headers[0].Name != "test-header" {
b.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
}
if stream.Headers[0].Value != "headerValue" {
b.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
}
stream.WriteHeaders([]Header{
{Name: "response-header", Value: "responseValue"},
})
return nil
})
muxPair := NewDefaultMuxerPair(b, f)
muxPair.Serve(b)
b.StartTimer()
openStreams(b, muxPair, streams)
}
}
func openStreams(b *testing.B, muxPair *DefaultMuxerPair, n int) {
errGroup, _ := errgroup.WithContext(context.Background())
for i := 0; i < n; i++ {
errGroup.Go(func() error {
_, err := muxPair.OpenEdgeMuxStream(
[]Header{{Name: "test-header", Value: "headerValue"}},
nil,
)
return err
})
}
assert.NoError(b, errGroup.Wait())
}

View File

@ -34,6 +34,7 @@ import (
const ( const (
dialTimeout = 15 * time.Second dialTimeout = 15 * time.Second
openStreamTimeout = 30 * time.Second
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
TagHeaderNamePrefix = "Cf-Warp-Tag-" TagHeaderNamePrefix = "Cf-Warp-Tag-"
DuplicateConnectionError = "EDUPCONN" DuplicateConnectionError = "EDUPCONN"
@ -339,11 +340,7 @@ func RegisterTunnel(
uuid uuid.UUID, uuid uuid.UUID,
) error { ) error {
config.TransportLogger.Debug("initiating RPC stream to register") config.TransportLogger.Debug("initiating RPC stream to register")
stream, err := muxer.OpenStream([]h2mux.Header{ stream, err := openStream(ctx, muxer)
{Name: ":method", Value: "RPC"},
{Name: ":scheme", Value: "capnp"},
{Name: ":path", Value: "*"},
}, nil)
if err != nil { if err != nil {
// RPC stream open error // RPC stream open error
return newClientRegisterTunnelError(err, config.Metrics.rpcFail) return newClientRegisterTunnelError(err, config.Metrics.rpcFail)
@ -421,11 +418,8 @@ func processRegisterTunnelError(err string, permanentFailure bool, metrics *Tunn
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error { func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error {
logger.Debug("initiating RPC stream to unregister") logger.Debug("initiating RPC stream to unregister")
stream, err := muxer.OpenStream([]h2mux.Header{ ctx := context.Background()
{Name: ":method", Value: "RPC"}, stream, err := openStream(ctx, muxer)
{Name: ":scheme", Value: "capnp"},
{Name: ":path", Value: "*"},
}, nil)
if err != nil { if err != nil {
// RPC stream open error // RPC stream open error
return err return err
@ -434,7 +428,6 @@ func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log
// stream response error // stream response error
return err return err
} }
ctx := context.Background()
conn := rpc.NewConn( conn := rpc.NewConn(
tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)), tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")), tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
@ -445,6 +438,16 @@ func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log
return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds()) return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
} }
func openStream(ctx context.Context, muxer *h2mux.Muxer) (*h2mux.MuxedStream, error) {
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
defer cancel()
return muxer.OpenStream(openStreamCtx, []h2mux.Header{
{Name: ":method", Value: "RPC"},
{Name: ":scheme", Value: "capnp"},
{Name: ":path", Value: "*"},
}, nil)
}
func LogServerInfo( func LogServerInfo(
promise tunnelrpc.ServerInfo_Promise, promise tunnelrpc.ServerInfo_Promise,
connectionID uint8, connectionID uint8,