Add new files missed from release 2017.11.3

This commit is contained in:
Chris Branch 2017-11-29 14:43:31 +00:00
parent ee499dfa88
commit d40eb85da6
5 changed files with 680 additions and 0 deletions

82
origin/discovery.go Normal file
View File

@ -0,0 +1,82 @@
package origin
import (
"fmt"
"net"
)
const (
// Used to discover HA Warp servers
srvService = "warp"
srvProto = "tcp"
srvName = "cloudflarewarp.com"
)
func ResolveEdgeIPs(addresses []string) ([]*net.TCPAddr, error) {
if len(addresses) > 0 {
var tcpAddrs []*net.TCPAddr
for _, address := range addresses {
// Addresses specified (for testing, usually)
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return nil, err
}
tcpAddrs = append(tcpAddrs, tcpAddr)
}
return tcpAddrs, nil
}
// HA service discovery lookup
_, addrs, err := net.LookupSRV(srvService, srvProto, srvName)
if err != nil {
return nil, err
}
var resolvedIPsPerCNAME [][]*net.TCPAddr
var lookupErr error
for _, addr := range addrs {
ips, err := ResolveSRVToTCP(addr)
if err != nil || len(ips) == 0 {
// don't return early, we might be able to resolve other addresses
lookupErr = err
continue
}
resolvedIPsPerCNAME = append(resolvedIPsPerCNAME, ips)
}
ips := FlattenServiceIPs(resolvedIPsPerCNAME)
if lookupErr == nil && len(ips) == 0 {
return nil, fmt.Errorf("Unknown service discovery error")
}
return ips, lookupErr
}
func ResolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) {
ips, err := net.LookupIP(srv.Target)
if err != nil {
return nil, err
}
addrs := make([]*net.TCPAddr, len(ips))
for i, ip := range ips {
addrs[i] = &net.TCPAddr{IP: ip, Port: int(srv.Port)}
}
return addrs, nil
}
// FlattenServiceIPs transposes and flattens the input slices such that the
// first element of the n inner slices are the first n elements of the result.
func FlattenServiceIPs(ipsByService [][]*net.TCPAddr) []*net.TCPAddr {
var result []*net.TCPAddr
for len(ipsByService) > 0 {
filtered := ipsByService[:0]
for _, ips := range ipsByService {
if len(ips) == 0 {
// sanity check
continue
}
result = append(result, ips[0])
if len(ips) > 1 {
filtered = append(filtered, ips[1:])
}
}
ipsByService = filtered
}
return result
}

45
origin/discovery_test.go Normal file
View File

@ -0,0 +1,45 @@
package origin
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestFlattenServiceIPs(t *testing.T) {
result := FlattenServiceIPs([][]*net.TCPAddr{
[]*net.TCPAddr{
&net.TCPAddr{Port: 1},
&net.TCPAddr{Port: 2},
&net.TCPAddr{Port: 3},
&net.TCPAddr{Port: 4},
},
[]*net.TCPAddr{
&net.TCPAddr{Port: 10},
&net.TCPAddr{Port: 12},
&net.TCPAddr{Port: 13},
},
[]*net.TCPAddr{
&net.TCPAddr{Port: 21},
&net.TCPAddr{Port: 22},
&net.TCPAddr{Port: 23},
&net.TCPAddr{Port: 24},
&net.TCPAddr{Port: 25},
},
})
assert.EqualValues(t, []*net.TCPAddr{
&net.TCPAddr{Port: 1},
&net.TCPAddr{Port: 10},
&net.TCPAddr{Port: 21},
&net.TCPAddr{Port: 2},
&net.TCPAddr{Port: 12},
&net.TCPAddr{Port: 22},
&net.TCPAddr{Port: 3},
&net.TCPAddr{Port: 13},
&net.TCPAddr{Port: 23},
&net.TCPAddr{Port: 4},
&net.TCPAddr{Port: 24},
&net.TCPAddr{Port: 25},
}, result)
}

258
origin/metrics.go Normal file
View File

