TUN-6530: Implement ICMPv4 proxy
This proxy uses unprivileged datagram-oriented endpoint and is shared by all quic connections
This commit is contained in:
parent
f6bd4aa039
commit
59f5b0df83
|
@ -48,6 +48,7 @@ type QUICConnection struct {
|
|||
sessionManager datagramsession.Manager
|
||||
// datagramMuxer mux/demux datagrams from quic connection
|
||||
datagramMuxer quicpogs.BaseDatagramMuxer
|
||||
packetRouter *packetRouter
|
||||
controlStreamHandler ControlStreamHandler
|
||||
connOptions *tunnelpogs.ConnectionOptions
|
||||
}
|
||||
|
@ -61,6 +62,7 @@ func NewQUICConnection(
|
|||
connOptions *tunnelpogs.ConnectionOptions,
|
||||
controlStreamHandler ControlStreamHandler,
|
||||
logger *zerolog.Logger,
|
||||
icmpProxy ingress.ICMPProxy,
|
||||
) (*QUICConnection, error) {
|
||||
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
|
||||
if err != nil {
|
||||
|
@ -68,7 +70,20 @@ func NewQUICConnection(
|
|||
}
|
||||
|
||||
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
|
||||
datagramMuxer := quicpogs.NewDatagramMuxer(session, logger, sessionDemuxChan)
|
||||
var (
|
||||
datagramMuxer quicpogs.BaseDatagramMuxer
|
||||
pr *packetRouter
|
||||
)
|
||||
if icmpProxy != nil {
|
||||
pr = &packetRouter{
|
||||
muxer: quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan),
|
||||
icmpProxy: icmpProxy,
|
||||
logger: logger,
|
||||
}
|
||||
datagramMuxer = pr.muxer
|
||||
} else {
|
||||
datagramMuxer = quicpogs.NewDatagramMuxer(session, logger, sessionDemuxChan)
|
||||
}
|
||||
sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan)
|
||||
|
||||
return &QUICConnection{
|
||||
|
@ -77,6 +92,7 @@ func NewQUICConnection(
|
|||
logger: logger,
|
||||
sessionManager: sessionManager,
|
||||
datagramMuxer: datagramMuxer,
|
||||
packetRouter: pr,
|
||||
controlStreamHandler: controlStreamHandler,
|
||||
connOptions: connOptions,
|
||||
}, nil
|
||||
|
@ -117,6 +133,12 @@ func (q *QUICConnection) Serve(ctx context.Context) error {
|
|||
defer cancel()
|
||||
return q.datagramMuxer.ServeReceive(ctx)
|
||||
})
|
||||
if q.packetRouter != nil {
|
||||
errGroup.Go(func() error {
|
||||
defer cancel()
|
||||
return q.packetRouter.serve(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
return errGroup.Wait()
|
||||
}
|
||||
|
@ -305,6 +327,32 @@ func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32,
|
|||
return q.orchestrator.UpdateConfig(version, config)
|
||||
}
|
||||
|
||||
type packetRouter struct {
|
||||
muxer *quicpogs.DatagramMuxerV2
|
||||
icmpProxy ingress.ICMPProxy
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
func (pr *packetRouter) serve(ctx context.Context) error {
|
||||
icmpDecoder := packet.NewICMPDecoder()
|
||||
for {
|
||||
pk, err := pr.muxer.ReceivePacket(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
icmpPacket, err := icmpDecoder.Decode(pk)
|
||||
if err != nil {
|
||||
pr.logger.Err(err).Msg("Failed to decode ICMP packet from quic datagram")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := pr.icmpProxy.Request(icmpPacket, pr.muxer); err != nil {
|
||||
pr.logger.Err(err).Str("src", icmpPacket.Src.String()).Str("dst", icmpPacket.Dst.String()).Msg("Failed to send ICMP packet")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
|
||||
// the client.
|
||||
type streamReadWriteAcker struct {
|
||||
|
|
|
@ -682,6 +682,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection
|
|||
&tunnelpogs.ConnectionOptions{},
|
||||
fakeControlStream{},
|
||||
&log,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return qc
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
package ingress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/net/icmp"
|
||||
|
||||
"github.com/cloudflare/cloudflared/packet"
|
||||
)
|
||||
|
||||
// ICMPProxy sends ICMP messages and listens for their responses
|
||||
type ICMPProxy interface {
|
||||
// Request sends an ICMP message
|
||||
Request(pk *packet.ICMP, responder packet.FlowResponder) error
|
||||
// ListenResponse listens for responses to the requests until context is done
|
||||
ListenResponse(ctx context.Context) error
|
||||
}
|
||||
|
||||
// TODO: TUN-6654 Extend support to IPv6
|
||||
type icmpProxy struct {
|
||||
srcFlowTracker *packet.FlowTracker
|
||||
conn *icmp.PacketConn
|
||||
logger *zerolog.Logger
|
||||
encoder *packet.Encoder
|
||||
}
|
||||
|
||||
// TODO: TUN-6586: Use echo ID as FlowID
|
||||
type seqNumFlowID int
|
||||
|
||||
func (snf seqNumFlowID) ID() string {
|
||||
return strconv.FormatInt(int64(snf), 10)
|
||||
}
|
||||
|
||||
func NewICMPProxy(network string, listenIP net.IP, logger *zerolog.Logger) (*icmpProxy, error) {
|
||||
conn, err := icmp.ListenPacket(network, listenIP.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &icmpProxy{
|
||||
srcFlowTracker: packet.NewFlowTracker(),
|
||||
conn: conn,
|
||||
logger: logger,
|
||||
encoder: packet.NewEncoder(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FlowResponder) error {
|
||||
switch body := pk.Message.Body.(type) {
|
||||
case *icmp.Echo:
|
||||
return ip.sendICMPEchoRequest(pk, body, responder)
|
||||
default:
|
||||
return fmt.Errorf("sending ICMP %s is not implemented", pk.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (ip *icmpProxy) ListenResponse(ctx context.Context) error {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
ip.conn.Close()
|
||||
}()
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, src, err := ip.conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: TUN-6654 Check for IPv6
|
||||
msg, err := icmp.ParseMessage(int(layers.IPProtocolICMPv4), buf[:n])
|
||||
if err != nil {
|
||||
ip.logger.Error().Err(err).Str("src", src.String()).Msg("Failed to parse ICMP message")
|
||||
continue
|
||||
}
|
||||
switch body := msg.Body.(type) {
|
||||
case *icmp.Echo:
|
||||
if err := ip.handleEchoResponse(msg, body); err != nil {
|
||||
ip.logger.Error().Err(err).Str("src", src.String()).Msg("Failed to handle ICMP response")
|
||||
continue
|
||||
}
|
||||
default:
|
||||
ip.logger.Warn().
|
||||
Str("icmpType", fmt.Sprintf("%s", msg.Type)).
|
||||
Msgf("Responding to this type of ICMP is not implemented")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ip *icmpProxy) sendICMPEchoRequest(pk *packet.ICMP, echo *icmp.Echo, responder packet.FlowResponder) error {
|
||||
flow := packet.Flow{
|
||||
Src: pk.Src,
|
||||
Dst: pk.Dst,
|
||||
Responder: responder,
|
||||
}
|
||||
// TODO: TUN-6586 rewrite ICMP echo request identifier and use it to track flows
|
||||
flowID := seqNumFlowID(echo.Seq)
|
||||
// TODO: TUN-6588 clean up flows
|
||||
if replaced := ip.srcFlowTracker.Register(flowID, &flow, true); replaced {
|
||||
ip.logger.Info().Str("src", flow.Src.String()).Str("dst", flow.Dst.String()).Msg("Replaced flow")
|
||||
}
|
||||
var pseudoHeader []byte = nil
|
||||
serializedMsg, err := pk.Marshal(pseudoHeader)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to encode ICMP message")
|
||||
}
|
||||
// The address needs to be of type UDPAddr when conn is created without priviledge
|
||||
_, err = ip.conn.WriteTo(serializedMsg, &net.UDPAddr{
|
||||
IP: pk.Dst.AsSlice(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (ip *icmpProxy) handleEchoResponse(msg *icmp.Message, echo *icmp.Echo) error {
|
||||
flow, ok := ip.srcFlowTracker.Get(seqNumFlowID(echo.Seq))
|
||||
if !ok {
|
||||
return fmt.Errorf("flow not found")
|
||||
}
|
||||
icmpPacket := packet.ICMP{
|
||||
IP: &packet.IP{
|
||||
Src: flow.Dst,
|
||||
Dst: flow.Src,
|
||||
Protocol: layers.IPProtocol(msg.Type.Protocol()),
|
||||
},
|
||||
Message: msg,
|
||||
}
|
||||
serializedPacket, err := ip.encoder.Encode(&icmpPacket)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to encode ICMP message")
|
||||
}
|
||||
if err := flow.Responder.SendPacket(serializedPacket); err != nil {
|
||||
return errors.Wrap(err, "Failed to send packet to the edge")
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
package ingress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
|
||||
"github.com/cloudflare/cloudflared/packet"
|
||||
)
|
||||
|
||||
var (
|
||||
noopLogger = zerolog.Nop()
|
||||
localhostIP = netip.MustParseAddr("127.0.0.1")
|
||||
)
|
||||
|
||||
// TestICMPProxyEcho makes sure we can send ICMP echo via the Request method and receives response via the
|
||||
// ListenResponse method
|
||||
func TestICMPProxyEcho(t *testing.T) {
|
||||
skipWindows(t)
|
||||
const (
|
||||
echoID = 36571
|
||||
endSeq = 100
|
||||
)
|
||||
proxy, err := NewICMPProxy("udp4", localhostIP.AsSlice(), &noopLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyDone := make(chan struct{})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
proxy.ListenResponse(ctx)
|
||||
close(proxyDone)
|
||||
}()
|
||||
|
||||
responder := echoFlowResponder{
|
||||
decoder: packet.NewICMPDecoder(),
|
||||
respChan: make(chan []byte),
|
||||
}
|
||||
|
||||
ip := packet.IP{
|
||||
Src: localhostIP,
|
||||
Dst: localhostIP,
|
||||
Protocol: layers.IPProtocolICMPv4,
|
||||
}
|
||||
for i := 0; i < endSeq; i++ {
|
||||
pk := packet.ICMP{
|
||||
IP: &ip,
|
||||
Message: &icmp.Message{
|
||||
Type: ipv4.ICMPTypeEcho,
|
||||
Code: 0,
|
||||
Body: &icmp.Echo{
|
||||
ID: echoID,
|
||||
Seq: i,
|
||||
Data: []byte(fmt.Sprintf("icmp echo seq %d", i)),
|
||||
},
|
||||
},
|
||||
}
|
||||
require.NoError(t, proxy.Request(&pk, &responder))
|
||||
responder.validate(t, &pk)
|
||||
}
|
||||
cancel()
|
||||
<-proxyDone
|
||||
}
|
||||
|
||||
// TestICMPProxyRejectNotEcho makes sure it rejects messages other than echo
|
||||
func TestICMPProxyRejectNotEcho(t *testing.T) {
|
||||
skipWindows(t)
|
||||
msgs := []icmp.Message{
|
||||
{
|
||||
Type: ipv4.ICMPTypeDestinationUnreachable,
|
||||
Code: 1,
|
||||
Body: &icmp.DstUnreach{
|
||||
Data: []byte("original packet"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: ipv4.ICMPTypeTimeExceeded,
|
||||
Code: 1,
|
||||
Body: &icmp.TimeExceeded{
|
||||
Data: []byte("original packet"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: ipv4.ICMPType(2),
|
||||
Code: 0,
|
||||
Body: &icmp.PacketTooBig{
|
||||
MTU: 1280,
|
||||
Data: []byte("original packet"),
|
||||
},
|
||||
},
|
||||
}
|
||||
proxy, err := NewICMPProxy("udp4", localhostIP.AsSlice(), &noopLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
responder := echoFlowResponder{
|
||||
decoder: packet.NewICMPDecoder(),
|
||||
respChan: make(chan []byte),
|
||||
}
|
||||
for _, m := range msgs {
|
||||
pk := packet.ICMP{
|
||||
IP: &packet.IP{
|
||||
Src: localhostIP,
|
||||
Dst: localhostIP,
|
||||
Protocol: layers.IPProtocolICMPv4,
|
||||
},
|
||||
Message: &m,
|
||||
}
|
||||
require.Error(t, proxy.Request(&pk, &responder))
|
||||
}
|
||||
}
|
||||
|
||||
func skipWindows(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Cannot create non-privileged datagram-oriented ICMP endpoint on Windows")
|
||||
}
|
||||
}
|
||||
|
||||
type echoFlowResponder struct {
|
||||
decoder *packet.ICMPDecoder
|
||||
respChan chan []byte
|
||||
}
|
||||
|
||||
func (efr *echoFlowResponder) SendPacket(pk packet.RawPacket) error {
|
||||
copiedPacket := make([]byte, len(pk.Data))
|
||||
copy(copiedPacket, pk.Data)
|
||||
efr.respChan <- copiedPacket
|
||||
return nil
|
||||
}
|
||||
|
||||
func (efr *echoFlowResponder) validate(t *testing.T, echoReq *packet.ICMP) {
|
||||
pk := <-efr.respChan
|
||||
decoded, err := efr.decoder.Decode(packet.RawPacket{Data: pk})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, decoded.Src, echoReq.Dst)
|
||||
require.Equal(t, decoded.Dst, echoReq.Src)
|
||||
require.Equal(t, echoReq.Protocol, decoded.Protocol)
|
||||
|
||||
require.Equal(t, ipv4.ICMPTypeEchoReply, decoded.Type)
|
||||
require.Equal(t, 0, decoded.Code)
|
||||
require.NotZero(t, decoded.Checksum)
|
||||
// TODO: TUN-6586: Enable this validation when ICMP echo ID matches on Linux
|
||||
//require.Equal(t, echoReq.Body, decoded.Body)
|
||||
}
|
|
@ -75,9 +75,9 @@ func NewIPDecoder() *IPDecoder {
|
|||
}
|
||||
}
|
||||
|
||||
func (pd *IPDecoder) Decode(packet []byte) (*IP, error) {
|
||||
func (pd *IPDecoder) Decode(packet RawPacket) (*IP, error) {
|
||||
// Should decode to IP layer
|
||||
decoded, err := pd.decodeByVersion(packet)
|
||||
decoded, err := pd.decodeByVersion(packet.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -139,9 +139,9 @@ func NewICMPDecoder() *ICMPDecoder {
|
|||
}
|
||||
}
|
||||
|
||||
func (pd *ICMPDecoder) Decode(packet []byte) (*ICMP, error) {
|
||||
func (pd *ICMPDecoder) Decode(packet RawPacket) (*ICMP, error) {
|
||||
// Should decode to IP and optionally ICMP layer
|
||||
decoded, err := pd.decodeByVersion(packet)
|
||||
decoded, err := pd.decodeByVersion(packet.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -43,11 +43,11 @@ func TestDecodeIP(t *testing.T) {
|
|||
p, err := encoder.Encode(&udp)
|
||||
require.NoError(t, err)
|
||||
|
||||
ipPacket, err := ipDecoder.Decode(p.Data)
|
||||
ipPacket, err := ipDecoder.Decode(p)
|
||||
require.NoError(t, err)
|
||||
assertIPLayer(t, &udp.IP, ipPacket)
|
||||
|
||||
icmpPacket, err := icmpDecoder.Decode(p.Data)
|
||||
icmpPacket, err := icmpDecoder.Decode(p)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, icmpPacket)
|
||||
}
|
||||
|
@ -137,14 +137,14 @@ func TestDecodeICMP(t *testing.T) {
|
|||
p, err := encoder.Encode(test.packet)
|
||||
require.NoError(t, err)
|
||||
|
||||
ipPacket, err := ipDecoder.Decode(p.Data)
|
||||
ipPacket, err := ipDecoder.Decode(p)
|
||||
require.NoError(t, err)
|
||||
if ipPacket.Src.Is4() {
|
||||
assertIPLayer(t, &ipv4Packet, ipPacket)
|
||||
} else {
|
||||
assertIPLayer(t, &ipv6Packet, ipPacket)
|
||||
}
|
||||
icmpPacket, err := icmpDecoder.Decode(p.Data)
|
||||
icmpPacket, err := icmpDecoder.Decode(p)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ipPacket, icmpPacket.IP)
|
||||
|
||||
|
@ -202,11 +202,11 @@ func TestDecodeBadPackets(t *testing.T) {
|
|||
ipDecoder := NewIPDecoder()
|
||||
icmpDecoder := NewICMPDecoder()
|
||||
for _, test := range tests {
|
||||
ipPacket, err := ipDecoder.Decode(test.packet)
|
||||
ipPacket, err := ipDecoder.Decode(RawPacket{Data: test.packet})
|
||||
require.Error(t, err)
|
||||
require.Nil(t, ipPacket)
|
||||
|
||||
icmpPacket, err := icmpDecoder.Decode(test.packet)
|
||||
icmpPacket, err := icmpDecoder.Decode(RawPacket{Data: test.packet})
|
||||
require.Error(t, err)
|
||||
require.Nil(t, icmpPacket)
|
||||
}
|
||||
|
|
|
@ -2,19 +2,17 @@ package packet
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type flowID string
|
||||
|
||||
var (
|
||||
ErrFlowNotFound = errors.New("flow not found")
|
||||
)
|
||||
|
||||
func newFlowID(ip net.IP) flowID {
|
||||
return flowID(ip.String())
|
||||
// FlowID represents a key type that can be used by FlowTracker
|
||||
type FlowID interface {
|
||||
ID() string
|
||||
}
|
||||
|
||||
type Flow struct {
|
||||
|
@ -37,32 +35,29 @@ type FlowResponder interface {
|
|||
SendPacket(pk RawPacket) error
|
||||
}
|
||||
|
||||
// SrcFlowTracker tracks flow from the perspective of eyeball to origin
|
||||
// flowID is the source IP
|
||||
type SrcFlowTracker struct {
|
||||
// FlowTracker tracks flow from the perspective of eyeball to origin
|
||||
type FlowTracker struct {
|
||||
lock sync.RWMutex
|
||||
flows map[flowID]*Flow
|
||||
flows map[FlowID]*Flow
|
||||
}
|
||||
|
||||
func NewSrcFlowTracker() *SrcFlowTracker {
|
||||
return &SrcFlowTracker{
|
||||
flows: make(map[flowID]*Flow),
|
||||
func NewFlowTracker() *FlowTracker {
|
||||
return &FlowTracker{
|
||||
flows: make(map[FlowID]*Flow),
|
||||
}
|
||||
}
|
||||
|
||||
func (sft *SrcFlowTracker) Get(srcIP net.IP) (*Flow, bool) {
|
||||
func (sft *FlowTracker) Get(id FlowID) (*Flow, bool) {
|
||||
sft.lock.RLock()
|
||||
defer sft.lock.RUnlock()
|
||||
id := newFlowID(srcIP)
|
||||
flow, ok := sft.flows[id]
|
||||
return flow, ok
|
||||
}
|
||||
|
||||
// Registers a flow. If shouldReplace = true, replace the current flow
|
||||
func (sft *SrcFlowTracker) Register(flow *Flow, shouldReplace bool) (replaced bool) {
|
||||
func (sft *FlowTracker) Register(id FlowID, flow *Flow, shouldReplace bool) (replaced bool) {
|
||||
sft.lock.Lock()
|
||||
defer sft.lock.Unlock()
|
||||
id := flowID(flow.Src.String())
|
||||
currentFlow, ok := sft.flows[id]
|
||||
if !ok {
|
||||
sft.flows[id] = flow
|
||||
|
@ -77,10 +72,9 @@ func (sft *SrcFlowTracker) Register(flow *Flow, shouldReplace bool) (replaced bo
|
|||
}
|
||||
|
||||
// Unregisters a flow. If force = true, delete it even if it maps to a different flow
|
||||
func (sft *SrcFlowTracker) Unregister(flow *Flow, force bool) (forceDeleted bool) {
|
||||
func (sft *FlowTracker) Unregister(id FlowID, flow *Flow, force bool) (forceDeleted bool) {
|
||||
sft.lock.Lock()
|
||||
defer sft.lock.Unlock()
|
||||
id := flowID(flow.Src.String())
|
||||
currentFlow, ok := sft.flows[id]
|
||||
if !ok {
|
||||
return false
|
||||
|
|
|
@ -145,7 +145,7 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi
|
|||
received, err := muxer.ReceivePacket(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
receivedICMP, err := icmpDecoder.Decode(received.Data)
|
||||
receivedICMP, err := icmpDecoder.Decode(received)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pk.IP, receivedICMP.IP)
|
||||
require.Equal(t, pk.Type, receivedICMP.Type)
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -15,6 +16,7 @@ import (
|
|||
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/orchestration"
|
||||
"github.com/cloudflare/cloudflared/retry"
|
||||
"github.com/cloudflare/cloudflared/signal"
|
||||
|
@ -44,7 +46,7 @@ type Supervisor struct {
|
|||
config *TunnelConfig
|
||||
orchestrator *orchestration.Orchestrator
|
||||
edgeIPs *edgediscovery.Edge
|
||||
edgeTunnelServer EdgeTunnelServer
|
||||
edgeTunnelServer *EdgeTunnelServer
|
||||
tunnelErrors chan tunnelError
|
||||
tunnelsConnecting map[int]chan struct{}
|
||||
tunnelsProtocolFallback map[int]*protocolFallback
|
||||
|
@ -114,6 +116,15 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
|
|||
gracefulShutdownC: gracefulShutdownC,
|
||||
connAwareLogger: log,
|
||||
}
|
||||
if useDatagramV2(config) {
|
||||
// For non-privileged datagram-oriented ICMP endpoints, network must be "udp4" or "udp6"
|
||||
// TODO: TUN-6654 listen for IPv6 and decide if it should listen on specific IP
|
||||
icmpProxy, err := ingress.NewICMPProxy("udp4", net.IPv4zero, config.Log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
edgeTunnelServer.icmpProxy = icmpProxy
|
||||
}
|
||||
|
||||
useReconnectToken := false
|
||||
if config.ClassicTunnel != nil {
|
||||
|
@ -125,7 +136,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
|
|||
config: config,
|
||||
orchestrator: orchestrator,
|
||||
edgeIPs: edgeIPs,
|
||||
edgeTunnelServer: edgeTunnelServer,
|
||||
edgeTunnelServer: &edgeTunnelServer,
|
||||
tunnelErrors: make(chan tunnelError),
|
||||
tunnelsConnecting: map[int]chan struct{}{},
|
||||
tunnelsProtocolFallback: map[int]*protocolFallback{},
|
||||
|
@ -142,6 +153,14 @@ func (s *Supervisor) Run(
|
|||
ctx context.Context,
|
||||
connectedSignal *signal.Signal,
|
||||
) error {
|
||||
if s.edgeTunnelServer.icmpProxy != nil {
|
||||
go func() {
|
||||
if err := s.edgeTunnelServer.icmpProxy.ListenResponse(ctx); err != nil {
|
||||
s.log.Logger().Err(err).Msg("icmp proxy terminated")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||
if err == errEarlyShutdown {
|
||||
return nil
|
||||
|
@ -413,3 +432,15 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
|
|||
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
|
||||
return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions)
|
||||
}
|
||||
|
||||
func useDatagramV2(config *TunnelConfig) bool {
|
||||
if config.NamedTunnel == nil {
|
||||
return false
|
||||
}
|
||||
for _, feature := range config.NamedTunnel.Client.Features {
|
||||
if feature == FeatureDatagramV2 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/orchestration"
|
||||
quicpogs "github.com/cloudflare/cloudflared/quic"
|
||||
"github.com/cloudflare/cloudflared/retry"
|
||||
|
@ -193,11 +194,12 @@ type EdgeTunnelServer struct {
|
|||
reconnectCh chan ReconnectSignal
|
||||
gracefulShutdownC <-chan struct{}
|
||||
tracker *tunnelstate.ConnTracker
|
||||
icmpProxy ingress.ICMPProxy
|
||||
|
||||
connAwareLogger *ConnAwareLogger
|
||||
}
|
||||
|
||||
func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFallback *protocolFallback, connectedSignal *signal.Signal) error {
|
||||
func (e *EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFallback *protocolFallback, connectedSignal *signal.Signal) error {
|
||||
haConnections.Inc()
|
||||
defer haConnections.Dec()
|
||||
|
||||
|
@ -229,20 +231,14 @@ func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFa
|
|||
// to another protocol when a particular metal doesn't support new protocol
|
||||
// Each connection can also have it's own IP version because individual connections might fallback
|
||||
// to another IP version.
|
||||
err, recoverable := ServeTunnel(
|
||||
err, recoverable := e.serveTunnel(
|
||||
ctx,
|
||||
connLog,
|
||||
e.credentialManager,
|
||||
e.config,
|
||||
e.orchestrator,
|
||||
addr,
|
||||
connIndex,
|
||||
connectedFuse,
|
||||
protocolFallback,
|
||||
e.cloudflaredUUID,
|
||||
e.reconnectCh,
|
||||
protocolFallback.protocol,
|
||||
e.gracefulShutdownC,
|
||||
)
|
||||
|
||||
// If the connection is recoverable, we want to maintain the same IP
|
||||
|
@ -361,20 +357,14 @@ func selectNextProtocol(
|
|||
|
||||
// ServeTunnel runs a single tunnel connection, returns nil on graceful shutdown,
|
||||
// on error returns a flag indicating if error can be retried
|
||||
func ServeTunnel(
|
||||
func (e *EdgeTunnelServer) serveTunnel(
|
||||
ctx context.Context,
|
||||
connLog *ConnAwareLogger,
|
||||
credentialManager *reconnectCredentialManager,
|
||||
config *TunnelConfig,
|
||||
orchestrator *orchestration.Orchestrator,
|
||||
addr *allregions.EdgeAddr,
|
||||
connIndex uint8,
|
||||
fuse *h2mux.BooleanFuse,
|
||||
backoff *protocolFallback,
|
||||
cloudflaredUUID uuid.UUID,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
protocol connection.Protocol,
|
||||
gracefulShutdownC <-chan struct{},
|
||||
) (err error, recoverable bool) {
|
||||
// Treat panics as recoverable errors
|
||||
defer func() {
|
||||
|
@ -389,21 +379,15 @@ func ServeTunnel(
|
|||
}
|
||||
}()
|
||||
|
||||
defer config.Observer.SendDisconnect(connIndex)
|
||||
err, recoverable = serveTunnel(
|
||||
defer e.config.Observer.SendDisconnect(connIndex)
|
||||
err, recoverable = e.serveConnection(
|
||||
ctx,
|
||||
connLog,
|
||||
credentialManager,
|
||||
config,
|
||||
orchestrator,
|
||||
addr,
|
||||
connIndex,
|
||||
fuse,
|
||||
backoff,
|
||||
cloudflaredUUID,
|
||||
reconnectCh,
|
||||
protocol,
|
||||
gracefulShutdownC,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
|
@ -416,7 +400,7 @@ func ServeTunnel(
|
|||
connLog.ConnAwareLogger().Err(err).Msg("Register tunnel error from server side")
|
||||
// Don't send registration error return from server to Sentry. They are
|
||||
// logged on server side
|
||||
if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 {
|
||||
if incidents := e.config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 {
|
||||
connLog.ConnAwareLogger().Msg(activeIncidentsMsg(incidents))
|
||||
}
|
||||
return err.Cause, !err.Permanent
|
||||
|
@ -442,93 +426,73 @@ func ServeTunnel(
|
|||
return nil, false
|
||||
}
|
||||
|
||||
func serveTunnel(
|
||||
func (e *EdgeTunnelServer) serveConnection(
|
||||
ctx context.Context,
|
||||
connLog *ConnAwareLogger,
|
||||
credentialManager *reconnectCredentialManager,
|
||||
config *TunnelConfig,
|
||||
orchestrator *orchestration.Orchestrator,
|
||||
addr *allregions.EdgeAddr,
|
||||
connIndex uint8,
|
||||
fuse *h2mux.BooleanFuse,
|
||||
backoff *protocolFallback,
|
||||
cloudflaredUUID uuid.UUID,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
protocol connection.Protocol,
|
||||
gracefulShutdownC <-chan struct{},
|
||||
) (err error, recoverable bool) {
|
||||
connectedFuse := &connectedFuse{
|
||||
fuse: fuse,
|
||||
backoff: backoff,
|
||||
}
|
||||
controlStream := connection.NewControlStream(
|
||||
config.Observer,
|
||||
e.config.Observer,
|
||||
connectedFuse,
|
||||
config.NamedTunnel,
|
||||
e.config.NamedTunnel,
|
||||
connIndex,
|
||||
addr.UDP.IP,
|
||||
nil,
|
||||
gracefulShutdownC,
|
||||
config.GracePeriod,
|
||||
e.gracefulShutdownC,
|
||||
e.config.GracePeriod,
|
||||
protocol,
|
||||
)
|
||||
|
||||
switch protocol {
|
||||
case connection.QUIC, connection.QUICWarp:
|
||||
connOptions := config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
|
||||
return ServeQUIC(ctx,
|
||||
connOptions := e.config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
|
||||
return e.serveQUIC(ctx,
|
||||
addr.UDP,
|
||||
config,
|
||||
orchestrator,
|
||||
connLog,
|
||||
connOptions,
|
||||
controlStream,
|
||||
connIndex,
|
||||
reconnectCh,
|
||||
gracefulShutdownC)
|
||||
connIndex)
|
||||
|
||||
case connection.HTTP2, connection.HTTP2Warp:
|
||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP)
|
||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP)
|
||||
if err != nil {
|
||||
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
|
||||
return err, true
|
||||
}
|
||||
|
||||
connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
|
||||
if err := ServeHTTP2(
|
||||
connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
|
||||
if err := e.serveHTTP2(
|
||||
ctx,
|
||||
connLog,
|
||||
config,
|
||||
orchestrator,
|
||||
edgeConn,
|
||||
connOptions,
|
||||
controlStream,
|
||||
connIndex,
|
||||
gracefulShutdownC,
|
||||
reconnectCh,
|
||||
); err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
default:
|
||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP)
|
||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP)
|
||||
if err != nil {
|
||||
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
|
||||
return err, true
|
||||
}
|
||||
|
||||
if err := ServeH2mux(
|
||||
if err := e.serveH2mux(
|
||||
ctx,
|
||||
connLog,
|
||||
credentialManager,
|
||||
config,
|
||||
orchestrator,
|
||||
edgeConn,
|
||||
connIndex,
|
||||
connectedFuse,
|
||||
cloudflaredUUID,
|
||||
reconnectCh,
|
||||
gracefulShutdownC,
|
||||
); err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
@ -544,30 +508,24 @@ func (r unrecoverableError) Error() string {
|
|||
return r.err.Error()
|
||||
}
|
||||
|
||||
func ServeH2mux(
|
||||
func (e *EdgeTunnelServer) serveH2mux(
|
||||
ctx context.Context,
|
||||
connLog *ConnAwareLogger,
|
||||
credentialManager *reconnectCredentialManager,
|
||||
config *TunnelConfig,
|
||||
orchestrator *orchestration.Orchestrator,
|
||||
edgeConn net.Conn,
|
||||
connIndex uint8,
|
||||
connectedFuse *connectedFuse,
|
||||
cloudflaredUUID uuid.UUID,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
gracefulShutdownC <-chan struct{},
|
||||
) error {
|
||||
connLog.Logger().Debug().Msgf("Connecting via h2mux")
|
||||
// Returns error from parsing the origin URL or handshake errors
|
||||
handler, err, recoverable := connection.NewH2muxConnection(
|
||||
orchestrator,
|
||||
config.GracePeriod,
|
||||
config.MuxerConfig,
|
||||
e.orchestrator,
|
||||
e.config.GracePeriod,
|
||||
e.config.MuxerConfig,
|
||||
edgeConn,
|
||||
connIndex,
|
||||
config.Observer,
|
||||
gracefulShutdownC,
|
||||
config.Log,
|
||||
e.config.Observer,
|
||||
e.gracefulShutdownC,
|
||||
e.config.Log,
|
||||
)
|
||||
if err != nil {
|
||||
if !recoverable {
|
||||
|
@ -579,42 +537,38 @@ func ServeH2mux(
|
|||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||
|
||||
errGroup.Go(func() error {
|
||||
if config.NamedTunnel != nil {
|
||||
connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries()))
|
||||
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse)
|
||||
if e.config.NamedTunnel != nil {
|
||||
connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries()))
|
||||
return handler.ServeNamedTunnel(serveCtx, e.config.NamedTunnel, connOptions, connectedFuse)
|
||||
}
|
||||
registrationOptions := config.registrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
|
||||
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
|
||||
registrationOptions := e.config.registrationOptions(connIndex, edgeConn.LocalAddr().String(), e.cloudflaredUUID)
|
||||
return handler.ServeClassicTunnel(serveCtx, e.config.ClassicTunnel, e.credentialManager, registrationOptions, connectedFuse)
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
return listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)
|
||||
return listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC)
|
||||
})
|
||||
|
||||
return errGroup.Wait()
|
||||
}
|
||||
|
||||
func ServeHTTP2(
|
||||
func (e *EdgeTunnelServer) serveHTTP2(
|
||||
ctx context.Context,
|
||||
connLog *ConnAwareLogger,
|
||||
config *TunnelConfig,
|
||||
orchestrator *orchestration.Orchestrator,
|
||||
tlsServerConn net.Conn,
|
||||
connOptions *tunnelpogs.ConnectionOptions,
|
||||
controlStreamHandler connection.ControlStreamHandler,
|
||||
connIndex uint8,
|
||||
gracefulShutdownC <-chan struct{},
|
||||
reconnectCh chan ReconnectSignal,
|
||||
) error {
|
||||
connLog.Logger().Debug().Msgf("Connecting via http2")
|
||||
h2conn := connection.NewHTTP2Connection(
|
||||
tlsServerConn,
|
||||
orchestrator,
|
||||
e.orchestrator,
|
||||
connOptions,
|
||||
config.Observer,
|
||||
e.config.Observer,
|
||||
connIndex,
|
||||
controlStreamHandler,
|
||||
config.Log,
|
||||
e.config.Log,
|
||||
)
|
||||
|
||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||
|
@ -623,7 +577,7 @@ func ServeHTTP2(
|
|||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
err := listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)
|
||||
err := listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC)
|
||||
if err != nil {
|
||||
// forcefully break the connection (this is only used for testing)
|
||||
connLog.Logger().Debug().Msg("Forcefully breaking http2 connection")
|
||||
|
@ -635,19 +589,15 @@ func ServeHTTP2(
|
|||
return errGroup.Wait()
|
||||
}
|
||||
|
||||
func ServeQUIC(
|
||||
func (e *EdgeTunnelServer) serveQUIC(
|
||||
ctx context.Context,
|
||||
edgeAddr *net.UDPAddr,
|
||||
config *TunnelConfig,
|
||||
orchestrator *orchestration.Orchestrator,
|
||||
connLogger *ConnAwareLogger,
|
||||
connOptions *tunnelpogs.ConnectionOptions,
|
||||
controlStreamHandler connection.ControlStreamHandler,
|
||||
connIndex uint8,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
gracefulShutdownC <-chan struct{},
|
||||
) (err error, recoverable bool) {
|
||||
tlsConfig := config.EdgeTLSConfigs[connection.QUIC]
|
||||
tlsConfig := e.config.EdgeTLSConfigs[connection.QUIC]
|
||||
quicConfig := &quic.Config{
|
||||
HandshakeIdleTimeout: quicpogs.HandshakeIdleTimeout,
|
||||
MaxIdleTimeout: quicpogs.MaxIdleTimeout,
|
||||
|
@ -663,10 +613,11 @@ func ServeQUIC(
|
|||
quicConfig,
|
||||
edgeAddr,
|
||||
tlsConfig,
|
||||
orchestrator,
|
||||
e.orchestrator,
|
||||
connOptions,
|
||||
controlStreamHandler,
|
||||
connLogger.Logger())
|
||||
connLogger.Logger(),
|
||||
e.icmpProxy)
|
||||
if err != nil {
|
||||
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection")
|
||||
return err, true
|
||||
|
@ -682,7 +633,7 @@ func ServeQUIC(
|
|||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
err := listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)
|
||||
err := listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC)
|
||||
if err != nil {
|
||||
// forcefully break the connection (this is only used for testing)
|
||||
connLogger.Logger().Debug().Msg("Forcefully breaking quic connection")
|
||||
|
|
Loading…
Reference in New Issue