TUN-6530: Implement ICMPv4 proxy
This proxy uses unprivileged datagram-oriented endpoint and is shared by all quic connections
This commit is contained in:
parent
f6bd4aa039
commit
59f5b0df83
|
@ -48,6 +48,7 @@ type QUICConnection struct {
|
||||||
sessionManager datagramsession.Manager
|
sessionManager datagramsession.Manager
|
||||||
// datagramMuxer mux/demux datagrams from quic connection
|
// datagramMuxer mux/demux datagrams from quic connection
|
||||||
datagramMuxer quicpogs.BaseDatagramMuxer
|
datagramMuxer quicpogs.BaseDatagramMuxer
|
||||||
|
packetRouter *packetRouter
|
||||||
controlStreamHandler ControlStreamHandler
|
controlStreamHandler ControlStreamHandler
|
||||||
connOptions *tunnelpogs.ConnectionOptions
|
connOptions *tunnelpogs.ConnectionOptions
|
||||||
}
|
}
|
||||||
|
@ -61,6 +62,7 @@ func NewQUICConnection(
|
||||||
connOptions *tunnelpogs.ConnectionOptions,
|
connOptions *tunnelpogs.ConnectionOptions,
|
||||||
controlStreamHandler ControlStreamHandler,
|
controlStreamHandler ControlStreamHandler,
|
||||||
logger *zerolog.Logger,
|
logger *zerolog.Logger,
|
||||||
|
icmpProxy ingress.ICMPProxy,
|
||||||
) (*QUICConnection, error) {
|
) (*QUICConnection, error) {
|
||||||
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
|
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -68,7 +70,20 @@ func NewQUICConnection(
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
|
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
|
||||||
datagramMuxer := quicpogs.NewDatagramMuxer(session, logger, sessionDemuxChan)
|
var (
|
||||||
|
datagramMuxer quicpogs.BaseDatagramMuxer
|
||||||
|
pr *packetRouter
|
||||||
|
)
|
||||||
|
if icmpProxy != nil {
|
||||||
|
pr = &packetRouter{
|
||||||
|
muxer: quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan),
|
||||||
|
icmpProxy: icmpProxy,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
datagramMuxer = pr.muxer
|
||||||
|
} else {
|
||||||
|
datagramMuxer = quicpogs.NewDatagramMuxer(session, logger, sessionDemuxChan)
|
||||||
|
}
|
||||||
sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan)
|
sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan)
|
||||||
|
|
||||||
return &QUICConnection{
|
return &QUICConnection{
|
||||||
|
@ -77,6 +92,7 @@ func NewQUICConnection(
|
||||||
logger: logger,
|
logger: logger,
|
||||||
sessionManager: sessionManager,
|
sessionManager: sessionManager,
|
||||||
datagramMuxer: datagramMuxer,
|
datagramMuxer: datagramMuxer,
|
||||||
|
packetRouter: pr,
|
||||||
controlStreamHandler: controlStreamHandler,
|
controlStreamHandler: controlStreamHandler,
|
||||||
connOptions: connOptions,
|
connOptions: connOptions,
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -117,6 +133,12 @@ func (q *QUICConnection) Serve(ctx context.Context) error {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
return q.datagramMuxer.ServeReceive(ctx)
|
return q.datagramMuxer.ServeReceive(ctx)
|
||||||
})
|
})
|
||||||
|
if q.packetRouter != nil {
|
||||||
|
errGroup.Go(func() error {
|
||||||
|
defer cancel()
|
||||||
|
return q.packetRouter.serve(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return errGroup.Wait()
|
return errGroup.Wait()
|
||||||
}
|
}
|
||||||
|
@ -305,6 +327,32 @@ func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32,
|
||||||
return q.orchestrator.UpdateConfig(version, config)
|
return q.orchestrator.UpdateConfig(version, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type packetRouter struct {
|
||||||
|
muxer *quicpogs.DatagramMuxerV2
|
||||||
|
icmpProxy ingress.ICMPProxy
|
||||||
|
logger *zerolog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pr *packetRouter) serve(ctx context.Context) error {
|
||||||
|
icmpDecoder := packet.NewICMPDecoder()
|
||||||
|
for {
|
||||||
|
pk, err := pr.muxer.ReceivePacket(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
icmpPacket, err := icmpDecoder.Decode(pk)
|
||||||
|
if err != nil {
|
||||||
|
pr.logger.Err(err).Msg("Failed to decode ICMP packet from quic datagram")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pr.icmpProxy.Request(icmpPacket, pr.muxer); err != nil {
|
||||||
|
pr.logger.Err(err).Str("src", icmpPacket.Src.String()).Str("dst", icmpPacket.Dst.String()).Msg("Failed to send ICMP packet")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
|
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
|
||||||
// the client.
|
// the client.
|
||||||
type streamReadWriteAcker struct {
|
type streamReadWriteAcker struct {
|
||||||
|
|
|
@ -682,6 +682,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection
|
||||||
&tunnelpogs.ConnectionOptions{},
|
&tunnelpogs.ConnectionOptions{},
|
||||||
fakeControlStream{},
|
fakeControlStream{},
|
||||||
&log,
|
&log,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return qc
|
return qc
|
||||||
|
|
|
@ -0,0 +1,139 @@
|
||||||
|
package ingress
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"golang.org/x/net/icmp"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/packet"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ICMPProxy sends ICMP messages and listens for their responses
|
||||||
|
type ICMPProxy interface {
|
||||||
|
// Request sends an ICMP message
|
||||||
|
Request(pk *packet.ICMP, responder packet.FlowResponder) error
|
||||||
|
// ListenResponse listens for responses to the requests until context is done
|
||||||
|
ListenResponse(ctx context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: TUN-6654 Extend support to IPv6
|
||||||
|
type icmpProxy struct {
|
||||||
|
srcFlowTracker *packet.FlowTracker
|
||||||
|
conn *icmp.PacketConn
|
||||||
|
logger *zerolog.Logger
|
||||||
|
encoder *packet.Encoder
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: TUN-6586: Use echo ID as FlowID
|
||||||
|
type seqNumFlowID int
|
||||||
|
|
||||||
|
func (snf seqNumFlowID) ID() string {
|
||||||
|
return strconv.FormatInt(int64(snf), 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewICMPProxy(network string, listenIP net.IP, logger *zerolog.Logger) (*icmpProxy, error) {
|
||||||
|
conn, err := icmp.ListenPacket(network, listenIP.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &icmpProxy{
|
||||||
|
srcFlowTracker: packet.NewFlowTracker(),
|
||||||
|
conn: conn,
|
||||||
|
logger: logger,
|
||||||
|
encoder: packet.NewEncoder(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FlowResponder) error {
|
||||||
|
switch body := pk.Message.Body.(type) {
|
||||||
|
case *icmp.Echo:
|
||||||
|
return ip.sendICMPEchoRequest(pk, body, responder)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("sending ICMP %s is not implemented", pk.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ip *icmpProxy) ListenResponse(ctx context.Context) error {
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
ip.conn.Close()
|
||||||
|
}()
|
||||||
|
buf := make([]byte, 1500)
|
||||||
|
for {
|
||||||
|
n, src, err := ip.conn.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// TODO: TUN-6654 Check for IPv6
|
||||||
|
msg, err := icmp.ParseMessage(int(layers.IPProtocolICMPv4), buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
ip.logger.Error().Err(err).Str("src", src.String()).Msg("Failed to parse ICMP message")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch body := msg.Body.(type) {
|
||||||
|
case *icmp.Echo:
|
||||||
|
if err := ip.handleEchoResponse(msg, body); err != nil {
|
||||||
|
ip.logger.Error().Err(err).Str("src", src.String()).Msg("Failed to handle ICMP response")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ip.logger.Warn().
|
||||||
|
Str("icmpType", fmt.Sprintf("%s", msg.Type)).
|
||||||
|
Msgf("Responding to this type of ICMP is not implemented")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ip *icmpProxy) sendICMPEchoRequest(pk *packet.ICMP, echo *icmp.Echo, responder packet.FlowResponder) error {
|
||||||
|
flow := packet.Flow{
|
||||||
|
Src: pk.Src,
|
||||||
|
Dst: pk.Dst,
|
||||||
|
Responder: responder,
|
||||||
|
}
|
||||||
|
// TODO: TUN-6586 rewrite ICMP echo request identifier and use it to track flows
|
||||||
|
flowID := seqNumFlowID(echo.Seq)
|
||||||
|
// TODO: TUN-6588 clean up flows
|
||||||
|
if replaced := ip.srcFlowTracker.Register(flowID, &flow, true); replaced {
|
||||||
|
ip.logger.Info().Str("src", flow.Src.String()).Str("dst", flow.Dst.String()).Msg("Replaced flow")
|
||||||
|
}
|
||||||
|
var pseudoHeader []byte = nil
|
||||||
|
serializedMsg, err := pk.Marshal(pseudoHeader)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Failed to encode ICMP message")
|
||||||
|
}
|
||||||
|
// The address needs to be of type UDPAddr when conn is created without priviledge
|
||||||
|
_, err = ip.conn.WriteTo(serializedMsg, &net.UDPAddr{
|
||||||
|
IP: pk.Dst.AsSlice(),
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ip *icmpProxy) handleEchoResponse(msg *icmp.Message, echo *icmp.Echo) error {
|
||||||
|
flow, ok := ip.srcFlowTracker.Get(seqNumFlowID(echo.Seq))
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("flow not found")
|
||||||
|
}
|
||||||
|
icmpPacket := packet.ICMP{
|
||||||
|
IP: &packet.IP{
|
||||||
|
Src: flow.Dst,
|
||||||
|
Dst: flow.Src,
|
||||||
|
Protocol: layers.IPProtocol(msg.Type.Protocol()),
|
||||||
|
},
|
||||||
|
Message: msg,
|
||||||
|
}
|
||||||
|
serializedPacket, err := ip.encoder.Encode(&icmpPacket)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "Failed to encode ICMP message")
|
||||||
|
}
|
||||||
|
if err := flow.Responder.SendPacket(serializedPacket); err != nil {
|
||||||
|
return errors.Wrap(err, "Failed to send packet to the edge")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,150 @@
|
||||||
|
package ingress
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/net/icmp"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
|
||||||
|
"github.com/cloudflare/cloudflared/packet"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
noopLogger = zerolog.Nop()
|
||||||
|
localhostIP = netip.MustParseAddr("127.0.0.1")
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestICMPProxyEcho makes sure we can send ICMP echo via the Request method and receives response via the
|
||||||
|
// ListenResponse method
|
||||||
|
func TestICMPProxyEcho(t *testing.T) {
|
||||||
|
skipWindows(t)
|
||||||
|
const (
|
||||||
|
echoID = 36571
|
||||||
|
endSeq = 100
|
||||||
|
)
|
||||||
|
proxy, err := NewICMPProxy("udp4", localhostIP.AsSlice(), &noopLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
proxyDone := make(chan struct{})
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
proxy.ListenResponse(ctx)
|
||||||
|
close(proxyDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
responder := echoFlowResponder{
|
||||||
|
decoder: packet.NewICMPDecoder(),
|
||||||
|
respChan: make(chan []byte),
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := packet.IP{
|
||||||
|
Src: localhostIP,
|
||||||
|
Dst: localhostIP,
|
||||||
|
Protocol: layers.IPProtocolICMPv4,
|
||||||
|
}
|
||||||
|
for i := 0; i < endSeq; i++ {
|
||||||
|
pk := packet.ICMP{
|
||||||
|
IP: &ip,
|
||||||
|
Message: &icmp.Message{
|
||||||
|
Type: ipv4.ICMPTypeEcho,
|
||||||
|
Code: 0,
|
||||||
|
Body: &icmp.Echo{
|
||||||
|
ID: echoID,
|
||||||
|
Seq: i,
|
||||||
|
Data: []byte(fmt.Sprintf("icmp echo seq %d", i)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.NoError(t, proxy.Request(&pk, &responder))
|
||||||
|
responder.validate(t, &pk)
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
<-proxyDone
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestICMPProxyRejectNotEcho makes sure it rejects messages other than echo
|
||||||
|
func TestICMPProxyRejectNotEcho(t *testing.T) {
|
||||||
|
skipWindows(t)
|
||||||
|
msgs := []icmp.Message{
|
||||||
|
{
|
||||||
|
Type: ipv4.ICMPTypeDestinationUnreachable,
|
||||||
|
Code: 1,
|
||||||
|
Body: &icmp.DstUnreach{
|
||||||
|
Data: []byte("original packet"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: ipv4.ICMPTypeTimeExceeded,
|
||||||
|
Code: 1,
|
||||||
|
Body: &icmp.TimeExceeded{
|
||||||
|
Data: []byte("original packet"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: ipv4.ICMPType(2),
|
||||||
|
Code: 0,
|
||||||
|
Body: &icmp.PacketTooBig{
|
||||||
|
MTU: 1280,
|
||||||
|
Data: []byte("original packet"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
proxy, err := NewICMPProxy("udp4", localhostIP.AsSlice(), &noopLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
responder := echoFlowResponder{
|
||||||
|
decoder: packet.NewICMPDecoder(),
|
||||||
|
respChan: make(chan []byte),
|
||||||
|
}
|
||||||
|
for _, m := range msgs {
|
||||||
|
pk := packet.ICMP{
|
||||||
|
IP: &packet.IP{
|
||||||
|
Src: localhostIP,
|
||||||
|
Dst: localhostIP,
|
||||||
|
Protocol: layers.IPProtocolICMPv4,
|
||||||
|
},
|
||||||
|
Message: &m,
|
||||||
|
}
|
||||||
|
require.Error(t, proxy.Request(&pk, &responder))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func skipWindows(t *testing.T) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("Cannot create non-privileged datagram-oriented ICMP endpoint on Windows")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type echoFlowResponder struct {
|
||||||
|
decoder *packet.ICMPDecoder
|
||||||
|
respChan chan []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (efr *echoFlowResponder) SendPacket(pk packet.RawPacket) error {
|
||||||
|
copiedPacket := make([]byte, len(pk.Data))
|
||||||
|
copy(copiedPacket, pk.Data)
|
||||||
|
efr.respChan <- copiedPacket
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (efr *echoFlowResponder) validate(t *testing.T, echoReq *packet.ICMP) {
|
||||||
|
pk := <-efr.respChan
|
||||||
|
decoded, err := efr.decoder.Decode(packet.RawPacket{Data: pk})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, decoded.Src, echoReq.Dst)
|
||||||
|
require.Equal(t, decoded.Dst, echoReq.Src)
|
||||||
|
require.Equal(t, echoReq.Protocol, decoded.Protocol)
|
||||||
|
|
||||||
|
require.Equal(t, ipv4.ICMPTypeEchoReply, decoded.Type)
|
||||||
|
require.Equal(t, 0, decoded.Code)
|
||||||
|
require.NotZero(t, decoded.Checksum)
|
||||||
|
// TODO: TUN-6586: Enable this validation when ICMP echo ID matches on Linux
|
||||||
|
//require.Equal(t, echoReq.Body, decoded.Body)
|
||||||
|
}
|
|
@ -75,9 +75,9 @@ func NewIPDecoder() *IPDecoder {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pd *IPDecoder) Decode(packet []byte) (*IP, error) {
|
func (pd *IPDecoder) Decode(packet RawPacket) (*IP, error) {
|
||||||
// Should decode to IP layer
|
// Should decode to IP layer
|
||||||
decoded, err := pd.decodeByVersion(packet)
|
decoded, err := pd.decodeByVersion(packet.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -139,9 +139,9 @@ func NewICMPDecoder() *ICMPDecoder {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pd *ICMPDecoder) Decode(packet []byte) (*ICMP, error) {
|
func (pd *ICMPDecoder) Decode(packet RawPacket) (*ICMP, error) {
|
||||||
// Should decode to IP and optionally ICMP layer
|
// Should decode to IP and optionally ICMP layer
|
||||||
decoded, err := pd.decodeByVersion(packet)
|
decoded, err := pd.decodeByVersion(packet.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,11 +43,11 @@ func TestDecodeIP(t *testing.T) {
|
||||||
p, err := encoder.Encode(&udp)
|
p, err := encoder.Encode(&udp)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ipPacket, err := ipDecoder.Decode(p.Data)
|
ipPacket, err := ipDecoder.Decode(p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assertIPLayer(t, &udp.IP, ipPacket)
|
assertIPLayer(t, &udp.IP, ipPacket)
|
||||||
|
|
||||||
icmpPacket, err := icmpDecoder.Decode(p.Data)
|
icmpPacket, err := icmpDecoder.Decode(p)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, icmpPacket)
|
require.Nil(t, icmpPacket)
|
||||||
}
|
}
|
||||||
|
@ -137,14 +137,14 @@ func TestDecodeICMP(t *testing.T) {
|
||||||
p, err := encoder.Encode(test.packet)
|
p, err := encoder.Encode(test.packet)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ipPacket, err := ipDecoder.Decode(p.Data)
|
ipPacket, err := ipDecoder.Decode(p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
if ipPacket.Src.Is4() {
|
if ipPacket.Src.Is4() {
|
||||||
assertIPLayer(t, &ipv4Packet, ipPacket)
|
assertIPLayer(t, &ipv4Packet, ipPacket)
|
||||||
} else {
|
} else {
|
||||||
assertIPLayer(t, &ipv6Packet, ipPacket)
|
assertIPLayer(t, &ipv6Packet, ipPacket)
|
||||||
}
|
}
|
||||||
icmpPacket, err := icmpDecoder.Decode(p.Data)
|
icmpPacket, err := icmpDecoder.Decode(p)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, ipPacket, icmpPacket.IP)
|
require.Equal(t, ipPacket, icmpPacket.IP)
|
||||||
|
|
||||||
|
@ -202,11 +202,11 @@ func TestDecodeBadPackets(t *testing.T) {
|
||||||
ipDecoder := NewIPDecoder()
|
ipDecoder := NewIPDecoder()
|
||||||
icmpDecoder := NewICMPDecoder()
|
icmpDecoder := NewICMPDecoder()
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
ipPacket, err := ipDecoder.Decode(test.packet)
|
ipPacket, err := ipDecoder.Decode(RawPacket{Data: test.packet})
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, ipPacket)
|
require.Nil(t, ipPacket)
|
||||||
|
|
||||||
icmpPacket, err := icmpDecoder.Decode(test.packet)
|
icmpPacket, err := icmpDecoder.Decode(RawPacket{Data: test.packet})
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, icmpPacket)
|
require.Nil(t, icmpPacket)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,19 +2,17 @@ package packet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
type flowID string
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrFlowNotFound = errors.New("flow not found")
|
ErrFlowNotFound = errors.New("flow not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
func newFlowID(ip net.IP) flowID {
|
// FlowID represents a key type that can be used by FlowTracker
|
||||||
return flowID(ip.String())
|
type FlowID interface {
|
||||||
|
ID() string
|
||||||
}
|
}
|
||||||
|
|
||||||
type Flow struct {
|
type Flow struct {
|
||||||
|
@ -37,32 +35,29 @@ type FlowResponder interface {
|
||||||
SendPacket(pk RawPacket) error
|
SendPacket(pk RawPacket) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// SrcFlowTracker tracks flow from the perspective of eyeball to origin
|
// FlowTracker tracks flow from the perspective of eyeball to origin
|
||||||
// flowID is the source IP
|
type FlowTracker struct {
|
||||||
type SrcFlowTracker struct {
|
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
flows map[flowID]*Flow
|
flows map[FlowID]*Flow
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSrcFlowTracker() *SrcFlowTracker {
|
func NewFlowTracker() *FlowTracker {
|
||||||
return &SrcFlowTracker{
|
return &FlowTracker{
|
||||||
flows: make(map[flowID]*Flow),
|
flows: make(map[FlowID]*Flow),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sft *SrcFlowTracker) Get(srcIP net.IP) (*Flow, bool) {
|
func (sft *FlowTracker) Get(id FlowID) (*Flow, bool) {
|
||||||
sft.lock.RLock()
|
sft.lock.RLock()
|
||||||
defer sft.lock.RUnlock()
|
defer sft.lock.RUnlock()
|
||||||
id := newFlowID(srcIP)
|
|
||||||
flow, ok := sft.flows[id]
|
flow, ok := sft.flows[id]
|
||||||
return flow, ok
|
return flow, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// Registers a flow. If shouldReplace = true, replace the current flow
|
// Registers a flow. If shouldReplace = true, replace the current flow
|
||||||
func (sft *SrcFlowTracker) Register(flow *Flow, shouldReplace bool) (replaced bool) {
|
func (sft *FlowTracker) Register(id FlowID, flow *Flow, shouldReplace bool) (replaced bool) {
|
||||||
sft.lock.Lock()
|
sft.lock.Lock()
|
||||||
defer sft.lock.Unlock()
|
defer sft.lock.Unlock()
|
||||||
id := flowID(flow.Src.String())
|
|
||||||
currentFlow, ok := sft.flows[id]
|
currentFlow, ok := sft.flows[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
sft.flows[id] = flow
|
sft.flows[id] = flow
|
||||||
|
@ -77,10 +72,9 @@ func (sft *SrcFlowTracker) Register(flow *Flow, shouldReplace bool) (replaced bo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unregisters a flow. If force = true, delete it even if it maps to a different flow
|
// Unregisters a flow. If force = true, delete it even if it maps to a different flow
|
||||||
func (sft *SrcFlowTracker) Unregister(flow *Flow, force bool) (forceDeleted bool) {
|
func (sft *FlowTracker) Unregister(id FlowID, flow *Flow, force bool) (forceDeleted bool) {
|
||||||
sft.lock.Lock()
|
sft.lock.Lock()
|
||||||
defer sft.lock.Unlock()
|
defer sft.lock.Unlock()
|
||||||
id := flowID(flow.Src.String())
|
|
||||||
currentFlow, ok := sft.flows[id]
|
currentFlow, ok := sft.flows[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
|
|
|
@ -145,7 +145,7 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi
|
||||||
received, err := muxer.ReceivePacket(ctx)
|
received, err := muxer.ReceivePacket(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
receivedICMP, err := icmpDecoder.Decode(received.Data)
|
receivedICMP, err := icmpDecoder.Decode(received)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, pk.IP, receivedICMP.IP)
|
require.Equal(t, pk.IP, receivedICMP.IP)
|
||||||
require.Equal(t, pk.Type, receivedICMP.Type)
|
require.Equal(t, pk.Type, receivedICMP.Type)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -15,6 +16,7 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||||
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
"github.com/cloudflare/cloudflared/orchestration"
|
"github.com/cloudflare/cloudflared/orchestration"
|
||||||
"github.com/cloudflare/cloudflared/retry"
|
"github.com/cloudflare/cloudflared/retry"
|
||||||
"github.com/cloudflare/cloudflared/signal"
|
"github.com/cloudflare/cloudflared/signal"
|
||||||
|
@ -44,7 +46,7 @@ type Supervisor struct {
|
||||||
config *TunnelConfig
|
config *TunnelConfig
|
||||||
orchestrator *orchestration.Orchestrator
|
orchestrator *orchestration.Orchestrator
|
||||||
edgeIPs *edgediscovery.Edge
|
edgeIPs *edgediscovery.Edge
|
||||||
edgeTunnelServer EdgeTunnelServer
|
edgeTunnelServer *EdgeTunnelServer
|
||||||
tunnelErrors chan tunnelError
|
tunnelErrors chan tunnelError
|
||||||
tunnelsConnecting map[int]chan struct{}
|
tunnelsConnecting map[int]chan struct{}
|
||||||
tunnelsProtocolFallback map[int]*protocolFallback
|
tunnelsProtocolFallback map[int]*protocolFallback
|
||||||
|
@ -114,6 +116,15 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
|
||||||
gracefulShutdownC: gracefulShutdownC,
|
gracefulShutdownC: gracefulShutdownC,
|
||||||
connAwareLogger: log,
|
connAwareLogger: log,
|
||||||
}
|
}
|
||||||
|
if useDatagramV2(config) {
|
||||||
|
// For non-privileged datagram-oriented ICMP endpoints, network must be "udp4" or "udp6"
|
||||||
|
// TODO: TUN-6654 listen for IPv6 and decide if it should listen on specific IP
|
||||||
|
icmpProxy, err := ingress.NewICMPProxy("udp4", net.IPv4zero, config.Log)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
edgeTunnelServer.icmpProxy = icmpProxy
|
||||||
|
}
|
||||||
|
|
||||||
useReconnectToken := false
|
useReconnectToken := false
|
||||||
if config.ClassicTunnel != nil {
|
if config.ClassicTunnel != nil {
|
||||||
|
@ -125,7 +136,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
|
||||||
config: config,
|
config: config,
|
||||||
orchestrator: orchestrator,
|
orchestrator: orchestrator,
|
||||||
edgeIPs: edgeIPs,
|
edgeIPs: edgeIPs,
|
||||||
edgeTunnelServer: edgeTunnelServer,
|
edgeTunnelServer: &edgeTunnelServer,
|
||||||
tunnelErrors: make(chan tunnelError),
|
tunnelErrors: make(chan tunnelError),
|
||||||
tunnelsConnecting: map[int]chan struct{}{},
|
tunnelsConnecting: map[int]chan struct{}{},
|
||||||
tunnelsProtocolFallback: map[int]*protocolFallback{},
|
tunnelsProtocolFallback: map[int]*protocolFallback{},
|
||||||
|
@ -142,6 +153,14 @@ func (s *Supervisor) Run(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
connectedSignal *signal.Signal,
|
connectedSignal *signal.Signal,
|
||||||
) error {
|
) error {
|
||||||
|
if s.edgeTunnelServer.icmpProxy != nil {
|
||||||
|
go func() {
|
||||||
|
if err := s.edgeTunnelServer.icmpProxy.ListenResponse(ctx); err != nil {
|
||||||
|
s.log.Logger().Err(err).Msg("icmp proxy terminated")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
if err := s.initialize(ctx, connectedSignal); err != nil {
|
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||||
if err == errEarlyShutdown {
|
if err == errEarlyShutdown {
|
||||||
return nil
|
return nil
|
||||||
|
@ -413,3 +432,15 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
|
||||||
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
|
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
|
||||||
return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions)
|
return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func useDatagramV2(config *TunnelConfig) bool {
|
||||||
|
if config.NamedTunnel == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, feature := range config.NamedTunnel.Client.Features {
|
||||||
|
if feature == FeatureDatagramV2 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||||
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
||||||
"github.com/cloudflare/cloudflared/h2mux"
|
"github.com/cloudflare/cloudflared/h2mux"
|
||||||
|
"github.com/cloudflare/cloudflared/ingress"
|
||||||
"github.com/cloudflare/cloudflared/orchestration"
|
"github.com/cloudflare/cloudflared/orchestration"
|
||||||
quicpogs "github.com/cloudflare/cloudflared/quic"
|
quicpogs "github.com/cloudflare/cloudflared/quic"
|
||||||
"github.com/cloudflare/cloudflared/retry"
|
"github.com/cloudflare/cloudflared/retry"
|
||||||
|
@ -193,11 +194,12 @@ type EdgeTunnelServer struct {
|
||||||
reconnectCh chan ReconnectSignal
|
reconnectCh chan ReconnectSignal
|
||||||
gracefulShutdownC <-chan struct{}
|
gracefulShutdownC <-chan struct{}
|
||||||
tracker *tunnelstate.ConnTracker
|
tracker *tunnelstate.ConnTracker
|
||||||
|
icmpProxy ingress.ICMPProxy
|
||||||
|
|
||||||
connAwareLogger *ConnAwareLogger
|
connAwareLogger *ConnAwareLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFallback *protocolFallback, connectedSignal *signal.Signal) error {
|
func (e *EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFallback *protocolFallback, connectedSignal *signal.Signal) error {
|
||||||
haConnections.Inc()
|
haConnections.Inc()
|
||||||
defer haConnections.Dec()
|
defer haConnections.Dec()
|
||||||
|
|
||||||
|
@ -229,20 +231,14 @@ func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFa
|
||||||
// to another protocol when a particular metal doesn't support new protocol
|
// to another protocol when a particular metal doesn't support new protocol
|
||||||
// Each connection can also have it's own IP version because individual connections might fallback
|
// Each connection can also have it's own IP version because individual connections might fallback
|
||||||
// to another IP version.
|
// to another IP version.
|
||||||
err, recoverable := ServeTunnel(
|
err, recoverable := e.serveTunnel(
|
||||||
ctx,
|
ctx,
|
||||||
connLog,
|
connLog,
|
||||||
e.credentialManager,
|
|
||||||
e.config,
|
|
||||||
e.orchestrator,
|
|
||||||
addr,
|
addr,
|
||||||
connIndex,
|
connIndex,
|
||||||
connectedFuse,
|
connectedFuse,
|
||||||
protocolFallback,
|
protocolFallback,
|
||||||
e.cloudflaredUUID,
|
|
||||||
e.reconnectCh,
|
|
||||||
protocolFallback.protocol,
|
protocolFallback.protocol,
|
||||||
e.gracefulShutdownC,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// If the connection is recoverable, we want to maintain the same IP
|
// If the connection is recoverable, we want to maintain the same IP
|
||||||
|
@ -361,20 +357,14 @@ func selectNextProtocol(
|
||||||
|
|
||||||
// ServeTunnel runs a single tunnel connection, returns nil on graceful shutdown,
|
// ServeTunnel runs a single tunnel connection, returns nil on graceful shutdown,
|
||||||
// on error returns a flag indicating if error can be retried
|
// on error returns a flag indicating if error can be retried
|
||||||
func ServeTunnel(
|
func (e *EdgeTunnelServer) serveTunnel(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
connLog *ConnAwareLogger,
|
connLog *ConnAwareLogger,
|
||||||
credentialManager *reconnectCredentialManager,
|
|
||||||
config *TunnelConfig,
|
|
||||||
orchestrator *orchestration.Orchestrator,
|
|
||||||
addr *allregions.EdgeAddr,
|
addr *allregions.EdgeAddr,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
fuse *h2mux.BooleanFuse,
|
fuse *h2mux.BooleanFuse,
|
||||||
backoff *protocolFallback,
|
backoff *protocolFallback,
|
||||||
cloudflaredUUID uuid.UUID,
|
|
||||||
reconnectCh chan ReconnectSignal,
|
|
||||||
protocol connection.Protocol,
|
protocol connection.Protocol,
|
||||||
gracefulShutdownC <-chan struct{},
|
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
// Treat panics as recoverable errors
|
// Treat panics as recoverable errors
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -389,21 +379,15 @@ func ServeTunnel(
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
defer config.Observer.SendDisconnect(connIndex)
|
defer e.config.Observer.SendDisconnect(connIndex)
|
||||||
err, recoverable = serveTunnel(
|
err, recoverable = e.serveConnection(
|
||||||
ctx,
|
ctx,
|
||||||
connLog,
|
connLog,
|
||||||
credentialManager,
|
|
||||||
config,
|
|
||||||
orchestrator,
|
|
||||||
addr,
|
addr,
|
||||||
connIndex,
|
connIndex,
|
||||||
fuse,
|
fuse,
|
||||||
backoff,
|
backoff,
|
||||||
cloudflaredUUID,
|
|
||||||
reconnectCh,
|
|
||||||
protocol,
|
protocol,
|
||||||
gracefulShutdownC,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -416,7 +400,7 @@ func ServeTunnel(
|
||||||
connLog.ConnAwareLogger().Err(err).Msg("Register tunnel error from server side")
|
connLog.ConnAwareLogger().Err(err).Msg("Register tunnel error from server side")
|
||||||
// Don't send registration error return from server to Sentry. They are
|
// Don't send registration error return from server to Sentry. They are
|
||||||
// logged on server side
|
// logged on server side
|
||||||
if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 {
|
if incidents := e.config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 {
|
||||||
connLog.ConnAwareLogger().Msg(activeIncidentsMsg(incidents))
|
connLog.ConnAwareLogger().Msg(activeIncidentsMsg(incidents))
|
||||||
}
|
}
|
||||||
return err.Cause, !err.Permanent
|
return err.Cause, !err.Permanent
|
||||||
|
@ -442,93 +426,73 @@ func ServeTunnel(
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func serveTunnel(
|
func (e *EdgeTunnelServer) serveConnection(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
connLog *ConnAwareLogger,
|
connLog *ConnAwareLogger,
|
||||||
credentialManager *reconnectCredentialManager,
|
|
||||||
config *TunnelConfig,
|
|
||||||
orchestrator *orchestration.Orchestrator,
|
|
||||||
addr *allregions.EdgeAddr,
|
addr *allregions.EdgeAddr,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
fuse *h2mux.BooleanFuse,
|
fuse *h2mux.BooleanFuse,
|
||||||
backoff *protocolFallback,
|
backoff *protocolFallback,
|
||||||
cloudflaredUUID uuid.UUID,
|
|
||||||
reconnectCh chan ReconnectSignal,
|
|
||||||
protocol connection.Protocol,
|
protocol connection.Protocol,
|
||||||
gracefulShutdownC <-chan struct{},
|
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
connectedFuse := &connectedFuse{
|
connectedFuse := &connectedFuse{
|
||||||
fuse: fuse,
|
fuse: fuse,
|
||||||
backoff: backoff,
|
backoff: backoff,
|
||||||
}
|
}
|
||||||
controlStream := connection.NewControlStream(
|
controlStream := connection.NewControlStream(
|
||||||
config.Observer,
|
e.config.Observer,
|
||||||
connectedFuse,
|
connectedFuse,
|
||||||
config.NamedTunnel,
|
e.config.NamedTunnel,
|
||||||
connIndex,
|
connIndex,
|
||||||
addr.UDP.IP,
|
addr.UDP.IP,
|
||||||
nil,
|
nil,
|
||||||
gracefulShutdownC,
|
e.gracefulShutdownC,
|
||||||
config.GracePeriod,
|
e.config.GracePeriod,
|
||||||
protocol,
|
protocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case connection.QUIC, connection.QUICWarp:
|
case connection.QUIC, connection.QUICWarp:
|
||||||
connOptions := config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
|
connOptions := e.config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
|
||||||
return ServeQUIC(ctx,
|
return e.serveQUIC(ctx,
|
||||||
addr.UDP,
|
addr.UDP,
|
||||||
config,
|
|
||||||
orchestrator,
|
|
||||||
connLog,
|
connLog,
|
||||||
connOptions,
|
connOptions,
|
||||||
controlStream,
|
controlStream,
|
||||||
connIndex,
|
connIndex)
|
||||||
reconnectCh,
|
|
||||||
gracefulShutdownC)
|
|
||||||
|
|
||||||
case connection.HTTP2, connection.HTTP2Warp:
|
case connection.HTTP2, connection.HTTP2Warp:
|
||||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP)
|
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
|
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
|
||||||
return err, true
|
return err, true
|
||||||
}
|
}
|
||||||
|
|
||||||
connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
|
connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
|
||||||
if err := ServeHTTP2(
|
if err := e.serveHTTP2(
|
||||||
ctx,
|
ctx,
|
||||||
connLog,
|
connLog,
|
||||||
config,
|
|
||||||
orchestrator,
|
|
||||||
edgeConn,
|
edgeConn,
|
||||||
connOptions,
|
connOptions,
|
||||||
controlStream,
|
controlStream,
|
||||||
connIndex,
|
connIndex,
|
||||||
gracefulShutdownC,
|
|
||||||
reconnectCh,
|
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err, false
|
return err, false
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP)
|
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
|
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
|
||||||
return err, true
|
return err, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ServeH2mux(
|
if err := e.serveH2mux(
|
||||||
ctx,
|
ctx,
|
||||||
connLog,
|
connLog,
|
||||||
credentialManager,
|
|
||||||
config,
|
|
||||||
orchestrator,
|
|
||||||
edgeConn,
|
edgeConn,
|
||||||
connIndex,
|
connIndex,
|
||||||
connectedFuse,
|
connectedFuse,
|
||||||
cloudflaredUUID,
|
|
||||||
reconnectCh,
|
|
||||||
gracefulShutdownC,
|
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err, false
|
return err, false
|
||||||
}
|
}
|
||||||
|
@ -544,30 +508,24 @@ func (r unrecoverableError) Error() string {
|
||||||
return r.err.Error()
|
return r.err.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServeH2mux(
|
func (e *EdgeTunnelServer) serveH2mux(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
connLog *ConnAwareLogger,
|
connLog *ConnAwareLogger,
|
||||||
credentialManager *reconnectCredentialManager,
|
|
||||||
config *TunnelConfig,
|
|
||||||
orchestrator *orchestration.Orchestrator,
|
|
||||||
edgeConn net.Conn,
|
edgeConn net.Conn,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
connectedFuse *connectedFuse,
|
connectedFuse *connectedFuse,
|
||||||
cloudflaredUUID uuid.UUID,
|
|
||||||
reconnectCh chan ReconnectSignal,
|
|
||||||
gracefulShutdownC <-chan struct{},
|
|
||||||
) error {
|
) error {
|
||||||
connLog.Logger().Debug().Msgf("Connecting via h2mux")
|
connLog.Logger().Debug().Msgf("Connecting via h2mux")
|
||||||
// Returns error from parsing the origin URL or handshake errors
|
// Returns error from parsing the origin URL or handshake errors
|
||||||
handler, err, recoverable := connection.NewH2muxConnection(
|
handler, err, recoverable := connection.NewH2muxConnection(
|
||||||
orchestrator,
|
e.orchestrator,
|
||||||
config.GracePeriod,
|
e.config.GracePeriod,
|
||||||
config.MuxerConfig,
|
e.config.MuxerConfig,
|
||||||
edgeConn,
|
edgeConn,
|
||||||
connIndex,
|
connIndex,
|
||||||
config.Observer,
|
e.config.Observer,
|
||||||
gracefulShutdownC,
|
e.gracefulShutdownC,
|
||||||
config.Log,
|
e.config.Log,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !recoverable {
|
if !recoverable {
|
||||||
|
@ -579,42 +537,38 @@ func ServeH2mux(
|
||||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
if config.NamedTunnel != nil {
|
if e.config.NamedTunnel != nil {
|
||||||
connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries()))
|
connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries()))
|
||||||
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse)
|
return handler.ServeNamedTunnel(serveCtx, e.config.NamedTunnel, connOptions, connectedFuse)
|
||||||
}
|
}
|
||||||
registrationOptions := config.registrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
|
registrationOptions := e.config.registrationOptions(connIndex, edgeConn.LocalAddr().String(), e.cloudflaredUUID)
|
||||||
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
|
return handler.ServeClassicTunnel(serveCtx, e.config.ClassicTunnel, e.credentialManager, registrationOptions, connectedFuse)
|
||||||
})
|
})
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
return listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)
|
return listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC)
|
||||||
})
|
})
|
||||||
|
|
||||||
return errGroup.Wait()
|
return errGroup.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServeHTTP2(
|
func (e *EdgeTunnelServer) serveHTTP2(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
connLog *ConnAwareLogger,
|
connLog *ConnAwareLogger,
|
||||||
config *TunnelConfig,
|
|
||||||
orchestrator *orchestration.Orchestrator,
|
|
||||||
tlsServerConn net.Conn,
|
tlsServerConn net.Conn,
|
||||||
connOptions *tunnelpogs.ConnectionOptions,
|
connOptions *tunnelpogs.ConnectionOptions,
|
||||||
controlStreamHandler connection.ControlStreamHandler,
|
controlStreamHandler connection.ControlStreamHandler,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
gracefulShutdownC <-chan struct{},
|
|
||||||
reconnectCh chan ReconnectSignal,
|
|
||||||
) error {
|
) error {
|
||||||
connLog.Logger().Debug().Msgf("Connecting via http2")
|
connLog.Logger().Debug().Msgf("Connecting via http2")
|
||||||
h2conn := connection.NewHTTP2Connection(
|
h2conn := connection.NewHTTP2Connection(
|
||||||
tlsServerConn,
|
tlsServerConn,
|
||||||
orchestrator,
|
e.orchestrator,
|
||||||
connOptions,
|
connOptions,
|
||||||
config.Observer,
|
e.config.Observer,
|
||||||
connIndex,
|
connIndex,
|
||||||
controlStreamHandler,
|
controlStreamHandler,
|
||||||
config.Log,
|
e.config.Log,
|
||||||
)
|
)
|
||||||
|
|
||||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||||
|
@ -623,7 +577,7 @@ func ServeHTTP2(
|
||||||
})
|
})
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
err := listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)
|
err := listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// forcefully break the connection (this is only used for testing)
|
// forcefully break the connection (this is only used for testing)
|
||||||
connLog.Logger().Debug().Msg("Forcefully breaking http2 connection")
|
connLog.Logger().Debug().Msg("Forcefully breaking http2 connection")
|
||||||
|
@ -635,19 +589,15 @@ func ServeHTTP2(
|
||||||
return errGroup.Wait()
|
return errGroup.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func ServeQUIC(
|
func (e *EdgeTunnelServer) serveQUIC(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
edgeAddr *net.UDPAddr,
|
edgeAddr *net.UDPAddr,
|
||||||
config *TunnelConfig,
|
|
||||||
orchestrator *orchestration.Orchestrator,
|
|
||||||
connLogger *ConnAwareLogger,
|
connLogger *ConnAwareLogger,
|
||||||
connOptions *tunnelpogs.ConnectionOptions,
|
connOptions *tunnelpogs.ConnectionOptions,
|
||||||
controlStreamHandler connection.ControlStreamHandler,
|
controlStreamHandler connection.ControlStreamHandler,
|
||||||
connIndex uint8,
|
connIndex uint8,
|
||||||
reconnectCh chan ReconnectSignal,
|
|
||||||
gracefulShutdownC <-chan struct{},
|
|
||||||
) (err error, recoverable bool) {
|
) (err error, recoverable bool) {
|
||||||
tlsConfig := config.EdgeTLSConfigs[connection.QUIC]
|
tlsConfig := e.config.EdgeTLSConfigs[connection.QUIC]
|
||||||
quicConfig := &quic.Config{
|
quicConfig := &quic.Config{
|
||||||
HandshakeIdleTimeout: quicpogs.HandshakeIdleTimeout,
|
HandshakeIdleTimeout: quicpogs.HandshakeIdleTimeout,
|
||||||
MaxIdleTimeout: quicpogs.MaxIdleTimeout,
|
MaxIdleTimeout: quicpogs.MaxIdleTimeout,
|
||||||
|
@ -663,10 +613,11 @@ func ServeQUIC(
|
||||||
quicConfig,
|
quicConfig,
|
||||||
edgeAddr,
|
edgeAddr,
|
||||||
tlsConfig,
|
tlsConfig,
|
||||||
orchestrator,
|
e.orchestrator,
|
||||||
connOptions,
|
connOptions,
|
||||||
controlStreamHandler,
|
controlStreamHandler,
|
||||||
connLogger.Logger())
|
connLogger.Logger(),
|
||||||
|
e.icmpProxy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection")
|
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection")
|
||||||
return err, true
|
return err, true
|
||||||
|
@ -682,7 +633,7 @@ func ServeQUIC(
|
||||||
})
|
})
|
||||||
|
|
||||||
errGroup.Go(func() error {
|
errGroup.Go(func() error {
|
||||||
err := listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)
|
err := listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// forcefully break the connection (this is only used for testing)
|
// forcefully break the connection (this is only used for testing)
|
||||||
connLogger.Logger().Debug().Msg("Forcefully breaking quic connection")
|
connLogger.Logger().Debug().Msg("Forcefully breaking quic connection")
|
||||||
|
|
Loading…
Reference in New Issue