Add Tunnel CLI option "edge-bind-address"

This commit is contained in:
iBug 2023-01-10 18:28:09 +08:00
parent de4fd472f3
commit 5a8a45be3d
7 changed files with 55 additions and 17 deletions

View File

@ -83,12 +83,12 @@ const (
LogFieldTmpTraceFilename = "tmpTraceFilename" LogFieldTmpTraceFilename = "tmpTraceFilename"
LogFieldTraceOutputFilepath = "traceOutputFilepath" LogFieldTraceOutputFilepath = "traceOutputFilepath"
tunnelCmdErrorMessage = `You did not specify any valid additional argument to the cloudflared tunnel command. tunnelCmdErrorMessage = `You did not specify any valid additional argument to the cloudflared tunnel command.
If you are trying to run a Quick Tunnel then you need to explicitly pass the --url flag. If you are trying to run a Quick Tunnel then you need to explicitly pass the --url flag.
Eg. cloudflared tunnel --url localhost:8080/. Eg. cloudflared tunnel --url localhost:8080/.
Please note that Quick Tunnels are meant to be ephemeral and should only be used for testing purposes. Please note that Quick Tunnels are meant to be ephemeral and should only be used for testing purposes.
For production usage, we recommend creating Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/) For production usage, we recommend creating Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)
` `
) )
@ -539,11 +539,17 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: "edge-ip-version", Name: "edge-ip-version",
Usage: "Cloudflare Edge ip address version to connect with. {4, 6, auto}", Usage: "Cloudflare Edge IP address version to connect with. {4, 6, auto}",
EnvVars: []string{"TUNNEL_EDGE_IP_VERSION"}, EnvVars: []string{"TUNNEL_EDGE_IP_VERSION"},
Value: "4", Value: "4",
Hidden: false, Hidden: false,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "edge-bind-address",
Usage: "Bind to IP address for outgoing connections to Cloudflare Edge.",
EnvVars: []string{"TUNNEL_EDGE_BIND_ADDRESS"},
Hidden: false,
}),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: tlsconfig.CaCertFlag, Name: tlsconfig.CaCertFlag,
Usage: "Certificate Authority authenticating connections with Cloudflare's edge network.", Usage: "Certificate Authority authenticating connections with Cloudflare's edge network.",

View File

