Add new files missed from release 2017.11.3
This commit is contained in:
parent
ee499dfa88
commit
d40eb85da6
|
@ -0,0 +1,82 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Used to discover HA Warp servers
|
||||||
|
srvService = "warp"
|
||||||
|
srvProto = "tcp"
|
||||||
|
srvName = "cloudflarewarp.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ResolveEdgeIPs(addresses []string) ([]*net.TCPAddr, error) {
|
||||||
|
if len(addresses) > 0 {
|
||||||
|
var tcpAddrs []*net.TCPAddr
|
||||||
|
for _, address := range addresses {
|
||||||
|
// Addresses specified (for testing, usually)
|
||||||
|
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tcpAddrs = append(tcpAddrs, tcpAddr)
|
||||||
|
}
|
||||||
|
return tcpAddrs, nil
|
||||||
|
}
|
||||||
|
// HA service discovery lookup
|
||||||
|
_, addrs, err := net.LookupSRV(srvService, srvProto, srvName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var resolvedIPsPerCNAME [][]*net.TCPAddr
|
||||||
|
var lookupErr error
|
||||||
|
for _, addr := range addrs {
|
||||||
|
ips, err := ResolveSRVToTCP(addr)
|
||||||
|
if err != nil || len(ips) == 0 {
|
||||||
|
// don't return early, we might be able to resolve other addresses
|
||||||
|
lookupErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resolvedIPsPerCNAME = append(resolvedIPsPerCNAME, ips)
|
||||||
|
}
|
||||||
|
ips := FlattenServiceIPs(resolvedIPsPerCNAME)
|
||||||
|
if lookupErr == nil && len(ips) == 0 {
|
||||||
|
return nil, fmt.Errorf("Unknown service discovery error")
|
||||||
|
}
|
||||||
|
return ips, lookupErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) {
|
||||||
|
ips, err := net.LookupIP(srv.Target)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
addrs := make([]*net.TCPAddr, len(ips))
|
||||||
|
for i, ip := range ips {
|
||||||
|
addrs[i] = &net.TCPAddr{IP: ip, Port: int(srv.Port)}
|
||||||
|
}
|
||||||
|
return addrs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlattenServiceIPs transposes and flattens the input slices such that the
|
||||||
|
// first element of the n inner slices are the first n elements of the result.
|
||||||
|
func FlattenServiceIPs(ipsByService [][]*net.TCPAddr) []*net.TCPAddr {
|
||||||
|
var result []*net.TCPAddr
|
||||||
|
for len(ipsByService) > 0 {
|
||||||
|
filtered := ipsByService[:0]
|
||||||
|
for _, ips := range ipsByService {
|
||||||
|
if len(ips) == 0 {
|
||||||
|
// sanity check
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, ips[0])
|
||||||
|
if len(ips) > 1 {
|
||||||
|
filtered = append(filtered, ips[1:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ipsByService = filtered
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFlattenServiceIPs(t *testing.T) {
|
||||||
|
result := FlattenServiceIPs([][]*net.TCPAddr{
|
||||||
|
[]*net.TCPAddr{
|
||||||
|
&net.TCPAddr{Port: 1},
|
||||||
|
&net.TCPAddr{Port: 2},
|
||||||
|
&net.TCPAddr{Port: 3},
|
||||||
|
&net.TCPAddr{Port: 4},
|
||||||
|
},
|
||||||
|
[]*net.TCPAddr{
|
||||||
|
&net.TCPAddr{Port: 10},
|
||||||
|
&net.TCPAddr{Port: 12},
|
||||||
|
&net.TCPAddr{Port: 13},
|
||||||
|
},
|
||||||
|
[]*net.TCPAddr{
|
||||||
|
&net.TCPAddr{Port: 21},
|
||||||
|
&net.TCPAddr{Port: 22},
|
||||||
|
&net.TCPAddr{Port: 23},
|
||||||
|
&net.TCPAddr{Port: 24},
|
||||||
|
&net.TCPAddr{Port: 25},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assert.EqualValues(t, []*net.TCPAddr{
|
||||||
|
&net.TCPAddr{Port: 1},
|
||||||
|
&net.TCPAddr{Port: 10},
|
||||||
|
&net.TCPAddr{Port: 21},
|
||||||
|
&net.TCPAddr{Port: 2},
|
||||||
|
&net.TCPAddr{Port: 12},
|
||||||
|
&net.TCPAddr{Port: 22},
|
||||||
|
&net.TCPAddr{Port: 3},
|
||||||
|
&net.TCPAddr{Port: 13},
|
||||||
|
&net.TCPAddr{Port: 23},
|
||||||
|
&net.TCPAddr{Port: 4},
|
||||||
|
&net.TCPAddr{Port: 24},
|
||||||
|
&net.TCPAddr{Port: 25},
|
||||||
|
}, result)
|
||||||
|
}
|
|
@ -0,0 +1,258 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflare-warp/h2mux"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TunnelMetrics struct {
|
||||||
|
totalRequests prometheus.Counter
|
||||||
|
requestsPerTunnel *prometheus.CounterVec
|
||||||
|
// concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests
|
||||||
|
concurrentRequestsLock sync.Mutex
|
||||||
|
concurrentRequestsPerTunnel *prometheus.GaugeVec
|
||||||
|
// concurrentRequests records count of concurrent requests for each tunnel
|
||||||
|
concurrentRequests map[string]uint64
|
||||||
|
maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
|
||||||
|
// concurrentRequests records max count of concurrent requests for each tunnel
|
||||||
|
maxConcurrentRequests map[string]uint64
|
||||||
|
rtt prometheus.Gauge
|
||||||
|
rttMin prometheus.Gauge
|
||||||
|
rttMax prometheus.Gauge
|
||||||
|
timerRetries prometheus.Gauge
|
||||||
|
receiveWindowSizeAve prometheus.Gauge
|
||||||
|
sendWindowSizeAve prometheus.Gauge
|
||||||
|
receiveWindowSizeMin prometheus.Gauge
|
||||||
|
receiveWindowSizeMax prometheus.Gauge
|
||||||
|
sendWindowSizeMin prometheus.Gauge
|
||||||
|
sendWindowSizeMax prometheus.Gauge
|
||||||
|
responseByCode *prometheus.CounterVec
|
||||||
|
responseCodePerTunnel *prometheus.CounterVec
|
||||||
|
serverLocations *prometheus.GaugeVec
|
||||||
|
// locationLock is a mutex for oldServerLocations
|
||||||
|
locationLock sync.Mutex
|
||||||
|
// oldServerLocations stores the last server the tunnel was connected to
|
||||||
|
oldServerLocations map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metrics that can be collected without asking the edge
|
||||||
|
func NewTunnelMetrics() *TunnelMetrics {
|
||||||
|
totalRequests := prometheus.NewCounter(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Name: "total_requests",
|
||||||
|
Help: "Amount of requests proxied through all the tunnels",
|
||||||
|
})
|
||||||
|
prometheus.MustRegister(totalRequests)
|
||||||
|
|
||||||
|
requestsPerTunnel := prometheus.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Name: "requests_per_tunnel",
|
||||||
|
Help: "Amount of requests proxied through each tunnel",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(requestsPerTunnel)
|
||||||
|
|
||||||
|
concurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "concurrent_requests_per_tunnel",
|
||||||
|
Help: "Concurrent requests proxied through each tunnel",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(concurrentRequestsPerTunnel)
|
||||||
|
|
||||||
|
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "max_concurrent_requests_per_tunnel",
|
||||||
|
Help: "Largest number of concurrent requests proxied through each tunnel so far",
|
||||||
|
},
|
||||||
|
[]string{"connection_id"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
|
||||||
|
|
||||||
|
rtt := prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "rtt",
|
||||||
|
Help: "Round-trip time",
|
||||||
|
})
|
||||||
|
prometheus.MustRegister(rtt)
|
||||||
|
|
||||||
|
rttMin := prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "rtt_min",
|
||||||
|
Help: "Shortest round-trip time",
|
||||||
|
})
|
||||||
|
prometheus.MustRegister(rttMin)
|
||||||
|
|
||||||
|
rttMax := prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "rtt_max",
|
||||||
|
Help: "Longest round-trip time",
|
||||||
|
})
|
||||||
|
prometheus.MustRegister(rttMax)
|
||||||
|
|
||||||
|
timerRetries := prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "timer_retries",
|
||||||
|
Help: "Unacknowledged heart beats count",
|
||||||
|
})
|
||||||
|
prometheus.MustRegister(timerRetries)
|
||||||
|
|
||||||
|
receiveWindowSizeAve := prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "receive_window_ave",
|
||||||
|
Help: "Average receive window size",
|
||||||
|
})
|
||||||
|
prometheus.MustRegister(receiveWindowSizeAve)
|
||||||
|
|
||||||
|
sendWindowSizeAve := prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "send_window_ave",
|
||||||
|
Help: "Average send window size",
|
||||||
|
})
|
||||||
|
prometheus.MustRegister(sendWindowSizeAve)
|
||||||
|
|
||||||
|
receiveWindowSizeMin := prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "receive_window_min",
|
||||||
|
Help: "Smallest receive window size",
|
||||||
|
})
|
||||||
|
prometheus.MustRegister(receiveWindowSizeMin)
|
||||||
|
|
||||||
|
receiveWindowSizeMax := prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "receive_window_max",
|
||||||
|
Help: "Largest receive window size",
|
||||||
|
})
|
||||||
|
prometheus.MustRegister(receiveWindowSizeMax)
|
||||||
|
|
||||||
|
sendWindowSizeMin := prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "send_window_min",
|
||||||
|
Help: "Smallest send window size",
|
||||||
|
})
|
||||||
|
prometheus.MustRegister(sendWindowSizeMin)
|
||||||
|
|
||||||
|
sendWindowSizeMax := prometheus.NewGauge(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "send_window_max",
|
||||||
|
Help: "Largest send window size",
|
||||||
|
})
|
||||||
|
prometheus.MustRegister(sendWindowSizeMax)
|
||||||
|
|
||||||
|
responseByCode := prometheus.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Name: "response_by_code",
|
||||||
|
Help: "Count of responses by HTTP status code",
|
||||||
|
},
|
||||||
|
[]string{"status_code"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(responseByCode)
|
||||||
|
|
||||||
|
responseCodePerTunnel := prometheus.NewCounterVec(
|
||||||
|
prometheus.CounterOpts{
|
||||||
|
Name: "response_code_per_tunnel",
|
||||||
|
Help: "Count of responses by HTTP status code fore each tunnel",
|
||||||
|
},
|
||||||
|
[]string{"connection_id", "status_code"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(responseCodePerTunnel)
|
||||||
|
|
||||||
|
serverLocations := prometheus.NewGaugeVec(
|
||||||
|
prometheus.GaugeOpts{
|
||||||
|
Name: "server_locations",
|
||||||
|
Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.",
|
||||||
|
},
|
||||||
|
[]string{"connection_id", "location"},
|
||||||
|
)
|
||||||
|
prometheus.MustRegister(serverLocations)
|
||||||
|
|
||||||
|
return &TunnelMetrics{
|
||||||
|
totalRequests: totalRequests,
|
||||||
|
requestsPerTunnel: requestsPerTunnel,
|
||||||
|
concurrentRequestsPerTunnel: concurrentRequestsPerTunnel,
|
||||||
|
concurrentRequests: make(map[string]uint64),
|
||||||
|
maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
|
||||||
|
maxConcurrentRequests: make(map[string]uint64),
|
||||||
|
rtt: rtt,
|
||||||
|
rttMin: rttMin,
|
||||||
|
rttMax: rttMax,
|
||||||
|
timerRetries: timerRetries,
|
||||||
|
receiveWindowSizeAve: receiveWindowSizeAve,
|
||||||
|
sendWindowSizeAve: sendWindowSizeAve,
|
||||||
|
receiveWindowSizeMin: receiveWindowSizeMin,
|
||||||
|
receiveWindowSizeMax: receiveWindowSizeMax,
|
||||||
|
sendWindowSizeMin: sendWindowSizeMin,
|
||||||
|
sendWindowSizeMax: sendWindowSizeMax,
|
||||||
|
responseByCode: responseByCode,
|
||||||
|
responseCodePerTunnel: responseCodePerTunnel,
|
||||||
|
serverLocations: serverLocations,
|
||||||
|
oldServerLocations: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TunnelMetrics) updateTunnelFlowControlMetrics(metrics *h2mux.FlowControlMetrics) {
|
||||||
|
t.receiveWindowSizeAve.Set(float64(metrics.AverageReceiveWindowSize))
|
||||||
|
t.sendWindowSizeAve.Set(float64(metrics.AverageSendWindowSize))
|
||||||
|
t.receiveWindowSizeMin.Set(float64(metrics.MinReceiveWindowSize))
|
||||||
|
t.receiveWindowSizeMax.Set(float64(metrics.MaxReceiveWindowSize))
|
||||||
|
t.sendWindowSizeMin.Set(float64(metrics.MinSendWindowSize))
|
||||||
|
t.sendWindowSizeMax.Set(float64(metrics.MaxSendWindowSize))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TunnelMetrics) incrementRequests(connectionID string) {
|
||||||
|
t.concurrentRequestsLock.Lock()
|
||||||
|
var concurrentRequests uint64
|
||||||
|
var ok bool
|
||||||
|
if concurrentRequests, ok = t.concurrentRequests[connectionID]; ok {
|
||||||
|
t.concurrentRequests[connectionID] += 1
|
||||||
|
concurrentRequests++
|
||||||
|
} else {
|
||||||
|
t.concurrentRequests[connectionID] = 1
|
||||||
|
concurrentRequests = 1
|
||||||
|
}
|
||||||
|
if maxConcurrentRequests, ok := t.maxConcurrentRequests[connectionID]; (ok && maxConcurrentRequests < concurrentRequests) || !ok {
|
||||||
|
t.maxConcurrentRequests[connectionID] = concurrentRequests
|
||||||
|
t.maxConcurrentRequestsPerTunnel.WithLabelValues(connectionID).Set(float64(concurrentRequests))
|
||||||
|
}
|
||||||
|
t.concurrentRequestsLock.Unlock()
|
||||||
|
|
||||||
|
t.totalRequests.Inc()
|
||||||
|
t.requestsPerTunnel.WithLabelValues(connectionID).Inc()
|
||||||
|
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
|
||||||
|
t.concurrentRequestsLock.Lock()
|
||||||
|
if _, ok := t.concurrentRequests[connectionID]; ok {
|
||||||
|
t.concurrentRequests[connectionID] -= 1
|
||||||
|
} else {
|
||||||
|
log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.")
|
||||||
|
}
|
||||||
|
t.concurrentRequestsLock.Unlock()
|
||||||
|
|
||||||
|
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TunnelMetrics) incrementResponses(connectionID, code string) {
|
||||||
|
t.responseByCode.WithLabelValues(code).Inc()
|
||||||
|
t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) {
|
||||||
|
t.locationLock.Lock()
|
||||||
|
defer t.locationLock.Unlock()
|
||||||
|
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
|
||||||
|
return
|
||||||
|
} else if ok {
|
||||||
|
t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec()
|
||||||
|
}
|
||||||
|
t.serverLocations.WithLabelValues(connectionID, loc).Inc()
|
||||||
|
t.oldServerLocations[connectionID] = loc
|
||||||
|
}
|
|
@ -0,0 +1,121 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// can only be called once
|
||||||
|
var m = NewTunnelMetrics()
|
||||||
|
|
||||||
|
func TestConcurrentRequestsSingleTunnel(t *testing.T) {
|
||||||
|
routines := 20
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(routines)
|
||||||
|
for i := 0; i < routines; i++ {
|
||||||
|
go func() {
|
||||||
|
m.incrementRequests("0")
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
assert.Len(t, m.concurrentRequests, 1)
|
||||||
|
assert.Equal(t, uint64(routines), m.concurrentRequests["0"])
|
||||||
|
assert.Len(t, m.maxConcurrentRequests, 1)
|
||||||
|
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
|
||||||
|
|
||||||
|
wg.Add(routines / 2)
|
||||||
|
for i := 0; i < routines/2; i++ {
|
||||||
|
go func() {
|
||||||
|
m.decrementConcurrentRequests("0")
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
assert.Equal(t, uint64(routines-routines/2), m.concurrentRequests["0"])
|
||||||
|
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrentRequestsMultiTunnel(t *testing.T) {
|
||||||
|
m.concurrentRequests = make(map[string]uint64)
|
||||||
|
m.maxConcurrentRequests = make(map[string]uint64)
|
||||||
|
tunnels := 20
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(tunnels)
|
||||||
|
for i := 0; i < tunnels; i++ {
|
||||||
|
go func(i int) {
|
||||||
|
// if we have j < i, then tunnel 0 won't have a chance to call incrementRequests
|
||||||
|
for j := 0; j < i+1; j++ {
|
||||||
|
id := strconv.Itoa(i)
|
||||||
|
m.incrementRequests(id)
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
assert.Len(t, m.concurrentRequests, tunnels)
|
||||||
|
assert.Len(t, m.maxConcurrentRequests, tunnels)
|
||||||
|
for i := 0; i < tunnels; i++ {
|
||||||
|
id := strconv.Itoa(i)
|
||||||
|
assert.Equal(t, uint64(i+1), m.concurrentRequests[id])
|
||||||
|
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(tunnels)
|
||||||
|
for i := 0; i < tunnels; i++ {
|
||||||
|
go func(i int) {
|
||||||
|
for j := 0; j < i+1; j++ {
|
||||||
|
id := strconv.Itoa(i)
|
||||||
|
m.decrementConcurrentRequests(id)
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
assert.Len(t, m.concurrentRequests, tunnels)
|
||||||
|
assert.Len(t, m.maxConcurrentRequests, tunnels)
|
||||||
|
for i := 0; i < tunnels; i++ {
|
||||||
|
id := strconv.Itoa(i)
|
||||||
|
assert.Equal(t, uint64(0), m.concurrentRequests[id])
|
||||||
|
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterServerLocation(t *testing.T) {
|
||||||
|
tunnels := 20
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(tunnels)
|
||||||
|
for i := 0; i < tunnels; i++ {
|
||||||
|
go func(i int) {
|
||||||
|
id := strconv.Itoa(i)
|
||||||
|
m.registerServerLocation(id, "LHR")
|
||||||
|
wg.Done()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
for i := 0; i < tunnels; i++ {
|
||||||
|
id := strconv.Itoa(i)
|
||||||
|
assert.Equal(t, "LHR", m.oldServerLocations[id])
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(tunnels)
|
||||||
|
for i := 0; i < tunnels; i++ {
|
||||||
|
go func(i int) {
|
||||||
|
id := strconv.Itoa(i)
|
||||||
|
m.registerServerLocation(id, "AUS")
|
||||||
|
wg.Done()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
for i := 0; i < tunnels; i++ {
|
||||||
|
id := strconv.Itoa(i)
|
||||||
|
assert.Equal(t, "AUS", m.oldServerLocations[id])
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,174 @@
|
||||||
|
package origin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Waiting time before retrying a failed tunnel connection
|
||||||
|
tunnelRetryDuration = time.Minute
|
||||||
|
// Limit on the exponential backoff time period. (2^5 = 32 minutes)
|
||||||
|
tunnelRetryLimit = 5
|
||||||
|
// SRV record resolution TTL
|
||||||
|
resolveTTL = time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
type Supervisor struct {
|
||||||
|
config *TunnelConfig
|
||||||
|
edgeIPs []*net.TCPAddr
|
||||||
|
lastResolve time.Time
|
||||||
|
resolverC chan resolveResult
|
||||||
|
tunnelErrors chan tunnelError
|
||||||
|
tunnelsConnecting map[int]chan struct{}
|
||||||
|
nextConnectedIndex int
|
||||||
|
nextConnectedSignal chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type resolveResult struct {
|
||||||
|
edgeIPs []*net.TCPAddr
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type tunnelError struct {
|
||||||
|
index int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSupervisor(config *TunnelConfig) *Supervisor {
|
||||||
|
return &Supervisor{
|
||||||
|
config: config,
|
||||||
|
tunnelErrors: make(chan tunnelError),
|
||||||
|
tunnelsConnecting: map[int]chan struct{}{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) error {
|
||||||
|
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tunnelsActive := s.config.HAConnections
|
||||||
|
tunnelsWaiting := []int{}
|
||||||
|
backoff := BackoffHandler{MaxRetries: tunnelRetryLimit, BaseTime: tunnelRetryDuration, RetryForever: true}
|
||||||
|
var backoffTimer <-chan time.Time
|
||||||
|
for tunnelsActive > 0 {
|
||||||
|
select {
|
||||||
|
// Context cancelled
|
||||||
|
case <-ctx.Done():
|
||||||
|
for tunnelsActive > 0 {
|
||||||
|
<-s.tunnelErrors
|
||||||
|
tunnelsActive--
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
// startTunnel returned with error
|
||||||
|
// (note that this may also be caused by context cancellation)
|
||||||
|
case tunnelError := <-s.tunnelErrors:
|
||||||
|
tunnelsActive--
|
||||||
|
if tunnelError.err != nil {
|
||||||
|
log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
|
||||||
|
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
||||||
|
s.waitForNextTunnel(tunnelError.index)
|
||||||
|
if backoffTimer != nil {
|
||||||
|
backoffTimer = backoff.BackoffTimer()
|
||||||
|
}
|
||||||
|
s.refreshEdgeIPs()
|
||||||
|
}
|
||||||
|
// Backoff was set and its timer expired
|
||||||
|
case <-backoffTimer:
|
||||||
|
backoffTimer = nil
|
||||||
|
for _, index := range tunnelsWaiting {
|
||||||
|
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
|
||||||
|
}
|
||||||
|
tunnelsActive += len(tunnelsWaiting)
|
||||||
|
tunnelsWaiting = nil
|
||||||
|
// Tunnel successfully connected
|
||||||
|
case <-s.nextConnectedSignal:
|
||||||
|
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
|
||||||
|
// No more tunnels outstanding, clear backoff timer
|
||||||
|
backoff.SetGracePeriod()
|
||||||
|
}
|
||||||
|
// DNS resolution returned
|
||||||
|
case result := <-s.resolverC:
|
||||||
|
s.lastResolve = time.Now()
|
||||||
|
s.resolverC = nil
|
||||||
|
if result.err == nil {
|
||||||
|
log.Debug("Service discovery refresh complete")
|
||||||
|
s.edgeIPs = result.edgeIPs
|
||||||
|
} else {
|
||||||
|
log.WithError(result.err).Error("Service discovery error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("All tunnels terminated")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
|
||||||
|
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.edgeIPs = edgeIPs
|
||||||
|
s.lastResolve = time.Now()
|
||||||
|
go s.startTunnel(ctx, 0, connectedSignal)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
<-s.tunnelErrors
|
||||||
|
return nil
|
||||||
|
case tunnelError := <-s.tunnelErrors:
|
||||||
|
return tunnelError.err
|
||||||
|
case <-connectedSignal:
|
||||||
|
}
|
||||||
|
// At least one successful connection, so start the rest
|
||||||
|
for i := 1; i < s.config.HAConnections; i++ {
|
||||||
|
go s.startTunnel(ctx, i, make(chan struct{}))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||||
|
// s.tunnelErrors.
|
||||||
|
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}) {
|
||||||
|
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal)
|
||||||
|
s.tunnelErrors <- tunnelError{index: index, err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) newConnectedTunnelSignal(index int) chan struct{} {
|
||||||
|
signal := make(chan struct{})
|
||||||
|
s.tunnelsConnecting[index] = signal
|
||||||
|
s.nextConnectedSignal = signal
|
||||||
|
s.nextConnectedIndex = index
|
||||||
|
return signal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) waitForNextTunnel(index int) bool {
|
||||||
|
delete(s.tunnelsConnecting, index)
|
||||||
|
s.nextConnectedSignal = nil
|
||||||
|
for k, v := range s.tunnelsConnecting {
|
||||||
|
s.nextConnectedIndex = k
|
||||||
|
s.nextConnectedSignal = v
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr {
|
||||||
|
return s.edgeIPs[index%len(s.edgeIPs)]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Supervisor) refreshEdgeIPs() {
|
||||||
|
if s.resolverC != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if time.Since(s.lastResolve) < resolveTTL {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.resolverC = make(chan resolveResult)
|
||||||
|
go func() {
|
||||||
|
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
|
||||||
|
s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err}
|
||||||
|
}()
|
||||||
|
}
|
Loading…
Reference in New Issue