Add new files missed from release 2017.11.3
This commit is contained in:
parent
ee499dfa88
commit
d40eb85da6
|
@ -0,0 +1,82 @@
|
|||
package origin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
// Used to discover HA Warp servers
|
||||
srvService = "warp"
|
||||
srvProto = "tcp"
|
||||
srvName = "cloudflarewarp.com"
|
||||
)
|
||||
|
||||
func ResolveEdgeIPs(addresses []string) ([]*net.TCPAddr, error) {
|
||||
if len(addresses) > 0 {
|
||||
var tcpAddrs []*net.TCPAddr
|
||||
for _, address := range addresses {
|
||||
// Addresses specified (for testing, usually)
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tcpAddrs = append(tcpAddrs, tcpAddr)
|
||||
}
|
||||
return tcpAddrs, nil
|
||||
}
|
||||
// HA service discovery lookup
|
||||
_, addrs, err := net.LookupSRV(srvService, srvProto, srvName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var resolvedIPsPerCNAME [][]*net.TCPAddr
|
||||
var lookupErr error
|
||||
for _, addr := range addrs {
|
||||
ips, err := ResolveSRVToTCP(addr)
|
||||
if err != nil || len(ips) == 0 {
|
||||
// don't return early, we might be able to resolve other addresses
|
||||
lookupErr = err
|
||||
continue
|
||||
}
|
||||
resolvedIPsPerCNAME = append(resolvedIPsPerCNAME, ips)
|
||||
}
|
||||
ips := FlattenServiceIPs(resolvedIPsPerCNAME)
|
||||
if lookupErr == nil && len(ips) == 0 {
|
||||
return nil, fmt.Errorf("Unknown service discovery error")
|
||||
}
|
||||
return ips, lookupErr
|
||||
}
|
||||
|
||||
func ResolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) {
|
||||
ips, err := net.LookupIP(srv.Target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addrs := make([]*net.TCPAddr, len(ips))
|
||||
for i, ip := range ips {
|
||||
addrs[i] = &net.TCPAddr{IP: ip, Port: int(srv.Port)}
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// FlattenServiceIPs transposes and flattens the input slices such that the
|
||||
// first element of the n inner slices are the first n elements of the result.
|
||||
func FlattenServiceIPs(ipsByService [][]*net.TCPAddr) []*net.TCPAddr {
|
||||
var result []*net.TCPAddr
|
||||
for len(ipsByService) > 0 {
|
||||
filtered := ipsByService[:0]
|
||||
for _, ips := range ipsByService {
|
||||
if len(ips) == 0 {
|
||||
// sanity check
|
||||
continue
|
||||
}
|
||||
result = append(result, ips[0])
|
||||
if len(ips) > 1 {
|
||||
filtered = append(filtered, ips[1:])
|
||||
}
|
||||
}
|
||||
ipsByService = filtered
|
||||
}
|
||||
return result
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package origin
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFlattenServiceIPs(t *testing.T) {
|
||||
result := FlattenServiceIPs([][]*net.TCPAddr{
|
||||
[]*net.TCPAddr{
|
||||
&net.TCPAddr{Port: 1},
|
||||
&net.TCPAddr{Port: 2},
|
||||
&net.TCPAddr{Port: 3},
|
||||
&net.TCPAddr{Port: 4},
|
||||
},
|
||||
[]*net.TCPAddr{
|
||||
&net.TCPAddr{Port: 10},
|
||||
&net.TCPAddr{Port: 12},
|
||||
&net.TCPAddr{Port: 13},
|
||||
},
|
||||
[]*net.TCPAddr{
|
||||
&net.TCPAddr{Port: 21},
|
||||
&net.TCPAddr{Port: 22},
|
||||
&net.TCPAddr{Port: 23},
|
||||
&net.TCPAddr{Port: 24},
|
||||
&net.TCPAddr{Port: 25},
|
||||
},
|
||||
})
|
||||
assert.EqualValues(t, []*net.TCPAddr{
|
||||
&net.TCPAddr{Port: 1},
|
||||
&net.TCPAddr{Port: 10},
|
||||
&net.TCPAddr{Port: 21},
|
||||
&net.TCPAddr{Port: 2},
|
||||
&net.TCPAddr{Port: 12},
|
||||
&net.TCPAddr{Port: 22},
|
||||
&net.TCPAddr{Port: 3},
|
||||
&net.TCPAddr{Port: 13},
|
||||
&net.TCPAddr{Port: 23},
|
||||
&net.TCPAddr{Port: 4},
|
||||
&net.TCPAddr{Port: 24},
|
||||
&net.TCPAddr{Port: 25},
|
||||
}, result)
|
||||
}
|
|
@ -0,0 +1,258 @@
|
|||
package origin
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/cloudflare/cloudflare-warp/h2mux"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
type TunnelMetrics struct {
|
||||
totalRequests prometheus.Counter
|
||||
requestsPerTunnel *prometheus.CounterVec
|
||||
// concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests
|
||||
concurrentRequestsLock sync.Mutex
|
||||
concurrentRequestsPerTunnel *prometheus.GaugeVec
|
||||
// concurrentRequests records count of concurrent requests for each tunnel
|
||||
concurrentRequests map[string]uint64
|
||||
maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
|
||||
// concurrentRequests records max count of concurrent requests for each tunnel
|
||||
maxConcurrentRequests map[string]uint64
|
||||
rtt prometheus.Gauge
|
||||
rttMin prometheus.Gauge
|
||||
rttMax prometheus.Gauge
|
||||
timerRetries prometheus.Gauge
|
||||
receiveWindowSizeAve prometheus.Gauge
|
||||
sendWindowSizeAve prometheus.Gauge
|
||||
receiveWindowSizeMin prometheus.Gauge
|
||||
receiveWindowSizeMax prometheus.Gauge
|
||||
sendWindowSizeMin prometheus.Gauge
|
||||
sendWindowSizeMax prometheus.Gauge
|
||||
responseByCode *prometheus.CounterVec
|
||||
responseCodePerTunnel *prometheus.CounterVec
|
||||
serverLocations *prometheus.GaugeVec
|
||||
// locationLock is a mutex for oldServerLocations
|
||||
locationLock sync.Mutex
|
||||
// oldServerLocations stores the last server the tunnel was connected to
|
||||
oldServerLocations map[string]string
|
||||
}
|
||||
|
||||
// Metrics that can be collected without asking the edge
|
||||
func NewTunnelMetrics() *TunnelMetrics {
|
||||
totalRequests := prometheus.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "total_requests",
|
||||
Help: "Amount of requests proxied through all the tunnels",
|
||||
})
|
||||
prometheus.MustRegister(totalRequests)
|
||||
|
||||
requestsPerTunnel := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "requests_per_tunnel",
|
||||
Help: "Amount of requests proxied through each tunnel",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(requestsPerTunnel)
|
||||
|
||||
concurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "concurrent_requests_per_tunnel",
|
||||
Help: "Concurrent requests proxied through each tunnel",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(concurrentRequestsPerTunnel)
|
||||
|
||||
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "max_concurrent_requests_per_tunnel",
|
||||
Help: "Largest number of concurrent requests proxied through each tunnel so far",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
|
||||
|
||||
rtt := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "rtt",
|
||||
Help: "Round-trip time",
|
||||
})
|
||||
prometheus.MustRegister(rtt)
|
||||
|
||||
rttMin := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "rtt_min",
|
||||
Help: "Shortest round-trip time",
|
||||
})
|
||||
prometheus.MustRegister(rttMin)
|
||||
|
||||
rttMax := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "rtt_max",
|
||||
Help: "Longest round-trip time",
|
||||
})
|
||||
prometheus.MustRegister(rttMax)
|
||||
|
||||
timerRetries := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "timer_retries",
|
||||
Help: "Unacknowledged heart beats count",
|
||||
})
|
||||
prometheus.MustRegister(timerRetries)
|
||||
|
||||
receiveWindowSizeAve := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "receive_window_ave",
|
||||
Help: "Average receive window size",
|
||||
})
|
||||
prometheus.MustRegister(receiveWindowSizeAve)
|
||||
|
||||
sendWindowSizeAve := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "send_window_ave",
|
||||
Help: "Average send window size",
|
||||
})
|
||||
prometheus.MustRegister(sendWindowSizeAve)
|
||||
|
||||
receiveWindowSizeMin := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "receive_window_min",
|
||||
Help: "Smallest receive window size",
|
||||
})
|
||||
prometheus.MustRegister(receiveWindowSizeMin)
|
||||
|
||||
receiveWindowSizeMax := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "receive_window_max",
|
||||
Help: "Largest receive window size",
|
||||
})
|
||||
prometheus.MustRegister(receiveWindowSizeMax)
|
||||
|
||||
sendWindowSizeMin := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "send_window_min",
|
||||
Help: "Smallest send window size",
|
||||
})
|
||||
prometheus.MustRegister(sendWindowSizeMin)
|
||||
|
||||
sendWindowSizeMax := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "send_window_max",
|
||||
Help: "Largest send window size",
|
||||
})
|
||||
prometheus.MustRegister(sendWindowSizeMax)
|
||||
|
||||
responseByCode := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "response_by_code",
|
||||
Help: "Count of responses by HTTP status code",
|
||||
},
|
||||
[]string{"status_code"},
|
||||
)
|
||||
prometheus.MustRegister(responseByCode)
|
||||
|
||||
responseCodePerTunnel := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "response_code_per_tunnel",
|
||||
Help: "Count of responses by HTTP status code fore each tunnel",
|
||||
},
|
||||
[]string{"connection_id", "status_code"},
|
||||
)
|
||||
prometheus.MustRegister(responseCodePerTunnel)
|
||||
|
||||
serverLocations := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "server_locations",
|
||||
Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.",
|
||||
},
|
||||
[]string{"connection_id", "location"},
|
||||
)
|
||||
prometheus.MustRegister(serverLocations)
|
||||
|
||||
return &TunnelMetrics{
|
||||
totalRequests: totalRequests,
|
||||
requestsPerTunnel: requestsPerTunnel,
|
||||
concurrentRequestsPerTunnel: concurrentRequestsPerTunnel,
|
||||
concurrentRequests: make(map[string]uint64),
|
||||
maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
|
||||
maxConcurrentRequests: make(map[string]uint64),
|
||||
rtt: rtt,
|
||||
rttMin: rttMin,
|
||||
rttMax: rttMax,
|
||||
timerRetries: timerRetries,
|
||||
receiveWindowSizeAve: receiveWindowSizeAve,
|
||||
sendWindowSizeAve: sendWindowSizeAve,
|
||||
receiveWindowSizeMin: receiveWindowSizeMin,
|
||||
receiveWindowSizeMax: receiveWindowSizeMax,
|
||||
sendWindowSizeMin: sendWindowSizeMin,
|
||||
sendWindowSizeMax: sendWindowSizeMax,
|
||||
responseByCode: responseByCode,
|
||||
responseCodePerTunnel: responseCodePerTunnel,
|
||||
serverLocations: serverLocations,
|
||||
oldServerLocations: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) updateTunnelFlowControlMetrics(metrics *h2mux.FlowControlMetrics) {
|
||||
t.receiveWindowSizeAve.Set(float64(metrics.AverageReceiveWindowSize))
|
||||
t.sendWindowSizeAve.Set(float64(metrics.AverageSendWindowSize))
|
||||
t.receiveWindowSizeMin.Set(float64(metrics.MinReceiveWindowSize))
|
||||
t.receiveWindowSizeMax.Set(float64(metrics.MaxReceiveWindowSize))
|
||||
t.sendWindowSizeMin.Set(float64(metrics.MinSendWindowSize))
|
||||
t.sendWindowSizeMax.Set(float64(metrics.MaxSendWindowSize))
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) incrementRequests(connectionID string) {
|
||||
t.concurrentRequestsLock.Lock()
|
||||
var concurrentRequests uint64
|
||||
var ok bool
|
||||
if concurrentRequests, ok = t.concurrentRequests[connectionID]; ok {
|
||||
t.concurrentRequests[connectionID] += 1
|
||||
concurrentRequests++
|
||||
} else {
|
||||
t.concurrentRequests[connectionID] = 1
|
||||
concurrentRequests = 1
|
||||
}
|
||||
if maxConcurrentRequests, ok := t.maxConcurrentRequests[connectionID]; (ok && maxConcurrentRequests < concurrentRequests) || !ok {
|
||||
t.maxConcurrentRequests[connectionID] = concurrentRequests
|
||||
t.maxConcurrentRequestsPerTunnel.WithLabelValues(connectionID).Set(float64(concurrentRequests))
|
||||
}
|
||||
t.concurrentRequestsLock.Unlock()
|
||||
|
||||
t.totalRequests.Inc()
|
||||
t.requestsPerTunnel.WithLabelValues(connectionID).Inc()
|
||||
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
|
||||
t.concurrentRequestsLock.Lock()
|
||||
if _, ok := t.concurrentRequests[connectionID]; ok {
|
||||
t.concurrentRequests[connectionID] -= 1
|
||||
} else {
|
||||
log.Error("Concurrent requests per tunnel metrics went wrong; you can't decrement concurrent requests count without increment it first.")
|
||||
}
|
||||
t.concurrentRequestsLock.Unlock()
|
||||
|
||||
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) incrementResponses(connectionID, code string) {
|
||||
t.responseByCode.WithLabelValues(code).Inc()
|
||||
t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
|
||||
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) {
|
||||
t.locationLock.Lock()
|
||||
defer t.locationLock.Unlock()
|
||||
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
|
||||
return
|
||||
} else if ok {
|
||||
t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec()
|
||||
}
|
||||
t.serverLocations.WithLabelValues(connectionID, loc).Inc()
|
||||
t.oldServerLocations[connectionID] = loc
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
package origin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// can only be called once
|
||||
var m = NewTunnelMetrics()
|
||||
|
||||
func TestConcurrentRequestsSingleTunnel(t *testing.T) {
|
||||
routines := 20
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(routines)
|
||||
for i := 0; i < routines; i++ {
|
||||
go func() {
|
||||
m.incrementRequests("0")
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Len(t, m.concurrentRequests, 1)
|
||||
assert.Equal(t, uint64(routines), m.concurrentRequests["0"])
|
||||
assert.Len(t, m.maxConcurrentRequests, 1)
|
||||
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
|
||||
|
||||
wg.Add(routines / 2)
|
||||
for i := 0; i < routines/2; i++ {
|
||||
go func() {
|
||||
m.decrementConcurrentRequests("0")
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Equal(t, uint64(routines-routines/2), m.concurrentRequests["0"])
|
||||
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
|
||||
}
|
||||
|
||||
func TestConcurrentRequestsMultiTunnel(t *testing.T) {
|
||||
m.concurrentRequests = make(map[string]uint64)
|
||||
m.maxConcurrentRequests = make(map[string]uint64)
|
||||
tunnels := 20
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
go func(i int) {
|
||||
// if we have j < i, then tunnel 0 won't have a chance to call incrementRequests
|
||||
for j := 0; j < i+1; j++ {
|
||||
id := strconv.Itoa(i)
|
||||
m.incrementRequests(id)
|
||||
}
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Len(t, m.concurrentRequests, tunnels)
|
||||
assert.Len(t, m.maxConcurrentRequests, tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
assert.Equal(t, uint64(i+1), m.concurrentRequests[id])
|
||||
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
|
||||
}
|
||||
|
||||
wg.Add(tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
go func(i int) {
|
||||
for j := 0; j < i+1; j++ {
|
||||
id := strconv.Itoa(i)
|
||||
m.decrementConcurrentRequests(id)
|
||||
}
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Len(t, m.concurrentRequests, tunnels)
|
||||
assert.Len(t, m.maxConcurrentRequests, tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
assert.Equal(t, uint64(0), m.concurrentRequests[id])
|
||||
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRegisterServerLocation(t *testing.T) {
|
||||
tunnels := 20
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
go func(i int) {
|
||||
id := strconv.Itoa(i)
|
||||
m.registerServerLocation(id, "LHR")
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
for i := 0; i < tunnels; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
assert.Equal(t, "LHR", m.oldServerLocations[id])
|
||||
}
|
||||
|
||||
wg.Add(tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
go func(i int) {
|
||||
id := strconv.Itoa(i)
|
||||
m.registerServerLocation(id, "AUS")
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
for i := 0; i < tunnels; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
assert.Equal(t, "AUS", m.oldServerLocations[id])
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
package origin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
// Waiting time before retrying a failed tunnel connection
|
||||
tunnelRetryDuration = time.Minute
|
||||
// Limit on the exponential backoff time period. (2^5 = 32 minutes)
|
||||
tunnelRetryLimit = 5
|
||||
// SRV record resolution TTL
|
||||
resolveTTL = time.Hour
|
||||
)
|
||||
|
||||
type Supervisor struct {
|
||||
config *TunnelConfig
|
||||
edgeIPs []*net.TCPAddr
|
||||
lastResolve time.Time
|
||||
resolverC chan resolveResult
|
||||
tunnelErrors chan tunnelError
|
||||
tunnelsConnecting map[int]chan struct{}
|
||||
nextConnectedIndex int
|
||||
nextConnectedSignal chan struct{}
|
||||
}
|
||||
|
||||
type resolveResult struct {
|
||||
edgeIPs []*net.TCPAddr
|
||||
err error
|
||||
}
|
||||
|
||||
type tunnelError struct {
|
||||
index int
|
||||
err error
|
||||
}
|
||||
|
||||
func NewSupervisor(config *TunnelConfig) *Supervisor {
|
||||
return &Supervisor{
|
||||
config: config,
|
||||
tunnelErrors: make(chan tunnelError),
|
||||
tunnelsConnecting: map[int]chan struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) error {
|
||||
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||
return err
|
||||
}
|
||||
tunnelsActive := s.config.HAConnections
|
||||
tunnelsWaiting := []int{}
|
||||
backoff := BackoffHandler{MaxRetries: tunnelRetryLimit, BaseTime: tunnelRetryDuration, RetryForever: true}
|
||||
var backoffTimer <-chan time.Time
|
||||
for tunnelsActive > 0 {
|
||||
select {
|
||||
// Context cancelled
|
||||
case <-ctx.Done():
|
||||
for tunnelsActive > 0 {
|
||||
<-s.tunnelErrors
|
||||
tunnelsActive--
|
||||
}
|
||||
return nil
|
||||
// startTunnel returned with error
|
||||
// (note that this may also be caused by context cancellation)
|
||||
case tunnelError := <-s.tunnelErrors:
|
||||
tunnelsActive--
|
||||
if tunnelError.err != nil {
|
||||
log.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
|
||||
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
||||
s.waitForNextTunnel(tunnelError.index)
|
||||
if backoffTimer != nil {
|
||||
backoffTimer = backoff.BackoffTimer()
|
||||
}
|
||||
s.refreshEdgeIPs()
|
||||
}
|
||||
// Backoff was set and its timer expired
|
||||
case <-backoffTimer:
|
||||
backoffTimer = nil
|
||||
for _, index := range tunnelsWaiting {
|
||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
|
||||
}
|
||||
tunnelsActive += len(tunnelsWaiting)
|
||||
tunnelsWaiting = nil
|
||||
// Tunnel successfully connected
|
||||
case <-s.nextConnectedSignal:
|
||||
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
|
||||
// No more tunnels outstanding, clear backoff timer
|
||||
backoff.SetGracePeriod()
|
||||
}
|
||||
// DNS resolution returned
|
||||
case result := <-s.resolverC:
|
||||
s.lastResolve = time.Now()
|
||||
s.resolverC = nil
|
||||
if result.err == nil {
|
||||
log.Debug("Service discovery refresh complete")
|
||||
s.edgeIPs = result.edgeIPs
|
||||
} else {
|
||||
log.WithError(result.err).Error("Service discovery error")
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("All tunnels terminated")
|
||||
}
|
||||
|
||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
|
||||
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.edgeIPs = edgeIPs
|
||||
s.lastResolve = time.Now()
|
||||
go s.startTunnel(ctx, 0, connectedSignal)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
<-s.tunnelErrors
|
||||
return nil
|
||||
case tunnelError := <-s.tunnelErrors:
|
||||
return tunnelError.err
|
||||
case <-connectedSignal:
|
||||
}
|
||||
// At least one successful connection, so start the rest
|
||||
for i := 1; i < s.config.HAConnections; i++ {
|
||||
go s.startTunnel(ctx, i, make(chan struct{}))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||
// s.tunnelErrors.
|
||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}) {
|
||||
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal)
|
||||
s.tunnelErrors <- tunnelError{index: index, err: err}
|
||||
}
|
||||
|
||||
func (s *Supervisor) newConnectedTunnelSignal(index int) chan struct{} {
|
||||
signal := make(chan struct{})
|
||||
s.tunnelsConnecting[index] = signal
|
||||
s.nextConnectedSignal = signal
|
||||
s.nextConnectedIndex = index
|
||||
return signal
|
||||
}
|
||||
|
||||
func (s *Supervisor) waitForNextTunnel(index int) bool {
|
||||
delete(s.tunnelsConnecting, index)
|
||||
s.nextConnectedSignal = nil
|
||||
for k, v := range s.tunnelsConnecting {
|
||||
s.nextConnectedIndex = k
|
||||
s.nextConnectedSignal = v
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr {
|
||||
return s.edgeIPs[index%len(s.edgeIPs)]
|
||||
}
|
||||
|
||||
func (s *Supervisor) refreshEdgeIPs() {
|
||||
if s.resolverC != nil {
|
||||
return
|
||||
}
|
||||
if time.Since(s.lastResolve) < resolveTTL {
|
||||
return
|
||||
}
|
||||
s.resolverC = make(chan resolveResult)
|
||||
go func() {
|
||||
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
|
||||
s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err}
|
||||
}()
|
||||
}
|
Loading…
Reference in New Issue