@ -48,7 +48,7 @@ var (
secretFlags = [2]*altsrc.StringFlag{credentialsContentsFlag, tunnelTokenFlag} secretFlags = [2]*altsrc.StringFlag{credentialsContentsFlag, tunnelTokenFlag}
defaultFeatures = []string{supervisor.FeatureAllowRemoteConfig, supervisor.FeatureSerializedHeaders, supervisor.FeatureDatagramV2, supervisor.FeatureQUICSupportEOF} defaultFeatures = []string{supervisor.FeatureAllowRemoteConfig, supervisor.FeatureSerializedHeaders, supervisor.FeatureDatagramV2, supervisor.FeatureQUICSupportEOF}
configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version"} configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address"}
) )
// returns the first path that contains a cert.pem file. If none of the DefaultConfigSearchDirectories // returns the first path that contains a cert.pem file. If none of the DefaultConfigSearchDirectories
@ -349,6 +349,11 @@ func prepareTunnelConfig(
return nil, nil, err return nil, nil, err
} }
edgeBindAddr, err := parseConfigBindAddress(c.String("edge-bind-address"))
if err != nil {
return nil, nil, err
}
var pqKexIdx int var pqKexIdx int
if needPQ { if needPQ {
pqKexIdx = mathRand.Intn(len(supervisor.PQKexes)) pqKexIdx = mathRand.Intn(len(supervisor.PQKexes))
@ -366,6 +371,7 @@ func prepareTunnelConfig(
EdgeAddrs: c.StringSlice("edge"), EdgeAddrs: c.StringSlice("edge"),
Region: c.String("region"), Region: c.String("region"),
EdgeIPVersion: edgeIPVersion, EdgeIPVersion: edgeIPVersion,
EdgeBindAddr: edgeBindAddr,
HAConnections: c.Int("ha-connections"), HAConnections: c.Int("ha-connections"),
IncidentLookup: supervisor.NewIncidentLookup(), IncidentLookup: supervisor.NewIncidentLookup(),
IsAutoupdated: c.Bool("is-autoupdated"), IsAutoupdated: c.Bool("is-autoupdated"),
@ -463,6 +469,18 @@ func parseConfigIPVersion(version string) (v allregions.ConfigIPVersion, err err
return return
} }
func parseConfigBindAddress(ipstr string) (net.IP, error) {
// Unspecified - it's fine
if ipstr == "" {
return nil, nil
}
ip := net.ParseIP(ipstr)
if ip == nil {
return nil, fmt.Errorf("invalid value for edge-bind-address: %s", ipstr)
}
return ip, nil
}
func newPacketConfig(c *cli.Context, logger *zerolog.Logger) (*ingress.GlobalRouterConfig, error) { func newPacketConfig(c *cli.Context, logger *zerolog.Logger) (*ingress.GlobalRouterConfig, error) {
ipv4Src, err := determineICMPv4Src(c.String("icmpv4-src"), logger) ipv4Src, err := determineICMPv4Src(c.String("icmpv4-src"), logger)
if err != nil { if err != nil {

View File

@ -66,6 +66,7 @@ type QUICConnection struct {
func NewQUICConnection( func NewQUICConnection(
quicConfig *quic.Config, quicConfig *quic.Config,
edgeAddr net.Addr, edgeAddr net.Addr,
localAddr net.IP,
connIndex uint8, connIndex uint8,
tlsConfig *tls.Config, tlsConfig *tls.Config,
orchestrator Orchestrator, orchestrator Orchestrator,
@ -74,7 +75,7 @@ func NewQUICConnection(
logger *zerolog.Logger, logger *zerolog.Logger,
packetRouterConfig *ingress.GlobalRouterConfig, packetRouterConfig *ingress.GlobalRouterConfig,
) (*QUICConnection, error) { ) (*QUICConnection, error) {
udpConn, err := createUDPConnForConnIndex(connIndex, logger) udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -525,13 +526,17 @@ func (rp *muxerWrapper) Close() error {
return nil return nil
} }
func createUDPConnForConnIndex(connIndex uint8, logger *zerolog.Logger) (*net.UDPConn, error) { func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, logger *zerolog.Logger) (*net.UDPConn, error) {
portMapMutex.Lock() portMapMutex.Lock()
defer portMapMutex.Unlock() defer portMapMutex.Unlock()
if localIP == nil {
localIP = net.IPv4zero
}
// if port was not set yet, it will be zero, so bind will randomly allocate one. // if port was not set yet, it will be zero, so bind will randomly allocate one.
if port, ok := portForConnIndex[connIndex]; ok { if port, ok := portForConnIndex[connIndex]; ok {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: port}) udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: localIP, Port: port})
// if there wasn't an error, or if port was 0 (independently of error or not, just return) // if there wasn't an error, or if port was 0 (independently of error or not, just return)
if err == nil { if err == nil {
return udpConn, nil return udpConn, nil
@ -541,7 +546,7 @@ func createUDPConnForConnIndex(connIndex uint8, logger *zerolog.Logger) (*net.UD
} }
// if we reached here, then there was an error or port as not been allocated it. // if we reached here, then there was an error or port as not been allocated it.
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: localIP, Port: 0})
if err == nil { if err == nil {
udpAddr, ok := (udpConn.LocalAddr()).(*net.UDPAddr) udpAddr, ok := (udpConn.LocalAddr()).(*net.UDPAddr)
if !ok { if !ok {

View File

@ -572,7 +572,7 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
func TestCreateUDPConnReuseSourcePort(t *testing.T) { func TestCreateUDPConnReuseSourcePort(t *testing.T) {
logger := zerolog.Nop() logger := zerolog.Nop()
conn, err := createUDPConnForConnIndex(0, &logger) conn, err := createUDPConnForConnIndex(0, nil, &logger)
require.NoError(t, err) require.NoError(t, err)
getPortFunc := func(conn *net.UDPConn) int { getPortFunc := func(conn *net.UDPConn) int {
@ -586,17 +586,17 @@ func TestCreateUDPConnReuseSourcePort(t *testing.T) {
conn.Close() conn.Close()
// should get the same port as before. // should get the same port as before.
conn, err = createUDPConnForConnIndex(0, &logger) conn, err = createUDPConnForConnIndex(0, nil, &logger)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, initialPort, getPortFunc(conn)) require.Equal(t, initialPort, getPortFunc(conn))
// new index, should get a different port // new index, should get a different port
conn1, err := createUDPConnForConnIndex(1, &logger) conn1, err := createUDPConnForConnIndex(1, nil, &logger)
require.NoError(t, err) require.NoError(t, err)
require.NotEqual(t, initialPort, getPortFunc(conn1)) require.NotEqual(t, initialPort, getPortFunc(conn1))
// not closing the conn and trying to obtain a new conn for same index should give a different random port // not closing the conn and trying to obtain a new conn for same index should give a different random port
conn, err = createUDPConnForConnIndex(0, &logger) conn, err = createUDPConnForConnIndex(0, nil, &logger)
require.NoError(t, err) require.NoError(t, err)
require.NotEqual(t, initialPort, getPortFunc(conn)) require.NotEqual(t, initialPort, getPortFunc(conn))
} }

View File

@ -15,12 +15,16 @@ func DialEdge(
timeout time.Duration, timeout time.Duration,
tlsConfig *tls.Config, tlsConfig *tls.Config,
edgeTCPAddr *net.TCPAddr, edgeTCPAddr *net.TCPAddr,
localIP net.IP,
) (net.Conn, error) { ) (net.Conn, error) {
// Inherit from parent context so we can cancel (Ctrl-C) while dialing // Inherit from parent context so we can cancel (Ctrl-C) while dialing
dialCtx, dialCancel := context.WithTimeout(ctx, timeout) dialCtx, dialCancel := context.WithTimeout(ctx, timeout)
defer dialCancel() defer dialCancel()
dialer := net.Dialer{} dialer := net.Dialer{}
if localIP != nil {
dialer.LocalAddr = &net.TCPAddr{IP: localIP, Port: 0}
}
edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeTCPAddr.String()) edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeTCPAddr.String())
if err != nil { if err != nil {
return nil, newDialError(err, "DialContext error") return nil, newDialError(err, "DialContext error")

View File

@ -94,6 +94,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
log := NewConnAwareLogger(config.Log, tracker, config.Observer) log := NewConnAwareLogger(config.Log, tracker, config.Observer)
edgeAddrHandler := NewIPAddrFallback(config.MaxEdgeAddrRetries) edgeAddrHandler := NewIPAddrFallback(config.MaxEdgeAddrRetries)
edgeBindAddr := config.EdgeBindAddr
edgeTunnelServer := EdgeTunnelServer{ edgeTunnelServer := EdgeTunnelServer{
config: config, config: config,
@ -102,6 +103,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
credentialManager: reconnectCredentialManager, credentialManager: reconnectCredentialManager,
edgeAddrs: edgeIPs, edgeAddrs: edgeIPs,
edgeAddrHandler: edgeAddrHandler, edgeAddrHandler: edgeAddrHandler,
edgeBindAddr: edgeBindAddr,
tracker: tracker, tracker: tracker,
reconnectCh: reconnectCh, reconnectCh: reconnectCh,
gracefulShutdownC: gracefulShutdownC, gracefulShutdownC: gracefulShutdownC,
@ -384,7 +386,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
return nil, err return nil, err
} }
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.EdgeTLSConfigs[connection.H2mux], arbitraryEdgeIP.TCP) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.EdgeTLSConfigs[connection.H2mux], arbitraryEdgeIP.TCP, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -49,6 +49,7 @@ type TunnelConfig struct {
EdgeAddrs []string EdgeAddrs []string
Region string Region string
EdgeIPVersion allregions.ConfigIPVersion EdgeIPVersion allregions.ConfigIPVersion
EdgeBindAddr net.IP
HAConnections int HAConnections int
IncidentLookup IncidentLookup IncidentLookup IncidentLookup
IsAutoupdated bool IsAutoupdated bool
@ -209,6 +210,7 @@ type EdgeTunnelServer struct {
credentialManager *reconnectCredentialManager credentialManager *reconnectCredentialManager
edgeAddrHandler EdgeAddrHandler edgeAddrHandler EdgeAddrHandler
edgeAddrs *edgediscovery.Edge edgeAddrs *edgediscovery.Edge
edgeBindAddr net.IP
reconnectCh chan ReconnectSignal reconnectCh chan ReconnectSignal
gracefulShutdownC <-chan struct{} gracefulShutdownC <-chan struct{}
tracker *tunnelstate.ConnTracker tracker *tunnelstate.ConnTracker
@ -499,7 +501,7 @@ func (e *EdgeTunnelServer) serveConnection(
connIndex) connIndex)
case connection.HTTP2, connection.HTTP2Warp: case connection.HTTP2, connection.HTTP2Warp:
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP, e.edgeBindAddr)
if err != nil { if err != nil {
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge") connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
return err, true return err, true
@ -518,7 +520,7 @@ func (e *EdgeTunnelServer) serveConnection(
} }
default: default:
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP, e.edgeBindAddr)
if err != nil { if err != nil {
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge") connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
return err, true return err, true
@ -678,6 +680,7 @@ func (e *EdgeTunnelServer) serveQUIC(
quicConn, err := connection.NewQUICConnection( quicConn, err := connection.NewQUICConnection(
quicConfig, quicConfig,
edgeAddr, edgeAddr,
e.edgeBindAddr,
connIndex, connIndex,
tlsConfig, tlsConfig,
e.orchestrator, e.orchestrator,