TUN-6530: Implement ICMPv4 proxy

This proxy uses unprivileged datagram-oriented endpoint and is shared by all quic connections
This commit is contained in:
cthuang 2022-08-18 16:03:47 +01:00
parent f6bd4aa039
commit 59f5b0df83
10 changed files with 440 additions and 126 deletions

View File

@ -48,6 +48,7 @@ type QUICConnection struct {
sessionManager datagramsession.Manager sessionManager datagramsession.Manager
// datagramMuxer mux/demux datagrams from quic connection // datagramMuxer mux/demux datagrams from quic connection
datagramMuxer quicpogs.BaseDatagramMuxer datagramMuxer quicpogs.BaseDatagramMuxer
packetRouter *packetRouter
controlStreamHandler ControlStreamHandler controlStreamHandler ControlStreamHandler
connOptions *tunnelpogs.ConnectionOptions connOptions *tunnelpogs.ConnectionOptions
} }
@ -61,6 +62,7 @@ func NewQUICConnection(
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler ControlStreamHandler, controlStreamHandler ControlStreamHandler,
logger *zerolog.Logger, logger *zerolog.Logger,
icmpProxy ingress.ICMPProxy,
) (*QUICConnection, error) { ) (*QUICConnection, error) {
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig) session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
if err != nil { if err != nil {
@ -68,7 +70,20 @@ func NewQUICConnection(
} }
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
datagramMuxer := quicpogs.NewDatagramMuxer(session, logger, sessionDemuxChan) var (
datagramMuxer quicpogs.BaseDatagramMuxer
pr *packetRouter
)
if icmpProxy != nil {
pr = &packetRouter{
muxer: quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan),
icmpProxy: icmpProxy,
logger: logger,
}
datagramMuxer = pr.muxer
} else {
datagramMuxer = quicpogs.NewDatagramMuxer(session, logger, sessionDemuxChan)
}
sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan)
return &QUICConnection{ return &QUICConnection{
@ -77,6 +92,7 @@ func NewQUICConnection(
logger: logger, logger: logger,
sessionManager: sessionManager, sessionManager: sessionManager,
datagramMuxer: datagramMuxer, datagramMuxer: datagramMuxer,
packetRouter: pr,
controlStreamHandler: controlStreamHandler, controlStreamHandler: controlStreamHandler,
connOptions: connOptions, connOptions: connOptions,
}, nil }, nil
@ -117,6 +133,12 @@ func (q *QUICConnection) Serve(ctx context.Context) error {
defer cancel() defer cancel()
return q.datagramMuxer.ServeReceive(ctx) return q.datagramMuxer.ServeReceive(ctx)
}) })
if q.packetRouter != nil {
errGroup.Go(func() error {
defer cancel()
return q.packetRouter.serve(ctx)
})
}
return errGroup.Wait() return errGroup.Wait()
} }
@ -305,6 +327,32 @@ func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32,
return q.orchestrator.UpdateConfig(version, config) return q.orchestrator.UpdateConfig(version, config)
} }
type packetRouter struct {
muxer *quicpogs.DatagramMuxerV2
icmpProxy ingress.ICMPProxy
logger *zerolog.Logger
}
func (pr *packetRouter) serve(ctx context.Context) error {
icmpDecoder := packet.NewICMPDecoder()
for {
pk, err := pr.muxer.ReceivePacket(ctx)
if err != nil {
return err
}
icmpPacket, err := icmpDecoder.Decode(pk)
if err != nil {
pr.logger.Err(err).Msg("Failed to decode ICMP packet from quic datagram")
continue
}
if err := pr.icmpProxy.Request(icmpPacket, pr.muxer); err != nil {
pr.logger.Err(err).Str("src", icmpPacket.Src.String()).Str("dst", icmpPacket.Dst.String()).Msg("Failed to send ICMP packet")
continue
}
}
}
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to // streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
// the client. // the client.
type streamReadWriteAcker struct { type streamReadWriteAcker struct {

View File

@ -682,6 +682,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection
&tunnelpogs.ConnectionOptions{}, &tunnelpogs.ConnectionOptions{},
fakeControlStream{}, fakeControlStream{},
&log, &log,
nil,
) )
require.NoError(t, err) require.NoError(t, err)
return qc return qc

View File