@ -0,0 +1,258 @@
package origin
import (
"sync"
"github.com/cloudflare/cloudflare-warp/h2mux"
log "github.com/Sirupsen/logrus"
"github.com/prometheus/client_golang/prometheus"
)
type TunnelMetrics struct {
totalRequests prometheus.Counter
requestsPerTunnel *prometheus.CounterVec
// concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests
concurrentRequestsLock sync.Mutex
concurrentRequestsPerTunnel *prometheus.GaugeVec
// concurrentRequests records count of concurrent requests for each tunnel
concurrentRequests map[string]uint64
maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
// concurrentRequests records max count of concurrent requests for each tunnel
maxConcurrentRequests map[string]uint64
rtt prometheus.Gauge
rttMin prometheus.Gauge
rttMax prometheus.Gauge
timerRetries prometheus.Gauge
receiveWindowSizeAve prometheus.Gauge
sendWindowSizeAve prometheus.Gauge
receiveWindowSizeMin prometheus.Gauge
receiveWindowSizeMax prometheus.Gauge
sendWindowSizeMin prometheus.Gauge
sendWindowSizeMax prometheus.Gauge
responseByCode *prometheus.CounterVec
responseCodePerTunnel *prometheus.CounterVec
serverLocations *prometheus.GaugeVec
// locationLock is a mutex for oldServerLocations
locationLock sync.Mutex
// oldServerLocations stores the last server the tunnel was connected to
oldServerLocations map[string]string
}
// Metrics that can be collected without asking the edge
func NewTunnelMetrics() *TunnelMetrics {
totalRequests := prometheus.NewCounter(
prometheus.CounterOpts{
Name: "total_requests",
Help: "Amount of requests proxied through all the tunnels",
})
prometheus.MustRegister(totalRequests)
requestsPerTunnel := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "requests_per_tunnel",
Help: "Amount of requests proxied through each tunnel",
},
[]string{"connection_id"},
)
prometheus.MustRegister(requestsPerTunnel)
concurrentRequestsPerTunnel := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "concurrent_requests_per_tunnel",
Help: "Concurrent requests proxied through each tunnel",
},
[]string{"connection_id"},
)
prometheus.MustRegister(concurrentRequestsPerTunnel)
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "max_concurrent_requests_per_tunnel",
Help: "Largest number of concurrent requests proxied through each tunnel so far",
},
[]string{"connection_id"},
)
prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
rtt := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "rtt",
Help: "Round-trip time",
})
prometheus.MustRegister(rtt)
rttMin := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "rtt_min",
Help: "Shortest round-trip time",
})
prometheus.MustRegister(rttMin)
rttMax := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "rtt_max",
Help: "Longest round-trip time",
})
prometheus.MustRegister(rttMax)
timerRetries := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "timer_retries",
Help: "Unacknowledged heart beats count",
})
prometheus.MustRegister(timerRetries)
receiveWindowSizeAve := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "receive_window_ave",
Help: "Average receive window size",
})
prometheus.MustRegister(receiveWindowSizeAve)
sendWindowSizeAve := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "send_window_ave",
Help: "Average send window size",
})
prometheus.MustRegister(sendWindowSizeAve)
receiveWindowSizeMin := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "receive_window_min",
Help: "Smallest receive window size",
})
prometheus.MustRegister(receiveWindowSizeMin)
receiveWindowSizeMax := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "receive_window_max",
Help: "Largest receive window size",
})
prometheus.MustRegister(receiveWindowSizeMax)
sendWindowSizeMin := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "send_window_min",
Help: "Smallest send window size",
})
prometheus.MustRegister(sendWindowSizeMin)
sendWindowSizeMax := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "send_window_max",
Help: "Largest send window size",
})
prometheus.MustRegister(sendWindowSizeMax)
responseByCode := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "response_by_code",
Help: "Count of responses by HTTP status code",
},
[]string{"status_code"},
)
prometheus.MustRegister(responseByCode)
responseCodePerTunnel := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "response_code_per_tunnel",
Help: "Count of responses by HTTP status code fore each tunnel",
},
[]string{"connection_id", "status_code"},
)
prometheus.MustRegister(responseCodePerTunnel)
serverLocations := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "server_locations",
Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.",
},
[]string{"connection_id", "location"},
)
prometheus.MustRegister(serverLocations)
return &TunnelMetrics{
totalRequests: totalRequests,
requestsPerTunnel: requestsPerTunnel,
concurrentRequestsPerTunnel: concurrentRequestsPerTunnel,
concurrentRequests: make(map[string]uint64),
maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
maxConcurrentRequests: make(map[string]uint64),
rtt: rtt,
rttMin: rttMin,
rttMax: rttMax,
timerRetries: timerRetries,
receiveWindowSizeAve: receiveWindowSizeAve,
sendWindowSizeAve: sendWindowSizeAve,
receiveWindowSizeMin: receiveWindowSizeMin,
receiveWindowSizeMax: receiveWindowSizeMax,
sendWindowSizeMin: sendWindowSizeMin,
sendWindowSizeMax: sendWindowSizeMax,
responseByCode: responseByCode,
responseCodePerTunnel: responseCodePerTunnel,
serverLocations: serverLocations,
oldServerLocations: make(map[string]string),
}
}
func (t *TunnelMetrics) updateTunnelFlowControlMetrics(metrics *h2mux.FlowControlMetrics) {
t.receiveWindowSizeAve.Set(float64(metrics.AverageReceiveWindowSize))
t.sendWindowSizeAve.Set(float64(metrics.AverageSendWindowSize))
t.receiveWindowSizeMin.Set(float64(metrics.MinReceiveWindowSize))
t.receiveWindowSizeMax.Set(float64(metrics.MaxReceiveWindowSize))
t.sendWindowSizeMin.Set(float64(metrics.MinSendWindowSize))
t.sendWindowSizeMax.Set(float64(metrics.MaxSendWindowSize))
}
func (t *TunnelMetrics) incrementRequests(connectionID string) {
t.concurrentRequestsLock.Lock()
var concurrentRequests uint64
var ok bool
if concurrentRequests, ok = t.concurrentRequests[connectionID]; ok {
t.concurrentRequests[connectionID] += 1
concurrentRequests++
} else {
t.concurrentRequests[connectionID] = 1
concurrentRequests = 1
}
if maxConcurrentRequests, ok := t.maxConcurrentRequests[connectionID]; (ok && maxConcurrentRequests < concurrentRequests) || !ok {
t.maxConcurrentRequests[connectionID] = concurrentRequests
t.maxConcurrentRequestsPerTunnel.WithLabelValues(connectionID).Set(float64(concurrentRequests))
}
t.concurrentRequestsLock.Unlock()
t.totalRequests.Inc()
t.requestsPerTunnel.WithLabelValues(connectionID).Inc()
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
}
func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
t.concurrentRequestsLock.Lock()
if _, ok := t.concurrentRequests[connectionID]; ok {
t.concurrentRequests[connectionID] -= 1
} else {
log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.")
}
t.concurrentRequestsLock.Unlock()
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
}
func (t *TunnelMetrics) incrementResponses(connectionID, code string) {
t.responseByCode.WithLabelValues(code).Inc()
t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
}
func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) {
t.locationLock.Lock()
defer t.locationLock.Unlock()
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
return
} else if ok {
t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec()
}
t.serverLocations.WithLabelValues(connectionID, loc).Inc()
t.oldServerLocations[connectionID] = loc
}

