diff --git a/origin/discovery.go b/origin/discovery.go new file mode 100644 index 00000000..388b5b1f --- /dev/null +++ b/origin/discovery.go @@ -0,0 +1,82 @@ +package origin + +import ( + "fmt" + "net" +) + +const ( + // Used to discover HA Warp servers + srvService = "warp" + srvProto = "tcp" + srvName = "cloudflarewarp.com" +) + +func ResolveEdgeIPs(addresses []string) ([]*net.TCPAddr, error) { + if len(addresses) > 0 { + var tcpAddrs []*net.TCPAddr + for _, address := range addresses { + // Addresses specified (for testing, usually) + tcpAddr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + return nil, err + } + tcpAddrs = append(tcpAddrs, tcpAddr) + } + return tcpAddrs, nil + } + // HA service discovery lookup + _, addrs, err := net.LookupSRV(srvService, srvProto, srvName) + if err != nil { + return nil, err + } + var resolvedIPsPerCNAME [][]*net.TCPAddr + var lookupErr error + for _, addr := range addrs { + ips, err := ResolveSRVToTCP(addr) + if err != nil || len(ips) == 0 { + // don't return early, we might be able to resolve other addresses + lookupErr = err + continue + } + resolvedIPsPerCNAME = append(resolvedIPsPerCNAME, ips) + } + ips := FlattenServiceIPs(resolvedIPsPerCNAME) + if lookupErr == nil && len(ips) == 0 { + return nil, fmt.Errorf("Unknown service discovery error") + } + return ips, lookupErr +} + +func ResolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) { + ips, err := net.LookupIP(srv.Target) + if err != nil { + return nil, err + } + addrs := make([]*net.TCPAddr, len(ips)) + for i, ip := range ips { + addrs[i] = &net.TCPAddr{IP: ip, Port: int(srv.Port)} + } + return addrs, nil +} + +// FlattenServiceIPs transposes and flattens the input slices such that the +// first element of the n inner slices are the first n elements of the result. +func FlattenServiceIPs(ipsByService [][]*net.TCPAddr) []*net.TCPAddr { + var result []*net.TCPAddr + for len(ipsByService) > 0 { + filtered := ipsByService[:0] + for _, ips := range ipsByService { + if len(ips) == 0 { + // sanity check + continue + } + result = append(result, ips[0]) + if len(ips) > 1 { + filtered = append(filtered, ips[1:]) + } + } + ipsByService = filtered + } + return result +} diff --git a/origin/discovery_test.go b/origin/discovery_test.go new file mode 100644 index 00000000..fecf9c23 --- /dev/null +++ b/origin/discovery_test.go @@ -0,0 +1,45 @@ +package origin + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFlattenServiceIPs(t *testing.T) { + result := FlattenServiceIPs([][]*net.TCPAddr{ + []*net.TCPAddr{ + &net.TCPAddr{Port: 1}, + &net.TCPAddr{Port: 2}, + &net.TCPAddr{Port: 3}, + &net.TCPAddr{Port: 4}, + }, + []*net.TCPAddr{ + &net.TCPAddr{Port: 10}, + &net.TCPAddr{Port: 12}, + &net.TCPAddr{Port: 13}, + }, + []*net.TCPAddr{ + &net.TCPAddr{Port: 21}, + &net.TCPAddr{Port: 22}, + &net.TCPAddr{Port: 23}, + &net.TCPAddr{Port: 24}, + &net.TCPAddr{Port: 25}, + }, + }) + assert.EqualValues(t, []*net.TCPAddr{ + &net.TCPAddr{Port: 1}, + &net.TCPAddr{Port: 10}, + &net.TCPAddr{Port: 21}, + &net.TCPAddr{Port: 2}, + &net.TCPAddr{Port: 12}, + &net.TCPAddr{Port: 22}, + &net.TCPAddr{Port: 3}, + &net.TCPAddr{Port: 13}, + &net.TCPAddr{Port: 23}, + &net.TCPAddr{Port: 4}, + &net.TCPAddr{Port: 24}, + &net.TCPAddr{Port: 25}, + }, result) +} diff --git a/origin/metrics.go b/origin/metrics.go new file mode 100644 index 00000000..cba0e81e --- /dev/null +++ b/origin/metrics.go @@ -0,0 +1,258 @@ +package origin + +import ( + "sync" + + "github.com/cloudflare/cloudflare-warp/h2mux" + + log "github.com/Sirupsen/logrus" + "github.com/prometheus/client_golang/prometheus" +) + +type TunnelMetrics struct { + totalRequests prometheus.Counter + requestsPerTunnel *prometheus.CounterVec + // concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests + concurrentRequestsLock sync.Mutex + concurrentRequestsPerTunnel *prometheus.GaugeVec + // concurrentRequests records count of concurrent requests for each tunnel + concurrentRequests map[string]uint64 + maxConcurrentRequestsPerTunnel *prometheus.GaugeVec + // concurrentRequests records max count of concurrent requests for each tunnel + maxConcurrentRequests map[string]uint64 + rtt prometheus.Gauge + rttMin prometheus.Gauge + rttMax prometheus.Gauge + timerRetries prometheus.Gauge + receiveWindowSizeAve prometheus.Gauge + sendWindowSizeAve prometheus.Gauge + receiveWindowSizeMin prometheus.Gauge + receiveWindowSizeMax prometheus.Gauge + sendWindowSizeMin prometheus.Gauge + sendWindowSizeMax prometheus.Gauge + responseByCode *prometheus.CounterVec + responseCodePerTunnel *prometheus.CounterVec + serverLocations *prometheus.GaugeVec + // locationLock is a mutex for oldServerLocations + locationLock sync.Mutex + // oldServerLocations stores the last server the tunnel was connected to + oldServerLocations map[string]string +} + +// Metrics that can be collected without asking the edge +func NewTunnelMetrics() *TunnelMetrics { + totalRequests := prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "total_requests", + Help: "Amount of requests proxied through all the tunnels", + }) + prometheus.MustRegister(totalRequests) + + requestsPerTunnel := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "requests_per_tunnel", + Help: "Amount of requests proxied through each tunnel", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(requestsPerTunnel) + + concurrentRequestsPerTunnel := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "concurrent_requests_per_tunnel", + Help: "Concurrent requests proxied through each tunnel", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(concurrentRequestsPerTunnel) + + maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "max_concurrent_requests_per_tunnel", + Help: "Largest number of concurrent requests proxied through each tunnel so far", + }, + []string{"connection_id"}, + ) + prometheus.MustRegister(maxConcurrentRequestsPerTunnel) + + rtt := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "rtt", + Help: "Round-trip time", + }) + prometheus.MustRegister(rtt) + + rttMin := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "rtt_min", + Help: "Shortest round-trip time", + }) + prometheus.MustRegister(rttMin) + + rttMax := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "rtt_max", + Help: "Longest round-trip time", + }) + prometheus.MustRegister(rttMax) + + timerRetries := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "timer_retries", + Help: "Unacknowledged heart beats count", + }) + prometheus.MustRegister(timerRetries) + + receiveWindowSizeAve := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "receive_window_ave", + Help: "Average receive window size", + }) + prometheus.MustRegister(receiveWindowSizeAve) + + sendWindowSizeAve := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "send_window_ave", + Help: "Average send window size", + }) + prometheus.MustRegister(sendWindowSizeAve) + + receiveWindowSizeMin := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "receive_window_min", + Help: "Smallest receive window size", + }) + prometheus.MustRegister(receiveWindowSizeMin) + + receiveWindowSizeMax := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "receive_window_max", + Help: "Largest receive window size", + }) + prometheus.MustRegister(receiveWindowSizeMax) + + sendWindowSizeMin := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "send_window_min", + Help: "Smallest send window size", + }) + prometheus.MustRegister(sendWindowSizeMin) + + sendWindowSizeMax := prometheus.NewGauge( + prometheus.GaugeOpts{ + Name: "send_window_max", + Help: "Largest send window size", + }) + prometheus.MustRegister(sendWindowSizeMax) + + responseByCode := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "response_by_code", + Help: "Count of responses by HTTP status code", + }, + []string{"status_code"}, + ) + prometheus.MustRegister(responseByCode) + + responseCodePerTunnel := prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "response_code_per_tunnel", + Help: "Count of responses by HTTP status code fore each tunnel", + }, + []string{"connection_id", "status_code"}, + ) + prometheus.MustRegister(responseCodePerTunnel) + + serverLocations := prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "server_locations", + Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.", + }, + []string{"connection_id", "location"}, + ) + prometheus.MustRegister(serverLocations) + + return &TunnelMetrics{ + totalRequests: totalRequests, + requestsPerTunnel: requestsPerTunnel, + concurrentRequestsPerTunnel: concurrentRequestsPerTunnel, + concurrentRequests: make(map[string]uint64), + maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel, + maxConcurrentRequests: make(map[string]uint64), + rtt: rtt, + rttMin: rttMin, + rttMax: rttMax, + timerRetries: timerRetries, + receiveWindowSizeAve: receiveWindowSizeAve, + sendWindowSizeAve: sendWindowSizeAve, + receiveWindowSizeMin: receiveWindowSizeMin, + receiveWindowSizeMax: receiveWindowSizeMax, + sendWindowSizeMin: sendWindowSizeMin, + sendWindowSizeMax: sendWindowSizeMax, + responseByCode: responseByCode, + responseCodePerTunnel: responseCodePerTunnel, + serverLocations: serverLocations, + oldServerLocations: make(map[string]string), + } +} + +func (t *TunnelMetrics) updateTunnelFlowControlMetrics(metrics *h2mux.FlowControlMetrics) { + t.receiveWindowSizeAve.Set(float64(metrics.AverageReceiveWindowSize)) + t.sendWindowSizeAve.Set(float64(metrics.AverageSendWindowSize)) + t.receiveWindowSizeMin.Set(float64(metrics.MinReceiveWindowSize)) + t.receiveWindowSizeMax.Set(float64(metrics.MaxReceiveWindowSize)) + t.sendWindowSizeMin.Set(float64(metrics.MinSendWindowSize)) + t.sendWindowSizeMax.Set(float64(metrics.MaxSendWindowSize)) +} + +func (t *TunnelMetrics) incrementRequests(connectionID string) { + t.concurrentRequestsLock.Lock() + var concurrentRequests uint64 + var ok bool + if concurrentRequests, ok = t.concurrentRequests[connectionID]; ok { + t.concurrentRequests[connectionID] += 1 + concurrentRequests++ + } else { + t.concurrentRequests[connectionID] = 1 + concurrentRequests = 1 + } + if maxConcurrentRequests, ok := t.maxConcurrentRequests[connectionID]; (ok && maxConcurrentRequests < concurrentRequests) || !ok { + t.maxConcurrentRequests[connectionID] = concurrentRequests + t.maxConcurrentRequestsPerTunnel.WithLabelValues(connectionID).Set(float64(concurrentRequests)) + } + t.concurrentRequestsLock.Unlock() + + t.totalRequests.Inc() + t.requestsPerTunnel.WithLabelValues(connectionID).Inc() + t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc() +} + +func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) { + t.concurrentRequestsLock.Lock() + if _, ok := t.concurrentRequests[connectionID]; ok { + t.concurrentRequests[connectionID] -= 1 + } else { + log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.") + } + t.concurrentRequestsLock.Unlock() + + t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec() +} + +func (t *TunnelMetrics) incrementResponses(connectionID, code string) { + t.responseByCode.WithLabelValues(code).Inc() + t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc() + +} + +func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) { + t.locationLock.Lock() + defer t.locationLock.Unlock() + if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc { + return + } else if ok { + t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec() + } + t.serverLocations.WithLabelValues(connectionID, loc).Inc() + t.oldServerLocations[connectionID] = loc +} diff --git a/origin/metrics_test.go b/origin/metrics_test.go new file mode 100644 index 00000000..b6cc8206 --- /dev/null +++ b/origin/metrics_test.go @@ -0,0 +1,121 @@ +package origin + +import ( + "strconv" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +// can only be called once +var m = NewTunnelMetrics() + +func TestConcurrentRequestsSingleTunnel(t *testing.T) { + routines := 20 + var wg sync.WaitGroup + wg.Add(routines) + for i := 0; i < routines; i++ { + go func() { + m.incrementRequests("0") + wg.Done() + }() + } + wg.Wait() + assert.Len(t, m.concurrentRequests, 1) + assert.Equal(t, uint64(routines), m.concurrentRequests["0"]) + assert.Len(t, m.maxConcurrentRequests, 1) + assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"]) + + wg.Add(routines / 2) + for i := 0; i < routines/2; i++ { + go func() { + m.decrementConcurrentRequests("0") + wg.Done() + }() + } + wg.Wait() + assert.Equal(t, uint64(routines-routines/2), m.concurrentRequests["0"]) + assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"]) +} + +func TestConcurrentRequestsMultiTunnel(t *testing.T) { + m.concurrentRequests = make(map[string]uint64) + m.maxConcurrentRequests = make(map[string]uint64) + tunnels := 20 + var wg sync.WaitGroup + wg.Add(tunnels) + for i := 0; i < tunnels; i++ { + go func(i int) { + // if we have j < i, then tunnel 0 won't have a chance to call incrementRequests + for j := 0; j < i+1; j++ { + id := strconv.Itoa(i) + m.incrementRequests(id) + } + wg.Done() + }(i) + } + wg.Wait() + + assert.Len(t, m.concurrentRequests, tunnels) + assert.Len(t, m.maxConcurrentRequests, tunnels) + for i := 0; i < tunnels; i++ { + id := strconv.Itoa(i) + assert.Equal(t, uint64(i+1), m.concurrentRequests[id]) + assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id]) + } + + wg.Add(tunnels) + for i := 0; i < tunnels; i++ { + go func(i int) { + for j := 0; j < i+1; j++ { + id := strconv.Itoa(i) + m.decrementConcurrentRequests(id) + } + wg.Done() + }(i) + } + wg.Wait() + + assert.Len(t, m.concurrentRequests, tunnels) + assert.Len(t, m.maxConcurrentRequests, tunnels) + for i := 0; i < tunnels; i++ { + id := strconv.Itoa(i) + assert.Equal(t, uint64(0), m.concurrentRequests[id]) + assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id]) + } + +} + +func TestRegisterServerLocation(t *testing.T) { + tunnels := 20 + var wg sync.WaitGroup + wg.Add(tunnels) + for i := 0; i < tunnels; i++ { + go func(i int) { + id := strconv.Itoa(i) + m.registerServerLocation(id, "LHR") + wg.Done() + }(i) + } + wg.Wait() + for i := 0; i < tunnels; i++ { + id := strconv.Itoa(i) + assert.Equal(t, "LHR", m.oldServerLocations[id]) + } + + wg.Add(tunnels) + for i := 0; i < tunnels; i++ { + go func(i int) { + id := strconv.Itoa(i) + m.registerServerLocation(id, "AUS") + wg.Done() + }(i) + } + wg.Wait() + for i := 0; i < tunnels; i++ { + id := strconv.Itoa(i) + assert.Equal(t, "AUS", m.oldServerLocations[id]) + } + +} diff --git a/origin/supervisor.go b/origin/supervisor.go new file mode 100644 index 00000000..b242c33c --- /dev/null +++ b/origin/supervisor.go @@ -0,0 +1,174 @@ +package origin + +import ( + "fmt" + "net" + "time" + + log "github.com/Sirupsen/logrus" + "golang.org/x/net/context" +) + +const ( + // Waiting time before retrying a failed tunnel connection + tunnelRetryDuration = time.Minute + // Limit on the exponential backoff time period. (2^5 = 32 minutes) + tunnelRetryLimit = 5 + // SRV record resolution TTL + resolveTTL = time.Hour +) + +type Supervisor struct { + config *TunnelConfig + edgeIPs []*net.TCPAddr + lastResolve time.Time + resolverC chan resolveResult + tunnelErrors chan tunnelError + tunnelsConnecting map[int]chan struct{} + nextConnectedIndex int + nextConnectedSignal chan struct{} +} + +type resolveResult struct { + edgeIPs []*net.TCPAddr + err error +} + +type tunnelError struct { + index int + err error +} + +func NewSupervisor(config *TunnelConfig) *Supervisor { + return &Supervisor{ + config: config, + tunnelErrors: make(chan tunnelError), + tunnelsConnecting: map[int]chan struct{}{}, + } +} + +func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) error { + if err := s.initialize(ctx, connectedSignal); err != nil { + return err + } + tunnelsActive := s.config.HAConnections + tunnelsWaiting := []int{} + backoff := BackoffHandler{MaxRetries: tunnelRetryLimit, BaseTime: tunnelRetryDuration, RetryForever: true} + var backoffTimer <-chan time.Time + for tunnelsActive > 0 { + select { + // Context cancelled + case <-ctx.Done(): + for tunnelsActive > 0 { + <-s.tunnelErrors + tunnelsActive-- + } + return nil + // startTunnel returned with error + // (note that this may also be caused by context cancellation) + case tunnelError := <-s.tunnelErrors: + tunnelsActive-- + if tunnelError.err != nil { + log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error") + tunnelsWaiting = append(tunnelsWaiting, tunnelError.index) + s.waitForNextTunnel(tunnelError.index) + if backoffTimer != nil { + backoffTimer = backoff.BackoffTimer() + } + s.refreshEdgeIPs() + } + // Backoff was set and its timer expired + case <-backoffTimer: + backoffTimer = nil + for _, index := range tunnelsWaiting { + go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index)) + } + tunnelsActive += len(tunnelsWaiting) + tunnelsWaiting = nil + // Tunnel successfully connected + case <-s.nextConnectedSignal: + if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 { + // No more tunnels outstanding, clear backoff timer + backoff.SetGracePeriod() + } + // DNS resolution returned + case result := <-s.resolverC: + s.lastResolve = time.Now() + s.resolverC = nil + if result.err == nil { + log.Debug("Service discovery refresh complete") + s.edgeIPs = result.edgeIPs + } else { + log.WithError(result.err).Error("Service discovery error") + } + } + } + return fmt.Errorf("All tunnels terminated") +} + +func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error { + edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs) + if err != nil { + return err + } + s.edgeIPs = edgeIPs + s.lastResolve = time.Now() + go s.startTunnel(ctx, 0, connectedSignal) + select { + case <-ctx.Done(): + <-s.tunnelErrors + return nil + case tunnelError := <-s.tunnelErrors: + return tunnelError.err + case <-connectedSignal: + } + // At least one successful connection, so start the rest + for i := 1; i < s.config.HAConnections; i++ { + go s.startTunnel(ctx, i, make(chan struct{})) + } + return nil +} + +// startTunnel starts a new tunnel connection. The resulting error will be sent on +// s.tunnelErrors. +func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}) { + err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal) + s.tunnelErrors <- tunnelError{index: index, err: err} +} + +func (s *Supervisor) newConnectedTunnelSignal(index int) chan struct{} { + signal := make(chan struct{}) + s.tunnelsConnecting[index] = signal + s.nextConnectedSignal = signal + s.nextConnectedIndex = index + return signal +} + +func (s *Supervisor) waitForNextTunnel(index int) bool { + delete(s.tunnelsConnecting, index) + s.nextConnectedSignal = nil + for k, v := range s.tunnelsConnecting { + s.nextConnectedIndex = k + s.nextConnectedSignal = v + return true + } + return false +} + +func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr { + return s.edgeIPs[index%len(s.edgeIPs)] +} + +func (s *Supervisor) refreshEdgeIPs() { + if s.resolverC != nil { + return + } + if time.Since(s.lastResolve) < resolveTTL { + return + } + s.resolverC = make(chan resolveResult) + go func() { + edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs) + s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err} + }() +}