diff --git a/buffer/pool.go b/buffer/pool.go new file mode 100644 index 00000000..3e7d2b62 --- /dev/null +++ b/buffer/pool.go @@ -0,0 +1,27 @@ +package buffer + +import ( + "sync" +) + +type Pool struct { + buffers sync.Pool +} + +func NewPool(bufferSize int) *Pool { + return &Pool{ + buffers: sync.Pool{ + New: func() interface{} { + return make([]byte, bufferSize) + }, + }, + } +} + +func (p *Pool) Get() []byte { + return p.buffers.Get().([]byte) +} + +func (p *Pool) Put(buf []byte) { + p.buffers.Put(buf) +} diff --git a/origin/supervisor.go b/origin/supervisor.go index b8df583f..a1c78571 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" "github.com/sirupsen/logrus" + "github.com/cloudflare/cloudflared/buffer" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" @@ -62,6 +63,8 @@ type Supervisor struct { eventDigestLock *sync.RWMutex eventDigest []byte + + bufferPool *buffer.Pool } type resolveResult struct { @@ -96,6 +99,7 @@ func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) { logger: config.Logger.WithField("subsystem", "supervisor"), jwtLock: &sync.RWMutex{}, eventDigestLock: &sync.RWMutex{}, + bufferPool: buffer.NewPool(512 * 1024), }, nil } @@ -230,7 +234,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign return } - err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID) + err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool) // If the first tunnel disconnects, keep restarting it. edgeErrors := 0 for s.unusedIPs() { @@ -253,7 +257,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign return } } - err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID) + err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool) } } @@ -272,7 +276,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal if err != nil { return } - err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID) + err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool) } func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal { diff --git a/origin/tunnel.go b/origin/tunnel.go index 4ee6b55e..98cebfd3 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -20,6 +20,7 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" + "github.com/cloudflare/cloudflared/buffer" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/h2mux" @@ -178,6 +179,7 @@ func ServeTunnelLoop(ctx context.Context, connectionID uint8, connectedSignal *signal.Signal, u uuid.UUID, + bufferPool *buffer.Pool, ) error { connectionLogger := config.Logger.WithField("connectionID", connectionID) config.Metrics.incrementHaConnections() @@ -201,6 +203,7 @@ func ServeTunnelLoop(ctx context.Context, connectedFuse, &backoff, u, + bufferPool, ) if recoverable { if duration, ok := backoff.GetBackoffDuration(ctx); ok { @@ -223,6 +226,7 @@ func ServeTunnel( connectedFuse *h2mux.BooleanFuse, backoff *BackoffHandler, u uuid.UUID, + bufferPool *buffer.Pool, ) (err error, recoverable bool) { // Treat panics as recoverable errors defer func() { @@ -243,7 +247,7 @@ func ServeTunnel( tags["ha"] = connectionTag // Returns error from parsing the origin URL or handshake errors - handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID) + handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID, bufferPool) if err != nil { errLog := logger.WithError(err) switch err.(type) { @@ -500,6 +504,8 @@ type TunnelHandler struct { connectionID string logger *log.Logger noChunkedEncoding bool + + bufferPool *buffer.Pool } // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error @@ -507,6 +513,7 @@ func NewTunnelHandler(ctx context.Context, config *TunnelConfig, addr *net.TCPAddr, connectionID uint8, + bufferPool *buffer.Pool, ) (*TunnelHandler, string, error) { originURL, err := validation.ValidateUrl(config.OriginUrl) if err != nil { @@ -522,6 +529,7 @@ func NewTunnelHandler(ctx context.Context, connectionID: uint8ToString(connectionID), logger: config.Logger, noChunkedEncoding: config.NoChunkedEncoding, + bufferPool: bufferPool, } if h.httpClient == nil { h.httpClient = http.DefaultTransport @@ -642,7 +650,9 @@ func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) } else { // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream // compression generates dictionary on first write - io.CopyBuffer(stream, response.Body, make([]byte, 512*1024)) + buf := h.bufferPool.Get() + defer h.bufferPool.Put(buf) + io.CopyBuffer(stream, response.Body, buf) } return response, nil } diff --git a/originservice/originservice.go b/originservice/originservice.go index 3cd0af53..2277385f 100644 --- a/originservice/originservice.go +++ b/originservice/originservice.go @@ -12,6 +12,7 @@ import ( "strconv" "strings" + "github.com/cloudflare/cloudflared/buffer" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/log" @@ -33,6 +34,7 @@ type HTTPService struct { client http.RoundTripper originURL *url.URL chunkedEncoding bool + bufferPool *buffer.Pool } func NewHTTPService(transport http.RoundTripper, url *url.URL, chunkedEncoding bool) OriginService { @@ -40,6 +42,7 @@ func NewHTTPService(transport http.RoundTripper, url *url.URL, chunkedEncoding b client: transport, originURL: url, chunkedEncoding: chunkedEncoding, + bufferPool: buffer.NewPool(512 * 1024), } } @@ -71,7 +74,9 @@ func (hc *HTTPService) Proxy(stream *h2mux.MuxedStream, req *http.Request) (*htt } else { // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream // compression generates dictionary on first write - io.CopyBuffer(stream, resp.Body, make([]byte, 512*1024)) + buf := hc.bufferPool.Get() + defer hc.bufferPool.Put(buf) + io.CopyBuffer(stream, resp.Body, buf) } return resp, nil } @@ -142,10 +147,11 @@ func (wsc *WebsocketService) Shutdown() { // HelloWorldService talks to the hello world example origin type HelloWorldService struct { - client http.RoundTripper - listener net.Listener - originURL *url.URL - shutdownC chan struct{} + client http.RoundTripper + listener net.Listener + originURL *url.URL + shutdownC chan struct{} + bufferPool *buffer.Pool } func NewHelloWorldService(transport http.RoundTripper) (OriginService, error) { @@ -164,7 +170,8 @@ func NewHelloWorldService(transport http.RoundTripper) (OriginService, error) { Scheme: "https", Host: listener.Addr().String(), }, - shutdownC: shutdownC, + shutdownC: shutdownC, + bufferPool: buffer.NewPool(512 * 1024), }, nil } @@ -184,7 +191,9 @@ func (hwc *HelloWorldService) Proxy(stream *h2mux.MuxedStream, req *http.Request // Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream // compression generates dictionary on first write - io.CopyBuffer(stream, resp.Body, make([]byte, 512*1024)) + buf := hwc.bufferPool.Get() + defer hwc.bufferPool.Put(buf) + io.CopyBuffer(stream, resp.Body, buf) return resp, nil }