diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index c746a585..4c38aaaf 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -14,7 +14,7 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" "github.com/cloudflare/cloudflared/cmd/cloudflared/config" - "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/origin" "github.com/cloudflare/cloudflared/tlsconfig" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -280,13 +280,13 @@ func prepareTunnelConfig( }, nil } -func serviceDiscoverer(c *cli.Context, logger *logrus.Logger) (connection.EdgeServiceDiscoverer, error) { +func serviceDiscoverer(c *cli.Context, logger *logrus.Logger) (*edgediscovery.Edge, error) { // If --edge is specfied, resolve edge server addresses if len(c.StringSlice("edge")) > 0 { - return connection.NewEdgeHostnameResolver(c.StringSlice("edge")) + return edgediscovery.StaticEdge(logger, c.StringSlice("edge")) } // Otherwise lookup edge server addresses through service discovery - return connection.NewEdgeAddrResolver(logger) + return edgediscovery.ResolveEdge(logger) } func isRunningFromTerminal() bool { diff --git a/connection/connection.go b/connection/connection.go index ab318d25..b898403f 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -18,9 +18,11 @@ const ( ) type Connection struct { - id uuid.UUID - muxer *h2mux.Muxer - addr *net.TCPAddr + id uuid.UUID + muxer *h2mux.Muxer + addr *net.TCPAddr + isLongLived bool + longLivedID int } func newConnection(muxer *h2mux.Muxer, addr *net.TCPAddr) (*Connection, error) { diff --git a/connection/discovery.go b/connection/discovery.go deleted file mode 100644 index 18c6cad5..00000000 --- a/connection/discovery.go +++ /dev/null @@ -1,420 +0,0 @@ -package connection - -import ( - "context" - "crypto/tls" - "fmt" - "math/rand" - "net" - "sync" - "time" - - "github.com/pkg/errors" - "github.com/sirupsen/logrus" -) - -const ( - // Used to discover HA origintunneld servers - srvService = "origintunneld" - srvProto = "tcp" - srvName = "argotunnel.com" - - // Used to fallback to DoT when we can't use the default resolver to - // discover HA origintunneld servers (GitHub issue #75). - dotServerName = "cloudflare-dns.com" - dotServerAddr = "1.1.1.1:853" - dotTimeout = time.Duration(15 * time.Second) - - // SRV record resolution TTL - resolveEdgeAddrTTL = 1 * time.Hour - - subsystemEdgeAddrResolver = "edgeAddrResolver" -) - -// Redeclare network functions so they can be overridden in tests. -var ( - netLookupSRV = net.LookupSRV - netLookupIP = net.LookupIP -) - -// If the call to net.LookupSRV fails, try to fall back to DoT from Cloudflare directly. -// -// Note: Instead of DoT, we could also have used DoH. Either of these: -// - directly via the JSON API (https://1.1.1.1/dns-query?ct=application/dns-json&name=_origintunneld._tcp.argotunnel.com&type=srv) -// - indirectly via `tunneldns.NewUpstreamHTTPS()` -// But both of these cases miss out on a key feature from the stdlib: -// "The returned records are sorted by priority and randomized by weight within a priority." -// (https://golang.org/pkg/net/#Resolver.LookupSRV) -// Does this matter? I don't know. It may someday. Let's use DoT so we don't need to worry about it. -// See also: Go feature request for stdlib-supported DoH: https://github.com/golang/go/issues/27552 -var fallbackLookupSRV = lookupSRVWithDOT - -var friendlyDNSErrorLines = []string{ - `Please try the following things to diagnose this issue:`, - ` 1. ensure that argotunnel.com is returning "origintunneld" service records.`, - ` Run your system's equivalent of: dig srv _origintunneld._tcp.argotunnel.com`, - ` 2. ensure that your DNS resolver is not returning compressed SRV records.`, - ` See GitHub issue https://github.com/golang/go/issues/27546`, - ` For example, you could use Cloudflare's 1.1.1.1 as your resolver:`, - ` https://developers.cloudflare.com/1.1.1.1/setting-up-1.1.1.1/`, -} - -// EdgeServiceDiscoverer is an interface for looking up Cloudflare's edge network addresses -type EdgeServiceDiscoverer interface { - // Addr returns an unused address to connect to cloudflare's edge network. - // Before this method returns, the address will be removed from the pool of available addresses, - // so the caller can assume they have exclusive access to the address for tunneling purposes. - // The caller should remember to put it back via ReplaceAddr or MarkAddrBad. - Addr() (*net.TCPAddr, error) - // AnyAddr returns an address to connect to cloudflare's edge network. - // It may or may not be in active use for a tunnel. - // The caller should NOT return it via ReplaceAddr or MarkAddrBad! - AnyAddr() (*net.TCPAddr, error) - // ReplaceAddr is called when the address is no longer needed, e.g. due to a scaling-down of numHAConnections. - // It returns the address to the pool of available addresses. - ReplaceAddr(addr *net.TCPAddr) - // MarkAddrBad is called when there was a connectivity error for the address. - // It marks the address as unused but doesn't return it to the pool of available addresses. - MarkAddrBad(addr *net.TCPAddr) - // AvailableAddrs returns the number of addresses available for use - // (less those that have been marked bad). - AvailableAddrs() int - // Refresh rediscovers Cloudflare's edge network addresses. - // It resets the state of "bad" addresses but not those in active use. - Refresh() error -} - -// EdgeAddrResolver discovers the addresses of Cloudflare's edge network through SRV record. -// It implements EdgeServiceDiscoverer interface -type EdgeAddrResolver struct { - sync.Mutex - // HA regions - regions []*region - // Logger for noteworthy events - logger *logrus.Entry -} - -type region struct { - // Addresses that we expect will be in active use - addrs []*net.TCPAddr - // Addresses that are in active use. - // This is actually a set of net.TCPAddr's, but we can't make a map like - // map[net.TCPAddr]bool - // since net.TCPAddr contains a field of type net.IP and therefore it cannot be used as a map key. - // So instead we use map[string]*net.TCPAddr, where the keys are obtained by net.TCPAddr.String(). - // (We keep the "raw" *net.TCPAddr values for the convenience of AnyAddr(). If that method didn't - // exist, we wouldn't strictly need the values, and this could be a map[string]bool.) - inUse map[string]*net.TCPAddr - // Addresses that were discarded due to a network error. - // Not sure what we'll do with these, but it feels good to keep them around for now. - bad []*net.TCPAddr -} - -func NewEdgeAddrResolver(logger *logrus.Logger) (EdgeServiceDiscoverer, error) { - r := &EdgeAddrResolver{ - logger: logger.WithField("subsystem", subsystemEdgeAddrResolver), - } - if err := r.Refresh(); err != nil { - return nil, err - } - return r, nil -} - -func (r *EdgeAddrResolver) Addr() (*net.TCPAddr, error) { - r.Lock() - defer r.Unlock() - - // compute the largest region based on len(addrs) - var largestRegion *region - { - if len(r.regions) == 0 { - return nil, errors.New("No HA regions") - } - largestRegion = r.regions[0] - for _, region := range r.regions[1:] { - if len(region.addrs) > len(largestRegion.addrs) { - largestRegion = region - } - } - if len(largestRegion.addrs) == 0 { - return nil, errors.New("No IP address to claim") - } - } - - var addr *net.TCPAddr - addr, largestRegion.addrs = popAddr(largestRegion.addrs) - largestRegion.inUse[addr.String()] = addr - return addr, nil -} - -func (r *EdgeAddrResolver) AnyAddr() (*net.TCPAddr, error) { - r.Lock() - defer r.Unlock() - for _, region := range r.regions { - // return an unused addr - if len(region.addrs) > 0 { - return region.addrs[rand.Intn(len(region.addrs))], nil - } - // return an addr that's in use - for _, addr := range region.inUse { - return addr, nil - } - } - return nil, fmt.Errorf("No IP addresses") -} - -func (r *EdgeAddrResolver) ReplaceAddr(addr *net.TCPAddr) { - r.Lock() - defer r.Unlock() - addrString := addr.String() - for _, region := range r.regions { - if _, ok := region.inUse[addrString]; ok { - delete(region.inUse, addrString) - region.addrs = append(region.addrs, addr) - break - } - } -} - -func (r *EdgeAddrResolver) MarkAddrBad(addr *net.TCPAddr) { - r.Lock() - defer r.Unlock() - addrString := addr.String() - for _, region := range r.regions { - if _, ok := region.inUse[addrString]; ok { - delete(region.inUse, addrString) - region.bad = append(region.bad, addr) - break - } - } -} - -func (r *EdgeAddrResolver) AvailableAddrs() int { - r.Lock() - defer r.Unlock() - result := 0 - for _, region := range r.regions { - result += len(region.addrs) - } - return result -} - -func (r *EdgeAddrResolver) Refresh() error { - addrLists, err := EdgeDiscovery(r.logger) - if err != nil { - return err - } - - r.Lock() - defer r.Unlock() - inUse := allInUse(r.regions) - r.regions = makeHARegions(addrLists, inUse) - return nil -} - -// EdgeDiscovery implements HA service discovery lookup. -func EdgeDiscovery(logger *logrus.Entry) ([][]*net.TCPAddr, error) { - _, addrs, err := netLookupSRV(srvService, srvProto, srvName) - if err != nil { - _, fallbackAddrs, fallbackErr := fallbackLookupSRV(srvService, srvProto, srvName) - if fallbackErr != nil || len(fallbackAddrs) == 0 { - // use the original DNS error `err` in messages, not `fallbackErr` - logger.Errorln("Error looking up Cloudflare edge IPs: the DNS query failed:", err) - for _, s := range friendlyDNSErrorLines { - logger.Errorln(s) - } - return nil, errors.Wrapf(err, "Could not lookup srv records on _%v._%v.%v", srvService, srvProto, srvName) - } - // Accept the fallback results and keep going - addrs = fallbackAddrs - } - - var resolvedIPsPerCNAME [][]*net.TCPAddr - for _, addr := range addrs { - ips, err := resolveSRVToTCP(addr) - if err != nil { - return nil, err - } - resolvedIPsPerCNAME = append(resolvedIPsPerCNAME, ips) - } - - return resolvedIPsPerCNAME, nil -} - -func lookupSRVWithDOT(service, proto, name string) (cname string, addrs []*net.SRV, err error) { - // Inspiration: https://github.com/artyom/dot/blob/master/dot.go - r := &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, _ string, _ string) (net.Conn, error) { - var dialer net.Dialer - conn, err := dialer.DialContext(ctx, "tcp", dotServerAddr) - if err != nil { - return nil, err - } - tlsConfig := &tls.Config{ServerName: dotServerName} - return tls.Client(conn, tlsConfig), nil - }, - } - ctx, cancel := context.WithTimeout(context.Background(), dotTimeout) - defer cancel() - return r.LookupSRV(ctx, srvService, srvProto, srvName) -} - -func resolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) { - ips, err := netLookupIP(srv.Target) - if err != nil { - return nil, errors.Wrapf(err, "Couldn't resolve SRV record %v", srv) - } - if len(ips) == 0 { - return nil, fmt.Errorf("SRV record %v had no IPs", srv) - } - addrs := make([]*net.TCPAddr, len(ips)) - for i, ip := range ips { - addrs[i] = &net.TCPAddr{IP: ip, Port: int(srv.Port)} - } - return addrs, nil -} - -// EdgeHostnameResolver discovers the addresses of Cloudflare's edge network via a list of server hostnames. -// It implements EdgeServiceDiscoverer interface, and is used mainly for testing connectivity. -type EdgeHostnameResolver struct { - sync.Mutex - // hostnames of edge servers - hostnames []string - // Addrs to connect to cloudflare's edge network - addrs []*net.TCPAddr - // Addresses that are in active use. - // This is actually a set of net.TCPAddr's. We have to encode the keys - // with .String(), since net.TCPAddr contains a field of type net.IP and - // therefore it cannot be used as a map key - inUse map[string]*net.TCPAddr - // Addresses that were discarded due to a network error. - // Not sure what we'll do with these, but it feels good to keep them around for now. - bad []*net.TCPAddr -} - -func NewEdgeHostnameResolver(edgeHostnames []string) (EdgeServiceDiscoverer, error) { - r := &EdgeHostnameResolver{ - hostnames: edgeHostnames, - inUse: map[string]*net.TCPAddr{}, - } - if err := r.Refresh(); err != nil { - return nil, err - } - return r, nil -} - -func (r *EdgeHostnameResolver) Addr() (*net.TCPAddr, error) { - r.Lock() - defer r.Unlock() - if len(r.addrs) == 0 { - return nil, errors.New("No IP address to claim") - } - var addr *net.TCPAddr - addr, r.addrs = popAddr(r.addrs) - r.inUse[addr.String()] = addr - return addr, nil -} - -func (r *EdgeHostnameResolver) AnyAddr() (*net.TCPAddr, error) { - r.Lock() - defer r.Unlock() - // return an unused addr - if len(r.addrs) > 0 { - return r.addrs[rand.Intn(len(r.addrs))], nil - } - // return an addr that's in use - for _, addr := range r.inUse { - return addr, nil - } - return nil, errors.New("No IP addresses") -} - -func (r *EdgeHostnameResolver) ReplaceAddr(addr *net.TCPAddr) { - r.Lock() - defer r.Unlock() - delete(r.inUse, addr.String()) - r.addrs = append(r.addrs, addr) -} -func (r *EdgeHostnameResolver) MarkAddrBad(addr *net.TCPAddr) { - r.Lock() - defer r.Unlock() - delete(r.inUse, addr.String()) - r.bad = append(r.bad, addr) -} - -func (r *EdgeHostnameResolver) AvailableAddrs() int { - r.Lock() - defer r.Unlock() - return len(r.addrs) -} - -func (r *EdgeHostnameResolver) Refresh() error { - newAddrs, err := ResolveAddrs(r.hostnames) - if err != nil { - return err - } - r.Lock() - defer r.Unlock() - var notInUse []*net.TCPAddr - for _, newAddr := range newAddrs { - if _, ok := r.inUse[newAddr.String()]; !ok { - notInUse = append(notInUse, newAddr) - } - } - r.addrs = notInUse - r.bad = nil - return nil -} - -// Resolve TCP address given a list of addresses. Address can be a hostname, however, it will return at most one -// of the hostname's IP addresses -func ResolveAddrs(addrs []string) ([]*net.TCPAddr, error) { - var tcpAddrs []*net.TCPAddr - for _, addr := range addrs { - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - tcpAddrs = append(tcpAddrs, tcpAddr) - } - return tcpAddrs, nil -} - -// Compute total set of IP addresses in use. This is useful if the regions -// are returned in a different order, or if an IP address is assigned to -// a different region for some reasion. -func allInUse(regions []*region) map[string]*net.TCPAddr { - result := make(map[string]*net.TCPAddr) - for _, region := range regions { - for k, v := range region.inUse { - result[k] = v - } - } - return result -} - -func makeHARegions(addrLists [][]*net.TCPAddr, inUse map[string]*net.TCPAddr) (regions []*region) { - for _, addrList := range addrLists { - region := ®ion{inUse: map[string]*net.TCPAddr{}} - for _, addr := range addrList { - addrString := addr.String() - // No matter what region `addr` used to belong to, it's now a part - // of this region, so add it to this region's `inUse` map. - if _, ok := inUse[addrString]; ok { - region.inUse[addrString] = addr - } else { - region.addrs = append(region.addrs, addr) - } - } - regions = append(regions, region) - } - return -} - -func popAddr(addrs []*net.TCPAddr) (*net.TCPAddr, []*net.TCPAddr) { - first := addrs[0] - addrs[0] = nil // prevent memory leak - addrs = addrs[1:] - return first, addrs -} diff --git a/connection/discovery_test.go b/connection/discovery_test.go deleted file mode 100644 index c74ebe81..00000000 --- a/connection/discovery_test.go +++ /dev/null @@ -1,317 +0,0 @@ -package connection - -import ( - "net" - "sync" - "testing" - "testing/quick" - "time" - - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" -) - -func TestEdgeDiscovery(t *testing.T) { - mockAddrs := newMockAddrs(19, 2, 5) - netLookupSRV = mockNetLookupSRV(mockAddrs) - netLookupIP = mockNetLookupIP(mockAddrs) - - expectedAddrSet := map[string]bool{} - for _, addrs := range mockAddrs.addrMap { - for _, addr := range addrs { - expectedAddrSet[addr.String()] = true - } - } - - addrLists, err := EdgeDiscovery(logrus.New().WithFields(logrus.Fields{})) - assert.NoError(t, err) - actualAddrSet := map[string]bool{} - for _, addrs := range addrLists { - for _, addr := range addrs { - actualAddrSet[addr.String()] = true - } - } - - assert.Equal(t, expectedAddrSet, actualAddrSet) -} - -func TestAllInUse(t *testing.T) { - for _, testCase := range []struct { - regions []*region - expected map[string]*net.TCPAddr - }{ - { - regions: nil, - expected: map[string]*net.TCPAddr{}, - }, - { - regions: []*region{ - ®ion{inUse: map[string]*net.TCPAddr{}}, - ®ion{inUse: map[string]*net.TCPAddr{}}, - }, - expected: map[string]*net.TCPAddr{}, - }, - { - regions: []*region{ - ®ion{inUse: map[string]*net.TCPAddr{":1": &net.TCPAddr{Port: 1}}}, - ®ion{inUse: map[string]*net.TCPAddr{":4": &net.TCPAddr{Port: 4}}}, - }, - expected: map[string]*net.TCPAddr{":1": &net.TCPAddr{Port: 1}, ":4": &net.TCPAddr{Port: 4}}, - }, - } { - actual := allInUse(testCase.regions) - assert.Equal(t, testCase.expected, actual) - } -} - -func TestMakeRegions(t *testing.T) { - for _, testCase := range []struct { - addrList [][]*net.TCPAddr - inUse map[string]*net.TCPAddr - expected []*region - }{ - { - addrList: [][]*net.TCPAddr{}, - expected: nil, - }, - { - addrList: [][]*net.TCPAddr{ - []*net.TCPAddr{&net.TCPAddr{Port: 1}, &net.TCPAddr{Port: 2}}, - }, - expected: []*region{ - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 1}, &net.TCPAddr{Port: 2}}, inUse: map[string]*net.TCPAddr{}}, - }, - }, - { - addrList: [][]*net.TCPAddr{ - []*net.TCPAddr{&net.TCPAddr{Port: 1}, &net.TCPAddr{Port: 2}}, - []*net.TCPAddr{&net.TCPAddr{Port: 3}, &net.TCPAddr{Port: 4}}, - }, - expected: []*region{ - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 1}, &net.TCPAddr{Port: 2}}, inUse: map[string]*net.TCPAddr{}}, - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 3}, &net.TCPAddr{Port: 4}}, inUse: map[string]*net.TCPAddr{}}, - }, - }, - { - addrList: [][]*net.TCPAddr{ - []*net.TCPAddr{&net.TCPAddr{Port: 1}, &net.TCPAddr{Port: 2}}, - []*net.TCPAddr{&net.TCPAddr{Port: 3}, &net.TCPAddr{Port: 4}}, - }, - inUse: map[string]*net.TCPAddr{ - ":1": &net.TCPAddr{Port: 1}, - ":4": &net.TCPAddr{Port: 4}, - }, - expected: []*region{ - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 2}}, inUse: map[string]*net.TCPAddr{":1": &net.TCPAddr{Port: 1}}}, - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 3}}, inUse: map[string]*net.TCPAddr{":4": &net.TCPAddr{Port: 4}}}, - }, - }, - } { - actual := makeHARegions(testCase.addrList, testCase.inUse) - assert.Equal(t, testCase.expected, actual) - } -} - -func assertIsBalanced(t *testing.T, regions []*region) bool { - // Compute max(len(region.addrs) for region in regions) - // No region should have significantly fewer addresses than this - var longestAddrs int - { - longestAddrs = 0 - for _, region := range regions { - if l := len(region.addrs); l > longestAddrs { - longestAddrs = l - } - } - } - for _, region := range regions { - if len(region.addrs) == longestAddrs || len(region.addrs) == longestAddrs-1 { - continue - } - return assert.Fail(t, - "found a region with %v free addrs, while the longest addrs list is %v", - len(region.addrs), longestAddrs) - } - return true -} - -// Various end-to-end tests, run with quickcheck (i.e. the testing/quick package) -func TestEdgeAddrResolver(t *testing.T) { - concurrentReplacement := func(mockAddrs mockAddrs) bool { - netLookupSRV = mockNetLookupSRV(mockAddrs) - netLookupIP = mockNetLookupIP(mockAddrs) - - resolver, err := NewEdgeAddrResolver(logrus.New()) - if !assert.NoError(t, err) { - return false - } - assert.Equal(t, mockAddrs.numAddrs, resolver.AvailableAddrs(), - "every address should be initially available") - - // Create several goroutines to simulate HA connections that acquire - // and replace IP addresses. - var wg sync.WaitGroup - wg.Add(mockAddrs.numAddrs) - for i := 0; i < mockAddrs.numAddrs; i++ { - go func() { - defer wg.Done() - const reconnectionCount = 50 - for i := 0; i < reconnectionCount; i++ { - if resolver.AvailableAddrs() == 0 { - err = resolver.Refresh() - assert.NoError(t, err) - } - addr, err := resolver.Addr() - if !assert.NoError(t, err) { - return - } - time.Sleep(0) // allow some other goroutine to run - resolver.ReplaceAddr(addr) - time.Sleep(0) // allow some other goroutine to run - } - }() - } - wg.Wait() - assert.Equal(t, mockAddrs.numAddrs, resolver.AvailableAddrs(), - "every address should be available after replacement") - return !t.Failed() - } - - badAddrWithRefresh := func(mockAddrs mockAddrs) bool { - netLookupSRV = mockNetLookupSRV(mockAddrs) - netLookupIP = mockNetLookupIP(mockAddrs) - - resolver, err := NewEdgeAddrResolver(logrus.New()) - if !assert.NoError(t, err) { - return false - } - assert.Equal(t, mockAddrs.numAddrs, resolver.AvailableAddrs(), - "every address should be initially available") - - var addrs []*net.TCPAddr - for i := 0; i < mockAddrs.numAddrs; i++ { - assert.Equal(t, mockAddrs.numAddrs-i, resolver.AvailableAddrs()) - addr, err := resolver.Addr() - assert.NoError(t, err) - addrs = append(addrs, addr) - } - assert.Equal(t, 0, resolver.AvailableAddrs(), "all addresses should have been taken") - _, err = resolver.Addr() - assert.Error(t, err) - - anyAddr, err := resolver.AnyAddr() - assert.NoError(t, err, "should still be okay to call AnyAddr") - - resolver.MarkAddrBad(anyAddr) - - assert.Equal(t, 0, resolver.AvailableAddrs(), "all addresses should still be used") - _, err = resolver.Addr() - assert.Error(t, err, "all addresses should still be used") - - err = resolver.Refresh() - assert.NoError(t, err, "Refresh() should have worked") - - assert.Equal(t, 1, resolver.AvailableAddrs(), - "Refresh() should have reset the state of the 'bad' address") - addr, err := resolver.Addr() - assert.NoError(t, err) - assert.Equal(t, anyAddr, addr) - - _, err = resolver.Addr() - assert.Error(t, err, "all addresses should be used again") - - return !t.Failed() - } - - assert.NoError(t, quick.Check(concurrentReplacement, nil)) - assert.NoError(t, quick.Check(badAddrWithRefresh, nil)) -} - -// "White-box" test: runs Addr() and checks internal state -func TestEdgeAddrResolver_Addr(t *testing.T) { - e := &EdgeAddrResolver{regions: nil} - addr, err := e.Addr() - assert.Error(t, err) - - testRegions := func() []*region { - return []*region{ - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 1}}, inUse: map[string]*net.TCPAddr{":2": &net.TCPAddr{Port: 2}, ":3": &net.TCPAddr{Port: 3}}}, - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 4}, &net.TCPAddr{Port: 5}}, inUse: map[string]*net.TCPAddr{":6": &net.TCPAddr{Port: 6}}}, - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 7}, &net.TCPAddr{Port: 8}}, inUse: map[string]*net.TCPAddr{":9": &net.TCPAddr{Port: 9}}}, - } - } - e = &EdgeAddrResolver{regions: testRegions()} - addr, err = e.Addr() - assert.NoError(t, err) - assert.Equal(t, &net.TCPAddr{Port: 4}, addr) - var expected []*region - { - expected = testRegions() - expected[1].addrs = expected[1].addrs[1:] - expected[1].inUse[":4"] = &net.TCPAddr{Port: 4} - } - assert.Equal(t, expected, e.regions) -} - -// "White-box" test: runs AnyAddr() and checks internal state -func TestEdgeAddrResolver_AnyAddr(t *testing.T) { - e := &EdgeAddrResolver{regions: nil} - addr, err := e.AnyAddr() - assert.Error(t, err) - - e = &EdgeAddrResolver{regions: []*region{®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 1}}, inUse: map[string]*net.TCPAddr{":2": &net.TCPAddr{Port: 2}}}}} - addr, err = e.AnyAddr() - assert.NoError(t, err) - assert.Equal(t, &net.TCPAddr{Port: 1}, addr, "should have chosen the inactive address") - - e = &EdgeAddrResolver{regions: []*region{®ion{inUse: map[string]*net.TCPAddr{":1": &net.TCPAddr{Port: 1}}}}} - addr, err = e.AnyAddr() - assert.NoError(t, err) - assert.Equal(t, &net.TCPAddr{Port: 1}, addr, "should have chosen an active address rather than nothing") -} - -// "White-box" test: runs ReplaceAddr() and checks internal state -func TestEdgeAddrResolver_ReplaceAddr(t *testing.T) { - e := &EdgeAddrResolver{regions: nil} - e.ReplaceAddr(&net.TCPAddr{Port: 1}) // this shouldn't panic, I guess - - testRegions := func() []*region { - return []*region{ - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 1}}, inUse: map[string]*net.TCPAddr{":2": &net.TCPAddr{Port: 2}, ":3": &net.TCPAddr{Port: 3}}}, - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 4}, &net.TCPAddr{Port: 5}}, inUse: map[string]*net.TCPAddr{":6": &net.TCPAddr{Port: 6}}}, - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 7}, &net.TCPAddr{Port: 8}}, inUse: map[string]*net.TCPAddr{":9": &net.TCPAddr{Port: 9}}}, - } - } - e = &EdgeAddrResolver{regions: testRegions()} - e.ReplaceAddr(&net.TCPAddr{Port: 6}) - var expected []*region - { - expected = testRegions() - delete(expected[1].inUse, ":6") - expected[1].addrs = append(expected[1].addrs, &net.TCPAddr{Port: 6}) - } - assert.Equal(t, expected, e.regions) -} - -// "White-box" test: runs MarkAddrBad() and checks internal state -func TestEdgeAddrResolver_MarkAddrBad(t *testing.T) { - e := &EdgeAddrResolver{regions: nil} - e.ReplaceAddr(&net.TCPAddr{Port: 1}) // this shouldn't panic, I guess - - testRegions := func() []*region { - return []*region{ - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 1}}, inUse: map[string]*net.TCPAddr{":2": &net.TCPAddr{Port: 2}, ":3": &net.TCPAddr{Port: 3}}}, - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 4}, &net.TCPAddr{Port: 5}}, inUse: map[string]*net.TCPAddr{":6": &net.TCPAddr{Port: 6}}}, - ®ion{addrs: []*net.TCPAddr{&net.TCPAddr{Port: 7}, &net.TCPAddr{Port: 8}}, inUse: map[string]*net.TCPAddr{":9": &net.TCPAddr{Port: 9}}}, - } - } - e = &EdgeAddrResolver{regions: testRegions()} - e.MarkAddrBad(&net.TCPAddr{Port: 6}) - var expected []*region - { - expected = testRegions() - delete(expected[1].inUse, ":6") - expected[1].bad = append(expected[1].bad, &net.TCPAddr{Port: 6}) - } - assert.Equal(t, expected, e.regions) -} diff --git a/connection/manager.go b/connection/manager.go index b5ceacfe..eeac2e01 100644 --- a/connection/manager.go +++ b/connection/manager.go @@ -13,6 +13,7 @@ import ( "github.com/sirupsen/logrus" "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" + "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/streamhandler" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -35,7 +36,7 @@ type EdgeManager struct { // cloudflaredConfig is the cloudflared configuration that is determined when the process first starts cloudflaredConfig *CloudflaredConfig // serviceDiscoverer returns the next edge addr to connect to - serviceDiscoverer EdgeServiceDiscoverer + serviceDiscoverer *edgediscovery.Edge // state is attributes of ConnectionManager that can change during runtime. state *edgeManagerState @@ -73,7 +74,7 @@ func NewEdgeManager( edgeConnMgrConfigurable *EdgeManagerConfigurable, userCredential []byte, tlsConfig *tls.Config, - serviceDiscoverer EdgeServiceDiscoverer, + serviceDiscoverer *edgediscovery.Edge, cloudflaredConfig *CloudflaredConfig, logger *logrus.Logger, ) *EdgeManager { @@ -91,27 +92,29 @@ func NewEdgeManager( func (em *EdgeManager) Run(ctx context.Context) error { defer em.shutdown() - resolveEdgeIPTicker := time.Tick(resolveEdgeAddrTTL) + // Currently, declarative tunnels don't have any concept of a stable connection + // Each edge connection is transient and when it dies, it is replaced by a different one, + // not restarted. + // So in the future we should really change this so that n connections are stored individually + connIndex := 0 for { select { case <-ctx.Done(): return errors.Wrap(ctx.Err(), "EdgeConnectionManager terminated") - case <-resolveEdgeIPTicker: - if err := em.serviceDiscoverer.Refresh(); err != nil { - em.logger.WithError(err).Warn("Cannot refresh Cloudflare edge addresses") - } default: time.Sleep(1 * time.Second) } // Create/delete connection one at a time, so we don't need to adjust for connections that are being created/deleted // in shouldCreateConnection or shouldReduceConnection calculation if em.state.shouldCreateConnection(em.serviceDiscoverer.AvailableAddrs()) { - if connErr := em.newConnection(ctx); connErr != nil { + if connErr := em.newConnection(ctx, connIndex); connErr != nil { if !connErr.ShouldRetry { em.logger.WithError(connErr).Error(em.noRetryMessage()) return connErr } em.logger.WithError(connErr).Error("cannot create new connection") + } else { + connIndex++ } } else if em.state.shouldReduceConnection() { if err := em.closeConnection(ctx); err != nil { @@ -126,8 +129,8 @@ func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurab em.state.updateConfigurable(newConfigurable) } -func (em *EdgeManager) newConnection(ctx context.Context) *tunnelpogs.ConnectError { - edgeTCPAddr, err := em.serviceDiscoverer.Addr() +func (em *EdgeManager) newConnection(ctx context.Context, index int) *tunnelpogs.ConnectError { + edgeTCPAddr, err := em.serviceDiscoverer.GetAddr(index) if err != nil { return retryConnection(fmt.Sprintf("edge address discovery error: %v", err)) } @@ -197,7 +200,7 @@ func (em *EdgeManager) serveConn(ctx context.Context, conn *Connection) { err := conn.Serve(ctx) em.logger.WithError(err).Warn("Connection closed") em.state.closeConnection(conn) - em.serviceDiscoverer.ReplaceAddr(conn.addr) + em.serviceDiscoverer.GiveBack(conn.addr) } func (em *EdgeManager) noRetryMessage() string { diff --git a/connection/manager_test.go b/connection/manager_test.go index a8dd0e58..465fc951 100644 --- a/connection/manager_test.go +++ b/connection/manager_test.go @@ -1,6 +1,7 @@ package connection import ( + "net" "testing" "time" @@ -8,8 +9,8 @@ import ( "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" + "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/streamhandler" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -48,14 +49,15 @@ func mockEdgeManager() *EdgeManager { newConfigChan := make(chan<- *pogs.ClientConfig) useConfigResultChan := make(<-chan *pogs.UseConfigurationResult) logger := logrus.New() + edge := edgediscovery.MockEdge(logger, []*net.TCPAddr{}) return NewEdgeManager( streamhandler.NewStreamHandler(newConfigChan, useConfigResultChan, logger), configurable, []byte{}, nil, - &mockEdgeServiceDiscoverer{}, + edge, cloudflaredConfig, - logrus.New(), + logger, ) } diff --git a/edgediscovery/allregions/discovery.go b/edgediscovery/allregions/discovery.go new file mode 100644 index 00000000..b18e3df6 --- /dev/null +++ b/edgediscovery/allregions/discovery.go @@ -0,0 +1,135 @@ +package allregions + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +const ( + // Used to discover HA origintunneld servers + srvService = "origintunneld" + srvProto = "tcp" + srvName = "argotunnel.com" + + // Used to fallback to DoT when we can't use the default resolver to + // discover HA origintunneld servers (GitHub issue #75). + dotServerName = "cloudflare-dns.com" + dotServerAddr = "1.1.1.1:853" + dotTimeout = time.Duration(15 * time.Second) + + // SRV record resolution TTL + resolveEdgeAddrTTL = 1 * time.Hour + + subsystemEdgeAddrResolver = "edgeAddrResolver" +) + +// Redeclare network functions so they can be overridden in tests. +var ( + netLookupSRV = net.LookupSRV + netLookupIP = net.LookupIP +) + +// If the call to net.LookupSRV fails, try to fall back to DoT from Cloudflare directly. +// +// Note: Instead of DoT, we could also have used DoH. Either of these: +// - directly via the JSON API (https://1.1.1.1/dns-query?ct=application/dns-json&name=_origintunneld._tcp.argotunnel.com&type=srv) +// - indirectly via `tunneldns.NewUpstreamHTTPS()` +// But both of these cases miss out on a key feature from the stdlib: +// "The returned records are sorted by priority and randomized by weight within a priority." +// (https://golang.org/pkg/net/#Resolver.LookupSRV) +// Does this matter? I don't know. It may someday. Let's use DoT so we don't need to worry about it. +// See also: Go feature request for stdlib-supported DoH: https://github.com/golang/go/issues/27552 +var fallbackLookupSRV = lookupSRVWithDOT + +var friendlyDNSErrorLines = []string{ + `Please try the following things to diagnose this issue:`, + ` 1. ensure that argotunnel.com is returning "origintunneld" service records.`, + ` Run your system's equivalent of: dig srv _origintunneld._tcp.argotunnel.com`, + ` 2. ensure that your DNS resolver is not returning compressed SRV records.`, + ` See GitHub issue https://github.com/golang/go/issues/27546`, + ` For example, you could use Cloudflare's 1.1.1.1 as your resolver:`, + ` https://developers.cloudflare.com/1.1.1.1/setting-up-1.1.1.1/`, +} + +// EdgeDiscovery implements HA service discovery lookup. +func edgeDiscovery(logger *logrus.Entry) ([][]*net.TCPAddr, error) { + _, addrs, err := netLookupSRV(srvService, srvProto, srvName) + if err != nil { + _, fallbackAddrs, fallbackErr := fallbackLookupSRV(srvService, srvProto, srvName) + if fallbackErr != nil || len(fallbackAddrs) == 0 { + // use the original DNS error `err` in messages, not `fallbackErr` + logger.Errorln("Error looking up Cloudflare edge IPs: the DNS query failed:", err) + for _, s := range friendlyDNSErrorLines { + logger.Errorln(s) + } + return nil, errors.Wrapf(err, "Could not lookup srv records on _%v._%v.%v", srvService, srvProto, srvName) + } + // Accept the fallback results and keep going + addrs = fallbackAddrs + } + + var resolvedIPsPerCNAME [][]*net.TCPAddr + for _, addr := range addrs { + ips, err := resolveSRVToTCP(addr) + if err != nil { + return nil, err + } + resolvedIPsPerCNAME = append(resolvedIPsPerCNAME, ips) + } + + return resolvedIPsPerCNAME, nil +} + +func lookupSRVWithDOT(service, proto, name string) (cname string, addrs []*net.SRV, err error) { + // Inspiration: https://github.com/artyom/dot/blob/master/dot.go + r := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, _ string, _ string) (net.Conn, error) { + var dialer net.Dialer + conn, err := dialer.DialContext(ctx, "tcp", dotServerAddr) + if err != nil { + return nil, err + } + tlsConfig := &tls.Config{ServerName: dotServerName} + return tls.Client(conn, tlsConfig), nil + }, + } + ctx, cancel := context.WithTimeout(context.Background(), dotTimeout) + defer cancel() + return r.LookupSRV(ctx, srvService, srvProto, srvName) +} + +func resolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) { + ips, err := netLookupIP(srv.Target) + if err != nil { + return nil, errors.Wrapf(err, "Couldn't resolve SRV record %v", srv) + } + if len(ips) == 0 { + return nil, fmt.Errorf("SRV record %v had no IPs", srv) + } + addrs := make([]*net.TCPAddr, len(ips)) + for i, ip := range ips { + addrs[i] = &net.TCPAddr{IP: ip, Port: int(srv.Port)} + } + return addrs, nil +} + +// ResolveAddrs resolves TCP address given a list of addresses. Address can be a hostname, however, it will return at most one +// of the hostname's IP addresses +func ResolveAddrs(addrs []string) ([]*net.TCPAddr, error) { + var tcpAddrs []*net.TCPAddr + for _, addr := range addrs { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + tcpAddrs = append(tcpAddrs, tcpAddr) + } + return tcpAddrs, nil +} diff --git a/edgediscovery/allregions/discovery_test.go b/edgediscovery/allregions/discovery_test.go new file mode 100644 index 00000000..0ead20bc --- /dev/null +++ b/edgediscovery/allregions/discovery_test.go @@ -0,0 +1,32 @@ +package allregions + +import ( + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func TestEdgeDiscovery(t *testing.T) { + mockAddrs := newMockAddrs(19, 2, 5) + netLookupSRV = mockNetLookupSRV(mockAddrs) + netLookupIP = mockNetLookupIP(mockAddrs) + + expectedAddrSet := map[string]bool{} + for _, addrs := range mockAddrs.addrMap { + for _, addr := range addrs { + expectedAddrSet[addr.String()] = true + } + } + + addrLists, err := edgeDiscovery(logrus.New().WithFields(logrus.Fields{})) + assert.NoError(t, err) + actualAddrSet := map[string]bool{} + for _, addrs := range addrLists { + for _, addr := range addrs { + actualAddrSet[addr.String()] = true + } + } + + assert.Equal(t, expectedAddrSet, actualAddrSet) +} diff --git a/edgediscovery/allregions/mocks_for_test.go b/edgediscovery/allregions/mocks_for_test.go new file mode 100644 index 00000000..3f0470a1 --- /dev/null +++ b/edgediscovery/allregions/mocks_for_test.go @@ -0,0 +1,89 @@ +package allregions + +import ( + "fmt" + "math" + "math/rand" + "net" + "reflect" + "testing/quick" +) + +type mockAddrs struct { + // a set of synthetic SRV records + addrMap map[net.SRV][]*net.TCPAddr + // the total number of addresses, aggregated across addrMap. + // For the convenience of test code that would otherwise have to compute + // this by hand every time. + numAddrs int +} + +func newMockAddrs(port uint16, numRegions uint8, numAddrsPerRegion uint8) mockAddrs { + addrMap := make(map[net.SRV][]*net.TCPAddr) + numAddrs := 0 + + for r := uint8(0); r < numRegions; r++ { + var ( + srv = net.SRV{Target: fmt.Sprintf("test-region-%v.example.com", r), Port: port} + addrs []*net.TCPAddr + ) + for a := uint8(0); a < numAddrsPerRegion; a++ { + addrs = append(addrs, &net.TCPAddr{ + IP: net.ParseIP(fmt.Sprintf("10.0.%v.%v", r, a)), + Port: int(port), + }) + } + addrMap[srv] = addrs + numAddrs += len(addrs) + } + return mockAddrs{addrMap: addrMap, numAddrs: numAddrs} +} + +var _ quick.Generator = mockAddrs{} + +func (mockAddrs) Generate(rand *rand.Rand, size int) reflect.Value { + port := uint16(rand.Intn(math.MaxUint16)) + numRegions := uint8(1 + rand.Intn(10)) + numAddrsPerRegion := uint8(1 + rand.Intn(32)) + result := newMockAddrs(port, numRegions, numAddrsPerRegion) + return reflect.ValueOf(result) +} + +// Returns a function compatible with net.LookupSRV that will return the SRV +// records from mockAddrs. +func mockNetLookupSRV( + m mockAddrs, +) func(service, proto, name string) (cname string, addrs []*net.SRV, err error) { + var addrs []*net.SRV + for k := range m.addrMap { + addr := k + addrs = append(addrs, &addr) + // We can't just do + // addrs = append(addrs, &k) + // `k` will be reused by subsequent loop iterations, + // so all the copies of `&k` would point to the same location. + } + return func(_, _, _ string) (string, []*net.SRV, error) { + return "", addrs, nil + } +} + +// Returns a function compatible with net.LookupIP that translates the SRV records +// from mockAddrs into IP addresses, based on the TCP addresses in mockAddrs. +func mockNetLookupIP( + m mockAddrs, +) func(host string) ([]net.IP, error) { + return func(host string) ([]net.IP, error) { + for srv, tcpAddrs := range m.addrMap { + if srv.Target != host { + continue + } + result := make([]net.IP, len(tcpAddrs)) + for i, tcpAddr := range tcpAddrs { + result[i] = tcpAddr.IP + } + return result, nil + } + return nil, fmt.Errorf("No IPs for %v", host) + } +} diff --git a/edgediscovery/allregions/region.go b/edgediscovery/allregions/region.go new file mode 100644 index 00000000..31fb7311 --- /dev/null +++ b/edgediscovery/allregions/region.go @@ -0,0 +1,78 @@ +package allregions + +import ( + "net" +) + +// Region contains cloudflared edge addresses. The edge is partitioned into several regions for +// redundancy purposes. +type Region struct { + connFor map[*net.TCPAddr]UsedBy +} + +// NewRegion creates a region with the given addresses, which are all unused. +func NewRegion(addrs []*net.TCPAddr) Region { + // The zero value of UsedBy is Unused(), so we can just initialize the map's values with their + // zero values. + m := make(map[*net.TCPAddr]UsedBy) + for _, addr := range addrs { + m[addr] = Unused() + } + return Region{connFor: m} +} + +// AddrUsedBy finds the address used by the given connection in this region. +// Returns nil if the connection isn't using any IP. +func (r *Region) AddrUsedBy(connID int) *net.TCPAddr { + for addr, used := range r.connFor { + if used.Used && used.ConnID == connID { + return addr + } + } + return nil +} + +// AvailableAddrs counts how many unused addresses this region contains. +func (r Region) AvailableAddrs() int { + n := 0 + for _, usedby := range r.connFor { + if !usedby.Used { + n++ + } + } + return n +} + +// GetUnusedIP returns a random unused address in this region. +// Returns nil if all addresses are in use. +func (r Region) GetUnusedIP(excluding *net.TCPAddr) *net.TCPAddr { + for addr, usedby := range r.connFor { + if !usedby.Used && addr != excluding { + return addr + } + } + return nil +} + +// Use the address, assigning it to a proxy connection. +func (r Region) Use(addr *net.TCPAddr, connID int) { + r.connFor[addr] = InUse(connID) +} + +// GetAnyAddress returns an arbitrary address from the region. +func (r Region) GetAnyAddress() *net.TCPAddr { + for addr := range r.connFor { + return addr + } + return nil +} + +// GiveBack the address, ensuring it is no longer assigned to an IP. +// Returns true if the address is in this region. +func (r Region) GiveBack(addr *net.TCPAddr) (ok bool) { + if _, ok := r.connFor[addr]; !ok { + return false + } + r.connFor[addr] = Unused() + return true +} diff --git a/edgediscovery/allregions/region_test.go b/edgediscovery/allregions/region_test.go new file mode 100644 index 00000000..f5fb64ad --- /dev/null +++ b/edgediscovery/allregions/region_test.go @@ -0,0 +1,287 @@ +package allregions + +import ( + "fmt" + "net" + "reflect" + "testing" +) + +func TestRegion_New(t *testing.T) { + r := NewRegion([]*net.TCPAddr{&addr0, &addr1, &addr2}) + fmt.Println(r.connFor) + if r.AvailableAddrs() != 3 { + t.Errorf("r.AvailableAddrs() == %v but want 3", r.AvailableAddrs()) + } +} + +func TestRegion_AddrUsedBy(t *testing.T) { + type fields struct { + connFor map[*net.TCPAddr]UsedBy + } + type args struct { + connID int + } + tests := []struct { + name string + fields fields + args args + want *net.TCPAddr + }{ + { + name: "happy trivial test", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr0: InUse(0), + }}, + args: args{connID: 0}, + want: &addr0, + }, + { + name: "sad trivial test", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr0: InUse(0), + }}, + args: args{connID: 1}, + want: nil, + }, + { + name: "sad test", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr0: InUse(0), + &addr1: InUse(1), + &addr2: InUse(2), + }}, + args: args{connID: 3}, + want: nil, + }, + { + name: "happy test", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr0: InUse(0), + &addr1: InUse(1), + &addr2: InUse(2), + }}, + args: args{connID: 1}, + want: &addr1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &Region{ + connFor: tt.fields.connFor, + } + if got := r.AddrUsedBy(tt.args.connID); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Region.AddrUsedBy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRegion_AvailableAddrs(t *testing.T) { + type fields struct { + connFor map[*net.TCPAddr]UsedBy + } + tests := []struct { + name string + fields fields + want int + }{ + { + name: "contains addresses", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr0: InUse(0), + &addr1: Unused(), + &addr2: InUse(2), + }}, + want: 1, + }, + { + name: "all free", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr0: Unused(), + &addr1: Unused(), + &addr2: Unused(), + }}, + want: 3, + }, + { + name: "all used", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr0: InUse(0), + &addr1: InUse(1), + &addr2: InUse(2), + }}, + want: 0, + }, + { + name: "empty", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{}}, + want: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := Region{ + connFor: tt.fields.connFor, + } + if got := r.AvailableAddrs(); got != tt.want { + t.Errorf("Region.AvailableAddrs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRegion_GetUnusedIP(t *testing.T) { + type fields struct { + connFor map[*net.TCPAddr]UsedBy + } + type args struct { + excluding *net.TCPAddr + } + tests := []struct { + name string + fields fields + args args + want *net.TCPAddr + }{ + { + name: "happy test with excluding set", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr0: Unused(), + &addr1: Unused(), + &addr2: InUse(2), + }}, + args: args{excluding: &addr0}, + want: &addr1, + }, + { + name: "happy test with no excluding", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr0: InUse(0), + &addr1: Unused(), + &addr2: InUse(2), + }}, + args: args{excluding: nil}, + want: &addr1, + }, + { + name: "sad test with no excluding", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr0: InUse(0), + &addr1: InUse(1), + &addr2: InUse(2), + }}, + args: args{excluding: nil}, + want: nil, + }, + { + name: "sad test with excluding", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr0: Unused(), + &addr1: InUse(1), + &addr2: InUse(2), + }}, + args: args{excluding: &addr0}, + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := Region{ + connFor: tt.fields.connFor, + } + if got := r.GetUnusedIP(tt.args.excluding); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Region.GetUnusedIP() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestRegion_GiveBack(t *testing.T) { + type fields struct { + connFor map[*net.TCPAddr]UsedBy + } + type args struct { + addr *net.TCPAddr + } + tests := []struct { + name string + fields fields + args args + wantOk bool + availableAfter int + }{ + { + name: "sad test with excluding", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr1: InUse(1), + }}, + args: args{addr: &addr1}, + wantOk: true, + availableAfter: 1, + }, + { + name: "sad test with excluding", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr1: InUse(1), + }}, + args: args{addr: &addr2}, + wantOk: false, + availableAfter: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := Region{ + connFor: tt.fields.connFor, + } + if gotOk := r.GiveBack(tt.args.addr); gotOk != tt.wantOk { + t.Errorf("Region.GiveBack() = %v, want %v", gotOk, tt.wantOk) + } + if tt.availableAfter != r.AvailableAddrs() { + t.Errorf("Region.AvailableAddrs() = %v, want %v", r.AvailableAddrs(), tt.availableAfter) + } + }) + } +} + +func TestRegion_GetAnyAddress(t *testing.T) { + type fields struct { + connFor map[*net.TCPAddr]UsedBy + } + tests := []struct { + name string + fields fields + wantNil bool + }{ + { + name: "Sad test -- GetAnyAddress should only fail if the region is empty", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{}}, + wantNil: true, + }, + { + name: "Happy test (all addresses unused)", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr0: Unused(), + }}, + wantNil: false, + }, + { + name: "Happy test (GetAnyAddress can still return addresses used by proxy conns)", + fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + &addr0: InUse(2), + }}, + wantNil: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := Region{ + connFor: tt.fields.connFor, + } + if got := r.GetAnyAddress(); tt.wantNil != (got == nil) { + t.Errorf("Region.GetAnyAddress() = %v, but should it return nil? %v", got, tt.wantNil) + } + }) + } +} diff --git a/edgediscovery/allregions/regions.go b/edgediscovery/allregions/regions.go new file mode 100644 index 00000000..8a883523 --- /dev/null +++ b/edgediscovery/allregions/regions.go @@ -0,0 +1,118 @@ +package allregions + +import ( + "fmt" + "net" + + "github.com/sirupsen/logrus" +) + +// Regions stores Cloudflare edge network IPs, partitioned into two regions. +// This is NOT thread-safe. Users of this package should use it with a lock. +type Regions struct { + region1 Region + region2 Region +} + +// ------------------------------------ +// Constructors +// ------------------------------------ + +// ResolveEdge resolves the Cloudflare edge, returning all regions discovered. +func ResolveEdge(logger *logrus.Entry) (*Regions, error) { + addrLists, err := edgeDiscovery(logger) + if err != nil { + return nil, err + } + if len(addrLists) < 2 { + return nil, fmt.Errorf("expected at least 2 Cloudflare Regions regions, but SRV only returned %v", len(addrLists)) + } + return &Regions{ + region1: NewRegion(addrLists[0]), + region2: NewRegion(addrLists[1]), + }, nil +} + +// StaticEdge creates a list of edge addresses from the list of hostnames. +// Mainly used for testing connectivity. +func StaticEdge(hostnames []string) (*Regions, error) { + addrs, err := ResolveAddrs(hostnames) + if err != nil { + return nil, err + } + return NewNoResolve(addrs), nil +} + +// NewNoResolve doesn't resolve the edge. Instead it just uses the given addresses. +// You probably only need this for testing. +func NewNoResolve(addrs []*net.TCPAddr) *Regions { + region1 := make([]*net.TCPAddr, 0) + region2 := make([]*net.TCPAddr, 0) + for i, v := range addrs { + if i%2 == 0 { + region1 = append(region1, v) + } else { + region2 = append(region2, v) + } + } + + return &Regions{ + region1: NewRegion(region1), + region2: NewRegion(region2), + } +} + +// ------------------------------------ +// Methods +// ------------------------------------ + +// GetAnyAddress returns an arbitrary address from the larger region. +func (rs *Regions) GetAnyAddress() *net.TCPAddr { + if rs.region1.AvailableAddrs() > rs.region2.AvailableAddrs() { + return rs.region1.GetAnyAddress() + } + return rs.region2.GetAnyAddress() +} + +// AddrUsedBy finds the address used by the given connection. +// Returns nil if the connection isn't using an address. +func (rs *Regions) AddrUsedBy(connID int) *net.TCPAddr { + if addr := rs.region1.AddrUsedBy(connID); addr != nil { + return addr + } + return rs.region2.AddrUsedBy(connID) +} + +// GetUnusedAddr gets an unused addr from the edge, excluding the given addr. Prefer to use addresses +// evenly across both regions. +func (rs *Regions) GetUnusedAddr(excluding *net.TCPAddr, connID int) *net.TCPAddr { + var addr *net.TCPAddr + if rs.region1.AvailableAddrs() > rs.region2.AvailableAddrs() { + addr = rs.region1.GetUnusedIP(excluding) + rs.region1.Use(addr, connID) + } else { + addr = rs.region2.GetUnusedIP(excluding) + rs.region2.Use(addr, connID) + } + + if addr == nil { + return nil + } + + // Mark the address as used and return it + return addr +} + +// AvailableAddrs returns how many edge addresses aren't used. +func (rs *Regions) AvailableAddrs() int { + return rs.region1.AvailableAddrs() + rs.region2.AvailableAddrs() +} + +// GiveBack the address so that other connections can use it. +// Returns true if the address is in this edge. +func (rs *Regions) GiveBack(addr *net.TCPAddr) bool { + if found := rs.region1.GiveBack(addr); found { + return found + } + return rs.region2.GiveBack(addr) +} diff --git a/edgediscovery/allregions/regions_test.go b/edgediscovery/allregions/regions_test.go new file mode 100644 index 00000000..6c88f281 --- /dev/null +++ b/edgediscovery/allregions/regions_test.go @@ -0,0 +1,140 @@ +package allregions + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +var ( + addr0 = net.TCPAddr{ + IP: net.ParseIP("123.4.5.0"), + Port: 8000, + Zone: "", + } + addr1 = net.TCPAddr{ + IP: net.ParseIP("123.4.5.1"), + Port: 8000, + Zone: "", + } + addr2 = net.TCPAddr{ + IP: net.ParseIP("123.4.5.2"), + Port: 8000, + Zone: "", + } + addr3 = net.TCPAddr{ + IP: net.ParseIP("123.4.5.3"), + Port: 8000, + Zone: "", + } +) + +func makeRegions() Regions { + r1 := NewRegion([]*net.TCPAddr{&addr0, &addr1}) + r2 := NewRegion([]*net.TCPAddr{&addr2, &addr3}) + return Regions{region1: r1, region2: r2} +} + +func TestRegions_AddrUsedBy(t *testing.T) { + rs := makeRegions() + addr1 := rs.GetUnusedAddr(nil, 1) + assert.Equal(t, addr1, rs.AddrUsedBy(1)) + addr2 := rs.GetUnusedAddr(nil, 2) + assert.Equal(t, addr2, rs.AddrUsedBy(2)) + addr3 := rs.GetUnusedAddr(nil, 3) + assert.Equal(t, addr3, rs.AddrUsedBy(3)) +} + +func TestRegions_Giveback_Region1(t *testing.T) { + rs := makeRegions() + rs.region1.Use(&addr0, 0) + rs.region1.Use(&addr1, 1) + rs.region2.Use(&addr2, 2) + rs.region2.Use(&addr3, 3) + + assert.Equal(t, 0, rs.AvailableAddrs()) + + rs.GiveBack(&addr0) + assert.Equal(t, &addr0, rs.GetUnusedAddr(nil, 3)) +} +func TestRegions_Giveback_Region2(t *testing.T) { + rs := makeRegions() + rs.region1.Use(&addr0, 0) + rs.region1.Use(&addr1, 1) + rs.region2.Use(&addr2, 2) + rs.region2.Use(&addr3, 3) + + assert.Equal(t, 0, rs.AvailableAddrs()) + + rs.GiveBack(&addr2) + assert.Equal(t, &addr2, rs.GetUnusedAddr(nil, 2)) +} + +func TestRegions_GetUnusedAddr_OneAddrLeft(t *testing.T) { + rs := makeRegions() + + rs.region1.Use(&addr0, 0) + rs.region1.Use(&addr1, 1) + rs.region2.Use(&addr2, 2) + + assert.Equal(t, 1, rs.AvailableAddrs()) + assert.Equal(t, &addr3, rs.GetUnusedAddr(nil, 3)) +} + +func TestRegions_GetUnusedAddr_Excluding_Region1(t *testing.T) { + rs := makeRegions() + + rs.region1.Use(&addr0, 0) + rs.region1.Use(&addr1, 1) + + assert.Equal(t, 2, rs.AvailableAddrs()) + assert.Equal(t, &addr3, rs.GetUnusedAddr(&addr2, 3)) +} + +func TestRegions_GetUnusedAddr_Excluding_Region2(t *testing.T) { + rs := makeRegions() + + rs.region2.Use(&addr2, 0) + rs.region2.Use(&addr3, 1) + + assert.Equal(t, 2, rs.AvailableAddrs()) + assert.Equal(t, &addr1, rs.GetUnusedAddr(&addr0, 1)) +} + +func TestNewNoResolveBalancesRegions(t *testing.T) { + type args struct { + addrs []*net.TCPAddr + } + tests := []struct { + name string + args args + }{ + { + name: "one address", + args: args{addrs: []*net.TCPAddr{&addr0}}, + }, + { + name: "two addresses", + args: args{addrs: []*net.TCPAddr{&addr0, &addr1}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + regions := NewNoResolve(tt.args.addrs) + RegionsIsBalanced(t, regions) + }) + } +} + +func RegionsIsBalanced(t *testing.T, rs *Regions) { + delta := rs.region1.AvailableAddrs() - rs.region2.AvailableAddrs() + assert.True(t, abs(delta) <= 1) +} + +func abs(x int) int { + if x >= 0 { + return x + } + return -x +} diff --git a/edgediscovery/allregions/usedby.go b/edgediscovery/allregions/usedby.go new file mode 100644 index 00000000..6405c8f7 --- /dev/null +++ b/edgediscovery/allregions/usedby.go @@ -0,0 +1,14 @@ +package allregions + +type UsedBy struct { + ConnID int + Used bool +} + +func InUse(connID int) UsedBy { + return UsedBy{ConnID: connID, Used: true} +} + +func Unused() UsedBy { + return UsedBy{} +} diff --git a/edgediscovery/edgediscovery.go b/edgediscovery/edgediscovery.go new file mode 100644 index 00000000..abf83a66 --- /dev/null +++ b/edgediscovery/edgediscovery.go @@ -0,0 +1,143 @@ +package edgediscovery + +import ( + "fmt" + "net" + "sync" + + "github.com/cloudflare/cloudflared/edgediscovery/allregions" + + "github.com/sirupsen/logrus" +) + +const ( + subsystem = "edgediscovery" +) + +var errNoAddressesLeft = fmt.Errorf("There are no free edge addresses left") + +// Edge finds addresses on the Cloudflare edge and hands them out to connections. +type Edge struct { + regions *allregions.Regions + sync.Mutex + logger *logrus.Entry +} + +// ------------------------------------ +// Constructors +// ------------------------------------ + +// ResolveEdge runs the initial discovery of the Cloudflare edge, finding Addrs that can be allocated +// to connections. +func ResolveEdge(l *logrus.Logger) (*Edge, error) { + logger := l.WithField("subsystem", subsystem) + regions, err := allregions.ResolveEdge(logger) + if err != nil { + return new(Edge), err + } + return &Edge{ + logger: logger, + regions: regions, + }, nil +} + +// StaticEdge creates a list of edge addresses from the list of hostnames. Mainly used for testing connectivity. +func StaticEdge(l *logrus.Logger, hostnames []string) (*Edge, error) { + logger := l.WithField("subsystem", subsystem) + regions, err := allregions.StaticEdge(hostnames) + if err != nil { + return new(Edge), err + } + return &Edge{ + logger: logger, + regions: regions, + }, nil +} + +// MockEdge creates a Cloudflare Edge from arbitrary TCP addresses. Used for testing. +func MockEdge(l *logrus.Logger, addrs []*net.TCPAddr) *Edge { + logger := l.WithField("subsystem", subsystem) + regions := allregions.NewNoResolve(addrs) + return &Edge{ + logger: logger, + regions: regions, + } +} + +// ------------------------------------ +// Methods +// ------------------------------------ + +// GetAddrForRPC gives this connection an edge Addr. +func (ed *Edge) GetAddrForRPC() (*net.TCPAddr, error) { + ed.Lock() + defer ed.Unlock() + addr := ed.regions.GetAnyAddress() + if addr == nil { + return nil, errNoAddressesLeft + } + return addr, nil +} + +// GetAddr gives this proxy connection an edge Addr. Prefer Addrs this connection has already used. +func (ed *Edge) GetAddr(connID int) (*net.TCPAddr, error) { + ed.Lock() + defer ed.Unlock() + logger := ed.logger.WithFields(logrus.Fields{ + "connID": connID, + "function": "GetAddr", + }) + + // If this connection has already used an edge addr, return it. + if addr := ed.regions.AddrUsedBy(connID); addr != nil { + logger.Debug("Returning same address back to proxy connection") + return addr, nil + } + + // Otherwise, give it an unused one + addr := ed.regions.GetUnusedAddr(nil, connID) + if addr == nil { + logger.Debug("No addresses left to give proxy connection") + return nil, errNoAddressesLeft + } + logger.Debug("Giving connection its new address") + return addr, nil +} + +// GetDifferentAddr gives back the proxy connection's edge Addr and uses a new one. +func (ed *Edge) GetDifferentAddr(connID int) (*net.TCPAddr, error) { + ed.Lock() + defer ed.Unlock() + logger := ed.logger.WithFields(logrus.Fields{ + "connID": connID, + "function": "GetDifferentAddr", + }) + + oldAddr := ed.regions.AddrUsedBy(connID) + if oldAddr != nil { + ed.regions.GiveBack(oldAddr) + } + addr := ed.regions.GetUnusedAddr(oldAddr, connID) + if addr == nil { + logger.Debug("No addresses left to give proxy connection") + return nil, errNoAddressesLeft + } + logger.Debug("Giving connection its new address") + return addr, nil +} + +// AvailableAddrs returns how many unused addresses there are left. +func (ed *Edge) AvailableAddrs() int { + ed.Lock() + defer ed.Unlock() + return ed.regions.AvailableAddrs() +} + +// GiveBack the address so that other connections can use it. +// Returns true if the address is in this edge. +func (ed *Edge) GiveBack(addr *net.TCPAddr) bool { + ed.Lock() + defer ed.Unlock() + ed.logger.WithField("function", "GiveBack").Debug("Address now unused") + return ed.regions.GiveBack(addr) +} diff --git a/edgediscovery/edgediscovery_test.go b/edgediscovery/edgediscovery_test.go new file mode 100644 index 00000000..e25dc1bc --- /dev/null +++ b/edgediscovery/edgediscovery_test.go @@ -0,0 +1,130 @@ +package edgediscovery + +import ( + "net" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +var ( + addr0 = net.TCPAddr{ + IP: net.ParseIP("123.0.0.0"), + Port: 8000, + Zone: "", + } + addr1 = net.TCPAddr{ + IP: net.ParseIP("123.0.0.1"), + Port: 8000, + Zone: "", + } + addr2 = net.TCPAddr{ + IP: net.ParseIP("123.0.0.2"), + Port: 8000, + Zone: "", + } + addr3 = net.TCPAddr{ + IP: net.ParseIP("123.0.0.3"), + Port: 8000, + Zone: "", + } +) + +func TestGiveBack(t *testing.T) { + l := logrus.New() + edge := MockEdge(l, []*net.TCPAddr{&addr0, &addr1, &addr2, &addr3}) + + // Give this connection an address + assert.Equal(t, 4, edge.AvailableAddrs()) + const connID = 0 + addr, err := edge.GetAddr(connID) + assert.NoError(t, err) + assert.NotNil(t, addr) + assert.Equal(t, 3, edge.AvailableAddrs()) + + // Get it back + edge.GiveBack(addr) + assert.Equal(t, 4, edge.AvailableAddrs()) +} +func TestGetAddrForRPC(t *testing.T) { + l := logrus.New() + edge := MockEdge(l, []*net.TCPAddr{&addr0, &addr1, &addr2, &addr3}) + + // Get a connection + assert.Equal(t, 4, edge.AvailableAddrs()) + addr, err := edge.GetAddrForRPC() + assert.NoError(t, err) + assert.NotNil(t, addr) + + // Using an address for RPC shouldn't consume it + assert.Equal(t, 4, edge.AvailableAddrs()) + + // Get it back + edge.GiveBack(addr) + assert.Equal(t, 4, edge.AvailableAddrs()) +} + +func TestOnlyOneAddrLeft(t *testing.T) { + l := logrus.New() + + // Make an edge with only one address + edge := MockEdge(l, []*net.TCPAddr{&addr0}) + + // Use the only address + const connID = 0 + addr, err := edge.GetAddr(connID) + assert.NoError(t, err) + assert.NotNil(t, addr) + + // If that edge address is "bad", there's no alternative address. + _, err = edge.GetDifferentAddr(connID) + assert.Error(t, err) +} + +func TestNoAddrsLeft(t *testing.T) { + l := logrus.New() + + // Make an edge with no addresses + edge := MockEdge(l, []*net.TCPAddr{}) + + _, err := edge.GetAddr(2) + assert.Error(t, err) + _, err = edge.GetAddrForRPC() + assert.Error(t, err) +} + +func TestGetAddr(t *testing.T) { + l := logrus.New() + edge := MockEdge(l, []*net.TCPAddr{&addr0, &addr1, &addr2, &addr3}) + + // Give this connection an address + const connID = 0 + addr, err := edge.GetAddr(connID) + assert.NoError(t, err) + assert.NotNil(t, addr) + + // If the same connection requests another address, it should get the same one. + addr2, err := edge.GetAddr(connID) + assert.NoError(t, err) + assert.Equal(t, addr, addr2) +} + +func TestGetDifferentAddr(t *testing.T) { + l := logrus.New() + edge := MockEdge(l, []*net.TCPAddr{&addr0, &addr1, &addr2, &addr3}) + + // Give this connection an address + assert.Equal(t, 4, edge.AvailableAddrs()) + const connID = 0 + addr, err := edge.GetAddr(connID) + assert.NoError(t, err) + assert.NotNil(t, addr) + assert.Equal(t, 3, edge.AvailableAddrs()) + + // If the same connection requests another address, it should get the same one. + addr2, err := edge.GetDifferentAddr(connID) + assert.NoError(t, err) + assert.NotEqual(t, addr, addr2) + assert.Equal(t, 3, edge.AvailableAddrs()) +} diff --git a/origin/supervisor.go b/origin/supervisor.go index a28ce573..b8df583f 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -12,6 +12,7 @@ import ( "github.com/sirupsen/logrus" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/signal" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -39,10 +40,12 @@ var ( errEventDigestUnset = errors.New("event digest unset") ) +// Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and +// reconnects them if they disconnect. type Supervisor struct { cloudflaredUUID uuid.UUID config *TunnelConfig - edgeIPs connection.EdgeServiceDiscoverer + edgeIPs *edgediscovery.Edge lastResolve time.Time resolverC chan resolveResult tunnelErrors chan tunnelError @@ -73,13 +76,13 @@ type tunnelError struct { func NewSupervisor(config *TunnelConfig, u uuid.UUID) (*Supervisor, error) { var ( - edgeIPs connection.EdgeServiceDiscoverer + edgeIPs *edgediscovery.Edge err error ) if len(config.EdgeAddrs) > 0 { - edgeIPs, err = connection.NewEdgeHostnameResolver(config.EdgeAddrs) + edgeIPs, err = edgediscovery.StaticEdge(config.Logger, config.EdgeAddrs) } else { - edgeIPs, err = connection.NewEdgeAddrResolver(config.Logger) + edgeIPs, err = edgediscovery.ResolveEdge(config.Logger) } if err != nil { return nil, err @@ -141,14 +144,8 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er backoffTimer = backoff.BackoffTimer() } - // If the error is a dial error, the problem is likely to be network related - // try another addr before refreshing since we are likely to get back the - // same IPs in the same order. Same problem with duplicate connection error. - if s.unusedIPs() && tunnelError.addr != nil { - s.edgeIPs.MarkAddrBad(tunnelError.addr) - } else { - s.refreshEdgeIPs() - } + // Previously we'd mark the edge address as bad here, but now we'll just silently use + // another. } // Backoff was set and its timer expired case <-backoffTimer: @@ -191,11 +188,6 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal) er func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal) error { logger := s.logger - err := s.edgeIPs.Refresh() - if err != nil { - logger.Infof("ResolveEdgeIPs err") - return err - } s.lastResolve = time.Now() availableAddrs := int(s.edgeIPs.AvailableAddrs()) if s.config.HAConnections > availableAddrs { @@ -228,17 +220,19 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign addr *net.TCPAddr err error ) + const thisConnID = 0 defer func() { - s.tunnelErrors <- tunnelError{index: 0, addr: addr, err: err} + s.tunnelErrors <- tunnelError{index: thisConnID, addr: addr, err: err} }() - addr, err = s.edgeIPs.Addr() + addr, err = s.edgeIPs.GetAddr(thisConnID) if err != nil { return } - err = ServeTunnelLoop(ctx, s, s.config, addr, 0, connectedSignal, s.cloudflaredUUID) - + err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID) + // If the first tunnel disconnects, keep restarting it. + edgeErrors := 0 for s.unusedIPs() { if ctx.Err() != nil { return @@ -249,15 +243,17 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign // try the next address if it was a dialError(network problem) or // dupConnRegisterTunnelError case connection.DialError, dupConnRegisterTunnelError: - s.edgeIPs.MarkAddrBad(addr) + edgeErrors++ default: return } - addr, err = s.edgeIPs.Addr() - if err != nil { - return + if edgeErrors >= 2 { + addr, err = s.edgeIPs.GetDifferentAddr(thisConnID) + if err != nil { + return + } } - err = ServeTunnelLoop(ctx, s, s.config, addr, 0, connectedSignal, s.cloudflaredUUID) + err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID) } } @@ -272,7 +268,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal s.tunnelErrors <- tunnelError{index: index, addr: addr, err: err} }() - addr, err = s.edgeIPs.Addr() + addr, err = s.edgeIPs.GetAddr(index) if err != nil { return } @@ -302,20 +298,6 @@ func (s *Supervisor) unusedIPs() bool { return s.edgeIPs.AvailableAddrs() > s.config.HAConnections } -func (s *Supervisor) refreshEdgeIPs() { - if s.resolverC != nil { - return - } - if time.Since(s.lastResolve) < resolveTTL { - return - } - s.resolverC = make(chan resolveResult) - go func() { - err := s.edgeIPs.Refresh() - s.resolverC <- resolveResult{err: err} - }() -} - func (s *Supervisor) ReconnectToken() ([]byte, error) { s.jwtLock.RLock() defer s.jwtLock.RUnlock() @@ -385,7 +367,7 @@ func (s *Supervisor) refreshAuth( } func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) { - arbitraryEdgeIP, err := s.edgeIPs.AnyAddr() + arbitraryEdgeIP, err := s.edgeIPs.GetAddrForRPC() if err != nil { return nil, err } diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index 7ae111df..ca9764b1 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -1,3 +1,4 @@ +// Package supervisor is used by declarative tunnels to get/apply new config from the edge. package supervisor import ( @@ -13,6 +14,7 @@ import ( "github.com/cloudflare/cloudflared/cmd/cloudflared/updater" "github.com/cloudflare/cloudflared/connection" + "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/streamhandler" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -56,7 +58,7 @@ func NewSupervisor( defaultClientConfig *pogs.ClientConfig, userCredential []byte, tlsConfig *tls.Config, - serviceDiscoverer connection.EdgeServiceDiscoverer, + serviceDiscoverer *edgediscovery.Edge, cloudflaredConfig *connection.CloudflaredConfig, autoupdater *updater.AutoUpdater, supportAutoupdate bool,