perf(cloudflared): reuse memory from buffer pool to get better throughput (#161)

* perf(cloudflared): reuse memory from buffer pool to get better throughput

https://github.com/cloudflare/cloudflared/issues/160
This commit is contained in:
Rueian 2020-02-25 01:06:19 +08:00 committed by GitHub
parent 6488843ac4
commit 464bb53049
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 64 additions and 12 deletions

29
buffer/pool.go Normal file
View File

@ -0,0 +1,29 @@
package buffer
import (
"sync"
)
type Pool struct {
// A Pool must not be copied after first use.
// https://golang.org/pkg/sync/#Pool
buffers sync.Pool
}
func NewPool(bufferSize int) *Pool {
return &Pool{
buffers: sync.Pool{
New: func() interface{} {
return make([]byte, bufferSize)
},
},
}
}
func (p *Pool) Get() []byte {
return p.buffers.Get().([]byte)
}
func (p *Pool) Put(buf []byte) {
p.buffers.Put(buf)
}

View File

@ -11,6 +11,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/cloudflare/cloudflared/buffer"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
@ -62,6 +63,8 @@ type Supervisor struct {
eventDigestLock *sync.RWMutex eventDigestLock *sync.RWMutex
eventDigest []byte eventDigest []byte
bufferPool *buffer.Pool
} }
type resolveResult struct { type resolveResult struct {
@ -96,6 +99,7 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) {
logger: config.Logger.WithField("subsystem", "supervisor"), logger: config.Logger.WithField("subsystem", "supervisor"),
jwtLock: &sync.RWMutex{}, jwtLock: &sync.RWMutex{},
eventDigestLock: &sync.RWMutex{}, eventDigestLock: &sync.RWMutex{},
bufferPool: buffer.NewPool(512 * 1024),
}, nil }, nil
} }
@ -230,7 +234,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return return
} }
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID) err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool)
// If the first tunnel disconnects, keep restarting it. // If the first tunnel disconnects, keep restarting it.
edgeErrors := 0 edgeErrors := 0
for s.unusedIPs() { for s.unusedIPs() {
@ -253,7 +257,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return return
} }
} }
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID) err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool)
} }
} }
@ -272,7 +276,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
if err != nil { if err != nil {
return return
} }
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID) err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool)
} }
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal { func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {

View File

@ -20,6 +20,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/buffer"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
@ -178,6 +179,7 @@ func ServeTunnelLoop(ctx context.Context,
connectionID uint8, connectionID uint8,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
u uuid.UUID, u uuid.UUID,
bufferPool *buffer.Pool,
) error { ) error {
connectionLogger := config.Logger.WithField("connectionID", connectionID) connectionLogger := config.Logger.WithField("connectionID", connectionID)
config.Metrics.incrementHaConnections() config.Metrics.incrementHaConnections()
@ -201,6 +203,7 @@ func ServeTunnelLoop(ctx context.Context,
connectedFuse, connectedFuse,
&backoff, &backoff,
u, u,
bufferPool,
) )
if recoverable { if recoverable {
if duration, ok := backoff.GetBackoffDuration(ctx); ok { if duration, ok := backoff.GetBackoffDuration(ctx); ok {
@ -223,6 +226,7 @@ func ServeTunnel(
connectedFuse *h2mux.BooleanFuse, connectedFuse *h2mux.BooleanFuse,
backoff *BackoffHandler, backoff *BackoffHandler,
u uuid.UUID, u uuid.UUID,
bufferPool *buffer.Pool,
) (err error, recoverable bool) { ) (err error, recoverable bool) {
// Treat panics as recoverable errors // Treat panics as recoverable errors
defer func() { defer func() {
@ -243,7 +247,7 @@ func ServeTunnel(
tags["ha"] = connectionTag tags["ha"] = connectionTag
// Returns error from parsing the origin URL or handshake errors // Returns error from parsing the origin URL or handshake errors
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID) handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID, bufferPool)
if err != nil { if err != nil {
errLog := logger.WithError(err) errLog := logger.WithError(err)
switch err.(type) { switch err.(type) {
@ -500,6 +504,8 @@ type TunnelHandler struct {
connectionID string connectionID string
logger *log.Logger logger *log.Logger
noChunkedEncoding bool noChunkedEncoding bool
bufferPool *buffer.Pool
} }
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
@ -507,6 +513,7 @@ func NewTunnelHandler(ctx context.Context,
config *TunnelConfig, config *TunnelConfig,
addr *net.TCPAddr, addr *net.TCPAddr,
connectionID uint8, connectionID uint8,
bufferPool *buffer.Pool,
) (*TunnelHandler, string, error) { ) (*TunnelHandler, string, error) {
originURL, err := validation.ValidateUrl(config.OriginUrl) originURL, err := validation.ValidateUrl(config.OriginUrl)
if err != nil { if err != nil {
@ -522,6 +529,7 @@ func NewTunnelHandler(ctx context.Context,
connectionID: uint8ToString(connectionID), connectionID: uint8ToString(connectionID),
logger: config.Logger, logger: config.Logger,
noChunkedEncoding: config.NoChunkedEncoding, noChunkedEncoding: config.NoChunkedEncoding,
bufferPool: bufferPool,
} }
if h.httpClient == nil { if h.httpClient == nil {
h.httpClient = http.DefaultTransport h.httpClient = http.DefaultTransport
@ -642,7 +650,9 @@ func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request)
} else { } else {
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
// compression generates dictionary on first write // compression generates dictionary on first write
io.CopyBuffer(stream, response.Body, make([]byte, 512*1024)) buf := h.bufferPool.Get()
defer h.bufferPool.Put(buf)
io.CopyBuffer(stream, response.Body, buf)
} }
return response, nil return response, nil
} }

View File

@ -12,6 +12,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/cloudflare/cloudflared/buffer"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/log" "github.com/cloudflare/cloudflared/log"
@ -33,6 +34,7 @@ type HTTPService struct {
client http.RoundTripper client http.RoundTripper
originURL *url.URL originURL *url.URL
chunkedEncoding bool chunkedEncoding bool
bufferPool *buffer.Pool
} }
func NewHTTPService(transport http.RoundTripper, url *url.URL, chunkedEncoding bool) OriginService { func NewHTTPService(transport http.RoundTripper, url *url.URL, chunkedEncoding bool) OriginService {
@ -40,6 +42,7 @@ func NewHTTPService(transport http.RoundTripper, url *url.URL, chunkedEncoding b
client: transport, client: transport,
originURL: url, originURL: url,
chunkedEncoding: chunkedEncoding, chunkedEncoding: chunkedEncoding,
bufferPool: buffer.NewPool(512 * 1024),
} }
} }
@ -71,7 +74,9 @@ func (hc *HTTPService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*htt
} else { } else {
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
// compression generates dictionary on first write // compression generates dictionary on first write
io.CopyBuffer(stream, resp.Body, make([]byte, 512*1024)) buf := hc.bufferPool.Get()
defer hc.bufferPool.Put(buf)
io.CopyBuffer(stream, resp.Body, buf)
} }
return resp, nil return resp, nil
} }
@ -142,10 +147,11 @@ func (wsc *WebsocketService) Shutdown() {
// HelloWorldService talks to the hello world example origin // HelloWorldService talks to the hello world example origin
type HelloWorldService struct { type HelloWorldService struct {
client http.RoundTripper client http.RoundTripper
listener net.Listener listener net.Listener
originURL *url.URL originURL *url.URL
shutdownC chan struct{} shutdownC chan struct{}
bufferPool *buffer.Pool
} }
func NewHelloWorldService(transport http.RoundTripper) (OriginService, error) { func NewHelloWorldService(transport http.RoundTripper) (OriginService, error) {
@ -164,7 +170,8 @@ func NewHelloWorldService(transport http.RoundTripper) (OriginService, error) {
Scheme: "https", Scheme: "https",
Host: listener.Addr().String(), Host: listener.Addr().String(),
}, },
shutdownC: shutdownC, shutdownC: shutdownC,
bufferPool: buffer.NewPool(512 * 1024),
}, nil }, nil
} }
@ -184,7 +191,9 @@ func (hwc *HelloWorldService) Proxy(stream *h2mux.MuxedStream, req *http.Request
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
// compression generates dictionary on first write // compression generates dictionary on first write
io.CopyBuffer(stream, resp.Body, make([]byte, 512*1024)) buf := hwc.bufferPool.Get()
defer hwc.bufferPool.Put(buf)
io.CopyBuffer(stream, resp.Body, buf)
return resp, nil return resp, nil
} }