121
origin/metrics_test.go Normal file
View File

@ -0,0 +1,121 @@
package origin
import (
"strconv"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
// can only be called once
var m = NewTunnelMetrics()
func TestConcurrentRequestsSingleTunnel(t *testing.T) {
routines := 20
var wg sync.WaitGroup
wg.Add(routines)
for i := 0; i < routines; i++ {
go func() {
m.incrementRequests("0")
wg.Done()
}()
}
wg.Wait()
assert.Len(t, m.concurrentRequests, 1)
assert.Equal(t, uint64(routines), m.concurrentRequests["0"])
assert.Len(t, m.maxConcurrentRequests, 1)
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
wg.Add(routines / 2)
for i := 0; i < routines/2; i++ {
go func() {
m.decrementConcurrentRequests("0")
wg.Done()
}()
}
wg.Wait()
assert.Equal(t, uint64(routines-routines/2), m.concurrentRequests["0"])
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
}
func TestConcurrentRequestsMultiTunnel(t *testing.T) {
m.concurrentRequests = make(map[string]uint64)
m.maxConcurrentRequests = make(map[string]uint64)
tunnels := 20
var wg sync.WaitGroup
wg.Add(tunnels)
for i := 0; i < tunnels; i++ {
go func(i int) {
// if we have j < i, then tunnel 0 won't have a chance to call incrementRequests
for j := 0; j < i+1; j++ {
id := strconv.Itoa(i)
m.incrementRequests(id)
}
wg.Done()
}(i)
}
wg.Wait()
assert.Len(t, m.concurrentRequests, tunnels)
assert.Len(t, m.maxConcurrentRequests, tunnels)
for i := 0; i < tunnels; i++ {
id := strconv.Itoa(i)
assert.Equal(t, uint64(i+1), m.concurrentRequests[id])
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
}
wg.Add(tunnels)
for i := 0; i < tunnels; i++ {
go func(i int) {
for j := 0; j < i+1; j++ {
id := strconv.Itoa(i)
m.decrementConcurrentRequests(id)
}
wg.Done()
}(i)
}
wg.Wait()
assert.Len(t, m.concurrentRequests, tunnels)
assert.Len(t, m.maxConcurrentRequests, tunnels)
for i := 0; i < tunnels; i++ {
id := strconv.Itoa(i)
assert.Equal(t, uint64(0), m.concurrentRequests[id])
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
}
}
func TestRegisterServerLocation(t *testing.T) {
tunnels := 20
var wg sync.WaitGroup
wg.Add(tunnels)
for i := 0; i < tunnels; i++ {
go func(i int) {
id := strconv.Itoa(i)
m.registerServerLocation(id, "LHR")
wg.Done()
}(i)
}
wg.Wait()
for i := 0; i < tunnels; i++ {
id := strconv.Itoa(i)
assert.Equal(t, "LHR", m.oldServerLocations[id])
}
wg.Add(tunnels)
for i := 0; i < tunnels; i++ {
go func(i int) {
id := strconv.Itoa(i)
m.registerServerLocation(id, "AUS")
wg.Done()
}(i)
}
wg.Wait()
for i := 0; i < tunnels; i++ {
id := strconv.Itoa(i)
assert.Equal(t, "AUS", m.oldServerLocations[id])
}
}

174
origin/supervisor.go Normal file
View File

@ -0,0 +1,174 @@
package origin
import (
"fmt"
"net"
"time"
log "github.com/Sirupsen/logrus"
"golang.org/x/net/context"
)
const (
// Waiting time before retrying a failed tunnel connection
tunnelRetryDuration = time.Minute
// Limit on the exponential backoff time period. (2^5 = 32 minutes)
tunnelRetryLimit = 5
// SRV record resolution TTL
resolveTTL = time.Hour
)
type Supervisor struct {
config *TunnelConfig
edgeIPs []*net.TCPAddr
lastResolve time.Time
resolverC chan resolveResult
tunnelErrors chan tunnelError
tunnelsConnecting map[int]chan struct{}
nextConnectedIndex int
nextConnectedSignal chan struct{}
}
type resolveResult struct {
edgeIPs []*net.TCPAddr
err error
}
type tunnelError struct {
index int
err error
}
func NewSupervisor(config *TunnelConfig) *Supervisor {
return &Supervisor{
config: config,
tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{},
}
}
func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) error {
if err := s.initialize(ctx, connectedSignal); err != nil {
return err
}
tunnelsActive := s.config.HAConnections
tunnelsWaiting := []int{}
backoff := BackoffHandler{MaxRetries: tunnelRetryLimit, BaseTime: tunnelRetryDuration, RetryForever: true}
var backoffTimer <-chan time.Time
for tunnelsActive > 0 {
select {
// Context cancelled
case <-ctx.Done():
for tunnelsActive > 0 {
<-s.tunnelErrors
tunnelsActive--
}
return nil
// startTunnel returned with error
// (note that this may also be caused by context cancellation)
case tunnelError := <-s.tunnelErrors:
tunnelsActive--
if tunnelError.err != nil {
log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
s.waitForNextTunnel(tunnelError.index)
if backoffTimer != nil {
backoffTimer = backoff.BackoffTimer()
}
s.refreshEdgeIPs()
}
// Backoff was set and its timer expired
case <-backoffTimer:
backoffTimer = nil
for _, index := range tunnelsWaiting {
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
}
tunnelsActive += len(tunnelsWaiting)
tunnelsWaiting = nil
// Tunnel successfully connected
case <-s.nextConnectedSignal:
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
// No more tunnels outstanding, clear backoff timer
backoff.SetGracePeriod()
}
// DNS resolution returned
case result := <-s.resolverC:
s.lastResolve = time.Now()
s.resolverC = nil
if result.err == nil {
log.Debug("Service discovery refresh complete")
s.edgeIPs = result.edgeIPs
} else {
log.WithError(result.err).Error("Service discovery error")
}
}
}
return fmt.Errorf("All tunnels terminated")
}
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
if err != nil {
return err
}
s.edgeIPs = edgeIPs
s.lastResolve = time.Now()
go s.startTunnel(ctx, 0, connectedSignal)
select {
case <-ctx.Done():
<-s.tunnelErrors
return nil
case tunnelError := <-s.tunnelErrors:
return tunnelError.err
case <-connectedSignal:
}
// At least one successful connection, so start the rest
for i := 1; i < s.config.HAConnections; i++ {
go s.startTunnel(ctx, i, make(chan struct{}))
}
return nil
}
// startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelErrors.
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}) {
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal)
s.tunnelErrors <- tunnelError{index: index, err: err}
}
func (s *Supervisor) newConnectedTunnelSignal(index int) chan struct{} {
signal := make(chan struct{})
s.tunnelsConnecting[index] = signal
s.nextConnectedSignal = signal
s.nextConnectedIndex = index
return signal
}
func (s *Supervisor) waitForNextTunnel(index int) bool {
delete(s.tunnelsConnecting, index)
s.nextConnectedSignal = nil
for k, v := range s.tunnelsConnecting {
s.nextConnectedIndex = k
s.nextConnectedSignal = v
return true
}
return false
}
func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr {
return s.edgeIPs[index%len(s.edgeIPs)]
}
func (s *Supervisor) refreshEdgeIPs() {
if s.resolverC != nil {
return
}
if time.Since(s.lastResolve) < resolveTTL {
return
}
s.resolverC = make(chan resolveResult)
go func() {
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err}
}()
}