diff --git a/edgediscovery/allregions/discovery.go b/edgediscovery/allregions/discovery.go index 59460e4d..6ffeded3 100644 --- a/edgediscovery/allregions/discovery.go +++ b/edgediscovery/allregions/discovery.go @@ -30,6 +30,12 @@ var ( netLookupIP = net.LookupIP ) +// EdgeAddr is a representation of possible ways to refer an edge location. +type EdgeAddr struct { + TCP *net.TCPAddr + UDP *net.UDPAddr +} + // If the call to net.LookupSRV fails, try to fall back to DoT from Cloudflare directly. // // Note: Instead of DoT, we could also have used DoH. Either of these: @@ -53,7 +59,7 @@ var friendlyDNSErrorLines = []string{ } // EdgeDiscovery implements HA service discovery lookup. -func edgeDiscovery(log *zerolog.Logger) ([][]*net.TCPAddr, error) { +func edgeDiscovery(log *zerolog.Logger) ([][]*EdgeAddr, error) { _, addrs, err := netLookupSRV(srvService, srvProto, srvName) if err != nil { _, fallbackAddrs, fallbackErr := fallbackLookupSRV(srvService, srvProto, srvName) @@ -69,16 +75,16 @@ func edgeDiscovery(log *zerolog.Logger) ([][]*net.TCPAddr, error) { addrs = fallbackAddrs } - var resolvedIPsPerCNAME [][]*net.TCPAddr + var resolvedAddrPerCNAME [][]*EdgeAddr for _, addr := range addrs { - ips, err := resolveSRVToTCP(addr) + edgeAddrs, err := resolveSRV(addr) if err != nil { return nil, err } - resolvedIPsPerCNAME = append(resolvedIPsPerCNAME, ips) + resolvedAddrPerCNAME = append(resolvedAddrPerCNAME, edgeAddrs) } - return resolvedIPsPerCNAME, nil + return resolvedAddrPerCNAME, nil } func lookupSRVWithDOT(service, proto, name string) (cname string, addrs []*net.SRV, err error) { @@ -100,7 +106,7 @@ func lookupSRVWithDOT(service, proto, name string) (cname string, addrs []*net.S return r.LookupSRV(ctx, srvService, srvProto, srvName) } -func resolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) { +func resolveSRV(srv *net.SRV) ([]*EdgeAddr, error) { ips, err := netLookupIP(srv.Target) if err != nil { return nil, errors.Wrapf(err, "Couldn't resolve SRV record %v", srv) @@ -108,23 +114,36 @@ func resolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) { if len(ips) == 0 { return nil, fmt.Errorf("SRV record %v had no IPs", srv) } - addrs := make([]*net.TCPAddr, len(ips)) + addrs := make([]*EdgeAddr, len(ips)) for i, ip := range ips { - addrs[i] = &net.TCPAddr{IP: ip, Port: int(srv.Port)} + addrs[i] = &EdgeAddr{ + TCP: &net.TCPAddr{IP: ip, Port: int(srv.Port)}, + UDP: &net.UDPAddr{IP: ip, Port: int(srv.Port)}, + } } return addrs, nil } // ResolveAddrs resolves TCP address given a list of addresses. Address can be a hostname, however, it will return at most one // of the hostname's IP addresses. -func ResolveAddrs(addrs []string, log *zerolog.Logger) (resolved []*net.TCPAddr) { +func ResolveAddrs(addrs []string, log *zerolog.Logger) (resolved []*EdgeAddr) { for _, addr := range addrs { tcpAddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { - log.Err(err).Msgf("Failed to resolve %s", addr) - } else { - resolved = append(resolved, tcpAddr) + log.Error().Msgf("Failed to resolve %s to TCP address, err: %v", addr, err) + continue } + + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + log.Error().Msgf("Failed to resolve %s to UDP address, err: %v", addr, err) + continue + } + resolved = append(resolved, &EdgeAddr{ + TCP: tcpAddr, + UDP: udpAddr, + }) + } return } diff --git a/edgediscovery/allregions/discovery_test.go b/edgediscovery/allregions/discovery_test.go index 82f03ac1..007519b2 100644 --- a/edgediscovery/allregions/discovery_test.go +++ b/edgediscovery/allregions/discovery_test.go @@ -1,12 +1,17 @@ package allregions import ( + "fmt" "testing" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" ) +func (ea *EdgeAddr) String() string { + return fmt.Sprintf("%s-%s", ea.TCP, ea.UDP) +} + func TestEdgeDiscovery(t *testing.T) { mockAddrs := newMockAddrs(19, 2, 5) netLookupSRV = mockNetLookupSRV(mockAddrs) diff --git a/edgediscovery/allregions/mocks_for_test.go b/edgediscovery/allregions/mocks_for_test.go index 3f0470a1..36298be2 100644 --- a/edgediscovery/allregions/mocks_for_test.go +++ b/edgediscovery/allregions/mocks_for_test.go @@ -11,7 +11,7 @@ import ( type mockAddrs struct { // a set of synthetic SRV records - addrMap map[net.SRV][]*net.TCPAddr + addrMap map[net.SRV][]*EdgeAddr // the total number of addresses, aggregated across addrMap. // For the convenience of test code that would otherwise have to compute // this by hand every time. @@ -19,19 +19,24 @@ type mockAddrs struct { } func newMockAddrs(port uint16, numRegions uint8, numAddrsPerRegion uint8) mockAddrs { - addrMap := make(map[net.SRV][]*net.TCPAddr) + addrMap := make(map[net.SRV][]*EdgeAddr) numAddrs := 0 for r := uint8(0); r < numRegions; r++ { var ( srv = net.SRV{Target: fmt.Sprintf("test-region-%v.example.com", r), Port: port} - addrs []*net.TCPAddr + addrs []*EdgeAddr ) for a := uint8(0); a < numAddrsPerRegion; a++ { - addrs = append(addrs, &net.TCPAddr{ + tcpAddr := &net.TCPAddr{ IP: net.ParseIP(fmt.Sprintf("10.0.%v.%v", r, a)), Port: int(port), - }) + } + udpAddr := &net.UDPAddr{ + IP: net.ParseIP(fmt.Sprintf("10.0.%v.%v", r, a)), + Port: int(port), + } + addrs = append(addrs, &EdgeAddr{tcpAddr, udpAddr}) } addrMap[srv] = addrs numAddrs += len(addrs) @@ -74,13 +79,13 @@ func mockNetLookupIP( m mockAddrs, ) func(host string) ([]net.IP, error) { return func(host string) ([]net.IP, error) { - for srv, tcpAddrs := range m.addrMap { + for srv, addrs := range m.addrMap { if srv.Target != host { continue } - result := make([]net.IP, len(tcpAddrs)) - for i, tcpAddr := range tcpAddrs { - result[i] = tcpAddr.IP + result := make([]net.IP, len(addrs)) + for i, addr := range addrs { + result[i] = addr.TCP.IP } return result, nil } diff --git a/edgediscovery/allregions/region.go b/edgediscovery/allregions/region.go index 30b19808..4d268c77 100644 --- a/edgediscovery/allregions/region.go +++ b/edgediscovery/allregions/region.go @@ -1,29 +1,27 @@ package allregions -import ( - "net" -) - // Region contains cloudflared edge addresses. The edge is partitioned into several regions for // redundancy purposes. type Region struct { - connFor map[*net.TCPAddr]UsedBy + connFor map[*EdgeAddr]UsedBy } // NewRegion creates a region with the given addresses, which are all unused. -func NewRegion(addrs []*net.TCPAddr) Region { +func NewRegion(addrs []*EdgeAddr) Region { // The zero value of UsedBy is Unused(), so we can just initialize the map's values with their // zero values. - m := make(map[*net.TCPAddr]UsedBy) + connFor := make(map[*EdgeAddr]UsedBy) for _, addr := range addrs { - m[addr] = Unused() + connFor[addr] = Unused() + } + return Region{ + connFor: connFor, } - return Region{connFor: m} } // AddrUsedBy finds the address used by the given connection in this region. // Returns nil if the connection isn't using any IP. -func (r *Region) AddrUsedBy(connID int) *net.TCPAddr { +func (r *Region) AddrUsedBy(connID int) *EdgeAddr { for addr, used := range r.connFor { if used.Used && used.ConnID == connID { return addr @@ -45,7 +43,7 @@ func (r Region) AvailableAddrs() int { // GetUnusedIP returns a random unused address in this region. // Returns nil if all addresses are in use. -func (r Region) GetUnusedIP(excluding *net.TCPAddr) *net.TCPAddr { +func (r Region) GetUnusedIP(excluding *EdgeAddr) *EdgeAddr { for addr, usedby := range r.connFor { if !usedby.Used && addr != excluding { return addr @@ -55,7 +53,7 @@ func (r Region) GetUnusedIP(excluding *net.TCPAddr) *net.TCPAddr { } // Use the address, assigning it to a proxy connection. -func (r Region) Use(addr *net.TCPAddr, connID int) { +func (r Region) Use(addr *EdgeAddr, connID int) { if addr == nil { return } @@ -63,7 +61,7 @@ func (r Region) Use(addr *net.TCPAddr, connID int) { } // GetAnyAddress returns an arbitrary address from the region. -func (r Region) GetAnyAddress() *net.TCPAddr { +func (r Region) GetAnyAddress() *EdgeAddr { for addr := range r.connFor { return addr } @@ -72,7 +70,7 @@ func (r Region) GetAnyAddress() *net.TCPAddr { // GiveBack the address, ensuring it is no longer assigned to an IP. // Returns true if the address is in this region. -func (r Region) GiveBack(addr *net.TCPAddr) (ok bool) { +func (r Region) GiveBack(addr *EdgeAddr) (ok bool) { if _, ok := r.connFor[addr]; !ok { return false } diff --git a/edgediscovery/allregions/region_test.go b/edgediscovery/allregions/region_test.go index f5fb64ad..d83dea61 100644 --- a/edgediscovery/allregions/region_test.go +++ b/edgediscovery/allregions/region_test.go @@ -1,15 +1,12 @@ package allregions import ( - "fmt" - "net" "reflect" "testing" ) func TestRegion_New(t *testing.T) { - r := NewRegion([]*net.TCPAddr{&addr0, &addr1, &addr2}) - fmt.Println(r.connFor) + r := NewRegion([]*EdgeAddr{&addr0, &addr1, &addr2}) if r.AvailableAddrs() != 3 { t.Errorf("r.AvailableAddrs() == %v but want 3", r.AvailableAddrs()) } @@ -17,7 +14,7 @@ func TestRegion_New(t *testing.T) { func TestRegion_AddrUsedBy(t *testing.T) { type fields struct { - connFor map[*net.TCPAddr]UsedBy + connFor map[*EdgeAddr]UsedBy } type args struct { connID int @@ -26,11 +23,11 @@ func TestRegion_AddrUsedBy(t *testing.T) { name string fields fields args args - want *net.TCPAddr + want *EdgeAddr }{ { name: "happy trivial test", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr0: InUse(0), }}, args: args{connID: 0}, @@ -38,7 +35,7 @@ func TestRegion_AddrUsedBy(t *testing.T) { }, { name: "sad trivial test", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr0: InUse(0), }}, args: args{connID: 1}, @@ -46,7 +43,7 @@ func TestRegion_AddrUsedBy(t *testing.T) { }, { name: "sad test", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr0: InUse(0), &addr1: InUse(1), &addr2: InUse(2), @@ -56,7 +53,7 @@ func TestRegion_AddrUsedBy(t *testing.T) { }, { name: "happy test", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr0: InUse(0), &addr1: InUse(1), &addr2: InUse(2), @@ -79,7 +76,7 @@ func TestRegion_AddrUsedBy(t *testing.T) { func TestRegion_AvailableAddrs(t *testing.T) { type fields struct { - connFor map[*net.TCPAddr]UsedBy + connFor map[*EdgeAddr]UsedBy } tests := []struct { name string @@ -88,7 +85,7 @@ func TestRegion_AvailableAddrs(t *testing.T) { }{ { name: "contains addresses", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr0: InUse(0), &addr1: Unused(), &addr2: InUse(2), @@ -97,7 +94,7 @@ func TestRegion_AvailableAddrs(t *testing.T) { }, { name: "all free", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr0: Unused(), &addr1: Unused(), &addr2: Unused(), @@ -106,7 +103,7 @@ func TestRegion_AvailableAddrs(t *testing.T) { }, { name: "all used", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr0: InUse(0), &addr1: InUse(1), &addr2: InUse(2), @@ -115,7 +112,7 @@ func TestRegion_AvailableAddrs(t *testing.T) { }, { name: "empty", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{}}, + fields: fields{connFor: map[*EdgeAddr]UsedBy{}}, want: 0, }, } @@ -133,20 +130,20 @@ func TestRegion_AvailableAddrs(t *testing.T) { func TestRegion_GetUnusedIP(t *testing.T) { type fields struct { - connFor map[*net.TCPAddr]UsedBy + connFor map[*EdgeAddr]UsedBy } type args struct { - excluding *net.TCPAddr + excluding *EdgeAddr } tests := []struct { name string fields fields args args - want *net.TCPAddr + want *EdgeAddr }{ { name: "happy test with excluding set", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr0: Unused(), &addr1: Unused(), &addr2: InUse(2), @@ -156,7 +153,7 @@ func TestRegion_GetUnusedIP(t *testing.T) { }, { name: "happy test with no excluding", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr0: InUse(0), &addr1: Unused(), &addr2: InUse(2), @@ -166,7 +163,7 @@ func TestRegion_GetUnusedIP(t *testing.T) { }, { name: "sad test with no excluding", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr0: InUse(0), &addr1: InUse(1), &addr2: InUse(2), @@ -176,7 +173,7 @@ func TestRegion_GetUnusedIP(t *testing.T) { }, { name: "sad test with excluding", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr0: Unused(), &addr1: InUse(1), &addr2: InUse(2), @@ -199,10 +196,10 @@ func TestRegion_GetUnusedIP(t *testing.T) { func TestRegion_GiveBack(t *testing.T) { type fields struct { - connFor map[*net.TCPAddr]UsedBy + connFor map[*EdgeAddr]UsedBy } type args struct { - addr *net.TCPAddr + addr *EdgeAddr } tests := []struct { name string @@ -213,7 +210,7 @@ func TestRegion_GiveBack(t *testing.T) { }{ { name: "sad test with excluding", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr1: InUse(1), }}, args: args{addr: &addr1}, @@ -222,7 +219,7 @@ func TestRegion_GiveBack(t *testing.T) { }, { name: "sad test with excluding", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr1: InUse(1), }}, args: args{addr: &addr2}, @@ -247,7 +244,7 @@ func TestRegion_GiveBack(t *testing.T) { func TestRegion_GetAnyAddress(t *testing.T) { type fields struct { - connFor map[*net.TCPAddr]UsedBy + connFor map[*EdgeAddr]UsedBy } tests := []struct { name string @@ -256,19 +253,19 @@ func TestRegion_GetAnyAddress(t *testing.T) { }{ { name: "Sad test -- GetAnyAddress should only fail if the region is empty", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{}}, + fields: fields{connFor: map[*EdgeAddr]UsedBy{}}, wantNil: true, }, { name: "Happy test (all addresses unused)", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr0: Unused(), }}, wantNil: false, }, { name: "Happy test (GetAnyAddress can still return addresses used by proxy conns)", - fields: fields{connFor: map[*net.TCPAddr]UsedBy{ + fields: fields{connFor: map[*EdgeAddr]UsedBy{ &addr0: InUse(2), }}, wantNil: false, diff --git a/edgediscovery/allregions/regions.go b/edgediscovery/allregions/regions.go index cabb871c..5098561e 100644 --- a/edgediscovery/allregions/regions.go +++ b/edgediscovery/allregions/regions.go @@ -2,7 +2,6 @@ package allregions import ( "fmt" - "net" "github.com/rs/zerolog" ) @@ -20,16 +19,16 @@ type Regions struct { // ResolveEdge resolves the Cloudflare edge, returning all regions discovered. func ResolveEdge(log *zerolog.Logger) (*Regions, error) { - addrLists, err := edgeDiscovery(log) + edgeAddrs, err := edgeDiscovery(log) if err != nil { return nil, err } - if len(addrLists) < 2 { - return nil, fmt.Errorf("expected at least 2 Cloudflare Regions regions, but SRV only returned %v", len(addrLists)) + if len(edgeAddrs) < 2 { + return nil, fmt.Errorf("expected at least 2 Cloudflare Regions regions, but SRV only returned %v", len(edgeAddrs)) } return &Regions{ - region1: NewRegion(addrLists[0]), - region2: NewRegion(addrLists[1]), + region1: NewRegion(edgeAddrs[0]), + region2: NewRegion(edgeAddrs[1]), }, nil } @@ -45,9 +44,9 @@ func StaticEdge(hostnames []string, log *zerolog.Logger) (*Regions, error) { // NewNoResolve doesn't resolve the edge. Instead it just uses the given addresses. // You probably only need this for testing. -func NewNoResolve(addrs []*net.TCPAddr) *Regions { - region1 := make([]*net.TCPAddr, 0) - region2 := make([]*net.TCPAddr, 0) +func NewNoResolve(addrs []*EdgeAddr) *Regions { + region1 := make([]*EdgeAddr, 0) + region2 := make([]*EdgeAddr, 0) for i, v := range addrs { if i%2 == 0 { region1 = append(region1, v) @@ -67,7 +66,7 @@ func NewNoResolve(addrs []*net.TCPAddr) *Regions { // ------------------------------------ // GetAnyAddress returns an arbitrary address from the larger region. -func (rs *Regions) GetAnyAddress() *net.TCPAddr { +func (rs *Regions) GetAnyAddress() *EdgeAddr { if addr := rs.region1.GetAnyAddress(); addr != nil { return addr } @@ -76,7 +75,7 @@ func (rs *Regions) GetAnyAddress() *net.TCPAddr { // AddrUsedBy finds the address used by the given connection. // Returns nil if the connection isn't using an address. -func (rs *Regions) AddrUsedBy(connID int) *net.TCPAddr { +func (rs *Regions) AddrUsedBy(connID int) *EdgeAddr { if addr := rs.region1.AddrUsedBy(connID); addr != nil { return addr } @@ -85,7 +84,7 @@ func (rs *Regions) AddrUsedBy(connID int) *net.TCPAddr { // GetUnusedAddr gets an unused addr from the edge, excluding the given addr. Prefer to use addresses // evenly across both regions. -func (rs *Regions) GetUnusedAddr(excluding *net.TCPAddr, connID int) *net.TCPAddr { +func (rs *Regions) GetUnusedAddr(excluding *EdgeAddr, connID int) *EdgeAddr { if rs.region1.AvailableAddrs() > rs.region2.AvailableAddrs() { return getAddrs(excluding, connID, &rs.region1, &rs.region2) } @@ -95,7 +94,7 @@ func (rs *Regions) GetUnusedAddr(excluding *net.TCPAddr, connID int) *net.TCPAdd // getAddrs tries to grab address form `first` region, then `second` region // this is an unrolled loop over 2 element array -func getAddrs(excluding *net.TCPAddr, connID int, first *Region, second *Region) *net.TCPAddr { +func getAddrs(excluding *EdgeAddr, connID int, first *Region, second *Region) *EdgeAddr { addr := first.GetUnusedIP(excluding) if addr != nil { first.Use(addr, connID) @@ -117,7 +116,7 @@ func (rs *Regions) AvailableAddrs() int { // GiveBack the address so that other connections can use it. // Returns true if the address is in this edge. -func (rs *Regions) GiveBack(addr *net.TCPAddr) bool { +func (rs *Regions) GiveBack(addr *EdgeAddr) bool { if found := rs.region1.GiveBack(addr); found { return found } diff --git a/edgediscovery/allregions/regions_test.go b/edgediscovery/allregions/regions_test.go index 8abdf950..d73b75e1 100644 --- a/edgediscovery/allregions/regions_test.go +++ b/edgediscovery/allregions/regions_test.go @@ -8,31 +8,59 @@ import ( ) var ( - addr0 = net.TCPAddr{ - IP: net.ParseIP("123.4.5.0"), - Port: 8000, - Zone: "", + addr0 = EdgeAddr{ + TCP: &net.TCPAddr{ + IP: net.ParseIP("123.4.5.0"), + Port: 8000, + Zone: "", + }, + UDP: &net.UDPAddr{ + IP: net.ParseIP("123.4.5.0"), + Port: 8000, + Zone: "", + }, } - addr1 = net.TCPAddr{ - IP: net.ParseIP("123.4.5.1"), - Port: 8000, - Zone: "", + addr1 = EdgeAddr{ + TCP: &net.TCPAddr{ + IP: net.ParseIP("123.4.5.1"), + Port: 8000, + Zone: "", + }, + UDP: &net.UDPAddr{ + IP: net.ParseIP("123.4.5.1"), + Port: 8000, + Zone: "", + }, } - addr2 = net.TCPAddr{ - IP: net.ParseIP("123.4.5.2"), - Port: 8000, - Zone: "", + addr2 = EdgeAddr{ + TCP: &net.TCPAddr{ + IP: net.ParseIP("123.4.5.2"), + Port: 8000, + Zone: "", + }, + UDP: &net.UDPAddr{ + IP: net.ParseIP("123.4.5.2"), + Port: 8000, + Zone: "", + }, } - addr3 = net.TCPAddr{ - IP: net.ParseIP("123.4.5.3"), - Port: 8000, - Zone: "", + addr3 = EdgeAddr{ + TCP: &net.TCPAddr{ + IP: net.ParseIP("123.4.5.3"), + Port: 8000, + Zone: "", + }, + UDP: &net.UDPAddr{ + IP: net.ParseIP("123.4.5.3"), + Port: 8000, + Zone: "", + }, } ) func makeRegions() Regions { - r1 := NewRegion([]*net.TCPAddr{&addr0, &addr1}) - r2 := NewRegion([]*net.TCPAddr{&addr2, &addr3}) + r1 := NewRegion([]*EdgeAddr{&addr0, &addr1}) + r2 := NewRegion([]*EdgeAddr{&addr2, &addr3}) return Regions{region1: r1, region2: r2} } @@ -105,7 +133,7 @@ func TestRegions_GetUnusedAddr_Excluding_Region2(t *testing.T) { func TestNewNoResolveBalancesRegions(t *testing.T) { type args struct { - addrs []*net.TCPAddr + addrs []*EdgeAddr } tests := []struct { name string @@ -113,11 +141,11 @@ func TestNewNoResolveBalancesRegions(t *testing.T) { }{ { name: "one address", - args: args{addrs: []*net.TCPAddr{&addr0}}, + args: args{addrs: []*EdgeAddr{&addr0}}, }, { name: "two addresses", - args: args{addrs: []*net.TCPAddr{&addr0, &addr1}}, + args: args{addrs: []*EdgeAddr{&addr0, &addr1}}, }, } for _, tt := range tests { diff --git a/edgediscovery/edgediscovery.go b/edgediscovery/edgediscovery.go index a559ae36..c93e18c9 100644 --- a/edgediscovery/edgediscovery.go +++ b/edgediscovery/edgediscovery.go @@ -2,10 +2,10 @@ package edgediscovery import ( "fmt" - "net" "sync" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/cloudflare/cloudflared/edgediscovery/allregions" ) @@ -54,7 +54,7 @@ func StaticEdge(log *zerolog.Logger, hostnames []string) (*Edge, error) { } // MockEdge creates a Cloudflare Edge from arbitrary TCP addresses. Used for testing. -func MockEdge(log *zerolog.Logger, addrs []*net.TCPAddr) *Edge { +func MockEdge(log *zerolog.Logger, addrs []*allregions.EdgeAddr) *Edge { regions := allregions.NewNoResolve(addrs) return &Edge{ log: log, @@ -67,7 +67,7 @@ func MockEdge(log *zerolog.Logger, addrs []*net.TCPAddr) *Edge { // ------------------------------------ // GetAddrForRPC gives this connection an edge Addr. -func (ed *Edge) GetAddrForRPC() (*net.TCPAddr, error) { +func (ed *Edge) GetAddrForRPC() (*allregions.EdgeAddr, error) { ed.Lock() defer ed.Unlock() addr := ed.regions.GetAnyAddress() @@ -78,9 +78,8 @@ func (ed *Edge) GetAddrForRPC() (*net.TCPAddr, error) { } // GetAddr gives this proxy connection an edge Addr. Prefer Addrs this connection has already used. -func (ed *Edge) GetAddr(connIndex int) (*net.TCPAddr, error) { +func (ed *Edge) GetAddr(connIndex int) (*allregions.EdgeAddr, error) { log := ed.log.With().Int(LogFieldConnIndex, connIndex).Logger() - ed.Lock() defer ed.Unlock() @@ -96,14 +95,12 @@ func (ed *Edge) GetAddr(connIndex int) (*net.TCPAddr, error) { log.Debug().Msg("edgediscovery - GetAddr: No addresses left to give proxy connection") return nil, errNoAddressesLeft } - log.Debug().Str(LogFieldAddress, addr.String()).Msg("edgediscovery - GetAddr: Giving connection its new address") + log.Debug().Msg("edgediscovery - GetAddr: Giving connection its new address") return addr, nil } // GetDifferentAddr gives back the proxy connection's edge Addr and uses a new one. -func (ed *Edge) GetDifferentAddr(connIndex int) (*net.TCPAddr, error) { - log := ed.log.With().Int(LogFieldConnIndex, connIndex).Logger() - +func (ed *Edge) GetDifferentAddr(connIndex int) (*allregions.EdgeAddr, error) { ed.Lock() defer ed.Unlock() @@ -117,7 +114,7 @@ func (ed *Edge) GetDifferentAddr(connIndex int) (*net.TCPAddr, error) { // note: if oldAddr were not nil, it will become available on the next iteration return nil, errNoAddressesLeft } - log.Debug().Str(LogFieldAddress, addr.String()).Msg("edgediscovery - GetDifferentAddr: Giving connection its new address") + log.Debug().Msg("edgediscovery - GetDifferentAddr: Giving connection its new address") return addr, nil } @@ -130,7 +127,7 @@ func (ed *Edge) AvailableAddrs() int { // GiveBack the address so that other connections can use it. // Returns true if the address is in this edge. -func (ed *Edge) GiveBack(addr *net.TCPAddr) bool { +func (ed *Edge) GiveBack(addr *allregions.EdgeAddr) bool { ed.Lock() defer ed.Unlock() ed.log.Debug().Msg("edgediscovery - GiveBack: Address now unused") diff --git a/edgediscovery/edgediscovery_test.go b/edgediscovery/edgediscovery_test.go index e52142a1..9cc93807 100644 --- a/edgediscovery/edgediscovery_test.go +++ b/edgediscovery/edgediscovery_test.go @@ -6,35 +6,65 @@ import ( "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + + "github.com/cloudflare/cloudflared/edgediscovery/allregions" ) var ( - addr0 = net.TCPAddr{ - IP: net.ParseIP("123.0.0.0"), - Port: 8000, - Zone: "", + addr0 = allregions.EdgeAddr{ + TCP: &net.TCPAddr{ + IP: net.ParseIP("123.0.0.0"), + Port: 8000, + Zone: "", + }, + UDP: &net.UDPAddr{ + IP: net.ParseIP("123.0.0.0"), + Port: 8000, + Zone: "", + }, } - addr1 = net.TCPAddr{ - IP: net.ParseIP("123.0.0.1"), - Port: 8000, - Zone: "", + addr1 = allregions.EdgeAddr{ + TCP: &net.TCPAddr{ + IP: net.ParseIP("123.0.0.1"), + Port: 8000, + Zone: "", + }, + UDP: &net.UDPAddr{ + IP: net.ParseIP("123.0.0.1"), + Port: 8000, + Zone: "", + }, } - addr2 = net.TCPAddr{ - IP: net.ParseIP("123.0.0.2"), - Port: 8000, - Zone: "", + addr2 = allregions.EdgeAddr{ + TCP: &net.TCPAddr{ + IP: net.ParseIP("123.0.0.2"), + Port: 8000, + Zone: "", + }, + UDP: &net.UDPAddr{ + IP: net.ParseIP("123.0.0.2"), + Port: 8000, + Zone: "", + }, } - addr3 = net.TCPAddr{ - IP: net.ParseIP("123.0.0.3"), - Port: 8000, - Zone: "", + addr3 = allregions.EdgeAddr{ + TCP: &net.TCPAddr{ + IP: net.ParseIP("123.0.0.3"), + Port: 8000, + Zone: "", + }, + UDP: &net.UDPAddr{ + IP: net.ParseIP("123.0.0.3"), + Port: 8000, + Zone: "", + }, } - log = zerolog.Nop() + testLogger = zerolog.Nop() ) func TestGiveBack(t *testing.T) { - edge := MockEdge(&log, []*net.TCPAddr{&addr0, &addr1, &addr2, &addr3}) + edge := MockEdge(&testLogger, []*allregions.EdgeAddr{&addr0, &addr1, &addr2, &addr3}) // Give this connection an address assert.Equal(t, 4, edge.AvailableAddrs()) @@ -51,7 +81,7 @@ func TestGiveBack(t *testing.T) { func TestRPCAndProxyShareSingleEdgeIP(t *testing.T) { // Make an edge with a single IP - edge := MockEdge(&log, []*net.TCPAddr{&addr0}) + edge := MockEdge(&testLogger, []*allregions.EdgeAddr{&addr0}) tunnelConnID := 0 // Use the IP for a tunnel @@ -65,7 +95,7 @@ func TestRPCAndProxyShareSingleEdgeIP(t *testing.T) { } func TestGetAddrForRPC(t *testing.T) { - edge := MockEdge(&log, []*net.TCPAddr{&addr0, &addr1, &addr2, &addr3}) + edge := MockEdge(&testLogger, []*allregions.EdgeAddr{&addr0, &addr1, &addr2, &addr3}) // Get a connection assert.Equal(t, 4, edge.AvailableAddrs()) @@ -83,7 +113,7 @@ func TestGetAddrForRPC(t *testing.T) { func TestOnePerRegion(t *testing.T) { // Make an edge with only one address - edge := MockEdge(&log, []*net.TCPAddr{&addr0, &addr1}) + edge := MockEdge(&testLogger, []*allregions.EdgeAddr{&addr0, &addr1}) // Use the only address const connID = 0 @@ -105,7 +135,7 @@ func TestOnePerRegion(t *testing.T) { func TestOnlyOneAddrLeft(t *testing.T) { // Make an edge with only one address - edge := MockEdge(&log, []*net.TCPAddr{&addr0}) + edge := MockEdge(&testLogger, []*allregions.EdgeAddr{&addr0}) // Use the only address const connID = 0 @@ -125,7 +155,7 @@ func TestOnlyOneAddrLeft(t *testing.T) { func TestNoAddrsLeft(t *testing.T) { // Make an edge with no addresses - edge := MockEdge(&log, []*net.TCPAddr{}) + edge := MockEdge(&testLogger, []*allregions.EdgeAddr{}) _, err := edge.GetAddr(2) assert.Error(t, err) @@ -134,7 +164,7 @@ func TestNoAddrsLeft(t *testing.T) { } func TestGetAddr(t *testing.T) { - edge := MockEdge(&log, []*net.TCPAddr{&addr0, &addr1, &addr2, &addr3}) + edge := MockEdge(&testLogger, []*allregions.EdgeAddr{&addr0, &addr1, &addr2, &addr3}) // Give this connection an address const connID = 0 @@ -149,7 +179,7 @@ func TestGetAddr(t *testing.T) { } func TestGetDifferentAddr(t *testing.T) { - edge := MockEdge(&log, []*net.TCPAddr{&addr0, &addr1, &addr2, &addr3}) + edge := MockEdge(&testLogger, []*allregions.EdgeAddr{&addr0, &addr1, &addr2, &addr3}) // Give this connection an address assert.Equal(t, 4, edge.AvailableAddrs()) diff --git a/origin/supervisor.go b/origin/supervisor.go index 09888514..a0bbaa90 100644 --- a/origin/supervisor.go +++ b/origin/supervisor.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "time" "github.com/google/uuid" @@ -12,6 +11,7 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" + "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" @@ -60,7 +60,7 @@ var errEarlyShutdown = errors.New("shutdown started") type tunnelError struct { index int - addr *net.TCPAddr + addr *allregions.EdgeAddr err error } @@ -226,7 +226,7 @@ func (s *Supervisor) startFirstTunnel( connectedSignal *signal.Signal, ) { var ( - addr *net.TCPAddr + addr *allregions.EdgeAddr err error ) const firstConnIndex = 0 @@ -294,7 +294,7 @@ func (s *Supervisor) startTunnel( connectedSignal *signal.Signal, ) { var ( - addr *net.TCPAddr + addr *allregions.EdgeAddr err error ) defer func() { @@ -347,7 +347,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) return nil, err } - edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.EdgeTLSConfigs[connection.H2mux], arbitraryEdgeIP) + edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.EdgeTLSConfigs[connection.H2mux], arbitraryEdgeIP.TCP) if err != nil { return nil, err } diff --git a/origin/tunnel.go b/origin/tunnel.go index 83a74302..366ac54f 100644 --- a/origin/tunnel.go +++ b/origin/tunnel.go @@ -17,6 +17,7 @@ import ( "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" + "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" @@ -125,7 +126,7 @@ func ServeTunnelLoop( ctx context.Context, credentialManager *reconnectCredentialManager, config *TunnelConfig, - addr *net.TCPAddr, + addr *allregions.EdgeAddr, connIndex uint8, connectedSignal *signal.Signal, cloudflaredUUID uuid.UUID, @@ -246,7 +247,7 @@ func ServeTunnel( connLog *zerolog.Logger, credentialManager *reconnectCredentialManager, config *TunnelConfig, - addr *net.TCPAddr, + addr *allregions.EdgeAddr, connIndex uint8, fuse *h2mux.BooleanFuse, backoff *protocolFallback, @@ -270,7 +271,7 @@ func ServeTunnel( defer config.Observer.SendDisconnect(connIndex) - edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr) + edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP) if err != nil { connLog.Err(err).Msg("Unable to establish connection with Cloudflare edge") return err, true