From 2bef5dbe72bf7b0d9f4cf60e7ac9e1ca468e5a1e Mon Sep 17 00:00:00 2001 From: Chung-Ting Huang Date: Tue, 2 Apr 2019 18:12:09 -0500 Subject: [PATCH] TUN-1682: Add context to OpenStream to prevent it from blocking indefinitely. --- connection/connection.go | 11 +- h2mux/error.go | 2 + h2mux/h2mux.go | 13 +- h2mux/h2mux_test.go | 438 ++++++++++++++++++++++----------------- origin/tunnel.go | 25 ++- 5 files changed, 276 insertions(+), 213 deletions(-) diff --git a/connection/connection.go b/connection/connection.go index 220e9b28..f9dd6125 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -17,7 +17,8 @@ import ( ) const ( - dialTimeout = 5 * time.Second + dialTimeout = 5 * time.Second + openStreamTimeout = 30 * time.Second ) type dialError struct { @@ -70,7 +71,9 @@ func (h *h2muxHandler) serve(ctx context.Context) error { // Connect is used to establish connections with cloudflare's edge network func (h *h2muxHandler) connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters) (*tunnelpogs.ConnectResult, error) { - conn, err := h.newRPConn() + openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout) + defer cancel() + conn, err := h.newRPConn(openStreamCtx) if err != nil { return nil, errors.Wrap(err, "Failed to create new RPC connection") } @@ -83,8 +86,8 @@ func (h *h2muxHandler) shutdown() { h.muxer.Shutdown() } -func (h *h2muxHandler) newRPConn() (*rpc.Conn, error) { - stream, err := h.muxer.OpenStream([]h2mux.Header{ +func (h *h2muxHandler) newRPConn(ctx context.Context) (*rpc.Conn, error) { + stream, err := h.muxer.OpenStream(ctx, []h2mux.Header{ {Name: ":method", Value: "RPC"}, {Name: ":scheme", Value: "capnp"}, {Name: ":path", Value: "*"}, diff --git a/h2mux/error.go b/h2mux/error.go index efb1f37d..4a95bab9 100644 --- a/h2mux/error.go +++ b/h2mux/error.go @@ -7,6 +7,7 @@ import ( ) var ( + // HTTP2 error codes: https://http2.github.io/http2-spec/#ErrorCodes ErrHandshakeTimeout = MuxerHandshakeError{"1000 handshake timeout"} ErrBadHandshakeNotSettings = MuxerHandshakeError{"1001 unexpected response"} ErrBadHandshakeUnexpectedAck = MuxerHandshakeError{"1002 unexpected response"} @@ -22,6 +23,7 @@ var ( ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"} ErrConnectionClosed = MuxerApplicationError{"3001 connection closed"} ErrConnectionDropped = MuxerApplicationError{"3002 connection dropped"} + ErrOpenStreamTimeout = MuxerApplicationError{"3003 open stream timeout"} ErrClosedStream = MuxerStreamError{"4000 stream closed", http2.ErrCodeStreamClosed} ) diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go index 1341c127..65034830 100644 --- a/h2mux/h2mux.go +++ b/h2mux/h2mux.go @@ -379,7 +379,7 @@ func isConnectionClosedError(err error) bool { // OpenStream opens a new data stream with the given headers. // Called by proxy server and tunnel -func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, error) { +func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader) (*MuxedStream, error) { stream := &MuxedStream{ responseHeadersReceived: make(chan struct{}), readBuffer: NewSharedBuffer(), @@ -397,15 +397,20 @@ func (m *Muxer) OpenStream(headers []Header, body io.Reader) (*MuxedStream, erro select { // Will be received by mux writer - case m.newStreamChan <- MuxedStreamRequest{stream: stream, body: body}: + case <-ctx.Done(): + return nil, ErrOpenStreamTimeout case <-m.abortChan: return nil, ErrConnectionClosed + case m.newStreamChan <- MuxedStreamRequest{stream: stream, body: body}: } + select { + case <-ctx.Done(): + return nil, ErrOpenStreamTimeout + case <-m.abortChan: + return nil, ErrConnectionClosed case <-stream.responseHeadersReceived: return stream, nil - case <-m.abortChan: - return nil, ErrConnectionClosed } } diff --git a/h2mux/h2mux_test.go b/h2mux/h2mux_test.go index 241f8a0e..89aba250 100644 --- a/h2mux/h2mux_test.go +++ b/h2mux/h2mux_test.go @@ -15,7 +15,15 @@ import ( "testing" "time" + "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" +) + +const ( + testOpenStreamTimeout = time.Millisecond * 5000 + testHandshakeTimeout = time.Millisecond * 1000 ) func TestMain(m *testing.M) { @@ -35,11 +43,12 @@ type DefaultMuxerPair struct { doneC chan struct{} } -func NewDefaultMuxerPair() *DefaultMuxerPair { +func NewDefaultMuxerPair(t assert.TestingT, f MuxedStreamFunc) *DefaultMuxerPair { origin, edge := net.Pipe() - return &DefaultMuxerPair{ + p := &DefaultMuxerPair{ OriginMuxConfig: MuxerConfig{ - Timeout: time.Second, + Timeout: testHandshakeTimeout, + Handler: f, IsClient: true, Name: "origin", Logger: log.NewEntry(log.New()), @@ -49,7 +58,7 @@ func NewDefaultMuxerPair() *DefaultMuxerPair { }, OriginConn: origin, EdgeMuxConfig: MuxerConfig{ - Timeout: time.Second, + Timeout: testHandshakeTimeout, IsClient: false, Name: "edge", Logger: log.NewEntry(log.New()), @@ -60,13 +69,16 @@ func NewDefaultMuxerPair() *DefaultMuxerPair { EdgeConn: edge, doneC: make(chan struct{}), } + assert.NoError(t, p.Handshake()) + return p } -func NewCompressedMuxerPair(quality CompressionSetting) *DefaultMuxerPair { +func NewCompressedMuxerPair(t assert.TestingT, quality CompressionSetting, f MuxedStreamFunc) *DefaultMuxerPair { origin, edge := net.Pipe() - return &DefaultMuxerPair{ + p := &DefaultMuxerPair{ OriginMuxConfig: MuxerConfig{ Timeout: time.Second, + Handler: f, IsClient: true, Name: "origin", CompressionQuality: quality, @@ -83,44 +95,28 @@ func NewCompressedMuxerPair(quality CompressionSetting) *DefaultMuxerPair { EdgeConn: edge, doneC: make(chan struct{}), } + assert.NoError(t, p.Handshake()) + return p } -func (p *DefaultMuxerPair) Handshake(t *testing.T) { - edgeErrC := make(chan error) - originErrC := make(chan error) - go func() { - var err error +func (p *DefaultMuxerPair) Handshake() error { + ctx, cancel := context.WithTimeout(context.Background(), testHandshakeTimeout) + defer cancel() + errGroup, _ := errgroup.WithContext(ctx) + errGroup.Go(func() (err error) { p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig) - edgeErrC <- err - }() - go func() { - var err error + return errors.Wrap(err, "edge handshake failure") + }) + errGroup.Go(func() (err error) { p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig) - originErrC <- err - }() + return errors.Wrap(err, "origin handshake failure") + }) - select { - case err := <-edgeErrC: - if err != nil { - t.Fatalf("edge handshake failure: %s", err) - } - case <-time.After(time.Second * 5): - t.Fatalf("edge handshake timeout") - } - - select { - case err := <-originErrC: - if err != nil { - t.Fatalf("origin handshake failure: %s", err) - } - case <-time.After(time.Second * 5): - t.Fatalf("origin handshake timeout") - } + return errGroup.Wait() } -func (p *DefaultMuxerPair) HandshakeAndServe(t *testing.T) { +func (p *DefaultMuxerPair) Serve(t assert.TestingT) { ctx := context.Background() - p.Handshake(t) var wg sync.WaitGroup wg.Add(2) go func() { @@ -155,18 +151,23 @@ func (p *DefaultMuxerPair) Wait(t *testing.T) { } } +func (p *DefaultMuxerPair) OpenEdgeMuxStream(headers []Header, body io.Reader) (*MuxedStream, error) { + ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout) + defer cancel() + return p.EdgeMux.OpenStream(ctx, headers, body) +} + func TestHandshake(t *testing.T) { - muxPair := NewDefaultMuxerPair() - muxPair.Handshake(t) + f := func(stream *MuxedStream) error { + return nil + } + muxPair := NewDefaultMuxerPair(t, f) AssertIfPipeReadable(t, muxPair.OriginConn) AssertIfPipeReadable(t, muxPair.EdgeConn) } func TestSingleStream(t *testing.T) { - closeC := make(chan struct{}) - muxPair := NewDefaultMuxerPair() - muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { - defer close(closeC) + f := MuxedStreamFunc(func(stream *MuxedStream) error { if len(stream.Headers) != 1 { t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) } @@ -181,8 +182,6 @@ func TestSingleStream(t *testing.T) { }) buf := []byte("Hello world") stream.Write(buf) - // after this receive, the edge closed the stream - <-closeC n, err := io.ReadFull(stream, buf) if n > 0 { t.Fatalf("read %d bytes after EOF", n) @@ -192,9 +191,10 @@ func TestSingleStream(t *testing.T) { } return nil }) - muxPair.HandshakeAndServe(t) + muxPair := NewDefaultMuxerPair(t, f) + muxPair.Serve(t) - stream, err := muxPair.EdgeMux.OpenStream( + stream, err := muxPair.OpenEdgeMuxStream( []Header{{Name: "test-header", Value: "headerValue"}}, nil, ) @@ -222,7 +222,6 @@ func TestSingleStream(t *testing.T) { t.Fatalf("expected response body %s, got %s", "Hello world", responseBody) } stream.Close() - closeC <- struct{}{} n, err = stream.Write([]byte("aaaaa")) if n > 0 { t.Fatalf("wrote %d bytes after EOF", n) @@ -230,13 +229,11 @@ func TestSingleStream(t *testing.T) { if err != io.EOF { t.Fatalf("expected EOF, got %s", err) } - <-closeC } func TestSingleStreamLargeResponseBody(t *testing.T) { - muxPair := NewDefaultMuxerPair() bodySize := 1 << 24 - muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { + f := MuxedStreamFunc(func(stream *MuxedStream) error { if len(stream.Headers) != 1 { t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) } @@ -265,9 +262,10 @@ func TestSingleStreamLargeResponseBody(t *testing.T) { return nil }) - muxPair.HandshakeAndServe(t) + muxPair := NewDefaultMuxerPair(t, f) + muxPair.Serve(t) - stream, err := muxPair.EdgeMux.OpenStream( + stream, err := muxPair.OpenEdgeMuxStream( []Header{{Name: "test-header", Value: "headerValue"}}, nil, ) @@ -295,10 +293,7 @@ func TestSingleStreamLargeResponseBody(t *testing.T) { } func TestMultipleStreams(t *testing.T) { - muxPair := NewDefaultMuxerPair() - maxStreams := 64 - errorsC := make(chan error, maxStreams) - muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { + f := MuxedStreamFunc(func(stream *MuxedStream) error { if len(stream.Headers) != 1 { t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) } @@ -314,15 +309,18 @@ func TestMultipleStreams(t *testing.T) { log.Debugf("Wrote body for stream %s", stream.Headers[0].Value) return nil }) - muxPair.HandshakeAndServe(t) + muxPair := NewDefaultMuxerPair(t, f) + muxPair.Serve(t) + maxStreams := 64 + errorsC := make(chan error, maxStreams) var wg sync.WaitGroup wg.Add(maxStreams) for i := 0; i < maxStreams; i++ { go func(tokenId int) { defer wg.Done() tokenString := fmt.Sprintf("%d", tokenId) - stream, err := muxPair.EdgeMux.OpenStream( + stream, err := muxPair.OpenEdgeMuxStream( []Header{{Name: "client-token", Value: tokenString}}, nil, ) @@ -373,13 +371,12 @@ func TestMultipleStreams(t *testing.T) { func TestMultipleStreamsFlowControl(t *testing.T) { maxStreams := 32 - errorsC := make(chan error, maxStreams) responseSizes := make([]int32, maxStreams) for i := 0; i < maxStreams; i++ { responseSizes[i] = rand.Int31n(int32(defaultWindowSize << 4)) } - muxPair := NewDefaultMuxerPair() - muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { + + f := MuxedStreamFunc(func(stream *MuxedStream) error { if len(stream.Headers) != 1 { t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) } @@ -405,63 +402,48 @@ func TestMultipleStreamsFlowControl(t *testing.T) { } return nil }) - muxPair.HandshakeAndServe(t) + muxPair := NewDefaultMuxerPair(t, f) + muxPair.Serve(t) - var wg sync.WaitGroup - wg.Add(maxStreams) + errGroup, _ := errgroup.WithContext(context.Background()) for i := 0; i < maxStreams; i++ { - go func(tokenId int) { - defer wg.Done() - stream, err := muxPair.EdgeMux.OpenStream( + errGroup.Go(func() error { + stream, err := muxPair.OpenEdgeMuxStream( []Header{{Name: "test-header", Value: "headerValue"}}, nil, ) if err != nil { - errorsC <- fmt.Errorf("stream %d error in OpenStream: %s", stream.streamID, err) - return + return fmt.Errorf("error in OpenStream: %d %s", stream.streamID, err) } if len(stream.Headers) != 1 { - errorsC <- fmt.Errorf("stream %d expected %d headers, got %d", stream.streamID, 1, len(stream.Headers)) - return + return fmt.Errorf("stream %d expected %d headers, got %d", stream.streamID, 1, len(stream.Headers)) } if stream.Headers[0].Name != "response-header" { - errorsC <- fmt.Errorf("stream %d expected header name %s, got %s", stream.streamID, "response-header", stream.Headers[0].Name) - return + return fmt.Errorf("stream %d expected header name %s, got %s", stream.streamID, "response-header", stream.Headers[0].Name) } if stream.Headers[0].Value != "responseValue" { - errorsC <- fmt.Errorf("stream %d expected header value %s, got %s", stream.streamID, "responseValue", stream.Headers[0].Value) - return + return fmt.Errorf("stream %d expected header value %s, got %s", stream.streamID, "responseValue", stream.Headers[0].Value) } responseBody := make([]byte, responseSizes[(stream.streamID-2)/2]) n, err := io.ReadFull(stream, responseBody) if err != nil { - errorsC <- fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err) - return + return fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err) } if n != len(responseBody) { - errorsC <- fmt.Errorf("stream %d expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n) - return + return fmt.Errorf("stream %d expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n) } - }(i) - } - wg.Wait() - close(errorsC) - testFail := false - for err := range errorsC { - testFail = true - log.Error(err) - } - if testFail { - t.Fatalf("TestMultipleStreamsFlowControl failed") + return nil + }) } + assert.NoError(t, errGroup.Wait()) } func TestGracefulShutdown(t *testing.T) { sendC := make(chan struct{}) responseBuf := bytes.Repeat([]byte("Hello world"), 65536) - muxPair := NewDefaultMuxerPair() - muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { + + f := MuxedStreamFunc(func(stream *MuxedStream) error { stream.WriteHeaders([]Header{ {Name: "response-header", Value: "responseValue"}, }) @@ -479,18 +461,19 @@ func TestGracefulShutdown(t *testing.T) { log.Debugf("Handler ends") return nil }) - muxPair.HandshakeAndServe(t) + muxPair := NewDefaultMuxerPair(t, f) + muxPair.Serve(t) - stream, err := muxPair.EdgeMux.OpenStream( + stream, err := muxPair.OpenEdgeMuxStream( []Header{{Name: "test-header", Value: "headerValue"}}, nil, ) - // Start graceful shutdown of the edge mux - this should also close the origin mux when done - muxPair.EdgeMux.Shutdown() - close(sendC) if err != nil { t.Fatalf("error in OpenStream: %s", err) } + // Start graceful shutdown of the edge mux - this should also close the origin mux when done + muxPair.EdgeMux.Shutdown() + close(sendC) responseBody := make([]byte, len(responseBuf)) log.Debugf("Waiting for %d bytes", len(responseBuf)) n, err := io.ReadFull(stream, responseBody) @@ -511,8 +494,8 @@ func TestUnexpectedShutdown(t *testing.T) { sendC := make(chan struct{}) handlerFinishC := make(chan struct{}) responseBuf := bytes.Repeat([]byte("Hello world"), 65536) - muxPair := NewDefaultMuxerPair() - muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { + + f := MuxedStreamFunc(func(stream *MuxedStream) error { defer close(handlerFinishC) stream.WriteHeaders([]Header{ {Name: "response-header", Value: "responseValue"}, @@ -533,9 +516,10 @@ func TestUnexpectedShutdown(t *testing.T) { } return nil }) - muxPair.HandshakeAndServe(t) + muxPair := NewDefaultMuxerPair(t, f) + muxPair.Serve(t) - stream, err := muxPair.EdgeMux.OpenStream( + stream, err := muxPair.OpenEdgeMuxStream( []Header{{Name: "test-header", Value: "headerValue"}}, nil, ) @@ -580,9 +564,8 @@ func EchoHandler(stream *MuxedStream) error { func TestOpenAfterDisconnect(t *testing.T) { for i := 0; i < 3; i++ { - muxPair := NewDefaultMuxerPair() - muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(EchoHandler) - muxPair.HandshakeAndServe(t) + muxPair := NewDefaultMuxerPair(t, EchoHandler) + muxPair.Serve(t) switch i { case 0: @@ -597,7 +580,7 @@ func TestOpenAfterDisconnect(t *testing.T) { muxPair.EdgeConn.Close() } - _, err := muxPair.EdgeMux.OpenStream( + _, err := muxPair.OpenEdgeMuxStream( []Header{{Name: "test-header", Value: "headerValue"}}, nil, ) @@ -608,11 +591,10 @@ func TestOpenAfterDisconnect(t *testing.T) { } func TestHPACK(t *testing.T) { - muxPair := NewDefaultMuxerPair() - muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(EchoHandler) - muxPair.HandshakeAndServe(t) + muxPair := NewDefaultMuxerPair(t, EchoHandler) + muxPair.Serve(t) - stream, err := muxPair.EdgeMux.OpenStream( + stream, err := muxPair.OpenEdgeMuxStream( []Header{ {Name: ":method", Value: "RPC"}, {Name: ":scheme", Value: "capnp"}, @@ -626,7 +608,7 @@ func TestHPACK(t *testing.T) { stream.Close() for i := 0; i < 3; i++ { - stream, err := muxPair.EdgeMux.OpenStream( + stream, err := muxPair.OpenEdgeMuxStream( []Header{ {Name: ":method", Value: "GET"}, {Name: ":scheme", Value: "https"}, @@ -688,8 +670,6 @@ func AssertIfPipeReadable(t *testing.T, pipe io.ReadCloser) { func TestMultipleStreamsWithDictionaries(t *testing.T) { for q := CompressionNone; q <= CompressionMax; q++ { - muxPair := NewCompressedMuxerPair(q) - htmlBody := `` + `` + @@ -712,7 +692,7 @@ func TestMultipleStreamsWithDictionaries(t *testing.T) { `` + `` - muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(func(stream *MuxedStream) error { + f := MuxedStreamFunc(func(stream *MuxedStream) error { var contentType string var pathHeader Header @@ -744,8 +724,8 @@ func TestMultipleStreamsWithDictionaries(t *testing.T) { return nil }) - - muxPair.HandshakeAndServe(t) + muxPair := NewCompressedMuxerPair(t, q, f) + muxPair.Serve(t) var wg sync.WaitGroup @@ -782,25 +762,26 @@ func TestMultipleStreamsWithDictionaries(t *testing.T) { errorsC := make(chan error, len(paths)) for i, s := range paths { - go func(i int, path string) { + go func(index int, path string) { defer wg.Done() - stream, err := muxPair.EdgeMux.OpenStream( + stream, err := muxPair.OpenEdgeMuxStream( []Header{ {Name: ":method", Value: "GET"}, {Name: ":scheme", Value: "https"}, {Name: ":authority", Value: "tunnel.otterlyadorable.co.uk"}, {Name: ":path", Value: path}, {Name: "cf-ray", Value: "378948953f044408-SFO-DOG"}, - {Name: "idx", Value: strconv.Itoa(i)}, + {Name: "idx", Value: strconv.Itoa(index)}, {Name: "accept-encoding", Value: "gzip, br"}, }, nil, ) if err != nil { - t.Fatalf("error in OpenStream: %s", err) + errorsC <- fmt.Errorf("error in OpenStream: %v", err) + return } - expectBody := strings.Replace(htmlBody, "paragraph", path, 1) + strconv.Itoa(i) + expectBody := strings.Replace(htmlBody, "paragraph", path, 1) + strconv.Itoa(index) responseBody := make([]byte, len(expectBody)*2) n, err := stream.Read(responseBody) if err != nil { @@ -836,42 +817,47 @@ func TestMultipleStreamsWithDictionaries(t *testing.T) { } } -func sampleSiteHandler(stream *MuxedStream) error { - var contentType string - var pathHeader Header +func sampleSiteHandler(files map[string][]byte) MuxedStreamFunc { + return func(stream *MuxedStream) error { + var contentType string + var pathHeader Header - for _, h := range stream.Headers { - if h.Name == ":path" { - pathHeader = h - break + for _, h := range stream.Headers { + if h.Name == ":path" { + pathHeader = h + break + } } - } - if pathHeader.Name != ":path" { - panic("Couldn't find :path header in test") - } + if pathHeader.Name != ":path" { + return fmt.Errorf("Couldn't find :path header in test") + } - if strings.Contains(pathHeader.Value, "html") { - contentType = "text/html; charset=utf-8" - } else if strings.Contains(pathHeader.Value, "js") { - contentType = "application/javascript" - } else if strings.Contains(pathHeader.Value, "css") { - contentType = "text/css" - } else { - contentType = "img/gif" + if strings.Contains(pathHeader.Value, "html") { + contentType = "text/html; charset=utf-8" + } else if strings.Contains(pathHeader.Value, "js") { + contentType = "application/javascript" + } else if strings.Contains(pathHeader.Value, "css") { + contentType = "text/css" + } else { + contentType = "img/gif" + } + stream.WriteHeaders([]Header{ + Header{Name: "content-type", Value: contentType}, + }) + log.Debugf("Wrote headers for stream %s", pathHeader.Value) + file, ok := files[pathHeader.Value] + if !ok { + return fmt.Errorf("%s content is not preloaded", pathHeader.Value) + } + stream.Write(file) + log.Debugf("Wrote body for stream %s", pathHeader.Value) + return nil } - stream.WriteHeaders([]Header{ - Header{Name: "content-type", Value: contentType}, - }) - log.Debugf("Wrote headers for stream %s", pathHeader.Value) - b, _ := ioutil.ReadFile("./sample" + pathHeader.Value) - stream.Write(b) - log.Debugf("Wrote body for stream %s", pathHeader.Value) - return nil } -func sampleSiteTest(t *testing.T, muxPair *DefaultMuxerPair, path string) { - stream, err := muxPair.EdgeMux.OpenStream( +func sampleSiteTest(muxPair *DefaultMuxerPair, path string, files map[string][]byte) error { + stream, err := muxPair.OpenEdgeMuxStream( []Header{ {Name: ":method", Value: "GET"}, {Name: ":scheme", Value: "https"}, @@ -883,50 +869,75 @@ func sampleSiteTest(t *testing.T, muxPair *DefaultMuxerPair, path string) { nil, ) if err != nil { - t.Fatalf("error in OpenStream: %s", err) + return fmt.Errorf("error in OpenStream: %v", err) } - expectBody, _ := ioutil.ReadFile("./sample" + path) - responseBody := make([]byte, len(expectBody)) + file, ok := files[path] + if !ok { + return fmt.Errorf("%s content is not preloaded", path) + } + responseBody := make([]byte, len(file)) n, err := io.ReadFull(stream, responseBody) - log.Debugf("Got body for stream %s", path) if err != nil { - t.Fatalf("error from (*MuxedStream).Read: %s", err) + return fmt.Errorf("error from (*MuxedStream).Read: %v", err) } - if n != len(expectBody) { - t.Fatalf("expected response body to have %d bytes, got %d", len(expectBody), n) + if n != len(file) { + return fmt.Errorf("expected response body to have %d bytes, got %d", len(file), n) } - if string(responseBody[:n]) != string(expectBody) { - t.Fatalf("expected response body %s, got %s", expectBody, responseBody[:n]) + if string(responseBody[:n]) != string(file) { + return fmt.Errorf("expected response body %s, got %s", file, responseBody[:n]) } + return nil +} + +func loadSampleFiles(paths []string) (map[string][]byte, error) { + files := make(map[string][]byte) + for _, path := range paths { + if _, ok := files[path]; !ok { + expectBody, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + files[path] = expectBody + } + } + return files, nil } func TestSampleSiteWithDictionaries(t *testing.T) { + paths := []string{ + "./sample/index.html", + "./sample/index2.html", + "./sample/index1.html", + "./sample/ghost-url.min.js", + "./sample/jquery.fitvids.js", + "./sample/index1.html", + "./sample/index2.html", + "./sample/index.html", + } + files, err := loadSampleFiles(paths) + assert.NoError(t, err) + for q := CompressionNone; q <= CompressionMax; q++ { - muxPair := NewCompressedMuxerPair(q) - muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(sampleSiteHandler) - muxPair.HandshakeAndServe(t) + muxPair := NewCompressedMuxerPair(t, q, sampleSiteHandler(files)) + muxPair.Serve(t) var wg sync.WaitGroup - - paths := []string{ - "/index.html", - "/index2.html", - "/index1.html", - "/ghost-url.min.js", - "/jquery.fitvids.js", - "/index1.html", - "/index2.html", - "/index.html", - } + errC := make(chan error, len(paths)) wg.Add(len(paths)) for _, s := range paths { go func(path string) { - sampleSiteTest(t, muxPair, path) - wg.Done() + defer wg.Done() + errC <- sampleSiteTest(muxPair, path, files) }(s) } + wg.Wait() + close(errC) + + for err := range errC { + assert.NoError(t, err) + } originMuxMetrics := muxPair.OriginMux.Metrics() if q > CompressionNone && originMuxMetrics.CompBytesBefore.Value() <= 10*originMuxMetrics.CompBytesAfter.Value() { @@ -936,35 +947,74 @@ func TestSampleSiteWithDictionaries(t *testing.T) { } func TestLongSiteWithDictionaries(t *testing.T) { + paths := []string{ + "./sample/index.html", + "./sample/index1.html", + "./sample/index2.html", + "./sample/ghost-url.min.js", + "./sample/jquery.fitvids.js", + } + files, err := loadSampleFiles(paths) + assert.NoError(t, err) for q := CompressionNone; q <= CompressionMedium; q++ { - muxPair := NewCompressedMuxerPair(q) - muxPair.OriginMuxConfig.Handler = MuxedStreamFunc(sampleSiteHandler) - muxPair.HandshakeAndServe(t) + muxPair := NewCompressedMuxerPair(t, q, sampleSiteHandler(files)) + muxPair.Serve(t) - var wg sync.WaitGroup rand.Seed(time.Now().Unix()) - paths := []string{ - "/index.html", - "/index1.html", - "/index2.html", - "/ghost-url.min.js", - "/jquery.fitvids.js"} - - tstLen := 1000 - wg.Add(tstLen) + tstLen := 500 + errGroup, _ := errgroup.WithContext(context.Background()) for i := 0; i < tstLen; i++ { - path := paths[rand.Int()%len(paths)] - go func(path string) { - sampleSiteTest(t, muxPair, path) - wg.Done() - }(path) + errGroup.Go(func() error { + path := paths[rand.Int()%len(paths)] + return sampleSiteTest(muxPair, path, files) + }) } - wg.Wait() + assert.NoError(t, errGroup.Wait()) originMuxMetrics := muxPair.OriginMux.Metrics() - if q > CompressionNone && originMuxMetrics.CompBytesBefore.Value() <= 100*originMuxMetrics.CompBytesAfter.Value() { + if q > CompressionNone && originMuxMetrics.CompBytesBefore.Value() <= 10*originMuxMetrics.CompBytesAfter.Value() { t.Fatalf("Cross-stream compression is expected to give a better compression ratio") } } } + +func BenchmarkOpenStream(b *testing.B) { + const streams = 5000 + for i := 0; i < b.N; i++ { + b.StopTimer() + f := MuxedStreamFunc(func(stream *MuxedStream) error { + if len(stream.Headers) != 1 { + b.Fatalf("expected %d headers, got %d", 1, len(stream.Headers)) + } + if stream.Headers[0].Name != "test-header" { + b.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name) + } + if stream.Headers[0].Value != "headerValue" { + b.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value) + } + stream.WriteHeaders([]Header{ + {Name: "response-header", Value: "responseValue"}, + }) + return nil + }) + muxPair := NewDefaultMuxerPair(b, f) + muxPair.Serve(b) + b.StartTimer() + openStreams(b, muxPair, streams) + } +} + +func openStreams(b *testing.B, muxPair *DefaultMuxerPair, n int) { + errGroup, _ := errgroup.WithContext(context.Background()) + for i := 0; i < n; i++ { + errGroup.Go(func() error { + _, err := muxPair.OpenEdgeMuxStream( + []Header{{Name: "test-header", Value: "headerValue"}}, + nil, + ) + return err + }) + } + assert.NoError(b, errGroup.Wait()) +} diff --git a/origin/tunnel.go b/origin/tunnel.go index 60bf6f3e..a84b87d2 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -34,6 +34,7 @@ import ( const ( dialTimeout = 15 * time.Second + openStreamTimeout = 30 * time.Second lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" TagHeaderNamePrefix = "Cf-Warp-Tag-" DuplicateConnectionError = "EDUPCONN" @@ -339,11 +340,7 @@ func RegisterTunnel( uuid uuid.UUID, ) error { config.TransportLogger.Debug("initiating RPC stream to register") - stream, err := muxer.OpenStream([]h2mux.Header{ - {Name: ":method", Value: "RPC"}, - {Name: ":scheme", Value: "capnp"}, - {Name: ":path", Value: "*"}, - }, nil) + stream, err := openStream(ctx, muxer) if err != nil { // RPC stream open error return newClientRegisterTunnelError(err, config.Metrics.rpcFail) @@ -421,11 +418,8 @@ func processRegisterTunnelError(err string, permanentFailure bool, metrics *Tunn func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error { logger.Debug("initiating RPC stream to unregister") - stream, err := muxer.OpenStream([]h2mux.Header{ - {Name: ":method", Value: "RPC"}, - {Name: ":scheme", Value: "capnp"}, - {Name: ":path", Value: "*"}, - }, nil) + ctx := context.Background() + stream, err := openStream(ctx, muxer) if err != nil { // RPC stream open error return err @@ -434,7 +428,6 @@ func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log // stream response error return err } - ctx := context.Background() conn := rpc.NewConn( tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)), tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")), @@ -445,6 +438,16 @@ func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds()) } +func openStream(ctx context.Context, muxer *h2mux.Muxer) (*h2mux.MuxedStream, error) { + openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout) + defer cancel() + return muxer.OpenStream(openStreamCtx, []h2mux.Header{ + {Name: ":method", Value: "RPC"}, + {Name: ":scheme", Value: "capnp"}, + {Name: ":path", Value: "*"}, + }, nil) +} + func LogServerInfo( promise tunnelrpc.ServerInfo_Promise, connectionID uint8,