cloudflared-mirror/ingress/packet_router.go

187 lines
5.5 KiB
Go

package ingress
import (
"context"
"fmt"
"github.com/rs/zerolog"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"github.com/cloudflare/cloudflared/packet"
quicpogs "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tracing"
)
// Upstream of raw packets
type muxer interface {
SendPacket(pk quicpogs.Packet) error
// ReceivePacket waits for the next raw packet from upstream
ReceivePacket(ctx context.Context) (quicpogs.Packet, error)
}
// PacketRouter routes packets between Upstream and ICMPRouter. Currently it rejects all other type of ICMP packets
type PacketRouter struct {
icmpRouter ICMPRouter
muxer muxer
connIndex uint8
logger *zerolog.Logger
encoder *packet.Encoder
decoder *packet.ICMPDecoder
}
// NewPacketRouter creates a PacketRouter that handles ICMP packets. Packets are read from muxer but dropped if globalConfig is nil.
func NewPacketRouter(icmpRouter ICMPRouter, muxer muxer, connIndex uint8, logger *zerolog.Logger) *PacketRouter {
return &PacketRouter{
icmpRouter: icmpRouter,
muxer: muxer,
connIndex: connIndex,
logger: logger,
encoder: packet.NewEncoder(),
decoder: packet.NewICMPDecoder(),
}
}
func (r *PacketRouter) Serve(ctx context.Context) error {
for {
rawPacket, responder, err := r.nextPacket(ctx)
if err != nil {
return err
}
r.handlePacket(ctx, rawPacket, responder)
}
}
func (r *PacketRouter) nextPacket(ctx context.Context) (packet.RawPacket, ICMPResponder, error) {
pk, err := r.muxer.ReceivePacket(ctx)
if err != nil {
return packet.RawPacket{}, nil, err
}
responder := newPacketResponder(r.muxer, r.connIndex, packet.NewEncoder())
switch pk.Type() {
case quicpogs.DatagramTypeIP:
return packet.RawPacket{Data: pk.Payload()}, responder, nil
case quicpogs.DatagramTypeIPWithTrace:
var identity tracing.Identity
if err := identity.UnmarshalBinary(pk.Metadata()); err != nil {
r.logger.Err(err).Bytes("tracingIdentity", pk.Metadata()).Msg("Failed to unmarshal tracing identity")
} else {
tracedCtx := tracing.NewTracedContext(ctx, identity.String(), r.logger)
responder.AddTraceContext(tracedCtx, pk.Metadata())
}
return packet.RawPacket{Data: pk.Payload()}, responder, nil
default:
return packet.RawPacket{}, nil, fmt.Errorf("unexpected datagram type %d", pk.Type())
}
}
func (r *PacketRouter) handlePacket(ctx context.Context, rawPacket packet.RawPacket, responder ICMPResponder) {
// ICMP Proxy feature is disabled, drop packets
if r.icmpRouter == nil {
return
}
icmpPacket, err := r.decoder.Decode(rawPacket)
if err != nil {
r.logger.Err(err).Msg("Failed to decode ICMP packet from quic datagram")
return
}
if icmpPacket.TTL <= 1 {
if err := r.sendTTLExceedMsg(icmpPacket, rawPacket); err != nil {
r.logger.Err(err).Msg("Failed to return ICMP TTL exceed error")
}
return
}
icmpPacket.TTL--
if err := r.icmpRouter.Request(ctx, icmpPacket, responder); err != nil {
r.logger.Err(err).
Str("src", icmpPacket.Src.String()).
Str("dst", icmpPacket.Dst.String()).
Interface("type", icmpPacket.Type).
Msg("Failed to send ICMP packet")
}
}
func (r *PacketRouter) sendTTLExceedMsg(pk *packet.ICMP, rawPacket packet.RawPacket) error {
icmpTTLPacket := r.icmpRouter.ConvertToTTLExceeded(pk, rawPacket)
encodedTTLExceed, err := r.encoder.Encode(icmpTTLPacket)
if err != nil {
return err
}
return r.muxer.SendPacket(quicpogs.RawPacket(encodedTTLExceed))
}
// packetResponder should not be used concurrently. This assumption is upheld because reply packets are ready one-by-one
type packetResponder struct {
datagramMuxer muxer
connIndex uint8
encoder *packet.Encoder
tracedCtx *tracing.TracedContext
serializedIdentity []byte
// hadReply tracks if there has been any reply for this flow
hadReply bool
}
func newPacketResponder(datagramMuxer muxer, connIndex uint8, encoder *packet.Encoder) ICMPResponder {
return &packetResponder{
datagramMuxer: datagramMuxer,
connIndex: connIndex,
encoder: encoder,
}
}
func (pr *packetResponder) tracingEnabled() bool {
return pr.tracedCtx != nil
}
func (pr *packetResponder) ConnectionIndex() uint8 {
return pr.connIndex
}
func (pr *packetResponder) ReturnPacket(pk *packet.ICMP) error {
rawPacket, err := pr.encoder.Encode(pk)
if err != nil {
return err
}
pr.hadReply = true
return pr.datagramMuxer.SendPacket(quicpogs.RawPacket(rawPacket))
}
func (pr *packetResponder) AddTraceContext(tracedCtx *tracing.TracedContext, serializedIdentity []byte) {
pr.tracedCtx = tracedCtx
pr.serializedIdentity = serializedIdentity
}
func (pr *packetResponder) RequestSpan(ctx context.Context, pk *packet.ICMP) (context.Context, trace.Span) {
if !pr.tracingEnabled() {
return ctx, tracing.NewNoopSpan()
}
return pr.tracedCtx.Tracer().Start(pr.tracedCtx, "icmp-echo-request", trace.WithAttributes(
attribute.String("src", pk.Src.String()),
attribute.String("dst", pk.Dst.String()),
))
}
func (pr *packetResponder) ReplySpan(ctx context.Context, logger *zerolog.Logger) (context.Context, trace.Span) {
if !pr.tracingEnabled() || pr.hadReply {
return ctx, tracing.NewNoopSpan()
}
return pr.tracedCtx.Tracer().Start(pr.tracedCtx, "icmp-echo-reply")
}
func (pr *packetResponder) ExportSpan() {
if !pr.tracingEnabled() {
return
}
spans := pr.tracedCtx.GetProtoSpans()
if len(spans) > 0 {
pr.datagramMuxer.SendPacket(&quicpogs.TracingSpanPacket{
Spans: spans,
TracingIdentity: pr.serializedIdentity,
})
}
}