cloudflared-mirror/edgediscovery/allregions/mocks_for_test.go

90 lines
2.4 KiB
Go
Raw Permalink Normal View History

package allregions
import (
"fmt"
"math"
"math/rand"
"net"
"reflect"
"testing/quick"
)
type mockAddrs struct {
// a set of synthetic SRV records
addrMap map[net.SRV][]*net.TCPAddr
// the total number of addresses, aggregated across addrMap.
// For the convenience of test code that would otherwise have to compute
// this by hand every time.
numAddrs int
}
func newMockAddrs(port uint16, numRegions uint8, numAddrsPerRegion uint8) mockAddrs {
addrMap := make(map[net.SRV][]*net.TCPAddr)
numAddrs := 0
for r := uint8(0); r < numRegions; r++ {
var (
srv = net.SRV{Target: fmt.Sprintf("test-region-%v.example.com", r), Port: port}
addrs []*net.TCPAddr
)
for a := uint8(0); a < numAddrsPerRegion; a++ {
addrs = append(addrs, &net.TCPAddr{
IP: net.ParseIP(fmt.Sprintf("10.0.%v.%v", r, a)),
Port: int(port),
})
}
addrMap[srv] = addrs
numAddrs += len(addrs)
}
return mockAddrs{addrMap: addrMap, numAddrs: numAddrs}
}
var _ quick.Generator = mockAddrs{}
func (mockAddrs) Generate(rand *rand.Rand, size int) reflect.Value {
port := uint16(rand.Intn(math.MaxUint16))
numRegions := uint8(1 + rand.Intn(10))
numAddrsPerRegion := uint8(1 + rand.Intn(32))
result := newMockAddrs(port, numRegions, numAddrsPerRegion)
return reflect.ValueOf(result)
}
// Returns a function compatible with net.LookupSRV that will return the SRV
// records from mockAddrs.
func mockNetLookupSRV(
m mockAddrs,
) func(service, proto, name string) (cname string, addrs []*net.SRV, err error) {
var addrs []*net.SRV
for k := range m.addrMap {
addr := k
addrs = append(addrs, &addr)
// We can't just do
// addrs = append(addrs, &k)
// `k` will be reused by subsequent loop iterations,
// so all the copies of `&k` would point to the same location.
}
return func(_, _, _ string) (string, []*net.SRV, error) {
return "", addrs, nil
}
}
// Returns a function compatible with net.LookupIP that translates the SRV records
// from mockAddrs into IP addresses, based on the TCP addresses in mockAddrs.
func mockNetLookupIP(
m mockAddrs,
) func(host string) ([]net.IP, error) {
return func(host string) ([]net.IP, error) {
for srv, tcpAddrs := range m.addrMap {
if srv.Target != host {
continue
}
result := make([]net.IP, len(tcpAddrs))
for i, tcpAddr := range tcpAddrs {
result[i] = tcpAddr.IP
}
return result, nil
}
return nil, fmt.Errorf("No IPs for %v", host)
}
}