diff --git a/connection/quic_connection_test.go b/connection/quic_connection_test.go index 49c14445..1c22605b 100644 --- a/connection/quic_connection_test.go +++ b/connection/quic_connection_test.go @@ -752,7 +752,8 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) sessionDemuxChan := make(chan *packet.Session, 4) datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, &log, sessionDemuxChan) sessionManager := datagramsession.NewManager(&log, datagramMuxer.SendToSession, sessionDemuxChan) - packetRouter := ingress.NewPacketRouter(nil, datagramMuxer, 0, &log) + var connIndex uint8 = 0 + packetRouter := ingress.NewPacketRouter(nil, datagramMuxer, connIndex, &log) datagramConn := &datagramV2Connection{ conn, diff --git a/connection/quic_datagram_v3.go b/connection/quic_datagram_v3.go index 00d3c950..1b42600e 100644 --- a/connection/quic_datagram_v3.go +++ b/connection/quic_datagram_v3.go @@ -10,6 +10,7 @@ import ( "github.com/quic-go/quic-go" "github.com/rs/zerolog" + "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/management" cfdquic "github.com/cloudflare/cloudflared/quic/v3" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -25,6 +26,7 @@ type datagramV3Connection struct { func NewDatagramV3Connection(ctx context.Context, conn quic.Connection, sessionManager cfdquic.SessionManager, + icmpRouter ingress.ICMPRouter, index uint8, metrics cfdquic.Metrics, logger *zerolog.Logger, @@ -34,7 +36,7 @@ func NewDatagramV3Connection(ctx context.Context, Int(management.EventTypeKey, int(management.UDP)). Uint8(LogFieldConnIndex, index). Logger() - datagramMuxer := cfdquic.NewDatagramConn(conn, sessionManager, index, metrics, &log) + datagramMuxer := cfdquic.NewDatagramConn(conn, sessionManager, icmpRouter, index, metrics, &log) return &datagramV3Connection{ conn, diff --git a/quic/v3/datagram.go b/quic/v3/datagram.go index 3c45e6b2..136f8fbc 100644 --- a/quic/v3/datagram.go +++ b/quic/v3/datagram.go @@ -222,11 +222,11 @@ const ( // // This method should be used in-place of MarshalBinary which will allocate in-place the required byte array to return. func MarshalPayloadHeaderTo(requestID RequestID, payload []byte) error { - if len(payload) < 17 { + if len(payload) < DatagramPayloadHeaderLen { return wrapMarshalErr(ErrDatagramPayloadHeaderTooSmall) } payload[0] = byte(UDPSessionPayloadType) - return requestID.MarshalBinaryTo(payload[1:17]) + return requestID.MarshalBinaryTo(payload[1:DatagramPayloadHeaderLen]) } func (s *UDPSessionPayloadDatagram) UnmarshalBinary(data []byte) error { @@ -239,18 +239,18 @@ func (s *UDPSessionPayloadDatagram) UnmarshalBinary(data []byte) error { } // Make sure that the slice provided is the right size to be parsed. - if len(data) < 17 || len(data) > maxPayloadPlusHeaderLen { + if len(data) < DatagramPayloadHeaderLen || len(data) > maxPayloadPlusHeaderLen { return wrapUnmarshalErr(ErrDatagramPayloadInvalidSize) } - requestID, err := RequestIDFromSlice(data[1:17]) + requestID, err := RequestIDFromSlice(data[1:DatagramPayloadHeaderLen]) if err != nil { return wrapUnmarshalErr(err) } *s = UDPSessionPayloadDatagram{ RequestID: requestID, - Payload: data[17:], + Payload: data[DatagramPayloadHeaderLen:], } return nil } @@ -370,3 +370,61 @@ func (s *UDPSessionRegistrationResponseDatagram) UnmarshalBinary(data []byte) er } return nil } + +// ICMPDatagram is used to propagate ICMPv4 and ICMPv6 payloads. +type ICMPDatagram struct { + Payload []byte +} + +// The maximum size that an ICMP packet can be. +const maxICMPPayloadLen = maxDatagramPayloadLen + +// The datagram structure for ICMPDatagram is: +// +// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// 0| Type | | +// +-+-+-+-+-+-+-+-+ + +// . Payload . +// . . +// . . +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +func (d *ICMPDatagram) MarshalBinary() (data []byte, err error) { + if len(d.Payload) > maxICMPPayloadLen { + return nil, wrapMarshalErr(ErrDatagramICMPPayloadTooLarge) + } + // We shouldn't attempt to marshal an ICMP datagram with no ICMP payload provided + if len(d.Payload) == 0 { + return nil, wrapMarshalErr(ErrDatagramICMPPayloadMissing) + } + // Make room for the 1 byte ICMPType header + datagram := make([]byte, len(d.Payload)+datagramTypeLen) + datagram[0] = byte(ICMPType) + copy(datagram[1:], d.Payload) + return datagram, nil +} + +func (d *ICMPDatagram) UnmarshalBinary(data []byte) error { + datagramType, err := ParseDatagramType(data) + if err != nil { + return wrapUnmarshalErr(err) + } + if datagramType != ICMPType { + return wrapUnmarshalErr(ErrInvalidDatagramType) + } + + if len(data[1:]) > maxDatagramPayloadLen { + return wrapUnmarshalErr(ErrDatagramICMPPayloadTooLarge) + } + + // We shouldn't attempt to unmarshal an ICMP datagram with no ICMP payload provided + if len(data[1:]) == 0 { + return wrapUnmarshalErr(ErrDatagramICMPPayloadMissing) + } + + payload := make([]byte, len(data[1:])) + copy(payload, data[1:]) + d.Payload = payload + return nil +} diff --git a/quic/v3/datagram_errors.go b/quic/v3/datagram_errors.go index 9d92b7ea..cbe30abe 100644 --- a/quic/v3/datagram_errors.go +++ b/quic/v3/datagram_errors.go @@ -15,6 +15,8 @@ var ( ErrDatagramResponseInvalidSize error = errors.New("datagram response is an invalid size") ErrDatagramResponseMsgTooLargeMaximum error = fmt.Errorf("datagram response error message length exceeds the length of the datagram maximum: %d", maxResponseErrorMessageLen) ErrDatagramResponseMsgTooLargeDatagram error = fmt.Errorf("datagram response error message length exceeds the length of the provided datagram") + ErrDatagramICMPPayloadTooLarge error = fmt.Errorf("datagram icmp payload exceeds %d bytes", maxICMPPayloadLen) + ErrDatagramICMPPayloadMissing error = errors.New("datagram icmp payload is missing") ) func wrapMarshalErr(err error) error { diff --git a/quic/v3/datagram_test.go b/quic/v3/datagram_test.go index 2c5f06fb..834c4ae4 100644 --- a/quic/v3/datagram_test.go +++ b/quic/v3/datagram_test.go @@ -160,6 +160,12 @@ func TestTypeUnmarshalErrors(t *testing.T) { if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) { t.Errorf("expected invalid length to throw error") } + + d4 := v3.ICMPDatagram{} + err = d4.UnmarshalBinary([]byte{}) + if !errors.Is(err, v3.ErrDatagramHeaderTooSmall) { + t.Errorf("expected invalid length to throw error") + } }) t.Run("invalid types", func(t *testing.T) { @@ -180,6 +186,12 @@ func TestTypeUnmarshalErrors(t *testing.T) { if !errors.Is(err, v3.ErrInvalidDatagramType) { t.Errorf("expected invalid type to throw error") } + + d4 := v3.ICMPDatagram{} + err = d4.UnmarshalBinary([]byte{byte(v3.UDPSessionPayloadType)}) + if !errors.Is(err, v3.ErrInvalidDatagramType) { + t.Errorf("expected invalid type to throw error") + } }) } @@ -343,6 +355,54 @@ func TestSessionRegistrationResponse(t *testing.T) { }) } +func TestICMPDatagram(t *testing.T) { + t.Run("basic", func(t *testing.T) { + payload := makePayload(128) + datagram := v3.ICMPDatagram{Payload: payload} + marshaled, err := datagram.MarshalBinary() + if err != nil { + t.Error(err) + } + unmarshaled := &v3.ICMPDatagram{} + err = unmarshaled.UnmarshalBinary(marshaled) + if err != nil { + t.Error(err) + } + require.Equal(t, payload, unmarshaled.Payload) + }) + + t.Run("payload size empty", func(t *testing.T) { + payload := []byte{} + datagram := v3.ICMPDatagram{Payload: payload} + _, err := datagram.MarshalBinary() + if !errors.Is(err, v3.ErrDatagramICMPPayloadMissing) { + t.Errorf("expected an error: %s", err) + } + payload = []byte{byte(v3.ICMPType)} + unmarshaled := &v3.ICMPDatagram{} + err = unmarshaled.UnmarshalBinary(payload) + if !errors.Is(err, v3.ErrDatagramICMPPayloadMissing) { + t.Errorf("expected an error: %s", err) + } + }) + + t.Run("payload size too large", func(t *testing.T) { + payload := makePayload(1280 + 1) // larger than the datagram size could be + datagram := v3.ICMPDatagram{Payload: payload} + _, err := datagram.MarshalBinary() + if !errors.Is(err, v3.ErrDatagramICMPPayloadTooLarge) { + t.Errorf("expected an error: %s", err) + } + payload = makePayload(1280 + 2) // larger than the datagram size could be + header + payload[0] = byte(v3.ICMPType) + unmarshaled := &v3.ICMPDatagram{} + err = unmarshaled.UnmarshalBinary(payload) + if !errors.Is(err, v3.ErrDatagramICMPPayloadTooLarge) { + t.Errorf("expected an error: %s", err) + } + }) +} + func compareRegistrationDatagrams(t *testing.T, l *v3.UDPSessionRegistrationDatagram, r *v3.UDPSessionRegistrationDatagram) bool { require.Equal(t, l.Payload, r.Payload) return l.RequestID == r.RequestID && @@ -377,3 +437,13 @@ func FuzzRegistrationResponseDatagram(f *testing.F) { } }) } + +func FuzzICMPDatagram(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + unmarshaled := v3.ICMPDatagram{} + err := unmarshaled.UnmarshalBinary(data) + if err == nil { + _, _ = unmarshaled.MarshalBinary() + } + }) +} diff --git a/quic/v3/icmp.go b/quic/v3/icmp.go new file mode 100644 index 00000000..e9e3cc01 --- /dev/null +++ b/quic/v3/icmp.go @@ -0,0 +1,52 @@ +package v3 + +import ( + "context" + + "github.com/rs/zerolog" + "go.opentelemetry.io/otel/trace" + + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/packet" + "github.com/cloudflare/cloudflared/tracing" +) + +// packetResponder is an implementation of the [ingress.ICMPResponder] which provides the ICMP Flow manager the +// return path to return and ICMP Echo response back to the QUIC muxer. +type packetResponder struct { + datagramMuxer DatagramICMPWriter + connID uint8 +} + +func newPacketResponder(datagramMuxer DatagramICMPWriter, connID uint8) ingress.ICMPResponder { + return &packetResponder{ + datagramMuxer, + connID, + } +} + +func (pr *packetResponder) ConnectionIndex() uint8 { + return pr.connID +} + +func (pr *packetResponder) ReturnPacket(pk *packet.ICMP) error { + return pr.datagramMuxer.SendICMPPacket(pk) +} + +func (pr *packetResponder) AddTraceContext(tracedCtx *tracing.TracedContext, serializedIdentity []byte) { + // datagram v3 does not support tracing ICMP packets +} + +func (pr *packetResponder) RequestSpan(ctx context.Context, pk *packet.ICMP) (context.Context, trace.Span) { + // datagram v3 does not support tracing ICMP packets + return ctx, tracing.NewNoopSpan() +} + +func (pr *packetResponder) ReplySpan(ctx context.Context, logger *zerolog.Logger) (context.Context, trace.Span) { + // datagram v3 does not support tracing ICMP packets + return ctx, tracing.NewNoopSpan() +} + +func (pr *packetResponder) ExportSpan() { + // datagram v3 does not support tracing ICMP packets +} diff --git a/quic/v3/icmp_test.go b/quic/v3/icmp_test.go new file mode 100644 index 00000000..3189a571 --- /dev/null +++ b/quic/v3/icmp_test.go @@ -0,0 +1,45 @@ +package v3_test + +import ( + "context" + "testing" + + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/packet" +) + +type noopICMPRouter struct{} + +func (noopICMPRouter) Request(ctx context.Context, pk *packet.ICMP, responder ingress.ICMPResponder) error { + return nil +} +func (noopICMPRouter) ConvertToTTLExceeded(pk *packet.ICMP, rawPacket packet.RawPacket) *packet.ICMP { + return nil +} + +type mockICMPRouter struct { + recv chan *packet.ICMP +} + +func newMockICMPRouter() *mockICMPRouter { + return &mockICMPRouter{ + recv: make(chan *packet.ICMP, 1), + } +} + +func (m *mockICMPRouter) Request(ctx context.Context, pk *packet.ICMP, responder ingress.ICMPResponder) error { + m.recv <- pk + return nil +} +func (mockICMPRouter) ConvertToTTLExceeded(pk *packet.ICMP, rawPacket packet.RawPacket) *packet.ICMP { + return packet.NewICMPTTLExceedPacket(pk.IP, rawPacket, testLocalAddr.AddrPort().Addr()) +} + +func assertICMPEqual(t *testing.T, expected *packet.ICMP, actual *packet.ICMP) { + if expected.Src != actual.Src { + t.Fatalf("Src address not equal: %+v\t%+v", expected, actual) + } + if expected.Dst != actual.Dst { + t.Fatalf("Dst address not equal: %+v\t%+v", expected, actual) + } +} diff --git a/quic/v3/muxer.go b/quic/v3/muxer.go index 4107a845..ed688fea 100644 --- a/quic/v3/muxer.go +++ b/quic/v3/muxer.go @@ -3,9 +3,14 @@ package v3 import ( "context" "errors" + "fmt" + "sync" "time" "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/packet" ) const ( @@ -15,24 +20,31 @@ const ( logSrcKey = "src" logDstKey = "dst" + logICMPTypeKey = "type" logDurationKey = "durationMS" ) // DatagramConn is the bridge that multiplexes writes and reads of datagrams for UDP sessions and ICMP packets to // a connection. type DatagramConn interface { - DatagramWriter + DatagramUDPWriter + DatagramICMPWriter // Serve provides a server interface to process and handle incoming QUIC datagrams and demux their datagram v3 payloads. Serve(context.Context) error // ID indicates connection index identifier ID() uint8 } -// DatagramWriter provides the Muxer interface to create proper Datagrams when sending over a connection. -type DatagramWriter interface { +// DatagramUDPWriter provides the Muxer interface to create proper UDP Datagrams when sending over a connection. +type DatagramUDPWriter interface { SendUDPSessionDatagram(datagram []byte) error SendUDPSessionResponse(id RequestID, resp SessionRegistrationResp) error - //SendICMPPacket(packet packet.IP) error +} + +// DatagramICMPWriter provides the Muxer interface to create ICMP Datagrams when sending over a connection. +type DatagramICMPWriter interface { + SendICMPPacket(icmp *packet.ICMP) error + SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.RawPacket) error } // QuicConnection provides an interface that matches [quic.Connection] for only the datagram operations. @@ -50,27 +62,38 @@ type datagramConn struct { conn QuicConnection index uint8 sessionManager SessionManager + icmpRouter ingress.ICMPRouter metrics Metrics logger *zerolog.Logger datagrams chan []byte readErrors chan error + + icmpEncoderPool sync.Pool // a pool of *packet.Encoder + icmpDecoder *packet.ICMPDecoder } -func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn { +func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn { log := logger.With().Uint8("datagramVersion", 3).Logger() return &datagramConn{ conn: conn, index: index, sessionManager: sessionManager, + icmpRouter: icmpRouter, metrics: metrics, logger: &log, datagrams: make(chan []byte, demuxChanCapacity), readErrors: make(chan error, 2), + icmpEncoderPool: sync.Pool{ + New: func() any { + return packet.NewEncoder() + }, + }, + icmpDecoder: packet.NewICMPDecoder(), } } -func (c datagramConn) ID() uint8 { +func (c *datagramConn) ID() uint8 { return c.index } @@ -90,6 +113,33 @@ func (c *datagramConn) SendUDPSessionResponse(id RequestID, resp SessionRegistra return c.conn.SendDatagram(data) } +func (c *datagramConn) SendICMPPacket(icmp *packet.ICMP) error { + cachedEncoder := c.icmpEncoderPool.Get() + // The encoded packet is a slice to a buffer owned by the encoder, so we shouldn't return the encoder back to the + // pool until the encoded packet is sent. + defer c.icmpEncoderPool.Put(cachedEncoder) + encoder, ok := cachedEncoder.(*packet.Encoder) + if !ok { + return fmt.Errorf("encoderPool returned %T, expect *packet.Encoder", cachedEncoder) + } + payload, err := encoder.Encode(icmp) + if err != nil { + return err + } + icmpDatagram := ICMPDatagram{ + Payload: payload.Data, + } + datagram, err := icmpDatagram.MarshalBinary() + if err != nil { + return err + } + return c.conn.SendDatagram(datagram) +} + +func (c *datagramConn) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.RawPacket) error { + return c.SendICMPPacket(c.icmpRouter.ConvertToTTLExceeded(icmp, rawPacket)) +} + var errReadTimeout error = errors.New("receive datagram timeout") // pollDatagrams will read datagrams from the underlying connection until the provided context is done. @@ -165,6 +215,14 @@ func (c *datagramConn) Serve(ctx context.Context) error { } logger := c.logger.With().Str(logFlowID, payload.RequestID.String()).Logger() c.handleSessionPayloadDatagram(payload, &logger) + case ICMPType: + packet := &ICMPDatagram{} + err := packet.UnmarshalBinary(datagram) + if err != nil { + c.logger.Err(err).Msgf("unable to unmarshal icmp datagram") + return + } + c.handleICMPPacket(packet) case UDPSessionRegistrationResponseType: // cloudflared should never expect to receive UDP session responses as it will not initiate new // sessions towards the edge. @@ -299,3 +357,41 @@ func (c *datagramConn) handleSessionPayloadDatagram(datagram *UDPSessionPayloadD return } } + +// Handles incoming ICMP datagrams. +func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) { + if c.icmpRouter == nil { + // ICMPRouter is disabled so we drop the current packet and ignore all incoming ICMP packets + return + } + + // Decode the provided ICMPDatagram as an ICMP packet + rawPacket := packet.RawPacket{Data: datagram.Payload} + icmp, err := c.icmpDecoder.Decode(rawPacket) + if err != nil { + c.logger.Err(err).Msgf("unable to marshal icmp packet") + return + } + + // If the ICMP packet's TTL is expired, we won't send it to the origin and immediately return a TTL Exceeded Message + if icmp.TTL <= 1 { + if err := c.SendICMPTTLExceed(icmp, rawPacket); err != nil { + c.logger.Err(err).Msg("failed to return ICMP TTL exceed error") + } + return + } + icmp.TTL-- + + // The context isn't really needed here since it's only really used throughout the ICMP router as a way to store + // the tracing context, however datagram V3 does not support tracing ICMP packets, so we just pass the current + // connection context which will have no tracing information available. + err = c.icmpRouter.Request(c.conn.Context(), icmp, newPacketResponder(c, c.index)) + if err != nil { + c.logger.Err(err). + Str(logSrcKey, icmp.Src.String()). + Str(logDstKey, icmp.Dst.String()). + Interface(logICMPTypeKey, icmp.Type). + Msgf("unable to write icmp datagram to origin") + return + } +} diff --git a/quic/v3/muxer_test.go b/quic/v3/muxer_test.go index ac9bf883..7b532ba3 100644 --- a/quic/v3/muxer_test.go +++ b/quic/v3/muxer_test.go @@ -11,9 +11,13 @@ import ( "testing" "time" + "github.com/google/gopacket/layers" "github.com/rs/zerolog" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" "github.com/cloudflare/cloudflared/ingress" + "github.com/cloudflare/cloudflared/packet" v3 "github.com/cloudflare/cloudflared/quic/v3" ) @@ -27,6 +31,8 @@ func (noopEyeball) SendUDPSessionDatagram(datagram []byte) error { return nil } func (noopEyeball) SendUDPSessionResponse(id v3.RequestID, resp v3.SessionRegistrationResp) error { return nil } +func (noopEyeball) SendICMPPacket(icmp *packet.ICMP) error { return nil } +func (noopEyeball) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.RawPacket) error { return nil } type mockEyeball struct { connID uint8 @@ -70,9 +76,14 @@ func (m *mockEyeball) SendUDPSessionResponse(id v3.RequestID, resp v3.SessionReg return nil } +func (m *mockEyeball) SendICMPPacket(icmp *packet.ICMP) error { return nil } +func (m *mockEyeball) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.RawPacket) error { + return nil +} + func TestDatagramConn_New(t *testing.T) { log := zerolog.Nop() - conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log) if conn == nil { t.Fatal("expected valid connection") } @@ -81,7 +92,7 @@ func TestDatagramConn_New(t *testing.T) { func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) { log := zerolog.Nop() quic := newMockQuicConn() - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log) payload := []byte{0xef, 0xef} conn.SendUDPSessionDatagram(payload) @@ -94,7 +105,7 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) { func TestDatagramConn_SendUDPSessionResponse(t *testing.T) { log := zerolog.Nop() quic := newMockQuicConn() - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log) conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable) resp := <-quic.recv @@ -115,7 +126,7 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) { func TestDatagramConnServe_ApplicationClosed(t *testing.T) { log := zerolog.Nop() quic := newMockQuicConn() - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -131,7 +142,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() quic.ctx = ctx - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log) err := conn.Serve(context.Background()) if !errors.Is(err, context.DeadlineExceeded) { @@ -142,7 +153,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) { func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) { log := zerolog.Nop() quic := &mockQuicConnReadError{err: net.ErrClosed} - conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log) err := conn.Serve(context.Background()) if !errors.Is(err, net.ErrClosed) { @@ -177,7 +188,7 @@ func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) { log := zerolog.New(logOutput) quic := newMockQuicConn() quic.send <- test.input - conn := v3.NewDatagramConn(quic, &mockSessionManager{}, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, &mockSessionManager{}, &noopICMPRouter{}, 0, &noopMetrics{}, &log) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -218,7 +229,7 @@ func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) { quic := newMockQuicConn() expectedErr := errors.New("unable to register session") sessionManager := mockSessionManager{expectedRegErr: expectedErr} - conn := v3.NewDatagramConn(quic, &sessionManager, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) // Setup the muxer ctx, cancel := context.WithCancelCause(context.Background()) @@ -253,7 +264,7 @@ func TestDatagramConnServe(t *testing.T) { quic := newMockQuicConn() session := newMockSession() sessionManager := mockSessionManager{session: &session} - conn := v3.NewDatagramConn(quic, &sessionManager, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) // Setup the muxer ctx, cancel := context.WithCancelCause(context.Background()) @@ -298,7 +309,7 @@ func TestDatagramConnServe_RegisterTwice(t *testing.T) { quic := newMockQuicConn() session := newMockSession() sessionManager := mockSessionManager{session: &session} - conn := v3.NewDatagramConn(quic, &sessionManager, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) // Setup the muxer ctx, cancel := context.WithCancelCause(context.Background()) @@ -360,9 +371,9 @@ func TestDatagramConnServe_MigrateConnection(t *testing.T) { quic := newMockQuicConn() session := newMockSession() sessionManager := mockSessionManager{session: &session} - conn := v3.NewDatagramConn(quic, &sessionManager, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) quic2 := newMockQuicConn() - conn2 := v3.NewDatagramConn(quic2, &sessionManager, 1, &noopMetrics{}, &log) + conn2 := v3.NewDatagramConn(quic2, &sessionManager, &noopICMPRouter{}, 1, &noopMetrics{}, &log) // Setup the muxer ctx, cancel := context.WithCancelCause(context.Background()) @@ -443,7 +454,7 @@ func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) { quic := newMockQuicConn() // mockSessionManager will return the ErrSessionNotFound for any session attempting to be queried by the muxer sessionManager := mockSessionManager{session: nil, expectedGetErr: v3.ErrSessionNotFound} - conn := v3.NewDatagramConn(quic, &sessionManager, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) // Setup the muxer ctx, cancel := context.WithCancelCause(context.Background()) @@ -471,7 +482,7 @@ func TestDatagramConnServe_Payload(t *testing.T) { quic := newMockQuicConn() session := newMockSession() sessionManager := mockSessionManager{session: &session} - conn := v3.NewDatagramConn(quic, &sessionManager, 0, &noopMetrics{}, &log) + conn := v3.NewDatagramConn(quic, &sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log) // Setup the muxer ctx, cancel := context.WithCancelCause(context.Background()) @@ -496,6 +507,116 @@ func TestDatagramConnServe_Payload(t *testing.T) { assertContextClosed(t, ctx, done, cancel) } +func TestDatagramConnServe_ICMPDatagram_TTLDecremented(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + router := newMockICMPRouter() + conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log) + + // Setup the muxer + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errors.New("other error")) + done := make(chan error, 1) + go func() { + done <- conn.Serve(ctx) + }() + + // Send new ICMP Echo request + expectedICMP := &packet.ICMP{ + IP: &packet.IP{ + Src: netip.MustParseAddr("192.168.1.1"), + Dst: netip.MustParseAddr("10.0.0.1"), + Protocol: layers.IPProtocolICMPv4, + TTL: 20, + }, + Message: &icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: 25821, + Seq: 58129, + Data: []byte("test ttl=0"), + }, + }, + } + datagram := newICMPDatagram(expectedICMP) + quic.send <- datagram + + // Router should receive the packet + actualICMP := <-router.recv + assertICMPEqual(t, expectedICMP, actualICMP) + if expectedICMP.TTL-1 != actualICMP.TTL { + t.Fatalf("TTL should be decremented by one before sending to origin: %d, %d", expectedICMP.TTL, actualICMP.TTL) + } + + // Cancel the muxer Serve context and make sure it closes with the expected error + assertContextClosed(t, ctx, done, cancel) +} + +func TestDatagramConnServe_ICMPDatagram_TTLExceeded(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + router := newMockICMPRouter() + conn := v3.NewDatagramConn(quic, &mockSessionManager{}, router, 0, &noopMetrics{}, &log) + + // Setup the muxer + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errors.New("other error")) + done := make(chan error, 1) + go func() { + done <- conn.Serve(ctx) + }() + + // Send new ICMP Echo request + expectedICMP := &packet.ICMP{ + IP: &packet.IP{ + Src: netip.MustParseAddr("192.168.1.1"), + Dst: netip.MustParseAddr("10.0.0.1"), + Protocol: layers.IPProtocolICMPv4, + TTL: 0, + }, + Message: &icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: 25821, + Seq: 58129, + Data: []byte("test ttl=0"), + }, + }, + } + datagram := newICMPDatagram(expectedICMP) + quic.send <- datagram + + // Origin should not recieve a packet + select { + case <-router.recv: + t.Fatalf("TTL should be expired and no origin ICMP sent") + default: + } + + // Eyeball should receive the packet + datagram = <-quic.recv + icmpDatagram := v3.ICMPDatagram{} + err := icmpDatagram.UnmarshalBinary(datagram) + if err != nil { + t.Fatal(err) + } + decoder := packet.NewICMPDecoder() + ttlExpiredICMP, err := decoder.Decode(packet.RawPacket{Data: icmpDatagram.Payload}) + if err != nil { + t.Fatal(err) + } + + // Packet should be a TTL Exceeded ICMP + if ttlExpiredICMP.TTL != packet.DefaultTTL || ttlExpiredICMP.Message.Type != ipv4.ICMPTypeTimeExceeded { + t.Fatalf("ICMP packet should be a ICMP Exceeded: %+v", ttlExpiredICMP) + } + + // Cancel the muxer Serve context and make sure it closes with the expected error + assertContextClosed(t, ctx, done, cancel) +} + func newRegisterSessionDatagram(id v3.RequestID) []byte { datagram := v3.UDPSessionRegistrationDatagram{ RequestID: id, @@ -531,6 +652,22 @@ func newSessionPayloadDatagram(id v3.RequestID, payload []byte) []byte { return datagram } +func newICMPDatagram(pk *packet.ICMP) []byte { + encoder := packet.NewEncoder() + rawPacket, err := encoder.Encode(pk) + if err != nil { + panic(err) + } + datagram := v3.ICMPDatagram{ + Payload: rawPacket.Data, + } + payload, err := datagram.MarshalBinary() + if err != nil { + panic(err) + } + return payload +} + // Cancel the provided context and make sure it closes with the expected cancellation error func assertContextClosed(t *testing.T, ctx context.Context, done <-chan error, cancel context.CancelCauseFunc) { cancel(expectedContextCanceled) diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index c5ec9978..09983e11 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -607,6 +607,7 @@ func (e *EdgeTunnelServer) serveQUIC( ctx, conn, e.sessionManager, + e.config.ICMPRouterServer, connIndex, e.datagramMetrics, connLogger.Logger(),