diff --git a/origin/metrics.go b/origin/metrics.go index 694eea22..844dbc56 100644 --- a/origin/metrics.go +++ b/origin/metrics.go @@ -1,6 +1,8 @@ package origin import ( + "hash/fnv" + "strconv" "sync" "time" @@ -28,25 +30,25 @@ type muxerMetrics struct { } type TunnelMetrics struct { - haConnections prometheus.Gauge - totalRequests prometheus.Counter - requestsPerTunnel *prometheus.CounterVec + haConnections prometheus.Gauge + totalRequests prometheus.Counter + requests *prometheus.CounterVec // concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests concurrentRequestsLock sync.Mutex concurrentRequestsPerTunnel *prometheus.GaugeVec - // concurrentRequests records count of concurrent requests for each tunnel - concurrentRequests map[string]uint64 + // concurrentRequests records count of concurrent requests for each tunnel, keyed by hash of label values + concurrentRequests map[uint64]uint64 maxConcurrentRequestsPerTunnel *prometheus.GaugeVec - // concurrentRequests records max count of concurrent requests for each tunnel - maxConcurrentRequests map[string]uint64 + // concurrentRequests records max count of concurrent requests for each tunnel, keyed by hash of label values + maxConcurrentRequests map[uint64]uint64 timerRetries prometheus.Gauge - responseByCode *prometheus.CounterVec - responseCodePerTunnel *prometheus.CounterVec - serverLocations *prometheus.GaugeVec + + reponses *prometheus.CounterVec + serverLocations *prometheus.GaugeVec // locationLock is a mutex for oldServerLocations locationLock sync.Mutex // oldServerLocations stores the last server the tunnel was connected to - oldServerLocations map[string]string + oldServerLocations map[uint64]string muxerMetrics *muxerMetrics } @@ -206,22 +208,22 @@ func newMuxerMetrics() *muxerMetrics { } } -func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) { - m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT)) - m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin)) - m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax)) - m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve) - m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve) - m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin)) - m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax)) - m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin)) - m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax)) - m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr)) - m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin)) - m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax)) - m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr)) - m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin)) - m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax)) +func (m *muxerMetrics) update(metricLabelValues []string, metrics *h2mux.MuxerMetrics) { + m.rtt.WithLabelValues(metricLabelValues...).Set(convertRTTMilliSec(metrics.RTT)) + m.rttMin.WithLabelValues(metricLabelValues...).Set(convertRTTMilliSec(metrics.RTTMin)) + m.rttMax.WithLabelValues(metricLabelValues...).Set(convertRTTMilliSec(metrics.RTTMax)) + m.receiveWindowAve.WithLabelValues(metricLabelValues...).Set(metrics.ReceiveWindowAve) + m.sendWindowAve.WithLabelValues(metricLabelValues...).Set(metrics.SendWindowAve) + m.receiveWindowMin.WithLabelValues(metricLabelValues...).Set(float64(metrics.ReceiveWindowMin)) + m.receiveWindowMax.WithLabelValues(metricLabelValues...).Set(float64(metrics.ReceiveWindowMax)) + m.sendWindowMin.WithLabelValues(metricLabelValues...).Set(float64(metrics.SendWindowMin)) + m.sendWindowMax.WithLabelValues(metricLabelValues...).Set(float64(metrics.SendWindowMax)) + m.inBoundRateCurr.WithLabelValues(metricLabelValues...).Set(float64(metrics.InBoundRateCurr)) + m.inBoundRateMin.WithLabelValues(metricLabelValues...).Set(float64(metrics.InBoundRateMin)) + m.inBoundRateMax.WithLabelValues(metricLabelValues...).Set(float64(metrics.InBoundRateMax)) + m.outBoundRateCurr.WithLabelValues(metricLabelValues...).Set(float64(metrics.OutBoundRateCurr)) + m.outBoundRateMin.WithLabelValues(metricLabelValues...).Set(float64(metrics.OutBoundRateMin)) + m.outBoundRateMax.WithLabelValues(metricLabelValues...).Set(float64(metrics.OutBoundRateMax)) } func convertRTTMilliSec(t time.Duration) float64 { @@ -278,14 +280,14 @@ func NewTunnelMetrics() *TunnelMetrics { }) prometheus.MustRegister(timerRetries) - responseByCode := prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: "argotunnel_response_by_code", - Help: "Count of responses by HTTP status code", - }, - []string{"status_code"}, - ) - prometheus.MustRegister(responseByCode) + // responseByCode := prometheus.NewCounterVec( + // prometheus.CounterOpts{ + // Name: "argotunnel_response_by_code", + // Help: "Count of responses by HTTP status code", + // }, + // []string{"status_code"}, + // ) + // prometheus.MustRegister(responseByCode) responseCodePerTunnel := prometheus.NewCounterVec( prometheus.CounterOpts{ @@ -308,20 +310,28 @@ func NewTunnelMetrics() *TunnelMetrics { return &TunnelMetrics{ haConnections: haConnections, totalRequests: totalRequests, - requestsPerTunnel: requestsPerTunnel, + requests: requestsPerTunnel, concurrentRequestsPerTunnel: concurrentRequestsPerTunnel, - concurrentRequests: make(map[string]uint64), + concurrentRequests: make(map[uint64]uint64), maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel, - maxConcurrentRequests: make(map[string]uint64), + maxConcurrentRequests: make(map[uint64]uint64), timerRetries: timerRetries, - responseByCode: responseByCode, - responseCodePerTunnel: responseCodePerTunnel, - serverLocations: serverLocations, - oldServerLocations: make(map[string]string), - muxerMetrics: newMuxerMetrics(), + + reponses: responseCodePerTunnel, + serverLocations: serverLocations, + oldServerLocations: make(map[uint64]string), + muxerMetrics: newMuxerMetrics(), } } +func hashLabelValues(labelValues []string) uint64 { + h := fnv.New64() + for _, text := range labelValues { + h.Write([]byte(text)) + } + return h.Sum64() +} + func (t *TunnelMetrics) incrementHaConnections() { t.haConnections.Inc() } @@ -330,56 +340,61 @@ func (t *TunnelMetrics) decrementHaConnections() { t.haConnections.Dec() } -func (t *TunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) { - t.muxerMetrics.update(connectionID, metrics) +func (t *TunnelMetrics) updateMuxerMetrics(metricLabelValues []string, metrics *h2mux.MuxerMetrics) { + t.muxerMetrics.update(metricLabelValues, metrics) } -func (t *TunnelMetrics) incrementRequests(connectionID string) { +func (t *TunnelMetrics) incrementRequests(metricLabelValues []string) { t.concurrentRequestsLock.Lock() var concurrentRequests uint64 var ok bool - if concurrentRequests, ok = t.concurrentRequests[connectionID]; ok { - t.concurrentRequests[connectionID] += 1 + hashKey := hashLabelValues(metricLabelValues) + if concurrentRequests, ok = t.concurrentRequests[hashKey]; ok { + t.concurrentRequests[hashKey] += 1 concurrentRequests++ } else { - t.concurrentRequests[connectionID] = 1 + t.concurrentRequests[hashKey] = 1 concurrentRequests = 1 } - if maxConcurrentRequests, ok := t.maxConcurrentRequests[connectionID]; (ok && maxConcurrentRequests < concurrentRequests) || !ok { - t.maxConcurrentRequests[connectionID] = concurrentRequests - t.maxConcurrentRequestsPerTunnel.WithLabelValues(connectionID).Set(float64(concurrentRequests)) + if maxConcurrentRequests, ok := t.maxConcurrentRequests[hashKey]; (ok && maxConcurrentRequests < concurrentRequests) || !ok { + t.maxConcurrentRequests[hashKey] = concurrentRequests + t.maxConcurrentRequestsPerTunnel.WithLabelValues(metricLabelValues...).Set(float64(concurrentRequests)) } t.concurrentRequestsLock.Unlock() t.totalRequests.Inc() - t.requestsPerTunnel.WithLabelValues(connectionID).Inc() - t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc() + t.requests.WithLabelValues(metricLabelValues...).Inc() + t.concurrentRequestsPerTunnel.WithLabelValues(metricLabelValues...).Inc() } -func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) { +func (t *TunnelMetrics) decrementConcurrentRequests(metricLabelValues []string) { t.concurrentRequestsLock.Lock() - if _, ok := t.concurrentRequests[connectionID]; ok { - t.concurrentRequests[connectionID] -= 1 + hashKey := hashLabelValues(metricLabelValues) + if _, ok := t.concurrentRequests[hashKey]; ok { + t.concurrentRequests[hashKey] -= 1 } t.concurrentRequestsLock.Unlock() - t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec() + t.concurrentRequestsPerTunnel.WithLabelValues(metricLabelValues...).Dec() } -func (t *TunnelMetrics) incrementResponses(connectionID, code string) { - t.responseByCode.WithLabelValues(code).Inc() - t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc() +func (t *TunnelMetrics) incrementResponses(metricLabelValues []string, responseCode int) { + labelValues := append(metricLabelValues, strconv.Itoa(responseCode)) + t.reponses.WithLabelValues(labelValues...).Inc() } -func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) { +func (t *TunnelMetrics) registerServerLocation(metricLabelValues []string, loc string) { t.locationLock.Lock() defer t.locationLock.Unlock() - if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc { + hashKey := hashLabelValues(metricLabelValues) + if oldLoc, ok := t.oldServerLocations[hashKey]; ok && oldLoc == loc { return } else if ok { - t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec() + labelValues := append(metricLabelValues, oldLoc) + t.serverLocations.WithLabelValues(labelValues...).Dec() } - t.serverLocations.WithLabelValues(connectionID, loc).Inc() - t.oldServerLocations[connectionID] = loc + labelValues := append(metricLabelValues, loc) + t.serverLocations.WithLabelValues(labelValues...).Inc() + t.oldServerLocations[hashKey] = loc } diff --git a/origin/metrics_test.go b/origin/metrics_test.go index b6cc8206..15df9753 100644 --- a/origin/metrics_test.go +++ b/origin/metrics_test.go @@ -15,33 +15,37 @@ func TestConcurrentRequestsSingleTunnel(t *testing.T) { routines := 20 var wg sync.WaitGroup wg.Add(routines) + + baseLabels := []string{"0"} + hashKey := hashLabelValues(baseLabels) + for i := 0; i < routines; i++ { go func() { - m.incrementRequests("0") + m.incrementRequests(baseLabels) wg.Done() }() } wg.Wait() assert.Len(t, m.concurrentRequests, 1) - assert.Equal(t, uint64(routines), m.concurrentRequests["0"]) + assert.Equal(t, uint64(routines), m.concurrentRequests[hashKey]) assert.Len(t, m.maxConcurrentRequests, 1) - assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"]) + assert.Equal(t, uint64(routines), m.maxConcurrentRequests[hashKey]) wg.Add(routines / 2) for i := 0; i < routines/2; i++ { go func() { - m.decrementConcurrentRequests("0") + m.decrementConcurrentRequests(baseLabels) wg.Done() }() } wg.Wait() - assert.Equal(t, uint64(routines-routines/2), m.concurrentRequests["0"]) - assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"]) + assert.Equal(t, uint64(routines-routines/2), m.concurrentRequests[hashKey]) + assert.Equal(t, uint64(routines), m.maxConcurrentRequests[hashKey]) } func TestConcurrentRequestsMultiTunnel(t *testing.T) { - m.concurrentRequests = make(map[string]uint64) - m.maxConcurrentRequests = make(map[string]uint64) + m.concurrentRequests = make(map[uint64]uint64) + m.maxConcurrentRequests = make(map[uint64]uint64) tunnels := 20 var wg sync.WaitGroup wg.Add(tunnels) @@ -49,8 +53,8 @@ func TestConcurrentRequestsMultiTunnel(t *testing.T) { go func(i int) { // if we have j < i, then tunnel 0 won't have a chance to call incrementRequests for j := 0; j < i+1; j++ { - id := strconv.Itoa(i) - m.incrementRequests(id) + labels := []string{strconv.Itoa(i)} + m.incrementRequests(labels) } wg.Done() }(i) @@ -60,17 +64,18 @@ func TestConcurrentRequestsMultiTunnel(t *testing.T) { assert.Len(t, m.concurrentRequests, tunnels) assert.Len(t, m.maxConcurrentRequests, tunnels) for i := 0; i < tunnels; i++ { - id := strconv.Itoa(i) - assert.Equal(t, uint64(i+1), m.concurrentRequests[id]) - assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id]) + labels := []string{strconv.Itoa(i)} + hashKey := hashLabelValues(labels) + assert.Equal(t, uint64(i+1), m.concurrentRequests[hashKey]) + assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[hashKey]) } wg.Add(tunnels) for i := 0; i < tunnels; i++ { go func(i int) { for j := 0; j < i+1; j++ { - id := strconv.Itoa(i) - m.decrementConcurrentRequests(id) + labels := []string{strconv.Itoa(i)} + m.decrementConcurrentRequests(labels) } wg.Done() }(i) @@ -80,9 +85,10 @@ func TestConcurrentRequestsMultiTunnel(t *testing.T) { assert.Len(t, m.concurrentRequests, tunnels) assert.Len(t, m.maxConcurrentRequests, tunnels) for i := 0; i < tunnels; i++ { - id := strconv.Itoa(i) - assert.Equal(t, uint64(0), m.concurrentRequests[id]) - assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id]) + labels := []string{strconv.Itoa(i)} + hashKey := hashLabelValues(labels) + assert.Equal(t, uint64(0), m.concurrentRequests[hashKey]) + assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[hashKey]) } } @@ -93,29 +99,31 @@ func TestRegisterServerLocation(t *testing.T) { wg.Add(tunnels) for i := 0; i < tunnels; i++ { go func(i int) { - id := strconv.Itoa(i) - m.registerServerLocation(id, "LHR") + labels := []string{strconv.Itoa(i)} + m.registerServerLocation(labels, "LHR") wg.Done() }(i) } wg.Wait() for i := 0; i < tunnels; i++ { - id := strconv.Itoa(i) - assert.Equal(t, "LHR", m.oldServerLocations[id]) + labels := []string{strconv.Itoa(i)} + hashKey := hashLabelValues(labels) + assert.Equal(t, "LHR", m.oldServerLocations[hashKey]) } wg.Add(tunnels) for i := 0; i < tunnels; i++ { go func(i int) { - id := strconv.Itoa(i) - m.registerServerLocation(id, "AUS") + labels := []string{strconv.Itoa(i)} + m.registerServerLocation(labels, "AUS") wg.Done() }(i) } wg.Wait() for i := 0; i < tunnels; i++ { - id := strconv.Itoa(i) - assert.Equal(t, "AUS", m.oldServerLocations[id]) + labels := []string{strconv.Itoa(i)} + hashKey := hashLabelValues(labels) + assert.Equal(t, "AUS", m.oldServerLocations[hashKey]) } } diff --git a/origin/tunnel.go b/origin/tunnel.go index d4d8f37d..5e3a4a65 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -28,7 +28,7 @@ import ( ) const ( - dialTimeout = 15 * time.Second + dialTimeout = 15 * time.Second lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" TagHeaderNamePrefix = "Cf-Warp-Tag-" DuplicateConnectionError = "EDUPCONN" @@ -382,7 +382,7 @@ func LogServerInfo( return } logger.Infof("Connected to %s", serverInfo.LocationName) - metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName) + // metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName) } func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error { @@ -433,6 +433,10 @@ type TunnelHandler struct { tlsConfig *tls.Config tags []tunnelpogs.Tag metrics *TunnelMetrics + + baseMetricsLabelKeys []string + baseMetricsLabelValues []string + // connectionID is only used by metrics, and prometheus requires labels to be string connectionID string logger *log.Logger @@ -500,8 +504,12 @@ func (h *TunnelHandler) AppendTagHeaders(r *http.Request) { } } +func (h *TunnelHandler) getCombinedMetricsLabels(connectionID string) []string { + return append(h.baseMetricsLabelValues, connectionID) +} + func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { - h.metrics.incrementRequests(h.connectionID) + h.metrics.incrementRequests(h.getCombinedMetricsLabels(h.connectionID)) req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream}) if err != nil { h.logger.WithError(err).Panic("Unexpected error from http.NewRequest") @@ -516,6 +524,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { h.logRequest(req, cfRay, lbProbe) if websocket.IsWebSocketUpgrade(req) { conn, response, err := websocket.ClientConnect(req, h.tlsConfig) + h.metrics.incrementResponses(h.getCombinedMetricsLabels(h.connectionID), response.StatusCode) if err != nil { h.logError(stream, err) } else { @@ -524,22 +533,22 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error { // Copy to/from stream to the undelying connection. Use the underlying // connection because cloudflared doesn't operate on the message themselves websocket.Stream(conn.UnderlyingConn(), stream) - h.metrics.incrementResponses(h.connectionID, "200") h.logResponse(response, cfRay, lbProbe) } } else { response, err := h.httpClient.RoundTrip(req) + h.metrics.incrementResponses(h.getCombinedMetricsLabels(h.connectionID), response.StatusCode) if err != nil { h.logError(stream, err) } else { defer response.Body.Close() stream.WriteHeaders(H1ResponseToH2Response(response)) io.Copy(stream, response.Body) - h.metrics.incrementResponses(h.connectionID, "200") + h.logResponse(response, cfRay, lbProbe) } } - h.metrics.decrementConcurrentRequests(h.connectionID) + h.metrics.decrementConcurrentRequests(h.getCombinedMetricsLabels(h.connectionID)) return nil } @@ -547,7 +556,7 @@ func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) { h.logger.WithError(err).Error("HTTP request error") stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}}) stream.Write([]byte("502 Bad Gateway")) - h.metrics.incrementResponses(h.connectionID, "502") + } func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool) { @@ -573,7 +582,8 @@ func (h *TunnelHandler) logResponse(r *http.Response, cfRay string, lbProbe bool } func (h *TunnelHandler) UpdateMetrics(connectionID string) { - h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics()) + // why only updateMuxerMetrics + h.metrics.updateMuxerMetrics(h.getCombinedMetricsLabels(h.connectionID), h.muxer.Metrics()) } func uint8ToString(input uint8) string {