From fefef3c43bba8a846d1ba7f6a99e5a99c21bf5a1 Mon Sep 17 00:00:00 2001 From: iBug Date: Thu, 12 Jan 2023 14:53:10 +0800 Subject: [PATCH] Replace Dial with ListenUDP, add unit test --- cmd/cloudflared/tunnel/configuration.go | 14 ++++++------- cmd/cloudflared/tunnel/configuration_test.go | 21 ++++++++++++++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index ac2ff126..dd2ca67d 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -489,19 +489,17 @@ func parseConfigBindAddress(ipstr string) (net.IP, error) { } func testIPBindable(ip net.IP) error { - var network, address string - if ip.To4() != nil { - network, address = "udp4", "127.0.0.1:4" - } else { - network, address = "udp6", "[::1]:4" + // "Unspecified" = let OS choose, so always bindable + if ip == nil { + return nil } - dialer := net.Dialer{LocalAddr: &net.UDPAddr{IP: ip}} - conn, err := dialer.Dial(network, address) + addr := &net.UDPAddr{IP: ip, Port: 0} + listener, err := net.ListenUDP("udp", addr) if err != nil { return err } - conn.Close() + listener.Close() return nil } diff --git a/cmd/cloudflared/tunnel/configuration_test.go b/cmd/cloudflared/tunnel/configuration_test.go index b4edf636..237bc829 100644 --- a/cmd/cloudflared/tunnel/configuration_test.go +++ b/cmd/cloudflared/tunnel/configuration_test.go @@ -9,6 +9,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/asn1" + "net" "os" "testing" @@ -214,3 +215,23 @@ func getCertPoolSubjects(certPool *x509.CertPool) ([]*pkix.Name, error) { func isUnrecoverableError(err error) bool { return err != nil && err.Error() != "crypto/x509: system root pool is not available on Windows" } + +func TestTestIPBindable(t *testing.T) { + assert.Nil(t, testIPBindable(nil)) + + // Public services - if one of these IPs is on the machine, the test environment is too weird + assert.NotNil(t, testIPBindable(net.ParseIP("8.8.8.8"))) + assert.NotNil(t, testIPBindable(net.ParseIP("1.1.1.1"))) + + addrs, err := net.InterfaceAddrs() + if err != nil { + t.Fatal(err) + } + for i, addr := range addrs { + if i >= 3 { + break + } + ip := addr.(*net.IPNet).IP + assert.Nil(t, testIPBindable(ip)) + } +}