@ -0,0 +1,139 @@
package ingress
import (
"context"
"fmt"
"net"
"strconv"
"github.com/google/gopacket/layers"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"golang.org/x/net/icmp"
"github.com/cloudflare/cloudflared/packet"
)
// ICMPProxy sends ICMP messages and listens for their responses
type ICMPProxy interface {
// Request sends an ICMP message
Request(pk *packet.ICMP, responder packet.FlowResponder) error
// ListenResponse listens for responses to the requests until context is done
ListenResponse(ctx context.Context) error
}
// TODO: TUN-6654 Extend support to IPv6
type icmpProxy struct {
srcFlowTracker *packet.FlowTracker
conn *icmp.PacketConn
logger *zerolog.Logger
encoder *packet.Encoder
}
// TODO: TUN-6586: Use echo ID as FlowID
type seqNumFlowID int
func (snf seqNumFlowID) ID() string {
return strconv.FormatInt(int64(snf), 10)
}
func NewICMPProxy(network string, listenIP net.IP, logger *zerolog.Logger) (*icmpProxy, error) {
conn, err := icmp.ListenPacket(network, listenIP.String())
if err != nil {
return nil, err
}
return &icmpProxy{
srcFlowTracker: packet.NewFlowTracker(),
conn: conn,
logger: logger,
encoder: packet.NewEncoder(),
}, nil
}
func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FlowResponder) error {
switch body := pk.Message.Body.(type) {
case *icmp.Echo:
return ip.sendICMPEchoRequest(pk, body, responder)
default:
return fmt.Errorf("sending ICMP %s is not implemented", pk.Type)
}
}
func (ip *icmpProxy) ListenResponse(ctx context.Context) error {
go func() {
<-ctx.Done()
ip.conn.Close()
}()
buf := make([]byte, 1500)
for {
n, src, err := ip.conn.ReadFrom(buf)
if err != nil {
return err
}
// TODO: TUN-6654 Check for IPv6
msg, err := icmp.ParseMessage(int(layers.IPProtocolICMPv4), buf[:n])
if err != nil {
ip.logger.Error().Err(err).Str("src", src.String()).Msg("Failed to parse ICMP message")
continue
}
switch body := msg.Body.(type) {
case *icmp.Echo:
if err := ip.handleEchoResponse(msg, body); err != nil {
ip.logger.Error().Err(err).Str("src", src.String()).Msg("Failed to handle ICMP response")
continue
}
default:
ip.logger.Warn().
Str("icmpType", fmt.Sprintf("%s", msg.Type)).
Msgf("Responding to this type of ICMP is not implemented")
continue
}
}
}
func (ip *icmpProxy) sendICMPEchoRequest(pk *packet.ICMP, echo *icmp.Echo, responder packet.FlowResponder) error {
flow := packet.Flow{
Src: pk.Src,
Dst: pk.Dst,
Responder: responder,
}
// TODO: TUN-6586 rewrite ICMP echo request identifier and use it to track flows
flowID := seqNumFlowID(echo.Seq)
// TODO: TUN-6588 clean up flows
if replaced := ip.srcFlowTracker.Register(flowID, &flow, true); replaced {
ip.logger.Info().Str("src", flow.Src.String()).Str("dst", flow.Dst.String()).Msg("Replaced flow")
}
var pseudoHeader []byte = nil
serializedMsg, err := pk.Marshal(pseudoHeader)
if err != nil {
return errors.Wrap(err, "Failed to encode ICMP message")
}
// The address needs to be of type UDPAddr when conn is created without priviledge
_, err = ip.conn.WriteTo(serializedMsg, &net.UDPAddr{
IP: pk.Dst.AsSlice(),
})
return err
}
func (ip *icmpProxy) handleEchoResponse(msg *icmp.Message, echo *icmp.Echo) error {
flow, ok := ip.srcFlowTracker.Get(seqNumFlowID(echo.Seq))
if !ok {
return fmt.Errorf("flow not found")
}
icmpPacket := packet.ICMP{
IP: &packet.IP{
Src: flow.Dst,
Dst: flow.Src,
Protocol: layers.IPProtocol(msg.Type.Protocol()),
},
Message: msg,
}
serializedPacket, err := ip.encoder.Encode(&icmpPacket)
if err != nil {
return errors.Wrap(err, "Failed to encode ICMP message")
}
if err := flow.Responder.SendPacket(serializedPacket); err != nil {
return errors.Wrap(err, "Failed to send packet to the edge")
}
return nil
}

View File

