diff --git a/.gitignore b/.gitignore index c258e58c..2af7a1ed 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ cscope.* ssh_server_tests/.env /.cover built_artifacts/ +component-tests/.venv diff --git a/.teamcity/install-cloudflare-go.sh b/.teamcity/install-cloudflare-go.sh index ae24c038..9677ed7e 100755 --- a/.teamcity/install-cloudflare-go.sh +++ b/.teamcity/install-cloudflare-go.sh @@ -3,6 +3,6 @@ cd /tmp git clone -q https://github.com/cloudflare/go cd go/src -# https://github.com/cloudflare/go/tree/ec0a014545f180b0c74dfd687698657a9e86e310 is version go1.22.2-devel-cf -git checkout -q ec0a014545f180b0c74dfd687698657a9e86e310 -./make.bash \ No newline at end of file +# https://github.com/cloudflare/go/tree/f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38 is version go1.22.5-devel-cf +git checkout -q f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38 +./make.bash diff --git a/.teamcity/windows/component-test.ps1 b/.teamcity/windows/component-test.ps1 index fe70738e..548fac6b 100644 --- a/.teamcity/windows/component-test.ps1 +++ b/.teamcity/windows/component-test.ps1 @@ -37,7 +37,7 @@ if ($LASTEXITCODE -ne 0) { throw "Failed unit tests" } Write-Output "Running component tests" -python -m pip --disable-pip-version-check install --upgrade -r component-tests/requirements.txt +python -m pip --disable-pip-version-check install --upgrade -r component-tests/requirements.txt --use-pep517 python component-tests/setup.py --type create python -m pytest component-tests -o log_cli=true --log-cli-level=INFO if ($LASTEXITCODE -ne 0) { diff --git a/.teamcity/windows/install-cloudflare-go.ps1 b/.teamcity/windows/install-cloudflare-go.ps1 index eedc7c15..6ff957b9 100644 --- a/.teamcity/windows/install-cloudflare-go.ps1 +++ b/.teamcity/windows/install-cloudflare-go.ps1 @@ -9,8 +9,8 @@ Set-Location "$Env:Temp" git clone -q https://github.com/cloudflare/go Write-Output "Building go..." cd go/src -# https://github.com/cloudflare/go/tree/ec0a014545f180b0c74dfd687698657a9e86e310 is version go1.22.2-devel-cf -git checkout -q ec0a014545f180b0c74dfd687698657a9e86e310 +# https://github.com/cloudflare/go/tree/f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38 is version go1.22.5-devel-cf +git checkout -q f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38 & ./make.bat -Write-Output "Installed" \ No newline at end of file +Write-Output "Installed" diff --git a/.teamcity/windows/install-go-msi.ps1 b/.teamcity/windows/install-go-msi.ps1 index cb5602c1..7756c1c4 100644 --- a/.teamcity/windows/install-go-msi.ps1 +++ b/.teamcity/windows/install-go-msi.ps1 @@ -1,6 +1,6 @@ $ErrorActionPreference = "Stop" $ProgressPreference = "SilentlyContinue" -$GoMsiVersion = "go1.22.2.windows-amd64.msi" +$GoMsiVersion = "go1.22.5.windows-amd64.msi" Write-Output "Downloading go installer..." @@ -17,4 +17,4 @@ Install-Package "$Env:Temp\$GoMsiVersion" -Force # Go installer updates global $PATH go env -Write-Output "Installed" \ No newline at end of file +Write-Output "Installed" diff --git a/Dockerfile b/Dockerfile index 639dc5ca..8dac4752 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ # use a builder image for building cloudflare ARG TARGET_GOOS ARG TARGET_GOARCH -FROM golang:1.22.2 as builder +FROM golang:1.22.5 as builder ENV GO111MODULE=on \ CGO_ENABLED=0 \ TARGET_GOOS=${TARGET_GOOS} \ diff --git a/Dockerfile.amd64 b/Dockerfile.amd64 index f17969cb..d1cdbcbf 100644 --- a/Dockerfile.amd64 +++ b/Dockerfile.amd64 @@ -1,5 +1,5 @@ # use a builder image for building cloudflare -FROM golang:1.22.2 as builder +FROM golang:1.22.5 as builder ENV GO111MODULE=on \ CGO_ENABLED=0 diff --git a/Dockerfile.arm64 b/Dockerfile.arm64 index 67f4935d..0190bf63 100644 --- a/Dockerfile.arm64 +++ b/Dockerfile.arm64 @@ -1,5 +1,5 @@ # use a builder image for building cloudflare -FROM golang:1.22.2 as builder +FROM golang:1.22.5 as builder ENV GO111MODULE=on \ CGO_ENABLED=0 diff --git a/Makefile b/Makefile index 46fee2a9..c8652bec 100644 --- a/Makefile +++ b/Makefile @@ -165,10 +165,6 @@ cover: # Generate the HTML report that can be viewed from the browser in CI. $Q go tool cover -html ".cover/c.out" -o .cover/all.html -.PHONY: test-ssh-server -test-ssh-server: - docker-compose -f ssh_server_tests/docker-compose.yml up - .PHONY: install-go install-go: rm -rf ${CF_GO_PATH} diff --git a/RELEASE_NOTES b/RELEASE_NOTES index 53df2dfa..ebd55393 100644 --- a/RELEASE_NOTES +++ b/RELEASE_NOTES @@ -1,3 +1,15 @@ +2024.11.0 +- 2024-11-05 VULN-66059: remove ssh server tests +- 2024-11-04 TUN-8700: Add datagram v3 muxer +- 2024-11-04 TUN-8646: Allow experimental feature support for datagram v3 +- 2024-11-04 TUN-8641: Expose methods to simplify V3 Datagram parsing on the edge +- 2024-10-31 TUN-8708: Bump python min version to 3.10 +- 2024-10-31 TUN-8667: Add datagram v3 session manager +- 2024-10-25 TUN-8692: remove dashes from session id +- 2024-10-24 TUN-8694: Rework release script +- 2024-10-24 TUN-8661: Refactor connection methods to support future different datagram muxing methods +- 2024-07-22 TUN-8553: Bump go to 1.22.5 and go-boring 1.22.5-1 + 2024.10.1 - 2024-10-23 TUN-8694: Fix github release script - 2024-10-21 Revert "TUN-8592: Use metadata from the edge to determine if request body is empty for QUIC transport" diff --git a/cfsetup.yaml b/cfsetup.yaml index 62e8de5d..be7dbe3c 100644 --- a/cfsetup.yaml +++ b/cfsetup.yaml @@ -1,4 +1,4 @@ -pinned_go: &pinned_go go-boring=1.22.2-1 +pinned_go: &pinned_go go-boring=1.22.5-1 build_dir: &build_dir /cfsetup_build default-flavor: bullseye diff --git a/component-tests/README.md b/component-tests/README.md index 537fb47e..6eac7782 100644 --- a/component-tests/README.md +++ b/component-tests/README.md @@ -1,9 +1,9 @@ # Requirements -1. Python 3.7 or later with packages in the given `requirements.txt` - - E.g. with conda: - - `conda create -n component-tests python=3.7` - - `conda activate component-tests` - - `pip3 install -r requirements.txt` +1. Python 3.10 or later with packages in the given `requirements.txt` + - E.g. with venv: + - `python3 -m venv ./.venv` + - `source ./.venv/bin/activate` + - `python3 -m pip install -r requirements.txt` 2. Create a config yaml file, for example: ``` diff --git a/connection/quic_datagram_v3.go b/connection/quic_datagram_v3.go new file mode 100644 index 00000000..00d3c950 --- /dev/null +++ b/connection/quic_datagram_v3.go @@ -0,0 +1,56 @@ +package connection + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/google/uuid" + "github.com/quic-go/quic-go" + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/management" + cfdquic "github.com/cloudflare/cloudflared/quic/v3" + "github.com/cloudflare/cloudflared/tunnelrpc/pogs" +) + +type datagramV3Connection struct { + conn quic.Connection + // datagramMuxer mux/demux datagrams from quic connection + datagramMuxer cfdquic.DatagramConn + logger *zerolog.Logger +} + +func NewDatagramV3Connection(ctx context.Context, + conn quic.Connection, + sessionManager cfdquic.SessionManager, + index uint8, + metrics cfdquic.Metrics, + logger *zerolog.Logger, +) DatagramSessionHandler { + log := logger. + With(). + Int(management.EventTypeKey, int(management.UDP)). + Uint8(LogFieldConnIndex, index). + Logger() + datagramMuxer := cfdquic.NewDatagramConn(conn, sessionManager, index, metrics, &log) + + return &datagramV3Connection{ + conn, + datagramMuxer, + logger, + } +} + +func (d *datagramV3Connection) Serve(ctx context.Context) error { + return d.datagramMuxer.Serve(ctx) +} + +func (d *datagramV3Connection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*pogs.RegisterUdpSessionResponse, error) { + return nil, fmt.Errorf("datagram v3 does not support RegisterUdpSession RPC") +} + +func (d *datagramV3Connection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error { + return fmt.Errorf("datagram v3 does not support UnregisterUdpSession RPC") +} diff --git a/dev.Dockerfile b/dev.Dockerfile index e8a8ceba..8986040a 100644 --- a/dev.Dockerfile +++ b/dev.Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.22.2 as builder +FROM golang:1.22.5 as builder ENV GO111MODULE=on \ CGO_ENABLED=0 WORKDIR /go/src/github.com/cloudflare/cloudflared/ diff --git a/ingress/origin_udp_proxy.go b/ingress/origin_udp_proxy.go index 836489be..012c05c0 100644 --- a/ingress/origin_udp_proxy.go +++ b/ingress/origin_udp_proxy.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "net" + "net/netip" ) type UDPProxy interface { @@ -30,3 +31,16 @@ func DialUDP(dstIP net.IP, dstPort uint16) (UDPProxy, error) { return &udpProxy{udpConn}, nil } + +func DialUDPAddrPort(dest netip.AddrPort) (*net.UDPConn, error) { + addr := net.UDPAddrFromAddrPort(dest) + + // We use nil as local addr to force runtime to find the best suitable local address IP given the destination + // address as context. + udpConn, err := net.DialUDP("udp", nil, addr) + if err != nil { + return nil, fmt.Errorf("unable to dial udp to origin %s: %w", dest, err) + } + + return udpConn, nil +} diff --git a/quic/v3/datagram.go b/quic/v3/datagram.go index d5c2ac1b..3c45e6b2 100644 --- a/quic/v3/datagram.go +++ b/quic/v3/datagram.go @@ -24,10 +24,10 @@ const ( datagramTypeLen = 1 // 1280 is the default datagram packet length used before MTU discovery: https://github.com/quic-go/quic-go/blob/v0.45.0/internal/protocol/params.go#L12 - maxDatagramLen = 1280 + maxDatagramPayloadLen = 1280 ) -func parseDatagramType(data []byte) (DatagramType, error) { +func ParseDatagramType(data []byte) (DatagramType, error) { if len(data) < datagramTypeLen { return 0, ErrDatagramHeaderTooSmall } @@ -100,10 +100,10 @@ func (s *UDPSessionRegistrationDatagram) MarshalBinary() (data []byte, err error } var maxPayloadLen int if ipv6 { - maxPayloadLen = maxDatagramLen - sessionRegistrationIPv6DatagramHeaderLen + maxPayloadLen = maxDatagramPayloadLen + sessionRegistrationIPv6DatagramHeaderLen flags |= sessionRegistrationFlagsIPMask } else { - maxPayloadLen = maxDatagramLen - sessionRegistrationIPv4DatagramHeaderLen + maxPayloadLen = maxDatagramPayloadLen + sessionRegistrationIPv4DatagramHeaderLen } // Make sure that the payload being bundled can actually fit in the payload destination if len(s.Payload) > maxPayloadLen { @@ -140,7 +140,7 @@ func (s *UDPSessionRegistrationDatagram) MarshalBinary() (data []byte, err error } func (s *UDPSessionRegistrationDatagram) UnmarshalBinary(data []byte) error { - datagramType, err := parseDatagramType(data) + datagramType, err := ParseDatagramType(data) if err != nil { return err } @@ -192,10 +192,10 @@ type UDPSessionPayloadDatagram struct { } const ( - datagramPayloadHeaderLen = datagramTypeLen + datagramRequestIdLen + DatagramPayloadHeaderLen = datagramTypeLen + datagramRequestIdLen // The maximum size that a proxied UDP payload can be in a [UDPSessionPayloadDatagram] - maxPayloadPlusHeaderLen = maxDatagramLen - datagramPayloadHeaderLen + maxPayloadPlusHeaderLen = maxDatagramPayloadLen + DatagramPayloadHeaderLen ) // The datagram structure for UDPSessionPayloadDatagram is: @@ -230,7 +230,7 @@ func MarshalPayloadHeaderTo(requestID RequestID, payload []byte) error { } func (s *UDPSessionPayloadDatagram) UnmarshalBinary(data []byte) error { - datagramType, err := parseDatagramType(data) + datagramType, err := ParseDatagramType(data) if err != nil { return err } @@ -270,7 +270,7 @@ const ( datagramSessionRegistrationResponseLen = datagramTypeLen + datagramRespTypeLen + datagramRequestIdLen + datagramRespErrMsgLen // The maximum size that an error message can be in a [UDPSessionRegistrationResponseDatagram]. - maxResponseErrorMessageLen = maxDatagramLen - datagramSessionRegistrationResponseLen + maxResponseErrorMessageLen = maxDatagramPayloadLen - datagramSessionRegistrationResponseLen ) // SessionRegistrationResp represents all of the responses that a UDP session registration response @@ -330,7 +330,7 @@ func (s *UDPSessionRegistrationResponseDatagram) MarshalBinary() (data []byte, e } func (s *UDPSessionRegistrationResponseDatagram) UnmarshalBinary(data []byte) error { - datagramType, err := parseDatagramType(data) + datagramType, err := ParseDatagramType(data) if err != nil { return wrapUnmarshalErr(err) } diff --git a/quic/v3/datagram_errors.go b/quic/v3/datagram_errors.go index 244915db..9d92b7ea 100644 --- a/quic/v3/datagram_errors.go +++ b/quic/v3/datagram_errors.go @@ -7,7 +7,7 @@ import ( var ( ErrInvalidDatagramType error = errors.New("invalid datagram type expected") - ErrDatagramHeaderTooSmall error = fmt.Errorf("datagram should have at least %d bytes", datagramTypeLen) + ErrDatagramHeaderTooSmall error = fmt.Errorf("datagram should have at least %d byte", datagramTypeLen) ErrDatagramPayloadTooLarge error = errors.New("payload length is too large to be bundled in datagram") ErrDatagramPayloadHeaderTooSmall error = errors.New("payload length is too small to fit the datagram header") ErrDatagramPayloadInvalidSize error = errors.New("datagram provided is an invalid size") diff --git a/quic/v3/datagram_test.go b/quic/v3/datagram_test.go index b2e77f89..ff46ef24 100644 --- a/quic/v3/datagram_test.go +++ b/quic/v3/datagram_test.go @@ -21,7 +21,7 @@ func makePayload(size int) []byte { } func TestSessionRegistration_MarshalUnmarshal(t *testing.T) { - payload := makePayload(1254) + payload := makePayload(1280) tests := []*v3.UDPSessionRegistrationDatagram{ // Default (IPv4) { @@ -236,7 +236,7 @@ func TestSessionPayload(t *testing.T) { }) t.Run("payload size too large", func(t *testing.T) { - datagram := makePayload(17 + 1264) // 1263 is the largest payload size allowed + datagram := makePayload(17 + 1281) // 1280 is the largest payload size allowed err := v3.MarshalPayloadHeaderTo(testRequestID, datagram) if err != nil { t.Error(err) diff --git a/quic/v3/manager.go b/quic/v3/manager.go new file mode 100644 index 00000000..f5b0667f --- /dev/null +++ b/quic/v3/manager.go @@ -0,0 +1,103 @@ +package v3 + +import ( + "errors" + "net" + "net/netip" + "sync" + + "github.com/rs/zerolog" +) + +var ( + // ErrSessionNotFound indicates that a session has not been registered yet for the request id. + ErrSessionNotFound = errors.New("flow not found") + // ErrSessionBoundToOtherConn is returned when a registration already exists for a different connection. + ErrSessionBoundToOtherConn = errors.New("flow is in use by another connection") + // ErrSessionAlreadyRegistered is returned when a registration already exists for this connection. + ErrSessionAlreadyRegistered = errors.New("flow is already registered for this connection") +) + +type SessionManager interface { + // RegisterSession will register a new session if it does not already exist for the request ID. + // During new session creation, the session will also bind the UDP socket for the origin. + // If the session exists for a different connection, it will return [ErrSessionBoundToOtherConn]. + RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramConn) (Session, error) + // GetSession returns an active session if available for the provided connection. + // If the session does not exist, it will return [ErrSessionNotFound]. If the session exists for a different + // connection, it will return [ErrSessionBoundToOtherConn]. + GetSession(requestID RequestID) (Session, error) + // UnregisterSession will remove a session from the current session manager. It will attempt to close the session + // before removal. + UnregisterSession(requestID RequestID) +} + +type DialUDP func(dest netip.AddrPort) (*net.UDPConn, error) + +type sessionManager struct { + sessions map[RequestID]Session + mutex sync.RWMutex + originDialer DialUDP + metrics Metrics + log *zerolog.Logger +} + +func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP) SessionManager { + return &sessionManager{ + sessions: make(map[RequestID]Session), + originDialer: originDialer, + metrics: metrics, + log: log, + } +} + +func (s *sessionManager) RegisterSession(request *UDPSessionRegistrationDatagram, conn DatagramConn) (Session, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + // Check to make sure session doesn't already exist for requestID + if session, exists := s.sessions[request.RequestID]; exists { + if conn.ID() == session.ConnectionID() { + return nil, ErrSessionAlreadyRegistered + } + return nil, ErrSessionBoundToOtherConn + } + // Attempt to bind the UDP socket for the new session + origin, err := s.originDialer(request.Dest) + if err != nil { + return nil, err + } + // Create and insert the new session in the map + session := NewSession( + request.RequestID, + request.IdleDurationHint, + origin, + origin.RemoteAddr(), + origin.LocalAddr(), + conn, + s.metrics, + s.log) + s.sessions[request.RequestID] = session + return session, nil +} + +func (s *sessionManager) GetSession(requestID RequestID) (Session, error) { + s.mutex.RLock() + defer s.mutex.RUnlock() + session, exists := s.sessions[requestID] + if exists { + return session, nil + } + return nil, ErrSessionNotFound +} + +func (s *sessionManager) UnregisterSession(requestID RequestID) { + s.mutex.Lock() + defer s.mutex.Unlock() + // Get the session and make sure to close it if it isn't already closed + session, exists := s.sessions[requestID] + if exists { + // We ignore any errors when attempting to close the session + _ = session.Close() + } + delete(s.sessions, requestID) +} diff --git a/quic/v3/manager_test.go b/quic/v3/manager_test.go new file mode 100644 index 00000000..71defadd --- /dev/null +++ b/quic/v3/manager_test.go @@ -0,0 +1,80 @@ +package v3_test + +import ( + "errors" + "net/netip" + "strings" + "testing" + "time" + + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/ingress" + v3 "github.com/cloudflare/cloudflared/quic/v3" +) + +func TestRegisterSession(t *testing.T) { + log := zerolog.Nop() + manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort) + + request := v3.UDPSessionRegistrationDatagram{ + RequestID: testRequestID, + Dest: netip.MustParseAddrPort("127.0.0.1:5000"), + Traced: false, + IdleDurationHint: 5 * time.Second, + Payload: nil, + } + session, err := manager.RegisterSession(&request, &noopEyeball{}) + if err != nil { + t.Fatalf("register session should've succeeded: %v", err) + } + if request.RequestID != session.ID() { + t.Fatalf("session id doesn't match: %v != %v", request.RequestID, session.ID()) + } + + // We shouldn't be able to register another session with the same request id + _, err = manager.RegisterSession(&request, &noopEyeball{}) + if !errors.Is(err, v3.ErrSessionAlreadyRegistered) { + t.Fatalf("session is already registered for this connection: %v", err) + } + + // We shouldn't be able to register another session with the same request id for a different connection + _, err = manager.RegisterSession(&request, &noopEyeball{connID: 1}) + if !errors.Is(err, v3.ErrSessionBoundToOtherConn) { + t.Fatalf("session is already registered for a separate connection: %v", err) + } + + // Get session + sessionGet, err := manager.GetSession(request.RequestID) + if err != nil { + t.Fatalf("get session failed: %v", err) + } + if session.ID() != sessionGet.ID() { + t.Fatalf("session's do not match: %v != %v", session.ID(), sessionGet.ID()) + } + + // Remove the session + manager.UnregisterSession(request.RequestID) + + // Get session should fail + _, err = manager.GetSession(request.RequestID) + if !errors.Is(err, v3.ErrSessionNotFound) { + t.Fatalf("get session failed: %v", err) + } + + // Closing the original session should return that the socket is already closed (by the session unregistration) + err = session.Close() + if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { + t.Fatalf("session should've closed without issue: %v", err) + } +} + +func TestGetSession_Empty(t *testing.T) { + log := zerolog.Nop() + manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort) + + _, err := manager.GetSession(testRequestID) + if !errors.Is(err, v3.ErrSessionNotFound) { + t.Fatalf("get session find no session: %v", err) + } +} diff --git a/quic/v3/metrics.go b/quic/v3/metrics.go new file mode 100644 index 00000000..8f9cf19e --- /dev/null +++ b/quic/v3/metrics.go @@ -0,0 +1,90 @@ +package v3 + +import ( + "github.com/prometheus/client_golang/prometheus" +) + +const ( + namespace = "cloudflared" + subsystem = "udp" +) + +type Metrics interface { + IncrementFlows() + DecrementFlows() + PayloadTooLarge() + RetryFlowResponse() + MigrateFlow() +} + +type metrics struct { + activeUDPFlows prometheus.Gauge + totalUDPFlows prometheus.Counter + payloadTooLarge prometheus.Counter + retryFlowResponses prometheus.Counter + migratedFlows prometheus.Counter +} + +func (m *metrics) IncrementFlows() { + m.totalUDPFlows.Inc() + m.activeUDPFlows.Inc() +} + +func (m *metrics) DecrementFlows() { + m.activeUDPFlows.Dec() +} + +func (m *metrics) PayloadTooLarge() { + m.payloadTooLarge.Inc() +} + +func (m *metrics) RetryFlowResponse() { + m.retryFlowResponses.Inc() +} + +func (m *metrics) MigrateFlow() { + m.migratedFlows.Inc() +} + +func NewMetrics(registerer prometheus.Registerer) Metrics { + m := &metrics{ + activeUDPFlows: prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "active_flows", + Help: "Concurrent count of UDP flows that are being proxied to any origin", + }), + totalUDPFlows: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "total_flows", + Help: "Total count of UDP flows that have been proxied to any origin", + }), + payloadTooLarge: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "payload_too_large", + Help: "Total count of UDP flows that have had origin payloads that are too large to proxy", + }), + retryFlowResponses: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "retry_flow_responses", + Help: "Total count of UDP flows that have had to send their registration response more than once", + }), + migratedFlows: prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "migrated_flows", + Help: "Total count of UDP flows have been migrated across local connections", + }), + } + registerer.MustRegister( + m.activeUDPFlows, + m.totalUDPFlows, + m.payloadTooLarge, + m.retryFlowResponses, + m.migratedFlows, + ) + return m +} diff --git a/quic/v3/metrics_test.go b/quic/v3/metrics_test.go new file mode 100644 index 00000000..5f6b18a7 --- /dev/null +++ b/quic/v3/metrics_test.go @@ -0,0 +1,9 @@ +package v3_test + +type noopMetrics struct{} + +func (noopMetrics) IncrementFlows() {} +func (noopMetrics) DecrementFlows() {} +func (noopMetrics) PayloadTooLarge() {} +func (noopMetrics) RetryFlowResponse() {} +func (noopMetrics) MigrateFlow() {} diff --git a/quic/v3/muxer.go b/quic/v3/muxer.go new file mode 100644 index 00000000..79081762 --- /dev/null +++ b/quic/v3/muxer.go @@ -0,0 +1,301 @@ +package v3 + +import ( + "context" + "errors" + "time" + + "github.com/rs/zerolog" +) + +const ( + // Allocating a 16 channel buffer here allows for the writer to be slightly faster than the reader. + // This has worked previously well for datagramv2, so we will start with this as well + demuxChanCapacity = 16 + + logSrcKey = "src" + logDstKey = "dst" + logDurationKey = "durationMS" +) + +// DatagramConn is the bridge that multiplexes writes and reads of datagrams for UDP sessions and ICMP packets to +// a connection. +type DatagramConn interface { + DatagramWriter + // Serve provides a server interface to process and handle incoming QUIC datagrams and demux their datagram v3 payloads. + Serve(context.Context) error + // ID indicates connection index identifier + ID() uint8 +} + +// DatagramWriter provides the Muxer interface to create proper Datagrams when sending over a connection. +type DatagramWriter interface { + SendUDPSessionDatagram(datagram []byte) error + SendUDPSessionResponse(id RequestID, resp SessionRegistrationResp) error + //SendICMPPacket(packet packet.IP) error +} + +// QuicConnection provides an interface that matches [quic.Connection] for only the datagram operations. +// +// We currently rely on the mutex for the [quic.Connection.SendDatagram] and [quic.Connection.ReceiveDatagram] and +// do not have any locking for them. If the implementation in quic-go were to ever change, we would need to make +// sure that we lock properly on these operations. +type QuicConnection interface { + Context() context.Context + SendDatagram(payload []byte) error + ReceiveDatagram(context.Context) ([]byte, error) +} + +type datagramConn struct { + conn QuicConnection + index uint8 + sessionManager SessionManager + metrics Metrics + logger *zerolog.Logger + + datagrams chan []byte + readErrors chan error +} + +func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn { + log := logger.With().Uint8("datagramVersion", 3).Logger() + return &datagramConn{ + conn: conn, + index: index, + sessionManager: sessionManager, + metrics: metrics, + logger: &log, + datagrams: make(chan []byte, demuxChanCapacity), + readErrors: make(chan error, 2), + } +} + +func (c datagramConn) ID() uint8 { + return c.index +} + +func (c *datagramConn) SendUDPSessionDatagram(datagram []byte) error { + return c.conn.SendDatagram(datagram) +} + +func (c *datagramConn) SendUDPSessionResponse(id RequestID, resp SessionRegistrationResp) error { + datagram := UDPSessionRegistrationResponseDatagram{ + RequestID: id, + ResponseType: resp, + } + data, err := datagram.MarshalBinary() + if err != nil { + return err + } + return c.conn.SendDatagram(data) +} + +var errReadTimeout error = errors.New("receive datagram timeout") + +// pollDatagrams will read datagrams from the underlying connection until the provided context is done. +func (c *datagramConn) pollDatagrams(ctx context.Context) { + for ctx.Err() == nil { + datagram, err := c.conn.ReceiveDatagram(ctx) + // If the read returns an error, we want to return the failure to the channel. + if err != nil { + c.readErrors <- err + return + } + c.datagrams <- datagram + } + if ctx.Err() != nil { + c.readErrors <- ctx.Err() + } +} + +// Serve will begin the process of receiving datagrams from the [quic.Connection] and demuxing them to their destination. +// The [DatagramConn] when serving, will be responsible for the sessions it accepts. +func (c *datagramConn) Serve(ctx context.Context) error { + connCtx := c.conn.Context() + // We want to make sure that we cancel the reader context if the Serve method returns. This could also mean that the + // underlying connection is also closing, but that is handled outside of the context of the datagram muxer. + readCtx, cancel := context.WithCancel(connCtx) + defer cancel() + go c.pollDatagrams(readCtx) + for { + // We make sure to monitor the context of cloudflared and the underlying connection to return if any errors occur. + var datagram []byte + select { + // Monitor the context of cloudflared + case <-ctx.Done(): + return ctx.Err() + // Monitor the context of the underlying connection + case <-connCtx.Done(): + return connCtx.Err() + // Monitor for any hard errors from reading the connection + case err := <-c.readErrors: + return err + // Otherwise, wait and dequeue datagrams as they come in + case d := <-c.datagrams: + datagram = d + } + + // Each incoming datagram will be processed in a new go routine to handle the demuxing and action associated. + go func() { + typ, err := ParseDatagramType(datagram) + if err != nil { + c.logger.Err(err).Msgf("unable to parse datagram type: %d", typ) + return + } + switch typ { + case UDPSessionRegistrationType: + reg := &UDPSessionRegistrationDatagram{} + err := reg.UnmarshalBinary(datagram) + if err != nil { + c.logger.Err(err).Msgf("unable to unmarshal session registration datagram") + return + } + logger := c.logger.With().Str(logFlowID, reg.RequestID.String()).Logger() + // We bind the new session to the quic connection context instead of cloudflared context to allow for the + // quic connection to close and close only the sessions bound to it. Closing of cloudflared will also + // initiate the close of the quic connection, so we don't have to worry about the application context + // in the scope of a session. + c.handleSessionRegistrationDatagram(connCtx, reg, &logger) + case UDPSessionPayloadType: + payload := &UDPSessionPayloadDatagram{} + err := payload.UnmarshalBinary(datagram) + if err != nil { + c.logger.Err(err).Msgf("unable to unmarshal session payload datagram") + return + } + logger := c.logger.With().Str(logFlowID, payload.RequestID.String()).Logger() + c.handleSessionPayloadDatagram(payload, &logger) + case UDPSessionRegistrationResponseType: + // cloudflared should never expect to receive UDP session responses as it will not initiate new + // sessions towards the edge. + c.logger.Error().Msgf("unexpected datagram type received: %d", UDPSessionRegistrationResponseType) + return + default: + c.logger.Error().Msgf("unknown datagram type received: %d", typ) + } + }() + } +} + +// This method handles new registrations of a session and the serve loop for the session. +func (c *datagramConn) handleSessionRegistrationDatagram(ctx context.Context, datagram *UDPSessionRegistrationDatagram, logger *zerolog.Logger) { + log := logger.With(). + Str(logFlowID, datagram.RequestID.String()). + Str(logDstKey, datagram.Dest.String()). + Logger() + session, err := c.sessionManager.RegisterSession(datagram, c) + switch err { + case nil: + // Continue as normal + case ErrSessionAlreadyRegistered: + // Session is already registered and likely the response got lost + c.handleSessionAlreadyRegistered(datagram.RequestID, &log) + return + case ErrSessionBoundToOtherConn: + // Session is already registered but to a different connection + c.handleSessionMigration(datagram.RequestID, &log) + return + default: + log.Err(err).Msgf("flow registration failure") + c.handleSessionRegistrationFailure(datagram.RequestID, &log) + return + } + log = log.With().Str(logSrcKey, session.LocalAddr().String()).Logger() + c.metrics.IncrementFlows() + // Make sure to eventually remove the session from the session manager when the session is closed + defer c.sessionManager.UnregisterSession(session.ID()) + defer c.metrics.DecrementFlows() + + // Respond that we are able to process the new session + err = c.SendUDPSessionResponse(datagram.RequestID, ResponseOk) + if err != nil { + log.Err(err).Msgf("flow registration failure: unable to send session registration response") + return + } + + // We bind the context of the session to the [quic.Connection] that initiated the session. + // [Session.Serve] is blocking and will continue this go routine till the end of the session lifetime. + start := time.Now() + err = session.Serve(ctx) + elapsedMS := time.Now().Sub(start).Milliseconds() + log = log.With().Int64(logDurationKey, elapsedMS).Logger() + if err == nil { + // We typically don't expect a session to close without some error response. [SessionIdleErr] is the typical + // expected error response. + log.Warn().Msg("flow closed: no explicit close or timeout elapsed") + return + } + // SessionIdleErr and SessionCloseErr are valid and successful error responses to end a session. + if errors.Is(err, SessionIdleErr{}) || errors.Is(err, SessionCloseErr) { + log.Debug().Msgf("flow closed: %s", err.Error()) + return + } + + // All other errors should be reported as errors + log.Err(err).Msgf("flow closed with an error") +} + +func (c *datagramConn) handleSessionAlreadyRegistered(requestID RequestID, logger *zerolog.Logger) { + // Send another registration response since the session is already active + err := c.SendUDPSessionResponse(requestID, ResponseOk) + if err != nil { + logger.Err(err).Msgf("flow registration failure: unable to send an additional flow registration response") + return + } + + session, err := c.sessionManager.GetSession(requestID) + if err != nil { + // If for some reason we can not find the session after attempting to register it, we can just return + // instead of trying to reset the idle timer for it. + return + } + // The session is already running in another routine so we want to restart the idle timeout since no proxied + // packets have come down yet. + session.ResetIdleTimer() + c.metrics.RetryFlowResponse() + logger.Debug().Msgf("flow registration response retry") +} + +func (c *datagramConn) handleSessionMigration(requestID RequestID, logger *zerolog.Logger) { + // We need to migrate the currently running session to this edge connection. + session, err := c.sessionManager.GetSession(requestID) + if err != nil { + // If for some reason we can not find the session after attempting to register it, we can just return + // instead of trying to reset the idle timer for it. + return + } + + // Migrate the session to use this edge connection instead of the currently running one. + // We also pass in this connection's logger to override the existing logger for the session. + session.Migrate(c, c.logger) + + // Send another registration response since the session is already active + err = c.SendUDPSessionResponse(requestID, ResponseOk) + if err != nil { + logger.Err(err).Msgf("flow registration failure: unable to send an additional flow registration response") + return + } + logger.Debug().Msgf("flow registration migration") +} + +func (c *datagramConn) handleSessionRegistrationFailure(requestID RequestID, logger *zerolog.Logger) { + err := c.SendUDPSessionResponse(requestID, ResponseUnableToBindSocket) + if err != nil { + logger.Err(err).Msgf("unable to send flow registration error response (%d)", ResponseUnableToBindSocket) + } +} + +// Handles incoming datagrams that need to be sent to a registered session. +func (c *datagramConn) handleSessionPayloadDatagram(datagram *UDPSessionPayloadDatagram, logger *zerolog.Logger) { + s, err := c.sessionManager.GetSession(datagram.RequestID) + if err != nil { + logger.Err(err).Msgf("unable to find flow") + return + } + // We ignore the bytes written to the socket because any partial write must return an error. + _, err = s.Write(datagram.Payload) + if err != nil { + logger.Err(err).Msgf("unable to write payload for the flow") + return + } +} diff --git a/quic/v3/muxer_test.go b/quic/v3/muxer_test.go new file mode 100644 index 00000000..1b2149f9 --- /dev/null +++ b/quic/v3/muxer_test.go @@ -0,0 +1,643 @@ +package v3_test + +import ( + "bytes" + "context" + "errors" + "net" + "net/netip" + "slices" + "sync" + "testing" + "time" + + "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/ingress" + v3 "github.com/cloudflare/cloudflared/quic/v3" +) + +type noopEyeball struct { + connID uint8 +} + +func (noopEyeball) Serve(ctx context.Context) error { return nil } +func (n noopEyeball) ID() uint8 { return n.connID } +func (noopEyeball) SendUDPSessionDatagram(datagram []byte) error { return nil } +func (noopEyeball) SendUDPSessionResponse(id v3.RequestID, resp v3.SessionRegistrationResp) error { + return nil +} + +type mockEyeball struct { + connID uint8 + // datagram sent via SendUDPSessionDatagram + recvData chan []byte + // responses sent via SendUDPSessionResponse + recvResp chan struct { + id v3.RequestID + resp v3.SessionRegistrationResp + } +} + +func newMockEyeball() mockEyeball { + return mockEyeball{ + connID: 0, + recvData: make(chan []byte, 1), + recvResp: make(chan struct { + id v3.RequestID + resp v3.SessionRegistrationResp + }, 1), + } +} + +func (mockEyeball) Serve(ctx context.Context) error { return nil } +func (m *mockEyeball) ID() uint8 { return m.connID } + +func (m *mockEyeball) SendUDPSessionDatagram(datagram []byte) error { + b := make([]byte, len(datagram)) + copy(b, datagram) + m.recvData <- b + return nil +} + +func (m *mockEyeball) SendUDPSessionResponse(id v3.RequestID, resp v3.SessionRegistrationResp) error { + m.recvResp <- struct { + id v3.RequestID + resp v3.SessionRegistrationResp + }{ + id, resp, + } + return nil +} + +func TestDatagramConn_New(t *testing.T) { + log := zerolog.Nop() + conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), 0, &noopMetrics{}, &log) + if conn == nil { + t.Fatal("expected valid connection") + } +} + +func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), 0, &noopMetrics{}, &log) + + payload := []byte{0xef, 0xef} + conn.SendUDPSessionDatagram(payload) + p := <-quic.recv + if !slices.Equal(p, payload) { + t.Fatal("datagram sent does not match datagram received on quic side") + } +} + +func TestDatagramConn_SendUDPSessionResponse(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), 0, &noopMetrics{}, &log) + + conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable) + resp := <-quic.recv + var response v3.UDPSessionRegistrationResponseDatagram + err := response.UnmarshalBinary(resp) + if err != nil { + t.Fatal(err) + } + expected := v3.UDPSessionRegistrationResponseDatagram{ + RequestID: testRequestID, + ResponseType: v3.ResponseDestinationUnreachable, + } + if response != expected { + t.Fatal("datagram response sent does not match expected datagram response received") + } +} + +func TestDatagramConnServe_ApplicationClosed(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), 0, &noopMetrics{}, &log) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + err := conn.Serve(ctx) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal(err) + } +} + +func TestDatagramConnServe_ConnectionClosed(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + quic.ctx = ctx + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), 0, &noopMetrics{}, &log) + + err := conn.Serve(context.Background()) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal(err) + } +} + +func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) { + log := zerolog.Nop() + quic := &mockQuicConnReadError{err: net.ErrClosed} + conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), 0, &noopMetrics{}, &log) + + err := conn.Serve(context.Background()) + if !errors.Is(err, net.ErrClosed) { + t.Fatal(err) + } +} + +func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) { + for _, test := range []struct { + name string + input []byte + expected string + }{ + { + "empty", + []byte{}, + "{\"level\":\"error\",\"datagramVersion\":3,\"error\":\"datagram should have at least 1 byte\",\"message\":\"unable to parse datagram type: 0\"}\n", + }, + { + "unexpected", + []byte{byte(v3.UDPSessionRegistrationResponseType)}, + "{\"level\":\"error\",\"datagramVersion\":3,\"message\":\"unexpected datagram type received: 3\"}\n", + }, + { + "unknown", + []byte{99}, + "{\"level\":\"error\",\"datagramVersion\":3,\"message\":\"unknown datagram type received: 99\"}\n", + }, + } { + t.Run(test.name, func(t *testing.T) { + logOutput := new(LockedBuffer) + log := zerolog.New(logOutput) + quic := newMockQuicConn() + quic.send <- test.input + conn := v3.NewDatagramConn(quic, &mockSessionManager{}, 0, &noopMetrics{}, &log) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + err := conn.Serve(ctx) + // we cancel the Serve method to check to see if the log output was written since the unsupported datagram + // is dropped with only a log message as a side-effect. + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal(err) + } + + out := logOutput.String() + if out != test.expected { + t.Fatalf("incorrect log output expected: %s", out) + } + }) + } +} + +type LockedBuffer struct { + bytes.Buffer + l sync.Mutex +} + +func (b *LockedBuffer) Write(p []byte) (n int, err error) { + b.l.Lock() + defer b.l.Unlock() + return b.Buffer.Write(p) +} + +func (b *LockedBuffer) String() string { + b.l.Lock() + defer b.l.Unlock() + return b.Buffer.String() +} + +func TestDatagramConnServe_RegisterSession_SessionManagerError(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + expectedErr := errors.New("unable to register session") + sessionManager := mockSessionManager{expectedRegErr: expectedErr} + conn := v3.NewDatagramConn(quic, &sessionManager, 0, &noopMetrics{}, &log) + + // Setup the muxer + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errors.New("other error")) + done := make(chan error, 1) + go func() { + done <- conn.Serve(ctx) + }() + + // Send new session registration + datagram := newRegisterSessionDatagram(testRequestID) + quic.send <- datagram + + // Wait for session registration response with failure + datagram = <-quic.recv + var resp v3.UDPSessionRegistrationResponseDatagram + err := resp.UnmarshalBinary(datagram) + if err != nil { + t.Fatal(err) + } + + if resp.RequestID != testRequestID || resp.ResponseType != v3.ResponseUnableToBindSocket { + t.Fatalf("expected registration response failure") + } + + // Cancel the muxer Serve context and make sure it closes with the expected error + assertContextClosed(t, ctx, done, cancel) +} + +func TestDatagramConnServe(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + session := newMockSession() + sessionManager := mockSessionManager{session: &session} + conn := v3.NewDatagramConn(quic, &sessionManager, 0, &noopMetrics{}, &log) + + // Setup the muxer + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errors.New("other error")) + done := make(chan error, 1) + go func() { + done <- conn.Serve(ctx) + }() + + // Send new session registration + datagram := newRegisterSessionDatagram(testRequestID) + quic.send <- datagram + + // Wait for session registration response with success + datagram = <-quic.recv + var resp v3.UDPSessionRegistrationResponseDatagram + err := resp.UnmarshalBinary(datagram) + if err != nil { + t.Fatal(err) + } + + if resp.RequestID != testRequestID || resp.ResponseType != v3.ResponseOk { + t.Fatalf("expected registration response ok") + } + + // We expect the session to be served + timer := time.NewTimer(15 * time.Second) + defer timer.Stop() + select { + case <-session.served: + break + case <-timer.C: + t.Fatalf("expected session serve to be called") + } + + // Cancel the muxer Serve context and make sure it closes with the expected error + assertContextClosed(t, ctx, done, cancel) +} + +func TestDatagramConnServe_RegisterTwice(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + session := newMockSession() + sessionManager := mockSessionManager{session: &session} + conn := v3.NewDatagramConn(quic, &sessionManager, 0, &noopMetrics{}, &log) + + // Setup the muxer + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errors.New("other error")) + done := make(chan error, 1) + go func() { + done <- conn.Serve(ctx) + }() + + // Send new session registration + datagram := newRegisterSessionDatagram(testRequestID) + quic.send <- datagram + + // Wait for session registration response with success + datagram = <-quic.recv + var resp v3.UDPSessionRegistrationResponseDatagram + err := resp.UnmarshalBinary(datagram) + if err != nil { + t.Fatal(err) + } + + if resp.RequestID != testRequestID || resp.ResponseType != v3.ResponseOk { + t.Fatalf("expected registration response ok") + } + + // Set the session manager to return already registered + sessionManager.expectedRegErr = v3.ErrSessionAlreadyRegistered + // Send the registration again as if we didn't receive it at the edge + datagram = newRegisterSessionDatagram(testRequestID) + quic.send <- datagram + + // Wait for session registration response with success + datagram = <-quic.recv + err = resp.UnmarshalBinary(datagram) + if err != nil { + t.Fatal(err) + } + + if resp.RequestID != testRequestID || resp.ResponseType != v3.ResponseOk { + t.Fatalf("expected registration response ok") + } + + // We expect the session to be served + timer := time.NewTimer(15 * time.Second) + defer timer.Stop() + select { + case <-session.served: + break + case <-timer.C: + t.Fatalf("expected session serve to be called") + } + + // Cancel the muxer Serve context and make sure it closes with the expected error + assertContextClosed(t, ctx, done, cancel) +} + +func TestDatagramConnServe_MigrateConnection(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + session := newMockSession() + sessionManager := mockSessionManager{session: &session} + conn := v3.NewDatagramConn(quic, &sessionManager, 0, &noopMetrics{}, &log) + quic2 := newMockQuicConn() + conn2 := v3.NewDatagramConn(quic2, &sessionManager, 1, &noopMetrics{}, &log) + + // Setup the muxer + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errors.New("other error")) + done := make(chan error, 1) + go func() { + done <- conn.Serve(ctx) + }() + + ctx2, cancel2 := context.WithCancelCause(context.Background()) + defer cancel2(errors.New("other error")) + done2 := make(chan error, 1) + go func() { + done2 <- conn2.Serve(ctx2) + }() + + // Send new session registration + datagram := newRegisterSessionDatagram(testRequestID) + quic.send <- datagram + + // Wait for session registration response with success + datagram = <-quic.recv + var resp v3.UDPSessionRegistrationResponseDatagram + err := resp.UnmarshalBinary(datagram) + if err != nil { + t.Fatal(err) + } + + if resp.RequestID != testRequestID || resp.ResponseType != v3.ResponseOk { + t.Fatalf("expected registration response ok") + } + + // Set the session manager to return already registered to another connection + sessionManager.expectedRegErr = v3.ErrSessionBoundToOtherConn + // Send the registration again as if we didn't receive it at the edge for a new connection + datagram = newRegisterSessionDatagram(testRequestID) + quic2.send <- datagram + + // Wait for session registration response with success + datagram = <-quic2.recv + err = resp.UnmarshalBinary(datagram) + if err != nil { + t.Fatal(err) + } + + if resp.RequestID != testRequestID || resp.ResponseType != v3.ResponseOk { + t.Fatalf("expected registration response ok") + } + + // We expect the session to be served + timer := time.NewTimer(15 * time.Second) + defer timer.Stop() + select { + case <-session.served: + break + case <-timer.C: + t.Fatalf("expected session serve to be called") + } + + // Expect session to be migrated + select { + case id := <-session.migrated: + if id != conn2.ID() { + t.Fatalf("expected session to be migrated to connection 2") + } + case <-timer.C: + t.Fatalf("expected session migration to be called") + } + + // Cancel the muxer Serve context and make sure it closes with the expected error + assertContextClosed(t, ctx, done, cancel) + // Cancel the second muxer Serve context and make sure it closes with the expected error + assertContextClosed(t, ctx2, done2, cancel2) +} + +func TestDatagramConnServe_Payload_GetSessionError(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + // mockSessionManager will return the ErrSessionNotFound for any session attempting to be queried by the muxer + sessionManager := mockSessionManager{session: nil, expectedGetErr: v3.ErrSessionNotFound} + conn := v3.NewDatagramConn(quic, &sessionManager, 0, &noopMetrics{}, &log) + + // Setup the muxer + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errors.New("other error")) + done := make(chan error, 1) + go func() { + done <- conn.Serve(ctx) + }() + + // Send new session registration + datagram := newSessionPayloadDatagram(testRequestID, []byte{0xef, 0xef}) + quic.send <- datagram + + // Since the muxer should eventually discard a failed registration request, there is no side-effect + // that the registration was failed beyond the muxer accepting the registration request. As such, the + // test can only ensure that the quic.send channel was consumed and that the muxer closes normally + // afterwards with the expected context cancelled trigger. + + // Cancel the muxer Serve context and make sure it closes with the expected error + assertContextClosed(t, ctx, done, cancel) +} + +func TestDatagramConnServe_Payload(t *testing.T) { + log := zerolog.Nop() + quic := newMockQuicConn() + session := newMockSession() + sessionManager := mockSessionManager{session: &session} + conn := v3.NewDatagramConn(quic, &sessionManager, 0, &noopMetrics{}, &log) + + // Setup the muxer + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(errors.New("other error")) + done := make(chan error, 1) + go func() { + done <- conn.Serve(ctx) + }() + + // Send new session registration + expectedPayload := []byte{0xef, 0xef} + datagram := newSessionPayloadDatagram(testRequestID, expectedPayload) + quic.send <- datagram + + // Session should receive the payload + payload := <-session.recv + if !slices.Equal(expectedPayload, payload) { + t.Fatalf("expected session receieve the payload sent via the muxer") + } + + // Cancel the muxer Serve context and make sure it closes with the expected error + assertContextClosed(t, ctx, done, cancel) +} + +func newRegisterSessionDatagram(id v3.RequestID) []byte { + datagram := v3.UDPSessionRegistrationDatagram{ + RequestID: id, + Dest: netip.MustParseAddrPort("127.0.0.1:8080"), + IdleDurationHint: 5 * time.Second, + } + payload, err := datagram.MarshalBinary() + if err != nil { + panic(err) + } + return payload +} + +func newRegisterResponseSessionDatagram(id v3.RequestID, resp v3.SessionRegistrationResp) []byte { + datagram := v3.UDPSessionRegistrationResponseDatagram{ + RequestID: id, + ResponseType: resp, + } + payload, err := datagram.MarshalBinary() + if err != nil { + panic(err) + } + return payload +} + +func newSessionPayloadDatagram(id v3.RequestID, payload []byte) []byte { + datagram := make([]byte, len(payload)+17) + err := v3.MarshalPayloadHeaderTo(id, datagram[:]) + if err != nil { + panic(err) + } + copy(datagram[17:], payload) + return datagram +} + +// Cancel the provided context and make sure it closes with the expected cancellation error +func assertContextClosed(t *testing.T, ctx context.Context, done <-chan error, cancel context.CancelCauseFunc) { + cancel(expectedContextCanceled) + err := <-done + if !errors.Is(err, context.Canceled) { + t.Fatal(err) + } + if !errors.Is(context.Cause(ctx), expectedContextCanceled) { + t.Fatal(err) + } +} + +type mockQuicConn struct { + ctx context.Context + send chan []byte + recv chan []byte +} + +func newMockQuicConn() *mockQuicConn { + return &mockQuicConn{ + ctx: context.Background(), + send: make(chan []byte, 1), + recv: make(chan []byte, 1), + } +} + +func (m *mockQuicConn) Context() context.Context { + return m.ctx +} + +func (m *mockQuicConn) SendDatagram(payload []byte) error { + b := make([]byte, len(payload)) + copy(b, payload) + m.recv <- b + return nil +} + +func (m *mockQuicConn) ReceiveDatagram(_ context.Context) ([]byte, error) { + return <-m.send, nil +} + +type mockQuicConnReadError struct { + err error +} + +func (m *mockQuicConnReadError) Context() context.Context { + return context.Background() +} + +func (m *mockQuicConnReadError) SendDatagram(payload []byte) error { + return nil +} + +func (m *mockQuicConnReadError) ReceiveDatagram(_ context.Context) ([]byte, error) { + return nil, m.err +} + +type mockSessionManager struct { + session v3.Session + + expectedRegErr error + expectedGetErr error +} + +func (m *mockSessionManager) RegisterSession(request *v3.UDPSessionRegistrationDatagram, conn v3.DatagramConn) (v3.Session, error) { + return m.session, m.expectedRegErr +} + +func (m *mockSessionManager) GetSession(requestID v3.RequestID) (v3.Session, error) { + return m.session, m.expectedGetErr +} + +func (m *mockSessionManager) UnregisterSession(requestID v3.RequestID) {} + +type mockSession struct { + served chan struct{} + migrated chan uint8 + recv chan []byte +} + +func newMockSession() mockSession { + return mockSession{ + served: make(chan struct{}), + migrated: make(chan uint8, 2), + recv: make(chan []byte, 1), + } +} + +func (m *mockSession) ID() v3.RequestID { return testRequestID } +func (m *mockSession) RemoteAddr() net.Addr { return testOriginAddr } +func (m *mockSession) LocalAddr() net.Addr { return testLocalAddr } +func (m *mockSession) ConnectionID() uint8 { return 0 } +func (m *mockSession) Migrate(conn v3.DatagramConn, log *zerolog.Logger) { m.migrated <- conn.ID() } +func (m *mockSession) ResetIdleTimer() {} + +func (m *mockSession) Serve(ctx context.Context) error { + close(m.served) + return v3.SessionCloseErr +} + +func (m *mockSession) Write(payload []byte) (n int, err error) { + b := make([]byte, len(payload)) + copy(b, payload) + m.recv <- b + return len(b), nil +} + +func (m *mockSession) Close() error { + return nil +} diff --git a/quic/v3/request.go b/quic/v3/request.go index 29509e83..b716605d 100644 --- a/quic/v3/request.go +++ b/quic/v3/request.go @@ -3,6 +3,7 @@ package v3 import ( "encoding/binary" "errors" + "fmt" ) const ( @@ -37,6 +38,10 @@ func RequestIDFromSlice(data []byte) (RequestID, error) { }, nil } +func (id RequestID) String() string { + return fmt.Sprintf("%016x%016x", id.hi, id.lo) +} + // Compare returns an integer comparing two IPs. // The result will be 0 if id == id2, -1 if id < id2, and +1 if id > id2. // The definition of "less than" is the same as the [RequestID.Less] method. @@ -70,3 +75,15 @@ func (id RequestID) MarshalBinaryTo(data []byte) error { binary.BigEndian.PutUint64(data[8:], id.lo) return nil } + +func (id *RequestID) UnmarshalBinary(data []byte) error { + if len(data) != 16 { + return fmt.Errorf("invalid length slice provided to unmarshal: %d (expected 16)", len(data)) + } + + *id = RequestID{ + binary.BigEndian.Uint64(data[:8]), + binary.BigEndian.Uint64(data[8:]), + } + return nil +} diff --git a/quic/v3/request_test.go b/quic/v3/request_test.go index 519c2dd2..3c7915c7 100644 --- a/quic/v3/request_test.go +++ b/quic/v3/request_test.go @@ -5,6 +5,8 @@ import ( "slices" "testing" + "github.com/stretchr/testify/require" + v3 "github.com/cloudflare/cloudflared/quic/v3" ) @@ -48,3 +50,15 @@ func TestRequestIDParsing(t *testing.T) { t.Fatalf("buf1 != buf2: %+v %+v", buf1, buf2) } } + +func TestRequestID_MarshalBinary(t *testing.T) { + buf := make([]byte, 16) + err := testRequestID.MarshalBinaryTo(buf) + require.NoError(t, err) + require.Len(t, buf, 16) + + parsed := v3.RequestID{} + err = parsed.UnmarshalBinary(buf) + require.NoError(t, err) + require.Equal(t, testRequestID, parsed) +} diff --git a/quic/v3/session.go b/quic/v3/session.go new file mode 100644 index 00000000..7ebe02a7 --- /dev/null +++ b/quic/v3/session.go @@ -0,0 +1,258 @@ +package v3 + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/rs/zerolog" +) + +const ( + // A default is provided in the case that the client does not provide a close idle timeout. + defaultCloseIdleAfter = 210 * time.Second + + // The maximum payload from the origin that we will be able to read. However, even though we will + // read 1500 bytes from the origin, we limit the amount of bytes to be proxied to less than + // this value (maxDatagramPayloadLen). + maxOriginUDPPacketSize = 1500 + + logFlowID = "flowID" + logPacketSizeKey = "packetSize" +) + +// SessionCloseErr indicates that the session's Close method was called. +var SessionCloseErr error = errors.New("flow was closed directly") + +// SessionIdleErr is returned when the session was closed because there was no communication +// in either direction over the session for the timeout period. +type SessionIdleErr struct { + timeout time.Duration +} + +func (e SessionIdleErr) Error() string { + return fmt.Sprintf("flow was idle for %v", e.timeout) +} + +func (e SessionIdleErr) Is(target error) bool { + _, ok := target.(SessionIdleErr) + return ok +} + +func newSessionIdleErr(timeout time.Duration) error { + return SessionIdleErr{timeout} +} + +type Session interface { + io.WriteCloser + ID() RequestID + ConnectionID() uint8 + RemoteAddr() net.Addr + LocalAddr() net.Addr + ResetIdleTimer() + Migrate(eyeball DatagramConn, logger *zerolog.Logger) + // Serve starts the event loop for processing UDP packets + Serve(ctx context.Context) error +} + +type session struct { + id RequestID + closeAfterIdle time.Duration + origin io.ReadWriteCloser + originAddr net.Addr + localAddr net.Addr + eyeball atomic.Pointer[DatagramConn] + // activeAtChan is used to communicate the last read/write time + activeAtChan chan time.Time + closeChan chan error + metrics Metrics + log *zerolog.Logger +} + +func NewSession( + id RequestID, + closeAfterIdle time.Duration, + origin io.ReadWriteCloser, + originAddr net.Addr, + localAddr net.Addr, + eyeball DatagramConn, + metrics Metrics, + log *zerolog.Logger, +) Session { + logger := log.With().Str(logFlowID, id.String()).Logger() + session := &session{ + id: id, + closeAfterIdle: closeAfterIdle, + origin: origin, + originAddr: originAddr, + localAddr: localAddr, + eyeball: atomic.Pointer[DatagramConn]{}, + // activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will + // drop instead of blocking because last active time only needs to be an approximation + activeAtChan: make(chan time.Time, 1), + closeChan: make(chan error, 1), + metrics: metrics, + log: &logger, + } + session.eyeball.Store(&eyeball) + return session +} + +func (s *session) ID() RequestID { + return s.id +} + +func (s *session) RemoteAddr() net.Addr { + return s.originAddr +} + +func (s *session) LocalAddr() net.Addr { + return s.localAddr +} + +func (s *session) ConnectionID() uint8 { + eyeball := *(s.eyeball.Load()) + return eyeball.ID() +} + +func (s *session) Migrate(eyeball DatagramConn, logger *zerolog.Logger) { + current := *(s.eyeball.Load()) + // Only migrate if the connection ids are different. + if current.ID() != eyeball.ID() { + s.eyeball.Store(&eyeball) + log := logger.With().Str(logFlowID, s.id.String()).Logger() + s.log = &log + } + // The session is already running so we want to restart the idle timeout since no proxied packets have come down yet. + s.markActive() + s.metrics.MigrateFlow() +} + +func (s *session) Serve(ctx context.Context) error { + go func() { + // QUIC implementation copies data to another buffer before returning https://github.com/quic-go/quic-go/blob/v0.24.0/session.go#L1967-L1975 + // This makes it safe to share readBuffer between iterations + readBuffer := [maxOriginUDPPacketSize + DatagramPayloadHeaderLen]byte{} + // To perform a zero copy write when passing the datagram to the connection, we prepare the buffer with + // the required datagram header information. We can reuse this buffer for this session since the header is the + // same for the each read. + MarshalPayloadHeaderTo(s.id, readBuffer[:DatagramPayloadHeaderLen]) + for { + // Read from the origin UDP socket + n, err := s.origin.Read(readBuffer[DatagramPayloadHeaderLen:]) + if err != nil { + if errors.Is(err, io.EOF) || + errors.Is(err, io.ErrUnexpectedEOF) { + s.log.Debug().Msgf("flow (origin) connection closed: %v", err) + } + s.closeChan <- err + return + } + if n < 0 { + s.log.Warn().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was negative and was dropped") + continue + } + if n > maxDatagramPayloadLen { + s.metrics.PayloadTooLarge() + s.log.Error().Int(logPacketSizeKey, n).Msg("flow (origin) packet read was too large and was dropped") + continue + } + // We need to synchronize on the eyeball in-case that the connection was migrated. This should be rarely a point + // of lock contention, as a migration can only happen during startup of a session before traffic flow. + eyeball := *(s.eyeball.Load()) + // Sending a packet to the session does block on the [quic.Connection], however, this is okay because it + // will cause back-pressure to the kernel buffer if the writes are not fast enough to the edge. + err = eyeball.SendUDPSessionDatagram(readBuffer[:DatagramPayloadHeaderLen+n]) + if err != nil { + s.closeChan <- err + return + } + // Mark the session as active since we proxied a valid packet from the origin. + s.markActive() + } + }() + return s.waitForCloseCondition(ctx, s.closeAfterIdle) +} + +func (s *session) Write(payload []byte) (n int, err error) { + n, err = s.origin.Write(payload) + if err != nil { + s.log.Err(err).Msg("failed to write payload to flow (remote)") + return n, err + } + // Write must return a non-nil error if it returns n < len(p). https://pkg.go.dev/io#Writer + if n < len(payload) { + s.log.Err(io.ErrShortWrite).Msg("failed to write the full payload to flow (remote)") + return n, io.ErrShortWrite + } + // Mark the session as active since we proxied a packet to the origin. + s.markActive() + return n, err +} + +// ResetIdleTimer will restart the current idle timer. +// +// This public method is used to allow operators of sessions the ability to extend the session using information that is +// known external to the session itself. +func (s *session) ResetIdleTimer() { + s.markActive() +} + +// Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there +// are many concurrent read/write. It is fine to lose some precision +func (s *session) markActive() { + select { + case s.activeAtChan <- time.Now(): + default: + } +} + +func (s *session) Close() error { + // Make sure that we only close the origin connection once + return sync.OnceValue(func() error { + // We don't want to block on sending to the close channel if it is already full + select { + case s.closeChan <- SessionCloseErr: + default: + } + return s.origin.Close() + })() +} + +func (s *session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error { + // Closing the session at the end cancels read so Serve() can return + defer s.Close() + if closeAfterIdle == 0 { + // provide deafult is caller doesn't specify one + closeAfterIdle = defaultCloseIdleAfter + } + + checkIdleTimer := time.NewTimer(closeAfterIdle) + defer checkIdleTimer.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case reason := <-s.closeChan: + return reason + case <-checkIdleTimer.C: + // The check idle timer will only return after an idle period since the last active + // operation (read or write). + return newSessionIdleErr(closeAfterIdle) + case <-s.activeAtChan: + // The session is still active, we want to reset the timer. First we have to stop the timer, drain the + // current value and then reset. It's okay if we lose some time on this operation as we don't need to + // close an idle session directly on-time. + if !checkIdleTimer.Stop() { + <-checkIdleTimer.C + } + checkIdleTimer.Reset(closeAfterIdle) + } + } +} diff --git a/quic/v3/session_fuzz_test.go b/quic/v3/session_fuzz_test.go new file mode 100644 index 00000000..0e4952c0 --- /dev/null +++ b/quic/v3/session_fuzz_test.go @@ -0,0 +1,23 @@ +package v3_test + +import ( + "testing" +) + +// FuzzSessionWrite verifies that we don't run into any panics when writing variable sized payloads to the origin. +func FuzzSessionWrite(f *testing.F) { + f.Fuzz(func(t *testing.T, b []byte) { + testSessionWrite(t, b) + }) +} + +// FuzzSessionServe verifies that we don't run into any panics when reading variable sized payloads from the origin. +func FuzzSessionServe(f *testing.F) { + f.Fuzz(func(t *testing.T, b []byte) { + // The origin transport read is bound to 1280 bytes + if len(b) > 1280 { + b = b[:1280] + } + testSessionServe_Origin(t, b) + }) +} diff --git a/quic/v3/session_test.go b/quic/v3/session_test.go new file mode 100644 index 00000000..1e31962a --- /dev/null +++ b/quic/v3/session_test.go @@ -0,0 +1,328 @@ +package v3_test + +import ( + "context" + "errors" + "net" + "net/netip" + "slices" + "sync/atomic" + "testing" + "time" + + "github.com/rs/zerolog" + + v3 "github.com/cloudflare/cloudflared/quic/v3" +) + +var ( + expectedContextCanceled = errors.New("expected context canceled") + + testOriginAddr = net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:0")) + testLocalAddr = net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:0")) +) + +func TestSessionNew(t *testing.T) { + log := zerolog.Nop() + session := v3.NewSession(testRequestID, 5*time.Second, nil, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + if testRequestID != session.ID() { + t.Fatalf("session id doesn't match: %s != %s", testRequestID, session.ID()) + } +} + +func testSessionWrite(t *testing.T, payload []byte) { + log := zerolog.Nop() + origin := newTestOrigin(makePayload(1280)) + session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + n, err := session.Write(payload) + if err != nil { + t.Fatal(err) + } + if n != len(payload) { + t.Fatal("unable to write the whole payload") + } + if !slices.Equal(payload, origin.write[:len(payload)]) { + t.Fatal("payload provided from origin and read value are not the same") + } +} + +func TestSessionWrite_Max(t *testing.T) { + payload := makePayload(1280) + testSessionWrite(t, payload) +} + +func TestSessionWrite_Min(t *testing.T) { + payload := makePayload(0) + testSessionWrite(t, payload) +} + +func TestSessionServe_OriginMax(t *testing.T) { + payload := makePayload(1280) + testSessionServe_Origin(t, payload) +} + +func TestSessionServe_OriginMin(t *testing.T) { + payload := makePayload(0) + testSessionServe_Origin(t, payload) +} + +func testSessionServe_Origin(t *testing.T, payload []byte) { + log := zerolog.Nop() + eyeball := newMockEyeball() + origin := newTestOrigin(payload) + session := v3.NewSession(testRequestID, 3*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log) + defer session.Close() + + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(context.Canceled) + done := make(chan error) + go func() { + done <- session.Serve(ctx) + }() + + select { + case data := <-eyeball.recvData: + // check received data matches provided from origin + expectedData := makePayload(1500) + v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:]) + copy(expectedData[17:], payload) + if !slices.Equal(expectedData[:17+len(payload)], data) { + t.Fatal("expected datagram did not equal expected") + } + cancel(expectedContextCanceled) + case err := <-ctx.Done(): + // we expect the payload to return before the context to cancel on the session + t.Fatal(err) + } + + err := <-done + if !errors.Is(err, context.Canceled) { + t.Fatal(err) + } + if !errors.Is(context.Cause(ctx), expectedContextCanceled) { + t.Fatal(err) + } +} + +func TestSessionServe_OriginTooLarge(t *testing.T) { + log := zerolog.Nop() + eyeball := newMockEyeball() + payload := makePayload(1281) + origin := newTestOrigin(payload) + session := v3.NewSession(testRequestID, 2*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log) + defer session.Close() + + done := make(chan error) + go func() { + done <- session.Serve(context.Background()) + }() + + select { + case data := <-eyeball.recvData: + // we never expect a read to make it here because the origin provided a payload that is too large + // for cloudflared to proxy and it will drop it. + t.Fatalf("we should never proxy a payload of this size: %d", len(data)) + case err := <-done: + if !errors.Is(err, v3.SessionIdleErr{}) { + t.Error(err) + } + } +} + +func TestSessionServe_Migrate(t *testing.T) { + log := zerolog.Nop() + eyeball := newMockEyeball() + pipe1, pipe2 := net.Pipe() + session := v3.NewSession(testRequestID, 2*time.Second, pipe2, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log) + defer session.Close() + + done := make(chan error) + go func() { + done <- session.Serve(context.Background()) + }() + + // Migrate the session to a new connection before origin sends data + eyeball2 := newMockEyeball() + eyeball2.connID = 1 + session.Migrate(&eyeball2, &log) + + // Origin sends data + payload2 := []byte{0xde} + pipe1.Write(payload2) + + // Expect write to eyeball2 + data := <-eyeball2.recvData + if len(data) <= 17 || !slices.Equal(payload2, data[17:]) { + t.Fatalf("expected data to write to eyeball2 after migration: %+v", data) + } + + select { + case data := <-eyeball.recvData: + t.Fatalf("expected no data to write to eyeball1 after migration: %+v", data) + default: + } + + err := <-done + if !errors.Is(err, v3.SessionIdleErr{}) { + t.Error(err) + } +} + +func TestSessionClose_Multiple(t *testing.T) { + log := zerolog.Nop() + origin := newTestOrigin(makePayload(128)) + session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + err := session.Close() + if err != nil { + t.Fatal(err) + } + if !origin.closed.Load() { + t.Fatal("origin wasn't closed") + } + // subsequent closes shouldn't call close again or cause any errors + err = session.Close() + if err != nil { + t.Fatal(err) + } +} + +func TestSessionServe_IdleTimeout(t *testing.T) { + log := zerolog.Nop() + origin := newTestIdleOrigin(10 * time.Second) // Make idle time longer than closeAfterIdle + closeAfterIdle := 2 * time.Second + session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + err := session.Serve(context.Background()) + if !errors.Is(err, v3.SessionIdleErr{}) { + t.Fatal(err) + } + // session should be closed + if !origin.closed { + t.Fatalf("session should be closed after Serve returns") + } + // closing a session again should not return an error + err = session.Close() + if err != nil { + t.Fatal(err) + } +} + +func TestSessionServe_ParentContextCanceled(t *testing.T) { + log := zerolog.Nop() + // Make idle time and idle timeout longer than closeAfterIdle + origin := newTestIdleOrigin(10 * time.Second) + closeAfterIdle := 10 * time.Second + + session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + err := session.Serve(ctx) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatal(err) + } + // session should be closed + if !origin.closed { + t.Fatalf("session should be closed after Serve returns") + } + // closing a session again should not return an error + err = session.Close() + if err != nil { + t.Fatal(err) + } +} + +func TestSessionServe_ReadErrors(t *testing.T) { + log := zerolog.Nop() + origin := newTestErrOrigin(net.ErrClosed, nil) + session := v3.NewSession(testRequestID, 30*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log) + err := session.Serve(context.Background()) + if !errors.Is(err, net.ErrClosed) { + t.Fatal(err) + } +} + +type testOrigin struct { + // bytes from Write + write []byte + // bytes provided to Read + read []byte + readOnce atomic.Bool + closed atomic.Bool +} + +func newTestOrigin(payload []byte) testOrigin { + return testOrigin{ + read: payload, + } +} + +func (o *testOrigin) Read(p []byte) (n int, err error) { + if o.closed.Load() { + return -1, net.ErrClosed + } + if o.readOnce.Load() { + // We only want to provide one read so all other reads will be blocked + time.Sleep(10 * time.Second) + } + o.readOnce.Store(true) + return copy(p, o.read), nil +} + +func (o *testOrigin) Write(p []byte) (n int, err error) { + if o.closed.Load() { + return -1, net.ErrClosed + } + o.write = make([]byte, len(p)) + copy(o.write, p) + return len(p), nil +} + +func (o *testOrigin) Close() error { + o.closed.Store(true) + return nil +} + +type testIdleOrigin struct { + duration time.Duration + closed bool +} + +func newTestIdleOrigin(d time.Duration) testIdleOrigin { + return testIdleOrigin{ + duration: d, + } +} + +func (o *testIdleOrigin) Read(p []byte) (n int, err error) { + time.Sleep(o.duration) + return -1, nil +} + +func (o *testIdleOrigin) Write(p []byte) (n int, err error) { + return 0, nil +} + +func (o *testIdleOrigin) Close() error { + o.closed = true + return nil +} + +type testErrOrigin struct { + readErr error + writeErr error +} + +func newTestErrOrigin(readErr error, writeErr error) testErrOrigin { + return testErrOrigin{readErr, writeErr} +} + +func (o *testErrOrigin) Read(p []byte) (n int, err error) { + return 0, o.readErr +} + +func (o *testErrOrigin) Write(p []byte) (n int, err error) { + return len(p), o.writeErr +} + +func (o *testErrOrigin) Close() error { + return nil +} diff --git a/ssh_server_tests/Dockerfile b/ssh_server_tests/Dockerfile deleted file mode 100644 index 0f4a1629..00000000 --- a/ssh_server_tests/Dockerfile +++ /dev/null @@ -1,14 +0,0 @@ -FROM python:3-buster - -RUN wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb \ - && dpkg -i cloudflared-linux-amd64.deb - -RUN pip install pexpect - -COPY tests.py . -COPY ssh /root/.ssh -RUN chmod 600 /root/.ssh/id_rsa - -ARG SSH_HOSTNAME -RUN bash -c 'sed -i "s/{{hostname}}/${SSH_HOSTNAME}/g" /root/.ssh/authorized_keys_config' -RUN bash -c 'sed -i "s/{{hostname}}/${SSH_HOSTNAME}/g" /root/.ssh/short_lived_cert_config' diff --git a/ssh_server_tests/README.md b/ssh_server_tests/README.md deleted file mode 100644 index 391d9040..00000000 --- a/ssh_server_tests/README.md +++ /dev/null @@ -1,23 +0,0 @@ -# Cloudflared SSH server smoke tests - -Runs several tests in a docker container against a server that is started out of band of these tests. -Cloudflared token also needs to be retrieved out of band. -SSH server hostname and user need to be configured in a docker environment file - - -## Running tests - -* Build cloudflared: -make cloudflared - -* Start server: -sudo ./cloudflared tunnel --hostname HOSTNAME --ssh-server - -* Fetch token: -./cloudflared access login HOSTNAME - -* Create docker env file: -echo "SSH_HOSTNAME=HOSTNAME\nSSH_USER=USERNAME\n" > ssh_server_tests/.env - -* Run tests: -make test-ssh-server diff --git a/ssh_server_tests/docker-compose.yml b/ssh_server_tests/docker-compose.yml deleted file mode 100644 index e292dc27..00000000 --- a/ssh_server_tests/docker-compose.yml +++ /dev/null @@ -1,18 +0,0 @@ -version: "3.1" - -services: - ssh_test: - build: - context: . - args: - - SSH_HOSTNAME=${SSH_HOSTNAME} - volumes: - - "~/.cloudflared/:/root/.cloudflared" - env_file: - - .env - environment: - - AUTHORIZED_KEYS_SSH_CONFIG=/root/.ssh/authorized_keys_config - - SHORT_LIVED_CERT_SSH_CONFIG=/root/.ssh/short_lived_cert_config - - REMOTE_SCP_FILENAME=scp_test.txt - - ROOT_ONLY_TEST_FILE_PATH=~/permission_test.txt - entrypoint: "python tests.py" diff --git a/ssh_server_tests/ssh/authorized_keys_config b/ssh_server_tests/ssh/authorized_keys_config deleted file mode 100644 index 288c0b70..00000000 --- a/ssh_server_tests/ssh/authorized_keys_config +++ /dev/null @@ -1,5 +0,0 @@ -Host * - AddressFamily inet - -Host {{hostname}} - ProxyCommand /usr/local/bin/cloudflared access ssh --hostname %h diff --git a/ssh_server_tests/ssh/id_rsa b/ssh_server_tests/ssh/id_rsa deleted file mode 100644 index 16e178b1..00000000 --- a/ssh_server_tests/ssh/id_rsa +++ /dev/null @@ -1,49 +0,0 @@ ------BEGIN OPENSSH PRIVATE KEY----- -b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAACFwAAAAdzc2gtcn -NhAAAAAwEAAQAAAgEAvi26NDQ8cYTTztqPe9ZgF5HR/rIo5FoDgL5NbbZKW6h0txP9Fd8s -id9Bgmo+aGkeM327tPVVMQ6UFmdRksOCIDWQDjNLF8b6S+Fu95tvMKSbGreRoR32OvgZKV -I6KmOsF4z4GIv9naPplZswtKEUhSSI+/gPdAs9wfwalqZ77e82QJ727bYMeC3lzuoT+KBI -dYufJ4OQhLtpHrqhB5sn7s6+oCv/u85GSln5SIC18Hi2t9lW4tgb5tH8P0kEDDWGfPS5ok -qGi4kFTvwBXOCS2r4dhi5hRkpP7PqG4np0OCfvK5IRRJ27fCnj0loc+puZJAxnPMbuJr64 -vwxRx78PM/V0PDUsl0P6aR/vbe0XmF9FGqbWf2Tar1p4r6C9/bMzcDz8seYT8hzLIHP3+R -l1hdlsTLm+1EzhaExKId+tjXegKGG4nU24h6qHEnRxLQDMwEsdkfj4E1pVypZJXVyNj99D -o84vi0EUnu7R4HmQb/C+Pu7qMDtLT3Zk7O5Mg4LQ+cTz9V0noYEAyG46nAB4U/nJzBnV1J -+aAdpioHmUAYhLYlQ9Kiy7LCJi92g9Wqa4wxMKxBbO5ZeH++p2p2lUi/oQNqx/2dLYFmy0 -wxvJHbZIhAaFbOeCvHg1ucIAQznli2jOr2qoB+yKRRPAp/3NXnZg1v7ce2CkwiAD52wjtC -kAAAdILMJUeyzCVHsAAAAHc3NoLXJzYQAAAgEAvi26NDQ8cYTTztqPe9ZgF5HR/rIo5FoD -gL5NbbZKW6h0txP9Fd8sid9Bgmo+aGkeM327tPVVMQ6UFmdRksOCIDWQDjNLF8b6S+Fu95 -tvMKSbGreRoR32OvgZKVI6KmOsF4z4GIv9naPplZswtKEUhSSI+/gPdAs9wfwalqZ77e82 -QJ727bYMeC3lzuoT+KBIdYufJ4OQhLtpHrqhB5sn7s6+oCv/u85GSln5SIC18Hi2t9lW4t -gb5tH8P0kEDDWGfPS5okqGi4kFTvwBXOCS2r4dhi5hRkpP7PqG4np0OCfvK5IRRJ27fCnj -0loc+puZJAxnPMbuJr64vwxRx78PM/V0PDUsl0P6aR/vbe0XmF9FGqbWf2Tar1p4r6C9/b -MzcDz8seYT8hzLIHP3+Rl1hdlsTLm+1EzhaExKId+tjXegKGG4nU24h6qHEnRxLQDMwEsd -kfj4E1pVypZJXVyNj99Do84vi0EUnu7R4HmQb/C+Pu7qMDtLT3Zk7O5Mg4LQ+cTz9V0noY -EAyG46nAB4U/nJzBnV1J+aAdpioHmUAYhLYlQ9Kiy7LCJi92g9Wqa4wxMKxBbO5ZeH++p2 -p2lUi/oQNqx/2dLYFmy0wxvJHbZIhAaFbOeCvHg1ucIAQznli2jOr2qoB+yKRRPAp/3NXn -Zg1v7ce2CkwiAD52wjtCkAAAADAQABAAACAQCbnVsyAFQ9J00Rg/HIiUATyTQlzq57O9SF -8jH1RiZOHedzLx32WaleH5rBFiJ+2RTnWUjQ57aP77fpJR2wk93UcT+w/vPBPwXsNUjRvx -Qan3ZzRCYbyiKDWiNslmYV7X0RwD36CAK8jTVDP7t48h2SXLTiSLaMY+5i3uD6yLu7k/O2 -qNyw4jgN1rCmwQ8acD0aQec3NAZ7NcbsaBX/3Uutsup0scwOZtlJWZoLY5Z8cKpCgcsAz4 -j1NHnNZvey7dFgSffj/ktdvf7kBH0w/GnuJ4aNF0Jte70u0kiw5TZYBQVFh74tgUu6a6SJ -qUbxIYUL5EJNjxGsDn+phHEemw3aMv0CwZG6Tqaionlna7bLsl9Bg1HTGclczVWx8uqC+M -6agLmkhYCHG0rVj8h5smjXAQXtmvIDVYDOlJZZoF9VAOCj6QfmJUH1NAGpCs1HDHbeOxGA -OLCh4d3F4rScPqhGdtSt4W13VFIvXn2Qqoz9ufepZsee1SZqpcerxywx2wN9ZAzu+X8lTN -i+TA2B3vWpqqucOEsp4JwDN+VMKZqKUGUDWcm/eHSaG6wq0q734LUlgM85TjaIg8QsNtWV -giB1nWwsYIuH4rsFNFGEwURYdGBcw6idH0GZ7I4RaIB5F9oOza1d601E0APHYrtnx9yOiK -nOtJ+5ZmVZovaDRfu1aQAAAQBU/EFaNUzoVhO04pS2L6BlByt963bOIsSJhdlEzek5AAli -eaf1S/PD6xWCc0IGY+GZE0HPbhsKYanjqOpWldcA2T7fzf4oz4vFBfUkPYo/MLSlLCYsDd -IH3wBkCssnfR5EkzNgxnOvq646Nl64BMvxwSIXGPktdq9ZALxViwricSRzCFURnh5vLHWU -wBzSgAA0UlZ9E64GtAv066+AoZCp83GhTLRC4o0naE2e/K4op4BCFHLrZ8eXmDRK3NJj80 -Vkn+uhrk+SHmbjIhmS57Pv9p8TWyRvemph/nMUuZGKBUu2X+JQxggck0KigIrXjsmciCsM -BIM3mYDDfjYbyVhTAAABAQDkV8O1bWUsAIqk7RU+iDZojN5kaO+zUvj1TafX8QX1sY6pu4 -Z2cfSEka1532BaehM95bQm7BCPw4cYg56XidmCQTZ9WaWqxVrOo48EKXUtZMZx6nKFOKlq -MT2XTMnGT9n7kFCfEjSVkAjuJ9ZTFLOaoXAaVRnxeHQwOKaup5KKP9GSzNIw328U+96s3V -WKHeT4pMjHBccgW/qX/tRRidZw5in5uBC9Ew5y3UACFTkNOnhUwVfyUNbBZJ2W36msQ3KD -AN7nOrQHqhd3NFyCEy2ovIAKVBacr/VEX6EsRUshIehJzz8EY9f3kXL7WT2QDoz2giPeBJ -HJdEpXt43UpszjAAABAQDVNpqNdHUlCs9XnbIvc6ZRrNh79wt65YFfvh/QEuA33KnA6Ri6 -EgnV5IdUWXS/UFaYcm2udydrBpVIVifSYl3sioHBylpri23BEy38PKwVXvghUtfpN6dWGn -NZUG25fQPtIzqi+lo953ZjIj+Adi17AeVv4P4NiLrZeM9lXfWf2pEPOecxXs1IwAf9IiDQ -WepAwRLsu42eEnHA+DSJPZUkSbISfM5X345k0g6EVATX/yLL3CsqClPzPtsqjh6rbEfFg3 -2OfIMcWV77gOlGWGQ+bUHc8kV6xJqV9QVacLWzfLvIqHF0wQMf8WLOVHEzkfiq4VjwhVqr -/+FFvljm5nSDAAAAEW1pa2VAQzAyWTUwVEdKR0g4AQ== ------END OPENSSH PRIVATE KEY----- diff --git a/ssh_server_tests/ssh/id_rsa.pub b/ssh_server_tests/ssh/id_rsa.pub deleted file mode 100644 index 024f80ee..00000000 --- a/ssh_server_tests/ssh/id_rsa.pub +++ /dev/null @@ -1 +0,0 @@ -ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQC+Lbo0NDxxhNPO2o971mAXkdH+sijkWgOAvk1ttkpbqHS3E/0V3yyJ30GCaj5oaR4zfbu09VUxDpQWZ1GSw4IgNZAOM0sXxvpL4W73m28wpJsat5GhHfY6+BkpUjoqY6wXjPgYi/2do+mVmzC0oRSFJIj7+A90Cz3B/BqWpnvt7zZAnvbttgx4LeXO6hP4oEh1i58ng5CEu2keuqEHmyfuzr6gK/+7zkZKWflIgLXweLa32Vbi2Bvm0fw/SQQMNYZ89LmiSoaLiQVO/AFc4JLavh2GLmFGSk/s+obienQ4J+8rkhFEnbt8KePSWhz6m5kkDGc8xu4mvri/DFHHvw8z9XQ8NSyXQ/ppH+9t7ReYX0UaptZ/ZNqvWnivoL39szNwPPyx5hPyHMsgc/f5GXWF2WxMub7UTOFoTEoh362Nd6AoYbidTbiHqocSdHEtAMzASx2R+PgTWlXKlkldXI2P30Ojzi+LQRSe7tHgeZBv8L4+7uowO0tPdmTs7kyDgtD5xPP1XSehgQDIbjqcAHhT+cnMGdXUn5oB2mKgeZQBiEtiVD0qLLssImL3aD1aprjDEwrEFs7ll4f76nanaVSL+hA2rH/Z0tgWbLTDG8kdtkiEBoVs54K8eDW5wgBDOeWLaM6vaqgH7IpFE8Cn/c1edmDW/tx7YKTCIAPnbCO0KQ== mike@C02Y50TGJGH8 diff --git a/ssh_server_tests/ssh/short_lived_cert_config b/ssh_server_tests/ssh/short_lived_cert_config deleted file mode 100644 index e29356fe..00000000 --- a/ssh_server_tests/ssh/short_lived_cert_config +++ /dev/null @@ -1,11 +0,0 @@ -Host * - AddressFamily inet - -Host {{hostname}} - ProxyCommand bash -c '/usr/local/bin/cloudflared access ssh-gen --hostname %h; ssh -F /root/.ssh/short_lived_cert_config -tt %r@cfpipe-{{hostname}} >&2 <&1' - -Host cfpipe-{{hostname}} - HostName {{hostname}} - ProxyCommand /usr/local/bin/cloudflared access ssh --hostname %h - IdentityFile ~/.cloudflared/{{hostname}}-cf_key - CertificateFile ~/.cloudflared/{{hostname}}-cf_key-cert.pub diff --git a/ssh_server_tests/tests.py b/ssh_server_tests/tests.py deleted file mode 100644 index d27fcd7f..00000000 --- a/ssh_server_tests/tests.py +++ /dev/null @@ -1,195 +0,0 @@ -""" -Cloudflared Integration tests -""" - -import unittest -import subprocess -import os -import tempfile -from contextlib import contextmanager - -from pexpect import pxssh - - -class TestSSHBase(unittest.TestCase): - """ - SSH test base class containing constants and helper funcs - """ - - HOSTNAME = os.environ["SSH_HOSTNAME"] - SSH_USER = os.environ["SSH_USER"] - SSH_TARGET = f"{SSH_USER}@{HOSTNAME}" - AUTHORIZED_KEYS_SSH_CONFIG = os.environ["AUTHORIZED_KEYS_SSH_CONFIG"] - SHORT_LIVED_CERT_SSH_CONFIG = os.environ["SHORT_LIVED_CERT_SSH_CONFIG"] - SSH_OPTIONS = {"StrictHostKeyChecking": "no"} - - @classmethod - def get_ssh_command(cls, pty=True): - """ - Return ssh command arg list. If pty is true, a PTY is forced for the session. - """ - cmd = [ - "ssh", - "-o", - "StrictHostKeyChecking=no", - "-F", - cls.AUTHORIZED_KEYS_SSH_CONFIG, - cls.SSH_TARGET, - ] - if not pty: - cmd += ["-T"] - else: - cmd += ["-tt"] - - return cmd - - @classmethod - @contextmanager - def ssh_session_manager(cls, *args, **kwargs): - """ - Context manager for interacting with a pxssh session. - Disables pty echo on the remote server and ensures session is terminated afterward. - """ - session = pxssh.pxssh(options=cls.SSH_OPTIONS) - - session.login( - cls.HOSTNAME, - username=cls.SSH_USER, - original_prompt=r"[#@$]", - ssh_config=kwargs.get("ssh_config", cls.AUTHORIZED_KEYS_SSH_CONFIG), - ssh_tunnels=kwargs.get("ssh_tunnels", {}), - ) - try: - session.sendline("stty -echo") - session.prompt() - yield session - finally: - session.logout() - - @staticmethod - def get_command_output(session, cmd): - """ - Executes command on remote ssh server and waits for prompt. - Returns command output - """ - session.sendline(cmd) - session.prompt() - return session.before.decode().strip() - - def exec_command(self, cmd, shell=False): - """ - Executes command locally. Raises Assertion error for non-zero return code. - Returns stdout and stderr - """ - proc = subprocess.Popen( - cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=shell - ) - raw_out, raw_err = proc.communicate() - - out = raw_out.decode() - err = raw_err.decode() - self.assertEqual(proc.returncode, 0, msg=f"stdout: {out} stderr: {err}") - return out.strip(), err.strip() - - -class TestSSHCommandExec(TestSSHBase): - """ - Tests inline ssh command exec - """ - - # Name of file to be downloaded over SCP on remote server. - REMOTE_SCP_FILENAME = os.environ["REMOTE_SCP_FILENAME"] - - @classmethod - def get_scp_base_command(cls): - return [ - "scp", - "-o", - "StrictHostKeyChecking=no", - "-v", - "-F", - cls.AUTHORIZED_KEYS_SSH_CONFIG, - ] - - @unittest.skip( - "This creates files on the remote. Should be skipped until server is dockerized." - ) - def test_verbose_scp_sink_mode(self): - with tempfile.NamedTemporaryFile() as fl: - self.exec_command( - self.get_scp_base_command() + [fl.name, f"{self.SSH_TARGET}:"] - ) - - def test_verbose_scp_source_mode(self): - with tempfile.TemporaryDirectory() as tmpdirname: - self.exec_command( - self.get_scp_base_command() - + [f"{self.SSH_TARGET}:{self.REMOTE_SCP_FILENAME}", tmpdirname] - ) - local_filename = os.path.join(tmpdirname, self.REMOTE_SCP_FILENAME) - - self.assertTrue(os.path.exists(local_filename)) - self.assertTrue(os.path.getsize(local_filename) > 0) - - def test_pty_command(self): - base_cmd = self.get_ssh_command() - - out, _ = self.exec_command(base_cmd + ["whoami"]) - self.assertEqual(out.strip().lower(), self.SSH_USER.lower()) - - out, _ = self.exec_command(base_cmd + ["tty"]) - self.assertNotEqual(out, "not a tty") - - def test_non_pty_command(self): - base_cmd = self.get_ssh_command(pty=False) - - out, _ = self.exec_command(base_cmd + ["whoami"]) - self.assertEqual(out.strip().lower(), self.SSH_USER.lower()) - - out, _ = self.exec_command(base_cmd + ["tty"]) - self.assertEqual(out, "not a tty") - - -class TestSSHShell(TestSSHBase): - """ - Tests interactive SSH shell - """ - - # File path to a file on the remote server with root only read privileges. - ROOT_ONLY_TEST_FILE_PATH = os.environ["ROOT_ONLY_TEST_FILE_PATH"] - - def test_ssh_pty(self): - with self.ssh_session_manager() as session: - - # Test shell launched as correct user - username = self.get_command_output(session, "whoami") - self.assertEqual(username.lower(), self.SSH_USER.lower()) - - # Test USER env variable set - user_var = self.get_command_output(session, "echo $USER") - self.assertEqual(user_var.lower(), self.SSH_USER.lower()) - - # Test HOME env variable set to true user home. - home_env = self.get_command_output(session, "echo $HOME") - pwd = self.get_command_output(session, "pwd") - self.assertEqual(pwd, home_env) - - # Test shell launched in correct user home dir. - self.assertIn(username, pwd) - - # Ensure shell launched with correct user's permissions and privs. - # Can't read root owned 0700 files. - output = self.get_command_output( - session, f"cat {self.ROOT_ONLY_TEST_FILE_PATH}" - ) - self.assertIn("Permission denied", output) - - def test_short_lived_cert_auth(self): - with self.ssh_session_manager( - ssh_config=self.SHORT_LIVED_CERT_SSH_CONFIG - ) as session: - username = self.get_command_output(session, "whoami") - self.assertEqual(username.lower(), self.SSH_USER.lower()) - - -unittest.main() diff --git a/supervisor/supervisor.go b/supervisor/supervisor.go index 8dde4e76..d4a75e77 100644 --- a/supervisor/supervisor.go +++ b/supervisor/supervisor.go @@ -7,12 +7,15 @@ import ( "strings" "time" + "github.com/prometheus/client_golang/prometheus" "github.com/quic-go/quic-go" "github.com/rs/zerolog" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" + "github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/orchestration" + v3 "github.com/cloudflare/cloudflared/quic/v3" "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/tunnelstate" @@ -80,9 +83,14 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato edgeAddrHandler := NewIPAddrFallback(config.MaxEdgeAddrRetries) edgeBindAddr := config.EdgeBindAddr + datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer) + sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort) + edgeTunnelServer := EdgeTunnelServer{ config: config, orchestrator: orchestrator, + sessionManager: sessionManager, + datagramMetrics: datagramMetrics, edgeAddrs: edgeIPs, edgeAddrHandler: edgeAddrHandler, edgeBindAddr: edgeBindAddr, diff --git a/supervisor/tunnel.go b/supervisor/tunnel.go index 7de2cbd0..6a456ca6 100644 --- a/supervisor/tunnel.go +++ b/supervisor/tunnel.go @@ -7,6 +7,7 @@ import ( "net" "net/netip" "runtime/debug" + "slices" "strings" "sync" "time" @@ -24,6 +25,7 @@ import ( "github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/orchestration" quicpogs "github.com/cloudflare/cloudflared/quic" + v3 "github.com/cloudflare/cloudflared/quic/v3" "github.com/cloudflare/cloudflared/retry" "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/tunnelrpc/pogs" @@ -87,14 +89,6 @@ func (c *TunnelConfig) connectionOptions(originLocalAddr string, numPreviousAtte } } -func (c *TunnelConfig) SupportedFeatures() []string { - supported := []string{features.FeatureSerializedHeaders} - if c.NamedTunnel == nil { - supported = append(supported, features.FeatureQuickReconnects) - } - return supported -} - func StartTunnelDaemon( ctx context.Context, config *TunnelConfig, @@ -181,6 +175,8 @@ func (f *ipAddrFallback) ShouldGetNewAddress(connIndex uint8, err error) (needsN type EdgeTunnelServer struct { config *TunnelConfig orchestrator *orchestration.Orchestrator + sessionManager v3.SessionManager + datagramMetrics v3.Metrics edgeAddrHandler EdgeAddrHandler edgeAddrs *edgediscovery.Edge edgeBindAddr net.IP @@ -605,14 +601,26 @@ func (e *EdgeTunnelServer) serveQUIC( return err, true } - datagramSessionManager := connection.NewDatagramV2Connection( - ctx, - conn, - e.config.PacketConfig, - e.config.RPCTimeout, - e.config.WriteStreamTimeout, - connLogger.Logger(), - ) + var datagramSessionManager connection.DatagramSessionHandler + if slices.Contains(connOptions.Client.Features, features.FeatureDatagramV3) { + datagramSessionManager = connection.NewDatagramV3Connection( + ctx, + conn, + e.sessionManager, + connIndex, + e.datagramMetrics, + connLogger.Logger(), + ) + } else { + datagramSessionManager = connection.NewDatagramV2Connection( + ctx, + conn, + e.config.PacketConfig, + e.config.RPCTimeout, + e.config.WriteStreamTimeout, + connLogger.Logger(), + ) + } // Wrap the [quic.Connection] as a TunnelConnection tunnelConn, err := connection.NewTunnelConnection(