90 lines
2.4 KiB
Go
90 lines
2.4 KiB
Go
|
package allregions
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"math"
|
||
|
"math/rand"
|
||
|
"net"
|
||
|
"reflect"
|
||
|
"testing/quick"
|
||
|
)
|
||
|
|
||
|
type mockAddrs struct {
|
||
|
// a set of synthetic SRV records
|
||
|
addrMap map[net.SRV][]*net.TCPAddr
|
||
|
// the total number of addresses, aggregated across addrMap.
|
||
|
// For the convenience of test code that would otherwise have to compute
|
||
|
// this by hand every time.
|
||
|
numAddrs int
|
||
|
}
|
||
|
|
||
|
func newMockAddrs(port uint16, numRegions uint8, numAddrsPerRegion uint8) mockAddrs {
|
||
|
addrMap := make(map[net.SRV][]*net.TCPAddr)
|
||
|
numAddrs := 0
|
||
|
|
||
|
for r := uint8(0); r < numRegions; r++ {
|
||
|
var (
|
||
|
srv = net.SRV{Target: fmt.Sprintf("test-region-%v.example.com", r), Port: port}
|
||
|
addrs []*net.TCPAddr
|
||
|
)
|
||
|
for a := uint8(0); a < numAddrsPerRegion; a++ {
|
||
|
addrs = append(addrs, &net.TCPAddr{
|
||
|
IP: net.ParseIP(fmt.Sprintf("10.0.%v.%v", r, a)),
|
||
|
Port: int(port),
|
||
|
})
|
||
|
}
|
||
|
addrMap[srv] = addrs
|
||
|
numAddrs += len(addrs)
|
||
|
}
|
||
|
return mockAddrs{addrMap: addrMap, numAddrs: numAddrs}
|
||
|
}
|
||
|
|
||
|
var _ quick.Generator = mockAddrs{}
|
||
|
|
||
|
func (mockAddrs) Generate(rand *rand.Rand, size int) reflect.Value {
|
||
|
port := uint16(rand.Intn(math.MaxUint16))
|
||
|
numRegions := uint8(1 + rand.Intn(10))
|
||
|
numAddrsPerRegion := uint8(1 + rand.Intn(32))
|
||
|
result := newMockAddrs(port, numRegions, numAddrsPerRegion)
|
||
|
return reflect.ValueOf(result)
|
||
|
}
|
||
|
|
||
|
// Returns a function compatible with net.LookupSRV that will return the SRV
|
||
|
// records from mockAddrs.
|
||
|
func mockNetLookupSRV(
|
||
|
m mockAddrs,
|
||
|
) func(service, proto, name string) (cname string, addrs []*net.SRV, err error) {
|
||
|
var addrs []*net.SRV
|
||
|
for k := range m.addrMap {
|
||
|
addr := k
|
||
|
addrs = append(addrs, &addr)
|
||
|
// We can't just do
|
||
|
// addrs = append(addrs, &k)
|
||
|
// `k` will be reused by subsequent loop iterations,
|
||
|
// so all the copies of `&k` would point to the same location.
|
||
|
}
|
||
|
return func(_, _, _ string) (string, []*net.SRV, error) {
|
||
|
return "", addrs, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Returns a function compatible with net.LookupIP that translates the SRV records
|
||
|
// from mockAddrs into IP addresses, based on the TCP addresses in mockAddrs.
|
||
|
func mockNetLookupIP(
|
||
|
m mockAddrs,
|
||
|
) func(host string) ([]net.IP, error) {
|
||
|
return func(host string) ([]net.IP, error) {
|
||
|
for srv, tcpAddrs := range m.addrMap {
|
||
|
if srv.Target != host {
|
||
|
continue
|
||
|
}
|
||
|
result := make([]net.IP, len(tcpAddrs))
|
||
|
for i, tcpAddr := range tcpAddrs {
|
||
|
result[i] = tcpAddr.IP
|
||
|
}
|
||
|
return result, nil
|
||
|
}
|
||
|
return nil, fmt.Errorf("No IPs for %v", host)
|
||
|
}
|
||
|
}
|