diff --git a/connection/quic.go b/connection/quic.go index 7a61bdbb..10361088 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -48,6 +48,7 @@ type QUICConnection struct { sessionManager datagramsession.Manager // datagramMuxer mux/demux datagrams from quic connection datagramMuxer quicpogs.BaseDatagramMuxer + packetRouter *packetRouter controlStreamHandler ControlStreamHandler connOptions *tunnelpogs.ConnectionOptions } @@ -61,6 +62,7 @@ func NewQUICConnection( connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler ControlStreamHandler, logger *zerolog.Logger, + icmpProxy ingress.ICMPProxy, ) (*QUICConnection, error) { session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig) if err != nil { @@ -68,7 +70,20 @@ func NewQUICConnection( } sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) - datagramMuxer := quicpogs.NewDatagramMuxer(session, logger, sessionDemuxChan) + var ( + datagramMuxer quicpogs.BaseDatagramMuxer + pr *packetRouter + ) + if icmpProxy != nil { + pr = &packetRouter{ + muxer: quicpogs.NewDatagramMuxerV2(session, logger, sessionDemuxChan), + icmpProxy: icmpProxy, + logger: logger, + } + datagramMuxer = pr.muxer + } else { + datagramMuxer = quicpogs.NewDatagramMuxer(session, logger, sessionDemuxChan) + } sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) return &QUICConnection{ @@ -77,6 +92,7 @@ func NewQUICConnection( logger: logger, sessionManager: sessionManager, datagramMuxer: datagramMuxer, + packetRouter: pr, controlStreamHandler: controlStreamHandler, connOptions: connOptions, }, nil @@ -117,6 +133,12 @@ func (q *QUICConnection) Serve(ctx context.Context) error { defer cancel() return q.datagramMuxer.ServeReceive(ctx) }) + if q.packetRouter != nil { + errGroup.Go(func() error { + defer cancel() + return q.packetRouter.serve(ctx) + }) + } return errGroup.Wait() } @@ -305,6 +327,32 @@ func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32, return q.orchestrator.UpdateConfig(version, config) } +type packetRouter struct { + muxer *quicpogs.DatagramMuxerV2 + icmpProxy ingress.ICMPProxy + logger *zerolog.Logger +} + +func (pr *packetRouter) serve(ctx context.Context) error { + icmpDecoder := packet.NewICMPDecoder() + for { + pk, err := pr.muxer.ReceivePacket(ctx) + if err != nil { + return err + } + icmpPacket, err := icmpDecoder.Decode(pk) + if err != nil { + pr.logger.Err(err).Msg("Failed to decode ICMP packet from quic datagram") + continue + } + + if err := pr.icmpProxy.Request(icmpPacket, pr.muxer); err != nil { + pr.logger.Err(err).Str("src", icmpPacket.Src.String()).Str("dst", icmpPacket.Dst.String()).Msg("Failed to send ICMP packet") + continue + } + } +} + // streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to // the client. type streamReadWriteAcker struct { diff --git a/connection/quic_test.go b/connection/quic_test.go index 0afb3953..c860400e 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -682,6 +682,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection &tunnelpogs.ConnectionOptions{}, fakeControlStream{}, &log, + nil, ) require.NoError(t, err) return qc diff --git a/ingress/origin_icmp_proxy.go b/ingress/origin_icmp_proxy.go new file mode 100644 index 00000000..672e833e --- /dev/null +++ b/ingress/origin_icmp_proxy.go @@ -0,0 +1,139 @@ +package ingress + +import ( + "context" + "fmt" + "net" + "strconv" + + "github.com/google/gopacket/layers" + "github.com/pkg/errors" + "github.com/rs/zerolog" + "golang.org/x/net/icmp" + + "github.com/cloudflare/cloudflared/packet" +) + +// ICMPProxy sends ICMP messages and listens for their responses +type ICMPProxy interface { + // Request sends an ICMP message + Request(pk *packet.ICMP, responder packet.FlowResponder) error + // ListenResponse listens for responses to the requests until context is done + ListenResponse(ctx context.Context) error +} + +// TODO: TUN-6654 Extend support to IPv6 +type icmpProxy struct { + srcFlowTracker *packet.FlowTracker + conn *icmp.PacketConn + logger *zerolog.Logger + encoder *packet.Encoder +} + +// TODO: TUN-6586: Use echo ID as FlowID +type seqNumFlowID int + +func (snf seqNumFlowID) ID() string { + return strconv.FormatInt(int64(snf), 10) +} + +func NewICMPProxy(network string, listenIP net.IP, logger *zerolog.Logger) (*icmpProxy, error) { + conn, err := icmp.ListenPacket(network, listenIP.String()) + if err != nil { + return nil, err + } + return &icmpProxy{ + srcFlowTracker: packet.NewFlowTracker(), + conn: conn, + logger: logger, + encoder: packet.NewEncoder(), + }, nil +} + +func (ip *icmpProxy) Request(pk *packet.ICMP, responder packet.FlowResponder) error { + switch body := pk.Message.Body.(type) { + case *icmp.Echo: + return ip.sendICMPEchoRequest(pk, body, responder) + default: + return fmt.Errorf("sending ICMP %s is not implemented", pk.Type) + } +} + +func (ip *icmpProxy) ListenResponse(ctx context.Context) error { + go func() { + <-ctx.Done() + ip.conn.Close() + }() + buf := make([]byte, 1500) + for { + n, src, err := ip.conn.ReadFrom(buf) + if err != nil { + return err + } + // TODO: TUN-6654 Check for IPv6 + msg, err := icmp.ParseMessage(int(layers.IPProtocolICMPv4), buf[:n]) + if err != nil { + ip.logger.Error().Err(err).Str("src", src.String()).Msg("Failed to parse ICMP message") + continue + } + switch body := msg.Body.(type) { + case *icmp.Echo: + if err := ip.handleEchoResponse(msg, body); err != nil { + ip.logger.Error().Err(err).Str("src", src.String()).Msg("Failed to handle ICMP response") + continue + } + default: + ip.logger.Warn(). + Str("icmpType", fmt.Sprintf("%s", msg.Type)). + Msgf("Responding to this type of ICMP is not implemented") + continue + } + } +} + +func (ip *icmpProxy) sendICMPEchoRequest(pk *packet.ICMP, echo *icmp.Echo, responder packet.FlowResponder) error { + flow := packet.Flow{ + Src: pk.Src, + Dst: pk.Dst, + Responder: responder, + } + // TODO: TUN-6586 rewrite ICMP echo request identifier and use it to track flows + flowID := seqNumFlowID(echo.Seq) + // TODO: TUN-6588 clean up flows + if replaced := ip.srcFlowTracker.Register(flowID, &flow, true); replaced { + ip.logger.Info().Str("src", flow.Src.String()).Str("dst", flow.Dst.String()).Msg("Replaced flow") + } + var pseudoHeader []byte = nil + serializedMsg, err := pk.Marshal(pseudoHeader) + if err != nil { + return errors.Wrap(err, "Failed to encode ICMP message") + } + // The address needs to be of type UDPAddr when conn is created without priviledge + _, err = ip.conn.WriteTo(serializedMsg, &net.UDPAddr{ + IP: pk.Dst.AsSlice(), + }) + return err +} + +func (ip *icmpProxy) handleEchoResponse(msg *icmp.Message, echo *icmp.Echo) error { + flow, ok := ip.srcFlowTracker.Get(seqNumFlowID(echo.Seq)) + if !ok { + return fmt.Errorf("flow not found") + } + icmpPacket := packet.ICMP{ + IP: &packet.IP{ + Src: flow.Dst, + Dst: flow.Src, + Protocol: layers.IPProtocol(msg.Type.Protocol()), + }, + Message: msg, + } + serializedPacket, err := ip.encoder.Encode(&icmpPacket) + if err != nil { + return errors.Wrap(err, "Failed to encode ICMP message") + } + if err := flow.Responder.SendPacket(serializedPacket); err != nil { + return errors.Wrap(err, "Failed to send packet to the edge") + } + return nil +} diff --git a/ingress/origin_icmp_proxy_test.go b/ingress/origin_icmp_proxy_test.go new file mode 100644 index 00000000..9ca8ebee --- /dev/null +++ b/ingress/origin_icmp_proxy_test.go @@ -0,0 +1,150 @@ +package ingress + +import ( + "context" + "fmt" + "net/netip" + "runtime" + "testing" + + "github.com/google/gopacket/layers" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + + "github.com/cloudflare/cloudflared/packet" +) + +var ( + noopLogger = zerolog.Nop() + localhostIP = netip.MustParseAddr("127.0.0.1") +) + +// TestICMPProxyEcho makes sure we can send ICMP echo via the Request method and receives response via the +// ListenResponse method +func TestICMPProxyEcho(t *testing.T) { + skipWindows(t) + const ( + echoID = 36571 + endSeq = 100 + ) + proxy, err := NewICMPProxy("udp4", localhostIP.AsSlice(), &noopLogger) + require.NoError(t, err) + + proxyDone := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + proxy.ListenResponse(ctx) + close(proxyDone) + }() + + responder := echoFlowResponder{ + decoder: packet.NewICMPDecoder(), + respChan: make(chan []byte), + } + + ip := packet.IP{ + Src: localhostIP, + Dst: localhostIP, + Protocol: layers.IPProtocolICMPv4, + } + for i := 0; i < endSeq; i++ { + pk := packet.ICMP{ + IP: &ip, + Message: &icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: echoID, + Seq: i, + Data: []byte(fmt.Sprintf("icmp echo seq %d", i)), + }, + }, + } + require.NoError(t, proxy.Request(&pk, &responder)) + responder.validate(t, &pk) + } + cancel() + <-proxyDone +} + +// TestICMPProxyRejectNotEcho makes sure it rejects messages other than echo +func TestICMPProxyRejectNotEcho(t *testing.T) { + skipWindows(t) + msgs := []icmp.Message{ + { + Type: ipv4.ICMPTypeDestinationUnreachable, + Code: 1, + Body: &icmp.DstUnreach{ + Data: []byte("original packet"), + }, + }, + { + Type: ipv4.ICMPTypeTimeExceeded, + Code: 1, + Body: &icmp.TimeExceeded{ + Data: []byte("original packet"), + }, + }, + { + Type: ipv4.ICMPType(2), + Code: 0, + Body: &icmp.PacketTooBig{ + MTU: 1280, + Data: []byte("original packet"), + }, + }, + } + proxy, err := NewICMPProxy("udp4", localhostIP.AsSlice(), &noopLogger) + require.NoError(t, err) + + responder := echoFlowResponder{ + decoder: packet.NewICMPDecoder(), + respChan: make(chan []byte), + } + for _, m := range msgs { + pk := packet.ICMP{ + IP: &packet.IP{ + Src: localhostIP, + Dst: localhostIP, + Protocol: layers.IPProtocolICMPv4, + }, + Message: &m, + } + require.Error(t, proxy.Request(&pk, &responder)) + } +} + +func skipWindows(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Cannot create non-privileged datagram-oriented ICMP endpoint on Windows") + } +} + +type echoFlowResponder struct { + decoder *packet.ICMPDecoder + respChan chan []byte +} + +func (efr *echoFlowResponder) SendPacket(pk packet.RawPacket) error { + copiedPacket := make([]byte, len(pk.Data)) + copy(copiedPacket, pk.Data) + efr.respChan <- copiedPacket + return nil +} + +func (efr *echoFlowResponder) validate(t *testing.T, echoReq *packet.ICMP) { + pk := <-efr.respChan + decoded, err := efr.decoder.Decode(packet.RawPacket{Data: pk}) + require.NoError(t, err) + require.Equal(t, decoded.Src, echoReq.Dst) + require.Equal(t, decoded.Dst, echoReq.Src) + require.Equal(t, echoReq.Protocol, decoded.Protocol) + + require.Equal(t, ipv4.ICMPTypeEchoReply, decoded.Type) + require.Equal(t, 0, decoded.Code) + require.NotZero(t, decoded.Checksum) + // TODO: TUN-6586: Enable this validation when ICMP echo ID matches on Linux + //require.Equal(t, echoReq.Body, decoded.Body) +} diff --git a/packet/decoder.go b/packet/decoder.go index 3af4d025..147cbd16 100644 --- a/packet/decoder.go +++ b/packet/decoder.go @@ -75,9 +75,9 @@ func NewIPDecoder() *IPDecoder { } } -func (pd *IPDecoder) Decode(packet []byte) (*IP, error) { +func (pd *IPDecoder) Decode(packet RawPacket) (*IP, error) { // Should decode to IP layer - decoded, err := pd.decodeByVersion(packet) + decoded, err := pd.decodeByVersion(packet.Data) if err != nil { return nil, err } @@ -139,9 +139,9 @@ func NewICMPDecoder() *ICMPDecoder { } } -func (pd *ICMPDecoder) Decode(packet []byte) (*ICMP, error) { +func (pd *ICMPDecoder) Decode(packet RawPacket) (*ICMP, error) { // Should decode to IP and optionally ICMP layer - decoded, err := pd.decodeByVersion(packet) + decoded, err := pd.decodeByVersion(packet.Data) if err != nil { return nil, err } diff --git a/packet/decoder_test.go b/packet/decoder_test.go index e315bf1f..6db377dd 100644 --- a/packet/decoder_test.go +++ b/packet/decoder_test.go @@ -43,11 +43,11 @@ func TestDecodeIP(t *testing.T) { p, err := encoder.Encode(&udp) require.NoError(t, err) - ipPacket, err := ipDecoder.Decode(p.Data) + ipPacket, err := ipDecoder.Decode(p) require.NoError(t, err) assertIPLayer(t, &udp.IP, ipPacket) - icmpPacket, err := icmpDecoder.Decode(p.Data) + icmpPacket, err := icmpDecoder.Decode(p) require.Error(t, err) require.Nil(t, icmpPacket) } @@ -137,14 +137,14 @@ func TestDecodeICMP(t *testing.T) { p, err := encoder.Encode(test.packet) require.NoError(t, err) - ipPacket, err := ipDecoder.Decode(p.Data) + ipPacket, err := ipDecoder.Decode(p) require.NoError(t, err) if ipPacket.Src.Is4() { assertIPLayer(t, &ipv4Packet, ipPacket) } else { assertIPLayer(t, &ipv6Packet, ipPacket) } - icmpPacket, err := icmpDecoder.Decode(p.Data) + icmpPacket, err := icmpDecoder.Decode(p) require.NoError(t, err) require.Equal(t, ipPacket, icmpPacket.IP) @@ -202,11 +202,11 @@ func TestDecodeBadPackets(t *testing.T) { ipDecoder := NewIPDecoder() icmpDecoder := NewICMPDecoder() for _, test := range tests { - ipPacket, err := ipDecoder.Decode(test.packet) + ipPacket, err := ipDecoder.Decode(RawPacket{Data: test.packet}) require.Error(t, err) require.Nil(t, ipPacket) - icmpPacket, err := icmpDecoder.Decode(test.packet) + icmpPacket, err := icmpDecoder.Decode(RawPacket{Data: test.packet}) require.Error(t, err) require.Nil(t, icmpPacket) } diff --git a/packet/flow.go b/packet/flow.go index f4196fa1..3eea5652 100644 --- a/packet/flow.go +++ b/packet/flow.go @@ -2,19 +2,17 @@ package packet import ( "errors" - "net" "net/netip" "sync" ) -type flowID string - var ( ErrFlowNotFound = errors.New("flow not found") ) -func newFlowID(ip net.IP) flowID { - return flowID(ip.String()) +// FlowID represents a key type that can be used by FlowTracker +type FlowID interface { + ID() string } type Flow struct { @@ -37,32 +35,29 @@ type FlowResponder interface { SendPacket(pk RawPacket) error } -// SrcFlowTracker tracks flow from the perspective of eyeball to origin -// flowID is the source IP -type SrcFlowTracker struct { +// FlowTracker tracks flow from the perspective of eyeball to origin +type FlowTracker struct { lock sync.RWMutex - flows map[flowID]*Flow + flows map[FlowID]*Flow } -func NewSrcFlowTracker() *SrcFlowTracker { - return &SrcFlowTracker{ - flows: make(map[flowID]*Flow), +func NewFlowTracker() *FlowTracker { + return &FlowTracker{ + flows: make(map[FlowID]*Flow), } } -func (sft *SrcFlowTracker) Get(srcIP net.IP) (*Flow, bool) { +func (sft *FlowTracker) Get(id FlowID) (*Flow, bool) { sft.lock.RLock() defer sft.lock.RUnlock() - id := newFlowID(srcIP) flow, ok := sft.flows[id] return flow, ok } // Registers a flow. If shouldReplace = true, replace the current flow -func (sft *SrcFlowTracker) Register(flow *Flow, shouldReplace bool) (replaced bool) { +func (sft *FlowTracker) Register(id FlowID, flow *Flow, shouldReplace bool) (replaced bool) { sft.lock.Lock() defer sft.lock.Unlock() - id := flowID(flow.Src.String()) currentFlow, ok := sft.flows[id] if !ok { sft.flows[id] = flow @@ -77,10 +72,9 @@ func (sft *SrcFlowTracker) Register(flow *Flow, shouldReplace bool) (replaced bo } // Unregisters a flow. If force = true, delete it even if it maps to a different flow -func (sft *SrcFlowTracker) Unregister(flow *Flow, force bool) (forceDeleted bool) { +func (sft *FlowTracker) Unregister(id FlowID, flow *Flow, force bool) (forceDeleted bool) { sft.lock.Lock() defer sft.lock.Unlock() - id := flowID(flow.Src.String()) currentFlow, ok := sft.flows[id] if !ok { return false diff --git a/quic/datagram_test.go b/quic/datagram_test.go index c26362a9..bd55f425 100644 --- a/quic/datagram_test.go +++ b/quic/datagram_test.go @@ -145,7 +145,7 @@ func testDatagram(t *testing.T, version uint8, sessionToPayloads []*packet.Sessi received, err := muxer.ReceivePacket(ctx) require.NoError(t, err) - receivedICMP, err := icmpDecoder.Decode(received.Data) + receivedICMP, err := icmpDecoder.Decode(received) require.NoError(t, err) require.Equal(t, pk.IP, receivedICMP.IP) require.Equal(t, pk.Type, receivedICMP.Type) diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index 7a100a59..af556b16 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net" "strings" "time" @@ -15,6 +16,7 @@ import ( "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/orchestration" "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" @@ -44,7 +46,7 @@ type Supervisor struct { config *TunnelConfig orchestrator *orchestration.Orchestrator edgeIPs *edgediscovery.Edge - edgeTunnelServer EdgeTunnelServer + edgeTunnelServer *EdgeTunnelServer tunnelErrors chan tunnelError tunnelsConnecting map[int]chan struct{} tunnelsProtocolFallback map[int]*protocolFallback @@ -114,6 +116,15 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato gracefulShutdownC: gracefulShutdownC, connAwareLogger: log, } + if useDatagramV2(config) { + // For non-privileged datagram-oriented ICMP endpoints, network must be "udp4" or "udp6" + // TODO: TUN-6654 listen for IPv6 and decide if it should listen on specific IP + icmpProxy, err := ingress.NewICMPProxy("udp4", net.IPv4zero, config.Log) + if err != nil { + return nil, err + } + edgeTunnelServer.icmpProxy = icmpProxy + } useReconnectToken := false if config.ClassicTunnel != nil { @@ -125,7 +136,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato config: config, orchestrator: orchestrator, edgeIPs: edgeIPs, - edgeTunnelServer: edgeTunnelServer, + edgeTunnelServer: &edgeTunnelServer, tunnelErrors: make(chan tunnelError), tunnelsConnecting: map[int]chan struct{}{}, tunnelsProtocolFallback: map[int]*protocolFallback{}, @@ -142,6 +153,14 @@ func (s *Supervisor) Run( ctx context.Context, connectedSignal *signal.Signal, ) error { + if s.edgeTunnelServer.icmpProxy != nil { + go func() { + if err := s.edgeTunnelServer.icmpProxy.ListenResponse(ctx); err != nil { + s.log.Logger().Err(err).Msg("icmp proxy terminated") + } + }() + } + if err := s.initialize(ctx, connectedSignal); err != nil { if err == errEarlyShutdown { return nil @@ -413,3 +432,15 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts) return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions) } + +func useDatagramV2(config *TunnelConfig) bool { + if config.NamedTunnel == nil { + return false + } + for _, feature := range config.NamedTunnel.Client.Features { + if feature == FeatureDatagramV2 { + return true + } + } + return false +} diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 0b9998d8..2652c22d 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -20,6 +20,7 @@ import ( "github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery/allregions" "github.com/cloudflare/cloudflared/h2mux" + "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/orchestration" quicpogs "github.com/cloudflare/cloudflared/quic" "github.com/cloudflare/cloudflared/retry" @@ -193,11 +194,12 @@ type EdgeTunnelServer struct { reconnectCh chan ReconnectSignal gracefulShutdownC <-chan struct{} tracker *tunnelstate.ConnTracker + icmpProxy ingress.ICMPProxy connAwareLogger *ConnAwareLogger } -func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFallback *protocolFallback, connectedSignal *signal.Signal) error { +func (e *EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFallback *protocolFallback, connectedSignal *signal.Signal) error { haConnections.Inc() defer haConnections.Dec() @@ -229,20 +231,14 @@ func (e EdgeTunnelServer) Serve(ctx context.Context, connIndex uint8, protocolFa // to another protocol when a particular metal doesn't support new protocol // Each connection can also have it's own IP version because individual connections might fallback // to another IP version. - err, recoverable := ServeTunnel( + err, recoverable := e.serveTunnel( ctx, connLog, - e.credentialManager, - e.config, - e.orchestrator, addr, connIndex, connectedFuse, protocolFallback, - e.cloudflaredUUID, - e.reconnectCh, protocolFallback.protocol, - e.gracefulShutdownC, ) // If the connection is recoverable, we want to maintain the same IP @@ -361,20 +357,14 @@ func selectNextProtocol( // ServeTunnel runs a single tunnel connection, returns nil on graceful shutdown, // on error returns a flag indicating if error can be retried -func ServeTunnel( +func (e *EdgeTunnelServer) serveTunnel( ctx context.Context, connLog *ConnAwareLogger, - credentialManager *reconnectCredentialManager, - config *TunnelConfig, - orchestrator *orchestration.Orchestrator, addr *allregions.EdgeAddr, connIndex uint8, fuse *h2mux.BooleanFuse, backoff *protocolFallback, - cloudflaredUUID uuid.UUID, - reconnectCh chan ReconnectSignal, protocol connection.Protocol, - gracefulShutdownC <-chan struct{}, ) (err error, recoverable bool) { // Treat panics as recoverable errors defer func() { @@ -389,21 +379,15 @@ func ServeTunnel( } }() - defer config.Observer.SendDisconnect(connIndex) - err, recoverable = serveTunnel( + defer e.config.Observer.SendDisconnect(connIndex) + err, recoverable = e.serveConnection( ctx, connLog, - credentialManager, - config, - orchestrator, addr, connIndex, fuse, backoff, - cloudflaredUUID, - reconnectCh, protocol, - gracefulShutdownC, ) if err != nil { @@ -416,7 +400,7 @@ func ServeTunnel( connLog.ConnAwareLogger().Err(err).Msg("Register tunnel error from server side") // Don't send registration error return from server to Sentry. They are // logged on server side - if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 { + if incidents := e.config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 { connLog.ConnAwareLogger().Msg(activeIncidentsMsg(incidents)) } return err.Cause, !err.Permanent @@ -442,93 +426,73 @@ func ServeTunnel( return nil, false } -func serveTunnel( +func (e *EdgeTunnelServer) serveConnection( ctx context.Context, connLog *ConnAwareLogger, - credentialManager *reconnectCredentialManager, - config *TunnelConfig, - orchestrator *orchestration.Orchestrator, addr *allregions.EdgeAddr, connIndex uint8, fuse *h2mux.BooleanFuse, backoff *protocolFallback, - cloudflaredUUID uuid.UUID, - reconnectCh chan ReconnectSignal, protocol connection.Protocol, - gracefulShutdownC <-chan struct{}, ) (err error, recoverable bool) { connectedFuse := &connectedFuse{ fuse: fuse, backoff: backoff, } controlStream := connection.NewControlStream( - config.Observer, + e.config.Observer, connectedFuse, - config.NamedTunnel, + e.config.NamedTunnel, connIndex, addr.UDP.IP, nil, - gracefulShutdownC, - config.GracePeriod, + e.gracefulShutdownC, + e.config.GracePeriod, protocol, ) switch protocol { case connection.QUIC, connection.QUICWarp: - connOptions := config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries())) - return ServeQUIC(ctx, + connOptions := e.config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries())) + return e.serveQUIC(ctx, addr.UDP, - config, - orchestrator, connLog, connOptions, controlStream, - connIndex, - reconnectCh, - gracefulShutdownC) + connIndex) case connection.HTTP2, connection.HTTP2Warp: - edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP) + edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP) if err != nil { connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge") return err, true } - connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries())) - if err := ServeHTTP2( + connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries())) + if err := e.serveHTTP2( ctx, connLog, - config, - orchestrator, edgeConn, connOptions, controlStream, connIndex, - gracefulShutdownC, - reconnectCh, ); err != nil { return err, false } default: - edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP) + edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, e.config.EdgeTLSConfigs[protocol], addr.TCP) if err != nil { connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge") return err, true } - if err := ServeH2mux( + if err := e.serveH2mux( ctx, connLog, - credentialManager, - config, - orchestrator, edgeConn, connIndex, connectedFuse, - cloudflaredUUID, - reconnectCh, - gracefulShutdownC, ); err != nil { return err, false } @@ -544,30 +508,24 @@ func (r unrecoverableError) Error() string { return r.err.Error() } -func ServeH2mux( +func (e *EdgeTunnelServer) serveH2mux( ctx context.Context, connLog *ConnAwareLogger, - credentialManager *reconnectCredentialManager, - config *TunnelConfig, - orchestrator *orchestration.Orchestrator, edgeConn net.Conn, connIndex uint8, connectedFuse *connectedFuse, - cloudflaredUUID uuid.UUID, - reconnectCh chan ReconnectSignal, - gracefulShutdownC <-chan struct{}, ) error { connLog.Logger().Debug().Msgf("Connecting via h2mux") // Returns error from parsing the origin URL or handshake errors handler, err, recoverable := connection.NewH2muxConnection( - orchestrator, - config.GracePeriod, - config.MuxerConfig, + e.orchestrator, + e.config.GracePeriod, + e.config.MuxerConfig, edgeConn, connIndex, - config.Observer, - gracefulShutdownC, - config.Log, + e.config.Observer, + e.gracefulShutdownC, + e.config.Log, ) if err != nil { if !recoverable { @@ -579,42 +537,38 @@ func ServeH2mux( errGroup, serveCtx := errgroup.WithContext(ctx) errGroup.Go(func() error { - if config.NamedTunnel != nil { - connOptions := config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries())) - return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse) + if e.config.NamedTunnel != nil { + connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries())) + return handler.ServeNamedTunnel(serveCtx, e.config.NamedTunnel, connOptions, connectedFuse) } - registrationOptions := config.registrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID) - return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse) + registrationOptions := e.config.registrationOptions(connIndex, edgeConn.LocalAddr().String(), e.cloudflaredUUID) + return handler.ServeClassicTunnel(serveCtx, e.config.ClassicTunnel, e.credentialManager, registrationOptions, connectedFuse) }) errGroup.Go(func() error { - return listenReconnect(serveCtx, reconnectCh, gracefulShutdownC) + return listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC) }) return errGroup.Wait() } -func ServeHTTP2( +func (e *EdgeTunnelServer) serveHTTP2( ctx context.Context, connLog *ConnAwareLogger, - config *TunnelConfig, - orchestrator *orchestration.Orchestrator, tlsServerConn net.Conn, connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler connection.ControlStreamHandler, connIndex uint8, - gracefulShutdownC <-chan struct{}, - reconnectCh chan ReconnectSignal, ) error { connLog.Logger().Debug().Msgf("Connecting via http2") h2conn := connection.NewHTTP2Connection( tlsServerConn, - orchestrator, + e.orchestrator, connOptions, - config.Observer, + e.config.Observer, connIndex, controlStreamHandler, - config.Log, + e.config.Log, ) errGroup, serveCtx := errgroup.WithContext(ctx) @@ -623,7 +577,7 @@ func ServeHTTP2( }) errGroup.Go(func() error { - err := listenReconnect(serveCtx, reconnectCh, gracefulShutdownC) + err := listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC) if err != nil { // forcefully break the connection (this is only used for testing) connLog.Logger().Debug().Msg("Forcefully breaking http2 connection") @@ -635,19 +589,15 @@ func ServeHTTP2( return errGroup.Wait() } -func ServeQUIC( +func (e *EdgeTunnelServer) serveQUIC( ctx context.Context, edgeAddr *net.UDPAddr, - config *TunnelConfig, - orchestrator *orchestration.Orchestrator, connLogger *ConnAwareLogger, connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler connection.ControlStreamHandler, connIndex uint8, - reconnectCh chan ReconnectSignal, - gracefulShutdownC <-chan struct{}, ) (err error, recoverable bool) { - tlsConfig := config.EdgeTLSConfigs[connection.QUIC] + tlsConfig := e.config.EdgeTLSConfigs[connection.QUIC] quicConfig := &quic.Config{ HandshakeIdleTimeout: quicpogs.HandshakeIdleTimeout, MaxIdleTimeout: quicpogs.MaxIdleTimeout, @@ -663,10 +613,11 @@ func ServeQUIC( quicConfig, edgeAddr, tlsConfig, - orchestrator, + e.orchestrator, connOptions, controlStreamHandler, - connLogger.Logger()) + connLogger.Logger(), + e.icmpProxy) if err != nil { connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection") return err, true @@ -682,7 +633,7 @@ func ServeQUIC( }) errGroup.Go(func() error { - err := listenReconnect(serveCtx, reconnectCh, gracefulShutdownC) + err := listenReconnect(serveCtx, e.reconnectCh, e.gracefulShutdownC) if err != nil { // forcefully break the connection (this is only used for testing) connLogger.Logger().Debug().Msg("Forcefully breaking quic connection")