From e72e4b882ac5b57a8f1a086f11d5f70e60f33be6 Mon Sep 17 00:00:00 2001 From: Nate Franzen Date: Fri, 31 Aug 2018 15:02:20 -0700 Subject: [PATCH] refactor prometheus metrics --- cmd/cloudflared/configuration.go | 12 +- cmd/cloudflared/main.go | 7 +- origin/metrics.go | 374 +++++++++++++------------------ origin/metrics_test.go | 120 ---------- origin/supervisor.go | 7 +- origin/tunnel.go | 47 ++-- 6 files changed, 206 insertions(+), 361 deletions(-) diff --git a/cmd/cloudflared/configuration.go b/cmd/cloudflared/configuration.go index 345eb7ff..11f55a4a 100644 --- a/cmd/cloudflared/configuration.go +++ b/cmd/cloudflared/configuration.go @@ -36,6 +36,11 @@ var ( defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp", "/usr/local/etc/cloudflared", "/etc/cloudflared"} ) +// NoAdditionalMetricsLabels returns an empty slice of label keys or label values +func NoAdditionalMetricsLabels() []string { + return make([]string, 0) +} + const defaultCredentialFile = "cert.pem" func fileExists(path string) (bool, error) { @@ -243,7 +248,7 @@ func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, logger, pr return nil, errors.Wrap(err, "Error loading cert pool") } - tunnelMetrics := origin.NewTunnelMetrics() + tunnelMetrics := origin.InitializeTunnelMetrics(NoAdditionalMetricsLabels()) httpTransport := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ @@ -291,6 +296,11 @@ func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, logger, pr }, nil } +// newMetricsUpdater returns a default implementation with no additional metrics label values +func newMetricsUpdater(config *origin.TunnelConfig) (origin.TunnelMetricsUpdater, error) { + return origin.NewTunnelMetricsUpdater(config.Metrics, NoAdditionalMetricsLabels()) +} + func loadCertPool(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error) { const originCAPoolFlag = "origin-ca-pool" originCAPoolFilename := c.String(originCAPoolFlag) diff --git a/cmd/cloudflared/main.go b/cmd/cloudflared/main.go index 91a6858b..c14cfa1c 100644 --- a/cmd/cloudflared/main.go +++ b/cmd/cloudflared/main.go @@ -587,11 +587,14 @@ func startServer(c *cli.Context, shutdownC, graceShutdownC chan struct{}) error if err != nil { return err } - + metricsUpdater, err := newMetricsUpdater(tunnelConfig) + if err != nil { + return err + } wg.Add(1) go func() { defer wg.Done() - errC <- origin.StartTunnelDaemon(tunnelConfig, graceShutdownC, connectedSignal) + errC <- origin.StartTunnelDaemon(tunnelConfig, metricsUpdater, graceShutdownC, connectedSignal) }() return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period")) diff --git a/origin/metrics.go b/origin/metrics.go index d69a9009..2624d170 100644 --- a/origin/metrics.go +++ b/origin/metrics.go @@ -1,7 +1,7 @@ package origin import ( - "sync" + "fmt" "time" "github.com/cloudflare/cloudflared/h2mux" @@ -9,7 +9,19 @@ import ( "github.com/prometheus/client_golang/prometheus" ) -type muxerMetrics struct { +// TunnelMetrics contains pointers to the global prometheus metrics and their common label keys +type TunnelMetrics struct { + connectionKey string + locationKey string + statusKey string + commonKeys []string + + haConnections prometheus.Gauge + timerRetries prometheus.Gauge + + requests *prometheus.CounterVec + responses *prometheus.CounterVec + rtt *prometheus.GaugeVec rttMin *prometheus.GaugeVec rttMax *prometheus.GaugeVec @@ -30,37 +42,96 @@ type muxerMetrics struct { compRateAve *prometheus.GaugeVec } -type TunnelMetrics struct { - haConnections prometheus.Gauge - totalRequests prometheus.Counter - requestsPerTunnel *prometheus.CounterVec - // concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests - concurrentRequestsLock sync.Mutex - concurrentRequestsPerTunnel *prometheus.GaugeVec - // concurrentRequests records count of concurrent requests for each tunnel - concurrentRequests map[string]uint64 - maxConcurrentRequestsPerTunnel *prometheus.GaugeVec - // concurrentRequests records max count of concurrent requests for each tunnel - maxConcurrentRequests map[string]uint64 - timerRetries prometheus.Gauge - responseByCode *prometheus.CounterVec - responseCodePerTunnel *prometheus.CounterVec - serverLocations *prometheus.GaugeVec - // locationLock is a mutex for oldServerLocations - locationLock sync.Mutex - // oldServerLocations stores the last server the tunnel was connected to - oldServerLocations map[string]string +// TunnelMetricsUpdater separates the prometheus metrics and the update process +// The updater can be initialized with some shared metrics labels, while other +// labels (connectionID, status) are set when the metric is updated +type TunnelMetricsUpdater interface { + setServerLocation(connectionID, loc string) - muxerMetrics *muxerMetrics + incrementHaConnections() + decrementHaConnections() + + incrementRequests(connectionID string) + incrementResponses(connectionID, code string) + + updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) } -func newMuxerMetrics() *muxerMetrics { +type tunnelMetricsUpdater struct { + + // metrics is a set of pointers to prometheus metrics, configured globally + metrics *TunnelMetrics + + // commonValues is group of label values that are set for this updater + commonValues []string + + // serverLocations maps the connectionID to a server location string + serverLocations map[string]string +} + +// NewTunnelMetricsUpdater creates a metrics updater with common label values +func NewTunnelMetricsUpdater(metrics *TunnelMetrics, commonLabelValues []string) (TunnelMetricsUpdater, error) { + + if len(commonLabelValues) != len(metrics.commonKeys) { + return nil, fmt.Errorf("failed to create updater, mismatched count of metrics label key (%v) and values (%v)", metrics.commonKeys, commonLabelValues) + } + return &tunnelMetricsUpdater{ + metrics: metrics, + commonValues: commonLabelValues, + serverLocations: make(map[string]string, 1), + }, nil +} + +// InitializeTunnelMetrics configures the prometheus metrics globally with common label keys +func InitializeTunnelMetrics(commonLabelKeys []string) *TunnelMetrics { + + connectionKey := "connection_id" + locationKey := "location" + statusKey := "status" + + labelKeys := append(commonLabelKeys, connectionKey, locationKey) + + // not a labelled vector + haConnections := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "ha_connections", + Help: "Number of active HA connections", + }, + ) + prometheus.MustRegister(haConnections) + + // not a labelled vector + timerRetries := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "timer_retries", + Help: "Unacknowledged heart beats count", + }) + prometheus.MustRegister(timerRetries) + + requests := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "requests", + Help: "Count of requests", + }, + labelKeys, + ) + prometheus.MustRegister(requests) + + responses := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "responses", + Help: "Count of responses", + }, + append(labelKeys, statusKey), + ) + prometheus.MustRegister(responses) + rtt := prometheus.NewGaugeVec( prometheus.GaugeOpts{ Name: "rtt", Help: "Round-trip time in millisecond", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(rtt) @@ -69,7 +140,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "rtt_min", Help: "Shortest round-trip time in millisecond", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(rttMin) @@ -78,7 +149,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "rtt_max", Help: "Longest round-trip time in millisecond", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(rttMax) @@ -87,7 +158,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "receive_window_ave", Help: "Average receive window size in bytes", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(receiveWindowAve) @@ -96,7 +167,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "send_window_ave", Help: "Average send window size in bytes", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(sendWindowAve) @@ -105,7 +176,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "receive_window_min", Help: "Smallest receive window size in bytes", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(receiveWindowMin) @@ -114,7 +185,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "receive_window_max", Help: "Largest receive window size in bytes", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(receiveWindowMax) @@ -123,7 +194,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "send_window_min", Help: "Smallest send window size in bytes", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(sendWindowMin) @@ -132,7 +203,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "send_window_max", Help: "Largest send window size in bytes", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(sendWindowMax) @@ -141,7 +212,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "inbound_bytes_per_sec_curr", Help: "Current inbounding bytes per second, 0 if there is no incoming connection", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(inBoundRateCurr) @@ -150,7 +221,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "inbound_bytes_per_sec_min", Help: "Minimum non-zero inbounding bytes per second", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(inBoundRateMin) @@ -159,7 +230,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "inbound_bytes_per_sec_max", Help: "Maximum inbounding bytes per second", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(inBoundRateMax) @@ -168,7 +239,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "outbound_bytes_per_sec_curr", Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(outBoundRateCurr) @@ -177,7 +248,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "outbound_bytes_per_sec_min", Help: "Minimum non-zero outbounding bytes per second", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(outBoundRateMin) @@ -186,7 +257,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "outbound_bytes_per_sec_max", Help: "Maximum outbounding bytes per second", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(outBoundRateMax) @@ -195,7 +266,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "comp_bytes_before", Help: "Bytes sent via cross-stream compression, pre compression", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(compBytesBefore) @@ -204,7 +275,7 @@ func newMuxerMetrics() *muxerMetrics { Name: "comp_bytes_after", Help: "Bytes sent via cross-stream compression, post compression", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(compBytesAfter) @@ -213,11 +284,23 @@ func newMuxerMetrics() *muxerMetrics { Name: "comp_rate_ave", Help: "Average outbound cross-stream compression ratio", }, - []string{"connection_id"}, + labelKeys, ) prometheus.MustRegister(compRateAve) - return &muxerMetrics{ + return &TunnelMetrics{ + + connectionKey: connectionKey, + locationKey: locationKey, + statusKey: statusKey, + commonKeys: commonLabelKeys, + + haConnections: haConnections, + timerRetries: timerRetries, + + requests: requests, + responses: responses, + rtt: rtt, rttMin: rttMin, rttMax: rttMax, @@ -235,187 +318,54 @@ func newMuxerMetrics() *muxerMetrics { outBoundRateMax: outBoundRateMax, compBytesBefore: compBytesBefore, compBytesAfter: compBytesAfter, - compRateAve: compRateAve, - } + compRateAve: compRateAve} } -func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) { - m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT)) - m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin)) - m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax)) - m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve) - m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve) - m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin)) - m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax)) - m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin)) - m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax)) - m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr)) - m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin)) - m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax)) - m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr)) - m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin)) - m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax)) - m.compBytesBefore.WithLabelValues(connectionID).Set(float64(metrics.CompBytesBefore.Value())) - m.compBytesAfter.WithLabelValues(connectionID).Set(float64(metrics.CompBytesAfter.Value())) - m.compRateAve.WithLabelValues(connectionID).Set(float64(metrics.CompRateAve())) +func (t *tunnelMetricsUpdater) incrementHaConnections() { + t.metrics.haConnections.Inc() +} + +func (t *tunnelMetricsUpdater) decrementHaConnections() { + t.metrics.haConnections.Dec() } func convertRTTMilliSec(t time.Duration) float64 { return float64(t / time.Millisecond) } +func (t *tunnelMetricsUpdater) updateMuxerMetrics(connectionID string, muxMetrics *h2mux.MuxerMetrics) { + values := append(t.commonValues, connectionID, t.serverLocations[connectionID]) -// Metrics that can be collected without asking the edge -func NewTunnelMetrics() *TunnelMetrics { - haConnections := prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "ha_connections", - Help: "Number of active ha connections", - }) - prometheus.MustRegister(haConnections) - - totalRequests := prometheus.NewCounter( - prometheus.CounterOpts{ - Name: "total_requests", - Help: "Amount of requests proxied through all the tunnels", - }) - prometheus.MustRegister(totalRequests) - - requestsPerTunnel := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "requests_per_tunnel", - Help: "Amount of requests proxied through each tunnel", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(requestsPerTunnel) - - concurrentRequestsPerTunnel := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "concurrent_requests_per_tunnel", - Help: "Concurrent requests proxied through each tunnel", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(concurrentRequestsPerTunnel) - - maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "max_concurrent_requests_per_tunnel", - Help: "Largest number of concurrent requests proxied through each tunnel so far", - }, - []string{"connection_id"}, - ) - prometheus.MustRegister(maxConcurrentRequestsPerTunnel) - - timerRetries := prometheus.NewGauge( - prometheus.GaugeOpts{ - Name: "timer_retries", - Help: "Unacknowledged heart beats count", - }) - prometheus.MustRegister(timerRetries) - - responseByCode := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "response_by_code", - Help: "Count of responses by HTTP status code", - }, - []string{"status_code"}, - ) - prometheus.MustRegister(responseByCode) - - responseCodePerTunnel := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "response_code_per_tunnel", - Help: "Count of responses by HTTP status code fore each tunnel", - }, - []string{"connection_id", "status_code"}, - ) - prometheus.MustRegister(responseCodePerTunnel) - - serverLocations := prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: "server_locations", - Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.", - }, - []string{"connection_id", "location"}, - ) - prometheus.MustRegister(serverLocations) - - return &TunnelMetrics{ - haConnections: haConnections, - totalRequests: totalRequests, - requestsPerTunnel: requestsPerTunnel, - concurrentRequestsPerTunnel: concurrentRequestsPerTunnel, - concurrentRequests: make(map[string]uint64), - maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel, - maxConcurrentRequests: make(map[string]uint64), - timerRetries: timerRetries, - responseByCode: responseByCode, - responseCodePerTunnel: responseCodePerTunnel, - serverLocations: serverLocations, - oldServerLocations: make(map[string]string), - muxerMetrics: newMuxerMetrics(), - } + t.metrics.rtt.WithLabelValues(values...).Set(convertRTTMilliSec(muxMetrics.RTT)) + t.metrics.rttMin.WithLabelValues(values...).Set(convertRTTMilliSec(muxMetrics.RTTMin)) + t.metrics.rttMax.WithLabelValues(values...).Set(convertRTTMilliSec(muxMetrics.RTTMax)) + t.metrics.receiveWindowAve.WithLabelValues(values...).Set(muxMetrics.ReceiveWindowAve) + t.metrics.sendWindowAve.WithLabelValues(values...).Set(muxMetrics.SendWindowAve) + t.metrics.receiveWindowMin.WithLabelValues(values...).Set(float64(muxMetrics.ReceiveWindowMin)) + t.metrics.receiveWindowMax.WithLabelValues(values...).Set(float64(muxMetrics.ReceiveWindowMax)) + t.metrics.sendWindowMin.WithLabelValues(values...).Set(float64(muxMetrics.SendWindowMin)) + t.metrics.sendWindowMax.WithLabelValues(values...).Set(float64(muxMetrics.SendWindowMax)) + t.metrics.inBoundRateCurr.WithLabelValues(values...).Set(float64(muxMetrics.InBoundRateCurr)) + t.metrics.inBoundRateMin.WithLabelValues(values...).Set(float64(muxMetrics.InBoundRateMin)) + t.metrics.inBoundRateMax.WithLabelValues(values...).Set(float64(muxMetrics.InBoundRateMax)) + t.metrics.outBoundRateCurr.WithLabelValues(values...).Set(float64(muxMetrics.OutBoundRateCurr)) + t.metrics.outBoundRateMin.WithLabelValues(values...).Set(float64(muxMetrics.OutBoundRateMin)) + t.metrics.outBoundRateMax.WithLabelValues(values...).Set(float64(muxMetrics.OutBoundRateMax)) + t.metrics.compBytesBefore.WithLabelValues(values...).Set(float64(muxMetrics.CompBytesBefore.Value())) + t.metrics.compBytesAfter.WithLabelValues(values...).Set(float64(muxMetrics.CompBytesAfter.Value())) + t.metrics.compRateAve.WithLabelValues(values...).Set(float64(muxMetrics.CompRateAve())) } -func (t *TunnelMetrics) incrementHaConnections() { - t.haConnections.Inc() +func (t *tunnelMetricsUpdater) incrementRequests(connectionID string) { + values := append(t.commonValues, connectionID, t.serverLocations[connectionID]) + t.metrics.requests.WithLabelValues(values...).Inc() } -func (t *TunnelMetrics) decrementHaConnections() { - t.haConnections.Dec() +func (t *tunnelMetricsUpdater) incrementResponses(connectionID, code string) { + values := append(t.commonValues, connectionID, t.serverLocations[connectionID], code) + + t.metrics.responses.WithLabelValues(values...).Inc() } -func (t *TunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) { - t.muxerMetrics.update(connectionID, metrics) -} - -func (t *TunnelMetrics) incrementRequests(connectionID string) { - t.concurrentRequestsLock.Lock() - var concurrentRequests uint64 - var ok bool - if concurrentRequests, ok = t.concurrentRequests[connectionID]; ok { - t.concurrentRequests[connectionID] += 1 - concurrentRequests++ - } else { - t.concurrentRequests[connectionID] = 1 - concurrentRequests = 1 - } - if maxConcurrentRequests, ok := t.maxConcurrentRequests[connectionID]; (ok && maxConcurrentRequests < concurrentRequests) || !ok { - t.maxConcurrentRequests[connectionID] = concurrentRequests - t.maxConcurrentRequestsPerTunnel.WithLabelValues(connectionID).Set(float64(concurrentRequests)) - } - t.concurrentRequestsLock.Unlock() - - t.totalRequests.Inc() - t.requestsPerTunnel.WithLabelValues(connectionID).Inc() - t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc() -} - -func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) { - t.concurrentRequestsLock.Lock() - if _, ok := t.concurrentRequests[connectionID]; ok { - t.concurrentRequests[connectionID] -= 1 - } - t.concurrentRequestsLock.Unlock() - - t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec() -} - -func (t *TunnelMetrics) incrementResponses(connectionID, code string) { - t.responseByCode.WithLabelValues(code).Inc() - t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc() - -} - -func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) { - t.locationLock.Lock() - defer t.locationLock.Unlock() - if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc { - return - } else if ok { - t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec() - } - t.serverLocations.WithLabelValues(connectionID, loc).Inc() - t.oldServerLocations[connectionID] = loc +func (t *tunnelMetricsUpdater) setServerLocation(connectionID, loc string) { + t.serverLocations[connectionID] = loc } diff --git a/origin/metrics_test.go b/origin/metrics_test.go index b6cc8206..7b106c51 100644 --- a/origin/metrics_test.go +++ b/origin/metrics_test.go @@ -1,121 +1 @@ package origin - -import ( - "strconv" - "sync" - "testing" - - "github.com/stretchr/testify/assert" -) - -// can only be called once -var m = NewTunnelMetrics() - -func TestConcurrentRequestsSingleTunnel(t *testing.T) { - routines := 20 - var wg sync.WaitGroup - wg.Add(routines) - for i := 0; i < routines; i++ { - go func() { - m.incrementRequests("0") - wg.Done() - }() - } - wg.Wait() - assert.Len(t, m.concurrentRequests, 1) - assert.Equal(t, uint64(routines), m.concurrentRequests["0"]) - assert.Len(t, m.maxConcurrentRequests, 1) - assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"]) - - wg.Add(routines / 2) - for i := 0; i < routines/2; i++ { - go func() { - m.decrementConcurrentRequests("0") - wg.Done() - }() - } - wg.Wait() - assert.Equal(t, uint64(routines-routines/2), m.concurrentRequests["0"]) - assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"]) -} - -func TestConcurrentRequestsMultiTunnel(t *testing.T) { - m.concurrentRequests = make(map[string]uint64) - m.maxConcurrentRequests = make(map[string]uint64) - tunnels := 20 - var wg sync.WaitGroup - wg.Add(tunnels) - for i := 0; i < tunnels; i++ { - go func(i int) { - // if we have j < i, then tunnel 0 won't have a chance to call incrementRequests - for j := 0; j < i+1; j++ { - id := strconv.Itoa(i) - m.incrementRequests(id) - } - wg.Done() - }(i) - } - wg.Wait() - - assert.Len(t, m.concurrentRequests, tunnels) - assert.Len(t, m.maxConcurrentRequests, tunnels) - for i := 0; i < tunnels; i++ { - id := strconv.Itoa(i) - assert.Equal(t, uint64(i+1), m.concurrentRequests[id]) - assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id]) - } - - wg.Add(tunnels) - for i := 0; i < tunnels; i++ { - go func(i int) { - for j := 0; j < i+1; j++ { - id := strconv.Itoa(i) - m.decrementConcurrentRequests(id) - } - wg.Done() - }(i) - } - wg.Wait() - - assert.Len(t, m.concurrentRequests, tunnels) - assert.Len(t, m.maxConcurrentRequests, tunnels) - for i := 0; i < tunnels; i++ { - id := strconv.Itoa(i) - assert.Equal(t, uint64(0), m.concurrentRequests[id]) - assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id]) - } - -} - -func TestRegisterServerLocation(t *testing.T) { - tunnels := 20 - var wg sync.WaitGroup - wg.Add(tunnels) - for i := 0; i < tunnels; i++ { - go func(i int) { - id := strconv.Itoa(i) - m.registerServerLocation(id, "LHR") - wg.Done() - }(i) - } - wg.Wait() - for i := 0; i < tunnels; i++ { - id := strconv.Itoa(i) - assert.Equal(t, "LHR", m.oldServerLocations[id]) - } - - wg.Add(tunnels) - for i := 0; i < tunnels; i++ { - go func(i int) { - id := strconv.Itoa(i) - m.registerServerLocation(id, "AUS") - wg.Done() - }(i) - } - wg.Wait() - for i := 0; i < tunnels; i++ { - id := strconv.Itoa(i) - assert.Equal(t, "AUS", m.oldServerLocations[id]) - } - -} diff --git a/origin/supervisor.go b/origin/supervisor.go index 297d5153..abfa3241 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -19,6 +19,7 @@ const ( type Supervisor struct { config *TunnelConfig + metrics TunnelMetricsUpdater edgeIPs []*net.TCPAddr // nextUnusedEdgeIP is the index of the next addr k edgeIPs to try nextUnusedEdgeIP int @@ -155,7 +156,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct // startTunnel starts the first tunnel connection. The resulting error will be sent on // s.tunnelErrors. It will send a signal via connectedSignal if registration succeed func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan struct{}) { - err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal) + err := ServeTunnelLoop(ctx, s.config, s.metrics, s.getEdgeIP(0), 0, connectedSignal) defer func() { s.tunnelErrors <- tunnelError{index: 0, err: err} }() @@ -176,14 +177,14 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan default: return } - err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal) + err = ServeTunnelLoop(ctx, s.config, s.metrics, s.getEdgeIP(0), 0, connectedSignal) } } // startTunnel starts a new tunnel connection. The resulting error will be sent on // s.tunnelErrors. func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}) { - err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal) + err := ServeTunnelLoop(ctx, s.config, s.metrics, s.getEdgeIP(index), uint8(index), connectedSignal) s.tunnelErrors <- tunnelError{index: index, err: err} } diff --git a/origin/tunnel.go b/origin/tunnel.go index 6cdafd8f..f2eaa65c 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -123,7 +123,7 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str } } -func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error { +func StartTunnelDaemon(config *TunnelConfig, metrics TunnelMetricsUpdater, shutdownC <-chan struct{}, connectedSignal chan struct{}) error { ctx, cancel := context.WithCancel(context.Background()) go func() { <-shutdownC @@ -137,19 +137,20 @@ func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connecte if err != nil { return err } - return ServeTunnelLoop(ctx, config, addrs[0], 0, connectedSignal) + return ServeTunnelLoop(ctx, config, metrics, addrs[0], 0, connectedSignal) } } func ServeTunnelLoop(ctx context.Context, config *TunnelConfig, + metrics TunnelMetricsUpdater, addr *net.TCPAddr, connectionID uint8, connectedSignal chan struct{}, ) error { logger := config.Logger - config.Metrics.incrementHaConnections() - defer config.Metrics.decrementHaConnections() + metrics.incrementHaConnections() + defer metrics.decrementHaConnections() backoff := BackoffHandler{MaxRetries: config.Retries} // Used to close connectedSignal no more than once connectedFuse := h2mux.NewBooleanFuse() @@ -161,7 +162,7 @@ func ServeTunnelLoop(ctx context.Context, // Ensure the above goroutine will terminate if we return without connecting defer connectedFuse.Fuse(false) for { - err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff) + err, recoverable := ServeTunnel(ctx, config, metrics, addr, connectionID, connectedFuse, &backoff) if recoverable { if duration, ok := backoff.GetBackoffDuration(ctx); ok { logger.Infof("Retrying in %s seconds", duration) @@ -176,6 +177,7 @@ func ServeTunnelLoop(ctx context.Context, func ServeTunnel( ctx context.Context, config *TunnelConfig, + metrics TunnelMetricsUpdater, addr *net.TCPAddr, connectionID uint8, connectedFuse *h2mux.BooleanFuse, @@ -201,7 +203,7 @@ func ServeTunnel( tags["ha"] = connectionTag // Returns error from parsing the origin URL or handshake errors - handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID) + handler, originLocalIP, err := NewTunnelHandler(ctx, config, metrics, addr.String(), connectionID) if err != nil { errLog := config.Logger.WithError(err) switch err.(type) { @@ -219,7 +221,7 @@ func ServeTunnel( errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { - err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP) + err := RegisterTunnel(serveCtx, handler.muxer, config, metrics, connectionID, originLocalIP) if err == nil { connectedFuse.Fuse(true) backoff.SetGracePeriod() @@ -290,7 +292,7 @@ func IsRPCStreamResponse(headers []h2mux.Header) bool { return true } -func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error { +func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, metrics TunnelMetricsUpdater, connectionID uint8, originLocalIP string) error { config.Logger.Debug("initiating RPC stream to register") stream, err := muxer.OpenStream([]h2mux.Header{ {Name: ":method", Value: "RPC"}, @@ -322,7 +324,7 @@ func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfi config.Hostname, config.RegistrationOptions(connectionID, originLocalIP), ) - LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, config.Logger) + LogServerInfo(serverInfoPromise.Result(), connectionID, metrics, config.Logger) if err != nil { // RegisterTunnel RPC failure return clientRegisterTunnelError{cause: err} @@ -373,7 +375,7 @@ func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log func LogServerInfo( promise tunnelrpc.ServerInfo_Promise, connectionID uint8, - metrics *TunnelMetrics, + metrics TunnelMetricsUpdater, logger *log.Logger, ) { serverInfoMessage, err := promise.Struct() @@ -387,7 +389,7 @@ func LogServerInfo( return } logger.Infof("Connected to %s", serverInfo.LocationName) - metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName) + metrics.setServerLocation(uint8ToString(connectionID), serverInfo.LocationName) } func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error { @@ -437,7 +439,7 @@ type TunnelHandler struct { httpClient http.RoundTripper tlsConfig *tls.Config tags []tunnelpogs.Tag - metrics *TunnelMetrics + metrics TunnelMetricsUpdater // connectionID is only used by metrics, and prometheus requires labels to be string connectionID string logger *log.Logger @@ -449,6 +451,7 @@ var dialer = net.Dialer{DualStack: true} // NewTunnelHandler returns a TunnelHandler, origin LAN IP and error func NewTunnelHandler(ctx context.Context, config *TunnelConfig, + metrics TunnelMetricsUpdater, addr string, connectionID uint8, ) (*TunnelHandler, string, error) { @@ -461,7 +464,7 @@ func NewTunnelHandler(ctx context.Context, httpClient: config.HTTPTransport, tlsConfig: config.ClientTlsConfig, tags: config.Tags, - metrics: config.Metrics, + metrics: metrics, connectionID: uint8ToString(connectionID), logger: config.Logger, noChunkedEncoding: config.NoChunkedEncoding, @@ -525,15 +528,15 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { if websocket.IsWebSocketUpgrade(req) { conn, response, err := websocket.ClientConnect(req, h.tlsConfig) if err != nil { - h.logError(stream, err) + h.logError(stream, response, err) } else { - stream.WriteHeaders(H1ResponseToH2Response(response)) defer conn.Close() + stream.WriteHeaders(H1ResponseToH2Response(response)) // Copy to/from stream to the undelying connection. Use the underlying // connection because cloudflared doesn't operate on the message themselves websocket.Stream(conn.UnderlyingConn(), stream) - h.metrics.incrementResponses(h.connectionID, "200") h.logResponse(response, cfRay, lbProbe) + h.metrics.incrementResponses(h.connectionID, strconv.Itoa(response.StatusCode)) } } else { // Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate @@ -551,7 +554,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { response, err := h.httpClient.RoundTrip(req) if err != nil { - h.logError(stream, err) + h.logError(stream, response, err) } else { defer response.Body.Close() stream.WriteHeaders(H1ResponseToH2Response(response)) @@ -563,11 +566,10 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { io.CopyBuffer(stream, response.Body, make([]byte, 512*1024)) } - h.metrics.incrementResponses(h.connectionID, "200") h.logResponse(response, cfRay, lbProbe) + h.metrics.incrementResponses(h.connectionID, strconv.Itoa(response.StatusCode)) } } - h.metrics.decrementConcurrentRequests(h.connectionID) return nil } @@ -590,11 +592,10 @@ func (h *TunnelHandler) isEventStream(response *http.Response) bool { return false } -func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) { +func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, response *http.Response, err error) { h.logger.WithError(err).Error("HTTP request error") - stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}}) - stream.Write([]byte("502 Bad Gateway")) - h.metrics.incrementResponses(h.connectionID, "502") + stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: strconv.Itoa(response.StatusCode)}}) + stream.Write([]byte(response.Status)) } func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool) {