@ -0,0 +1,150 @@
package ingress
import (
"context"
"fmt"
"net/netip"
"runtime"
"testing"
"github.com/google/gopacket/layers"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"github.com/cloudflare/cloudflared/packet"
)
var (
noopLogger = zerolog.Nop()
localhostIP = netip.MustParseAddr("127.0.0.1")
)
// TestICMPProxyEcho makes sure we can send ICMP echo via the Request method and receives response via the
// ListenResponse method
func TestICMPProxyEcho(t *testing.T) {
skipWindows(t)
const (
echoID = 36571
endSeq = 100
)
proxy, err := NewICMPProxy("udp4", localhostIP.AsSlice(), &noopLogger)
require.NoError(t, err)
proxyDone := make(chan struct{})
ctx, cancel := context.WithCancel(context.Background())
go func() {
proxy.ListenResponse(ctx)
close(proxyDone)
}()
responder := echoFlowResponder{
decoder: packet.NewICMPDecoder(),
respChan: make(chan []byte),
}
ip := packet.IP{
Src: localhostIP,
Dst: localhostIP,
Protocol: layers.IPProtocolICMPv4,
}
for i := 0; i < endSeq; i++ {
pk := packet.ICMP{
IP: &ip,
Message: &icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: echoID,
Seq: i,
Data: []byte(fmt.Sprintf("icmp echo seq %d", i)),
},
},
}
require.NoError(t, proxy.Request(&pk, &responder))
responder.validate(t, &pk)
}
cancel()
<-proxyDone
}
// TestICMPProxyRejectNotEcho makes sure it rejects messages other than echo
func TestICMPProxyRejectNotEcho(t *testing.T) {
skipWindows(t)
msgs := []icmp.Message{
{
Type: ipv4.ICMPTypeDestinationUnreachable,
Code: 1,
Body: &icmp.DstUnreach{
Data: []byte("original packet"),
},
},
{
Type: ipv4.ICMPTypeTimeExceeded,
Code: 1,
Body: &icmp.TimeExceeded{
Data: []byte("original packet"),
},
},
{
Type: ipv4.ICMPType(2),
Code: 0,
Body: &icmp.PacketTooBig{
MTU: 1280,
Data: []byte("original packet"),
},
},
}
proxy, err := NewICMPProxy("udp4", localhostIP.AsSlice(), &noopLogger)
require.NoError(t, err)
responder := echoFlowResponder{
decoder: packet.NewICMPDecoder(),
respChan: make(chan []byte),
}
for _, m := range msgs {
pk := packet.ICMP{
IP: &packet.IP{
Src: localhostIP,
Dst: localhostIP,
Protocol: layers.IPProtocolICMPv4,
},
Message: &m,
}
require.Error(t, proxy.Request(&pk, &responder))
}
}
func skipWindows(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Cannot create non-privileged datagram-oriented ICMP endpoint on Windows")
}
}
type echoFlowResponder struct {
decoder *packet.ICMPDecoder
respChan chan []byte
}
func (efr *echoFlowResponder) SendPacket(pk packet.RawPacket) error {
copiedPacket := make([]byte, len(pk.Data))
copy(copiedPacket, pk.Data)
efr.respChan <- copiedPacket
return nil
}
func (efr *echoFlowResponder) validate(t *testing.T, echoReq *packet.ICMP) {
pk := <-efr.respChan
decoded, err := efr.decoder.Decode(packet.RawPacket{Data: pk})
require.NoError(t, err)
require.Equal(t, decoded.Src, echoReq.Dst)
require.Equal(t, decoded.Dst, echoReq.Src)
require.Equal(t, echoReq.Protocol, decoded.Protocol)
require.Equal(t, ipv4.ICMPTypeEchoReply, decoded.Type)
require.Equal(t, 0, decoded.Code)
require.NotZero(t, decoded.Checksum)
// TODO: TUN-6586: Enable this validation when ICMP echo ID matches on Linux
//require.Equal(t, echoReq.Body, decoded.Body)
}

View File

