GH-352: Add Tunnel CLI option "edge-bind-address" (#870)

* Add Tunnel CLI option "edge-bind-address"
This commit is contained in:
iBug 2023-03-01 00:11:42 +08:00 committed by GitHub
parent b97979487e
commit fed60ae4c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 129 additions and 16 deletions

View File

@ -551,11 +551,17 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: "edge-ip-version", Name: "edge-ip-version",
Usage: "Cloudflare Edge ip address version to connect with. {4, 6, auto}", Usage: "Cloudflare Edge IP address version to connect with. {4, 6, auto}",
EnvVars: []string{"TUNNEL_EDGE_IP_VERSION"}, EnvVars: []string{"TUNNEL_EDGE_IP_VERSION"},
Value: "4", Value: "4",
Hidden: false, Hidden: false,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "edge-bind-address",
Usage: "Bind to IP address for outgoing connections to Cloudflare Edge.",
EnvVars: []string{"TUNNEL_EDGE_BIND_ADDRESS"},
Hidden: false,
}),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: tlsconfig.CaCertFlag, Name: tlsconfig.CaCertFlag,
Usage: "Certificate Authority authenticating connections with Cloudflare's edge network.", Usage: "Certificate Authority authenticating connections with Cloudflare's edge network.",

View File

@ -44,7 +44,7 @@ var (
secretFlags = [2]*altsrc.StringFlag{credentialsContentsFlag, tunnelTokenFlag} secretFlags = [2]*altsrc.StringFlag{credentialsContentsFlag, tunnelTokenFlag}
defaultFeatures = []string{supervisor.FeatureAllowRemoteConfig, supervisor.FeatureSerializedHeaders, supervisor.FeatureDatagramV2, supervisor.FeatureQUICSupportEOF} defaultFeatures = []string{supervisor.FeatureAllowRemoteConfig, supervisor.FeatureSerializedHeaders, supervisor.FeatureDatagramV2, supervisor.FeatureQUICSupportEOF}
configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version"} configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address"}
) )
// returns the first path that contains a cert.pem file. If none of the DefaultConfigSearchDirectories // returns the first path that contains a cert.pem file. If none of the DefaultConfigSearchDirectories
@ -284,6 +284,18 @@ func prepareTunnelConfig(
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
edgeBindAddr, err := parseConfigBindAddress(c.String("edge-bind-address"))
if err != nil {
return nil, nil, err
}
if err := testIPBindable(edgeBindAddr); err != nil {
return nil, nil, fmt.Errorf("invalid edge-bind-address %s: %v", edgeBindAddr, err)
}
edgeIPVersion, err = adjustIPVersionByBindAddress(edgeIPVersion, edgeBindAddr)
if err != nil {
// This is not a fatal error, we just overrode edgeIPVersion
log.Warn().Str("edgeIPVersion", edgeIPVersion.String()).Err(err).Msg("Overriding edge-ip-version")
}
var pqKexIdx int var pqKexIdx int
if needPQ { if needPQ {
@ -302,6 +314,7 @@ func prepareTunnelConfig(
EdgeAddrs: c.StringSlice("edge"), EdgeAddrs: c.StringSlice("edge"),
Region: c.String("region"), Region: c.String("region"),
EdgeIPVersion: edgeIPVersion, EdgeIPVersion: edgeIPVersion,
EdgeBindAddr: edgeBindAddr,
HAConnections: c.Int("ha-connections"), HAConnections: c.Int("ha-connections"),
IncidentLookup: supervisor.NewIncidentLookup(), IncidentLookup: supervisor.NewIncidentLookup(),
IsAutoupdated: c.Bool("is-autoupdated"), IsAutoupdated: c.Bool("is-autoupdated"),
@ -394,6 +407,51 @@ func parseConfigIPVersion(version string) (v allregions.ConfigIPVersion, err err
return return
} }
func parseConfigBindAddress(ipstr string) (net.IP, error) {
// Unspecified - it's fine
if ipstr == "" {
return nil, nil
}
ip := net.ParseIP(ipstr)
if ip == nil {
return nil, fmt.Errorf("invalid value for edge-bind-address: %s", ipstr)
}
return ip, nil
}
func testIPBindable(ip net.IP) error {
// "Unspecified" = let OS choose, so always bindable
if ip == nil {
return nil
}
addr := &net.UDPAddr{IP: ip, Port: 0}
listener, err := net.ListenUDP("udp", addr)
if err != nil {
return err
}
listener.Close()
return nil
}
func adjustIPVersionByBindAddress(ipVersion allregions.ConfigIPVersion, ip net.IP) (allregions.ConfigIPVersion, error) {
if ip == nil {
return ipVersion, nil
}
// https://pkg.go.dev/net#IP.To4: "If ip is not an IPv4 address, To4 returns nil."
if ip.To4() != nil {
if ipVersion == allregions.IPv6Only {
return allregions.IPv4Only, fmt.Errorf("IPv4 bind address is specified, but edge-ip-version is IPv6")
}
return allregions.IPv4Only, nil
} else {
if ipVersion == allregions.IPv4Only {
return allregions.IPv6Only, fmt.Errorf("IPv6 bind address is specified, but edge-ip-version is IPv4")
}
return allregions.IPv6Only, nil
}
}
func newPacketConfig(c *cli.Context, logger *zerolog.Logger) (*ingress.GlobalRouterConfig, error) { func newPacketConfig(c *cli.Context, logger *zerolog.Logger) (*ingress.GlobalRouterConfig, error) {
ipv4Src, err := determineICMPv4Src(c.String("icmpv4-src"), logger) ipv4Src, err := determineICMPv4Src(c.String("icmpv4-src"), logger)
if err != nil { if err != nil {

View File

@ -9,6 +9,7 @@ import (
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"net"
"os" "os"
"testing" "testing"
@ -214,3 +215,23 @@ func getCertPoolSubjects(certPool *x509.CertPool) ([]*pkix.Name, error) {
func isUnrecoverableError(err error) bool { func isUnrecoverableError(err error) bool {
return err != nil && err.Error() != "crypto/x509: system root pool is not available on Windows" return err != nil && err.Error() != "crypto/x509: system root pool is not available on Windows"
} }
func TestTestIPBindable(t *testing.T) {
assert.Nil(t, testIPBindable(nil))
// Public services - if one of these IPs is on the machine, the test environment is too weird
assert.NotNil(t, testIPBindable(net.ParseIP("8.8.8.8")))
assert.NotNil(t, testIPBindable(net.ParseIP("1.1.1.1")))
addrs, err := net.InterfaceAddrs()
if err != nil {
t.Fatal(err)
}
for i, addr := range addrs {
if i >= 3 {
break
}
ip := addr.(*net.IPNet).IP
assert.Nil(t, testIPBindable(ip))
}
}

View File

@ -67,6 +67,7 @@ type QUICConnection struct {
func NewQUICConnection( func NewQUICConnection(
quicConfig *quic.Config, quicConfig *quic.Config,
edgeAddr net.Addr, edgeAddr net.Addr,
localAddr net.IP,
connIndex uint8, connIndex uint8,
tlsConfig *tls.Config, tlsConfig *tls.Config,
orchestrator Orchestrator, orchestrator Orchestrator,
@ -75,7 +76,7 @@ func NewQUICConnection(
logger *zerolog.Logger, logger *zerolog.Logger,
packetRouterConfig *ingress.GlobalRouterConfig, packetRouterConfig *ingress.GlobalRouterConfig,
) (*QUICConnection, error) { ) (*QUICConnection, error) {
udpConn, err := createUDPConnForConnIndex(connIndex, logger) udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -563,13 +564,17 @@ func (rp *muxerWrapper) Close() error {
return nil return nil
} }
func createUDPConnForConnIndex(connIndex uint8, logger *zerolog.Logger) (*net.UDPConn, error) { func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, logger *zerolog.Logger) (*net.UDPConn, error) {
portMapMutex.Lock() portMapMutex.Lock()
defer portMapMutex.Unlock() defer portMapMutex.Unlock()
if localIP == nil {
localIP = net.IPv4zero
}
// if port was not set yet, it will be zero, so bind will randomly allocate one. // if port was not set yet, it will be zero, so bind will randomly allocate one.
if port, ok := portForConnIndex[connIndex]; ok { if port, ok := portForConnIndex[connIndex]; ok {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: port}) udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: localIP, Port: port})
// if there wasn't an error, or if port was 0 (independently of error or not, just return) // if there wasn't an error, or if port was 0 (independently of error or not, just return)
if err == nil { if err == nil {
return udpConn, nil return udpConn, nil
@ -579,7 +584,7 @@ func createUDPConnForConnIndex(connIndex uint8, logger *zerolog.Logger) (*net.UD
} }
// if we reached here, then there was an error or port as not been allocated it. // if we reached here, then there was an error or port as not been allocated it.
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: localIP, Port: 0})
if err == nil { if err == nil {
udpAddr, ok := (udpConn.LocalAddr()).(*net.UDPAddr) udpAddr, ok := (udpConn.LocalAddr()).(*net.UDPAddr)
if !ok { if !ok {

View File

@ -572,7 +572,7 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
func TestCreateUDPConnReuseSourcePort(t *testing.T) { func TestCreateUDPConnReuseSourcePort(t *testing.T) {
logger := zerolog.Nop() logger := zerolog.Nop()
conn, err := createUDPConnForConnIndex(0, &logger) conn, err := createUDPConnForConnIndex(0, nil, &logger)
require.NoError(t, err) require.NoError(t, err)
getPortFunc := func(conn *net.UDPConn) int { getPortFunc := func(conn *net.UDPConn) int {
@ -586,17 +586,17 @@ func TestCreateUDPConnReuseSourcePort(t *testing.T) {
conn.Close() conn.Close()
// should get the same port as before. // should get the same port as before.
conn, err = createUDPConnForConnIndex(0, &logger) conn, err = createUDPConnForConnIndex(0, nil, &logger)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, initialPort, getPortFunc(conn)) require.Equal(t, initialPort, getPortFunc(conn))
// new index, should get a different port // new index, should get a different port
conn1, err := createUDPConnForConnIndex(1, &logger) conn1, err := createUDPConnForConnIndex(1, nil, &logger)
require.NoError(t, err) require.NoError(t, err)
require.NotEqual(t, initialPort, getPortFunc(conn1)) require.NotEqual(t, initialPort, getPortFunc(conn1))
// not closing the conn and trying to obtain a new conn for same index should give a different random port // not closing the conn and trying to obtain a new conn for same index should give a different random port
conn, err = createUDPConnForConnIndex(0, &logger) conn, err = createUDPConnForConnIndex(0, nil, &logger)
require.NoError(t, err) require.NoError(t, err)
require.NotEqual(t, initialPort, getPortFunc(conn)) require.NotEqual(t, initialPort, getPortFunc(conn))
} }
@ -716,6 +716,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QU
qc, err := NewQUICConnection( qc, err := NewQUICConnection(
testQUICConfig, testQUICConfig,
udpListenerAddr, udpListenerAddr,
nil,
index, index,
tlsClientConfig, tlsClientConfig,
&mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}}, &mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}},

View File

@ -41,6 +41,19 @@ const (
IPv6Only ConfigIPVersion = 6 IPv6Only ConfigIPVersion = 6
) )
func (c ConfigIPVersion) String() string {
switch c {
case Auto:
return "auto"
case IPv4Only:
return "4"
case IPv6Only:
return "6"
default:
return ""
}
}
// IPVersion is the IP version of an EdgeAddr // IPVersion is the IP version of an EdgeAddr
type EdgeIPVersion int8 type EdgeIPVersion int8

View File

@ -15,12 +15,16 @@ func DialEdge(
timeout time.Duration, timeout time.Duration,
tlsConfig *tls.Config, tlsConfig *tls.Config,
edgeTCPAddr *net.TCPAddr, edgeTCPAddr *net.TCPAddr,
localIP net.IP,
) (net.Conn, error) { ) (net.Conn, error) {
// Inherit from parent context so we can cancel (Ctrl-C) while dialing // Inherit from parent context so we can cancel (Ctrl-C) while dialing
dialCtx, dialCancel := context.WithTimeout(ctx, timeout) dialCtx, dialCancel := context.WithTimeout(ctx, timeout)
defer dialCancel() defer dialCancel()
dialer := net.Dialer{} dialer := net.Dialer{}
if localIP != nil {
dialer.LocalAddr = &net.TCPAddr{IP: localIP, Port: 0}
}
edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeTCPAddr.String()) edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeTCPAddr.String())
if err != nil { if err != nil {
return nil, newDialError(err, "DialContext error") return nil, newDialError(err, "DialContext error")

View File

@ -82,6 +82,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
log := NewConnAwareLogger(config.Log, tracker, config.Observer) log := NewConnAwareLogger(config.Log, tracker, config.Observer)
edgeAddrHandler := NewIPAddrFallback(config.MaxEdgeAddrRetries) edgeAddrHandler := NewIPAddrFallback(config.MaxEdgeAddrRetries)
edgeBindAddr := config.EdgeBindAddr
edgeTunnelServer := EdgeTunnelServer{ edgeTunnelServer := EdgeTunnelServer{
config: config, config: config,
@ -89,6 +90,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
credentialManager: reconnectCredentialManager, credentialManager: reconnectCredentialManager,
edgeAddrs: edgeIPs, edgeAddrs: edgeIPs,
edgeAddrHandler: edgeAddrHandler, edgeAddrHandler: edgeAddrHandler,
edgeBindAddr: edgeBindAddr,
tracker: tracker, tracker: tracker,
reconnectCh: reconnectCh, reconnectCh: reconnectCh,
gracefulShutdownC: gracefulShutdownC, gracefulShutdownC: gracefulShutdownC,

View File

@ -49,6 +49,7 @@ type TunnelConfig struct {
EdgeAddrs []string EdgeAddrs []string
Region string Region string
EdgeIPVersion allregions.ConfigIPVersion EdgeIPVersion allregions.ConfigIPVersion
EdgeBindAddr net.IP
HAConnections int HAConnections int
IncidentLookup IncidentLookup IncidentLookup IncidentLookup
IsAutoupdated bool IsAutoupdated bool
@ -207,6 +208,7 @@ type EdgeTunnelServer struct {
credentialManager *reconnectCredentialManager credentialManager *reconnectCredentialManager
edgeAddrHandler EdgeAddrHandler edgeAddrHandler EdgeAddrHandler
edgeAddrs *edgediscovery.Edge edgeAddrs *edgediscovery.Edge
edgeBindAddr net.IP
reconnectCh chan ReconnectSignal reconnectCh chan ReconnectSignal
gracefulShutdownC <-chan struct{} gracefulShutdownC <-chan struct{}
tracker *tunnelstate.ConnTracker tracker *tunnelstate.ConnTracker
@ -497,7 +499,7 @@ func (e *EdgeTunnelServer) serveConnection(
connIndex) connIndex)
case connection.HTTP2: case connection.HTTP2:
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP, e.edgeBindAddr)
if err != nil { if err != nil {
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge") connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
return err, true return err, true
@ -516,7 +518,7 @@ func (e *EdgeTunnelServer) serveConnection(
} }
default: default:
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP, e.edgeBindAddr)
if err != nil { if err != nil {
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge") connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
return err, true return err, true
@ -672,6 +674,7 @@ func (e *EdgeTunnelServer) serveQUIC(
quicConn, err := connection.NewQUICConnection( quicConn, err := connection.NewQUICConnection(
quicConfig, quicConfig,
edgeAddr, edgeAddr,
e.edgeBindAddr,
connIndex, connIndex,
tlsConfig, tlsConfig,
e.orchestrator, e.orchestrator,