@ -75,9 +75,9 @@ func NewIPDecoder() *IPDecoder {
} }
} }
func (pd *IPDecoder) Decode(packet []byte) (*IP, error) { func (pd *IPDecoder) Decode(packet RawPacket) (*IP, error) {
// Should decode to IP layer // Should decode to IP layer
decoded, err := pd.decodeByVersion(packet) decoded, err := pd.decodeByVersion(packet.Data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -139,9 +139,9 @@ func NewICMPDecoder() *ICMPDecoder {
} }
} }
func (pd *ICMPDecoder) Decode(packet []byte) (*ICMP, error) { func (pd *ICMPDecoder) Decode(packet RawPacket) (*ICMP, error) {
// Should decode to IP and optionally ICMP layer // Should decode to IP and optionally ICMP layer
decoded, err := pd.decodeByVersion(packet) decoded, err := pd.decodeByVersion(packet.Data)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -43,11 +43,11 @@ func TestDecodeIP(t *testing.T) {
p, err := encoder.Encode(&udp) p, err := encoder.Encode(&udp)
require.NoError(t, err) require.NoError(t, err)
ipPacket, err := ipDecoder.Decode(p.Data) ipPacket, err := ipDecoder.Decode(p)
require.NoError(t, err) require.NoError(t, err)
assertIPLayer(t, &udp.IP, ipPacket) assertIPLayer(t, &udp.IP, ipPacket)
icmpPacket, err := icmpDecoder.Decode(p.Data) icmpPacket, err := icmpDecoder.Decode(p)
require.Error(t, err) require.Error(t, err)
require.Nil(t, icmpPacket) require.Nil(t, icmpPacket)
} }
@ -137,14 +137,14 @@ func TestDecodeICMP(t *testing.T) {
p, err := encoder.Encode(test.packet) p, err := encoder.Encode(test.packet)
require.NoError(t, err) require.NoError(t, err)
ipPacket, err := ipDecoder.Decode(p.Data) ipPacket, err := ipDecoder.Decode(p)
require.NoError(t, err) require.NoError(t, err)
if ipPacket.Src.Is4() { if ipPacket.Src.Is4() {
assertIPLayer(t, &ipv4Packet, ipPacket) assertIPLayer(t, &ipv4Packet, ipPacket)
} else { } else {
assertIPLayer(t, &ipv6Packet, ipPacket) assertIPLayer(t, &ipv6Packet, ipPacket)
} }
icmpPacket, err := icmpDecoder.Decode(p.Data) icmpPacket, err := icmpDecoder.Decode(p)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, ipPacket, icmpPacket.IP) require.Equal(t, ipPacket, icmpPacket.IP)
@ -202,11 +202,11 @@ func TestDecodeBadPackets(t *testing.T) {
ipDecoder := NewIPDecoder() ipDecoder := NewIPDecoder()
icmpDecoder := NewICMPDecoder() icmpDecoder := NewICMPDecoder()
for _, test := range tests { for _, test := range tests {
ipPacket, err := ipDecoder.Decode(test.packet) ipPacket, err := ipDecoder.Decode(RawPacket{Data: test.packet})
require.Error(t, err) require.Error(t, err)
require.Nil(t, ipPacket) require.Nil(t, ipPacket)
icmpPacket, err := icmpDecoder.Decode(test.packet) icmpPacket, err := icmpDecoder.Decode(RawPacket{Data: test.packet})
require.Error(t, err) require.Error(t, err)
require.Nil(t, icmpPacket) require.Nil(t, icmpPacket)
} }

View File

@ -2,19 +2,17 @@ package packet
import ( import (
"errors" "errors"
"net"
"net/netip" "net/netip"
"sync" "sync"
) )
type flowID string
var ( var (
ErrFlowNotFound = errors.New("flow not found") ErrFlowNotFound = errors.New("flow not found")
) )
func newFlowID(ip net.IP) flowID { // FlowID represents a key type that can be used by FlowTracker
return flowID(ip.String()) type FlowID interface {
ID() string
} }
type Flow struct { type Flow struct {
@ -37,32 +35,29 @@ type FlowResponder interface {
SendPacket(pk RawPacket) error SendPacket(pk RawPacket) error
} }
// SrcFlowTracker tracks flow from the perspective of eyeball to origin // FlowTracker tracks flow from the perspective of eyeball to origin
// flowID is the source IP type FlowTracker struct {
type SrcFlowTracker struct {
lock sync.RWMutex lock sync.RWMutex
flows map[flowID]*Flow flows map[FlowID]*Flow
} }
func NewSrcFlowTracker() *SrcFlowTracker { func NewFlowTracker() *FlowTracker {
return &SrcFlowTracker{ return &FlowTracker{
flows: make(map[flowID]*Flow), flows: make(map[FlowID]*Flow),
} }
} }
func (sft *SrcFlowTracker) Get(srcIP net.IP) (*Flow, bool) { func (sft *FlowTracker) Get(id FlowID) (*Flow, bool) {
sft.lock.RLock() sft.lock.RLock()
defer sft.lock.RUnlock() defer sft.lock.RUnlock()
id := newFlowID(srcIP)
flow, ok := sft.flows[id] flow, ok := sft.flows[id]
return flow, ok return flow, ok
} }
// Registers a flow. If shouldReplace = true, replace the current flow // Registers a flow. If shouldReplace = true, replace the current flow
func (sft *SrcFlowTracker) Register(flow *Flow, shouldReplace bool) (replaced bool) { func (sft *FlowTracker) Register(id FlowID, flow *Flow, shouldReplace bool) (replaced bool) {
sft.lock.Lock() sft.lock.Lock()
defer sft.lock.Unlock() defer sft.lock.Unlock()
id := flowID(flow.Src.String())
currentFlow, ok := sft.flows[id] currentFlow, ok := sft.flows[id]
if !ok { if !ok {
sft.flows[id] = flow sft.flows[id] = flow
@ -77,10 +72,9 @@ func (sft *SrcFlowTracker) Register(flow *Flow, shouldReplace bool) (replaced bo
} }
// Unregisters a flow. If force = true, delete it even if it maps to a different flow // Unregisters a flow. If force = true, delete it even if it maps to a different flow
func (sft *SrcFlowTracker) Unregister(flow *Flow, force bool) (forceDeleted bool) { func (sft *FlowTracker) Unregister(id FlowID, flow *Flow, force bool) (forceDeleted bool) {
sft.lock.Lock() sft.lock.Lock()
defer sft.lock.Unlock() defer sft.lock.Unlock()
id := flowID(flow.Src.String())
currentFlow, ok := sft.flows[id] currentFlow, ok := sft.flows[id]
if !ok { if !ok {
return false return false

View File

@ -145,7 +145,7 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi
received, err := muxer.ReceivePacket(ctx) received, err := muxer.ReceivePacket(ctx)
require.NoError(t, err) require.NoError(t, err)
receivedICMP, err := icmpDecoder.Decode(received.Data) receivedICMP, err := icmpDecoder.Decode(received)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, pk.IP, receivedICMP.IP) require.Equal(t, pk.IP, receivedICMP.IP)
require.Equal(t, pk.Type, receivedICMP.Type) require.Equal(t, pk.Type, receivedICMP.Type)

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"strings" "strings"
"time" "time"
@ -15,6 +16,7 @@ import (
"github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/edgediscovery/allregions"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/orchestration" "github.com/cloudflare/cloudflared/orchestration"
"github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/retry"
"github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/signal"
@ -44,7 +46,7 @@ type Supervisor struct {
config *TunnelConfig config *TunnelConfig
orchestrator *orchestration.Orchestrator orchestrator *orchestration.Orchestrator
edgeIPs *edgediscovery.Edge edgeIPs *edgediscovery.Edge
edgeTunnelServer EdgeTunnelServer edgeTunnelServer *EdgeTunnelServer
tunnelErrors chan tunnelError tunnelErrors chan tunnelError
tunnelsConnecting map[int]chan struct{} tunnelsConnecting map[int]chan struct{}
tunnelsProtocolFallback map[int]*protocolFallback tunnelsProtocolFallback map[int]*protocolFallback
@ -114,6 +116,15 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
gracefulShutdownC: gracefulShutdownC, gracefulShutdownC: gracefulShutdownC,
connAwareLogger: log, connAwareLogger: log,
} }
if useDatagramV2(config) {
// For non-privileged datagram-oriented ICMP endpoints, network must be "udp4" or "udp6"
// TODO: TUN-6654 listen for IPv6 and decide if it should listen on specific IP
icmpProxy, err := ingress.NewICMPProxy("udp4", net.IPv4zero, config.Log)
if err != nil {
return nil, err
}
edgeTunnelServer.icmpProxy = icmpProxy
}
useReconnectToken := false useReconnectToken := false
if config.ClassicTunnel != nil { if config.ClassicTunnel != nil {
@ -125,7 +136,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
config: config, config: config,
orchestrator: orchestrator, orchestrator: orchestrator,
edgeIPs: edgeIPs, edgeIPs: edgeIPs,
edgeTunnelServer: edgeTunnelServer, edgeTunnelServer: &edgeTunnelServer,
tunnelErrors: make(chan tunnelError), tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{}, tunnelsConnecting: map[int]chan struct{}{},
tunnelsProtocolFallback: map[int]*protocolFallback{}, tunnelsProtocolFallback: map[int]*protocolFallback{},
@ -142,6 +153,14 @@ func (s *Supervisor) Run(
ctx context.Context, ctx context.Context,
connectedSignal *signal.Signal, connectedSignal *signal.Signal,
) error { ) error {
if s.edgeTunnelServer.icmpProxy != nil {
go func() {
if err := s.edgeTunnelServer.icmpProxy.ListenResponse(ctx); err != nil {
s.log.Logger().Err(err).Msg("icmp proxy terminated")
}
}()
}
if err := s.initialize(ctx, connectedSignal); err != nil { if err := s.initialize(ctx, connectedSignal); err != nil {
if err == errEarlyShutdown { if err == errEarlyShutdown {
return nil return nil
@ -413,3 +432,15 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts) registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions) return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions)
} }
func useDatagramV2(config *TunnelConfig) bool {
if config.NamedTunnel == nil {
return false
}
for _, feature := range config.NamedTunnel.Client.Features {
if feature == FeatureDatagramV2 {
return true
}
}
return false
}

View File

@ -20,6 +20,7 @@ import (
"github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/edgediscovery/allregions"
"github.com/cloudflare/cloudflared/h2mux" "github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/orchestration" "github.com/cloudflare/cloudflared/orchestration"
quicpogs "github.com/cloudflare/cloudflared/quic" quicpogs "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/retry"
@ -193,11 +194,12 @@ type EdgeTunnelServer struct {
reconnectCh chan ReconnectSignal reconnectCh chan ReconnectSignal
gracefulShutdownC <-chan struct{} gracefulShutdownC <-chan struct{}
tracker *tunnelstate.ConnTracker tracker *tunnelstate.ConnTracker
icmpProxy ingress.ICMPProxy
connAwareLogger *ConnAwareLogger connAwareLogger *ConnAwareLogger
} }
func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFallback *protocolFallback, connectedSignal *signal.Signal) error { func (e *EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFallback *protocolFallback, connectedSignal *signal.Signal) error {
haConnections.Inc() haConnections.Inc()
defer haConnections.Dec() defer haConnections.Dec()
@ -229,20 +231,14 @@ func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFa
// to another protocol when a particular metal doesn't support new protocol // to another protocol when a particular metal doesn't support new protocol
// Each connection can also have it's own IP version because individual connections might fallback // Each connection can also have it's own IP version because individual connections might fallback
// to another IP version. // to another IP version.
err, recoverable := ServeTunnel( err, recoverable := e.serveTunnel(
ctx, ctx,
connLog, connLog,
e.credentialManager,
e.config,
e.orchestrator,
addr, addr,
connIndex, connIndex,
connectedFuse, connectedFuse,
protocolFallback, protocolFallback,
e.cloudflaredUUID,
e.reconnectCh,
protocolFallback.protocol, protocolFallback.protocol,
e.gracefulShutdownC,
) )
// If the connection is recoverable, we want to maintain the same IP // If the connection is recoverable, we want to maintain the same IP
@ -361,20 +357,14 @@ func selectNextProtocol(
// ServeTunnel runs a single tunnel connection, returns nil on graceful shutdown, // ServeTunnel runs a single tunnel connection, returns nil on graceful shutdown,
// on error returns a flag indicating if error can be retried // on error returns a flag indicating if error can be retried
func ServeTunnel( func (e *EdgeTunnelServer) serveTunnel(
ctx context.Context, ctx context.Context,
connLog *ConnAwareLogger, connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
orchestrator *orchestration.Orchestrator,
addr *allregions.EdgeAddr, addr *allregions.EdgeAddr,
connIndex uint8, connIndex uint8,
fuse *h2mux.BooleanFuse, fuse *h2mux.BooleanFuse,
backoff *protocolFallback, backoff *protocolFallback,
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
protocol connection.Protocol, protocol connection.Protocol,
gracefulShutdownC <-chan struct{},
) (err error, recoverable bool) { ) (err error, recoverable bool) {
// Treat panics as recoverable errors // Treat panics as recoverable errors
defer func() { defer func() {
@ -389,21 +379,15 @@ func ServeTunnel(
} }
}() }()
defer config.Observer.SendDisconnect(connIndex) defer e.config.Observer.SendDisconnect(connIndex)
err, recoverable = serveTunnel( err, recoverable = e.serveConnection(
ctx, ctx,
connLog, connLog,
credentialManager,
config,
orchestrator,
addr, addr,
connIndex, connIndex,
fuse, fuse,
backoff, backoff,
cloudflaredUUID,
reconnectCh,
protocol, protocol,
gracefulShutdownC,
) )
if err != nil { if err != nil {
@ -416,7 +400,7 @@ func ServeTunnel(
connLog.ConnAwareLogger().Err(err).Msg("Register tunnel error from server side") connLog.ConnAwareLogger().Err(err).Msg("Register tunnel error from server side")
// Don't send registration error return from server to Sentry. They are // Don't send registration error return from server to Sentry. They are
// logged on server side // logged on server side
if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 { if incidents := e.config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 {
connLog.ConnAwareLogger().Msg(activeIncidentsMsg(incidents)) connLog.ConnAwareLogger().Msg(activeIncidentsMsg(incidents))
} }
return err.Cause, !err.Permanent return err.Cause, !err.Permanent
@ -442,93 +426,73 @@ func ServeTunnel(
return nil, false return nil, false
} }
func serveTunnel( func (e *EdgeTunnelServer) serveConnection(
ctx context.Context, ctx context.Context,
connLog *ConnAwareLogger, connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
orchestrator *orchestration.Orchestrator,
addr *allregions.EdgeAddr, addr *allregions.EdgeAddr,
connIndex uint8, connIndex uint8,
fuse *h2mux.BooleanFuse, fuse *h2mux.BooleanFuse,
backoff *protocolFallback, backoff *protocolFallback,
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
protocol connection.Protocol, protocol connection.Protocol,
gracefulShutdownC <-chan struct{},
) (err error, recoverable bool) { ) (err error, recoverable bool) {
connectedFuse := &connectedFuse{ connectedFuse := &connectedFuse{
fuse: fuse, fuse: fuse,
backoff: backoff, backoff: backoff,
} }
controlStream := connection.NewControlStream( controlStream := connection.NewControlStream(
config.Observer, e.config.Observer,
connectedFuse, connectedFuse,
config.NamedTunnel, e.config.NamedTunnel,
connIndex, connIndex,
addr.UDP.IP, addr.UDP.IP,
nil, nil,
gracefulShutdownC, e.gracefulShutdownC,
config.GracePeriod, e.config.GracePeriod,
protocol, protocol,
) )
switch protocol { switch protocol {
case connection.QUIC, connection.QUICWarp: case connection.QUIC, connection.QUICWarp:
connOptions := config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries())) connOptions := e.config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
return ServeQUIC(ctx, return e.serveQUIC(ctx,
addr.UDP, addr.UDP,
config,
orchestrator,
connLog, connLog,
connOptions, connOptions,
controlStream, controlStream,
connIndex, connIndex)
reconnectCh,
gracefulShutdownC)
case connection.HTTP2, connection.HTTP2Warp: case connection.HTTP2, connection.HTTP2Warp:
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP)
if err != nil { if err != nil {
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge") connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
return err, true return err, true
} }
connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries())) connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
if err := ServeHTTP2( if err := e.serveHTTP2(
ctx, ctx,
connLog, connLog,
config,
orchestrator,
edgeConn, edgeConn,
connOptions, connOptions,
controlStream, controlStream,
connIndex, connIndex,
gracefulShutdownC,
reconnectCh,
); err != nil { ); err != nil {
return err, false return err, false
} }
default: default:
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP) edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP)
if err != nil { if err != nil {
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge") connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
return err, true return err, true
} }
if err := ServeH2mux( if err := e.serveH2mux(
ctx, ctx,
connLog, connLog,
credentialManager,
config,
orchestrator,
edgeConn, edgeConn,
connIndex, connIndex,
connectedFuse, connectedFuse,
cloudflaredUUID,
reconnectCh,
gracefulShutdownC,
); err != nil { ); err != nil {
return err, false return err, false
} }
@ -544,30 +508,24 @@ func (r unrecoverableError) Error() string {
return r.err.Error() return r.err.Error()
} }
func ServeH2mux( func (e *EdgeTunnelServer) serveH2mux(
ctx context.Context, ctx context.Context,
connLog *ConnAwareLogger, connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
orchestrator *orchestration.Orchestrator,
edgeConn net.Conn, edgeConn net.Conn,
connIndex uint8, connIndex uint8,
connectedFuse *connectedFuse, connectedFuse *connectedFuse,
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
gracefulShutdownC <-chan struct{},
) error { ) error {
connLog.Logger().Debug().Msgf("Connecting via h2mux") connLog.Logger().Debug().Msgf("Connecting via h2mux")
// Returns error from parsing the origin URL or handshake errors // Returns error from parsing the origin URL or handshake errors
handler, err, recoverable := connection.NewH2muxConnection( handler, err, recoverable := connection.NewH2muxConnection(
orchestrator, e.orchestrator,
config.GracePeriod, e.config.GracePeriod,
config.MuxerConfig, e.config.MuxerConfig,
edgeConn, edgeConn,
connIndex, connIndex,
config.Observer, e.config.Observer,
gracefulShutdownC, e.gracefulShutdownC,
config.Log, e.config.Log,
) )
if err != nil { if err != nil {
if !recoverable { if !recoverable {
@ -579,42 +537,38 @@ func ServeH2mux(
errGroup, serveCtx := errgroup.WithContext(ctx) errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
if config.NamedTunnel != nil { if e.config.NamedTunnel != nil {
connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries())) connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries()))
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse) return handler.ServeNamedTunnel(serveCtx, e.config.NamedTunnel, connOptions, connectedFuse)
} }
registrationOptions := config.registrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) registrationOptions := e.config.registrationOptions(connIndex, edgeConn.LocalAddr().String(), e.cloudflaredUUID)
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse) return handler.ServeClassicTunnel(serveCtx, e.config.ClassicTunnel, e.credentialManager, registrationOptions, connectedFuse)
}) })
errGroup.Go(func() error { errGroup.Go(func() error {
return listenReconnect(serveCtx, reconnectCh, gracefulShutdownC) return listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC)
}) })
return errGroup.Wait() return errGroup.Wait()
} }
func ServeHTTP2( func (e *EdgeTunnelServer) serveHTTP2(
ctx context.Context, ctx context.Context,
connLog *ConnAwareLogger, connLog *ConnAwareLogger,
config *TunnelConfig,
orchestrator *orchestration.Orchestrator,
tlsServerConn net.Conn, tlsServerConn net.Conn,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler connection.ControlStreamHandler, controlStreamHandler connection.ControlStreamHandler,
connIndex uint8, connIndex uint8,
gracefulShutdownC <-chan struct{},
reconnectCh chan ReconnectSignal,
) error { ) error {
connLog.Logger().Debug().Msgf("Connecting via http2") connLog.Logger().Debug().Msgf("Connecting via http2")
h2conn := connection.NewHTTP2Connection( h2conn := connection.NewHTTP2Connection(
tlsServerConn, tlsServerConn,
orchestrator, e.orchestrator,
connOptions, connOptions,
config.Observer, e.config.Observer,
connIndex, connIndex,
controlStreamHandler, controlStreamHandler,
config.Log, e.config.Log,
) )
errGroup, serveCtx := errgroup.WithContext(ctx) errGroup, serveCtx := errgroup.WithContext(ctx)
@ -623,7 +577,7 @@ func ServeHTTP2(
}) })
errGroup.Go(func() error { errGroup.Go(func() error {
err := listenReconnect(serveCtx, reconnectCh, gracefulShutdownC) err := listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC)
if err != nil { if err != nil {
// forcefully break the connection (this is only used for testing) // forcefully break the connection (this is only used for testing)
connLog.Logger().Debug().Msg("Forcefully breaking http2 connection") connLog.Logger().Debug().Msg("Forcefully breaking http2 connection")
@ -635,19 +589,15 @@ func ServeHTTP2(
return errGroup.Wait() return errGroup.Wait()
} }
func ServeQUIC( func (e *EdgeTunnelServer) serveQUIC(
ctx context.Context, ctx context.Context,
edgeAddr *net.UDPAddr, edgeAddr *net.UDPAddr,
config *TunnelConfig,
orchestrator *orchestration.Orchestrator,
connLogger *ConnAwareLogger, connLogger *ConnAwareLogger,
connOptions *tunnelpogs.ConnectionOptions, connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler connection.ControlStreamHandler, controlStreamHandler connection.ControlStreamHandler,
connIndex uint8, connIndex uint8,
reconnectCh chan ReconnectSignal,
gracefulShutdownC <-chan struct{},
) (err error, recoverable bool) { ) (err error, recoverable bool) {
tlsConfig := config.EdgeTLSConfigs[connection.QUIC] tlsConfig := e.config.EdgeTLSConfigs[connection.QUIC]
quicConfig := &quic.Config{ quicConfig := &quic.Config{
HandshakeIdleTimeout: quicpogs.HandshakeIdleTimeout, HandshakeIdleTimeout: quicpogs.HandshakeIdleTimeout,
MaxIdleTimeout: quicpogs.MaxIdleTimeout, MaxIdleTimeout: quicpogs.MaxIdleTimeout,
@ -663,10 +613,11 @@ func ServeQUIC(
quicConfig, quicConfig,
edgeAddr, edgeAddr,
tlsConfig, tlsConfig,
orchestrator, e.orchestrator,
connOptions, connOptions,
controlStreamHandler, controlStreamHandler,
connLogger.Logger()) connLogger.Logger(),
e.icmpProxy)
if err != nil { if err != nil {
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection") connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection")
return err, true return err, true
@ -682,7 +633,7 @@ func ServeQUIC(
}) })
errGroup.Go(func() error { errGroup.Go(func() error {
err := listenReconnect(serveCtx, reconnectCh, gracefulShutdownC) err := listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC)
if err != nil { if err != nil {
// forcefully break the connection (this is only used for testing) // forcefully break the connection (this is only used for testing)
connLogger.Logger().Debug().Msg("Forcefully breaking quic connection") connLogger.Logger().Debug().Msg("Forcefully breaking quic connection")