Merge branch 'master' into tunnel_token_file
This commit is contained in:
commit
521220a042
|
@ -0,0 +1,24 @@
|
|||
on:
|
||||
pull_request: {}
|
||||
workflow_dispatch: {}
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- master
|
||||
schedule:
|
||||
- cron: '0 0 * * *'
|
||||
name: Semgrep config
|
||||
jobs:
|
||||
semgrep:
|
||||
name: semgrep/ci
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
SEMGREP_APP_TOKEN: ${{ secrets.SEMGREP_APP_TOKEN }}
|
||||
SEMGREP_URL: https://cloudflare.semgrep.dev
|
||||
SEMGREP_APP_URL: https://cloudflare.semgrep.dev
|
||||
SEMGREP_VERSION_CHECK_URL: https://cloudflare.semgrep.dev/api/check-version
|
||||
container:
|
||||
image: semgrep/semgrep
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- run: semgrep ci
|
|
@ -17,3 +17,4 @@ cscope.*
|
|||
ssh_server_tests/.env
|
||||
/.cover
|
||||
built_artifacts/
|
||||
component-tests/.venv
|
||||
|
|
|
@ -3,6 +3,6 @@
|
|||
cd /tmp
|
||||
git clone -q https://github.com/cloudflare/go
|
||||
cd go/src
|
||||
# https://github.com/cloudflare/go/tree/ec0a014545f180b0c74dfd687698657a9e86e310 is version go1.22.2-devel-cf
|
||||
git checkout -q ec0a014545f180b0c74dfd687698657a9e86e310
|
||||
./make.bash
|
||||
# https://github.com/cloudflare/go/tree/f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38 is version go1.22.5-devel-cf
|
||||
git checkout -q f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38
|
||||
./make.bash
|
||||
|
|
|
@ -37,7 +37,7 @@ if ($LASTEXITCODE -ne 0) { throw "Failed unit tests" }
|
|||
|
||||
Write-Output "Running component tests"
|
||||
|
||||
python -m pip --disable-pip-version-check install --upgrade -r component-tests/requirements.txt
|
||||
python -m pip --disable-pip-version-check install --upgrade -r component-tests/requirements.txt --use-pep517
|
||||
python component-tests/setup.py --type create
|
||||
python -m pytest component-tests -o log_cli=true --log-cli-level=INFO
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
|
|
|
@ -9,8 +9,8 @@ Set-Location "$Env:Temp"
|
|||
git clone -q https://github.com/cloudflare/go
|
||||
Write-Output "Building go..."
|
||||
cd go/src
|
||||
# https://github.com/cloudflare/go/tree/ec0a014545f180b0c74dfd687698657a9e86e310 is version go1.22.2-devel-cf
|
||||
git checkout -q ec0a014545f180b0c74dfd687698657a9e86e310
|
||||
# https://github.com/cloudflare/go/tree/f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38 is version go1.22.5-devel-cf
|
||||
git checkout -q f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38
|
||||
& ./make.bat
|
||||
|
||||
Write-Output "Installed"
|
||||
Write-Output "Installed"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
$ErrorActionPreference = "Stop"
|
||||
$ProgressPreference = "SilentlyContinue"
|
||||
$GoMsiVersion = "go1.22.2.windows-amd64.msi"
|
||||
$GoMsiVersion = "go1.22.5.windows-amd64.msi"
|
||||
|
||||
Write-Output "Downloading go installer..."
|
||||
|
||||
|
@ -17,4 +17,4 @@ Install-Package "$Env:Temp\$GoMsiVersion" -Force
|
|||
# Go installer updates global $PATH
|
||||
go env
|
||||
|
||||
Write-Output "Installed"
|
||||
Write-Output "Installed"
|
||||
|
|
12
CHANGES.md
12
CHANGES.md
|
@ -1,3 +1,15 @@
|
|||
## 2024.12.2
|
||||
### New Features
|
||||
- This release introduces the ability to collect troubleshooting information from one instance of cloudflared running on the local machine. The command can be executed as `cloudflared tunnel diag`.
|
||||
|
||||
## 2024.12.1
|
||||
### Notices
|
||||
- The use of the `--metrics` is still honoured meaning that if this flag is set the metrics server will try to bind it, however, this version includes a change that makes the metrics server bind to a port with a semi-deterministic approach. If the metrics flag is not present the server will bind to the first available port of the range 20241 to 20245. In case of all ports being unavailable then the fallback is to bind to a random port.
|
||||
|
||||
## 2024.10.0
|
||||
### Bug Fixes
|
||||
- We fixed a bug related to `--grace-period`. Tunnels that use QUIC as transport weren't abiding by this waiting period before forcefully closing the connections to the edge. From now on, both QUIC and HTTP2 tunnels will wait for either the grace period to end (defaults to 30 seconds) or until the last in-flight request is handled. Users that wish to maintain the previous behavior should set `--grace-period` to 0 if `--protocol` is set to `quic`. This will force `cloudflared` to shutdown as soon as either SIGTERM or SIGINT is received.
|
||||
|
||||
## 2024.2.1
|
||||
### Notices
|
||||
- Starting from this version, tunnel diagnostics will be enabled by default. This will allow the engineering team to remotely get diagnostics from cloudflared during debug activities. Users still have the capability to opt-out of this feature by defining `--management-diagnostics=false` (or env `TUNNEL_MANAGEMENT_DIAGNOSTICS`).
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
# use a builder image for building cloudflare
|
||||
ARG TARGET_GOOS
|
||||
ARG TARGET_GOARCH
|
||||
FROM golang:1.22.2 as builder
|
||||
FROM golang:1.22.5 as builder
|
||||
ENV GO111MODULE=on \
|
||||
CGO_ENABLED=0 \
|
||||
TARGET_GOOS=${TARGET_GOOS} \
|
||||
TARGET_GOARCH=${TARGET_GOARCH}
|
||||
TARGET_GOARCH=${TARGET_GOARCH} \
|
||||
CONTAINER_BUILD=1
|
||||
|
||||
|
||||
WORKDIR /go/src/github.com/cloudflare/cloudflared/
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# use a builder image for building cloudflare
|
||||
FROM golang:1.22.2 as builder
|
||||
FROM golang:1.22.5 as builder
|
||||
ENV GO111MODULE=on \
|
||||
CGO_ENABLED=0
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# use a builder image for building cloudflare
|
||||
FROM golang:1.22.2 as builder
|
||||
FROM golang:1.22.5 as builder
|
||||
ENV GO111MODULE=on \
|
||||
CGO_ENABLED=0
|
||||
|
||||
|
|
18
Makefile
18
Makefile
|
@ -30,6 +30,10 @@ ifdef PACKAGE_MANAGER
|
|||
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)"
|
||||
endif
|
||||
|
||||
ifdef CONTAINER_BUILD
|
||||
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/metrics.Runtime=virtual"
|
||||
endif
|
||||
|
||||
LINK_FLAGS :=
|
||||
ifeq ($(FIPS), true)
|
||||
LINK_FLAGS := -linkmode=external -extldflags=-static $(LINK_FLAGS)
|
||||
|
@ -165,9 +169,17 @@ cover:
|
|||
# Generate the HTML report that can be viewed from the browser in CI.
|
||||
$Q go tool cover -html ".cover/c.out" -o .cover/all.html
|
||||
|
||||
.PHONY: test-ssh-server
|
||||
test-ssh-server:
|
||||
docker-compose -f ssh_server_tests/docker-compose.yml up
|
||||
.PHONY: fuzz
|
||||
fuzz:
|
||||
@go test -fuzz=FuzzIPDecoder -fuzztime=600s ./packet
|
||||
@go test -fuzz=FuzzICMPDecoder -fuzztime=600s ./packet
|
||||
@go test -fuzz=FuzzSessionWrite -fuzztime=600s ./quic/v3
|
||||
@go test -fuzz=FuzzSessionServe -fuzztime=600s ./quic/v3
|
||||
@go test -fuzz=FuzzRegistrationDatagram -fuzztime=600s ./quic/v3
|
||||
@go test -fuzz=FuzzPayloadDatagram -fuzztime=600s ./quic/v3
|
||||
@go test -fuzz=FuzzRegistrationResponseDatagram -fuzztime=600s ./quic/v3
|
||||
@go test -fuzz=FuzzNewIdentity -fuzztime=600s ./tracing
|
||||
@go test -fuzz=FuzzNewAccessValidator -fuzztime=600s ./validation
|
||||
|
||||
.PHONY: install-go
|
||||
install-go:
|
||||
|
|
|
@ -49,7 +49,7 @@ Once installed, you can authenticate `cloudflared` into your Cloudflare account
|
|||
|
||||
## TryCloudflare
|
||||
|
||||
Want to test Cloudflare Tunnel before adding a website to Cloudflare? You can do so with TryCloudflare using the documentation [available here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/trycloudflare).
|
||||
Want to test Cloudflare Tunnel before adding a website to Cloudflare? You can do so with TryCloudflare using the documentation [available here](https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/do-more-with-tunnels/trycloudflare/).
|
||||
|
||||
## Deprecated versions
|
||||
|
||||
|
|
|
@ -1,3 +1,85 @@
|
|||
2024.12.1
|
||||
- 2024-12-10 TUN-8795: update createrepo to createrepo_c to fix the release_pkgs.py script
|
||||
|
||||
2024.12.0
|
||||
- 2024-12-09 TUN-8640: Add ICMP support for datagram V3
|
||||
- 2024-12-09 TUN-8789: make python package installation consistent
|
||||
- 2024-12-06 TUN-8781: Add Trixie, drop Buster. Default to Bookworm
|
||||
- 2024-12-05 TUN-8775: Make sure the session Close can only be called once
|
||||
- 2024-12-04 TUN-8725: implement diagnostic procedure
|
||||
- 2024-12-04 TUN-8767: include raw output from network collector in diagnostic zipfile
|
||||
- 2024-12-04 TUN-8770: add cli configuration and tunnel configuration to diagnostic zipfile
|
||||
- 2024-12-04 TUN-8768: add job report to diagnostic zipfile
|
||||
- 2024-12-03 TUN-8726: implement compression routine to be used in diagnostic procedure
|
||||
- 2024-12-03 TUN-8732: implement port selection algorithm
|
||||
- 2024-12-03 TUN-8762: fix argument order when invoking tracert and modify network info output parsing.
|
||||
- 2024-12-03 TUN-8769: fix k8s log collector arguments
|
||||
- 2024-12-03 TUN-8727: extend client to include function to get cli configuration and tunnel configuration
|
||||
- 2024-11-29 TUN-8729: implement network collection for diagnostic procedure
|
||||
- 2024-11-29 TUN-8727: implement metrics, runtime, system, and tunnelstate in diagnostic http client
|
||||
- 2024-11-27 TUN-8733: add log collection for docker
|
||||
- 2024-11-27 TUN-8734: add log collection for kubernetes
|
||||
- 2024-11-27 TUN-8640: Refactor ICMPRouter to support new ICMPResponders
|
||||
- 2024-11-26 TUN-8735: add managed/local log collection
|
||||
- 2024-11-25 TUN-8728: implement diag/tunnel endpoint
|
||||
- 2024-11-25 TUN-8730: implement diag/configuration
|
||||
- 2024-11-22 TUN-8737: update metrics server port selection
|
||||
- 2024-11-22 TUN-8731: Implement diag/system endpoint
|
||||
- 2024-11-21 TUN-8748: Migrated datagram V3 flows to use migrated context
|
||||
|
||||
2024.11.1
|
||||
- 2024-11-18 Add cloudflared tunnel ready command
|
||||
- 2024-11-14 Make metrics a requirement for tunnel ready command
|
||||
- 2024-11-12 TUN-8701: Simplify flow registration logs for datagram v3
|
||||
- 2024-11-11 add: new go-fuzz targets
|
||||
- 2024-11-07 TUN-8701: Add metrics and adjust logs for datagram v3
|
||||
- 2024-11-06 TUN-8709: Add session migration for datagram v3
|
||||
- 2024-11-04 Fixed 404 in README.md to TryCloudflare
|
||||
- 2024-09-24 Update semgrep.yml
|
||||
|
||||
2024.11.0
|
||||
- 2024-11-05 VULN-66059: remove ssh server tests
|
||||
- 2024-11-04 TUN-8700: Add datagram v3 muxer
|
||||
- 2024-11-04 TUN-8646: Allow experimental feature support for datagram v3
|
||||
- 2024-11-04 TUN-8641: Expose methods to simplify V3 Datagram parsing on the edge
|
||||
- 2024-10-31 TUN-8708: Bump python min version to 3.10
|
||||
- 2024-10-31 TUN-8667: Add datagram v3 session manager
|
||||
- 2024-10-25 TUN-8692: remove dashes from session id
|
||||
- 2024-10-24 TUN-8694: Rework release script
|
||||
- 2024-10-24 TUN-8661: Refactor connection methods to support future different datagram muxing methods
|
||||
- 2024-07-22 TUN-8553: Bump go to 1.22.5 and go-boring 1.22.5-1
|
||||
|
||||
2024.10.1
|
||||
- 2024-10-23 TUN-8694: Fix github release script
|
||||
- 2024-10-21 Revert "TUN-8592: Use metadata from the edge to determine if request body is empty for QUIC transport"
|
||||
- 2024-10-18 TUN-8688: Correct UDP bind for IPv6 edge connectivity on macOS
|
||||
- 2024-10-17 TUN-8685: Bump coredns dependency
|
||||
- 2024-10-16 TUN-8638: Add datagram v3 serializers and deserializers
|
||||
- 2024-10-15 chore: Remove h2mux code
|
||||
- 2024-10-11 TUN-8631: Abort release on version mismatch
|
||||
|
||||
2024.10.0
|
||||
- 2024-10-01 TUN-8646: Add datagram v3 support feature flag
|
||||
- 2024-09-30 TUN-8621: Fix cloudflared version in change notes to account for release date
|
||||
- 2024-09-19 Adding semgrep yaml file
|
||||
- 2024-09-12 TUN-8632: Delay checking auto-update by the provided frequency
|
||||
- 2024-09-11 TUN-8630: Check checksum of downloaded binary to compare to current for auto-updating
|
||||
- 2024-09-09 TUN-8629: Cloudflared update on Windows requires running it twice to update
|
||||
- 2024-09-06 PPIP-2310: Update quick tunnel disclaimer
|
||||
- 2024-08-30 TUN-8621: Prevent QUIC connection from closing before grace period after unregistering
|
||||
- 2024-08-09 TUN-8592: Use metadata from the edge to determine if request body is empty for QUIC transport
|
||||
- 2024-06-26 TUN-8484: Print response when QuickTunnel can't be unmarshalled
|
||||
|
||||
2024.9.1
|
||||
- 2024-09-10 Revert Release 2024.9.0
|
||||
|
||||
2024.9.0
|
||||
- 2024-09-10 TUN-8621: Fix cloudflared version in change notes.
|
||||
- 2024-09-06 PPIP-2310: Update quick tunnel disclaimer
|
||||
- 2024-08-30 TUN-8621: Prevent QUIC connection from closing before grace period after unregistering
|
||||
- 2024-08-09 TUN-8592: Use metadata from the edge to determine if request body is empty for QUIC transport
|
||||
- 2024-06-26 TUN-8484: Print response when QuickTunnel can't be unmarshalled
|
||||
|
||||
2024.8.3
|
||||
- 2024-08-15 TUN-8591 login command without extra text
|
||||
- 2024-03-25 remove code that will not be executed
|
||||
|
|
75
cfsetup.yaml
75
cfsetup.yaml
|
@ -1,8 +1,9 @@
|
|||
pinned_go: &pinned_go go-boring=1.22.2-1
|
||||
pinned_go: &pinned_go go-boring=1.22.5-1
|
||||
|
||||
build_dir: &build_dir /cfsetup_build
|
||||
default-flavor: bullseye
|
||||
buster: &buster
|
||||
default-flavor: bookworm
|
||||
|
||||
bullseye: &bullseye
|
||||
build-linux:
|
||||
build_dir: *build_dir
|
||||
builddeps: &build_deps
|
||||
|
@ -31,8 +32,8 @@ buster: &buster
|
|||
builddeps: *build_deps
|
||||
pre-cache: *build_pre_cache
|
||||
post-cache:
|
||||
- make cover
|
||||
# except FIPS and macos
|
||||
- make cover
|
||||
# except FIPS and macos
|
||||
build-linux-release:
|
||||
build_dir: *build_dir
|
||||
builddeps: &build_deps_release
|
||||
|
@ -46,19 +47,17 @@ buster: &buster
|
|||
- python3-pip
|
||||
- python3-setuptools
|
||||
- wget
|
||||
pre-cache: &build_release_pre_cache
|
||||
- pip3 install pynacl==1.4.0
|
||||
- pip3 install pygithub==1.55
|
||||
- pip3 install boto3==1.22.9
|
||||
- pip3 install python-gnupg==0.4.9
|
||||
- python3-venv
|
||||
post-cache:
|
||||
- python3 -m venv env
|
||||
- . /cfsetup_build/env/bin/activate
|
||||
- pip install pynacl==1.4.0 pygithub==1.55 boto3==1.22.9 python-gnupg==0.4.9
|
||||
# build all packages (except macos and FIPS) and move them to /cfsetup/built_artifacts
|
||||
- ./build-packages.sh
|
||||
# handle FIPS separately so that we built with gofips compiler
|
||||
build-linux-fips-release:
|
||||
build_dir: *build_dir
|
||||
builddeps: *build_deps_release
|
||||
pre-cache: *build_release_pre_cache
|
||||
post-cache:
|
||||
# same logic as above, but for FIPS packages only
|
||||
- ./build-packages-fips.sh
|
||||
|
@ -110,7 +109,7 @@ buster: &buster
|
|||
- export GOOS=linux
|
||||
- export GOARCH=arm64
|
||||
- export NIGHTLY=true
|
||||
#- export FIPS=true # TUN-7595
|
||||
# - export FIPS=true # TUN-7595
|
||||
- export ORIGINAL_NAME=true
|
||||
- make cloudflared-deb
|
||||
build-deb-arm64:
|
||||
|
@ -133,12 +132,14 @@ buster: &buster
|
|||
# libmsi and libgcab are libraries the wixl binary depends on.
|
||||
- libmsi-dev
|
||||
- libgcab-dev
|
||||
- python3-venv
|
||||
pre-cache:
|
||||
- wget https://github.com/sudarshan-reddy/msitools/releases/download/v0.101b/wixl -P /usr/local/bin
|
||||
- chmod a+x /usr/local/bin/wixl
|
||||
- pip3 install pynacl==1.4.0
|
||||
- pip3 install pygithub==1.55
|
||||
post-cache:
|
||||
- python3 -m venv env
|
||||
- . env/bin/activate
|
||||
- pip install pynacl==1.4.0 pygithub==1.55
|
||||
- .teamcity/package-windows.sh
|
||||
test:
|
||||
build_dir: *build_dir
|
||||
|
@ -172,18 +173,22 @@ buster: &buster
|
|||
build_dir: *build_dir
|
||||
builddeps: &build_deps_component_test
|
||||
- *pinned_go
|
||||
- python3.7
|
||||
- python3
|
||||
- python3-pip
|
||||
- python3-setuptools
|
||||
# procps installs the ps command which is needed in test_sysv_service because the init script
|
||||
# uses ps pid to determine if the agent is running
|
||||
# procps installs the ps command which is needed in test_sysv_service
|
||||
# because the init script uses ps pid to determine if the agent is
|
||||
# running
|
||||
- procps
|
||||
- python3-venv
|
||||
pre-cache-copy-paths:
|
||||
- component-tests/requirements.txt
|
||||
pre-cache: &component_test_pre_cache
|
||||
- sudo pip3 install --upgrade -r component-tests/requirements.txt
|
||||
post-cache: &component_test_post_cache
|
||||
# Creates and routes a Named Tunnel for this build. Also constructs config file from env vars.
|
||||
- python3 -m venv env
|
||||
- . env/bin/activate
|
||||
- pip install --upgrade -r component-tests/requirements.txt
|
||||
# Creates and routes a Named Tunnel for this build. Also constructs
|
||||
# config file from env vars.
|
||||
- python3 component-tests/setup.py --type create
|
||||
- pytest component-tests -o log_cli=true --log-cli-level=INFO
|
||||
# The Named Tunnel is deleted and its route unprovisioned here.
|
||||
|
@ -193,7 +198,6 @@ buster: &buster
|
|||
builddeps: *build_deps_component_test
|
||||
pre-cache-copy-paths:
|
||||
- component-tests/requirements.txt
|
||||
pre-cache: *component_test_pre_cache
|
||||
post-cache: *component_test_post_cache
|
||||
github-release-dryrun:
|
||||
build_dir: *build_dir
|
||||
|
@ -204,10 +208,11 @@ buster: &buster
|
|||
- libffi-dev
|
||||
- python3-setuptools
|
||||
- python3-pip
|
||||
pre-cache:
|
||||
- pip3 install pynacl==1.4.0
|
||||
- pip3 install pygithub==1.55
|
||||
- python3-venv
|
||||
post-cache:
|
||||
- python3 -m venv env
|
||||
- . env/bin/activate
|
||||
- pip install pynacl==1.4.0 pygithub==1.55
|
||||
- make github-release-dryrun
|
||||
github-release:
|
||||
build_dir: *build_dir
|
||||
|
@ -218,10 +223,11 @@ buster: &buster
|
|||
- libffi-dev
|
||||
- python3-setuptools
|
||||
- python3-pip
|
||||
pre-cache:
|
||||
- pip3 install pynacl==1.4.0
|
||||
- pip3 install pygithub==1.55
|
||||
- python3-venv
|
||||
post-cache:
|
||||
- python3 -m venv env
|
||||
- . env/bin/activate
|
||||
- pip install pynacl==1.4.0 pygithub==1.55
|
||||
- make github-release
|
||||
r2-linux-release:
|
||||
build_dir: *build_dir
|
||||
|
@ -237,14 +243,13 @@ buster: &buster
|
|||
- python3-setuptools
|
||||
- python3-pip
|
||||
- reprepro
|
||||
- createrepo
|
||||
pre-cache:
|
||||
- pip3 install pynacl==1.4.0
|
||||
- pip3 install pygithub==1.55
|
||||
- pip3 install boto3==1.22.9
|
||||
- pip3 install python-gnupg==0.4.9
|
||||
- createrepo-c
|
||||
- python3-venv
|
||||
post-cache:
|
||||
- python3 -m venv env
|
||||
- . env/bin/activate
|
||||
- pip install pynacl==1.4.0 pygithub==1.55 boto3==1.22.9 python-gnupg==0.4.9
|
||||
- make r2-linux-release
|
||||
|
||||
bullseye: *buster
|
||||
bookworm: *buster
|
||||
bookworm: *bullseye
|
||||
trixie: *bullseye
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
package cliutil
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
@ -13,6 +16,7 @@ type BuildInfo struct {
|
|||
GoArch string `json:"go_arch"`
|
||||
BuildType string `json:"build_type"`
|
||||
CloudflaredVersion string `json:"cloudflared_version"`
|
||||
Checksum string `json:"checksum"`
|
||||
}
|
||||
|
||||
func GetBuildInfo(buildType, version string) *BuildInfo {
|
||||
|
@ -22,11 +26,12 @@ func GetBuildInfo(buildType, version string) *BuildInfo {
|
|||
GoArch: runtime.GOARCH,
|
||||
BuildType: buildType,
|
||||
CloudflaredVersion: version,
|
||||
Checksum: currentBinaryChecksum(),
|
||||
}
|
||||
}
|
||||
|
||||
func (bi *BuildInfo) Log(log *zerolog.Logger) {
|
||||
log.Info().Msgf("Version %s", bi.CloudflaredVersion)
|
||||
log.Info().Msgf("Version %s (Checksum %s)", bi.CloudflaredVersion, bi.Checksum)
|
||||
if bi.BuildType != "" {
|
||||
log.Info().Msgf("Built%s", bi.GetBuildTypeMsg())
|
||||
}
|
||||
|
@ -51,3 +56,28 @@ func (bi *BuildInfo) GetBuildTypeMsg() string {
|
|||
func (bi *BuildInfo) UserAgent() string {
|
||||
return fmt.Sprintf("cloudflared/%s", bi.CloudflaredVersion)
|
||||
}
|
||||
|
||||
// FileChecksum opens a file and returns the SHA256 checksum.
|
||||
func FileChecksum(filePath string) (string, error) {
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%x", h.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func currentBinaryChecksum() string {
|
||||
currentPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
sum, _ := FileChecksum(currentPath)
|
||||
return sum
|
||||
}
|
||||
|
|
|
@ -91,7 +91,7 @@ func main() {
|
|||
|
||||
tunnel.Init(bInfo, graceShutdownC) // we need this to support the tunnel sub command...
|
||||
access.Init(graceShutdownC, Version)
|
||||
updater.Init(Version)
|
||||
updater.Init(bInfo)
|
||||
tracing.Init(Version)
|
||||
token.Init(Version)
|
||||
tail.Init(bInfo)
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/trace"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -28,6 +29,7 @@ import (
|
|||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/credentials"
|
||||
"github.com/cloudflare/cloudflared/diagnostic"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/features"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
|
@ -39,6 +41,7 @@ import (
|
|||
"github.com/cloudflare/cloudflared/supervisor"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
"github.com/cloudflare/cloudflared/tunneldns"
|
||||
"github.com/cloudflare/cloudflared/tunnelstate"
|
||||
"github.com/cloudflare/cloudflared/validation"
|
||||
)
|
||||
|
||||
|
@ -125,6 +128,94 @@ var (
|
|||
"most likely you already have a conflicting record there. You can also rerun this command with --%s to overwrite "+
|
||||
"any existing DNS records for this hostname.", overwriteDNSFlag)
|
||||
deprecatedClassicTunnelErr = fmt.Errorf("Classic tunnels have been deprecated, please use Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)")
|
||||
// TODO: TUN-8756 the list below denotes the flags that do not possess any kind of sensitive information
|
||||
// however this approach is not maintainble in the long-term.
|
||||
nonSecretFlagsList = []string{
|
||||
"config",
|
||||
"autoupdate-freq",
|
||||
"no-autoupdate",
|
||||
"metrics",
|
||||
"pidfile",
|
||||
"url",
|
||||
"hello-world",
|
||||
"socks5",
|
||||
"proxy-connect-timeout",
|
||||
"proxy-tls-timeout",
|
||||
"proxy-tcp-keepalive",
|
||||
"proxy-no-happy-eyeballs",
|
||||
"proxy-keepalive-connections",
|
||||
"proxy-keepalive-timeout",
|
||||
"proxy-connection-timeout",
|
||||
"proxy-expect-continue-timeout",
|
||||
"http-host-header",
|
||||
"origin-server-name",
|
||||
"unix-socket",
|
||||
"origin-ca-pool",
|
||||
"no-tls-verify",
|
||||
"no-chunked-encoding",
|
||||
"http2-origin",
|
||||
"management-hostname",
|
||||
"service-op-ip",
|
||||
"local-ssh-port",
|
||||
"ssh-idle-timeout",
|
||||
"ssh-max-timeout",
|
||||
"bucket-name",
|
||||
"region-name",
|
||||
"s3-url-host",
|
||||
"host-key-path",
|
||||
"ssh-server",
|
||||
"bastion",
|
||||
"proxy-address",
|
||||
"proxy-port",
|
||||
"loglevel",
|
||||
"transport-loglevel",
|
||||
"logfile",
|
||||
"log-directory",
|
||||
"trace-output",
|
||||
"proxy-dns",
|
||||
"proxy-dns-port",
|
||||
"proxy-dns-address",
|
||||
"proxy-dns-upstream",
|
||||
"proxy-dns-max-upstream-conns",
|
||||
"proxy-dns-bootstrap",
|
||||
"is-autoupdated",
|
||||
"edge",
|
||||
"region",
|
||||
"edge-ip-version",
|
||||
"edge-bind-address",
|
||||
"cacert",
|
||||
"hostname",
|
||||
"id",
|
||||
"lb-pool",
|
||||
"api-url",
|
||||
"metrics-update-freq",
|
||||
"tag",
|
||||
"heartbeat-interval",
|
||||
"heartbeat-count",
|
||||
"max-edge-addr-retries",
|
||||
"retries",
|
||||
"ha-connections",
|
||||
"rpc-timeout",
|
||||
"write-stream-timeout",
|
||||
"quic-disable-pmtu-discovery",
|
||||
"quic-connection-level-flow-control-limit",
|
||||
"quic-stream-level-flow-control-limit",
|
||||
"label",
|
||||
"grace-period",
|
||||
"compression-quality",
|
||||
"use-reconnect-token",
|
||||
"dial-edge-timeout",
|
||||
"stdin-control",
|
||||
"name",
|
||||
"ui",
|
||||
"quick-service",
|
||||
"max-fetch-size",
|
||||
"post-quantum",
|
||||
"management-diagnostics",
|
||||
"protocol",
|
||||
"overwrite-dns",
|
||||
"help",
|
||||
}
|
||||
)
|
||||
|
||||
func Flags() []cli.Flag {
|
||||
|
@ -139,11 +230,13 @@ func Commands() []*cli.Command {
|
|||
buildVirtualNetworkSubcommand(false),
|
||||
buildRunCommand(),
|
||||
buildListCommand(),
|
||||
buildReadyCommand(),
|
||||
buildInfoCommand(),
|
||||
buildIngressSubcommand(),
|
||||
buildDeleteCommand(),
|
||||
buildCleanupCommand(),
|
||||
buildTokenCommand(),
|
||||
buildDiagCommand(),
|
||||
// for compatibility, allow following as tunnel subcommands
|
||||
proxydns.Command(true),
|
||||
cliutil.RemovedCommand("db-connect"),
|
||||
|
@ -419,7 +512,7 @@ func StartServer(
|
|||
|
||||
// Disable ICMP packet routing for quick tunnels
|
||||
if quickTunnelURL != "" {
|
||||
tunnelConfig.PacketConfig = nil
|
||||
tunnelConfig.ICMPRouterServer = nil
|
||||
}
|
||||
|
||||
internalRules := []ingress.Rule{}
|
||||
|
@ -447,19 +540,42 @@ func StartServer(
|
|||
return err
|
||||
}
|
||||
|
||||
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
|
||||
metricsListener, err := metrics.CreateMetricsListener(&listeners, c.String("metrics"))
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Error opening metrics server listener")
|
||||
return errors.Wrap(err, "Error opening metrics server listener")
|
||||
}
|
||||
|
||||
defer metricsListener.Close()
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
readinessServer := metrics.NewReadyServer(log, clientID)
|
||||
observer.RegisterSink(readinessServer)
|
||||
tracker := tunnelstate.NewConnTracker(log)
|
||||
observer.RegisterSink(tracker)
|
||||
|
||||
ipv4, ipv6, err := determineICMPSources(c, log)
|
||||
sources := make([]string, 0)
|
||||
if err == nil {
|
||||
sources = append(sources, ipv4.String())
|
||||
sources = append(sources, ipv6.String())
|
||||
}
|
||||
|
||||
readinessServer := metrics.NewReadyServer(clientID, tracker)
|
||||
cliFlags := nonSecretCliFlags(log, c, nonSecretFlagsList)
|
||||
diagnosticHandler := diagnostic.NewDiagnosticHandler(
|
||||
log,
|
||||
0,
|
||||
diagnostic.NewSystemCollectorImpl(buildInfo.CloudflaredVersion),
|
||||
tunnelConfig.NamedTunnel.Credentials.TunnelID,
|
||||
clientID,
|
||||
tracker,
|
||||
cliFlags,
|
||||
sources,
|
||||
)
|
||||
metricsConfig := metrics.Config{
|
||||
ReadyServer: readinessServer,
|
||||
DiagnosticHandler: diagnosticHandler,
|
||||
QuickTunnelHostname: quickTunnelURL,
|
||||
Orchestrator: orchestrator,
|
||||
}
|
||||
|
@ -856,9 +972,15 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag {
|
|||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "metrics",
|
||||
Value: "localhost:",
|
||||
Usage: "Listen address for metrics reporting.",
|
||||
Name: "metrics",
|
||||
Value: metrics.GetMetricsDefaultAddress(metrics.Runtime),
|
||||
Usage: fmt.Sprintf(
|
||||
`Listen address for metrics reporting. If no address is passed cloudflared will try to bind to %v.
|
||||
If all are unavailable, a random port will be used. Note that when running cloudflared from an virtual
|
||||
environment the default address binds to all interfaces, hence, it is important to isolate the host
|
||||
and virtualized host network stacks from each other`,
|
||||
metrics.GetMetricsKnownAddresses(metrics.Runtime),
|
||||
),
|
||||
EnvVars: []string{"TUNNEL_METRICS"},
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
|
@ -1189,3 +1311,46 @@ reconnect [delay]
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func nonSecretCliFlags(log *zerolog.Logger, cli *cli.Context, flagInclusionList []string) map[string]string {
|
||||
flagsNames := cli.FlagNames()
|
||||
flags := make(map[string]string, len(flagsNames))
|
||||
|
||||
for _, flag := range flagsNames {
|
||||
value := cli.String(flag)
|
||||
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
isIncluded := isFlagIncluded(flagInclusionList, flag)
|
||||
if !isIncluded {
|
||||
continue
|
||||
}
|
||||
|
||||
switch flag {
|
||||
case logger.LogDirectoryFlag, logger.LogFileFlag:
|
||||
{
|
||||
absolute, err := filepath.Abs(value)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("could not convert %s path to absolute", flag)
|
||||
} else {
|
||||
flags[flag] = absolute
|
||||
}
|
||||
}
|
||||
default:
|
||||
flags[flag] = value
|
||||
}
|
||||
}
|
||||
return flags
|
||||
}
|
||||
|
||||
func isFlagIncluded(flagInclusionList []string, flag string) bool {
|
||||
for _, include := range flagInclusionList {
|
||||
if include == flag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -252,11 +252,11 @@ func prepareTunnelConfig(
|
|||
QUICConnectionLevelFlowControlLimit: c.Uint64(quicConnLevelFlowControlLimit),
|
||||
QUICStreamLevelFlowControlLimit: c.Uint64(quicStreamLevelFlowControlLimit),
|
||||
}
|
||||
packetConfig, err := newPacketConfig(c, log)
|
||||
icmpRouter, err := newICMPRouter(c, log)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("ICMP proxy feature is disabled")
|
||||
} else {
|
||||
tunnelConfig.PacketConfig = packetConfig
|
||||
tunnelConfig.ICMPRouterServer = icmpRouter
|
||||
}
|
||||
orchestratorConfig := &orchestration.Config{
|
||||
Ingress: &ingressRules,
|
||||
|
@ -351,33 +351,39 @@ func adjustIPVersionByBindAddress(ipVersion allregions.ConfigIPVersion, ip net.I
|
|||
}
|
||||
}
|
||||
|
||||
func newPacketConfig(c *cli.Context, logger *zerolog.Logger) (*ingress.GlobalRouterConfig, error) {
|
||||
func newICMPRouter(c *cli.Context, logger *zerolog.Logger) (ingress.ICMPRouterServer, error) {
|
||||
ipv4Src, ipv6Src, err := determineICMPSources(c, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
icmpRouter, err := ingress.NewICMPRouter(ipv4Src, ipv6Src, logger, icmpFunnelTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return icmpRouter, nil
|
||||
}
|
||||
|
||||
func determineICMPSources(c *cli.Context, logger *zerolog.Logger) (netip.Addr, netip.Addr, error) {
|
||||
ipv4Src, err := determineICMPv4Src(c.String("icmpv4-src"), logger)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to determine IPv4 source address for ICMP proxy")
|
||||
return netip.Addr{}, netip.Addr{}, errors.Wrap(err, "failed to determine IPv4 source address for ICMP proxy")
|
||||
}
|
||||
|
||||
logger.Info().Msgf("ICMP proxy will use %s as source for IPv4", ipv4Src)
|
||||
|
||||
ipv6Src, zone, err := determineICMPv6Src(c.String("icmpv6-src"), logger, ipv4Src)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to determine IPv6 source address for ICMP proxy")
|
||||
return netip.Addr{}, netip.Addr{}, errors.Wrap(err, "failed to determine IPv6 source address for ICMP proxy")
|
||||
}
|
||||
|
||||
if zone != "" {
|
||||
logger.Info().Msgf("ICMP proxy will use %s in zone %s as source for IPv6", ipv6Src, zone)
|
||||
} else {
|
||||
logger.Info().Msgf("ICMP proxy will use %s as source for IPv6", ipv6Src)
|
||||
}
|
||||
|
||||
icmpRouter, err := ingress.NewICMPRouter(ipv4Src, ipv6Src, zone, logger, icmpFunnelTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ingress.GlobalRouterConfig{
|
||||
ICMPRouter: icmpRouter,
|
||||
IPv4Src: ipv4Src,
|
||||
IPv6Src: ipv6Src,
|
||||
Zone: zone,
|
||||
}, nil
|
||||
return ipv4Src, ipv6Src, nil
|
||||
}
|
||||
|
||||
func determineICMPv4Src(userDefinedSrc string, logger *zerolog.Logger) (netip.Addr, error) {
|
||||
|
@ -407,13 +413,12 @@ type interfaceIP struct {
|
|||
|
||||
func determineICMPv6Src(userDefinedSrc string, logger *zerolog.Logger, ipv4Src netip.Addr) (addr netip.Addr, zone string, err error) {
|
||||
if userDefinedSrc != "" {
|
||||
userDefinedIP, zone, _ := strings.Cut(userDefinedSrc, "%")
|
||||
addr, err := netip.ParseAddr(userDefinedIP)
|
||||
addr, err := netip.ParseAddr(userDefinedSrc)
|
||||
if err != nil {
|
||||
return netip.Addr{}, "", err
|
||||
}
|
||||
if addr.Is6() {
|
||||
return addr, zone, nil
|
||||
return addr, addr.Zone(), nil
|
||||
}
|
||||
return netip.Addr{}, "", fmt.Errorf("expect IPv6, but %s is IPv4", userDefinedSrc)
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package tunnel
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -15,10 +16,7 @@ import (
|
|||
|
||||
const httpTimeout = 15 * time.Second
|
||||
|
||||
const disclaimer = "Thank you for trying Cloudflare Tunnel. Doing so, without a Cloudflare account, is a quick way to" +
|
||||
" experiment and try it out. However, be aware that these account-less Tunnels have no uptime guarantee. If you " +
|
||||
"intend to use Tunnels in production you should use a pre-created named tunnel by following: " +
|
||||
"https://developers.cloudflare.com/cloudflare-one/connections/connect-apps"
|
||||
const disclaimer = "Thank you for trying Cloudflare Tunnel. Doing so, without a Cloudflare account, is a quick way to experiment and try it out. However, be aware that these account-less Tunnels have no uptime guarantee, are subject to the Cloudflare Online Services Terms of Use (https://www.cloudflare.com/website-terms/), and Cloudflare reserves the right to investigate your use of Tunnels for violations of such terms. If you intend to use Tunnels in production you should use a pre-created named tunnel by following: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps"
|
||||
|
||||
// RunQuickTunnel requests a tunnel from the specified service.
|
||||
// We use this to power quick tunnels on trycloudflare.com, but the
|
||||
|
@ -47,8 +45,17 @@ func RunQuickTunnel(sc *subcommandContext) error {
|
|||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// This will read the entire response into memory so we can print it in case of error
|
||||
rsp_body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read quick-tunnel response")
|
||||
}
|
||||
|
||||
var data QuickTunnelResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
if err := json.Unmarshal(rsp_body, &data); err != nil {
|
||||
rsp_string := string(rsp_body)
|
||||
fields := map[string]interface{}{"status_code": resp.Status}
|
||||
sc.log.Err(err).Fields(fields).Msgf("Error unmarshaling QuickTunnel response: %s", rsp_string)
|
||||
return errors.Wrap(err, "failed to unmarshal quick Tunnel")
|
||||
}
|
||||
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
|
@ -26,17 +28,27 @@ import (
|
|||
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/diagnostic"
|
||||
"github.com/cloudflare/cloudflared/metrics"
|
||||
)
|
||||
|
||||
const (
|
||||
allSortByOptions = "name, id, createdAt, deletedAt, numConnections"
|
||||
connsSortByOptions = "id, startedAt, numConnections, version"
|
||||
CredFileFlagAlias = "cred-file"
|
||||
CredFileFlag = "credentials-file"
|
||||
CredContentsFlag = "credentials-contents"
|
||||
TunnelTokenFlag = "token"
|
||||
TunnelTokenFileFlag = "token-file"
|
||||
overwriteDNSFlagName = "overwrite-dns"
|
||||
allSortByOptions = "name, id, createdAt, deletedAt, numConnections"
|
||||
connsSortByOptions = "id, startedAt, numConnections, version"
|
||||
CredFileFlagAlias = "cred-file"
|
||||
CredFileFlag = "credentials-file"
|
||||
CredContentsFlag = "credentials-contents"
|
||||
TunnelTokenFlag = "token"
|
||||
TunnelTokenFileFlag = "token-file"
|
||||
overwriteDNSFlagName = "overwrite-dns"
|
||||
noDiagLogsFlagName = "no-diag-logs"
|
||||
noDiagMetricsFlagName = "no-diag-metrics"
|
||||
noDiagSystemFlagName = "no-diag-system"
|
||||
noDiagRuntimeFlagName = "no-diag-runtime"
|
||||
noDiagNetworkFlagName = "no-diag-network"
|
||||
diagContainerIDFlagName = "diag-container-id"
|
||||
diagPodFlagName = "diag-pod-id"
|
||||
metricsFlagName = "metrics"
|
||||
|
||||
LogFieldTunnelID = "tunnelID"
|
||||
)
|
||||
|
@ -183,6 +195,46 @@ var (
|
|||
Usage: "Source address and the interface name to send/receive ICMPv6 messages. If not provided cloudflared will dial a local address to determine the source IP or fallback to ::.",
|
||||
EnvVars: []string{"TUNNEL_ICMPV6_SRC"},
|
||||
}
|
||||
metricsFlag = &cli.StringFlag{
|
||||
Name: metricsFlagName,
|
||||
Usage: "The metrics server address i.e.: 127.0.0.1:12345. If your instance is running in a Docker/Kubernetes environment you need to setup port forwarding for your application.",
|
||||
Value: "",
|
||||
}
|
||||
diagContainerFlag = &cli.StringFlag{
|
||||
Name: diagContainerIDFlagName,
|
||||
Usage: "Container ID or Name to collect logs from",
|
||||
Value: "",
|
||||
}
|
||||
diagPodFlag = &cli.StringFlag{
|
||||
Name: diagPodFlagName,
|
||||
Usage: "Kubernetes POD to collect logs from",
|
||||
Value: "",
|
||||
}
|
||||
noDiagLogsFlag = &cli.BoolFlag{
|
||||
Name: noDiagLogsFlagName,
|
||||
Usage: "Log collection will not be performed",
|
||||
Value: false,
|
||||
}
|
||||
noDiagMetricsFlag = &cli.BoolFlag{
|
||||
Name: noDiagMetricsFlagName,
|
||||
Usage: "Metric collection will not be performed",
|
||||
Value: false,
|
||||
}
|
||||
noDiagSystemFlag = &cli.BoolFlag{
|
||||
Name: noDiagSystemFlagName,
|
||||
Usage: "System information collection will not be performed",
|
||||
Value: false,
|
||||
}
|
||||
noDiagRuntimeFlag = &cli.BoolFlag{
|
||||
Name: noDiagRuntimeFlagName,
|
||||
Usage: "Runtime information collection will not be performed",
|
||||
Value: false,
|
||||
}
|
||||
noDiagNetworkFlag = &cli.BoolFlag{
|
||||
Name: noDiagNetworkFlagName,
|
||||
Usage: "Network diagnostics won't be performed",
|
||||
Value: false,
|
||||
}
|
||||
)
|
||||
|
||||
func buildCreateCommand() *cli.Command {
|
||||
|
@ -379,7 +431,6 @@ func formatAndPrintTunnelList(tunnels []*cfapi.Tunnel, showRecentlyDisconnected
|
|||
}
|
||||
|
||||
func fmtConnections(connections []cfapi.Connection, showRecentlyDisconnected bool) string {
|
||||
|
||||
// Count connections per colo
|
||||
numConnsPerColo := make(map[string]uint, len(connections))
|
||||
for _, connection := range connections {
|
||||
|
@ -403,6 +454,39 @@ func fmtConnections(connections []cfapi.Connection, showRecentlyDisconnected boo
|
|||
return strings.Join(output, ", ")
|
||||
}
|
||||
|
||||
func buildReadyCommand() *cli.Command {
|
||||
return &cli.Command{
|
||||
Name: "ready",
|
||||
Action: cliutil.ConfiguredAction(readyCommand),
|
||||
Usage: "Call /ready endpoint and return proper exit code",
|
||||
UsageText: "cloudflared tunnel [tunnel command options] ready [subcommand options]",
|
||||
Description: "cloudflared tunnel ready will return proper exit code based on the /ready endpoint",
|
||||
Flags: []cli.Flag{},
|
||||
CustomHelpTemplate: commandHelpTemplate(),
|
||||
}
|
||||
}
|
||||
|
||||
func readyCommand(c *cli.Context) error {
|
||||
metricsOpts := c.String("metrics")
|
||||
if !c.IsSet("metrics") {
|
||||
return fmt.Errorf("--metrics has to be provided")
|
||||
}
|
||||
|
||||
requestURL := fmt.Sprintf("http://%s/ready", metricsOpts)
|
||||
res, err := http.Get(requestURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("http://%s/ready endpoint returned status code %d\n%s", metricsOpts, res.StatusCode, body)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildInfoCommand() *cli.Command {
|
||||
return &cli.Command{
|
||||
Name: "info",
|
||||
|
@ -882,8 +966,10 @@ func lbRouteFromArg(c *cli.Context) (cfapi.HostnameRoute, error) {
|
|||
return cfapi.NewLBRoute(lbName, lbPool), nil
|
||||
}
|
||||
|
||||
var nameRegex = regexp.MustCompile("^[_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
|
||||
var hostNameRegex = regexp.MustCompile("^[*_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
|
||||
var (
|
||||
nameRegex = regexp.MustCompile("^[_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
|
||||
hostNameRegex = regexp.MustCompile("^[*_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
|
||||
)
|
||||
|
||||
func validateName(s string, allowWildcardSubdomain bool) bool {
|
||||
if allowWildcardSubdomain {
|
||||
|
@ -971,3 +1057,78 @@ SUBCOMMAND OPTIONS:
|
|||
`
|
||||
return fmt.Sprintf(template, parentFlagsHelp)
|
||||
}
|
||||
|
||||
func buildDiagCommand() *cli.Command {
|
||||
return &cli.Command{
|
||||
Name: "diag",
|
||||
Action: cliutil.ConfiguredAction(diagCommand),
|
||||
Usage: "Creates a diagnostic report from a local cloudflared instance",
|
||||
UsageText: "cloudflared tunnel [tunnel command options] diag [subcommand options]",
|
||||
Description: "cloudflared tunnel diag will create a diagnostic report of a local cloudflared instance. The diagnostic procedure collects: logs, metrics, system information, traceroute to Cloudflare Edge, and runtime information. Since there may be multiple instances of cloudflared running the --metrics option may be provided to target a specific instance.",
|
||||
Flags: []cli.Flag{
|
||||
metricsFlag,
|
||||
diagContainerFlag,
|
||||
diagPodFlag,
|
||||
noDiagLogsFlag,
|
||||
noDiagMetricsFlag,
|
||||
noDiagSystemFlag,
|
||||
noDiagRuntimeFlag,
|
||||
noDiagNetworkFlag,
|
||||
},
|
||||
CustomHelpTemplate: commandHelpTemplate(),
|
||||
}
|
||||
}
|
||||
|
||||
func diagCommand(ctx *cli.Context) error {
|
||||
sctx, err := newSubcommandContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log := sctx.log
|
||||
options := diagnostic.Options{
|
||||
KnownAddresses: metrics.GetMetricsKnownAddresses(metrics.Runtime),
|
||||
Address: sctx.c.String(metricsFlagName),
|
||||
ContainerID: sctx.c.String(diagContainerIDFlagName),
|
||||
PodID: sctx.c.String(diagPodFlagName),
|
||||
Toggles: diagnostic.Toggles{
|
||||
NoDiagLogs: sctx.c.Bool(noDiagLogsFlagName),
|
||||
NoDiagMetrics: sctx.c.Bool(noDiagMetricsFlagName),
|
||||
NoDiagSystem: sctx.c.Bool(noDiagSystemFlagName),
|
||||
NoDiagRuntime: sctx.c.Bool(noDiagRuntimeFlagName),
|
||||
NoDiagNetwork: sctx.c.Bool(noDiagNetworkFlagName),
|
||||
},
|
||||
}
|
||||
|
||||
if options.Address == "" {
|
||||
log.Info().Msg("If your instance is running in a Docker/Kubernetes environment you need to setup port forwarding for your application.")
|
||||
}
|
||||
|
||||
states, err := diagnostic.RunDiagnostic(log, options)
|
||||
|
||||
if errors.Is(err, diagnostic.ErrMetricsServerNotFound) {
|
||||
log.Warn().Msg("No instances found")
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, diagnostic.ErrMultipleMetricsServerFound) {
|
||||
if states != nil {
|
||||
log.Info().Msgf("Found multiple instances running:")
|
||||
for _, state := range states {
|
||||
log.Info().Msgf("Instance: tunnel-id=%s connector-id=%s metrics-address=%s", state.TunnelID, state.ConnectorID, state.URL.String())
|
||||
}
|
||||
log.Info().Msgf("To select one instance use the option --metrics")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, diagnostic.ErrLogConfigurationIsInvalid) {
|
||||
log.Info().Msg("Couldn't extract logs from the instance. If the instance is running in a containerized environment use the option --diag-container-id or --diag-pod-id. If there is no logging configuration use --no-diag-logs.")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warn().Msg("Diagnostic completed with one or more errors")
|
||||
} else {
|
||||
log.Info().Msg("Diagnostic completed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/urfave/cli/v2"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
)
|
||||
|
@ -31,7 +32,7 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
version string
|
||||
buildInfo *cliutil.BuildInfo
|
||||
BuiltForPackageManager = ""
|
||||
)
|
||||
|
||||
|
@ -81,8 +82,8 @@ func (uo *UpdateOutcome) noUpdate() bool {
|
|||
return uo.Error == nil && uo.Updated == false
|
||||
}
|
||||
|
||||
func Init(v string) {
|
||||
version = v
|
||||
func Init(info *cliutil.BuildInfo) {
|
||||
buildInfo = info
|
||||
}
|
||||
|
||||
func CheckForUpdate(options updateOptions) (CheckResult, error) {
|
||||
|
@ -100,11 +101,12 @@ func CheckForUpdate(options updateOptions) (CheckResult, error) {
|
|||
cfdPath = encodeWindowsPath(cfdPath)
|
||||
}
|
||||
|
||||
s := NewWorkersService(version, url, cfdPath, Options{IsBeta: options.isBeta,
|
||||
s := NewWorkersService(buildInfo.CloudflaredVersion, url, cfdPath, Options{IsBeta: options.isBeta,
|
||||
IsForced: options.isForced, RequestedVersion: options.intendedVersion})
|
||||
|
||||
return s.Check()
|
||||
}
|
||||
|
||||
func encodeWindowsPath(path string) string {
|
||||
// We do this because Windows allows spaces in directories such as
|
||||
// Program Files but does not allow these directories to be spaced in batch files.
|
||||
|
@ -196,10 +198,9 @@ func loggedUpdate(log *zerolog.Logger, options updateOptions) UpdateOutcome {
|
|||
|
||||
// AutoUpdater periodically checks for new version of cloudflared.
|
||||
type AutoUpdater struct {
|
||||
configurable *configurable
|
||||
listeners *gracenet.Net
|
||||
updateConfigChan chan *configurable
|
||||
log *zerolog.Logger
|
||||
configurable *configurable
|
||||
listeners *gracenet.Net
|
||||
log *zerolog.Logger
|
||||
}
|
||||
|
||||
// AutoUpdaterConfigurable is the attributes of AutoUpdater that can be reconfigured during runtime
|
||||
|
@ -210,10 +211,9 @@ type configurable struct {
|
|||
|
||||
func NewAutoUpdater(updateDisabled bool, freq time.Duration, listeners *gracenet.Net, log *zerolog.Logger) *AutoUpdater {
|
||||
return &AutoUpdater{
|
||||
configurable: createUpdateConfig(updateDisabled, freq, log),
|
||||
listeners: listeners,
|
||||
updateConfigChan: make(chan *configurable),
|
||||
log: log,
|
||||
configurable: createUpdateConfig(updateDisabled, freq, log),
|
||||
listeners: listeners,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -232,12 +232,20 @@ func createUpdateConfig(updateDisabled bool, freq time.Duration, log *zerolog.Lo
|
|||
}
|
||||
}
|
||||
|
||||
// Run will perodically check for cloudflared updates, download them, and then restart the current cloudflared process
|
||||
// to use the new version. It delays the first update check by the configured frequency as to not attempt a
|
||||
// download immediately and restart after starting (in the case that there is an upgrade available).
|
||||
func (a *AutoUpdater) Run(ctx context.Context) error {
|
||||
ticker := time.NewTicker(a.configurable.freq)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
updateOutcome := loggedUpdate(a.log, updateOptions{updateDisabled: !a.configurable.enabled})
|
||||
if updateOutcome.Updated {
|
||||
Init(updateOutcome.Version)
|
||||
buildInfo.CloudflaredVersion = updateOutcome.Version
|
||||
if IsSysV() {
|
||||
// SysV doesn't have a mechanism to keep service alive, we have to restart the process
|
||||
a.log.Info().Msg("Restarting service managed by SysV...")
|
||||
|
@ -254,25 +262,9 @@ func (a *AutoUpdater) Run(ctx context.Context) error {
|
|||
} else if updateOutcome.UserMessage != "" {
|
||||
a.log.Warn().Msg(updateOutcome.UserMessage)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case newConfigurable := <-a.updateConfigChan:
|
||||
ticker.Stop()
|
||||
a.configurable = newConfigurable
|
||||
ticker = time.NewTicker(a.configurable.freq)
|
||||
// Check if there is new version of cloudflared after receiving new AutoUpdaterConfigurable
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update is the method to pass new AutoUpdaterConfigurable to a running AutoUpdater. It is safe to be called concurrently
|
||||
func (a *AutoUpdater) Update(updateDisabled bool, newFreq time.Duration) {
|
||||
a.updateConfigChan <- createUpdateConfig(updateDisabled, newFreq, a.log)
|
||||
}
|
||||
|
||||
func isAutoupdateEnabled(log *zerolog.Logger, updateDisabled bool, updateFreq time.Duration) bool {
|
||||
if !supportAutoUpdate(log) {
|
||||
return false
|
||||
|
|
|
@ -9,8 +9,14 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
Init(cliutil.GetBuildInfo("TEST", "TEST"))
|
||||
}
|
||||
|
||||
func TestDisabledAutoUpdater(t *testing.T) {
|
||||
listeners := &gracenet.Net{}
|
||||
log := zerolog.Nop()
|
||||
|
|
|
@ -3,6 +3,7 @@ package updater
|
|||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
)
|
||||
|
@ -79,6 +80,10 @@ func (s *WorkersService) Check() (CheckResult, error) {
|
|||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("unable to check for update: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var v VersionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&v); err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -3,7 +3,6 @@ package updater
|
|||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -16,6 +15,10 @@ import (
|
|||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -27,9 +30,9 @@ const (
|
|||
// start the service
|
||||
// exit with code 0 if we've reached this point indicating success.
|
||||
windowsUpdateCommandTemplate = `sc stop cloudflared >nul 2>&1
|
||||
del "{{.OldPath}}"
|
||||
rename "{{.TargetPath}}" {{.OldName}}
|
||||
rename "{{.NewPath}}" {{.BinaryName}}
|
||||
del "{{.OldPath}}"
|
||||
sc start cloudflared >nul 2>&1
|
||||
exit /b 0`
|
||||
batchFileName = "cfd_update.bat"
|
||||
|
@ -86,8 +89,25 @@ func (v *WorkersVersion) Apply() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// check that the file is what is expected
|
||||
if err := isValidChecksum(v.checksum, newFilePath); err != nil {
|
||||
downloadSum, err := cliutil.FileChecksum(newFilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check that the file downloaded matches what is expected.
|
||||
if v.checksum != downloadSum {
|
||||
return errors.New("checksum validation failed")
|
||||
}
|
||||
|
||||
// Check if the currently running version has the same checksum
|
||||
if downloadSum == buildInfo.Checksum {
|
||||
// Currently running binary matches the downloaded binary so we have no reason to update. This is
|
||||
// typically unexpected, as such we emit a sentry event.
|
||||
localHub := sentry.CurrentHub().Clone()
|
||||
err := errors.New("checksum validation matches currently running process")
|
||||
localHub.CaptureException(err)
|
||||
// Make sure to cleanup the new downloaded file since we aren't upgrading versions.
|
||||
os.Remove(newFilePath)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -189,27 +209,6 @@ func isCompressedFile(urlstring string) bool {
|
|||
return strings.HasSuffix(u.Path, ".tgz")
|
||||
}
|
||||
|
||||
// checks if the checksum in the json response matches the checksum of the file download
|
||||
func isValidChecksum(checksum, filePath string) error {
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hash := fmt.Sprintf("%x", h.Sum(nil))
|
||||
|
||||
if checksum != hash {
|
||||
return errors.New("checksum validation failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeBatchFile writes a batch file out to disk
|
||||
// see the dicussion on why it has to be done this way
|
||||
func writeBatchFile(targetPath string, newPath string, oldPath string) error {
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Requirements
|
||||
1. Python 3.7 or later with packages in the given `requirements.txt`
|
||||
- E.g. with conda:
|
||||
- `conda create -n component-tests python=3.7`
|
||||
- `conda activate component-tests`
|
||||
- `pip3 install -r requirements.txt`
|
||||
1. Python 3.10 or later with packages in the given `requirements.txt`
|
||||
- E.g. with venv:
|
||||
- `python3 -m venv ./.venv`
|
||||
- `source ./.venv/bin/activate`
|
||||
- `python3 -m pip install -r requirements.txt`
|
||||
|
||||
2. Create a config yaml file, for example:
|
||||
```
|
||||
|
|
|
@ -45,9 +45,10 @@ class TestTermination:
|
|||
with connected:
|
||||
connected.wait(self.timeout)
|
||||
# Send signal after the SSE connection is established
|
||||
self.terminate_by_signal(cloudflared, signal)
|
||||
self.wait_eyeball_thread(
|
||||
in_flight_req, self.grace_period + self.timeout)
|
||||
with self.within_grace_period():
|
||||
self.terminate_by_signal(cloudflared, signal)
|
||||
self.wait_eyeball_thread(
|
||||
in_flight_req, self.grace_period + self.timeout)
|
||||
|
||||
# test cloudflared terminates before grace period expires when all eyeball
|
||||
# connections are drained
|
||||
|
@ -66,7 +67,7 @@ class TestTermination:
|
|||
|
||||
with connected:
|
||||
connected.wait(self.timeout)
|
||||
with self.within_grace_period():
|
||||
with self.within_grace_period(has_connection=False):
|
||||
# Send signal after the SSE connection is established
|
||||
self.terminate_by_signal(cloudflared, signal)
|
||||
self.wait_eyeball_thread(in_flight_req, self.grace_period)
|
||||
|
@ -78,7 +79,7 @@ class TestTermination:
|
|||
with start_cloudflared(
|
||||
tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], new_process=True, capture_output=False) as cloudflared:
|
||||
wait_tunnel_ready(tunnel_url=config.get_url())
|
||||
with self.within_grace_period():
|
||||
with self.within_grace_period(has_connection=False):
|
||||
self.terminate_by_signal(cloudflared, signal)
|
||||
|
||||
def terminate_by_signal(self, cloudflared, sig):
|
||||
|
@ -92,13 +93,21 @@ class TestTermination:
|
|||
|
||||
# Using this context asserts logic within the context is executed within grace period
|
||||
@contextmanager
|
||||
def within_grace_period(self):
|
||||
def within_grace_period(self, has_connection=True):
|
||||
try:
|
||||
start = time.time()
|
||||
yield
|
||||
finally:
|
||||
|
||||
# If the request takes longer than the grace period then we need to wait at most the grace period.
|
||||
# If the request fell within the grace period cloudflared can close earlier, but to ensure that it doesn't
|
||||
# close immediately we add a minimum boundary. If cloudflared shutdown in less than 1s it's likely that
|
||||
# it shutdown as soon as it received SIGINT. The only way cloudflared can close immediately is if it has no
|
||||
# in-flight requests
|
||||
minimum = 1 if has_connection else 0
|
||||
duration = time.time() - start
|
||||
assert duration < self.grace_period
|
||||
# Here we truncate to ensure that we don't fail on minute differences like 10.1 instead of 10
|
||||
assert minimum <= int(duration) <= self.grace_period
|
||||
|
||||
def stream_request(self, config, connected, early_terminate):
|
||||
expected_terminate_message = "502 Bad Gateway"
|
||||
|
|
|
@ -36,6 +36,13 @@ var (
|
|||
flushableContentTypes = []string{sseContentType, grpcContentType}
|
||||
)
|
||||
|
||||
// TunnelConnection represents the connection to the edge.
|
||||
// The Serve method is provided to allow clients to handle any errors from the connection encountered during
|
||||
// processing of the connection. Cancelling of the context provided to Serve will close the connection.
|
||||
type TunnelConnection interface {
|
||||
Serve(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Orchestrator interface {
|
||||
UpdateConfig(version int32, config []byte) *pogs.UpdateConfigurationResponse
|
||||
GetConfigJSON() ([]byte, error)
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
|
@ -100,7 +102,7 @@ func (c *controlStream) ServeControlStream(
|
|||
c.observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
|
||||
|
||||
c.observer.logConnected(registrationDetails.UUID, c.connIndex, registrationDetails.Location, c.edgeAddress, c.protocol)
|
||||
c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location)
|
||||
c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location, c.edgeAddress)
|
||||
c.connectedFuse.Connected()
|
||||
|
||||
// if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration
|
||||
|
@ -116,27 +118,32 @@ func (c *controlStream) ServeControlStream(
|
|||
}
|
||||
}
|
||||
|
||||
c.waitForUnregister(ctx, registrationClient)
|
||||
return nil
|
||||
return c.waitForUnregister(ctx, registrationClient)
|
||||
}
|
||||
|
||||
func (c *controlStream) waitForUnregister(ctx context.Context, registrationClient tunnelrpc.RegistrationClient) {
|
||||
func (c *controlStream) waitForUnregister(ctx context.Context, registrationClient tunnelrpc.RegistrationClient) error {
|
||||
// wait for connection termination or start of graceful shutdown
|
||||
defer registrationClient.Close()
|
||||
var shutdownError error
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
shutdownError = ctx.Err()
|
||||
break
|
||||
case <-c.gracefulShutdownC:
|
||||
c.stoppedGracefully = true
|
||||
}
|
||||
|
||||
c.observer.sendUnregisteringEvent(c.connIndex)
|
||||
registrationClient.GracefulShutdown(ctx, c.gracePeriod)
|
||||
err := registrationClient.GracefulShutdown(ctx, c.gracePeriod)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Error shutting down control stream")
|
||||
}
|
||||
c.observer.log.Info().
|
||||
Int(management.EventTypeKey, int(management.Cloudflared)).
|
||||
Uint8(LogFieldConnIndex, c.connIndex).
|
||||
IPAddr(LogFieldIPAddress, c.edgeAddress).
|
||||
Msg("Unregistered tunnel connection")
|
||||
return shutdownError
|
||||
}
|
||||
|
||||
func (c *controlStream) IsStopped() bool {
|
||||
|
|
|
@ -2,7 +2,6 @@ package connection
|
|||
|
||||
import (
|
||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
|
@ -71,8 +70,6 @@ func isHandshakeErrRecoverable(err error, connIndex uint8, observer *Observer) b
|
|||
switch err.(type) {
|
||||
case edgediscovery.DialError:
|
||||
log.Error().Msg("Connection unable to dial edge")
|
||||
case h2mux.MuxerHandshakeError:
|
||||
log.Error().Msg("Connection handshake with edge server failed")
|
||||
default:
|
||||
log.Error().Msg("Connection failed")
|
||||
return false
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
package connection
|
||||
|
||||
import "net"
|
||||
|
||||
// Event is something that happened to a connection, e.g. disconnection or registration.
|
||||
type Event struct {
|
||||
Index uint8
|
||||
EventType Status
|
||||
Location string
|
||||
Protocol Protocol
|
||||
URL string
|
||||
Index uint8
|
||||
EventType Status
|
||||
Location string
|
||||
Protocol Protocol
|
||||
URL string
|
||||
EdgeAddress net.IP
|
||||
}
|
||||
|
||||
// Status is the status of a connection.
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
)
|
||||
|
||||
const (
|
||||
muxerTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type MuxerConfig struct {
|
||||
HeartbeatInterval time.Duration
|
||||
MaxHeartbeats uint64
|
||||
CompressionSetting h2mux.CompressionSetting
|
||||
MetricsUpdateFreq time.Duration
|
||||
}
|
||||
|
||||
func (mc *MuxerConfig) H2MuxerConfig(h h2mux.MuxedStreamHandler, log *zerolog.Logger) *h2mux.MuxerConfig {
|
||||
return &h2mux.MuxerConfig{
|
||||
Timeout: muxerTimeout,
|
||||
Handler: h,
|
||||
IsClient: true,
|
||||
HeartbeatInterval: mc.HeartbeatInterval,
|
||||
MaxHeartbeats: mc.MaxHeartbeats,
|
||||
Log: log,
|
||||
CompressionQuality: mc.CompressionSetting,
|
||||
}
|
||||
}
|
|
@ -1,128 +0,0 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
)
|
||||
|
||||
// H2RequestHeadersToH1Request converts the HTTP/2 headers coming from origintunneld
|
||||
// to an HTTP/1 Request object destined for the local origin web service.
|
||||
// This operation includes conversion of the pseudo-headers into their closest
|
||||
// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3
|
||||
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
|
||||
for _, header := range h2 {
|
||||
name := strings.ToLower(header.Name)
|
||||
if !IsH2muxControlRequestHeader(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch name {
|
||||
case ":method":
|
||||
h1.Method = header.Value
|
||||
case ":scheme":
|
||||
// noop - use the preexisting scheme from h1.URL
|
||||
case ":authority":
|
||||
// Otherwise the host header will be based on the origin URL
|
||||
h1.Host = header.Value
|
||||
case ":path":
|
||||
// We don't want to be an "opinionated" proxy, so ideally we would use :path as-is.
|
||||
// However, this HTTP/1 Request object belongs to the Go standard library,
|
||||
// whose URL package makes some opinionated decisions about the encoding of
|
||||
// URL characters: see the docs of https://godoc.org/net/url#URL,
|
||||
// in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath,
|
||||
// which is always used when computing url.URL.String(), whether we'd like it or not.
|
||||
//
|
||||
// Well, not *always*. We could circumvent this by using url.URL.Opaque. But
|
||||
// that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque
|
||||
// is treated differently when HTTP_PROXY is set!
|
||||
// See https://github.com/golang/go/issues/5684#issuecomment-66080888
|
||||
//
|
||||
// This means we are subject to the behavior of net/url's function `shouldEscape`
|
||||
// (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101
|
||||
|
||||
if header.Value == "*" {
|
||||
h1.URL.Path = "*"
|
||||
continue
|
||||
}
|
||||
// Due to the behavior of validation.ValidateUrl, h1.URL may
|
||||
// already have a partial value, with or without a trailing slash.
|
||||
base := h1.URL.String()
|
||||
base = strings.TrimRight(base, "/")
|
||||
// But we know :path begins with '/', because we handled '*' above - see RFC7540
|
||||
requestURL, err := url.Parse(base + header.Value)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value))
|
||||
}
|
||||
h1.URL = requestURL
|
||||
case "content-length":
|
||||
contentLength, err := strconv.ParseInt(header.Value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unparseable content length")
|
||||
}
|
||||
h1.ContentLength = contentLength
|
||||
case RequestUserHeaders:
|
||||
// Do not forward the serialized headers to the origin -- deserialize them, and ditch the serialized version
|
||||
// Find and parse user headers serialized into a single one
|
||||
userHeaders, err := DeserializeHeaders(header.Value)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Unable to parse user headers")
|
||||
}
|
||||
for _, userHeader := range userHeaders {
|
||||
h1.Header.Add(userHeader.Name, userHeader.Value)
|
||||
}
|
||||
default:
|
||||
// All other control headers shall just be proxied transparently
|
||||
h1.Header.Add(header.Name, header.Value)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func H1ResponseToH2ResponseHeaders(status int, h1 http.Header) (h2 []h2mux.Header) {
|
||||
h2 = []h2mux.Header{
|
||||
{Name: ":status", Value: strconv.Itoa(status)},
|
||||
}
|
||||
userHeaders := make(http.Header, len(h1))
|
||||
for header, values := range h1 {
|
||||
h2name := strings.ToLower(header)
|
||||
if h2name == "content-length" {
|
||||
// This header has meaning in HTTP/2 and will be used by the edge,
|
||||
// so it should be sent as an HTTP/2 response header.
|
||||
|
||||
// Since these are http2 headers, they're required to be lowercase
|
||||
h2 = append(h2, h2mux.Header{Name: "content-length", Value: values[0]})
|
||||
} else if !IsH2muxControlResponseHeader(h2name) || IsWebsocketClientHeader(h2name) {
|
||||
// User headers, on the other hand, must all be serialized so that
|
||||
// HTTP/2 header validation won't be applied to HTTP/1 header values
|
||||
userHeaders[header] = values
|
||||
}
|
||||
}
|
||||
|
||||
// Perform user header serialization and set them in the single header
|
||||
h2 = append(h2, h2mux.Header{Name: ResponseUserHeaders, Value: SerializeHeaders(userHeaders)})
|
||||
return h2
|
||||
}
|
||||
|
||||
// IsH2muxControlRequestHeader is called in the direction of eyeball -> origin.
|
||||
func IsH2muxControlRequestHeader(headerName string) bool {
|
||||
return headerName == "content-length" ||
|
||||
headerName == "connection" || headerName == "upgrade" || // Websocket request headers
|
||||
strings.HasPrefix(headerName, ":") ||
|
||||
strings.HasPrefix(headerName, "cf-")
|
||||
}
|
||||
|
||||
// IsH2muxControlResponseHeader is called in the direction of eyeball <- origin.
|
||||
func IsH2muxControlResponseHeader(headerName string) bool {
|
||||
return headerName == "content-length" ||
|
||||
strings.HasPrefix(headerName, ":") ||
|
||||
strings.HasPrefix(headerName, "cf-int-") ||
|
||||
strings.HasPrefix(headerName, "cf-cloudflared-")
|
||||
}
|
|
@ -1,642 +0,0 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
)
|
||||
|
||||
type ByName []h2mux.Header
|
||||
|
||||
func (a ByName) Len() int { return len(a) }
|
||||
func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByName) Less(i, j int) bool {
|
||||
if a[i].Name == a[j].Name {
|
||||
return a[i].Value < a[j].Value
|
||||
}
|
||||
|
||||
return a[i].Name < a[j].Name
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockHeaders := http.Header{
|
||||
"Mock header 1": {"Mock value 1"},
|
||||
"Mock header 2": {"Mock value 2"},
|
||||
}
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(createSerializedHeaders(RequestUserHeaders, mockHeaders), request)
|
||||
|
||||
assert.True(t, reflect.DeepEqual(mockHeaders, request.Header))
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func createSerializedHeaders(headersField string, headers http.Header) []h2mux.Header {
|
||||
return []h2mux.Header{{
|
||||
Name: headersField,
|
||||
Value: SerializeHeaders(headers),
|
||||
}}
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
emptyHeaders := make(http.Header)
|
||||
headersConversionErr := H2RequestHeadersToH1Request(
|
||||
[]h2mux.Header{{
|
||||
Name: RequestUserHeaders,
|
||||
Value: SerializeHeaders(emptyHeaders),
|
||||
}},
|
||||
request,
|
||||
)
|
||||
|
||||
assert.True(t, reflect.DeepEqual(emptyHeaders, request.Header))
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockRequestHeaders := []h2mux.Header{
|
||||
{Name: ":path", Value: "//bad_path/"},
|
||||
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
||||
}
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
||||
|
||||
assert.Equal(t, http.Header{
|
||||
"Mock header": []string{"Mock value"},
|
||||
}, request.Header)
|
||||
|
||||
assert.Equal(t, "http://example.com//bad_path/", request.URL.String())
|
||||
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockRequestHeaders := []h2mux.Header{
|
||||
{Name: ":path", Value: "/?query=mock%20value"},
|
||||
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
||||
}
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
||||
|
||||
assert.Equal(t, http.Header{
|
||||
"Mock header": []string{"Mock value"},
|
||||
}, request.Header)
|
||||
|
||||
assert.Equal(t, "http://example.com/?query=mock%20value", request.URL.String())
|
||||
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockRequestHeaders := []h2mux.Header{
|
||||
{Name: ":path", Value: "/mock%20path"},
|
||||
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
||||
}
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
||||
|
||||
assert.Equal(t, http.Header{
|
||||
"Mock header": []string{"Mock value"},
|
||||
}, request.Header)
|
||||
|
||||
assert.Equal(t, "http://example.com/mock%20path", request.URL.String())
|
||||
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) {
|
||||
type testCase struct {
|
||||
path string
|
||||
want string
|
||||
}
|
||||
testCases := []testCase{
|
||||
{
|
||||
path: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
path: "/",
|
||||
want: "/",
|
||||
},
|
||||
{
|
||||
path: "//",
|
||||
want: "//",
|
||||
},
|
||||
{
|
||||
path: "/test",
|
||||
want: "/test",
|
||||
},
|
||||
{
|
||||
path: "//test",
|
||||
want: "//test",
|
||||
},
|
||||
{
|
||||
// https://github.com/cloudflare/cloudflared/issues/81
|
||||
path: "//test/",
|
||||
want: "//test/",
|
||||
},
|
||||
{
|
||||
path: "/%2Ftest",
|
||||
want: "/%2Ftest",
|
||||
},
|
||||
{
|
||||
path: "//%20test",
|
||||
want: "//%20test",
|
||||
},
|
||||
{
|
||||
// https://github.com/cloudflare/cloudflared/issues/124
|
||||
path: "/test?get=somthing%20a",
|
||||
want: "/test?get=somthing%20a",
|
||||
},
|
||||
{
|
||||
path: "/%20",
|
||||
want: "/%20",
|
||||
},
|
||||
{
|
||||
// stdlib's EscapedPath() will always percent-encode ' '
|
||||
path: "/ ",
|
||||
want: "/%20",
|
||||
},
|
||||
{
|
||||
path: "/ a ",
|
||||
want: "/%20a%20",
|
||||
},
|
||||
{
|
||||
path: "/a%20b",
|
||||
want: "/a%20b",
|
||||
},
|
||||
{
|
||||
path: "/foo/bar;param?query#frag",
|
||||
want: "/foo/bar;param?query#frag",
|
||||
},
|
||||
{
|
||||
// stdlib's EscapedPath() will always percent-encode non-ASCII chars
|
||||
path: "/a␠b",
|
||||
want: "/a%E2%90%A0b",
|
||||
},
|
||||
{
|
||||
path: "/a-umlaut-ä",
|
||||
want: "/a-umlaut-%C3%A4",
|
||||
},
|
||||
{
|
||||
path: "/a-umlaut-%C3%A4",
|
||||
want: "/a-umlaut-%C3%A4",
|
||||
},
|
||||
{
|
||||
path: "/a-umlaut-%c3%a4",
|
||||
want: "/a-umlaut-%c3%a4",
|
||||
},
|
||||
{
|
||||
// here the second '#' is treated as part of the fragment
|
||||
path: "/a#b#c",
|
||||
want: "/a#b%23c",
|
||||
},
|
||||
{
|
||||
path: "/a#b␠c",
|
||||
want: "/a#b%E2%90%A0c",
|
||||
},
|
||||
{
|
||||
path: "/a#b%20c",
|
||||
want: "/a#b%20c",
|
||||
},
|
||||
{
|
||||
path: "/a#b c",
|
||||
want: "/a#b%20c",
|
||||
},
|
||||
{
|
||||
// stdlib's EscapedPath() will always percent-encode '\'
|
||||
path: "/\\",
|
||||
want: "/%5C",
|
||||
},
|
||||
{
|
||||
path: "/a\\",
|
||||
want: "/a%5C",
|
||||
},
|
||||
{
|
||||
path: "/a,b.c.",
|
||||
want: "/a,b.c.",
|
||||
},
|
||||
{
|
||||
path: "/.",
|
||||
want: "/.",
|
||||
},
|
||||
{
|
||||
// stdlib's EscapedPath() will always percent-encode '`'
|
||||
path: "/a`",
|
||||
want: "/a%60",
|
||||
},
|
||||
{
|
||||
path: "/a[0]",
|
||||
want: "/a[0]",
|
||||
},
|
||||
{
|
||||
path: "/?a[0]=5 &b[]=",
|
||||
want: "/?a[0]=5 &b[]=",
|
||||
},
|
||||
{
|
||||
path: "/?a=%22b%20%22",
|
||||
want: "/?a=%22b%20%22",
|
||||
},
|
||||
}
|
||||
|
||||
for index, testCase := range testCases {
|
||||
requestURL := "https://example.com"
|
||||
|
||||
request, err := http.NewRequest(http.MethodGet, requestURL, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockRequestHeaders := []h2mux.Header{
|
||||
{Name: ":path", Value: testCase.path},
|
||||
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
||||
}
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
||||
assert.NoError(t, headersConversionErr)
|
||||
|
||||
assert.Equal(t,
|
||||
http.Header{
|
||||
"Mock header": []string{"Mock value"},
|
||||
},
|
||||
request.Header)
|
||||
|
||||
assert.Equal(t,
|
||||
"https://example.com"+testCase.want,
|
||||
request.URL.String(),
|
||||
"Failed URL index: %v %#v", index, testCase)
|
||||
}
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) {
|
||||
config := &quick.Config{
|
||||
Values: func(args []reflect.Value, rand *rand.Rand) {
|
||||
args[0] = reflect.ValueOf(randomHTTP2Path(t, rand))
|
||||
},
|
||||
}
|
||||
|
||||
type testOrigin struct {
|
||||
url string
|
||||
|
||||
expectedScheme string
|
||||
expectedBasePath string
|
||||
}
|
||||
testOrigins := []testOrigin{
|
||||
{
|
||||
url: "http://origin.hostname.example.com:8080",
|
||||
expectedScheme: "http",
|
||||
expectedBasePath: "http://origin.hostname.example.com:8080",
|
||||
},
|
||||
{
|
||||
url: "http://origin.hostname.example.com:8080/",
|
||||
expectedScheme: "http",
|
||||
expectedBasePath: "http://origin.hostname.example.com:8080",
|
||||
},
|
||||
{
|
||||
url: "http://origin.hostname.example.com:8080/api",
|
||||
expectedScheme: "http",
|
||||
expectedBasePath: "http://origin.hostname.example.com:8080/api",
|
||||
},
|
||||
{
|
||||
url: "http://origin.hostname.example.com:8080/api/",
|
||||
expectedScheme: "http",
|
||||
expectedBasePath: "http://origin.hostname.example.com:8080/api",
|
||||
},
|
||||
{
|
||||
url: "https://origin.hostname.example.com:8080/api",
|
||||
expectedScheme: "https",
|
||||
expectedBasePath: "https://origin.hostname.example.com:8080/api",
|
||||
},
|
||||
}
|
||||
|
||||
// use multiple schemes to demonstrate that the URL is based on the
|
||||
// origin's scheme, not the :scheme header
|
||||
for _, testScheme := range []string{"http", "https"} {
|
||||
for _, testOrigin := range testOrigins {
|
||||
assertion := func(testPath string) bool {
|
||||
const expectedMethod = "POST"
|
||||
const expectedHostname = "request.hostname.example.com"
|
||||
|
||||
h2 := []h2mux.Header{
|
||||
{Name: ":method", Value: expectedMethod},
|
||||
{Name: ":scheme", Value: testScheme},
|
||||
{Name: ":authority", Value: expectedHostname},
|
||||
{Name: ":path", Value: testPath},
|
||||
{Name: RequestUserHeaders, Value: ""},
|
||||
}
|
||||
h1, err := http.NewRequest("GET", testOrigin.url, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = H2RequestHeadersToH1Request(h2, h1)
|
||||
return assert.NoError(t, err) &&
|
||||
assert.Equal(t, expectedMethod, h1.Method) &&
|
||||
assert.Equal(t, expectedHostname, h1.Host) &&
|
||||
assert.Equal(t, testOrigin.expectedScheme, h1.URL.Scheme) &&
|
||||
assert.Equal(t, testOrigin.expectedBasePath+testPath, h1.URL.String())
|
||||
}
|
||||
err := quick.Check(assertion, config)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func randomASCIIPrintableChar(rand *rand.Rand) int {
|
||||
// smallest printable ASCII char is 32, largest is 126
|
||||
const startPrintable = 32
|
||||
const endPrintable = 127
|
||||
return startPrintable + rand.Intn(endPrintable-startPrintable)
|
||||
}
|
||||
|
||||
// randomASCIIText generates an ASCII string, some of whose characters may be
|
||||
// percent-encoded. Its "logical length" (ignoring percent-encoding) is
|
||||
// between 1 and `maxLength`.
|
||||
func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string {
|
||||
length := minLength + rand.Intn(maxLength)
|
||||
var result strings.Builder
|
||||
for i := 0; i < length; i++ {
|
||||
c := randomASCIIPrintableChar(rand)
|
||||
|
||||
// 1/4 chance of using percent encoding when not necessary
|
||||
if c == '%' || rand.Intn(4) == 0 {
|
||||
result.WriteString(fmt.Sprintf("%%%02X", c))
|
||||
} else {
|
||||
result.WriteByte(byte(c))
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// Calls `randomASCIIText` and ensures the result is a valid URL path,
|
||||
// i.e. one that can pass unchanged through url.URL.String()
|
||||
func randomHTTP1Path(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
|
||||
text := randomASCIIText(rand, minLength, maxLength)
|
||||
re, err := regexp.Compile("[^/;,]*")
|
||||
require.NoError(t, err)
|
||||
return "/" + re.ReplaceAllStringFunc(text, url.PathEscape)
|
||||
}
|
||||
|
||||
// Calls `randomASCIIText` and ensures the result is a valid URL query,
|
||||
// i.e. one that can pass unchanged through url.URL.String()
|
||||
func randomHTTP1Query(rand *rand.Rand, minLength int, maxLength int) string {
|
||||
text := randomASCIIText(rand, minLength, maxLength)
|
||||
return "?" + strings.ReplaceAll(text, "#", "%23")
|
||||
}
|
||||
|
||||
// Calls `randomASCIIText` and ensures the result is a valid URL fragment,
|
||||
// i.e. one that can pass unchanged through url.URL.String()
|
||||
func randomHTTP1Fragment(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
|
||||
text := randomASCIIText(rand, minLength, maxLength)
|
||||
u, err := url.Parse("#" + text)
|
||||
require.NoError(t, err)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// Assemble a random :path pseudoheader that is legal by Go stdlib standards
|
||||
// (i.e. all characters will satisfy "net/url".shouldEscape for their respective locations)
|
||||
func randomHTTP2Path(t *testing.T, rand *rand.Rand) string {
|
||||
result := randomHTTP1Path(t, rand, 1, 64)
|
||||
if rand.Intn(2) == 1 {
|
||||
result += randomHTTP1Query(rand, 1, 32)
|
||||
}
|
||||
if rand.Intn(2) == 1 {
|
||||
result += randomHTTP1Fragment(t, rand, 1, 16)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func stdlibHeaderToH2muxHeader(headers http.Header) (h2muxHeaders []h2mux.Header) {
|
||||
for name, values := range headers {
|
||||
for _, value := range values {
|
||||
h2muxHeaders = append(h2muxHeaders, h2mux.Header{Name: name, Value: value})
|
||||
}
|
||||
}
|
||||
|
||||
return h2muxHeaders
|
||||
}
|
||||
|
||||
func TestParseRequestHeaders(t *testing.T) {
|
||||
mockUserHeadersToSerialize := http.Header{
|
||||
"Mock-Header-One": {"1", "1.5"},
|
||||
"Mock-Header-Two": {"2"},
|
||||
"Mock-Header-Three": {"3"},
|
||||
}
|
||||
|
||||
mockHeaders := []h2mux.Header{
|
||||
{Name: "One", Value: "1"}, // will be dropped
|
||||
{Name: "Cf-Two", Value: "cf-value-1"},
|
||||
{Name: "Cf-Two", Value: "cf-value-2"},
|
||||
{Name: RequestUserHeaders, Value: SerializeHeaders(mockUserHeadersToSerialize)},
|
||||
}
|
||||
|
||||
expectedHeaders := []h2mux.Header{
|
||||
{Name: "Cf-Two", Value: "cf-value-1"},
|
||||
{Name: "Cf-Two", Value: "cf-value-2"},
|
||||
{Name: "Mock-Header-One", Value: "1"},
|
||||
{Name: "Mock-Header-One", Value: "1.5"},
|
||||
{Name: "Mock-Header-Two", Value: "2"},
|
||||
{Name: "Mock-Header-Three", Value: "3"},
|
||||
}
|
||||
h1 := &http.Request{
|
||||
Header: make(http.Header),
|
||||
}
|
||||
err := H2RequestHeadersToH1Request(mockHeaders, h1)
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, expectedHeaders, stdlibHeaderToH2muxHeader(h1.Header))
|
||||
}
|
||||
|
||||
func TestIsH2muxControlRequestHeader(t *testing.T) {
|
||||
controlRequestHeaders := []string{
|
||||
// Anything that begins with cf-
|
||||
"cf-sample-header",
|
||||
|
||||
// Any http2 pseudoheader
|
||||
":sample-pseudo-header",
|
||||
|
||||
// content-length is a special case, it has to be there
|
||||
// for some requests to work (per the HTTP2 spec)
|
||||
"content-length",
|
||||
|
||||
// Websocket request headers
|
||||
"connection",
|
||||
"upgrade",
|
||||
}
|
||||
|
||||
for _, header := range controlRequestHeaders {
|
||||
assert.True(t, IsH2muxControlRequestHeader(header))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsH2muxControlResponseHeader(t *testing.T) {
|
||||
controlResponseHeaders := []string{
|
||||
// Anything that begins with cf-int- or cf-cloudflared-
|
||||
"cf-int-sample-header",
|
||||
"cf-cloudflared-sample-header",
|
||||
|
||||
// Any http2 pseudoheader
|
||||
":sample-pseudo-header",
|
||||
|
||||
// content-length is a special case, it has to be there
|
||||
// for some requests to work (per the HTTP2 spec)
|
||||
"content-length",
|
||||
}
|
||||
|
||||
for _, header := range controlResponseHeaders {
|
||||
assert.True(t, IsH2muxControlResponseHeader(header))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNotH2muxControlRequestHeader(t *testing.T) {
|
||||
notControlRequestHeaders := []string{
|
||||
"mock-header",
|
||||
"another-sample-header",
|
||||
}
|
||||
|
||||
for _, header := range notControlRequestHeaders {
|
||||
assert.False(t, IsH2muxControlRequestHeader(header))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNotH2muxControlResponseHeader(t *testing.T) {
|
||||
notControlResponseHeaders := []string{
|
||||
"mock-header",
|
||||
"another-sample-header",
|
||||
"upgrade",
|
||||
"connection",
|
||||
"cf-whatever", // On the response path, we only want to filter cf-int- and cf-cloudflared-
|
||||
}
|
||||
|
||||
for _, header := range notControlResponseHeaders {
|
||||
assert.False(t, IsH2muxControlResponseHeader(header))
|
||||
}
|
||||
}
|
||||
|
||||
func TestH1ResponseToH2ResponseHeaders(t *testing.T) {
|
||||
mockHeaders := http.Header{
|
||||
"User-header-one": {""},
|
||||
"User-header-two": {"1", "2"},
|
||||
"cf-header": {"cf-value"},
|
||||
"cf-int-header": {"cf-int-value"},
|
||||
"cf-cloudflared-header": {"cf-cloudflared-value"},
|
||||
"Content-Length": {"123"},
|
||||
}
|
||||
mockResponse := http.Response{
|
||||
StatusCode: 200,
|
||||
Header: mockHeaders,
|
||||
}
|
||||
|
||||
headers := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header)
|
||||
|
||||
serializedHeadersIndex := -1
|
||||
for i, header := range headers {
|
||||
if header.Name == ResponseUserHeaders {
|
||||
serializedHeadersIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEqual(t, -1, serializedHeadersIndex)
|
||||
actualControlHeaders := append(
|
||||
headers[:serializedHeadersIndex],
|
||||
headers[serializedHeadersIndex+1:]...,
|
||||
)
|
||||
expectedControlHeaders := []h2mux.Header{
|
||||
{Name: ":status", Value: "200"},
|
||||
{Name: "content-length", Value: "123"},
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, expectedControlHeaders, actualControlHeaders)
|
||||
|
||||
actualUserHeaders, err := DeserializeHeaders(headers[serializedHeadersIndex].Value)
|
||||
expectedUserHeaders := []h2mux.Header{
|
||||
{Name: "User-header-one", Value: ""},
|
||||
{Name: "User-header-two", Value: "1"},
|
||||
{Name: "User-header-two", Value: "2"},
|
||||
{Name: "cf-header", Value: "cf-value"},
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, expectedUserHeaders, actualUserHeaders)
|
||||
}
|
||||
|
||||
// The purpose of this test is to check that our code and the http.Header
|
||||
// implementation don't throw validation errors about header size
|
||||
func TestHeaderSize(t *testing.T) {
|
||||
largeValue := randSeq(5 * 1024 * 1024) // 5Mb
|
||||
largeHeaders := http.Header{
|
||||
"User-header": {largeValue},
|
||||
}
|
||||
mockResponse := http.Response{
|
||||
StatusCode: 200,
|
||||
Header: largeHeaders,
|
||||
}
|
||||
|
||||
serializedHeaders := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header)
|
||||
request, err := http.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||
assert.NoError(t, err)
|
||||
for _, header := range serializedHeaders {
|
||||
request.Header.Set(header.Name, header.Value)
|
||||
}
|
||||
|
||||
for _, header := range serializedHeaders {
|
||||
if header.Name != ResponseUserHeaders {
|
||||
continue
|
||||
}
|
||||
|
||||
deserializedHeaders, err := DeserializeHeaders(header.Value)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, largeValue, deserializedHeaders[0].Value)
|
||||
}
|
||||
}
|
||||
|
||||
func randSeq(n int) string {
|
||||
randomizer := rand.New(rand.NewSource(17))
|
||||
var letters = []rune(":;,+/=abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
b := make([]rune, n)
|
||||
for i := range b {
|
||||
b[i] = letters[randomizer.Intn(len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) {
|
||||
ser := "eC1mb3J3YXJkZWQtcHJvdG8:aHR0cHM;dXBncmFkZS1pbnNlY3VyZS1yZXF1ZXN0cw:MQ;YWNjZXB0LWxhbmd1YWdl:ZW4tVVMsZW47cT0wLjkscnU7cT0wLjg;YWNjZXB0LWVuY29kaW5n:Z3ppcA;eC1mb3J3YXJkZWQtZm9y:MTczLjI0NS42MC42;dXNlci1hZ2VudA:TW96aWxsYS81LjAgKE1hY2ludG9zaDsgSW50ZWwgTWFjIE9TIFggMTBfMTRfNikgQXBwbGVXZWJLaXQvNTM3LjM2IChLSFRNTCwgbGlrZSBHZWNrbykgQ2hyb21lLzg0LjAuNDE0Ny44OSBTYWZhcmkvNTM3LjM2;c2VjLWZldGNoLW1vZGU:bmF2aWdhdGU;Y2RuLWxvb3A:Y2xvdWRmbGFyZQ;c2VjLWZldGNoLWRlc3Q:ZG9jdW1lbnQ;c2VjLWZldGNoLXVzZXI:PzE;c2VjLWZldGNoLXNpdGU:bm9uZQ;Y29va2ll:X19jZmR1aWQ9ZGNkOWZjOGNjNWMxMzE0NTMyYTFkMjhlZDEyOWRhOTYwMTU2OTk1MTYzNDsgX19jZl9ibT1mYzY2MzMzYzAzZmM0MWFiZTZmOWEyYzI2ZDUwOTA0YzIxYzZhMTQ2LTE1OTU2MjIzNDEtMTgwMC1BZTVzS2pIU2NiWGVFM05mMUhrTlNQMG1tMHBLc2pQWkloVnM1Z2g1SkNHQkFhS1UxVDB2b003alBGN3FjMHVSR2NjZGcrWHdhL1EzbTJhQzdDVU4xZ2M9;YWNjZXB0:dGV4dC9odG1sLGFwcGxpY2F0aW9uL3hodG1sK3htbCxhcHBsaWNhdGlvbi94bWw7cT0wLjksaW1hZ2Uvd2VicCxpbWFnZS9hcG5nLCovKjtxPTAuOCxhcHBsaWNhdGlvbi9zaWduZWQtZXhjaGFuZ2U7dj1iMztxPTAuOQ"
|
||||
h2, _ := DeserializeHeaders(ser)
|
||||
h1 := make(http.Header)
|
||||
for _, header := range h2 {
|
||||
h1.Add(header.Name, header.Value)
|
||||
}
|
||||
h1.Add("Content-Length", "200")
|
||||
h1.Add("Cf-Something", "Else")
|
||||
h1.Add("Upgrade", "websocket")
|
||||
|
||||
h1resp := &http.Response{
|
||||
StatusCode: 200,
|
||||
Header: h1,
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = H1ResponseToH2ResponseHeaders(h1resp.StatusCode, h1resp.Header)
|
||||
}
|
||||
}
|
|
@ -7,17 +7,15 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
)
|
||||
|
||||
var (
|
||||
// h2mux-style special headers
|
||||
// internal special headers
|
||||
RequestUserHeaders = "cf-cloudflared-request-headers"
|
||||
ResponseUserHeaders = "cf-cloudflared-response-headers"
|
||||
ResponseMetaHeader = "cf-cloudflared-response-meta"
|
||||
|
||||
// h2mux-style special headers
|
||||
// internal special headers
|
||||
CanonicalResponseUserHeaders = http.CanonicalHeaderKey(ResponseUserHeaders)
|
||||
CanonicalResponseMetaHeader = http.CanonicalHeaderKey(ResponseMetaHeader)
|
||||
)
|
||||
|
@ -28,6 +26,13 @@ var (
|
|||
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
|
||||
)
|
||||
|
||||
// HTTPHeader is a custom header struct that expects only ever one value for the header.
|
||||
// This structure is used to serialize the headers and attach them to the HTTP2 request when proxying.
|
||||
type HTTPHeader struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
type responseMetaHeader struct {
|
||||
Source string `json:"src"`
|
||||
}
|
||||
|
@ -104,10 +109,10 @@ func SerializeHeaders(h1Headers http.Header) string {
|
|||
}
|
||||
|
||||
// Deserialize headers serialized by `SerializeHeader`
|
||||
func DeserializeHeaders(serializedHeaders string) ([]h2mux.Header, error) {
|
||||
func DeserializeHeaders(serializedHeaders string) ([]HTTPHeader, error) {
|
||||
const unableToDeserializeErr = "Unable to deserialize headers"
|
||||
|
||||
var deserialized []h2mux.Header
|
||||
var deserialized []HTTPHeader
|
||||
for _, serializedPair := range strings.Split(serializedHeaders, ";") {
|
||||
if len(serializedPair) == 0 {
|
||||
continue
|
||||
|
@ -130,7 +135,7 @@ func DeserializeHeaders(serializedHeaders string) ([]h2mux.Header, error) {
|
|||
return nil, errors.Wrap(err, unableToDeserializeErr)
|
||||
}
|
||||
|
||||
deserialized = append(deserialized, h2mux.Header{
|
||||
deserialized = append(deserialized, HTTPHeader{
|
||||
Name: string(deserializedName),
|
||||
Value: string(deserializedValue),
|
||||
})
|
||||
|
|
|
@ -46,18 +46,40 @@ func TestSerializeHeaders(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 13, len(deserializedHeaders))
|
||||
h2muxExpectedHeaders := stdlibHeaderToH2muxHeader(mockHeaders)
|
||||
expectedHeaders := headerToReqHeader(mockHeaders)
|
||||
|
||||
sort.Sort(ByName(deserializedHeaders))
|
||||
sort.Sort(ByName(h2muxExpectedHeaders))
|
||||
sort.Sort(ByName(expectedHeaders))
|
||||
|
||||
assert.True(
|
||||
t,
|
||||
reflect.DeepEqual(h2muxExpectedHeaders, deserializedHeaders),
|
||||
fmt.Sprintf("got = %#v, want = %#v\n", deserializedHeaders, h2muxExpectedHeaders),
|
||||
reflect.DeepEqual(expectedHeaders, deserializedHeaders),
|
||||
fmt.Sprintf("got = %#v, want = %#v\n", deserializedHeaders, expectedHeaders),
|
||||
)
|
||||
}
|
||||
|
||||
type ByName []HTTPHeader
|
||||
|
||||
func (a ByName) Len() int { return len(a) }
|
||||
func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByName) Less(i, j int) bool {
|
||||
if a[i].Name == a[j].Name {
|
||||
return a[i].Value < a[j].Value
|
||||
}
|
||||
|
||||
return a[i].Name < a[j].Name
|
||||
}
|
||||
|
||||
func headerToReqHeader(headers http.Header) (reqHeaders []HTTPHeader) {
|
||||
for name, values := range headers {
|
||||
for _, value := range values {
|
||||
reqHeaders = append(reqHeaders, HTTPHeader{Name: name, Value: value})
|
||||
}
|
||||
}
|
||||
|
||||
return reqHeaders
|
||||
}
|
||||
|
||||
func TestSerializeNoHeaders(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
|
|
@ -385,8 +385,7 @@ func determineHTTP2Type(r *http.Request) Type {
|
|||
func handleMissingRequestParts(connType Type, r *http.Request) {
|
||||
if connType == TypeHTTP {
|
||||
// http library has no guarantees that we receive a filled URL. If not, then we fill it, as we reuse the request
|
||||
// for proxying. We use the same values as we used to in h2mux. For proxying they should not matter since we
|
||||
// control the dialer on every egress proxied.
|
||||
// for proxying. For proxying they should not matter since we control the dialer on every egress proxied.
|
||||
if len(r.URL.Scheme) == 0 {
|
||||
r.URL.Scheme = "http"
|
||||
}
|
||||
|
|
|
@ -192,8 +192,9 @@ func (mc mockNamedTunnelRPCClient) RegisterConnection(
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (mc mockNamedTunnelRPCClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
|
||||
func (mc mockNamedTunnelRPCClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) error {
|
||||
close(mc.unregistered)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mockNamedTunnelRPCClient) Close() {}
|
||||
|
|
|
@ -2,11 +2,8 @@ package connection
|
|||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -16,27 +13,6 @@ const (
|
|||
configSubsystem = "config"
|
||||
)
|
||||
|
||||
type muxerMetrics struct {
|
||||
rtt *prometheus.GaugeVec
|
||||
rttMin *prometheus.GaugeVec
|
||||
rttMax *prometheus.GaugeVec
|
||||
receiveWindowAve *prometheus.GaugeVec
|
||||
sendWindowAve *prometheus.GaugeVec
|
||||
receiveWindowMin *prometheus.GaugeVec
|
||||
receiveWindowMax *prometheus.GaugeVec
|
||||
sendWindowMin *prometheus.GaugeVec
|
||||
sendWindowMax *prometheus.GaugeVec
|
||||
inBoundRateCurr *prometheus.GaugeVec
|
||||
inBoundRateMin *prometheus.GaugeVec
|
||||
inBoundRateMax *prometheus.GaugeVec
|
||||
outBoundRateCurr *prometheus.GaugeVec
|
||||
outBoundRateMin *prometheus.GaugeVec
|
||||
outBoundRateMax *prometheus.GaugeVec
|
||||
compBytesBefore *prometheus.GaugeVec
|
||||
compBytesAfter *prometheus.GaugeVec
|
||||
compRateAve *prometheus.GaugeVec
|
||||
}
|
||||
|
||||
type localConfigMetrics struct {
|
||||
pushes prometheus.Counter
|
||||
pushesErrors prometheus.Counter
|
||||
|
@ -53,7 +29,6 @@ type tunnelMetrics struct {
|
|||
regFail *prometheus.CounterVec
|
||||
rpcFail *prometheus.CounterVec
|
||||
|
||||
muxerMetrics *muxerMetrics
|
||||
tunnelsHA tunnelsForHA
|
||||
userHostnamesCounts *prometheus.CounterVec
|
||||
|
||||
|
@ -91,252 +66,6 @@ func newLocalConfigMetrics() *localConfigMetrics {
|
|||
}
|
||||
}
|
||||
|
||||
func newMuxerMetrics() *muxerMetrics {
|
||||
rtt := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "rtt",
|
||||
Help: "Round-trip time in millisecond",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(rtt)
|
||||
|
||||
rttMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "rtt_min",
|
||||
Help: "Shortest round-trip time in millisecond",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(rttMin)
|
||||
|
||||
rttMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "rtt_max",
|
||||
Help: "Longest round-trip time in millisecond",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(rttMax)
|
||||
|
||||
receiveWindowAve := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "receive_window_ave",
|
||||
Help: "Average receive window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(receiveWindowAve)
|
||||
|
||||
sendWindowAve := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "send_window_ave",
|
||||
Help: "Average send window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(sendWindowAve)
|
||||
|
||||
receiveWindowMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "receive_window_min",
|
||||
Help: "Smallest receive window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(receiveWindowMin)
|
||||
|
||||
receiveWindowMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "receive_window_max",
|
||||
Help: "Largest receive window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(receiveWindowMax)
|
||||
|
||||
sendWindowMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "send_window_min",
|
||||
Help: "Smallest send window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(sendWindowMin)
|
||||
|
||||
sendWindowMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "send_window_max",
|
||||
Help: "Largest send window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(sendWindowMax)
|
||||
|
||||
inBoundRateCurr := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "inbound_bytes_per_sec_curr",
|
||||
Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(inBoundRateCurr)
|
||||
|
||||
inBoundRateMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "inbound_bytes_per_sec_min",
|
||||
Help: "Minimum non-zero inbounding bytes per second",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(inBoundRateMin)
|
||||
|
||||
inBoundRateMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "inbound_bytes_per_sec_max",
|
||||
Help: "Maximum inbounding bytes per second",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(inBoundRateMax)
|
||||
|
||||
outBoundRateCurr := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "outbound_bytes_per_sec_curr",
|
||||
Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(outBoundRateCurr)
|
||||
|
||||
outBoundRateMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "outbound_bytes_per_sec_min",
|
||||
Help: "Minimum non-zero outbounding bytes per second",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(outBoundRateMin)
|
||||
|
||||
outBoundRateMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "outbound_bytes_per_sec_max",
|
||||
Help: "Maximum outbounding bytes per second",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(outBoundRateMax)
|
||||
|
||||
compBytesBefore := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "comp_bytes_before",
|
||||
Help: "Bytes sent via cross-stream compression, pre compression",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(compBytesBefore)
|
||||
|
||||
compBytesAfter := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "comp_bytes_after",
|
||||
Help: "Bytes sent via cross-stream compression, post compression",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(compBytesAfter)
|
||||
|
||||
compRateAve := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "comp_rate_ave",
|
||||
Help: "Average outbound cross-stream compression ratio",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(compRateAve)
|
||||
|
||||
return &muxerMetrics{
|
||||
rtt: rtt,
|
||||
rttMin: rttMin,
|
||||
rttMax: rttMax,
|
||||
receiveWindowAve: receiveWindowAve,
|
||||
sendWindowAve: sendWindowAve,
|
||||
receiveWindowMin: receiveWindowMin,
|
||||
receiveWindowMax: receiveWindowMax,
|
||||
sendWindowMin: sendWindowMin,
|
||||
sendWindowMax: sendWindowMax,
|
||||
inBoundRateCurr: inBoundRateCurr,
|
||||
inBoundRateMin: inBoundRateMin,
|
||||
inBoundRateMax: inBoundRateMax,
|
||||
outBoundRateCurr: outBoundRateCurr,
|
||||
outBoundRateMin: outBoundRateMin,
|
||||
outBoundRateMax: outBoundRateMax,
|
||||
compBytesBefore: compBytesBefore,
|
||||
compBytesAfter: compBytesAfter,
|
||||
compRateAve: compRateAve,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) {
|
||||
m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT))
|
||||
m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin))
|
||||
m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax))
|
||||
m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve)
|
||||
m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve)
|
||||
m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin))
|
||||
m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax))
|
||||
m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin))
|
||||
m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax))
|
||||
m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr))
|
||||
m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin))
|
||||
m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax))
|
||||
m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr))
|
||||
m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin))
|
||||
m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax))
|
||||
m.compBytesBefore.WithLabelValues(connectionID).Set(float64(metrics.CompBytesBefore.Value()))
|
||||
m.compBytesAfter.WithLabelValues(connectionID).Set(float64(metrics.CompBytesAfter.Value()))
|
||||
m.compRateAve.WithLabelValues(connectionID).Set(float64(metrics.CompRateAve()))
|
||||
}
|
||||
|
||||
func convertRTTMilliSec(t time.Duration) float64 {
|
||||
return float64(t / time.Millisecond)
|
||||
}
|
||||
|
||||
// Metrics that can be collected without asking the edge
|
||||
func initTunnelMetrics() *tunnelMetrics {
|
||||
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
||||
|
@ -408,7 +137,6 @@ func initTunnelMetrics() *tunnelMetrics {
|
|||
return &tunnelMetrics{
|
||||
serverLocations: serverLocations,
|
||||
oldServerLocations: make(map[string]string),
|
||||
muxerMetrics: newMuxerMetrics(),
|
||||
tunnelsHA: newTunnelsForHA(),
|
||||
regSuccess: registerSuccess,
|
||||
regFail: registerFail,
|
||||
|
@ -418,10 +146,6 @@ func initTunnelMetrics() *tunnelMetrics {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *tunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) {
|
||||
t.muxerMetrics.update(connectionID, metrics)
|
||||
}
|
||||
|
||||
func (t *tunnelMetrics) registerServerLocation(connectionID, loc string) {
|
||||
t.locationLock.Lock()
|
||||
defer t.locationLock.Unlock()
|
||||
|
|
|
@ -47,7 +47,6 @@ func (o *Observer) RegisterSink(sink EventSink) {
|
|||
}
|
||||
|
||||
func (o *Observer) logConnected(connectionID uuid.UUID, connIndex uint8, location string, address net.IP, protocol Protocol) {
|
||||
o.sendEvent(Event{Index: connIndex, EventType: Connected, Location: location})
|
||||
o.log.Info().
|
||||
Int(management.EventTypeKey, int(management.Cloudflared)).
|
||||
Str(LogFieldConnectionID, connectionID.String()).
|
||||
|
@ -63,8 +62,8 @@ func (o *Observer) sendRegisteringEvent(connIndex uint8) {
|
|||
o.sendEvent(Event{Index: connIndex, EventType: RegisteringTunnel})
|
||||
}
|
||||
|
||||
func (o *Observer) sendConnectedEvent(connIndex uint8, protocol Protocol, location string) {
|
||||
o.sendEvent(Event{Index: connIndex, EventType: Connected, Protocol: protocol, Location: location})
|
||||
func (o *Observer) sendConnectedEvent(connIndex uint8, protocol Protocol, location string, edgeAddress net.IP) {
|
||||
o.sendEvent(Event{Index: connIndex, EventType: Connected, Protocol: protocol, Location: location, EdgeAddress: edgeAddress})
|
||||
}
|
||||
|
||||
func (o *Observer) SendURL(url string) {
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
|
||||
const (
|
||||
AvailableProtocolFlagMessage = "Available protocols: 'auto' - automatically chooses the best protocol over time (the default; and also the recommended one); 'quic' - based on QUIC, relying on UDP egress to Cloudflare edge; 'http2' - using Go's HTTP2 library, relying on TCP egress to Cloudflare edge"
|
||||
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
|
||||
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge (unused, but kept for legacy reference).
|
||||
edgeH2muxTLSServerName = "cftunnel.com"
|
||||
// edgeH2TLSServerName is the server name to establish http2 connection with edge
|
||||
edgeH2TLSServerName = "h2.cftunnel.com"
|
||||
|
|
|
@ -1,51 +1,16 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/cloudflare/cloudflared/datagramsession"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
"github.com/cloudflare/cloudflared/packet"
|
||||
cfdquic "github.com/cloudflare/cloudflared/quic"
|
||||
"github.com/cloudflare/cloudflared/tracing"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic"
|
||||
)
|
||||
|
||||
const (
|
||||
// HTTPHeaderKey is used to get or set http headers in QUIC ALPN if the underlying proxy connection type is HTTP.
|
||||
HTTPHeaderKey = "HttpHeader"
|
||||
// HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP.
|
||||
HTTPMethodKey = "HttpMethod"
|
||||
// HTTPHostKey is used to get or set http Method in QUIC ALPN if the underlying proxy connection type is HTTP.
|
||||
HTTPHostKey = "HttpHost"
|
||||
|
||||
QUICMetadataFlowID = "FlowID"
|
||||
// emperically this capacity has been working well
|
||||
demuxChanCapacity = 16
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -53,46 +18,21 @@ var (
|
|||
portMapMutex sync.Mutex
|
||||
)
|
||||
|
||||
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
|
||||
type QUICConnection struct {
|
||||
session quic.Connection
|
||||
logger *zerolog.Logger
|
||||
orchestrator Orchestrator
|
||||
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
|
||||
sessionManager datagramsession.Manager
|
||||
// datagramMuxer mux/demux datagrams from quic connection
|
||||
datagramMuxer *cfdquic.DatagramMuxerV2
|
||||
packetRouter *ingress.PacketRouter
|
||||
controlStreamHandler ControlStreamHandler
|
||||
connOptions *tunnelpogs.ConnectionOptions
|
||||
connIndex uint8
|
||||
|
||||
rpcTimeout time.Duration
|
||||
streamWriteTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewQUICConnection returns a new instance of QUICConnection.
|
||||
func NewQUICConnection(
|
||||
func DialQuic(
|
||||
ctx context.Context,
|
||||
quicConfig *quic.Config,
|
||||
edgeAddr net.Addr,
|
||||
tlsConfig *tls.Config,
|
||||
edgeAddr netip.AddrPort,
|
||||
localAddr net.IP,
|
||||
connIndex uint8,
|
||||
tlsConfig *tls.Config,
|
||||
orchestrator Orchestrator,
|
||||
connOptions *tunnelpogs.ConnectionOptions,
|
||||
controlStreamHandler ControlStreamHandler,
|
||||
logger *zerolog.Logger,
|
||||
packetRouterConfig *ingress.GlobalRouterConfig,
|
||||
rpcTimeout time.Duration,
|
||||
streamWriteTimeout time.Duration,
|
||||
) (*QUICConnection, error) {
|
||||
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
|
||||
) (quic.Connection, error) {
|
||||
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, edgeAddr, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, err := quic.Dial(ctx, udpConn, edgeAddr, tlsConfig, quicConfig)
|
||||
conn, err := quic.Dial(ctx, udpConn, net.UDPAddrFromAddrPort(edgeAddr), tlsConfig, quicConfig)
|
||||
if err != nil {
|
||||
// close the udp server socket in case of error connecting to the edge
|
||||
udpConn.Close()
|
||||
|
@ -100,510 +40,22 @@ func NewQUICConnection(
|
|||
}
|
||||
|
||||
// wrap the session, so that the UDPConn is closed after session is closed.
|
||||
session = &wrapCloseableConnQuicConnection{
|
||||
session,
|
||||
conn = &wrapCloseableConnQuicConnection{
|
||||
conn,
|
||||
udpConn,
|
||||
}
|
||||
|
||||
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
|
||||
datagramMuxer := cfdquic.NewDatagramMuxerV2(session, logger, sessionDemuxChan)
|
||||
sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan)
|
||||
packetRouter := ingress.NewPacketRouter(packetRouterConfig, datagramMuxer, logger)
|
||||
|
||||
return &QUICConnection{
|
||||
session: session,
|
||||
orchestrator: orchestrator,
|
||||
logger: logger,
|
||||
sessionManager: sessionManager,
|
||||
datagramMuxer: datagramMuxer,
|
||||
packetRouter: packetRouter,
|
||||
controlStreamHandler: controlStreamHandler,
|
||||
connOptions: connOptions,
|
||||
connIndex: connIndex,
|
||||
rpcTimeout: rpcTimeout,
|
||||
streamWriteTimeout: streamWriteTimeout,
|
||||
}, nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Serve starts a QUIC session that begins accepting streams.
|
||||
func (q *QUICConnection) Serve(ctx context.Context) error {
|
||||
// origintunneld assumes the first stream is used for the control plane
|
||||
controlStream, err := q.session.OpenStream()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open a registration control stream: %w", err)
|
||||
}
|
||||
|
||||
// If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits
|
||||
// as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this
|
||||
// connection).
|
||||
// If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the
|
||||
// other goroutine as fast as possible.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
// In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control
|
||||
// stream is already fully registered before the other goroutines can proceed.
|
||||
errGroup.Go(func() error {
|
||||
defer cancel()
|
||||
return q.serveControlStream(ctx, controlStream)
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
defer cancel()
|
||||
return q.acceptStream(ctx)
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
defer cancel()
|
||||
return q.sessionManager.Serve(ctx)
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
defer cancel()
|
||||
return q.datagramMuxer.ServeReceive(ctx)
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
defer cancel()
|
||||
return q.packetRouter.Serve(ctx)
|
||||
})
|
||||
|
||||
return errGroup.Wait()
|
||||
}
|
||||
|
||||
func (q *QUICConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error {
|
||||
// This blocks until the control plane is done.
|
||||
err := q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions, q.orchestrator)
|
||||
if err != nil {
|
||||
// Not wrapping error here to be consistent with the http2 message.
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the session with no errors specified.
|
||||
func (q *QUICConnection) Close() {
|
||||
q.session.CloseWithError(0, "")
|
||||
}
|
||||
|
||||
func (q *QUICConnection) acceptStream(ctx context.Context) error {
|
||||
defer q.Close()
|
||||
for {
|
||||
quicStream, err := q.session.AcceptStream(ctx)
|
||||
if err != nil {
|
||||
// context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional.
|
||||
if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to accept QUIC stream: %w", err)
|
||||
}
|
||||
go q.runStream(quicStream)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *QUICConnection) runStream(quicStream quic.Stream) {
|
||||
ctx := quicStream.Context()
|
||||
stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
|
||||
defer stream.Close()
|
||||
|
||||
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
|
||||
// code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream.
|
||||
// So, we wrap the stream with a no-op write closer and only this method can actually close write side of the stream.
|
||||
// A call to close will simulate a close to the read-side, which will fail subsequent reads.
|
||||
noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream}
|
||||
ss := rpcquic.NewCloudflaredServer(q.handleDataStream, q, q, q.rpcTimeout)
|
||||
if err := ss.Serve(ctx, noCloseStream); err != nil {
|
||||
q.logger.Debug().Err(err).Msg("Failed to handle QUIC stream")
|
||||
|
||||
// if we received an error at this level, then close write side of stream with an error, which will result in
|
||||
// RST_STREAM frame.
|
||||
quicStream.CancelWrite(0)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *QUICConnection) handleDataStream(ctx context.Context, stream *rpcquic.RequestServerStream) error {
|
||||
request, err := stream.ReadConnectRequestData()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err, connectResponseSent := q.dispatchRequest(ctx, stream, err, request); err != nil {
|
||||
q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed")
|
||||
|
||||
// if the connectResponse was already sent and we had an error, we need to propagate it up, so that the stream is
|
||||
// closed with an RST_STREAM frame
|
||||
if connectResponseSent {
|
||||
return err
|
||||
}
|
||||
|
||||
if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil {
|
||||
return writeRespErr
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dispatchRequest will dispatch the request depending on the type and returns an error if it occurs.
|
||||
// More importantly, it also tells if the during processing of the request the ConnectResponse metadata was sent downstream.
|
||||
// This is important since it informs
|
||||
func (q *QUICConnection) dispatchRequest(ctx context.Context, stream *rpcquic.RequestServerStream, err error, request *pogs.ConnectRequest) (error, bool) {
|
||||
originProxy, err := q.orchestrator.GetOriginProxy()
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
switch request.Type {
|
||||
case pogs.ConnectionTypeHTTP, pogs.ConnectionTypeWebsocket:
|
||||
tracedReq, err := buildHTTPRequest(ctx, request, stream, q.connIndex, q.logger)
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
w := newHTTPResponseAdapter(stream)
|
||||
return originProxy.ProxyHTTP(&w, tracedReq, request.Type == pogs.ConnectionTypeWebsocket), w.connectResponseSent
|
||||
|
||||
case pogs.ConnectionTypeTCP:
|
||||
rwa := &streamReadWriteAcker{RequestServerStream: stream}
|
||||
metadata := request.MetadataMap()
|
||||
return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{
|
||||
Dest: request.Dest,
|
||||
FlowID: metadata[QUICMetadataFlowID],
|
||||
CfTraceID: metadata[tracing.TracerContextName],
|
||||
ConnIndex: q.connIndex,
|
||||
}), rwa.connectResponseSent
|
||||
default:
|
||||
return errors.Errorf("unsupported error type: %s", request.Type), false
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterUdpSession is the RPC method invoked by edge to register and run a session
|
||||
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*tunnelpogs.RegisterUdpSessionResponse, error) {
|
||||
traceCtx := tracing.NewTracedContext(ctx, traceContext, q.logger)
|
||||
ctx, registerSpan := traceCtx.Tracer().Start(traceCtx, "register-session", trace.WithAttributes(
|
||||
attribute.String("session-id", sessionID.String()),
|
||||
attribute.String("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)),
|
||||
))
|
||||
log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger()
|
||||
// Each session is a series of datagram from an eyeball to a dstIP:dstPort.
|
||||
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
|
||||
originProxy, err := ingress.DialUDP(dstIP, dstPort)
|
||||
if err != nil {
|
||||
log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
|
||||
tracing.EndWithErrorStatus(registerSpan, err)
|
||||
return nil, err
|
||||
}
|
||||
registerSpan.SetAttributes(
|
||||
attribute.Bool("socket-bind-success", true),
|
||||
attribute.String("src", originProxy.LocalAddr().String()),
|
||||
)
|
||||
|
||||
session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy)
|
||||
if err != nil {
|
||||
originProxy.Close()
|
||||
log.Err(err).Str("sessionID", sessionID.String()).Msgf("Failed to register udp session")
|
||||
tracing.EndWithErrorStatus(registerSpan, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go q.serveUDPSession(session, closeAfterIdleHint)
|
||||
|
||||
log.Debug().
|
||||
Str("sessionID", sessionID.String()).
|
||||
Str("src", originProxy.LocalAddr().String()).
|
||||
Str("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)).
|
||||
Msgf("Registered session")
|
||||
tracing.End(registerSpan)
|
||||
|
||||
resp := tunnelpogs.RegisterUdpSessionResponse{
|
||||
Spans: traceCtx.GetProtoSpans(),
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (q *QUICConnection) serveUDPSession(session *datagramsession.Session, closeAfterIdleHint time.Duration) {
|
||||
ctx := q.session.Context()
|
||||
closedByRemote, err := session.Serve(ctx, closeAfterIdleHint)
|
||||
// If session is terminated by remote, then we know it has been unregistered from session manager and edge
|
||||
if !closedByRemote {
|
||||
if err != nil {
|
||||
q.closeUDPSession(ctx, session.ID, err.Error())
|
||||
} else {
|
||||
q.closeUDPSession(ctx, session.ID, "terminated without error")
|
||||
}
|
||||
}
|
||||
q.logger.Debug().Err(err).
|
||||
Int(management.EventTypeKey, int(management.UDP)).
|
||||
Str("sessionID", session.ID.String()).
|
||||
Msg("Session terminated")
|
||||
}
|
||||
|
||||
// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge
|
||||
func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) {
|
||||
q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
|
||||
quicStream, err := q.session.OpenStream()
|
||||
if err != nil {
|
||||
// Log this at debug because this is not an error if session was closed due to lost connection
|
||||
// with edge
|
||||
q.logger.Debug().Err(err).
|
||||
Int(management.EventTypeKey, int(management.UDP)).
|
||||
Str("sessionID", sessionID.String()).
|
||||
Msgf("Failed to open quic stream to unregister udp session with edge")
|
||||
return
|
||||
}
|
||||
|
||||
stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
|
||||
defer stream.Close()
|
||||
rpcClientStream, err := rpcquic.NewSessionClient(ctx, stream, q.rpcTimeout)
|
||||
if err != nil {
|
||||
// Log this at debug because this is not an error if session was closed due to lost connection
|
||||
// with edge
|
||||
q.logger.Err(err).Str("sessionID", sessionID.String()).
|
||||
Msgf("Failed to open rpc stream to unregister udp session with edge")
|
||||
return
|
||||
}
|
||||
defer rpcClientStream.Close()
|
||||
|
||||
if err := rpcClientStream.UnregisterUdpSession(ctx, sessionID, message); err != nil {
|
||||
q.logger.Err(err).Str("sessionID", sessionID.String()).
|
||||
Msgf("Failed to unregister udp session with edge")
|
||||
}
|
||||
}
|
||||
|
||||
// UnregisterUdpSession is the RPC method invoked by edge to unregister and terminate a sesssion
|
||||
func (q *QUICConnection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
|
||||
return q.sessionManager.UnregisterSession(ctx, sessionID, message, true)
|
||||
}
|
||||
|
||||
// UpdateConfiguration is the RPC method invoked by edge when there is a new configuration
|
||||
func (q *QUICConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
|
||||
return q.orchestrator.UpdateConfig(version, config)
|
||||
}
|
||||
|
||||
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
|
||||
// the client.
|
||||
type streamReadWriteAcker struct {
|
||||
*rpcquic.RequestServerStream
|
||||
connectResponseSent bool
|
||||
}
|
||||
|
||||
// AckConnection acks response back to the proxy.
|
||||
func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error {
|
||||
metadata := []pogs.Metadata{}
|
||||
// Only add tracing if provided by origintunneld
|
||||
if tracePropagation != "" {
|
||||
metadata = append(metadata, pogs.Metadata{
|
||||
Key: tracing.CanonicalCloudflaredTracingHeader,
|
||||
Val: tracePropagation,
|
||||
})
|
||||
}
|
||||
s.connectResponseSent = true
|
||||
return s.WriteConnectResponseData(nil, metadata...)
|
||||
}
|
||||
|
||||
// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
|
||||
type httpResponseAdapter struct {
|
||||
*rpcquic.RequestServerStream
|
||||
headers http.Header
|
||||
connectResponseSent bool
|
||||
}
|
||||
|
||||
func newHTTPResponseAdapter(s *rpcquic.RequestServerStream) httpResponseAdapter {
|
||||
return httpResponseAdapter{RequestServerStream: s, headers: make(http.Header)}
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) {
|
||||
// we do not support trailers over QUIC
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
|
||||
metadata := make([]pogs.Metadata, 0)
|
||||
metadata = append(metadata, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})
|
||||
for k, vv := range header {
|
||||
for _, v := range vv {
|
||||
httpHeaderKey := fmt.Sprintf("%s:%s", HTTPHeaderKey, k)
|
||||
metadata = append(metadata, pogs.Metadata{Key: httpHeaderKey, Val: v})
|
||||
}
|
||||
}
|
||||
|
||||
return hrw.WriteConnectResponseData(nil, metadata...)
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) Write(p []byte) (int, error) {
|
||||
// Make sure to send WriteHeader response if not called yet
|
||||
if !hrw.connectResponseSent {
|
||||
hrw.WriteRespHeaders(http.StatusOK, hrw.headers)
|
||||
}
|
||||
return hrw.RequestServerStream.Write(p)
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) Header() http.Header {
|
||||
return hrw.headers
|
||||
}
|
||||
|
||||
// This is a no-op Flush because this adapter is over a quic.Stream and we don't need Flush here.
|
||||
func (hrw *httpResponseAdapter) Flush() {}
|
||||
|
||||
func (hrw *httpResponseAdapter) WriteHeader(status int) {
|
||||
hrw.WriteRespHeaders(status, hrw.headers)
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
conn := &localProxyConnection{hrw.ReadWriteCloser}
|
||||
readWriter := bufio.NewReadWriter(
|
||||
bufio.NewReader(hrw.ReadWriteCloser),
|
||||
bufio.NewWriter(hrw.ReadWriteCloser),
|
||||
)
|
||||
return conn, readWriter, nil
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
|
||||
hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error {
|
||||
hrw.connectResponseSent = true
|
||||
return hrw.RequestServerStream.WriteConnectResponseData(respErr, metadata...)
|
||||
}
|
||||
|
||||
func buildHTTPRequest(
|
||||
ctx context.Context,
|
||||
connectRequest *pogs.ConnectRequest,
|
||||
body io.ReadCloser,
|
||||
connIndex uint8,
|
||||
log *zerolog.Logger,
|
||||
) (*tracing.TracedHTTPRequest, error) {
|
||||
metadata := connectRequest.MetadataMap()
|
||||
dest := connectRequest.Dest
|
||||
method := metadata[HTTPMethodKey]
|
||||
host := metadata[HTTPHostKey]
|
||||
isWebsocket := connectRequest.Type == pogs.ConnectionTypeWebsocket
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, dest, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Host = host
|
||||
for _, metadata := range connectRequest.Metadata {
|
||||
if strings.Contains(metadata.Key, HTTPHeaderKey) {
|
||||
// metadata.Key is off the format httpHeaderKey:<HTTPHeader>
|
||||
httpHeaderKey := strings.Split(metadata.Key, ":")
|
||||
if len(httpHeaderKey) != 2 {
|
||||
return nil, fmt.Errorf("header Key: %s malformed", metadata.Key)
|
||||
}
|
||||
req.Header.Add(httpHeaderKey[1], metadata.Val)
|
||||
}
|
||||
}
|
||||
// Go's http.Client automatically sends chunked request body if this value is not set on the
|
||||
// *http.Request struct regardless of header:
|
||||
// https://go.googlesource.com/go/+/go1.8rc2/src/net/http/transfer.go#154.
|
||||
if err := setContentLength(req); err != nil {
|
||||
return nil, fmt.Errorf("Error setting content-length: %w", err)
|
||||
}
|
||||
|
||||
// Go's client defaults to chunked encoding after a 200ms delay if the following cases are true:
|
||||
// * the request body blocks
|
||||
// * the content length is not set (or set to -1)
|
||||
// * the method doesn't usually have a body (GET, HEAD, DELETE, ...)
|
||||
// * there is no transfer-encoding=chunked already set.
|
||||
// So, if transfer cannot be chunked and content length is 0, we dont set a request body.
|
||||
if !isWebsocket && !isTransferEncodingChunked(req) && req.ContentLength == 0 {
|
||||
req.Body = http.NoBody
|
||||
}
|
||||
stripWebsocketUpgradeHeader(req)
|
||||
|
||||
// Check for tracing on request
|
||||
tracedReq := tracing.NewTracedHTTPRequest(req, connIndex, log)
|
||||
return tracedReq, err
|
||||
}
|
||||
|
||||
func setContentLength(req *http.Request) error {
|
||||
var err error
|
||||
if contentLengthStr := req.Header.Get("Content-Length"); contentLengthStr != "" {
|
||||
req.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func isTransferEncodingChunked(req *http.Request) bool {
|
||||
transferEncodingVal := req.Header.Get("Transfer-Encoding")
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding suggests that this can be a comma
|
||||
// separated value as well.
|
||||
return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
|
||||
}
|
||||
|
||||
// A helper struct that guarantees a call to close only affects read side, but not write side.
|
||||
type nopCloserReadWriter struct {
|
||||
io.ReadWriteCloser
|
||||
|
||||
// for use by Read only
|
||||
// we don't need a memory barrier here because there is an implicit assumption that
|
||||
// Read calls can't happen concurrently by different go-routines.
|
||||
sawEOF bool
|
||||
// should be updated and read using atomic primitives.
|
||||
// value is read in Read method and written in Close method, which could be done by different
|
||||
// go-routines.
|
||||
closed uint32
|
||||
}
|
||||
|
||||
func (np *nopCloserReadWriter) Read(p []byte) (n int, err error) {
|
||||
if np.sawEOF {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&np.closed) > 0 {
|
||||
return 0, fmt.Errorf("closed by handler")
|
||||
}
|
||||
|
||||
n, err = np.ReadWriteCloser.Read(p)
|
||||
if err == io.EOF {
|
||||
np.sawEOF = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (np *nopCloserReadWriter) Close() error {
|
||||
atomic.StoreUint32(&np.closed, 1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// muxerWrapper wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface
|
||||
type muxerWrapper struct {
|
||||
muxer *cfdquic.DatagramMuxerV2
|
||||
}
|
||||
|
||||
func (rp *muxerWrapper) SendPacket(dst netip.Addr, pk packet.RawPacket) error {
|
||||
return rp.muxer.SendPacket(cfdquic.RawPacket(pk))
|
||||
}
|
||||
|
||||
func (rp *muxerWrapper) ReceivePacket(ctx context.Context) (packet.RawPacket, error) {
|
||||
pk, err := rp.muxer.ReceivePacket(ctx)
|
||||
if err != nil {
|
||||
return packet.RawPacket{}, err
|
||||
}
|
||||
rawPacket, ok := pk.(cfdquic.RawPacket)
|
||||
if ok {
|
||||
return packet.RawPacket(rawPacket), nil
|
||||
}
|
||||
return packet.RawPacket{}, fmt.Errorf("unexpected packet type %+v", pk)
|
||||
}
|
||||
|
||||
func (rp *muxerWrapper) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, logger *zerolog.Logger) (*net.UDPConn, error) {
|
||||
func createUDPConnForConnIndex(connIndex uint8, localIP net.IP, edgeIP netip.AddrPort, logger *zerolog.Logger) (*net.UDPConn, error) {
|
||||
portMapMutex.Lock()
|
||||
defer portMapMutex.Unlock()
|
||||
|
||||
if localIP == nil {
|
||||
localIP = net.IPv4zero
|
||||
}
|
||||
|
||||
listenNetwork := "udp"
|
||||
// https://github.com/quic-go/quic-go/issues/3793 DF bit cannot be set for dual stack listener on OSX
|
||||
// https://github.com/quic-go/quic-go/issues/3793 DF bit cannot be set for dual stack listener ("udp") on macOS,
|
||||
// to set the DF bit properly, the network string needs to be specific to the IP family.
|
||||
if runtime.GOOS == "darwin" {
|
||||
if localIP.To4() != nil {
|
||||
if edgeIP.Addr().Is4() {
|
||||
listenNetwork = "udp4"
|
||||
} else {
|
||||
listenNetwork = "udp6"
|
||||
|
|
|
@ -0,0 +1,417 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
cfdquic "github.com/cloudflare/cloudflared/quic"
|
||||
"github.com/cloudflare/cloudflared/tracing"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic"
|
||||
)
|
||||
|
||||
const (
|
||||
// HTTPHeaderKey is used to get or set http headers in QUIC ALPN if the underlying proxy connection type is HTTP.
|
||||
HTTPHeaderKey = "HttpHeader"
|
||||
// HTTPMethodKey is used to get or set http method in QUIC ALPN if the underlying proxy connection type is HTTP.
|
||||
HTTPMethodKey = "HttpMethod"
|
||||
// HTTPHostKey is used to get or set http host in QUIC ALPN if the underlying proxy connection type is HTTP.
|
||||
HTTPHostKey = "HttpHost"
|
||||
|
||||
QUICMetadataFlowID = "FlowID"
|
||||
)
|
||||
|
||||
// quicConnection represents the type that facilitates Proxying via QUIC streams.
|
||||
type quicConnection struct {
|
||||
conn quic.Connection
|
||||
logger *zerolog.Logger
|
||||
orchestrator Orchestrator
|
||||
datagramHandler DatagramSessionHandler
|
||||
controlStreamHandler ControlStreamHandler
|
||||
connOptions *tunnelpogs.ConnectionOptions
|
||||
connIndex uint8
|
||||
|
||||
rpcTimeout time.Duration
|
||||
streamWriteTimeout time.Duration
|
||||
gracePeriod time.Duration
|
||||
}
|
||||
|
||||
// NewTunnelConnection takes a [quic.Connection] to wrap it for use with cloudflared application logic.
|
||||
func NewTunnelConnection(
|
||||
ctx context.Context,
|
||||
conn quic.Connection,
|
||||
connIndex uint8,
|
||||
orchestrator Orchestrator,
|
||||
datagramSessionHandler DatagramSessionHandler,
|
||||
controlStreamHandler ControlStreamHandler,
|
||||
connOptions *pogs.ConnectionOptions,
|
||||
rpcTimeout time.Duration,
|
||||
streamWriteTimeout time.Duration,
|
||||
gracePeriod time.Duration,
|
||||
logger *zerolog.Logger,
|
||||
) (TunnelConnection, error) {
|
||||
return &quicConnection{
|
||||
conn: conn,
|
||||
logger: logger,
|
||||
orchestrator: orchestrator,
|
||||
datagramHandler: datagramSessionHandler,
|
||||
controlStreamHandler: controlStreamHandler,
|
||||
connOptions: connOptions,
|
||||
connIndex: connIndex,
|
||||
rpcTimeout: rpcTimeout,
|
||||
streamWriteTimeout: streamWriteTimeout,
|
||||
gracePeriod: gracePeriod,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Serve starts a QUIC connection that begins accepting streams.
|
||||
func (q *quicConnection) Serve(ctx context.Context) error {
|
||||
// The edge assumes the first stream is used for the control plane
|
||||
controlStream, err := q.conn.OpenStream()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open a registration control stream: %w", err)
|
||||
}
|
||||
|
||||
// If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits
|
||||
// as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this
|
||||
// connection).
|
||||
// If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the
|
||||
// other goroutine as fast as possible.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
// In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control
|
||||
// stream is already fully registered before the other goroutines can proceed.
|
||||
errGroup.Go(func() error {
|
||||
// err is equal to nil if we exit due to unregistration. If that happens we want to wait the full
|
||||
// amount of the grace period, allowing requests to finish before we cancel the context, which will
|
||||
// make cloudflared exit.
|
||||
if err := q.serveControlStream(ctx, controlStream); err == nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.Tick(q.gracePeriod):
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
return err
|
||||
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
defer cancel()
|
||||
return q.acceptStream(ctx)
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
defer cancel()
|
||||
return q.datagramHandler.Serve(ctx)
|
||||
})
|
||||
|
||||
return errGroup.Wait()
|
||||
}
|
||||
|
||||
// serveControlStream will serve the RPC; blocking until the control plane is done.
|
||||
func (q *quicConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error {
|
||||
return q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions, q.orchestrator)
|
||||
}
|
||||
|
||||
// Close the connection with no errors specified.
|
||||
func (q *quicConnection) Close() {
|
||||
q.conn.CloseWithError(0, "")
|
||||
}
|
||||
|
||||
func (q *quicConnection) acceptStream(ctx context.Context) error {
|
||||
defer q.Close()
|
||||
for {
|
||||
quicStream, err := q.conn.AcceptStream(ctx)
|
||||
if err != nil {
|
||||
// context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional.
|
||||
if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to accept QUIC stream: %w", err)
|
||||
}
|
||||
go q.runStream(quicStream)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *quicConnection) runStream(quicStream quic.Stream) {
|
||||
ctx := quicStream.Context()
|
||||
stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
|
||||
defer stream.Close()
|
||||
|
||||
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
|
||||
// code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream.
|
||||
// So, we wrap the stream with a no-op write closer and only this method can actually close write side of the stream.
|
||||
// A call to close will simulate a close to the read-side, which will fail subsequent reads.
|
||||
noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream}
|
||||
ss := rpcquic.NewCloudflaredServer(q.handleDataStream, q.datagramHandler, q, q.rpcTimeout)
|
||||
if err := ss.Serve(ctx, noCloseStream); err != nil {
|
||||
q.logger.Debug().Err(err).Msg("Failed to handle QUIC stream")
|
||||
|
||||
// if we received an error at this level, then close write side of stream with an error, which will result in
|
||||
// RST_STREAM frame.
|
||||
quicStream.CancelWrite(0)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.RequestServerStream) error {
|
||||
request, err := stream.ReadConnectRequestData()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err, connectResponseSent := q.dispatchRequest(ctx, stream, request); err != nil {
|
||||
q.logger.Err(err).Str("type", request.Type.String()).Str("dest", request.Dest).Msg("Request failed")
|
||||
|
||||
// if the connectResponse was already sent and we had an error, we need to propagate it up, so that the stream is
|
||||
// closed with an RST_STREAM frame
|
||||
if connectResponseSent {
|
||||
return err
|
||||
}
|
||||
|
||||
if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil {
|
||||
return writeRespErr
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dispatchRequest will dispatch the request to the origin depending on the type and returns an error if it occurs.
|
||||
// Also returns if the connect response was sent to the downstream during processing of the origin request.
|
||||
func (q *quicConnection) dispatchRequest(ctx context.Context, stream *rpcquic.RequestServerStream, request *pogs.ConnectRequest) (err error, connectResponseSent bool) {
|
||||
originProxy, err := q.orchestrator.GetOriginProxy()
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
switch request.Type {
|
||||
case pogs.ConnectionTypeHTTP, pogs.ConnectionTypeWebsocket:
|
||||
tracedReq, err := buildHTTPRequest(ctx, request, stream, q.connIndex, q.logger)
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
w := newHTTPResponseAdapter(stream)
|
||||
return originProxy.ProxyHTTP(&w, tracedReq, request.Type == pogs.ConnectionTypeWebsocket), w.connectResponseSent
|
||||
|
||||
case pogs.ConnectionTypeTCP:
|
||||
rwa := &streamReadWriteAcker{RequestServerStream: stream}
|
||||
metadata := request.MetadataMap()
|
||||
return originProxy.ProxyTCP(ctx, rwa, &TCPRequest{
|
||||
Dest: request.Dest,
|
||||
FlowID: metadata[QUICMetadataFlowID],
|
||||
CfTraceID: metadata[tracing.TracerContextName],
|
||||
ConnIndex: q.connIndex,
|
||||
}), rwa.connectResponseSent
|
||||
default:
|
||||
return errors.Errorf("unsupported error type: %s", request.Type), false
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfiguration is the RPC method invoked by edge when there is a new configuration
|
||||
func (q *quicConnection) UpdateConfiguration(ctx context.Context, version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
|
||||
return q.orchestrator.UpdateConfig(version, config)
|
||||
}
|
||||
|
||||
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
|
||||
// the client.
|
||||
type streamReadWriteAcker struct {
|
||||
*rpcquic.RequestServerStream
|
||||
connectResponseSent bool
|
||||
}
|
||||
|
||||
// AckConnection acks response back to the proxy.
|
||||
func (s *streamReadWriteAcker) AckConnection(tracePropagation string) error {
|
||||
metadata := []pogs.Metadata{}
|
||||
// Only add tracing if provided by the edge request
|
||||
if tracePropagation != "" {
|
||||
metadata = append(metadata, pogs.Metadata{
|
||||
Key: tracing.CanonicalCloudflaredTracingHeader,
|
||||
Val: tracePropagation,
|
||||
})
|
||||
}
|
||||
s.connectResponseSent = true
|
||||
return s.WriteConnectResponseData(nil, metadata...)
|
||||
}
|
||||
|
||||
// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
|
||||
type httpResponseAdapter struct {
|
||||
*rpcquic.RequestServerStream
|
||||
headers http.Header
|
||||
connectResponseSent bool
|
||||
}
|
||||
|
||||
func newHTTPResponseAdapter(s *rpcquic.RequestServerStream) httpResponseAdapter {
|
||||
return httpResponseAdapter{RequestServerStream: s, headers: make(http.Header)}
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) {
|
||||
// we do not support trailers over QUIC
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
|
||||
metadata := make([]pogs.Metadata, 0)
|
||||
metadata = append(metadata, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(status)})
|
||||
for k, vv := range header {
|
||||
for _, v := range vv {
|
||||
httpHeaderKey := fmt.Sprintf("%s:%s", HTTPHeaderKey, k)
|
||||
metadata = append(metadata, pogs.Metadata{Key: httpHeaderKey, Val: v})
|
||||
}
|
||||
}
|
||||
|
||||
return hrw.WriteConnectResponseData(nil, metadata...)
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) Write(p []byte) (int, error) {
|
||||
// Make sure to send WriteHeader response if not called yet
|
||||
if !hrw.connectResponseSent {
|
||||
hrw.WriteRespHeaders(http.StatusOK, hrw.headers)
|
||||
}
|
||||
return hrw.RequestServerStream.Write(p)
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) Header() http.Header {
|
||||
return hrw.headers
|
||||
}
|
||||
|
||||
// This is a no-op Flush because this adapter is over a quic.Stream and we don't need Flush here.
|
||||
func (hrw *httpResponseAdapter) Flush() {}
|
||||
|
||||
func (hrw *httpResponseAdapter) WriteHeader(status int) {
|
||||
hrw.WriteRespHeaders(status, hrw.headers)
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
conn := &localProxyConnection{hrw.ReadWriteCloser}
|
||||
readWriter := bufio.NewReadWriter(
|
||||
bufio.NewReader(hrw.ReadWriteCloser),
|
||||
bufio.NewWriter(hrw.ReadWriteCloser),
|
||||
)
|
||||
return conn, readWriter, nil
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
|
||||
hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error {
|
||||
hrw.connectResponseSent = true
|
||||
return hrw.RequestServerStream.WriteConnectResponseData(respErr, metadata...)
|
||||
}
|
||||
|
||||
func buildHTTPRequest(
|
||||
ctx context.Context,
|
||||
connectRequest *pogs.ConnectRequest,
|
||||
body io.ReadCloser,
|
||||
connIndex uint8,
|
||||
log *zerolog.Logger,
|
||||
) (*tracing.TracedHTTPRequest, error) {
|
||||
metadata := connectRequest.MetadataMap()
|
||||
dest := connectRequest.Dest
|
||||
method := metadata[HTTPMethodKey]
|
||||
host := metadata[HTTPHostKey]
|
||||
isWebsocket := connectRequest.Type == pogs.ConnectionTypeWebsocket
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, dest, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Host = host
|
||||
for _, metadata := range connectRequest.Metadata {
|
||||
if strings.Contains(metadata.Key, HTTPHeaderKey) {
|
||||
// metadata.Key is off the format httpHeaderKey:<HTTPHeader>
|
||||
httpHeaderKey := strings.Split(metadata.Key, ":")
|
||||
if len(httpHeaderKey) != 2 {
|
||||
return nil, fmt.Errorf("header Key: %s malformed", metadata.Key)
|
||||
}
|
||||
req.Header.Add(httpHeaderKey[1], metadata.Val)
|
||||
}
|
||||
}
|
||||
// Go's http.Client automatically sends chunked request body if this value is not set on the
|
||||
// *http.Request struct regardless of header:
|
||||
// https://go.googlesource.com/go/+/go1.8rc2/src/net/http/transfer.go#154.
|
||||
if err := setContentLength(req); err != nil {
|
||||
return nil, fmt.Errorf("Error setting content-length: %w", err)
|
||||
}
|
||||
|
||||
// Go's client defaults to chunked encoding after a 200ms delay if the following cases are true:
|
||||
// * the request body blocks
|
||||
// * the content length is not set (or set to -1)
|
||||
// * the method doesn't usually have a body (GET, HEAD, DELETE, ...)
|
||||
// * there is no transfer-encoding=chunked already set.
|
||||
// So, if transfer cannot be chunked and content length is 0, we dont set a request body.
|
||||
if !isWebsocket && !isTransferEncodingChunked(req) && req.ContentLength == 0 {
|
||||
req.Body = http.NoBody
|
||||
}
|
||||
stripWebsocketUpgradeHeader(req)
|
||||
|
||||
// Check for tracing on request
|
||||
tracedReq := tracing.NewTracedHTTPRequest(req, connIndex, log)
|
||||
return tracedReq, err
|
||||
}
|
||||
|
||||
func setContentLength(req *http.Request) error {
|
||||
var err error
|
||||
if contentLengthStr := req.Header.Get("Content-Length"); contentLengthStr != "" {
|
||||
req.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func isTransferEncodingChunked(req *http.Request) bool {
|
||||
transferEncodingVal := req.Header.Get("Transfer-Encoding")
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding suggests that this can be a comma
|
||||
// separated value as well.
|
||||
return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
|
||||
}
|
||||
|
||||
// A helper struct that guarantees a call to close only affects read side, but not write side.
|
||||
type nopCloserReadWriter struct {
|
||||
io.ReadWriteCloser
|
||||
|
||||
// for use by Read only
|
||||
// we don't need a memory barrier here because there is an implicit assumption that
|
||||
// Read calls can't happen concurrently by different go-routines.
|
||||
sawEOF bool
|
||||
// should be updated and read using atomic primitives.
|
||||
// value is read in Read method and written in Close method, which could be done by different
|
||||
// go-routines.
|
||||
closed uint32
|
||||
}
|
||||
|
||||
func (np *nopCloserReadWriter) Read(p []byte) (n int, err error) {
|
||||
if np.sawEOF {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if atomic.LoadUint32(&np.closed) > 0 {
|
||||
return 0, fmt.Errorf("closed by handler")
|
||||
}
|
||||
|
||||
n, err = np.ReadWriteCloser.Read(p)
|
||||
if err == io.EOF {
|
||||
np.sawEOF = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (np *nopCloserReadWriter) Close() error {
|
||||
atomic.StoreUint32(&np.closed, 1)
|
||||
|
||||
return nil
|
||||
}
|
|
@ -13,8 +13,8 @@ import (
|
|||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -26,12 +26,14 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/nettest"
|
||||
|
||||
"github.com/cloudflare/cloudflared/datagramsession"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/packet"
|
||||
cfdquic "github.com/cloudflare/cloudflared/quic"
|
||||
"github.com/cloudflare/cloudflared/tracing"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic"
|
||||
)
|
||||
|
||||
|
@ -162,11 +164,11 @@ func TestQUICServer(t *testing.T) {
|
|||
close(serverDone)
|
||||
}()
|
||||
|
||||
qc := testQUICConnection(udpListener.LocalAddr(), t, uint8(i))
|
||||
tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(i))
|
||||
|
||||
connDone := make(chan struct{})
|
||||
go func() {
|
||||
qc.Serve(ctx)
|
||||
tunnelConn.Serve(ctx)
|
||||
close(connDone)
|
||||
}()
|
||||
|
||||
|
@ -513,7 +515,6 @@ func TestServeUDPSession(t *testing.T) {
|
|||
defer udpListener.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
val := udpListener.LocalAddr()
|
||||
|
||||
// Establish QUIC connection with edge
|
||||
edgeQUICSessionChan := make(chan quic.Connection)
|
||||
|
@ -527,13 +528,14 @@ func TestServeUDPSession(t *testing.T) {
|
|||
}()
|
||||
|
||||
// Random index to avoid reusing port
|
||||
qc := testQUICConnection(val, t, 28)
|
||||
go qc.Serve(ctx)
|
||||
tunnelConn, datagramConn := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), 28)
|
||||
go tunnelConn.Serve(ctx)
|
||||
|
||||
edgeQUICSession := <-edgeQUICSessionChan
|
||||
serveSession(ctx, qc, edgeQUICSession, closedByOrigin, io.EOF.Error(), t)
|
||||
serveSession(ctx, qc, edgeQUICSession, closedByTimeout, datagramsession.SessionIdleErr(time.Millisecond*50).Error(), t)
|
||||
serveSession(ctx, qc, edgeQUICSession, closedByRemote, "eyeball closed connection", t)
|
||||
|
||||
serveSession(ctx, datagramConn, edgeQUICSession, closedByOrigin, io.EOF.Error(), t)
|
||||
serveSession(ctx, datagramConn, edgeQUICSession, closedByTimeout, datagramsession.SessionIdleErr(time.Millisecond*50).Error(), t)
|
||||
serveSession(ctx, datagramConn, edgeQUICSession, closedByRemote, "eyeball closed connection", t)
|
||||
cancel()
|
||||
}
|
||||
|
||||
|
@ -576,8 +578,20 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCreateUDPConnReuseSourcePort(t *testing.T) {
|
||||
edgeIPv4 := netip.MustParseAddrPort("0.0.0.0:0")
|
||||
edgeIPv6 := netip.MustParseAddrPort("[::]:0")
|
||||
|
||||
// We assume the test environment has access to an IPv4 interface
|
||||
testCreateUDPConnReuseSourcePortForEdgeIP(t, edgeIPv4)
|
||||
|
||||
if nettest.SupportsIPv6() {
|
||||
testCreateUDPConnReuseSourcePortForEdgeIP(t, edgeIPv6)
|
||||
}
|
||||
}
|
||||
|
||||
func testCreateUDPConnReuseSourcePortForEdgeIP(t *testing.T, edgeIP netip.AddrPort) {
|
||||
logger := zerolog.Nop()
|
||||
conn, err := createUDPConnForConnIndex(0, nil, &logger)
|
||||
conn, err := createUDPConnForConnIndex(0, nil, edgeIP, &logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
getPortFunc := func(conn *net.UDPConn) int {
|
||||
|
@ -591,34 +605,34 @@ func TestCreateUDPConnReuseSourcePort(t *testing.T) {
|
|||
conn.Close()
|
||||
|
||||
// should get the same port as before.
|
||||
conn, err = createUDPConnForConnIndex(0, nil, &logger)
|
||||
conn, err = createUDPConnForConnIndex(0, nil, edgeIP, &logger)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, initialPort, getPortFunc(conn))
|
||||
|
||||
// new index, should get a different port
|
||||
conn1, err := createUDPConnForConnIndex(1, nil, &logger)
|
||||
conn1, err := createUDPConnForConnIndex(1, nil, edgeIP, &logger)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, initialPort, getPortFunc(conn1))
|
||||
|
||||
// not closing the conn and trying to obtain a new conn for same index should give a different random port
|
||||
conn, err = createUDPConnForConnIndex(0, nil, &logger)
|
||||
conn, err = createUDPConnForConnIndex(0, nil, edgeIP, &logger)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, initialPort, getPortFunc(conn))
|
||||
}
|
||||
|
||||
func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) {
|
||||
func serveSession(ctx context.Context, datagramConn *datagramV2Connection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) {
|
||||
var (
|
||||
payload = []byte(t.Name())
|
||||
)
|
||||
sessionID := uuid.New()
|
||||
cfdConn, originConn := net.Pipe()
|
||||
// Registers and run a new session
|
||||
session, err := qc.sessionManager.RegisterSession(ctx, sessionID, cfdConn)
|
||||
session, err := datagramConn.sessionManager.RegisterSession(ctx, sessionID, cfdConn)
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionDone := make(chan struct{})
|
||||
go func() {
|
||||
qc.serveUDPSession(session, time.Millisecond*50)
|
||||
datagramConn.serveUDPSession(session, time.Millisecond*50)
|
||||
close(sessionDone)
|
||||
}()
|
||||
|
||||
|
@ -642,7 +656,7 @@ func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.
|
|||
case closedByOrigin:
|
||||
originConn.Close()
|
||||
case closedByRemote:
|
||||
err = qc.UnregisterUdpSession(ctx, sessionID, expectedReason)
|
||||
err = datagramConn.UnregisterUdpSession(ctx, sessionID, expectedReason)
|
||||
require.NoError(t, err)
|
||||
case closedByTimeout:
|
||||
}
|
||||
|
@ -713,32 +727,59 @@ func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionI
|
|||
return nil
|
||||
}
|
||||
|
||||
func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QUICConnection {
|
||||
func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) (TunnelConnection, *datagramV2Connection) {
|
||||
tlsClientConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
NextProtos: []string{"argotunnel"},
|
||||
}
|
||||
// Start a mock httpProxy
|
||||
log := zerolog.New(os.Stdout)
|
||||
log := zerolog.New(io.Discard)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
qc, err := NewQUICConnection(
|
||||
|
||||
// Dial the QUIC connection to the edge
|
||||
conn, err := DialQuic(
|
||||
ctx,
|
||||
testQUICConfig,
|
||||
udpListenerAddr,
|
||||
nil,
|
||||
index,
|
||||
tlsClientConfig,
|
||||
&mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}},
|
||||
&tunnelpogs.ConnectionOptions{},
|
||||
fakeControlStream{},
|
||||
serverAddr,
|
||||
nil, // connect on a random port
|
||||
index,
|
||||
&log,
|
||||
nil,
|
||||
)
|
||||
|
||||
// Start a session manager for the connection
|
||||
sessionDemuxChan := make(chan *packet.Session, 4)
|
||||
datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, &log, sessionDemuxChan)
|
||||
sessionManager := datagramsession.NewManager(&log, datagramMuxer.SendToSession, sessionDemuxChan)
|
||||
var connIndex uint8 = 0
|
||||
packetRouter := ingress.NewPacketRouter(nil, datagramMuxer, connIndex, &log)
|
||||
|
||||
datagramConn := &datagramV2Connection{
|
||||
conn,
|
||||
sessionManager,
|
||||
datagramMuxer,
|
||||
packetRouter,
|
||||
15 * time.Second,
|
||||
0 * time.Second,
|
||||
&log,
|
||||
}
|
||||
|
||||
tunnelConn, err := NewTunnelConnection(
|
||||
ctx,
|
||||
conn,
|
||||
index,
|
||||
&mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}},
|
||||
datagramConn,
|
||||
fakeControlStream{},
|
||||
&pogs.ConnectionOptions{},
|
||||
15*time.Second,
|
||||
0*time.Second,
|
||||
0*time.Second,
|
||||
&log,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return qc
|
||||
return tunnelConn, datagramConn
|
||||
}
|
||||
|
||||
type mockReaderNoopWriter struct {
|
|
@ -0,0 +1,201 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/cloudflare/cloudflared/datagramsession"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
"github.com/cloudflare/cloudflared/packet"
|
||||
cfdquic "github.com/cloudflare/cloudflared/quic"
|
||||
"github.com/cloudflare/cloudflared/tracing"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
rpcquic "github.com/cloudflare/cloudflared/tunnelrpc/quic"
|
||||
)
|
||||
|
||||
const (
|
||||
// emperically this capacity has been working well
|
||||
demuxChanCapacity = 16
|
||||
)
|
||||
|
||||
// DatagramSessionHandler is a service that can serve datagrams for a connection and handle sessions from incoming
|
||||
// connection streams.
|
||||
type DatagramSessionHandler interface {
|
||||
Serve(context.Context) error
|
||||
|
||||
pogs.SessionManager
|
||||
}
|
||||
|
||||
type datagramV2Connection struct {
|
||||
conn quic.Connection
|
||||
|
||||
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
|
||||
sessionManager datagramsession.Manager
|
||||
// datagramMuxer mux/demux datagrams from quic connection
|
||||
datagramMuxer *cfdquic.DatagramMuxerV2
|
||||
packetRouter *ingress.PacketRouter
|
||||
|
||||
rpcTimeout time.Duration
|
||||
streamWriteTimeout time.Duration
|
||||
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
func NewDatagramV2Connection(ctx context.Context,
|
||||
conn quic.Connection,
|
||||
icmpRouter ingress.ICMPRouter,
|
||||
index uint8,
|
||||
rpcTimeout time.Duration,
|
||||
streamWriteTimeout time.Duration,
|
||||
logger *zerolog.Logger,
|
||||
) DatagramSessionHandler {
|
||||
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
|
||||
datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, logger, sessionDemuxChan)
|
||||
sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan)
|
||||
packetRouter := ingress.NewPacketRouter(icmpRouter, datagramMuxer, index, logger)
|
||||
|
||||
return &datagramV2Connection{
|
||||
conn,
|
||||
sessionManager,
|
||||
datagramMuxer,
|
||||
packetRouter,
|
||||
rpcTimeout,
|
||||
streamWriteTimeout,
|
||||
logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *datagramV2Connection) Serve(ctx context.Context) error {
|
||||
// If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits
|
||||
// as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this
|
||||
// connection).
|
||||
// If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the
|
||||
// other goroutine as fast as possible.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
errGroup.Go(func() error {
|
||||
defer cancel()
|
||||
return d.sessionManager.Serve(ctx)
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
defer cancel()
|
||||
return d.datagramMuxer.ServeReceive(ctx)
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
defer cancel()
|
||||
return d.packetRouter.Serve(ctx)
|
||||
})
|
||||
|
||||
return errGroup.Wait()
|
||||
}
|
||||
|
||||
// RegisterUdpSession is the RPC method invoked by edge to register and run a session
|
||||
func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*tunnelpogs.RegisterUdpSessionResponse, error) {
|
||||
traceCtx := tracing.NewTracedContext(ctx, traceContext, q.logger)
|
||||
ctx, registerSpan := traceCtx.Tracer().Start(traceCtx, "register-session", trace.WithAttributes(
|
||||
attribute.String("session-id", sessionID.String()),
|
||||
attribute.String("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)),
|
||||
))
|
||||
log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger()
|
||||
// Each session is a series of datagram from an eyeball to a dstIP:dstPort.
|
||||
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
|
||||
originProxy, err := ingress.DialUDP(dstIP, dstPort)
|
||||
if err != nil {
|
||||
log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
|
||||
tracing.EndWithErrorStatus(registerSpan, err)
|
||||
return nil, err
|
||||
}
|
||||
registerSpan.SetAttributes(
|
||||
attribute.Bool("socket-bind-success", true),
|
||||
attribute.String("src", originProxy.LocalAddr().String()),
|
||||
)
|
||||
|
||||
session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy)
|
||||
if err != nil {
|
||||
originProxy.Close()
|
||||
log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
|
||||
tracing.EndWithErrorStatus(registerSpan, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go q.serveUDPSession(session, closeAfterIdleHint)
|
||||
|
||||
log.Debug().
|
||||
Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).
|
||||
Str("src", originProxy.LocalAddr().String()).
|
||||
Str("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)).
|
||||
Msgf("Registered session")
|
||||
tracing.End(registerSpan)
|
||||
|
||||
resp := tunnelpogs.RegisterUdpSessionResponse{
|
||||
Spans: traceCtx.GetProtoSpans(),
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UnregisterUdpSession is the RPC method invoked by edge to unregister and terminate a sesssion
|
||||
func (q *datagramV2Connection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
|
||||
return q.sessionManager.UnregisterSession(ctx, sessionID, message, true)
|
||||
}
|
||||
|
||||
func (q *datagramV2Connection) serveUDPSession(session *datagramsession.Session, closeAfterIdleHint time.Duration) {
|
||||
ctx := q.conn.Context()
|
||||
closedByRemote, err := session.Serve(ctx, closeAfterIdleHint)
|
||||
// If session is terminated by remote, then we know it has been unregistered from session manager and edge
|
||||
if !closedByRemote {
|
||||
if err != nil {
|
||||
q.closeUDPSession(ctx, session.ID, err.Error())
|
||||
} else {
|
||||
q.closeUDPSession(ctx, session.ID, "terminated without error")
|
||||
}
|
||||
}
|
||||
q.logger.Debug().Err(err).
|
||||
Int(management.EventTypeKey, int(management.UDP)).
|
||||
Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(session.ID)).
|
||||
Msg("Session terminated")
|
||||
}
|
||||
|
||||
// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge
|
||||
func (q *datagramV2Connection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) {
|
||||
q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
|
||||
quicStream, err := q.conn.OpenStream()
|
||||
if err != nil {
|
||||
// Log this at debug because this is not an error if session was closed due to lost connection
|
||||
// with edge
|
||||
q.logger.Debug().Err(err).
|
||||
Int(management.EventTypeKey, int(management.UDP)).
|
||||
Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).
|
||||
Msgf("Failed to open quic stream to unregister udp session with edge")
|
||||
return
|
||||
}
|
||||
|
||||
stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
|
||||
defer stream.Close()
|
||||
rpcClientStream, err := rpcquic.NewSessionClient(ctx, stream, q.rpcTimeout)
|
||||
if err != nil {
|
||||
// Log this at debug because this is not an error if session was closed due to lost connection
|
||||
// with edge
|
||||
q.logger.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).
|
||||
Msgf("Failed to open rpc stream to unregister udp session with edge")
|
||||
return
|
||||
}
|
||||
defer rpcClientStream.Close()
|
||||
|
||||
if err := rpcClientStream.UnregisterUdpSession(ctx, sessionID, message); err != nil {
|
||||
q.logger.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).
|
||||
Msgf("Failed to unregister udp session with edge")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
cfdquic "github.com/cloudflare/cloudflared/quic/v3"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
type datagramV3Connection struct {
|
||||
conn quic.Connection
|
||||
// datagramMuxer mux/demux datagrams from quic connection
|
||||
datagramMuxer cfdquic.DatagramConn
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
func NewDatagramV3Connection(ctx context.Context,
|
||||
conn quic.Connection,
|
||||
sessionManager cfdquic.SessionManager,
|
||||
icmpRouter ingress.ICMPRouter,
|
||||
index uint8,
|
||||
metrics cfdquic.Metrics,
|
||||
logger *zerolog.Logger,
|
||||
) DatagramSessionHandler {
|
||||
log := logger.
|
||||
With().
|
||||
Int(management.EventTypeKey, int(management.UDP)).
|
||||
Uint8(LogFieldConnIndex, index).
|
||||
Logger()
|
||||
datagramMuxer := cfdquic.NewDatagramConn(conn, sessionManager, icmpRouter, index, metrics, &log)
|
||||
|
||||
return &datagramV3Connection{
|
||||
conn,
|
||||
datagramMuxer,
|
||||
logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *datagramV3Connection) Serve(ctx context.Context) error {
|
||||
return d.datagramMuxer.Serve(ctx)
|
||||
}
|
||||
|
||||
func (d *datagramV3Connection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration, traceContext string) (*pogs.RegisterUdpSessionResponse, error) {
|
||||
return nil, fmt.Errorf("datagram v3 does not support RegisterUdpSession RPC")
|
||||
}
|
||||
|
||||
func (d *datagramV3Connection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
|
||||
return fmt.Errorf("datagram v3 does not support UnregisterUdpSession RPC")
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
@ -20,8 +21,15 @@ const (
|
|||
|
||||
var (
|
||||
errSessionManagerClosed = fmt.Errorf("session manager closed")
|
||||
LogFieldSessionID = "sessionID"
|
||||
)
|
||||
|
||||
func FormatSessionID(sessionID uuid.UUID) string {
|
||||
sessionIDStr := sessionID.String()
|
||||
sessionIDStr = strings.ReplaceAll(sessionIDStr, "-", "")
|
||||
return sessionIDStr
|
||||
}
|
||||
|
||||
// Manager defines the APIs to manage sessions from the same transport.
|
||||
type Manager interface {
|
||||
// Serve starts the event loop
|
||||
|
@ -127,7 +135,7 @@ func (m *manager) registerSession(ctx context.Context, registration *registerSes
|
|||
func (m *manager) newSession(id uuid.UUID, dstConn io.ReadWriteCloser) *Session {
|
||||
logger := m.log.With().
|
||||
Int(management.EventTypeKey, int(management.UDP)).
|
||||
Str("sessionID", id.String()).Logger()
|
||||
Str(LogFieldSessionID, FormatSessionID(id)).Logger()
|
||||
return &Session{
|
||||
ID: id,
|
||||
sendFunc: m.sendFunc,
|
||||
|
@ -174,7 +182,7 @@ func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) {
|
|||
func (m *manager) sendToSession(datagram *packet.Session) {
|
||||
session, ok := m.sessions[datagram.ID]
|
||||
if !ok {
|
||||
m.log.Error().Str("sessionID", datagram.ID.String()).Msg("session not found")
|
||||
m.log.Error().Str(LogFieldSessionID, FormatSessionID(datagram.ID)).Msg("session not found")
|
||||
return
|
||||
}
|
||||
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
FROM golang:1.22.2 as builder
|
||||
FROM golang:1.22.5 as builder
|
||||
ENV GO111MODULE=on \
|
||||
CGO_ENABLED=0
|
||||
WORKDIR /go/src/github.com/cloudflare/cloudflared/
|
||||
|
|
|
@ -0,0 +1,216 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
)
|
||||
|
||||
type httpClient struct {
|
||||
http.Client
|
||||
baseURL *url.URL
|
||||
}
|
||||
|
||||
func NewHTTPClient() *httpClient {
|
||||
httpTransport := http.Transport{
|
||||
TLSHandshakeTimeout: defaultTimeout,
|
||||
ResponseHeaderTimeout: defaultTimeout,
|
||||
}
|
||||
|
||||
return &httpClient{
|
||||
http.Client{
|
||||
Transport: &httpTransport,
|
||||
Timeout: defaultTimeout,
|
||||
},
|
||||
nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (client *httpClient) SetBaseURL(baseURL *url.URL) {
|
||||
client.baseURL = baseURL
|
||||
}
|
||||
|
||||
func (client *httpClient) GET(ctx context.Context, endpoint string) (*http.Response, error) {
|
||||
if client.baseURL == nil {
|
||||
return nil, ErrNoBaseURL
|
||||
}
|
||||
url := client.baseURL.JoinPath(endpoint)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating GET request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Add("Accept", "application/json;version=1")
|
||||
|
||||
response, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error GET request: %w", err)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
type LogConfiguration struct {
|
||||
logFile string
|
||||
logDirectory string
|
||||
uid int // the uid of the user that started cloudflared
|
||||
}
|
||||
|
||||
func (client *httpClient) GetLogConfiguration(ctx context.Context) (*LogConfiguration, error) {
|
||||
response, err := client.GET(ctx, cliConfigurationEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
|
||||
var data map[string]string
|
||||
if err := json.NewDecoder(response.Body).Decode(&data); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode body: %w", err)
|
||||
}
|
||||
|
||||
uidStr, exists := data[configurationKeyUID]
|
||||
if !exists {
|
||||
return nil, ErrKeyNotFound
|
||||
}
|
||||
|
||||
uid, err := strconv.Atoi(uidStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error convertin pid to int: %w", err)
|
||||
}
|
||||
|
||||
logFile, exists := data[logger.LogFileFlag]
|
||||
if exists {
|
||||
return &LogConfiguration{logFile, "", uid}, nil
|
||||
}
|
||||
|
||||
logDirectory, exists := data[logger.LogDirectoryFlag]
|
||||
if exists {
|
||||
return &LogConfiguration{"", logDirectory, uid}, nil
|
||||
}
|
||||
|
||||
// No log configured may happen when cloudflared is executed as a managed service or
|
||||
// when containerized
|
||||
return &LogConfiguration{"", "", uid}, nil
|
||||
}
|
||||
|
||||
func (client *httpClient) GetMemoryDump(ctx context.Context, writer io.Writer) error {
|
||||
response, err := client.GET(ctx, memoryDumpEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return copyToWriter(response, writer)
|
||||
}
|
||||
|
||||
func (client *httpClient) GetGoroutineDump(ctx context.Context, writer io.Writer) error {
|
||||
response, err := client.GET(ctx, goroutineDumpEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return copyToWriter(response, writer)
|
||||
}
|
||||
|
||||
func (client *httpClient) GetTunnelState(ctx context.Context) (*TunnelState, error) {
|
||||
response, err := client.GET(ctx, tunnelStateEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
|
||||
var state TunnelState
|
||||
if err := json.NewDecoder(response.Body).Decode(&state); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode body: %w", err)
|
||||
}
|
||||
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
func (client *httpClient) GetSystemInformation(ctx context.Context, writer io.Writer) error {
|
||||
response, err := client.GET(ctx, systemInformationEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return copyJSONToWriter(response, writer)
|
||||
}
|
||||
|
||||
func (client *httpClient) GetMetrics(ctx context.Context, writer io.Writer) error {
|
||||
response, err := client.GET(ctx, metricsEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return copyToWriter(response, writer)
|
||||
}
|
||||
|
||||
func (client *httpClient) GetTunnelConfiguration(ctx context.Context, writer io.Writer) error {
|
||||
response, err := client.GET(ctx, tunnelConfigurationEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return copyJSONToWriter(response, writer)
|
||||
}
|
||||
|
||||
func (client *httpClient) GetCliConfiguration(ctx context.Context, writer io.Writer) error {
|
||||
response, err := client.GET(ctx, cliConfigurationEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return copyJSONToWriter(response, writer)
|
||||
}
|
||||
|
||||
func copyToWriter(response *http.Response, writer io.Writer) error {
|
||||
defer response.Body.Close()
|
||||
|
||||
_, err := io.Copy(writer, response.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyJSONToWriter(response *http.Response, writer io.Writer) error {
|
||||
defer response.Body.Close()
|
||||
|
||||
var data interface{}
|
||||
|
||||
decoder := json.NewDecoder(response.Body)
|
||||
|
||||
err := decoder.Decode(&data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("diagnostic client error whilst reading response: %w", err)
|
||||
}
|
||||
|
||||
encoder := newFormattedEncoder(writer)
|
||||
|
||||
err = encoder.Encode(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("diagnostic client error whilst writing json: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type HTTPClient interface {
|
||||
GetLogConfiguration(ctx context.Context) (*LogConfiguration, error)
|
||||
GetMemoryDump(ctx context.Context, writer io.Writer) error
|
||||
GetGoroutineDump(ctx context.Context, writer io.Writer) error
|
||||
GetTunnelState(ctx context.Context) (*TunnelState, error)
|
||||
GetSystemInformation(ctx context.Context, writer io.Writer) error
|
||||
GetMetrics(ctx context.Context, writer io.Writer) error
|
||||
GetCliConfiguration(ctx context.Context, writer io.Writer) error
|
||||
GetTunnelConfiguration(ctx context.Context, writer io.Writer) error
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package diagnostic
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
defaultCollectorTimeout = time.Second * 10 // This const define the timeout value of a collector operation.
|
||||
collectorField = "collector" // used for logging purposes
|
||||
systemCollectorName = "system" // used for logging purposes
|
||||
tunnelStateCollectorName = "tunnelState" // used for logging purposes
|
||||
configurationCollectorName = "configuration" // used for logging purposes
|
||||
defaultTimeout = 15 * time.Second // timeout for the collectors
|
||||
twoWeeksOffset = -14 * 24 * time.Hour // maximum offset for the logs
|
||||
logFilename = "cloudflared_logs.txt" // name of the output log file
|
||||
configurationKeyUID = "uid" // Key used to set and get the UID value from the configuration map
|
||||
tailMaxNumberOfLines = "10000" // maximum number of log lines from a virtual runtime (docker or kubernetes)
|
||||
|
||||
// Endpoints used by the diagnostic HTTP Client.
|
||||
cliConfigurationEndpoint = "/diag/configuration"
|
||||
tunnelStateEndpoint = "/diag/tunnel"
|
||||
systemInformationEndpoint = "/diag/system"
|
||||
memoryDumpEndpoint = "debug/pprof/heap"
|
||||
goroutineDumpEndpoint = "debug/pprof/goroutine"
|
||||
metricsEndpoint = "metrics"
|
||||
tunnelConfigurationEndpoint = "/config"
|
||||
// Base for filenames of the diagnostic procedure
|
||||
systemInformationBaseName = "systeminformation.json"
|
||||
metricsBaseName = "metrics.txt"
|
||||
zipName = "cloudflared-diag"
|
||||
heapPprofBaseName = "heap.pprof"
|
||||
goroutinePprofBaseName = "goroutine.pprof"
|
||||
networkBaseName = "network.json"
|
||||
rawNetworkBaseName = "raw-network.txt"
|
||||
tunnelStateBaseName = "tunnelstate.json"
|
||||
cliConfigurationBaseName = "cli-configuration.json"
|
||||
configurationBaseName = "configuration.json"
|
||||
taskResultBaseName = "task-result.json"
|
||||
)
|
|
@ -0,0 +1,561 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
network "github.com/cloudflare/cloudflared/diagnostic/network"
|
||||
)
|
||||
|
||||
const (
|
||||
taskSuccess = "success"
|
||||
taskFailure = "failure"
|
||||
jobReportName = "job report"
|
||||
tunnelStateJobName = "tunnel state"
|
||||
systemInformationJobName = "system information"
|
||||
goroutineJobName = "goroutine profile"
|
||||
heapJobName = "heap profile"
|
||||
metricsJobName = "metrics"
|
||||
logInformationJobName = "log information"
|
||||
rawNetworkInformationJobName = "raw network information"
|
||||
networkInformationJobName = "network information"
|
||||
cliConfigurationJobName = "cli configuration"
|
||||
configurationJobName = "configuration"
|
||||
)
|
||||
|
||||
// Struct used to hold the results of different routines executing the network collection.
|
||||
type taskResult struct {
|
||||
Result string `json:"result,omitempty"`
|
||||
Err error `json:"error,omitempty"`
|
||||
path string
|
||||
}
|
||||
|
||||
func (result taskResult) MarshalJSON() ([]byte, error) {
|
||||
s := map[string]string{
|
||||
"result": result.Result,
|
||||
}
|
||||
if result.Err != nil {
|
||||
s["error"] = result.Err.Error()
|
||||
}
|
||||
|
||||
return json.Marshal(s)
|
||||
}
|
||||
|
||||
// Struct used to hold the results of different routines executing the network collection.
|
||||
type networkCollectionResult struct {
|
||||
name string
|
||||
info []*network.Hop
|
||||
raw string
|
||||
err error
|
||||
}
|
||||
|
||||
// This type represents the most common functions from the diagnostic http client
|
||||
// functions.
|
||||
type collectToWriterFunc func(ctx context.Context, writer io.Writer) error
|
||||
|
||||
// This type represents the common denominator among all the collection procedures.
|
||||
type collectFunc func(ctx context.Context) (string, error)
|
||||
|
||||
// collectJob is an internal struct that denotes holds the information necessary
|
||||
// to run a collection job.
|
||||
type collectJob struct {
|
||||
jobName string
|
||||
fn collectFunc
|
||||
bypass bool
|
||||
}
|
||||
|
||||
// The Toggles structure denotes the available toggles for the diagnostic procedure.
|
||||
// Each toggle enables/disables tasks from the diagnostic.
|
||||
type Toggles struct {
|
||||
NoDiagLogs bool
|
||||
NoDiagMetrics bool
|
||||
NoDiagSystem bool
|
||||
NoDiagRuntime bool
|
||||
NoDiagNetwork bool
|
||||
}
|
||||
|
||||
// The Options structure holds every option necessary for
|
||||
// the diagnostic procedure to work.
|
||||
type Options struct {
|
||||
KnownAddresses []string
|
||||
Address string
|
||||
ContainerID string
|
||||
PodID string
|
||||
Toggles Toggles
|
||||
}
|
||||
|
||||
func collectLogs(
|
||||
ctx context.Context,
|
||||
client HTTPClient,
|
||||
diagContainer, diagPod string,
|
||||
) (string, error) {
|
||||
var collector LogCollector
|
||||
if diagPod != "" {
|
||||
collector = NewKubernetesLogCollector(diagContainer, diagPod)
|
||||
} else if diagContainer != "" {
|
||||
collector = NewDockerLogCollector(diagContainer)
|
||||
} else {
|
||||
collector = NewHostLogCollector(client)
|
||||
}
|
||||
|
||||
logInformation, err := collector.Collect(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error collecting logs: %w", err)
|
||||
}
|
||||
|
||||
if logInformation.isDirectory {
|
||||
return CopyFilesFromDirectory(logInformation.path)
|
||||
}
|
||||
|
||||
if logInformation.wasCreated {
|
||||
return logInformation.path, nil
|
||||
}
|
||||
|
||||
logHandle, err := os.Open(logInformation.path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error opening log file while collecting logs: %w", err)
|
||||
}
|
||||
defer logHandle.Close()
|
||||
|
||||
outputLogHandle, err := os.Create(filepath.Join(os.TempDir(), logFilename))
|
||||
if err != nil {
|
||||
return "", ErrCreatingTemporaryFile
|
||||
}
|
||||
defer outputLogHandle.Close()
|
||||
|
||||
_, err = io.Copy(outputLogHandle, logHandle)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error copying logs while collecting logs: %w", err)
|
||||
}
|
||||
|
||||
return outputLogHandle.Name(), err
|
||||
}
|
||||
|
||||
func collectNetworkResultRoutine(
|
||||
ctx context.Context,
|
||||
collector network.NetworkCollector,
|
||||
hostname string,
|
||||
useIPv4 bool,
|
||||
results chan networkCollectionResult,
|
||||
) {
|
||||
const (
|
||||
hopsNo = 5
|
||||
timeout = time.Second * 5
|
||||
)
|
||||
|
||||
name := hostname
|
||||
|
||||
if useIPv4 {
|
||||
name += "-v4"
|
||||
} else {
|
||||
name += "-v6"
|
||||
}
|
||||
|
||||
hops, raw, err := collector.Collect(ctx, network.NewTraceOptions(hopsNo, timeout, hostname, useIPv4))
|
||||
results <- networkCollectionResult{name, hops, raw, err}
|
||||
}
|
||||
|
||||
func gatherNetworkInformation(ctx context.Context) map[string]networkCollectionResult {
|
||||
networkCollector := network.NetworkCollectorImpl{}
|
||||
|
||||
hostAndIPversionPairs := []struct {
|
||||
host string
|
||||
useV4 bool
|
||||
}{
|
||||
{"region1.v2.argotunnel.com", true},
|
||||
{"region1.v2.argotunnel.com", false},
|
||||
{"region2.v2.argotunnel.com", true},
|
||||
{"region2.v2.argotunnel.com", false},
|
||||
}
|
||||
|
||||
// the number of results is known thus use len to avoid footguns
|
||||
results := make(chan networkCollectionResult, len(hostAndIPversionPairs))
|
||||
|
||||
var wgroup sync.WaitGroup
|
||||
|
||||
for _, item := range hostAndIPversionPairs {
|
||||
wgroup.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wgroup.Done()
|
||||
collectNetworkResultRoutine(ctx, &networkCollector, item.host, item.useV4, results)
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for routines to end.
|
||||
wgroup.Wait()
|
||||
|
||||
resultMap := make(map[string]networkCollectionResult)
|
||||
|
||||
for range len(hostAndIPversionPairs) {
|
||||
result := <-results
|
||||
resultMap[result.name] = result
|
||||
}
|
||||
|
||||
return resultMap
|
||||
}
|
||||
|
||||
func networkInformationCollectors() (rawNetworkCollector, jsonNetworkCollector collectFunc) {
|
||||
// The network collector is an operation that takes most of the diagnostic time, thus,
|
||||
// the sync.Once is used to memoize the result of the collector and then create different
|
||||
// outputs.
|
||||
var once sync.Once
|
||||
|
||||
var resultMap map[string]networkCollectionResult
|
||||
|
||||
rawNetworkCollector = func(ctx context.Context) (string, error) {
|
||||
once.Do(func() { resultMap = gatherNetworkInformation(ctx) })
|
||||
|
||||
return rawNetworkInformationWriter(resultMap)
|
||||
}
|
||||
jsonNetworkCollector = func(ctx context.Context) (string, error) {
|
||||
once.Do(func() { resultMap = gatherNetworkInformation(ctx) })
|
||||
|
||||
return jsonNetworkInformationWriter(resultMap)
|
||||
}
|
||||
|
||||
return rawNetworkCollector, jsonNetworkCollector
|
||||
}
|
||||
|
||||
func rawNetworkInformationWriter(resultMap map[string]networkCollectionResult) (string, error) {
|
||||
networkDumpHandle, err := os.Create(filepath.Join(os.TempDir(), rawNetworkBaseName))
|
||||
if err != nil {
|
||||
return "", ErrCreatingTemporaryFile
|
||||
}
|
||||
|
||||
defer networkDumpHandle.Close()
|
||||
|
||||
var exitErr error
|
||||
|
||||
for k, v := range resultMap {
|
||||
if v.err != nil {
|
||||
if exitErr == nil {
|
||||
exitErr = v.err
|
||||
}
|
||||
|
||||
_, err := networkDumpHandle.WriteString(k + "\nno content\n")
|
||||
if err != nil {
|
||||
return networkDumpHandle.Name(), fmt.Errorf("error writing 'no content' to raw network file: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, err := networkDumpHandle.WriteString(k + "\n" + v.raw + "\n")
|
||||
if err != nil {
|
||||
return networkDumpHandle.Name(), fmt.Errorf("error writing raw network information: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return networkDumpHandle.Name(), exitErr
|
||||
}
|
||||
|
||||
func jsonNetworkInformationWriter(resultMap map[string]networkCollectionResult) (string, error) {
|
||||
networkDumpHandle, err := os.Create(filepath.Join(os.TempDir(), networkBaseName))
|
||||
if err != nil {
|
||||
return "", ErrCreatingTemporaryFile
|
||||
}
|
||||
|
||||
defer networkDumpHandle.Close()
|
||||
|
||||
encoder := newFormattedEncoder(networkDumpHandle)
|
||||
|
||||
var exitErr error
|
||||
|
||||
jsonMap := make(map[string][]*network.Hop, len(resultMap))
|
||||
for k, v := range resultMap {
|
||||
jsonMap[k] = v.info
|
||||
|
||||
if exitErr == nil && v.err != nil {
|
||||
exitErr = v.err
|
||||
}
|
||||
}
|
||||
|
||||
err = encoder.Encode(jsonMap)
|
||||
if err != nil {
|
||||
return networkDumpHandle.Name(), fmt.Errorf("error encoding network information results: %w", err)
|
||||
}
|
||||
|
||||
return networkDumpHandle.Name(), exitErr
|
||||
}
|
||||
|
||||
func collectFromEndpointAdapter(collect collectToWriterFunc, fileName string) collectFunc {
|
||||
return func(ctx context.Context) (string, error) {
|
||||
dumpHandle, err := os.Create(filepath.Join(os.TempDir(), fileName))
|
||||
if err != nil {
|
||||
return "", ErrCreatingTemporaryFile
|
||||
}
|
||||
defer dumpHandle.Close()
|
||||
|
||||
err = collect(ctx, dumpHandle)
|
||||
if err != nil {
|
||||
return dumpHandle.Name(), fmt.Errorf("error running collector: %w", err)
|
||||
}
|
||||
|
||||
return dumpHandle.Name(), nil
|
||||
}
|
||||
}
|
||||
|
||||
func tunnelStateCollectEndpointAdapter(client HTTPClient, tunnel *TunnelState, fileName string) collectFunc {
|
||||
endpointFunc := func(ctx context.Context, writer io.Writer) error {
|
||||
if tunnel == nil {
|
||||
// When the metrics server is not passed the diagnostic will query all known hosts
|
||||
// and get the tunnel state, however, when the metrics server is passed that won't
|
||||
// happen hence the check for nil in this function.
|
||||
tunnelResponse, err := client.GetTunnelState(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error retrieving tunnel state: %w", err)
|
||||
}
|
||||
|
||||
tunnel = tunnelResponse
|
||||
}
|
||||
|
||||
encoder := newFormattedEncoder(writer)
|
||||
|
||||
err := encoder.Encode(tunnel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error encoding tunnel state: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return collectFromEndpointAdapter(endpointFunc, fileName)
|
||||
}
|
||||
|
||||
// resolveInstanceBaseURL is responsible to
|
||||
// resolve the base URL of the instance that should be diagnosed.
|
||||
// To resolve the instance it may be necessary to query the
|
||||
// /diag/tunnel endpoint of the known instances, thus, if a single
|
||||
// instance is found its state is also returned; if multiple instances
|
||||
// are found then their states are returned in an array along with an
|
||||
// error.
|
||||
func resolveInstanceBaseURL(
|
||||
metricsServerAddress string,
|
||||
log *zerolog.Logger,
|
||||
client *httpClient,
|
||||
addresses []string,
|
||||
) (*url.URL, *TunnelState, []*AddressableTunnelState, error) {
|
||||
if metricsServerAddress != "" {
|
||||
if !strings.HasPrefix(metricsServerAddress, "http://") {
|
||||
metricsServerAddress = "http://" + metricsServerAddress
|
||||
}
|
||||
url, err := url.Parse(metricsServerAddress)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("provided address is not valid: %w", err)
|
||||
}
|
||||
|
||||
return url, nil, nil, nil
|
||||
}
|
||||
|
||||
tunnelState, foundTunnelStates, err := FindMetricsServer(log, client, addresses)
|
||||
if err != nil {
|
||||
return nil, nil, foundTunnelStates, err
|
||||
}
|
||||
|
||||
return tunnelState.URL, tunnelState.TunnelState, nil, nil
|
||||
}
|
||||
|
||||
func createJobs(
|
||||
client *httpClient,
|
||||
tunnel *TunnelState,
|
||||
diagContainer string,
|
||||
diagPod string,
|
||||
noDiagSystem bool,
|
||||
noDiagRuntime bool,
|
||||
noDiagMetrics bool,
|
||||
noDiagLogs bool,
|
||||
noDiagNetwork bool,
|
||||
) []collectJob {
|
||||
rawNetworkCollectorFunc, jsonNetworkCollectorFunc := networkInformationCollectors()
|
||||
jobs := []collectJob{
|
||||
{
|
||||
jobName: tunnelStateJobName,
|
||||
fn: tunnelStateCollectEndpointAdapter(client, tunnel, tunnelStateBaseName),
|
||||
bypass: false,
|
||||
},
|
||||
{
|
||||
jobName: systemInformationJobName,
|
||||
fn: collectFromEndpointAdapter(client.GetSystemInformation, systemInformationBaseName),
|
||||
bypass: noDiagSystem,
|
||||
},
|
||||
{
|
||||
jobName: goroutineJobName,
|
||||
fn: collectFromEndpointAdapter(client.GetGoroutineDump, goroutinePprofBaseName),
|
||||
bypass: noDiagRuntime,
|
||||
},
|
||||
{
|
||||
jobName: heapJobName,
|
||||
fn: collectFromEndpointAdapter(client.GetMemoryDump, heapPprofBaseName),
|
||||
bypass: noDiagRuntime,
|
||||
},
|
||||
{
|
||||
jobName: metricsJobName,
|
||||
fn: collectFromEndpointAdapter(client.GetMetrics, metricsBaseName),
|
||||
bypass: noDiagMetrics,
|
||||
},
|
||||
{
|
||||
jobName: logInformationJobName,
|
||||
fn: func(ctx context.Context) (string, error) {
|
||||
return collectLogs(ctx, client, diagContainer, diagPod)
|
||||
},
|
||||
bypass: noDiagLogs,
|
||||
},
|
||||
{
|
||||
jobName: rawNetworkInformationJobName,
|
||||
fn: rawNetworkCollectorFunc,
|
||||
bypass: noDiagNetwork,
|
||||
},
|
||||
{
|
||||
jobName: networkInformationJobName,
|
||||
fn: jsonNetworkCollectorFunc,
|
||||
bypass: noDiagNetwork,
|
||||
},
|
||||
{
|
||||
jobName: cliConfigurationJobName,
|
||||
fn: collectFromEndpointAdapter(client.GetCliConfiguration, cliConfigurationBaseName),
|
||||
bypass: false,
|
||||
},
|
||||
{
|
||||
jobName: configurationJobName,
|
||||
fn: collectFromEndpointAdapter(client.GetTunnelConfiguration, configurationBaseName),
|
||||
bypass: false,
|
||||
},
|
||||
}
|
||||
|
||||
return jobs
|
||||
}
|
||||
|
||||
func createTaskReport(taskReport map[string]taskResult) (string, error) {
|
||||
dumpHandle, err := os.Create(filepath.Join(os.TempDir(), taskResultBaseName))
|
||||
if err != nil {
|
||||
return "", ErrCreatingTemporaryFile
|
||||
}
|
||||
defer dumpHandle.Close()
|
||||
|
||||
encoder := newFormattedEncoder(dumpHandle)
|
||||
|
||||
err = encoder.Encode(taskReport)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error encoding task results: %w", err)
|
||||
}
|
||||
|
||||
return dumpHandle.Name(), nil
|
||||
}
|
||||
|
||||
func runJobs(ctx context.Context, jobs []collectJob, log *zerolog.Logger) map[string]taskResult {
|
||||
jobReport := make(map[string]taskResult, len(jobs))
|
||||
|
||||
for _, job := range jobs {
|
||||
if job.bypass {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info().Msgf("Collecting %s...", job.jobName)
|
||||
path, err := job.fn(ctx)
|
||||
|
||||
var result taskResult
|
||||
if err != nil {
|
||||
result = taskResult{Result: taskFailure, Err: err, path: path}
|
||||
|
||||
log.Error().Err(err).Msgf("Job: %s finished with error.", job.jobName)
|
||||
} else {
|
||||
result = taskResult{Result: taskSuccess, Err: nil, path: path}
|
||||
|
||||
log.Info().Msgf("Collected %s.", job.jobName)
|
||||
}
|
||||
|
||||
jobReport[job.jobName] = result
|
||||
}
|
||||
|
||||
taskReportName, err := createTaskReport(jobReport)
|
||||
|
||||
var result taskResult
|
||||
|
||||
if err != nil {
|
||||
result = taskResult{
|
||||
Result: taskFailure,
|
||||
path: taskReportName,
|
||||
Err: err,
|
||||
}
|
||||
} else {
|
||||
result = taskResult{
|
||||
Result: taskSuccess,
|
||||
path: taskReportName,
|
||||
Err: nil,
|
||||
}
|
||||
}
|
||||
|
||||
jobReport[jobReportName] = result
|
||||
|
||||
return jobReport
|
||||
}
|
||||
|
||||
func RunDiagnostic(
|
||||
log *zerolog.Logger,
|
||||
options Options,
|
||||
) ([]*AddressableTunnelState, error) {
|
||||
client := NewHTTPClient()
|
||||
|
||||
baseURL, tunnel, foundTunnels, err := resolveInstanceBaseURL(options.Address, log, client, options.KnownAddresses)
|
||||
if err != nil {
|
||||
return foundTunnels, err
|
||||
}
|
||||
|
||||
log.Info().Msgf("Selected server %s starting diagnostic...", baseURL.String())
|
||||
client.SetBaseURL(baseURL)
|
||||
|
||||
const timeout = 45 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
|
||||
defer cancel()
|
||||
|
||||
jobs := createJobs(
|
||||
client,
|
||||
tunnel,
|
||||
options.ContainerID,
|
||||
options.PodID,
|
||||
options.Toggles.NoDiagSystem,
|
||||
options.Toggles.NoDiagRuntime,
|
||||
options.Toggles.NoDiagMetrics,
|
||||
options.Toggles.NoDiagLogs,
|
||||
options.Toggles.NoDiagNetwork,
|
||||
)
|
||||
|
||||
jobsReport := runJobs(ctx, jobs, log)
|
||||
paths := make([]string, 0)
|
||||
|
||||
var gerr error
|
||||
|
||||
for _, v := range jobsReport {
|
||||
paths = append(paths, v.path)
|
||||
|
||||
if gerr == nil && v.Err != nil {
|
||||
gerr = v.Err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if !errors.Is(v.Err, ErrCreatingTemporaryFile) {
|
||||
os.Remove(v.path)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
zipfile, err := CreateDiagnosticZipFile(zipName, paths)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info().Msgf("Diagnostic file written: %v", zipfile)
|
||||
|
||||
return nil, gerr
|
||||
}
|
|
@ -0,0 +1,148 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// CreateDiagnosticZipFile create a zip file with the contents from the all
|
||||
// files paths. The files will be written in the root of the zip file.
|
||||
// In case of an error occurs after whilst writing to the zip file
|
||||
// this will be removed.
|
||||
func CreateDiagnosticZipFile(base string, paths []string) (zipFileName string, err error) {
|
||||
// Create a zip file with all files from paths added to the root
|
||||
suffix := time.Now().Format(time.RFC3339)
|
||||
zipFileName = base + "-" + suffix + ".zip"
|
||||
zipFileName = strings.ReplaceAll(zipFileName, ":", "-")
|
||||
|
||||
archive, cerr := os.Create(zipFileName)
|
||||
if cerr != nil {
|
||||
return "", fmt.Errorf("error creating file %s: %w", zipFileName, cerr)
|
||||
}
|
||||
|
||||
archiveWriter := zip.NewWriter(archive)
|
||||
|
||||
defer func() {
|
||||
archiveWriter.Close()
|
||||
archive.Close()
|
||||
|
||||
if err != nil {
|
||||
os.Remove(zipFileName)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, file := range paths {
|
||||
if file == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var handle *os.File
|
||||
|
||||
handle, err = os.Open(file)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error opening file %s: %w", zipFileName, err)
|
||||
}
|
||||
|
||||
defer handle.Close()
|
||||
|
||||
// Keep the base only to not create sub directories in the
|
||||
// zip file.
|
||||
var writer io.Writer
|
||||
|
||||
writer, err = archiveWriter.Create(filepath.Base(file))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error creating archive writer from %s: %w", file, err)
|
||||
}
|
||||
|
||||
if _, err = io.Copy(writer, handle); err != nil {
|
||||
return "", fmt.Errorf("error copying file %s: %w", file, err)
|
||||
}
|
||||
}
|
||||
|
||||
zipFileName = archive.Name()
|
||||
return zipFileName, nil
|
||||
}
|
||||
|
||||
type AddressableTunnelState struct {
|
||||
*TunnelState
|
||||
URL *url.URL
|
||||
}
|
||||
|
||||
func findMetricsServerPredicate(tunnelID, connectorID uuid.UUID) func(state *TunnelState) bool {
|
||||
if tunnelID != uuid.Nil && connectorID != uuid.Nil {
|
||||
return func(state *TunnelState) bool {
|
||||
return state.ConnectorID == connectorID && state.TunnelID == tunnelID
|
||||
}
|
||||
} else if tunnelID == uuid.Nil && connectorID != uuid.Nil {
|
||||
return func(state *TunnelState) bool {
|
||||
return state.ConnectorID == connectorID
|
||||
}
|
||||
} else if tunnelID != uuid.Nil && connectorID == uuid.Nil {
|
||||
return func(state *TunnelState) bool {
|
||||
return state.TunnelID == tunnelID
|
||||
}
|
||||
}
|
||||
|
||||
return func(*TunnelState) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// The FindMetricsServer will try to find the metrics server url.
|
||||
// There are two possible error scenarios:
|
||||
// 1. No instance is found which will only return ErrMetricsServerNotFound
|
||||
// 2. Multiple instances are found which will return an array of state and ErrMultipleMetricsServerFound
|
||||
// In case of success, only the state for the instance is returned.
|
||||
func FindMetricsServer(
|
||||
log *zerolog.Logger,
|
||||
client *httpClient,
|
||||
addresses []string,
|
||||
) (*AddressableTunnelState, []*AddressableTunnelState, error) {
|
||||
instances := make([]*AddressableTunnelState, 0)
|
||||
|
||||
for _, address := range addresses {
|
||||
url, err := url.Parse("http://" + address)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Msgf("error parsing address %s", address)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
client.SetBaseURL(url)
|
||||
|
||||
state, err := client.GetTunnelState(context.Background())
|
||||
if err == nil {
|
||||
instances = append(instances, &AddressableTunnelState{state, url})
|
||||
} else {
|
||||
log.Debug().Err(err).Msgf("error getting tunnel state from address %s", address)
|
||||
}
|
||||
}
|
||||
|
||||
if len(instances) == 0 {
|
||||
return nil, nil, ErrMetricsServerNotFound
|
||||
}
|
||||
|
||||
if len(instances) == 1 {
|
||||
return instances[0], nil, nil
|
||||
}
|
||||
|
||||
return nil, instances, ErrMultipleMetricsServerFound
|
||||
}
|
||||
|
||||
// newFormattedEncoder return a JSON encoder with identation
|
||||
func newFormattedEncoder(w io.Writer) *json.Encoder {
|
||||
encoder := json.NewEncoder(w)
|
||||
encoder.SetIndent("", " ")
|
||||
return encoder
|
||||
}
|
|
@ -0,0 +1,147 @@
|
|||
package diagnostic_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/facebookgo/grace/gracenet"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/diagnostic"
|
||||
"github.com/cloudflare/cloudflared/metrics"
|
||||
"github.com/cloudflare/cloudflared/tunnelstate"
|
||||
)
|
||||
|
||||
func helperCreateServer(t *testing.T, listeners *gracenet.Net, tunnelID uuid.UUID, connectorID uuid.UUID) func() {
|
||||
t.Helper()
|
||||
listener, err := metrics.CreateMetricsListener(listeners, "localhost:0")
|
||||
require.NoError(t, err)
|
||||
log := zerolog.Nop()
|
||||
tracker := tunnelstate.NewConnTracker(&log)
|
||||
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tunnelID, connectorID, tracker, map[string]string{}, []string{})
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("/diag/tunnel", handler.TunnelStateHandler)
|
||||
server := &http.Server{
|
||||
ReadTimeout: 10 * time.Second,
|
||||
WriteTimeout: 10 * time.Second,
|
||||
Handler: router,
|
||||
}
|
||||
|
||||
var wgroup sync.WaitGroup
|
||||
|
||||
wgroup.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wgroup.Done()
|
||||
|
||||
_ = server.Serve(listener)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
|
||||
cleanUp := func() {
|
||||
_ = server.Shutdown(ctx)
|
||||
|
||||
cancel()
|
||||
wgroup.Wait()
|
||||
}
|
||||
|
||||
return cleanUp
|
||||
}
|
||||
|
||||
func TestFindMetricsServer_WhenSingleServerIsRunning_ReturnState(t *testing.T) {
|
||||
listeners := gracenet.Net{}
|
||||
tid1 := uuid.New()
|
||||
cid1 := uuid.New()
|
||||
|
||||
cleanUp := helperCreateServer(t, &listeners, tid1, cid1)
|
||||
defer cleanUp()
|
||||
|
||||
log := zerolog.Nop()
|
||||
client := diagnostic.NewHTTPClient()
|
||||
addresses := metrics.GetMetricsKnownAddresses("host")
|
||||
url1, err := url.Parse("http://localhost:20241")
|
||||
require.NoError(t, err)
|
||||
|
||||
tunnel1 := &diagnostic.AddressableTunnelState{
|
||||
TunnelState: &diagnostic.TunnelState{
|
||||
TunnelID: tid1,
|
||||
ConnectorID: cid1,
|
||||
Connections: nil,
|
||||
},
|
||||
URL: url1,
|
||||
}
|
||||
|
||||
state, tunnels, err := diagnostic.FindMetricsServer(&log, client, addresses[:])
|
||||
if err != nil {
|
||||
require.ErrorIs(t, err, diagnostic.ErrMultipleMetricsServerFound)
|
||||
}
|
||||
|
||||
assert.Equal(t, tunnel1, state)
|
||||
assert.Nil(t, tunnels)
|
||||
}
|
||||
|
||||
func TestFindMetricsServer_WhenMultipleServerAreRunning_ReturnError(t *testing.T) {
|
||||
listeners := gracenet.Net{}
|
||||
tid1 := uuid.New()
|
||||
cid1 := uuid.New()
|
||||
cid2 := uuid.New()
|
||||
|
||||
cleanUp := helperCreateServer(t, &listeners, tid1, cid1)
|
||||
defer cleanUp()
|
||||
|
||||
cleanUp = helperCreateServer(t, &listeners, tid1, cid2)
|
||||
defer cleanUp()
|
||||
|
||||
log := zerolog.Nop()
|
||||
client := diagnostic.NewHTTPClient()
|
||||
addresses := metrics.GetMetricsKnownAddresses("host")
|
||||
url1, err := url.Parse("http://localhost:20241")
|
||||
require.NoError(t, err)
|
||||
url2, err := url.Parse("http://localhost:20242")
|
||||
require.NoError(t, err)
|
||||
|
||||
tunnel1 := &diagnostic.AddressableTunnelState{
|
||||
TunnelState: &diagnostic.TunnelState{
|
||||
TunnelID: tid1,
|
||||
ConnectorID: cid1,
|
||||
Connections: nil,
|
||||
},
|
||||
URL: url1,
|
||||
}
|
||||
tunnel2 := &diagnostic.AddressableTunnelState{
|
||||
TunnelState: &diagnostic.TunnelState{
|
||||
TunnelID: tid1,
|
||||
ConnectorID: cid2,
|
||||
Connections: nil,
|
||||
},
|
||||
URL: url2,
|
||||
}
|
||||
|
||||
state, tunnels, err := diagnostic.FindMetricsServer(&log, client, addresses[:])
|
||||
if err != nil {
|
||||
require.ErrorIs(t, err, diagnostic.ErrMultipleMetricsServerFound)
|
||||
}
|
||||
|
||||
assert.Nil(t, state)
|
||||
assert.Equal(t, []*diagnostic.AddressableTunnelState{tunnel1, tunnel2}, tunnels)
|
||||
}
|
||||
|
||||
func TestFindMetricsServer_WhenNoInstanceIsRuning_ReturnError(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
client := diagnostic.NewHTTPClient()
|
||||
addresses := metrics.GetMetricsKnownAddresses("host")
|
||||
|
||||
state, tunnels, err := diagnostic.FindMetricsServer(&log, client, addresses[:])
|
||||
require.ErrorIs(t, err, diagnostic.ErrMetricsServerNotFound)
|
||||
|
||||
assert.Nil(t, state)
|
||||
assert.Nil(t, tunnels)
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// Error used when there is no log directory available.
|
||||
ErrManagedLogNotFound = errors.New("managed log directory not found")
|
||||
// Error used when it is not possible to collect logs using the log configuration.
|
||||
ErrLogConfigurationIsInvalid = errors.New("provided log configuration is invalid")
|
||||
// Error used when parsing the fields of the output of collector.
|
||||
ErrInsufficientLines = errors.New("insufficient lines")
|
||||
// Error used when parsing the lines of the output of collector.
|
||||
ErrInsuficientFields = errors.New("insufficient fields")
|
||||
// Error used when given key is not found while parsing KV.
|
||||
ErrKeyNotFound = errors.New("key not found")
|
||||
// Error used when there is no disk volume information available.
|
||||
ErrNoVolumeFound = errors.New("no disk volume information found")
|
||||
// Error user when the base url of the diagnostic client is not provided.
|
||||
ErrNoBaseURL = errors.New("no base url")
|
||||
// Error used when no metrics server is found listening to the known addresses list (check [metrics.GetMetricsKnownAddresses]).
|
||||
ErrMetricsServerNotFound = errors.New("metrics server not found")
|
||||
// Error used when multiple metrics server are found listening to the known addresses list (check [metrics.GetMetricsKnownAddresses]).
|
||||
ErrMultipleMetricsServerFound = errors.New("multiple metrics server found")
|
||||
// Error used when a temporary file creation fails within the diagnostic procedure
|
||||
ErrCreatingTemporaryFile = errors.New("temporary file creation failed")
|
||||
)
|
|
@ -0,0 +1,144 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelstate"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
log *zerolog.Logger
|
||||
timeout time.Duration
|
||||
systemCollector SystemCollector
|
||||
tunnelID uuid.UUID
|
||||
connectorID uuid.UUID
|
||||
tracker *tunnelstate.ConnTracker
|
||||
cliFlags map[string]string
|
||||
icmpSources []string
|
||||
}
|
||||
|
||||
func NewDiagnosticHandler(
|
||||
log *zerolog.Logger,
|
||||
timeout time.Duration,
|
||||
systemCollector SystemCollector,
|
||||
tunnelID uuid.UUID,
|
||||
connectorID uuid.UUID,
|
||||
tracker *tunnelstate.ConnTracker,
|
||||
cliFlags map[string]string,
|
||||
icmpSources []string,
|
||||
) *Handler {
|
||||
logger := log.With().Logger()
|
||||
if timeout == 0 {
|
||||
timeout = defaultCollectorTimeout
|
||||
}
|
||||
|
||||
cliFlags[configurationKeyUID] = strconv.Itoa(os.Getuid())
|
||||
return &Handler{
|
||||
log: &logger,
|
||||
timeout: timeout,
|
||||
systemCollector: systemCollector,
|
||||
tunnelID: tunnelID,
|
||||
connectorID: connectorID,
|
||||
tracker: tracker,
|
||||
cliFlags: cliFlags,
|
||||
icmpSources: icmpSources,
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *Handler) InstallEndpoints(router *http.ServeMux) {
|
||||
router.HandleFunc(cliConfigurationEndpoint, handler.ConfigurationHandler)
|
||||
router.HandleFunc(tunnelStateEndpoint, handler.TunnelStateHandler)
|
||||
router.HandleFunc(systemInformationEndpoint, handler.SystemHandler)
|
||||
}
|
||||
|
||||
type SystemInformationResponse struct {
|
||||
Info *SystemInformation `json:"info"`
|
||||
Err error `json:"errors"`
|
||||
}
|
||||
|
||||
func (handler *Handler) SystemHandler(writer http.ResponseWriter, request *http.Request) {
|
||||
logger := handler.log.With().Str(collectorField, systemCollectorName).Logger()
|
||||
logger.Info().Msg("Collection started")
|
||||
|
||||
defer logger.Info().Msg("Collection finished")
|
||||
|
||||
ctx, cancel := context.WithTimeout(request.Context(), handler.timeout)
|
||||
|
||||
defer cancel()
|
||||
|
||||
info, err := handler.systemCollector.Collect(ctx)
|
||||
|
||||
response := SystemInformationResponse{
|
||||
Info: info,
|
||||
Err: err,
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(writer)
|
||||
err = encoder.Encode(response)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msgf("error occurred whilst serializing information")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
type TunnelState struct {
|
||||
TunnelID uuid.UUID `json:"tunnelID,omitempty"`
|
||||
ConnectorID uuid.UUID `json:"connectorID,omitempty"`
|
||||
Connections []tunnelstate.IndexedConnectionInfo `json:"connections,omitempty"`
|
||||
ICMPSources []string `json:"icmp_sources,omitempty"`
|
||||
}
|
||||
|
||||
func (handler *Handler) TunnelStateHandler(writer http.ResponseWriter, _ *http.Request) {
|
||||
log := handler.log.With().Str(collectorField, tunnelStateCollectorName).Logger()
|
||||
log.Info().Msg("Collection started")
|
||||
|
||||
defer log.Info().Msg("Collection finished")
|
||||
|
||||
body := TunnelState{
|
||||
handler.tunnelID,
|
||||
handler.connectorID,
|
||||
handler.tracker.GetActiveConnections(),
|
||||
handler.icmpSources,
|
||||
}
|
||||
encoder := json.NewEncoder(writer)
|
||||
|
||||
err := encoder.Encode(body)
|
||||
if err != nil {
|
||||
handler.log.Error().Err(err).Msgf("error occurred whilst serializing information")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *Handler) ConfigurationHandler(writer http.ResponseWriter, _ *http.Request) {
|
||||
log := handler.log.With().Str(collectorField, configurationCollectorName).Logger()
|
||||
log.Info().Msg("Collection started")
|
||||
|
||||
defer func() {
|
||||
log.Info().Msg("Collection finished")
|
||||
}()
|
||||
|
||||
encoder := json.NewEncoder(writer)
|
||||
|
||||
err := encoder.Encode(handler.cliFlags)
|
||||
if err != nil {
|
||||
handler.log.Error().Err(err).Msgf("error occurred whilst serializing response")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func writeResponse(w http.ResponseWriter, bytes []byte, logger *zerolog.Logger) {
|
||||
bytesWritten, err := w.Write(bytes)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("error occurred writing response")
|
||||
} else if bytesWritten != len(bytes) {
|
||||
logger.Error().Msgf("error incomplete write response %d/%d", bytesWritten, len(bytes))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,224 @@
|
|||
package diagnostic_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/diagnostic"
|
||||
"github.com/cloudflare/cloudflared/tunnelstate"
|
||||
)
|
||||
|
||||
type SystemCollectorMock struct {
|
||||
systemInfo *diagnostic.SystemInformation
|
||||
err error
|
||||
}
|
||||
|
||||
const (
|
||||
systemInformationKey = "sikey"
|
||||
errorKey = "errkey"
|
||||
)
|
||||
|
||||
func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker {
|
||||
t.Helper()
|
||||
|
||||
log := zerolog.Nop()
|
||||
tracker := tunnelstate.NewConnTracker(&log)
|
||||
|
||||
for _, conn := range connections {
|
||||
tracker.OnTunnelEvent(connection.Event{
|
||||
Index: conn.Index,
|
||||
EventType: connection.Connected,
|
||||
Protocol: conn.Protocol,
|
||||
EdgeAddress: conn.EdgeAddress,
|
||||
})
|
||||
}
|
||||
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (collector *SystemCollectorMock) Collect(context.Context) (*diagnostic.SystemInformation, error) {
|
||||
return collector.systemInfo, collector.err
|
||||
}
|
||||
|
||||
func TestSystemHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := zerolog.Nop()
|
||||
tests := []struct {
|
||||
name string
|
||||
systemInfo *diagnostic.SystemInformation
|
||||
err error
|
||||
statusCode int
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
systemInfo: diagnostic.NewSystemInformation(
|
||||
0, 0, 0, 0,
|
||||
"string", "string", "string", "string",
|
||||
"string", "string",
|
||||
runtime.Version(), runtime.GOARCH, nil,
|
||||
),
|
||||
|
||||
err: nil,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "on error and no raw info", systemInfo: nil,
|
||||
err: errors.New("an error"), statusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range tests {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{
|
||||
systemInfo: tCase.systemInfo,
|
||||
err: tCase.err,
|
||||
}, uuid.New(), uuid.New(), nil, map[string]string{}, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx := context.Background()
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/system", nil)
|
||||
require.NoError(t, err)
|
||||
handler.SystemHandler(recorder, request)
|
||||
|
||||
assert.Equal(t, tCase.statusCode, recorder.Code)
|
||||
if tCase.statusCode == http.StatusOK && tCase.systemInfo != nil {
|
||||
var response diagnostic.SystemInformationResponse
|
||||
decoder := json.NewDecoder(recorder.Body)
|
||||
err := decoder.Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tCase.systemInfo, response.Info)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTunnelStateHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := zerolog.Nop()
|
||||
tests := []struct {
|
||||
name string
|
||||
tunnelID uuid.UUID
|
||||
clientID uuid.UUID
|
||||
connections []tunnelstate.IndexedConnectionInfo
|
||||
icmpSources []string
|
||||
}{
|
||||
{
|
||||
name: "case1",
|
||||
tunnelID: uuid.New(),
|
||||
clientID: uuid.New(),
|
||||
},
|
||||
{
|
||||
name: "case2",
|
||||
tunnelID: uuid.New(),
|
||||
clientID: uuid.New(),
|
||||
icmpSources: []string{"172.17.0.3", "::1"},
|
||||
connections: []tunnelstate.IndexedConnectionInfo{{
|
||||
ConnectionInfo: tunnelstate.ConnectionInfo{
|
||||
IsConnected: true,
|
||||
Protocol: connection.QUIC,
|
||||
EdgeAddress: net.IPv4(100, 100, 100, 100),
|
||||
},
|
||||
Index: 0,
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range tests {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tracker := newTrackerFromConns(t, tCase.connections)
|
||||
handler := diagnostic.NewDiagnosticHandler(
|
||||
&log,
|
||||
0,
|
||||
nil,
|
||||
tCase.tunnelID,
|
||||
tCase.clientID,
|
||||
tracker,
|
||||
map[string]string{},
|
||||
tCase.icmpSources,
|
||||
)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.TunnelStateHandler(recorder, nil)
|
||||
decoder := json.NewDecoder(recorder.Body)
|
||||
|
||||
var response diagnostic.TunnelState
|
||||
err := decoder.Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, tCase.tunnelID, response.TunnelID)
|
||||
assert.Equal(t, tCase.clientID, response.ConnectorID)
|
||||
assert.Equal(t, tCase.connections, response.Connections)
|
||||
assert.Equal(t, tCase.icmpSources, response.ICMPSources)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigurationHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
log := zerolog.Nop()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
flags map[string]string
|
||||
expected map[string]string
|
||||
}{
|
||||
{
|
||||
name: "empty cli",
|
||||
flags: make(map[string]string),
|
||||
expected: map[string]string{
|
||||
"uid": "0",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cli with flags",
|
||||
flags: map[string]string{
|
||||
"b": "a",
|
||||
"c": "a",
|
||||
"d": "a",
|
||||
"uid": "0",
|
||||
},
|
||||
expected: map[string]string{
|
||||
"b": "a",
|
||||
"c": "a",
|
||||
"d": "a",
|
||||
"uid": "0",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range tests {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var response map[string]string
|
||||
|
||||
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, tCase.flags, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ConfigurationHandler(recorder, nil)
|
||||
decoder := json.NewDecoder(recorder.Body)
|
||||
err := decoder.Decode(&response)
|
||||
require.NoError(t, err)
|
||||
_, ok := response["uid"]
|
||||
assert.True(t, ok)
|
||||
delete(tCase.expected, "uid")
|
||||
delete(response, "uid")
|
||||
assert.Equal(t, http.StatusOK, recorder.Code)
|
||||
assert.Equal(t, tCase.expected, response)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Represents the path of the log file or log directory.
|
||||
// This struct is meant to give some ergonimics regarding
|
||||
// the logging information.
|
||||
type LogInformation struct {
|
||||
path string // path to a file or directory
|
||||
wasCreated bool // denotes if `path` was created
|
||||
isDirectory bool // denotes if `path` is a directory
|
||||
}
|
||||
|
||||
func NewLogInformation(
|
||||
path string,
|
||||
wasCreated bool,
|
||||
isDirectory bool,
|
||||
) *LogInformation {
|
||||
return &LogInformation{
|
||||
path,
|
||||
wasCreated,
|
||||
isDirectory,
|
||||
}
|
||||
}
|
||||
|
||||
type LogCollector interface {
|
||||
// This function is responsible for returning a path to a single file
|
||||
// whose contents are the logs of a cloudflared instance.
|
||||
// A new file may be create by a LogCollector, thus, its the caller
|
||||
// responsibility to remove the newly create file.
|
||||
Collect(ctx context.Context) (*LogInformation, error)
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DockerLogCollector struct {
|
||||
containerID string // This member identifies the container by identifier or name
|
||||
}
|
||||
|
||||
func NewDockerLogCollector(containerID string) *DockerLogCollector {
|
||||
return &DockerLogCollector{
|
||||
containerID,
|
||||
}
|
||||
}
|
||||
|
||||
func (collector *DockerLogCollector) Collect(ctx context.Context) (*LogInformation, error) {
|
||||
tmp := os.TempDir()
|
||||
|
||||
outputHandle, err := os.Create(filepath.Join(tmp, logFilename))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening output file: %w", err)
|
||||
}
|
||||
|
||||
defer outputHandle.Close()
|
||||
|
||||
// Calculate 2 weeks ago
|
||||
since := time.Now().Add(twoWeeksOffset).Format(time.RFC3339)
|
||||
|
||||
command := exec.CommandContext(
|
||||
ctx,
|
||||
"docker",
|
||||
"logs",
|
||||
"--tail",
|
||||
tailMaxNumberOfLines,
|
||||
"--since",
|
||||
since,
|
||||
collector.containerID,
|
||||
)
|
||||
|
||||
return PipeCommandOutputToFile(command, outputHandle)
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const (
|
||||
linuxManagedLogsPath = "/var/log/cloudflared.err"
|
||||
darwinManagedLogsPath = "/Library/Logs/com.cloudflare.cloudflared.err.log"
|
||||
linuxServiceConfigurationPath = "/etc/systemd/system/cloudflared.service"
|
||||
linuxSystemdPath = "/run/systemd/system"
|
||||
)
|
||||
|
||||
type HostLogCollector struct {
|
||||
client HTTPClient
|
||||
}
|
||||
|
||||
func NewHostLogCollector(client HTTPClient) *HostLogCollector {
|
||||
return &HostLogCollector{
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
func extractLogsFromJournalCtl(ctx context.Context) (*LogInformation, error) {
|
||||
tmp := os.TempDir()
|
||||
|
||||
outputHandle, err := os.Create(filepath.Join(tmp, logFilename))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening output file: %w", err)
|
||||
}
|
||||
|
||||
defer outputHandle.Close()
|
||||
|
||||
command := exec.CommandContext(
|
||||
ctx,
|
||||
"journalctl",
|
||||
"--since",
|
||||
"2 weeks ago",
|
||||
"-u",
|
||||
"cloudflared.service",
|
||||
)
|
||||
|
||||
return PipeCommandOutputToFile(command, outputHandle)
|
||||
}
|
||||
|
||||
func getServiceLogPath() (string, error) {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
{
|
||||
path := darwinManagedLogsPath
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path, nil
|
||||
}
|
||||
|
||||
userHomeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting user home: %w", err)
|
||||
}
|
||||
|
||||
return filepath.Join(userHomeDir, darwinManagedLogsPath), nil
|
||||
}
|
||||
case "linux":
|
||||
{
|
||||
return linuxManagedLogsPath, nil
|
||||
}
|
||||
default:
|
||||
return "", ErrManagedLogNotFound
|
||||
}
|
||||
}
|
||||
|
||||
func (collector *HostLogCollector) Collect(ctx context.Context) (*LogInformation, error) {
|
||||
logConfiguration, err := collector.client.GetLogConfiguration(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting log configuration: %w", err)
|
||||
}
|
||||
|
||||
if logConfiguration.uid == 0 {
|
||||
_, statSystemdErr := os.Stat(linuxServiceConfigurationPath)
|
||||
|
||||
_, statServiceConfigurationErr := os.Stat(linuxServiceConfigurationPath)
|
||||
if statSystemdErr == nil && statServiceConfigurationErr == nil && runtime.GOOS == "linux" {
|
||||
return extractLogsFromJournalCtl(ctx)
|
||||
}
|
||||
|
||||
path, err := getServiceLogPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewLogInformation(path, false, false), nil
|
||||
}
|
||||
|
||||
if logConfiguration.logFile != "" {
|
||||
return NewLogInformation(logConfiguration.logFile, false, false), nil
|
||||
} else if logConfiguration.logDirectory != "" {
|
||||
return NewLogInformation(logConfiguration.logDirectory, false, true), nil
|
||||
}
|
||||
|
||||
return nil, ErrLogConfigurationIsInvalid
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
type KubernetesLogCollector struct {
|
||||
containerID string // This member identifies the container by identifier or name
|
||||
pod string // This member identifies the pod where the container is deployed
|
||||
}
|
||||
|
||||
func NewKubernetesLogCollector(containerID, pod string) *KubernetesLogCollector {
|
||||
return &KubernetesLogCollector{
|
||||
containerID,
|
||||
pod,
|
||||
}
|
||||
}
|
||||
|
||||
func (collector *KubernetesLogCollector) Collect(ctx context.Context) (*LogInformation, error) {
|
||||
tmp := os.TempDir()
|
||||
outputHandle, err := os.Create(filepath.Join(tmp, logFilename))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening output file: %w", err)
|
||||
}
|
||||
|
||||
defer outputHandle.Close()
|
||||
|
||||
var command *exec.Cmd
|
||||
// Calculate 2 weeks ago
|
||||
since := time.Now().Add(twoWeeksOffset).Format(time.RFC3339)
|
||||
if collector.containerID != "" {
|
||||
command = exec.CommandContext(
|
||||
ctx,
|
||||
"kubectl",
|
||||
"logs",
|
||||
collector.pod,
|
||||
"--since-time",
|
||||
since,
|
||||
"--tail",
|
||||
tailMaxNumberOfLines,
|
||||
"-c",
|
||||
collector.containerID,
|
||||
)
|
||||
} else {
|
||||
command = exec.CommandContext(
|
||||
ctx,
|
||||
"kubectl",
|
||||
"logs",
|
||||
collector.pod,
|
||||
"--since-time",
|
||||
since,
|
||||
"--tail",
|
||||
tailMaxNumberOfLines,
|
||||
)
|
||||
}
|
||||
|
||||
return PipeCommandOutputToFile(command, outputHandle)
|
||||
}
|
|
@ -0,0 +1,109 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func PipeCommandOutputToFile(command *exec.Cmd, outputHandle *os.File) (*LogInformation, error) {
|
||||
stdoutReader, err := command.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error retrieving stdout from command '%s': %w",
|
||||
command.String(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
stderrReader, err := command.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error retrieving stderr from command '%s': %w",
|
||||
command.String(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
if err := command.Start(); err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error running command '%s': %w",
|
||||
command.String(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
_, err = io.Copy(outputHandle, stdoutReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error copying stdout from %s to file %s: %w",
|
||||
command.String(),
|
||||
outputHandle.Name(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
_, err = io.Copy(outputHandle, stderrReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error copying stderr from %s to file %s: %w",
|
||||
command.String(),
|
||||
outputHandle.Name(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
if err := command.Wait(); err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error waiting from command '%s': %w",
|
||||
command.String(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return NewLogInformation(outputHandle.Name(), true, false), nil
|
||||
}
|
||||
|
||||
func CopyFilesFromDirectory(path string) (string, error) {
|
||||
// rolling logs have as suffix the current date thus
|
||||
// when iterating the path files they are already in
|
||||
// chronological order
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error reading directory %s: %w", path, err)
|
||||
}
|
||||
|
||||
outputHandle, err := os.Create(filepath.Join(os.TempDir(), logFilename))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating file %s: %w", outputHandle.Name(), err)
|
||||
}
|
||||
defer outputHandle.Close()
|
||||
|
||||
for _, file := range files {
|
||||
logHandle, err := os.Open(filepath.Join(path, file.Name()))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error opening file %s:%w", file.Name(), err)
|
||||
}
|
||||
defer logHandle.Close()
|
||||
|
||||
_, err = io.Copy(outputHandle, logHandle)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error copying file %s:%w", logHandle.Name(), err)
|
||||
}
|
||||
}
|
||||
|
||||
logHandle, err := os.Open(filepath.Join(path, "cloudflared.log"))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error opening file %s:%w", logHandle.Name(), err)
|
||||
}
|
||||
defer logHandle.Close()
|
||||
|
||||
_, err = io.Copy(outputHandle, logHandle)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error copying file %s:%w", logHandle.Name(), err)
|
||||
}
|
||||
|
||||
return outputHandle.Name(), nil
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
const MicrosecondsFactor = 1000.0
|
||||
|
||||
var ErrEmptyDomain = errors.New("domain must not be empty")
|
||||
|
||||
// For now only support ICMP is provided.
|
||||
type IPVersion int
|
||||
|
||||
const (
|
||||
V4 IPVersion = iota
|
||||
V6 IPVersion = iota
|
||||
)
|
||||
|
||||
type Hop struct {
|
||||
Hop uint8 `json:"hop,omitempty"` // hop number along the route
|
||||
Domain string `json:"domain,omitempty"` // domain and/or ip of the hop, this field will be '*' if the hop is a timeout
|
||||
Rtts []time.Duration `json:"rtts,omitempty"` // RTT measurements in microseconds
|
||||
}
|
||||
|
||||
type TraceOptions struct {
|
||||
ttl uint64 // number of hops to perform
|
||||
timeout time.Duration // wait timeout for each response
|
||||
address string // address to trace
|
||||
useV4 bool
|
||||
}
|
||||
|
||||
func NewTimeoutHop(
|
||||
hop uint8,
|
||||
) *Hop {
|
||||
// Whenever there is a hop in the format of 'N * * *'
|
||||
// it means that the hop in the path didn't answer to
|
||||
// any probe.
|
||||
return NewHop(
|
||||
hop,
|
||||
"*",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
func NewHop(hop uint8, domain string, rtts []time.Duration) *Hop {
|
||||
return &Hop{
|
||||
hop,
|
||||
domain,
|
||||
rtts,
|
||||
}
|
||||
}
|
||||
|
||||
func NewTraceOptions(
|
||||
ttl uint64,
|
||||
timeout time.Duration,
|
||||
address string,
|
||||
useV4 bool,
|
||||
) TraceOptions {
|
||||
return TraceOptions{
|
||||
ttl,
|
||||
timeout,
|
||||
address,
|
||||
useV4,
|
||||
}
|
||||
}
|
||||
|
||||
type NetworkCollector interface {
|
||||
// Performs a trace route operation with the specified options.
|
||||
// In case the trace fails, it will return a non-nil error and
|
||||
// it may return a string which represents the raw information
|
||||
// obtained.
|
||||
// In case it is successful it will only return an array of Hops
|
||||
// an empty string and a nil error.
|
||||
Collect(ctx context.Context, options TraceOptions) ([]*Hop, string, error)
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
//go:build darwin || linux
|
||||
|
||||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type NetworkCollectorImpl struct{}
|
||||
|
||||
func (tracer *NetworkCollectorImpl) Collect(ctx context.Context, options TraceOptions) ([]*Hop, string, error) {
|
||||
args := []string{
|
||||
"-I",
|
||||
"-w",
|
||||
strconv.FormatInt(int64(options.timeout.Seconds()), 10),
|
||||
"-m",
|
||||
strconv.FormatUint(options.ttl, 10),
|
||||
options.address,
|
||||
}
|
||||
|
||||
var command string
|
||||
|
||||
switch options.useV4 {
|
||||
case false:
|
||||
command = "traceroute6"
|
||||
default:
|
||||
command = "traceroute"
|
||||
}
|
||||
|
||||
process := exec.CommandContext(ctx, command, args...)
|
||||
|
||||
return decodeNetworkOutputToFile(process, DecodeLine)
|
||||
}
|
||||
|
||||
func DecodeLine(text string) (*Hop, error) {
|
||||
fields := strings.Fields(text)
|
||||
parts := []string{}
|
||||
filter := func(s string) bool { return s != "*" && s != "ms" }
|
||||
|
||||
for _, field := range fields {
|
||||
if filter(field) {
|
||||
parts = append(parts, field)
|
||||
}
|
||||
}
|
||||
|
||||
index, err := strconv.ParseUint(parts[0], 10, 8)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't parse index from timeout hop: %w", err)
|
||||
}
|
||||
|
||||
if len(parts) == 1 {
|
||||
return NewTimeoutHop(uint8(index)), nil
|
||||
}
|
||||
|
||||
domain := ""
|
||||
rtts := []time.Duration{}
|
||||
|
||||
for _, part := range parts[1:] {
|
||||
rtt, err := strconv.ParseFloat(part, 64)
|
||||
if err != nil {
|
||||
domain += part + " "
|
||||
} else {
|
||||
rtts = append(rtts, time.Duration(rtt*MicrosecondsFactor))
|
||||
}
|
||||
}
|
||||
|
||||
domain, _ = strings.CutSuffix(domain, " ")
|
||||
if domain == "" {
|
||||
return nil, ErrEmptyDomain
|
||||
}
|
||||
|
||||
return NewHop(uint8(index), domain, rtts), nil
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
//go:build darwin || linux
|
||||
|
||||
package diagnostic_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
diagnostic "github.com/cloudflare/cloudflared/diagnostic/network"
|
||||
)
|
||||
|
||||
func TestDecode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
expectedHops []*diagnostic.Hop
|
||||
}{
|
||||
{
|
||||
"repeated hop index parse failure",
|
||||
`1 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
|
||||
2 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
|
||||
someletters * * *
|
||||
4 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms `,
|
||||
[]*diagnostic.Hop{
|
||||
diagnostic.NewHop(
|
||||
uint8(1),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
diagnostic.NewHop(
|
||||
uint8(2),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
diagnostic.NewHop(
|
||||
uint8(4),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
},
|
||||
},
|
||||
{
|
||||
"hop index parse failure",
|
||||
`1 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
|
||||
2 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
|
||||
someletters 8.8.8.8 8.8.8.9 abc ms 0.456 ms 0.789 ms`,
|
||||
[]*diagnostic.Hop{
|
||||
diagnostic.NewHop(
|
||||
uint8(1),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
diagnostic.NewHop(
|
||||
uint8(2),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
},
|
||||
},
|
||||
{
|
||||
"missing rtt",
|
||||
`1 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
|
||||
2 * 8.8.8.8 8.8.8.9 0.456 ms 0.789 ms`,
|
||||
[]*diagnostic.Hop{
|
||||
diagnostic.NewHop(
|
||||
uint8(1),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
diagnostic.NewHop(
|
||||
uint8(2),
|
||||
"8.8.8.8 8.8.8.9",
|
||||
[]time.Duration{
|
||||
time.Duration(456),
|
||||
time.Duration(789),
|
||||
},
|
||||
),
|
||||
},
|
||||
},
|
||||
{
|
||||
"simple example ipv4",
|
||||
`1 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
|
||||
2 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
|
||||
3 * * *`,
|
||||
[]*diagnostic.Hop{
|
||||
diagnostic.NewHop(
|
||||
uint8(1),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
diagnostic.NewHop(
|
||||
uint8(2),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
diagnostic.NewTimeoutHop(uint8(3)),
|
||||
},
|
||||
},
|
||||
{
|
||||
"simple example ipv6",
|
||||
` 1 2400:cb00:107:1024::ac44:6550 12.780 ms 9.118 ms 10.046 ms
|
||||
2 2a09:bac1:: 9.945 ms 10.033 ms 11.562 ms`,
|
||||
[]*diagnostic.Hop{
|
||||
diagnostic.NewHop(
|
||||
uint8(1),
|
||||
"2400:cb00:107:1024::ac44:6550",
|
||||
[]time.Duration{
|
||||
time.Duration(12780),
|
||||
time.Duration(9118),
|
||||
time.Duration(10046),
|
||||
},
|
||||
),
|
||||
diagnostic.NewHop(
|
||||
uint8(2),
|
||||
"2a09:bac1::",
|
||||
[]time.Duration{
|
||||
time.Duration(9945),
|
||||
time.Duration(10033),
|
||||
time.Duration(11562),
|
||||
},
|
||||
),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hops, err := diagnostic.Decode(strings.NewReader(test.text), diagnostic.DecodeLine)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expectedHops, hops)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
type DecodeLineFunc func(text string) (*Hop, error)
|
||||
|
||||
func decodeNetworkOutputToFile(command *exec.Cmd, decodeLine DecodeLineFunc) ([]*Hop, string, error) {
|
||||
stdout, err := command.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error piping traceroute's output: %w", err)
|
||||
}
|
||||
|
||||
if err := command.Start(); err != nil {
|
||||
return nil, "", fmt.Errorf("error starting traceroute: %w", err)
|
||||
}
|
||||
|
||||
// Tee the output to a string to have the raw information
|
||||
// in case the decode call fails
|
||||
// This error is handled only after the Wait call below returns
|
||||
// otherwise the process can become a zombie
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
tee := io.TeeReader(stdout, buf)
|
||||
hops, err := Decode(tee, decodeLine)
|
||||
// regardless of success of the decoding
|
||||
// consume all output to have available in buf
|
||||
_, _ = io.ReadAll(tee)
|
||||
|
||||
if werr := command.Wait(); werr != nil {
|
||||
return nil, "", fmt.Errorf("error finishing traceroute: %w", werr)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, buf.String(), err
|
||||
}
|
||||
|
||||
return hops, buf.String(), nil
|
||||
}
|
||||
|
||||
func Decode(reader io.Reader, decodeLine DecodeLineFunc) ([]*Hop, error) {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
var hops []*Hop
|
||||
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
hop, err := decodeLine(text)
|
||||
if err != nil {
|
||||
// This continue is here on the error case because there are lines at the start and end
|
||||
// that may not be parsable. (check windows tracert output)
|
||||
// The skip is here because aside from the start and end lines the other lines should
|
||||
// always be parsable without errors.
|
||||
continue
|
||||
}
|
||||
|
||||
hops = append(hops, hop)
|
||||
}
|
||||
|
||||
if scanner.Err() != nil {
|
||||
return nil, fmt.Errorf("scanner reported an error: %w", scanner.Err())
|
||||
}
|
||||
|
||||
return hops, nil
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
//go:build windows
|
||||
|
||||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type NetworkCollectorImpl struct{}
|
||||
|
||||
func (tracer *NetworkCollectorImpl) Collect(ctx context.Context, options TraceOptions) ([]*Hop, string, error) {
|
||||
ipversion := "-4"
|
||||
if !options.useV4 {
|
||||
ipversion = "-6"
|
||||
}
|
||||
|
||||
args := []string{
|
||||
ipversion,
|
||||
"-w",
|
||||
strconv.FormatInt(int64(options.timeout.Seconds()), 10),
|
||||
"-h",
|
||||
strconv.FormatUint(options.ttl, 10),
|
||||
// Do not resolve host names (can add 30+ seconds to run time)
|
||||
"-d",
|
||||
options.address,
|
||||
}
|
||||
command := exec.CommandContext(ctx, "tracert.exe", args...)
|
||||
|
||||
return decodeNetworkOutputToFile(command, DecodeLine)
|
||||
}
|
||||
|
||||
func DecodeLine(text string) (*Hop, error) {
|
||||
const requestTimedOut = "Request timed out."
|
||||
|
||||
fields := strings.Fields(text)
|
||||
parts := []string{}
|
||||
filter := func(s string) bool { return s != "*" && s != "ms" }
|
||||
|
||||
for _, field := range fields {
|
||||
if filter(field) {
|
||||
parts = append(parts, field)
|
||||
}
|
||||
}
|
||||
|
||||
index, err := strconv.ParseUint(parts[0], 10, 8)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("couldn't parse index from timeout hop: %w", err)
|
||||
}
|
||||
|
||||
domain := ""
|
||||
rtts := []time.Duration{}
|
||||
|
||||
for _, part := range parts[1:] {
|
||||
|
||||
rtt, err := strconv.ParseFloat(strings.TrimLeft(part, "<"), 64)
|
||||
|
||||
if err != nil {
|
||||
domain += part + " "
|
||||
} else {
|
||||
rtts = append(rtts, time.Duration(rtt*MicrosecondsFactor))
|
||||
}
|
||||
}
|
||||
|
||||
domain, _ = strings.CutSuffix(domain, " ")
|
||||
// If the domain is equal to "Request timed out." then we build a
|
||||
// timeout hop.
|
||||
if domain == requestTimedOut {
|
||||
return NewTimeoutHop(uint8(index)), nil
|
||||
}
|
||||
|
||||
if domain == "" {
|
||||
return nil, ErrEmptyDomain
|
||||
}
|
||||
|
||||
return NewHop(uint8(index), domain, rtts), nil
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
//go:build windows
|
||||
|
||||
package diagnostic_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
diagnostic "github.com/cloudflare/cloudflared/diagnostic/network"
|
||||
)
|
||||
|
||||
func TestDecode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
expectedHops []*diagnostic.Hop
|
||||
}{
|
||||
|
||||
{
|
||||
"tracert output",
|
||||
`
|
||||
Tracing route to region2.v2.argotunnel.com [198.41.200.73]
|
||||
over a maximum of 5 hops:
|
||||
|
||||
1 10 ms <1 ms 1 ms 192.168.64.1
|
||||
2 27 ms 14 ms 5 ms 192.168.1.254
|
||||
3 * * * Request timed out.
|
||||
4 * * * Request timed out.
|
||||
5 27 ms 5 ms 5 ms 195.8.30.245
|
||||
|
||||
Trace complete.
|
||||
`,
|
||||
[]*diagnostic.Hop{
|
||||
diagnostic.NewHop(
|
||||
uint8(1),
|
||||
"192.168.64.1",
|
||||
[]time.Duration{
|
||||
time.Duration(10000),
|
||||
time.Duration(1000),
|
||||
time.Duration(1000),
|
||||
},
|
||||
),
|
||||
diagnostic.NewHop(
|
||||
uint8(2),
|
||||
"192.168.1.254",
|
||||
[]time.Duration{
|
||||
time.Duration(27000),
|
||||
time.Duration(14000),
|
||||
time.Duration(5000),
|
||||
},
|
||||
),
|
||||
diagnostic.NewTimeoutHop(uint8(3)),
|
||||
diagnostic.NewTimeoutHop(uint8(4)),
|
||||
diagnostic.NewHop(
|
||||
uint8(5),
|
||||
"195.8.30.245",
|
||||
[]time.Duration{
|
||||
time.Duration(27000),
|
||||
time.Duration(5000),
|
||||
time.Duration(5000),
|
||||
},
|
||||
),
|
||||
},
|
||||
},
|
||||
{
|
||||
"repeated hop index parse failure",
|
||||
`1 12.874 ms 15.517 ms 15.311 ms 172.68.101.121 (172.68.101.121)
|
||||
2 12.874 ms 15.517 ms 15.311 ms 172.68.101.121 (172.68.101.121)
|
||||
someletters * * *`,
|
||||
[]*diagnostic.Hop{
|
||||
diagnostic.NewHop(
|
||||
uint8(1),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
diagnostic.NewHop(
|
||||
uint8(2),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
},
|
||||
},
|
||||
{
|
||||
"hop index parse failure",
|
||||
`1 12.874 ms 15.517 ms 15.311 ms 172.68.101.121 (172.68.101.121)
|
||||
2 12.874 ms 15.517 ms 15.311 ms 172.68.101.121 (172.68.101.121)
|
||||
someletters abc ms 0.456 ms 0.789 ms 8.8.8.8 8.8.8.9`,
|
||||
[]*diagnostic.Hop{
|
||||
diagnostic.NewHop(
|
||||
uint8(1),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
diagnostic.NewHop(
|
||||
uint8(2),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
},
|
||||
},
|
||||
{
|
||||
"missing rtt",
|
||||
`1 <12.874 ms <15.517 ms <15.311 ms 172.68.101.121 (172.68.101.121)
|
||||
2 * 0.456 ms 0.789 ms 8.8.8.8 8.8.8.9`,
|
||||
[]*diagnostic.Hop{
|
||||
diagnostic.NewHop(
|
||||
uint8(1),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
diagnostic.NewHop(
|
||||
uint8(2),
|
||||
"8.8.8.8 8.8.8.9",
|
||||
[]time.Duration{
|
||||
time.Duration(456),
|
||||
time.Duration(789),
|
||||
},
|
||||
),
|
||||
},
|
||||
},
|
||||
{
|
||||
"simple example ipv4",
|
||||
`1 12.874 ms 15.517 ms 15.311 ms 172.68.101.121 (172.68.101.121)
|
||||
2 12.874 ms 15.517 ms 15.311 ms 172.68.101.121 (172.68.101.121)
|
||||
3 * * * Request timed out.`,
|
||||
[]*diagnostic.Hop{
|
||||
diagnostic.NewHop(
|
||||
uint8(1),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
diagnostic.NewHop(
|
||||
uint8(2),
|
||||
"172.68.101.121 (172.68.101.121)",
|
||||
[]time.Duration{
|
||||
time.Duration(12874),
|
||||
time.Duration(15517),
|
||||
time.Duration(15311),
|
||||
},
|
||||
),
|
||||
diagnostic.NewTimeoutHop(uint8(3)),
|
||||
},
|
||||
},
|
||||
{
|
||||
"simple example ipv6",
|
||||
` 1 12.780 ms 9.118 ms 10.046 ms 2400:cb00:107:1024::ac44:6550
|
||||
2 9.945 ms 10.033 ms 11.562 ms 2a09:bac1::`,
|
||||
[]*diagnostic.Hop{
|
||||
diagnostic.NewHop(
|
||||
uint8(1),
|
||||
"2400:cb00:107:1024::ac44:6550",
|
||||
[]time.Duration{
|
||||
time.Duration(12780),
|
||||
time.Duration(9118),
|
||||
time.Duration(10046),
|
||||
},
|
||||
),
|
||||
diagnostic.NewHop(
|
||||
uint8(2),
|
||||
"2a09:bac1::",
|
||||
[]time.Duration{
|
||||
time.Duration(9945),
|
||||
time.Duration(10033),
|
||||
time.Duration(11562),
|
||||
},
|
||||
),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hops, err := diagnostic.Decode(strings.NewReader(test.text), diagnostic.DecodeLine)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expectedHops, hops)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SystemInformationError struct {
|
||||
Err error `json:"error"`
|
||||
RawInfo string `json:"rawInfo"`
|
||||
}
|
||||
|
||||
func (err SystemInformationError) Error() string {
|
||||
return err.Err.Error()
|
||||
}
|
||||
|
||||
func (err SystemInformationError) MarshalJSON() ([]byte, error) {
|
||||
s := map[string]string{
|
||||
"error": err.Err.Error(),
|
||||
"rawInfo": err.RawInfo,
|
||||
}
|
||||
|
||||
return json.Marshal(s)
|
||||
}
|
||||
|
||||
type SystemInformationGeneralError struct {
|
||||
OperatingSystemInformationError error
|
||||
MemoryInformationError error
|
||||
FileDescriptorsInformationError error
|
||||
DiskVolumeInformationError error
|
||||
}
|
||||
|
||||
func (err SystemInformationGeneralError) Error() string {
|
||||
builder := &strings.Builder{}
|
||||
builder.WriteString("errors found:")
|
||||
|
||||
if err.OperatingSystemInformationError != nil {
|
||||
builder.WriteString(err.OperatingSystemInformationError.Error() + ", ")
|
||||
}
|
||||
|
||||
if err.MemoryInformationError != nil {
|
||||
builder.WriteString(err.MemoryInformationError.Error() + ", ")
|
||||
}
|
||||
|
||||
if err.FileDescriptorsInformationError != nil {
|
||||
builder.WriteString(err.FileDescriptorsInformationError.Error() + ", ")
|
||||
}
|
||||
|
||||
if err.DiskVolumeInformationError != nil {
|
||||
builder.WriteString(err.DiskVolumeInformationError.Error() + ", ")
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (err SystemInformationGeneralError) MarshalJSON() ([]byte, error) {
|
||||
data := map[string]SystemInformationError{}
|
||||
|
||||
var sysErr SystemInformationError
|
||||
if errors.As(err.OperatingSystemInformationError, &sysErr) {
|
||||
data["operatingSystemInformationError"] = sysErr
|
||||
}
|
||||
|
||||
if errors.As(err.MemoryInformationError, &sysErr) {
|
||||
data["memoryInformationError"] = sysErr
|
||||
}
|
||||
|
||||
if errors.As(err.FileDescriptorsInformationError, &sysErr) {
|
||||
data["fileDescriptorsInformationError"] = sysErr
|
||||
}
|
||||
|
||||
if errors.As(err.DiskVolumeInformationError, &sysErr) {
|
||||
data["diskVolumeInformationError"] = sysErr
|
||||
}
|
||||
|
||||
return json.Marshal(data)
|
||||
}
|
||||
|
||||
type DiskVolumeInformation struct {
|
||||
Name string `json:"name"` // represents the filesystem in linux/macos or device name in windows
|
||||
SizeMaximum uint64 `json:"sizeMaximum"` // represents the maximum size of the disk in kilobytes
|
||||
SizeCurrent uint64 `json:"sizeCurrent"` // represents the current size of the disk in kilobytes
|
||||
}
|
||||
|
||||
func NewDiskVolumeInformation(name string, maximum, current uint64) *DiskVolumeInformation {
|
||||
return &DiskVolumeInformation{
|
||||
name,
|
||||
maximum,
|
||||
current,
|
||||
}
|
||||
}
|
||||
|
||||
type SystemInformation struct {
|
||||
MemoryMaximum uint64 `json:"memoryMaximum,omitempty"` // represents the maximum memory of the system in kilobytes
|
||||
MemoryCurrent uint64 `json:"memoryCurrent,omitempty"` // represents the system's memory in use in kilobytes
|
||||
FileDescriptorMaximum uint64 `json:"fileDescriptorMaximum,omitempty"` // represents the maximum number of file descriptors of the system
|
||||
FileDescriptorCurrent uint64 `json:"fileDescriptorCurrent,omitempty"` // represents the system's file descriptors in use
|
||||
OsSystem string `json:"osSystem,omitempty"` // represents the operating system name i.e.: linux, windows, darwin
|
||||
HostName string `json:"hostName,omitempty"` // represents the system host name
|
||||
OsVersion string `json:"osVersion,omitempty"` // detailed information about the system's release version level
|
||||
OsRelease string `json:"osRelease,omitempty"` // detailed information about the system's release
|
||||
Architecture string `json:"architecture,omitempty"` // represents the system's hardware platform i.e: arm64/amd64
|
||||
CloudflaredVersion string `json:"cloudflaredVersion,omitempty"` // the runtime version of cloudflared
|
||||
GoVersion string `json:"goVersion,omitempty"`
|
||||
GoArch string `json:"goArch,omitempty"`
|
||||
Disk []*DiskVolumeInformation `json:"disk,omitempty"`
|
||||
}
|
||||
|
||||
func NewSystemInformation(
|
||||
memoryMaximum,
|
||||
memoryCurrent,
|
||||
filesMaximum,
|
||||
filesCurrent uint64,
|
||||
osystem,
|
||||
name,
|
||||
osVersion,
|
||||
osRelease,
|
||||
architecture,
|
||||
cloudflaredVersion,
|
||||
goVersion,
|
||||
goArchitecture string,
|
||||
disk []*DiskVolumeInformation,
|
||||
) *SystemInformation {
|
||||
return &SystemInformation{
|
||||
memoryMaximum,
|
||||
memoryCurrent,
|
||||
filesMaximum,
|
||||
filesCurrent,
|
||||
osystem,
|
||||
name,
|
||||
osVersion,
|
||||
osRelease,
|
||||
architecture,
|
||||
cloudflaredVersion,
|
||||
goVersion,
|
||||
goArchitecture,
|
||||
disk,
|
||||
}
|
||||
}
|
||||
|
||||
type SystemCollector interface {
|
||||
// If the collection is successful it will return `SystemInformation` struct,
|
||||
// and a nil error.
|
||||
//
|
||||
// This function expects that the caller sets the context timeout to prevent
|
||||
// long-lived collectors.
|
||||
Collect(ctx context.Context) (*SystemInformation, error)
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
//go:build linux
|
||||
|
||||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SystemCollectorImpl struct {
|
||||
version string
|
||||
}
|
||||
|
||||
func NewSystemCollectorImpl(
|
||||
version string,
|
||||
) *SystemCollectorImpl {
|
||||
return &SystemCollectorImpl{
|
||||
version,
|
||||
}
|
||||
}
|
||||
|
||||
func (collector *SystemCollectorImpl) Collect(ctx context.Context) (*SystemInformation, error) {
|
||||
memoryInfo, memoryInfoRaw, memoryInfoErr := collectMemoryInformation(ctx)
|
||||
fdInfo, fdInfoRaw, fdInfoErr := collectFileDescriptorInformation(ctx)
|
||||
disks, disksRaw, diskErr := collectDiskVolumeInformationUnix(ctx)
|
||||
osInfo, osInfoRaw, osInfoErr := collectOSInformationUnix(ctx)
|
||||
|
||||
var memoryMaximum, memoryCurrent, fileDescriptorMaximum, fileDescriptorCurrent uint64
|
||||
var osSystem, name, osVersion, osRelease, architecture string
|
||||
gerror := SystemInformationGeneralError{}
|
||||
|
||||
if memoryInfoErr != nil {
|
||||
gerror.MemoryInformationError = SystemInformationError{
|
||||
Err: memoryInfoErr,
|
||||
RawInfo: memoryInfoRaw,
|
||||
}
|
||||
} else {
|
||||
memoryMaximum = memoryInfo.MemoryMaximum
|
||||
memoryCurrent = memoryInfo.MemoryCurrent
|
||||
}
|
||||
|
||||
if fdInfoErr != nil {
|
||||
gerror.FileDescriptorsInformationError = SystemInformationError{
|
||||
Err: fdInfoErr,
|
||||
RawInfo: fdInfoRaw,
|
||||
}
|
||||
} else {
|
||||
fileDescriptorMaximum = fdInfo.FileDescriptorMaximum
|
||||
fileDescriptorCurrent = fdInfo.FileDescriptorCurrent
|
||||
}
|
||||
|
||||
if diskErr != nil {
|
||||
gerror.DiskVolumeInformationError = SystemInformationError{
|
||||
Err: diskErr,
|
||||
RawInfo: disksRaw,
|
||||
}
|
||||
}
|
||||
|
||||
if osInfoErr != nil {
|
||||
gerror.OperatingSystemInformationError = SystemInformationError{
|
||||
Err: osInfoErr,
|
||||
RawInfo: osInfoRaw,
|
||||
}
|
||||
} else {
|
||||
osSystem = osInfo.OsSystem
|
||||
name = osInfo.Name
|
||||
osVersion = osInfo.OsVersion
|
||||
osRelease = osInfo.OsRelease
|
||||
architecture = osInfo.Architecture
|
||||
}
|
||||
|
||||
cloudflaredVersion := collector.version
|
||||
info := NewSystemInformation(
|
||||
memoryMaximum,
|
||||
memoryCurrent,
|
||||
fileDescriptorMaximum,
|
||||
fileDescriptorCurrent,
|
||||
osSystem,
|
||||
name,
|
||||
osVersion,
|
||||
osRelease,
|
||||
architecture,
|
||||
cloudflaredVersion,
|
||||
runtime.Version(),
|
||||
runtime.GOARCH,
|
||||
disks,
|
||||
)
|
||||
|
||||
return info, gerror
|
||||
}
|
||||
|
||||
func collectMemoryInformation(ctx context.Context) (*MemoryInformation, string, error) {
|
||||
// This function relies on the output of `cat /proc/meminfo` to retrieve
|
||||
// memoryMax and memoryCurrent.
|
||||
// The expected output is in the format of `KEY VALUE UNIT`.
|
||||
const (
|
||||
memTotalPrefix = "MemTotal"
|
||||
memAvailablePrefix = "MemAvailable"
|
||||
)
|
||||
|
||||
command := exec.CommandContext(ctx, "cat", "/proc/meminfo")
|
||||
|
||||
stdout, err := command.Output()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
|
||||
}
|
||||
|
||||
output := string(stdout)
|
||||
|
||||
mapper := func(field string) (uint64, error) {
|
||||
field = strings.TrimRight(field, " kB")
|
||||
|
||||
return strconv.ParseUint(field, 10, 64)
|
||||
}
|
||||
|
||||
memoryInfo, err := ParseMemoryInformationFromKV(output, memTotalPrefix, memAvailablePrefix, mapper)
|
||||
if err != nil {
|
||||
return nil, output, err
|
||||
}
|
||||
|
||||
// returning raw output in case other collected information
|
||||
// resulted in errors
|
||||
return memoryInfo, output, nil
|
||||
}
|
||||
|
||||
func collectFileDescriptorInformation(ctx context.Context) (*FileDescriptorInformation, string, error) {
|
||||
// Command retrieved from https://docs.kernel.org/admin-guide/sysctl/fs.html#file-max-file-nr.
|
||||
// If the sysctl is not available the command with fail.
|
||||
command := exec.CommandContext(ctx, "sysctl", "-n", "fs.file-nr")
|
||||
|
||||
stdout, err := command.Output()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
|
||||
}
|
||||
|
||||
output := string(stdout)
|
||||
|
||||
fileDescriptorInfo, err := ParseSysctlFileDescriptorInformation(output)
|
||||
if err != nil {
|
||||
return nil, output, err
|
||||
}
|
||||
|
||||
// returning raw output in case other collected information
|
||||
// resulted in errors
|
||||
return fileDescriptorInfo, output, nil
|
||||
}
|
|
@ -0,0 +1,172 @@
|
|||
//go:build darwin
|
||||
|
||||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type SystemCollectorImpl struct {
|
||||
version string
|
||||
}
|
||||
|
||||
func NewSystemCollectorImpl(
|
||||
version string,
|
||||
) *SystemCollectorImpl {
|
||||
return &SystemCollectorImpl{
|
||||
version,
|
||||
}
|
||||
}
|
||||
|
||||
func (collector *SystemCollectorImpl) Collect(ctx context.Context) (*SystemInformation, error) {
|
||||
memoryInfo, memoryInfoRaw, memoryInfoErr := collectMemoryInformation(ctx)
|
||||
fdInfo, fdInfoRaw, fdInfoErr := collectFileDescriptorInformation(ctx)
|
||||
disks, disksRaw, diskErr := collectDiskVolumeInformationUnix(ctx)
|
||||
osInfo, osInfoRaw, osInfoErr := collectOSInformationUnix(ctx)
|
||||
|
||||
var memoryMaximum, memoryCurrent, fileDescriptorMaximum, fileDescriptorCurrent uint64
|
||||
var osSystem, name, osVersion, osRelease, architecture string
|
||||
|
||||
err := SystemInformationGeneralError{
|
||||
OperatingSystemInformationError: nil,
|
||||
MemoryInformationError: nil,
|
||||
FileDescriptorsInformationError: nil,
|
||||
DiskVolumeInformationError: nil,
|
||||
}
|
||||
|
||||
if memoryInfoErr != nil {
|
||||
err.MemoryInformationError = SystemInformationError{
|
||||
Err: memoryInfoErr,
|
||||
RawInfo: memoryInfoRaw,
|
||||
}
|
||||
} else {
|
||||
memoryMaximum = memoryInfo.MemoryMaximum
|
||||
memoryCurrent = memoryInfo.MemoryCurrent
|
||||
}
|
||||
|
||||
if fdInfoErr != nil {
|
||||
err.FileDescriptorsInformationError = SystemInformationError{
|
||||
Err: fdInfoErr,
|
||||
RawInfo: fdInfoRaw,
|
||||
}
|
||||
} else {
|
||||
fileDescriptorMaximum = fdInfo.FileDescriptorMaximum
|
||||
fileDescriptorCurrent = fdInfo.FileDescriptorCurrent
|
||||
}
|
||||
|
||||
if diskErr != nil {
|
||||
err.DiskVolumeInformationError = SystemInformationError{
|
||||
Err: diskErr,
|
||||
RawInfo: disksRaw,
|
||||
}
|
||||
}
|
||||
|
||||
if osInfoErr != nil {
|
||||
err.OperatingSystemInformationError = SystemInformationError{
|
||||
Err: osInfoErr,
|
||||
RawInfo: osInfoRaw,
|
||||
}
|
||||
} else {
|
||||
osSystem = osInfo.OsSystem
|
||||
name = osInfo.Name
|
||||
osVersion = osInfo.OsVersion
|
||||
osRelease = osInfo.OsRelease
|
||||
architecture = osInfo.Architecture
|
||||
}
|
||||
|
||||
cloudflaredVersion := collector.version
|
||||
info := NewSystemInformation(
|
||||
memoryMaximum,
|
||||
memoryCurrent,
|
||||
fileDescriptorMaximum,
|
||||
fileDescriptorCurrent,
|
||||
osSystem,
|
||||
name,
|
||||
osVersion,
|
||||
osRelease,
|
||||
architecture,
|
||||
cloudflaredVersion,
|
||||
runtime.Version(),
|
||||
runtime.GOARCH,
|
||||
disks,
|
||||
)
|
||||
|
||||
return info, err
|
||||
}
|
||||
|
||||
func collectFileDescriptorInformation(ctx context.Context) (
|
||||
*FileDescriptorInformation,
|
||||
string,
|
||||
error,
|
||||
) {
|
||||
const (
|
||||
fileDescriptorMaximumKey = "kern.maxfiles"
|
||||
fileDescriptorCurrentKey = "kern.num_files"
|
||||
)
|
||||
|
||||
command := exec.CommandContext(ctx, "sysctl", fileDescriptorMaximumKey, fileDescriptorCurrentKey)
|
||||
|
||||
stdout, err := command.Output()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
|
||||
}
|
||||
|
||||
output := string(stdout)
|
||||
|
||||
fileDescriptorInfo, err := ParseFileDescriptorInformationFromKV(
|
||||
output,
|
||||
fileDescriptorMaximumKey,
|
||||
fileDescriptorCurrentKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, output, err
|
||||
}
|
||||
|
||||
// returning raw output in case other collected information
|
||||
// resulted in errors
|
||||
return fileDescriptorInfo, output, nil
|
||||
}
|
||||
|
||||
func collectMemoryInformation(ctx context.Context) (
|
||||
*MemoryInformation,
|
||||
string,
|
||||
error,
|
||||
) {
|
||||
const (
|
||||
memoryMaximumKey = "hw.memsize"
|
||||
memoryAvailableKey = "hw.memsize_usable"
|
||||
)
|
||||
|
||||
command := exec.CommandContext(
|
||||
ctx,
|
||||
"sysctl",
|
||||
memoryMaximumKey,
|
||||
memoryAvailableKey,
|
||||
)
|
||||
|
||||
stdout, err := command.Output()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
|
||||
}
|
||||
|
||||
output := string(stdout)
|
||||
|
||||
mapper := func(field string) (uint64, error) {
|
||||
const kiloBytes = 1024
|
||||
value, err := strconv.ParseUint(field, 10, 64)
|
||||
return value / kiloBytes, err
|
||||
}
|
||||
|
||||
memoryInfo, err := ParseMemoryInformationFromKV(output, memoryMaximumKey, memoryAvailableKey, mapper)
|
||||
if err != nil {
|
||||
return nil, output, err
|
||||
}
|
||||
|
||||
// returning raw output in case other collected information
|
||||
// resulted in errors
|
||||
return memoryInfo, output, nil
|
||||
}
|
|
@ -0,0 +1,466 @@
|
|||
package diagnostic_test
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/diagnostic"
|
||||
)
|
||||
|
||||
func TestParseMemoryInformationFromKV(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mapper := func(field string) (uint64, error) {
|
||||
value, err := strconv.ParseUint(field, 10, 64)
|
||||
return value, err
|
||||
}
|
||||
|
||||
linuxMapper := func(field string) (uint64, error) {
|
||||
field = strings.TrimRight(field, " kB")
|
||||
return strconv.ParseUint(field, 10, 64)
|
||||
}
|
||||
|
||||
windowsMemoryOutput := `
|
||||
|
||||
FreeVirtualMemory : 5350472
|
||||
TotalVirtualMemorySize : 8903424
|
||||
|
||||
|
||||
`
|
||||
macosMemoryOutput := `hw.memsize: 38654705664
|
||||
hw.memsize_usable: 38009012224`
|
||||
memoryOutputWithMissingKey := `hw.memsize: 38654705664`
|
||||
|
||||
linuxMemoryOutput := `MemTotal: 8028860 kB
|
||||
MemFree: 731396 kB
|
||||
MemAvailable: 4678844 kB
|
||||
Buffers: 472632 kB
|
||||
Cached: 3186492 kB
|
||||
SwapCached: 4196 kB
|
||||
Active: 3088988 kB
|
||||
Inactive: 3468560 kB`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
memoryMaximumKey string
|
||||
memoryAvailableKey string
|
||||
expected *diagnostic.MemoryInformation
|
||||
expectedErr bool
|
||||
mapper func(string) (uint64, error)
|
||||
}{
|
||||
{
|
||||
name: "parse linux memory values",
|
||||
output: linuxMemoryOutput,
|
||||
memoryMaximumKey: "MemTotal",
|
||||
memoryAvailableKey: "MemAvailable",
|
||||
expected: &diagnostic.MemoryInformation{
|
||||
8028860,
|
||||
8028860 - 4678844,
|
||||
},
|
||||
expectedErr: false,
|
||||
mapper: linuxMapper,
|
||||
},
|
||||
{
|
||||
name: "parse memory values with missing key",
|
||||
output: memoryOutputWithMissingKey,
|
||||
memoryMaximumKey: "hw.memsize",
|
||||
memoryAvailableKey: "hw.memsize_usable",
|
||||
expected: nil,
|
||||
expectedErr: true,
|
||||
mapper: mapper,
|
||||
},
|
||||
{
|
||||
name: "parse macos memory values",
|
||||
output: macosMemoryOutput,
|
||||
memoryMaximumKey: "hw.memsize",
|
||||
memoryAvailableKey: "hw.memsize_usable",
|
||||
expected: &diagnostic.MemoryInformation{
|
||||
38654705664,
|
||||
38654705664 - 38009012224,
|
||||
},
|
||||
expectedErr: false,
|
||||
mapper: mapper,
|
||||
},
|
||||
{
|
||||
name: "parse windows memory values",
|
||||
output: windowsMemoryOutput,
|
||||
memoryMaximumKey: "TotalVirtualMemorySize",
|
||||
memoryAvailableKey: "FreeVirtualMemory",
|
||||
expected: &diagnostic.MemoryInformation{
|
||||
8903424,
|
||||
8903424 - 5350472,
|
||||
},
|
||||
expectedErr: false,
|
||||
mapper: mapper,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range tests {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
memoryInfo, err := diagnostic.ParseMemoryInformationFromKV(
|
||||
tCase.output,
|
||||
tCase.memoryMaximumKey,
|
||||
tCase.memoryAvailableKey,
|
||||
tCase.mapper,
|
||||
)
|
||||
|
||||
if tCase.expectedErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tCase.expected, memoryInfo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUnameOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
os string
|
||||
expected *diagnostic.OsInfo
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "darwin machine",
|
||||
output: "Darwin APC 23.6.0 Darwin Kernel Version 99.6.0: Wed Jul 31 20:48:04 PDT 1997; root:xnu-66666.666.6.666.6~1/RELEASE_ARM64_T6666 arm64",
|
||||
os: "darwin",
|
||||
expected: &diagnostic.OsInfo{
|
||||
Architecture: "arm64",
|
||||
Name: "APC",
|
||||
OsSystem: "Darwin",
|
||||
OsRelease: "Darwin Kernel Version 99.6.0: Wed Jul 31 20:48:04 PDT 1997; root:xnu-66666.666.6.666.6~1/RELEASE_ARM64_T6666",
|
||||
OsVersion: "23.6.0",
|
||||
},
|
||||
expectedErr: false,
|
||||
},
|
||||
{
|
||||
name: "linux machine",
|
||||
output: "Linux dab00d565591 6.6.31-linuxkit #1 SMP Thu May 23 08:36:57 UTC 2024 aarch64 GNU/Linux",
|
||||
os: "linux",
|
||||
expected: &diagnostic.OsInfo{
|
||||
Architecture: "aarch64",
|
||||
Name: "dab00d565591",
|
||||
OsSystem: "Linux",
|
||||
OsRelease: "#1 SMP Thu May 23 08:36:57 UTC 2024",
|
||||
OsVersion: "6.6.31-linuxkit",
|
||||
},
|
||||
expectedErr: false,
|
||||
},
|
||||
{
|
||||
name: "not enough fields",
|
||||
output: "Linux ",
|
||||
os: "linux",
|
||||
expected: nil,
|
||||
expectedErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range tests {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
memoryInfo, err := diagnostic.ParseUnameOutput(
|
||||
tCase.output,
|
||||
tCase.os,
|
||||
)
|
||||
|
||||
if tCase.expectedErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tCase.expected, memoryInfo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFileDescriptorInformationFromKV(t *testing.T) {
|
||||
const (
|
||||
fileDescriptorMaximumKey = "kern.maxfiles"
|
||||
fileDescriptorCurrentKey = "kern.num_files"
|
||||
)
|
||||
|
||||
t.Parallel()
|
||||
|
||||
memoryOutput := `kern.maxfiles: 276480
|
||||
kern.num_files: 11787`
|
||||
memoryOutputWithMissingKey := `kern.maxfiles: 276480`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
expected *diagnostic.FileDescriptorInformation
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "parse memory values with missing key",
|
||||
output: memoryOutputWithMissingKey,
|
||||
expected: nil,
|
||||
expectedErr: true,
|
||||
},
|
||||
{
|
||||
name: "parse macos memory values",
|
||||
output: memoryOutput,
|
||||
expected: &diagnostic.FileDescriptorInformation{
|
||||
276480,
|
||||
11787,
|
||||
},
|
||||
expectedErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range tests {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fdInfo, err := diagnostic.ParseFileDescriptorInformationFromKV(
|
||||
tCase.output,
|
||||
fileDescriptorMaximumKey,
|
||||
fileDescriptorCurrentKey,
|
||||
)
|
||||
|
||||
if tCase.expectedErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tCase.expected, fdInfo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSysctlFileDescriptorInformation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
expected *diagnostic.FileDescriptorInformation
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "expected output",
|
||||
output: "111 0 1111111",
|
||||
expected: &diagnostic.FileDescriptorInformation{
|
||||
FileDescriptorMaximum: 1111111,
|
||||
FileDescriptorCurrent: 111,
|
||||
},
|
||||
expectedErr: false,
|
||||
},
|
||||
{
|
||||
name: "not enough fields",
|
||||
output: "111 111 ",
|
||||
expected: nil,
|
||||
expectedErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range tests {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fdsInfo, err := diagnostic.ParseSysctlFileDescriptorInformation(
|
||||
tCase.output,
|
||||
)
|
||||
|
||||
if tCase.expectedErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tCase.expected, fdsInfo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseWinOperatingSystemInfo(t *testing.T) {
|
||||
const (
|
||||
architecturePrefix = "OSArchitecture"
|
||||
osSystemPrefix = "Caption"
|
||||
osVersionPrefix = "Version"
|
||||
osReleasePrefix = "BuildNumber"
|
||||
namePrefix = "CSName"
|
||||
)
|
||||
|
||||
t.Parallel()
|
||||
|
||||
windowsIncompleteOsInfo := `
|
||||
OSArchitecture : ARM 64 bits
|
||||
Caption : Microsoft Windows 11 Home
|
||||
Morekeys : 121314
|
||||
CSName : UTILIZA-QO859QP
|
||||
`
|
||||
windowsCompleteOsInfo := `
|
||||
OSArchitecture : ARM 64 bits
|
||||
Caption : Microsoft Windows 11 Home
|
||||
Version : 10.0.22631
|
||||
BuildNumber : 22631
|
||||
Morekeys : 121314
|
||||
CSName : UTILIZA-QO859QP
|
||||
`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
expected *diagnostic.OsInfo
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "expected output",
|
||||
output: windowsCompleteOsInfo,
|
||||
expected: &diagnostic.OsInfo{
|
||||
Architecture: "ARM 64 bits",
|
||||
Name: "UTILIZA-QO859QP",
|
||||
OsSystem: "Microsoft Windows 11 Home",
|
||||
OsRelease: "22631",
|
||||
OsVersion: "10.0.22631",
|
||||
},
|
||||
expectedErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing keys",
|
||||
output: windowsIncompleteOsInfo,
|
||||
expected: nil,
|
||||
expectedErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range tests {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
osInfo, err := diagnostic.ParseWinOperatingSystemInfo(
|
||||
tCase.output,
|
||||
architecturePrefix,
|
||||
osSystemPrefix,
|
||||
osVersionPrefix,
|
||||
osReleasePrefix,
|
||||
namePrefix,
|
||||
)
|
||||
|
||||
if tCase.expectedErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tCase.expected, osInfo)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDiskVolumeInformationOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
invalidUnixDiskVolumeInfo := `Filesystem Size Used Avail Use% Mounted on
|
||||
overlay 59G 19G 38G 33% /
|
||||
tmpfs 64M 0 64M 0% /dev
|
||||
shm 64M 0 64M 0% /dev/shm
|
||||
/run/host_mark/Users 461G 266G 195G 58% /tmp/cloudflared
|
||||
/dev/vda1 59G 19G 38G 33% /etc/hosts
|
||||
tmpfs 3.9G 0 3.9G 0% /sys/firmware
|
||||
`
|
||||
|
||||
unixDiskVolumeInfo := `Filesystem Size Used Avail Use% Mounted on
|
||||
overlay 61202244 18881444 39179476 33% /
|
||||
tmpfs 65536 0 65536 0% /dev
|
||||
shm 65536 0 65536 0% /dev/shm
|
||||
/run/host_mark/Users 482797652 278648468 204149184 58% /tmp/cloudflared
|
||||
/dev/vda1 61202244 18881444 39179476 33% /etc/hosts
|
||||
tmpfs 4014428 0 4014428 0% /sys/firmware`
|
||||
missingFields := ` DeviceID Size
|
||||
-------- ----
|
||||
C: size
|
||||
E: 235563008
|
||||
Z: 67754782720
|
||||
`
|
||||
invalidTypeField := ` DeviceID Size FreeSpace
|
||||
-------- ---- ---------
|
||||
C: size 31318736896
|
||||
D:
|
||||
E: 235563008 0
|
||||
Z: 67754782720 31318732800
|
||||
`
|
||||
|
||||
windowsDiskVolumeInfo := `
|
||||
|
||||
DeviceID Size FreeSpace
|
||||
-------- ---- ---------
|
||||
C: 67754782720 31318736896
|
||||
E: 235563008 0
|
||||
Z: 67754782720 31318732800`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
output string
|
||||
expected []*diagnostic.DiskVolumeInformation
|
||||
skipLines int
|
||||
expectedErr bool
|
||||
}{
|
||||
{
|
||||
name: "invalid unix disk volume information (numbers have units)",
|
||||
output: invalidUnixDiskVolumeInfo,
|
||||
expected: []*diagnostic.DiskVolumeInformation{},
|
||||
skipLines: 1,
|
||||
expectedErr: true,
|
||||
},
|
||||
{
|
||||
name: "unix disk volume information",
|
||||
output: unixDiskVolumeInfo,
|
||||
skipLines: 1,
|
||||
expected: []*diagnostic.DiskVolumeInformation{
|
||||
diagnostic.NewDiskVolumeInformation("overlay", 61202244, 18881444),
|
||||
diagnostic.NewDiskVolumeInformation("tmpfs", 65536, 0),
|
||||
diagnostic.NewDiskVolumeInformation("shm", 65536, 0),
|
||||
diagnostic.NewDiskVolumeInformation("/run/host_mark/Users", 482797652, 278648468),
|
||||
diagnostic.NewDiskVolumeInformation("/dev/vda1", 61202244, 18881444),
|
||||
diagnostic.NewDiskVolumeInformation("tmpfs", 4014428, 0),
|
||||
},
|
||||
expectedErr: false,
|
||||
},
|
||||
{
|
||||
name: "windows disk volume information",
|
||||
output: windowsDiskVolumeInfo,
|
||||
expected: []*diagnostic.DiskVolumeInformation{
|
||||
diagnostic.NewDiskVolumeInformation("C:", 67754782720, 31318736896),
|
||||
diagnostic.NewDiskVolumeInformation("E:", 235563008, 0),
|
||||
diagnostic.NewDiskVolumeInformation("Z:", 67754782720, 31318732800),
|
||||
},
|
||||
skipLines: 4,
|
||||
expectedErr: false,
|
||||
},
|
||||
{
|
||||
name: "insuficient fields",
|
||||
output: missingFields,
|
||||
expected: nil,
|
||||
skipLines: 2,
|
||||
expectedErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid field",
|
||||
output: invalidTypeField,
|
||||
expected: nil,
|
||||
skipLines: 2,
|
||||
expectedErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range tests {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
disks, err := diagnostic.ParseDiskVolumeInformationOutput(tCase.output, tCase.skipLines, 1)
|
||||
|
||||
if tCase.expectedErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tCase.expected, disks)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -0,0 +1,377 @@
|
|||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func findColonSeparatedPairs[V any](output string, keys []string, mapper func(string) (V, error)) map[string]V {
|
||||
const (
|
||||
memoryField = 1
|
||||
memoryInformationFields = 2
|
||||
)
|
||||
|
||||
lines := strings.Split(output, "\n")
|
||||
pairs := make(map[string]V, 0)
|
||||
|
||||
// sort keys and lines to allow incremental search
|
||||
sort.Strings(lines)
|
||||
sort.Strings(keys)
|
||||
|
||||
// keeps track of the last key found
|
||||
lastIndex := 0
|
||||
|
||||
for _, line := range lines {
|
||||
if lastIndex == len(keys) {
|
||||
// already found all keys no need to continue iterating
|
||||
// over the other values
|
||||
break
|
||||
}
|
||||
|
||||
for index, key := range keys[lastIndex:] {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, key) {
|
||||
fields := strings.Split(line, ":")
|
||||
if len(fields) < memoryInformationFields {
|
||||
lastIndex = index + 1
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
field, err := mapper(strings.TrimSpace(fields[memoryField]))
|
||||
if err != nil {
|
||||
lastIndex = lastIndex + index + 1
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
pairs[key] = field
|
||||
lastIndex = lastIndex + index + 1
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return pairs
|
||||
}
|
||||
|
||||
func ParseDiskVolumeInformationOutput(output string, skipLines int, scale float64) ([]*DiskVolumeInformation, error) {
|
||||
const (
|
||||
diskFieldsMinimum = 3
|
||||
nameField = 0
|
||||
sizeMaximumField = 1
|
||||
sizeCurrentField = 2
|
||||
)
|
||||
|
||||
disksRaw := strings.Split(output, "\n")
|
||||
disks := make([]*DiskVolumeInformation, 0)
|
||||
|
||||
if skipLines > len(disksRaw) || skipLines < 0 {
|
||||
skipLines = 0
|
||||
}
|
||||
|
||||
for _, disk := range disksRaw[skipLines:] {
|
||||
if disk == "" {
|
||||
// skip empty line
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Fields(disk)
|
||||
if len(fields) < diskFieldsMinimum {
|
||||
return nil, fmt.Errorf("expected disk volume to have %d fields got %d: %w",
|
||||
diskFieldsMinimum, len(fields), ErrInsuficientFields,
|
||||
)
|
||||
}
|
||||
|
||||
name := fields[nameField]
|
||||
|
||||
sizeMaximum, err := strconv.ParseUint(fields[sizeMaximumField], 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
sizeCurrent, err := strconv.ParseUint(fields[sizeCurrentField], 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
diskInfo := NewDiskVolumeInformation(
|
||||
name, uint64(float64(sizeMaximum)*scale), uint64(float64(sizeCurrent)*scale),
|
||||
)
|
||||
disks = append(disks, diskInfo)
|
||||
}
|
||||
|
||||
if len(disks) == 0 {
|
||||
return nil, ErrNoVolumeFound
|
||||
}
|
||||
|
||||
return disks, nil
|
||||
}
|
||||
|
||||
type OsInfo struct {
|
||||
OsSystem string
|
||||
Name string
|
||||
OsVersion string
|
||||
OsRelease string
|
||||
Architecture string
|
||||
}
|
||||
|
||||
func ParseUnameOutput(output string, system string) (*OsInfo, error) {
|
||||
const (
|
||||
osystemField = 0
|
||||
nameField = 1
|
||||
osVersionField = 2
|
||||
osReleaseStartField = 3
|
||||
osInformationFieldsMinimum = 6
|
||||
darwin = "darwin"
|
||||
)
|
||||
|
||||
architectureOffset := 2
|
||||
if system == darwin {
|
||||
architectureOffset = 1
|
||||
}
|
||||
|
||||
fields := strings.Fields(output)
|
||||
if len(fields) < osInformationFieldsMinimum {
|
||||
return nil, fmt.Errorf("expected system information to have %d fields got %d: %w",
|
||||
osInformationFieldsMinimum, len(fields), ErrInsuficientFields,
|
||||
)
|
||||
}
|
||||
|
||||
architectureField := len(fields) - architectureOffset
|
||||
osystem := fields[osystemField]
|
||||
name := fields[nameField]
|
||||
osVersion := fields[osVersionField]
|
||||
osRelease := strings.Join(fields[osReleaseStartField:architectureField], " ")
|
||||
architecture := fields[architectureField]
|
||||
|
||||
return &OsInfo{
|
||||
osystem,
|
||||
name,
|
||||
osVersion,
|
||||
osRelease,
|
||||
architecture,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ParseWinOperatingSystemInfo(
|
||||
output string,
|
||||
architectureKey string,
|
||||
osSystemKey string,
|
||||
osVersionKey string,
|
||||
osReleaseKey string,
|
||||
nameKey string,
|
||||
) (*OsInfo, error) {
|
||||
identity := func(s string) (string, error) { return s, nil }
|
||||
|
||||
keys := []string{architectureKey, osSystemKey, osVersionKey, osReleaseKey, nameKey}
|
||||
pairs := findColonSeparatedPairs(
|
||||
output,
|
||||
keys,
|
||||
identity,
|
||||
)
|
||||
|
||||
architecture, exists := pairs[architectureKey]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("parsing os information: %w, key=%s", ErrKeyNotFound, architectureKey)
|
||||
}
|
||||
|
||||
osSystem, exists := pairs[osSystemKey]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("parsing os information: %w, key=%s", ErrKeyNotFound, osSystemKey)
|
||||
}
|
||||
|
||||
osVersion, exists := pairs[osVersionKey]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("parsing os information: %w, key=%s", ErrKeyNotFound, osVersionKey)
|
||||
}
|
||||
|
||||
osRelease, exists := pairs[osReleaseKey]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("parsing os information: %w, key=%s", ErrKeyNotFound, osReleaseKey)
|
||||
}
|
||||
|
||||
name, exists := pairs[nameKey]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("parsing os information: %w, key=%s", ErrKeyNotFound, nameKey)
|
||||
}
|
||||
|
||||
return &OsInfo{osSystem, name, osVersion, osRelease, architecture}, nil
|
||||
}
|
||||
|
||||
type FileDescriptorInformation struct {
|
||||
FileDescriptorMaximum uint64
|
||||
FileDescriptorCurrent uint64
|
||||
}
|
||||
|
||||
func ParseSysctlFileDescriptorInformation(output string) (*FileDescriptorInformation, error) {
|
||||
const (
|
||||
openFilesField = 0
|
||||
maxFilesField = 2
|
||||
fileDescriptorLimitsFields = 3
|
||||
)
|
||||
|
||||
fields := strings.Fields(output)
|
||||
|
||||
if len(fields) != fileDescriptorLimitsFields {
|
||||
return nil,
|
||||
fmt.Errorf(
|
||||
"expected file descriptor information to have %d fields got %d: %w",
|
||||
fileDescriptorLimitsFields,
|
||||
len(fields),
|
||||
ErrInsuficientFields,
|
||||
)
|
||||
}
|
||||
|
||||
fileDescriptorCurrent, err := strconv.ParseUint(fields[openFilesField], 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error parsing files current field '%s': %w",
|
||||
fields[openFilesField],
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
fileDescriptorMaximum, err := strconv.ParseUint(fields[maxFilesField], 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing files max field '%s': %w", fields[maxFilesField], err)
|
||||
}
|
||||
|
||||
return &FileDescriptorInformation{fileDescriptorMaximum, fileDescriptorCurrent}, nil
|
||||
}
|
||||
|
||||
func ParseFileDescriptorInformationFromKV(
|
||||
output string,
|
||||
fileDescriptorMaximumKey string,
|
||||
fileDescriptorCurrentKey string,
|
||||
) (*FileDescriptorInformation, error) {
|
||||
mapper := func(field string) (uint64, error) {
|
||||
return strconv.ParseUint(field, 10, 64)
|
||||
}
|
||||
|
||||
pairs := findColonSeparatedPairs(output, []string{fileDescriptorMaximumKey, fileDescriptorCurrentKey}, mapper)
|
||||
|
||||
fileDescriptorMaximum, exists := pairs[fileDescriptorMaximumKey]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf(
|
||||
"parsing file descriptor information: %w, key=%s",
|
||||
ErrKeyNotFound,
|
||||
fileDescriptorMaximumKey,
|
||||
)
|
||||
}
|
||||
|
||||
fileDescriptorCurrent, exists := pairs[fileDescriptorCurrentKey]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf(
|
||||
"parsing file descriptor information: %w, key=%s",
|
||||
ErrKeyNotFound,
|
||||
fileDescriptorCurrentKey,
|
||||
)
|
||||
}
|
||||
|
||||
return &FileDescriptorInformation{fileDescriptorMaximum, fileDescriptorCurrent}, nil
|
||||
}
|
||||
|
||||
type MemoryInformation struct {
|
||||
MemoryMaximum uint64 // size in KB
|
||||
MemoryCurrent uint64 // size in KB
|
||||
}
|
||||
|
||||
func ParseMemoryInformationFromKV(
|
||||
output string,
|
||||
memoryMaximumKey string,
|
||||
memoryAvailableKey string,
|
||||
mapper func(field string) (uint64, error),
|
||||
) (*MemoryInformation, error) {
|
||||
pairs := findColonSeparatedPairs(output, []string{memoryMaximumKey, memoryAvailableKey}, mapper)
|
||||
|
||||
memoryMaximum, exists := pairs[memoryMaximumKey]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("parsing memory information: %w, key=%s", ErrKeyNotFound, memoryMaximumKey)
|
||||
}
|
||||
|
||||
memoryAvailable, exists := pairs[memoryAvailableKey]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("parsing memory information: %w, key=%s", ErrKeyNotFound, memoryAvailableKey)
|
||||
}
|
||||
|
||||
memoryCurrent := memoryMaximum - memoryAvailable
|
||||
|
||||
return &MemoryInformation{memoryMaximum, memoryCurrent}, nil
|
||||
}
|
||||
|
||||
func RawSystemInformation(osInfoRaw string, memoryInfoRaw string, fdInfoRaw string, disksRaw string) string {
|
||||
var builder strings.Builder
|
||||
|
||||
formatInfo := func(info string, builder *strings.Builder) {
|
||||
if info == "" {
|
||||
builder.WriteString("No information\n")
|
||||
} else {
|
||||
builder.WriteString(info)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
builder.WriteString("---BEGIN Operating system information\n")
|
||||
formatInfo(osInfoRaw, &builder)
|
||||
builder.WriteString("---END Operating system information\n")
|
||||
builder.WriteString("---BEGIN Memory information\n")
|
||||
formatInfo(memoryInfoRaw, &builder)
|
||||
builder.WriteString("---END Memory information\n")
|
||||
builder.WriteString("---BEGIN File descriptors information\n")
|
||||
formatInfo(fdInfoRaw, &builder)
|
||||
builder.WriteString("---END File descriptors information\n")
|
||||
builder.WriteString("---BEGIN Disks information\n")
|
||||
formatInfo(disksRaw, &builder)
|
||||
builder.WriteString("---END Disks information\n")
|
||||
|
||||
rawInformation := builder.String()
|
||||
|
||||
return rawInformation
|
||||
}
|
||||
|
||||
func collectDiskVolumeInformationUnix(ctx context.Context) ([]*DiskVolumeInformation, string, error) {
|
||||
command := exec.CommandContext(ctx, "df", "-k")
|
||||
|
||||
stdout, err := command.Output()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
|
||||
}
|
||||
|
||||
output := string(stdout)
|
||||
|
||||
disks, err := ParseDiskVolumeInformationOutput(output, 1, 1)
|
||||
if err != nil {
|
||||
return nil, output, err
|
||||
}
|
||||
|
||||
// returning raw output in case other collected information
|
||||
// resulted in errors
|
||||
return disks, output, nil
|
||||
}
|
||||
|
||||
func collectOSInformationUnix(ctx context.Context) (*OsInfo, string, error) {
|
||||
command := exec.CommandContext(ctx, "uname", "-a")
|
||||
|
||||
stdout, err := command.Output()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
|
||||
}
|
||||
|
||||
output := string(stdout)
|
||||
|
||||
osInfo, err := ParseUnameOutput(output, runtime.GOOS)
|
||||
if err != nil {
|
||||
return nil, output, err
|
||||
}
|
||||
|
||||
// returning raw output in case other collected information
|
||||
// resulted in errors
|
||||
return osInfo, output, nil
|
||||
}
|
|
@ -0,0 +1,183 @@
|
|||
//go:build windows
|
||||
|
||||
package diagnostic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const kiloBytesScale = 1.0 / 1024
|
||||
|
||||
type SystemCollectorImpl struct {
|
||||
version string
|
||||
}
|
||||
|
||||
func NewSystemCollectorImpl(
|
||||
version string,
|
||||
) *SystemCollectorImpl {
|
||||
return &SystemCollectorImpl{
|
||||
version,
|
||||
}
|
||||
}
|
||||
|
||||
func (collector *SystemCollectorImpl) Collect(ctx context.Context) (*SystemInformation, error) {
|
||||
memoryInfo, memoryInfoRaw, memoryInfoErr := collectMemoryInformation(ctx)
|
||||
disks, disksRaw, diskErr := collectDiskVolumeInformation(ctx)
|
||||
osInfo, osInfoRaw, osInfoErr := collectOSInformation(ctx)
|
||||
|
||||
var memoryMaximum, memoryCurrent, fileDescriptorMaximum, fileDescriptorCurrent uint64
|
||||
var osSystem, name, osVersion, osRelease, architecture string
|
||||
|
||||
err := SystemInformationGeneralError{
|
||||
OperatingSystemInformationError: nil,
|
||||
MemoryInformationError: nil,
|
||||
FileDescriptorsInformationError: nil,
|
||||
DiskVolumeInformationError: nil,
|
||||
}
|
||||
|
||||
if memoryInfoErr != nil {
|
||||
err.MemoryInformationError = SystemInformationError{
|
||||
Err: memoryInfoErr,
|
||||
RawInfo: memoryInfoRaw,
|
||||
}
|
||||
} else {
|
||||
memoryMaximum = memoryInfo.MemoryMaximum
|
||||
memoryCurrent = memoryInfo.MemoryCurrent
|
||||
}
|
||||
|
||||
if diskErr != nil {
|
||||
err.DiskVolumeInformationError = SystemInformationError{
|
||||
Err: diskErr,
|
||||
RawInfo: disksRaw,
|
||||
}
|
||||
}
|
||||
|
||||
if osInfoErr != nil {
|
||||
err.OperatingSystemInformationError = SystemInformationError{
|
||||
Err: osInfoErr,
|
||||
RawInfo: osInfoRaw,
|
||||
}
|
||||
} else {
|
||||
osSystem = osInfo.OsSystem
|
||||
name = osInfo.Name
|
||||
osVersion = osInfo.OsVersion
|
||||
osRelease = osInfo.OsRelease
|
||||
architecture = osInfo.Architecture
|
||||
}
|
||||
|
||||
cloudflaredVersion := collector.version
|
||||
info := NewSystemInformation(
|
||||
memoryMaximum,
|
||||
memoryCurrent,
|
||||
fileDescriptorMaximum,
|
||||
fileDescriptorCurrent,
|
||||
osSystem,
|
||||
name,
|
||||
osVersion,
|
||||
osRelease,
|
||||
architecture,
|
||||
cloudflaredVersion,
|
||||
runtime.Version(),
|
||||
runtime.GOARCH,
|
||||
disks,
|
||||
)
|
||||
|
||||
return info, err
|
||||
}
|
||||
|
||||
func collectMemoryInformation(ctx context.Context) (*MemoryInformation, string, error) {
|
||||
const (
|
||||
memoryTotalPrefix = "TotalVirtualMemorySize"
|
||||
memoryAvailablePrefix = "FreeVirtualMemory"
|
||||
)
|
||||
|
||||
command := exec.CommandContext(
|
||||
ctx,
|
||||
"powershell",
|
||||
"-Command",
|
||||
"Get-CimInstance -Class Win32_OperatingSystem | Select-Object FreeVirtualMemory, TotalVirtualMemorySize | Format-List",
|
||||
)
|
||||
|
||||
stdout, err := command.Output()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
|
||||
}
|
||||
|
||||
output := string(stdout)
|
||||
|
||||
// the result of the command above will return values in bytes hence
|
||||
// they need to be converted to kilobytes
|
||||
mapper := func(field string) (uint64, error) {
|
||||
value, err := strconv.ParseUint(field, 10, 64)
|
||||
return uint64(float64(value) * kiloBytesScale), err
|
||||
}
|
||||
|
||||
memoryInfo, err := ParseMemoryInformationFromKV(output, memoryTotalPrefix, memoryAvailablePrefix, mapper)
|
||||
if err != nil {
|
||||
return nil, output, err
|
||||
}
|
||||
|
||||
// returning raw output in case other collected information
|
||||
// resulted in errors
|
||||
return memoryInfo, output, nil
|
||||
}
|
||||
|
||||
func collectDiskVolumeInformation(ctx context.Context) ([]*DiskVolumeInformation, string, error) {
|
||||
|
||||
command := exec.CommandContext(
|
||||
ctx,
|
||||
"powershell", "-Command", "Get-CimInstance -Class Win32_LogicalDisk | Select-Object DeviceID, Size, FreeSpace")
|
||||
|
||||
stdout, err := command.Output()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
|
||||
}
|
||||
|
||||
output := string(stdout)
|
||||
|
||||
disks, err := ParseDiskVolumeInformationOutput(output, 2, kiloBytesScale)
|
||||
if err != nil {
|
||||
return nil, output, err
|
||||
}
|
||||
|
||||
// returning raw output in case other collected information
|
||||
// resulted in errors
|
||||
return disks, output, nil
|
||||
}
|
||||
|
||||
func collectOSInformation(ctx context.Context) (*OsInfo, string, error) {
|
||||
const (
|
||||
architecturePrefix = "OSArchitecture"
|
||||
osSystemPrefix = "Caption"
|
||||
osVersionPrefix = "Version"
|
||||
osReleasePrefix = "BuildNumber"
|
||||
namePrefix = "CSName"
|
||||
)
|
||||
|
||||
command := exec.CommandContext(
|
||||
ctx,
|
||||
"powershell",
|
||||
"-Command",
|
||||
"Get-CimInstance -Class Win32_OperatingSystem | Select-Object OSArchitecture, Caption, Version, BuildNumber, CSName | Format-List",
|
||||
)
|
||||
|
||||
stdout, err := command.Output()
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
|
||||
}
|
||||
|
||||
output := string(stdout)
|
||||
|
||||
osInfo, err := ParseWinOperatingSystemInfo(output, architecturePrefix, osSystemPrefix, osVersionPrefix, osReleasePrefix, namePrefix)
|
||||
if err != nil {
|
||||
return nil, output, err
|
||||
}
|
||||
|
||||
// returning raw output in case other collected information
|
||||
// resulted in errors
|
||||
return osInfo, output, nil
|
||||
}
|
|
@ -9,7 +9,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// DialEdgeWithH2Mux makes a TLS connection to a Cloudflare edge node
|
||||
// DialEdge makes a TLS connection to a Cloudflare edge node
|
||||
func DialEdge(
|
||||
ctx context.Context,
|
||||
timeout time.Duration,
|
||||
|
@ -36,7 +36,7 @@ func DialEdge(
|
|||
if err = tlsEdgeConn.Handshake(); err != nil {
|
||||
return nil, newDialError(err, "TLS handshake with edge error")
|
||||
}
|
||||
// clear the deadline on the conn; h2mux has its own timeouts
|
||||
// clear the deadline on the conn; http2 has its own timeouts
|
||||
tlsEdgeConn.SetDeadline(time.Time{})
|
||||
return tlsEdgeConn, nil
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ const (
|
|||
FeaturePostQuantum = "postquantum"
|
||||
FeatureQUICSupportEOF = "support_quic_eof"
|
||||
FeatureManagementLogs = "management_logs"
|
||||
FeatureDatagramV3 = "support_datagram_v3"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
@ -11,8 +11,9 @@ import hashlib
|
|||
import requests
|
||||
import tarfile
|
||||
from os import listdir
|
||||
from os.path import isfile, join
|
||||
from os.path import isfile, join, splitext
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from github import Github, GithubException, UnknownObjectException
|
||||
|
||||
|
@ -210,6 +211,61 @@ def move_asset(filepath, filename):
|
|||
except shutil.SameFileError:
|
||||
pass # the macOS release copy fails with being the same file (already in the artifacts directory)
|
||||
|
||||
def get_binary_version(binary_path):
|
||||
"""
|
||||
Sample output from go version -m <binary>:
|
||||
...
|
||||
build -compiler=gc
|
||||
build -ldflags="-X \"main.Version=2024.8.3-6-gec072691\" -X \"main.BuildTime=2024-09-10-1027 UTC\" "
|
||||
build CGO_ENABLED=1
|
||||
...
|
||||
|
||||
This function parses the above output to retrieve the following substring 2024.8.3-6-gec072691.
|
||||
To do this a start and end indexes are computed and the a slice is extracted from the output using them.
|
||||
"""
|
||||
needle = "main.Version="
|
||||
cmd = ['go','version', '-m', binary_path]
|
||||
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, _ = process.communicate()
|
||||
version_info = output.decode()
|
||||
|
||||
# Find start of needle
|
||||
needle_index = version_info.find(needle)
|
||||
# Find backward slash relative to the beggining of the needle
|
||||
relative_end_index = version_info[needle_index:].find("\\")
|
||||
# Calculate needle position plus needle length to find version beggining
|
||||
start_index = needle_index + len(needle)
|
||||
# Calculate needle position plus relative position of the backward slash
|
||||
end_index = needle_index + relative_end_index
|
||||
return version_info[start_index:end_index]
|
||||
|
||||
def assert_asset_version(binary_path, release_version):
|
||||
"""
|
||||
Asserts that the artifacts have the correct release_version.
|
||||
The artifacts that are checked must not have an extension expecting .exe and .tgz.
|
||||
In the occurrence of any other extension the function exits early.
|
||||
"""
|
||||
try:
|
||||
shutil.rmtree('tmp')
|
||||
except OSError:
|
||||
pass
|
||||
_, ext = os.path.splitext(binary_path)
|
||||
if ext == '.exe' or ext == '':
|
||||
binary_version = get_binary_version(binary_path)
|
||||
elif ext == '.tgz':
|
||||
tar = tarfile.open(binary_path, "r:gz")
|
||||
tar.extractall("tmp")
|
||||
tar.close()
|
||||
binary_path = os.path.join(os.getcwd(), 'tmp', 'cloudflared')
|
||||
binary_version = get_binary_version(binary_path)
|
||||
else:
|
||||
return
|
||||
|
||||
if binary_version != release_version:
|
||||
logging.error(f"Version mismatch {binary_path}, binary_version {binary_version} release_version {release_version}")
|
||||
exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
""" Attempts to upload Asset to Github Release. Creates Release if it doesn't exist """
|
||||
try:
|
||||
|
@ -221,6 +277,7 @@ def main():
|
|||
for filename in onlyfiles:
|
||||
binary_path = os.path.join(args.path, filename)
|
||||
logging.info("binary: " + binary_path)
|
||||
assert_asset_version(binary_path, args.release_version)
|
||||
elif os.path.isfile(args.path):
|
||||
logging.info("binary: " + binary_path)
|
||||
else:
|
||||
|
@ -229,18 +286,20 @@ def main():
|
|||
else:
|
||||
client = Github(args.api_key)
|
||||
repo = client.get_repo(CLOUDFLARED_REPO)
|
||||
release = get_or_create_release(repo, args.release_version, args.dry_run)
|
||||
|
||||
if os.path.isdir(args.path):
|
||||
onlyfiles = [f for f in listdir(args.path) if isfile(join(args.path, f))]
|
||||
for filename in onlyfiles:
|
||||
binary_path = os.path.join(args.path, filename)
|
||||
assert_asset_version(binary_path, args.release_version)
|
||||
release = get_or_create_release(repo, args.release_version, args.dry_run)
|
||||
for filename in onlyfiles:
|
||||
binary_path = os.path.join(args.path, filename)
|
||||
upload_asset(release, binary_path, filename, args.release_version, args.kv_account_id, args.namespace_id,
|
||||
args.kv_api_token)
|
||||
move_asset(binary_path, filename)
|
||||
else:
|
||||
upload_asset(release, args.path, args.name, args.release_version, args.kv_account_id, args.namespace_id,
|
||||
args.kv_api_token)
|
||||
raise Exception("the argument path must be a directory")
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
|
29
go.mod
29
go.mod
|
@ -3,7 +3,7 @@ module github.com/cloudflare/cloudflared
|
|||
go 1.22
|
||||
|
||||
require (
|
||||
github.com/coredns/coredns v1.10.0
|
||||
github.com/coredns/coredns v1.11.3
|
||||
github.com/coreos/go-oidc/v3 v3.10.0
|
||||
github.com/coreos/go-systemd/v22 v22.5.0
|
||||
github.com/facebookgo/grace v0.0.0-20180706040059-75cf19382434
|
||||
|
@ -13,18 +13,17 @@ require (
|
|||
github.com/go-chi/chi/v5 v5.0.8
|
||||
github.com/go-chi/cors v1.2.1
|
||||
github.com/go-jose/go-jose/v4 v4.0.1
|
||||
github.com/gobwas/ws v1.0.4
|
||||
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3
|
||||
github.com/gobwas/ws v1.2.1
|
||||
github.com/google/gopacket v1.1.19
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
github.com/json-iterator/go v1.1.12
|
||||
github.com/mattn/go-colorable v0.1.13
|
||||
github.com/miekg/dns v1.1.50
|
||||
github.com/miekg/dns v1.1.58
|
||||
github.com/mitchellh/go-homedir v1.1.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/prometheus/client_model v0.5.0
|
||||
github.com/prometheus/client_model v0.6.0
|
||||
github.com/quic-go/quic-go v0.45.0
|
||||
github.com/rs/zerolog v1.20.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
|
@ -55,7 +54,7 @@ require (
|
|||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/coredns/caddy v1.1.1 // indirect
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/facebookgo/ensure v0.0.0-20160127193407-b4ab57deab51 // indirect
|
||||
github.com/facebookgo/freeport v0.0.0-20150612182905-d4adf43b75b9 // indirect
|
||||
github.com/facebookgo/stack v0.0.0-20160209184415-751773369052 // indirect
|
||||
|
@ -64,36 +63,36 @@ require (
|
|||
github.com/go-logr/logr v1.4.1 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 // indirect
|
||||
github.com/gobwas/httphead v0.1.0 // indirect
|
||||
github.com/gobwas/pool v0.2.1 // indirect
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 // indirect
|
||||
github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 // indirect
|
||||
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect
|
||||
github.com/klauspost/compress v1.15.11 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.16 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.13.0 // indirect
|
||||
github.com/opentracing/opentracing-go v1.2.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/common v0.48.0 // indirect
|
||||
github.com/prometheus/common v0.53.0 // indirect
|
||||
github.com/prometheus/procfs v0.12.0 // indirect
|
||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.26.0 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/oauth2 v0.17.0 // indirect
|
||||
golang.org/x/oauth2 v0.18.0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
golang.org/x/tools v0.21.0 // indirect
|
||||
google.golang.org/appengine v1.6.8 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect
|
||||
google.golang.org/grpc v1.63.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240311132316-a219d84964c2 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect
|
||||
google.golang.org/grpc v1.63.2 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
|
||||
|
|
91
go.sum
91
go.sum
|
@ -7,13 +7,10 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
|||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
|
||||
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
|
||||
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
|
||||
github.com/coredns/caddy v1.1.1 h1:2eYKZT7i6yxIfGP3qLJoJ7HAsDJqYB+X68g4NYjSrE0=
|
||||
github.com/coredns/caddy v1.1.1/go.mod h1:A6ntJQlAWuQfFlsd9hvigKbo2WS0VUs2l1e2F+BawD4=
|
||||
github.com/coredns/coredns v1.10.0 h1:jCfuWsBjTs0dapkkhISfPCzn5LqvSRtrFtaf/Tjj4DI=
|
||||
github.com/coredns/coredns v1.10.0/go.mod h1:CIfRU5TgpuoIiJBJ4XrofQzfFQpPFh32ERpUevrSlaw=
|
||||
github.com/coredns/coredns v1.11.3 h1:8RjnpZc42db5th84/QJKH2i137ecJdzZK1HJwhetSPk=
|
||||
github.com/coredns/coredns v1.11.3/go.mod h1:lqFkDsHjEUdY7LJ75Nib3lwqJGip6ewWOqNIf8OavIQ=
|
||||
github.com/coreos/go-oidc/v3 v3.10.0 h1:tDnXHnLyiTVyT/2zLDGj09pFPkhND8Gl8lnTRhoEaJU=
|
||||
github.com/coreos/go-oidc/v3 v3.10.0/go.mod h1:5j11xcw0D3+SGxn6Z/WFADsgcWVMyNAlSQupk0KK3ac=
|
||||
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
|
@ -24,8 +21,9 @@ github.com/cpuguy83/go-md2man/v2 v2.0.0 h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng
|
|||
github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/facebookgo/ensure v0.0.0-20160127193407-b4ab57deab51 h1:0JZ+dUmQeA8IIVUMzysrX4/AKuQwWhV2dYQuPZdvdSQ=
|
||||
github.com/facebookgo/ensure v0.0.0-20160127193407-b4ab57deab51/go.mod h1:Yg+htXGokKKdzcwhuNDwVvN+uBxDGXJ7G/VN1d8fa64=
|
||||
github.com/facebookgo/freeport v0.0.0-20150612182905-d4adf43b75b9 h1:wWke/RUCl7VRjQhwPlR/v0glZXNYzBHdNUzf/Am2Nmg=
|
||||
|
@ -75,19 +73,18 @@ github.com/go-playground/validator/v10 v10.11.1/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4
|
|||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
||||
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58 h1:YyrUZvJaU8Q0QsoVo+xLFBgWDTam29PKea6GYmwvSiQ=
|
||||
github.com/gobwas/httphead v0.0.0-20200921212729-da3d93bc3c58/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
|
||||
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
|
||||
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
|
||||
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
|
||||
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
|
||||
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
||||
github.com/gobwas/ws v1.0.4 h1:5eXU1CZhpQdq5kXbKb+sECH5Ia5KiO6CYzIzdlVx6Bs=
|
||||
github.com/gobwas/ws v1.0.4/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
|
||||
github.com/gobwas/ws v1.2.1 h1:F2aeBZrm2NDsc7vbovKrWSogd4wvfAxg0FQ89/iqOTk=
|
||||
github.com/gobwas/ws v1.2.1/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY=
|
||||
github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk=
|
||||
github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 h1:zN2lZNZRflqFyxVaTIU61KNKQ9C0055u9CAfpmqUvo4=
|
||||
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3/go.mod h1:nPpo7qLxd6XL3hWJG/O60sR8ZKfMCiIoNap5GvD12KU=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
|
||||
github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
|
@ -102,8 +99,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
|
|||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
|
||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec=
|
||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b h1:h9U78+dx9a4BKdQkBBos92HalKpaGKHrp+3Uo6yTodo=
|
||||
github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik=
|
||||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
|
@ -114,7 +111,6 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 h1:/c3QmbOGMGTOumP2iT/rCwB7b0Q
|
|||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1/go.mod h1:5SN9VR2LTsRFsrEC6FHgRbTWrTHu6tqPeKxEQv15giM=
|
||||
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 h1:MJG/KsmcqMwFAkh8mTnAwhyKoB+sTAnY4CACC110tbU=
|
||||
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645/go.mod h1:6iZfnjpejD4L/4DwD7NryNaJyCQdzwWwH2MWhCA90Kw=
|
||||
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
|
||||
github.com/ipostelnik/cli/v2 v2.3.1-0.20210324024421-b6ea8234fe3d h1:PRDnysJ9dF1vUMmEzBu6aHQeUluSQy4eWH3RsSSy/vI=
|
||||
github.com/ipostelnik/cli/v2 v2.3.1-0.20210324024421-b6ea8234fe3d/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
|
||||
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
|
@ -140,10 +136,10 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk
|
|||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||
github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA=
|
||||
github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
|
||||
github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4=
|
||||
github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY=
|
||||
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
|
||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
|
@ -152,16 +148,16 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
|
|||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
|
||||
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
|
||||
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
|
||||
github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4=
|
||||
github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o=
|
||||
github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI=
|
||||
github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M=
|
||||
github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
|
||||
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
|
||||
github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg=
|
||||
github.com/pelletier/go-toml/v2 v2.0.5/go.mod h1:OMHamSCAODeSsVrwwvcJOaoN0LIUIaFVNZzmWyNfXas=
|
||||
github.com/philhofer/fwd v1.1.1 h1:GdGcTjf5RNAxwS4QLsiMzJYj5KEvPJD3Abr261yRQXQ=
|
||||
github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU=
|
||||
github.com/philhofer/fwd v1.1.2 h1:bnDivRJ1EWPjUIRXV5KfORO897HTbpFAQddBdE8t7Gw=
|
||||
github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0=
|
||||
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
||||
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
@ -171,10 +167,10 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE=
|
||||
github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho=
|
||||
github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
|
||||
github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
|
||||
github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE=
|
||||
github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc=
|
||||
github.com/prometheus/client_model v0.6.0 h1:k1v3CzpSRUTrKMppY35TLwPvxHqBu0bYgxZzqGIgaos=
|
||||
github.com/prometheus/client_model v0.6.0/go.mod h1:NTQHnmxFpouOD0DpvP4XujX3CdOAGQPoaGhyTchlyt8=
|
||||
github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+aLCE=
|
||||
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||
github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
|
||||
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
|
||||
github.com/quic-go/quic-go v0.45.0 h1:OHmkQGM37luZITyTSu6ff03HP/2IrwDX1ZFiNEhSFUE=
|
||||
|
@ -195,14 +191,13 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
|||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tinylib/msgp v1.1.2 h1:gWmO7n0Ys2RBEb7GPYB9Ujq8Mk5p2U08lRnmMcGy6BQ=
|
||||
github.com/tinylib/msgp v1.1.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
|
||||
github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0=
|
||||
github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw=
|
||||
github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo=
|
||||
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
|
||||
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
|
||||
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
|
||||
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.opentelemetry.io/contrib/propagators v0.22.0 h1:KGdv58M2//veiYLIhb31mofaI2LgkIPXXAZVeYVyfd8=
|
||||
go.opentelemetry.io/contrib/propagators v0.22.0/go.mod h1:xGOuXr6lLIF9BXipA4pm6UuOSI0M98U6tsI3khbOiwU=
|
||||
|
@ -233,39 +228,32 @@ golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJ
|
|||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/oauth2 v0.17.0 h1:6m3ZPmLEFdVxKKWnKq4VqZ60gutO35zm+zrAHVmHyDQ=
|
||||
golang.org/x/oauth2 v0.17.0/go.mod h1:OzPDGQiuQMguemayvdylqddI7qcD9lnSDb+1FiwQ5HA=
|
||||
golang.org/x/oauth2 v0.18.0 h1:09qnuIAgzdx1XplqJvW6CQqMCtGZykZWcXzPMPUusvI=
|
||||
golang.org/x/oauth2 v0.18.0/go.mod h1:Wf7knwG0MPoWIMMBgFlEaSUDaKskp0dCfrlJRJXbBi8=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
|
@ -275,7 +263,6 @@ golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
|||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
|
@ -287,24 +274,20 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
|
|||
golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
|
||||
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=
|
||||
google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds=
|
||||
google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de h1:F6qOa9AZTYJXOUEr4jDysRDLrm4PHePlge4v4TGAlxY=
|
||||
google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:VUhTRKeHn9wwcdrk73nvdC9gF178Tzhmt/qyaFcPLSo=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de h1:jFNzHPIeuzhdRwVhbZdiym9q0ory/xY3sA+v2wPg8I0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:5iCWqnniDlqZHrd3neWVTOwvh/v6s3232omMecelax8=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de h1:cZGRis4/ot9uVm639a+rHCUaG0JJHEsdyzSQTMX+suY=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:H4O17MA/PE9BsGx3w+a+W2VOLLD1Qf7oJneAoU6WktY=
|
||||
google.golang.org/grpc v1.63.0 h1:WjKe+dnvABXyPJMD7KDNLxtoGk5tgk+YFWN6cBWjZE8=
|
||||
google.golang.org/grpc v1.63.0/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240311132316-a219d84964c2 h1:rIo7ocm2roD9DcFIX67Ym8icoGCKSARAiPljFhh5suQ=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240311132316-a219d84964c2/go.mod h1:O1cOfN1Cy6QEYr7VxtjOyP5AdAuR0aJ/MYZaaof623Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 h1:NnYq6UN9ReLM9/Y01KWNOWyI5xQ9kbIms5GGJVwS/Yc=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY=
|
||||
google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM=
|
||||
google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
|
|
|
@ -1,195 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
ActiveStreams = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: "cloudflared",
|
||||
Subsystem: "tunnel",
|
||||
Name: "active_streams",
|
||||
Help: "Number of active streams created by all muxers.",
|
||||
})
|
||||
)
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(ActiveStreams)
|
||||
}
|
||||
|
||||
// activeStreamMap is used to moderate access to active streams between the read and write
|
||||
// threads, and deny access to new peer streams while shutting down.
|
||||
type activeStreamMap struct {
|
||||
sync.RWMutex
|
||||
// streams tracks open streams.
|
||||
streams map[uint32]*MuxedStream
|
||||
// nextStreamID is the next ID to use on our side of the connection.
|
||||
// This is odd for clients, even for servers.
|
||||
nextStreamID uint32
|
||||
// maxPeerStreamID is the ID of the most recent stream opened by the peer.
|
||||
maxPeerStreamID uint32
|
||||
// activeStreams is a gauge shared by all muxers of this process to expose the total number of active streams
|
||||
activeStreams prometheus.Gauge
|
||||
|
||||
// ignoreNewStreams is true when the connection is being shut down. New streams
|
||||
// cannot be registered.
|
||||
ignoreNewStreams bool
|
||||
// streamsEmpty is a chan that will be closed when no more streams are open.
|
||||
streamsEmptyChan chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func newActiveStreamMap(useClientStreamNumbers bool, activeStreams prometheus.Gauge) *activeStreamMap {
|
||||
m := &activeStreamMap{
|
||||
streams: make(map[uint32]*MuxedStream),
|
||||
streamsEmptyChan: make(chan struct{}),
|
||||
nextStreamID: 1,
|
||||
activeStreams: activeStreams,
|
||||
}
|
||||
// Client initiated stream uses odd stream ID, server initiated stream uses even stream ID
|
||||
if !useClientStreamNumbers {
|
||||
m.nextStreamID = 2
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// This function should be called while `m` is locked.
|
||||
func (m *activeStreamMap) notifyStreamsEmpty() {
|
||||
m.closeOnce.Do(func() {
|
||||
close(m.streamsEmptyChan)
|
||||
})
|
||||
}
|
||||
|
||||
// Len returns the number of active streams.
|
||||
func (m *activeStreamMap) Len() int {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return len(m.streams)
|
||||
}
|
||||
|
||||
func (m *activeStreamMap) Get(streamID uint32) (*MuxedStream, bool) {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
stream, ok := m.streams[streamID]
|
||||
return stream, ok
|
||||
}
|
||||
|
||||
// Set returns true if the stream was assigned successfully. If a stream
|
||||
// already existed with that ID or we are shutting down, return false.
|
||||
func (m *activeStreamMap) Set(newStream *MuxedStream) bool {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
if _, ok := m.streams[newStream.streamID]; ok {
|
||||
return false
|
||||
}
|
||||
if m.ignoreNewStreams {
|
||||
return false
|
||||
}
|
||||
m.streams[newStream.streamID] = newStream
|
||||
m.activeStreams.Inc()
|
||||
return true
|
||||
}
|
||||
|
||||
// Delete stops tracking the stream. It should be called only after it is closed and reset.
|
||||
func (m *activeStreamMap) Delete(streamID uint32) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
if _, ok := m.streams[streamID]; ok {
|
||||
delete(m.streams, streamID)
|
||||
m.activeStreams.Dec()
|
||||
}
|
||||
|
||||
// shutting down, and now the map is empty
|
||||
if m.ignoreNewStreams && len(m.streams) == 0 {
|
||||
m.notifyStreamsEmpty()
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown blocks new streams from being created.
|
||||
// It returns `done`, a channel that is closed once the last stream has closed
|
||||
// and `progress`, whether a shutdown was already in progress
|
||||
func (m *activeStreamMap) Shutdown() (done <-chan struct{}, alreadyInProgress bool) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
if m.ignoreNewStreams {
|
||||
// already shutting down
|
||||
return m.streamsEmptyChan, true
|
||||
}
|
||||
m.ignoreNewStreams = true
|
||||
if len(m.streams) == 0 {
|
||||
// there are no streams to wait for
|
||||
m.notifyStreamsEmpty()
|
||||
}
|
||||
return m.streamsEmptyChan, false
|
||||
}
|
||||
|
||||
// AcquireLocalID acquires a new stream ID for a stream you're opening.
|
||||
func (m *activeStreamMap) AcquireLocalID() uint32 {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
x := m.nextStreamID
|
||||
m.nextStreamID += 2
|
||||
return x
|
||||
}
|
||||
|
||||
// ObservePeerID observes the ID of a stream opened by the peer. It returns true if we should accept
|
||||
// the new stream, or false to reject it. The ErrCode gives the reason why.
|
||||
func (m *activeStreamMap) AcquirePeerID(streamID uint32) (bool, http2.ErrCode) {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
switch {
|
||||
case m.ignoreNewStreams:
|
||||
return false, http2.ErrCodeStreamClosed
|
||||
case streamID > m.maxPeerStreamID:
|
||||
m.maxPeerStreamID = streamID
|
||||
return true, http2.ErrCodeNo
|
||||
default:
|
||||
return false, http2.ErrCodeStreamClosed
|
||||
}
|
||||
}
|
||||
|
||||
// IsPeerStreamID is true if the stream ID belongs to the peer.
|
||||
func (m *activeStreamMap) IsPeerStreamID(streamID uint32) bool {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return (streamID % 2) != (m.nextStreamID % 2)
|
||||
}
|
||||
|
||||
// IsLocalStreamID is true if it is a stream we have opened, even if it is now closed.
|
||||
func (m *activeStreamMap) IsLocalStreamID(streamID uint32) bool {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return (streamID%2) == (m.nextStreamID%2) && streamID < m.nextStreamID
|
||||
}
|
||||
|
||||
// LastPeerStreamID returns the most recently opened peer stream ID.
|
||||
func (m *activeStreamMap) LastPeerStreamID() uint32 {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return m.maxPeerStreamID
|
||||
}
|
||||
|
||||
// LastLocalStreamID returns the most recently opened local stream ID.
|
||||
func (m *activeStreamMap) LastLocalStreamID() uint32 {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
if m.nextStreamID > 1 {
|
||||
return m.nextStreamID - 2
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Abort closes every active stream and prevents new ones being created. This should be used to
|
||||
// return errors in pending read/writes when the underlying connection goes away.
|
||||
func (m *activeStreamMap) Abort() {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
for _, stream := range m.streams {
|
||||
stream.Close()
|
||||
}
|
||||
m.ignoreNewStreams = true
|
||||
m.notifyStreamsEmpty()
|
||||
}
|
|
@ -1,195 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
const numStreams = 1000
|
||||
m := newActiveStreamMap(true, ActiveStreams)
|
||||
|
||||
// Add all the streams
|
||||
{
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func(streamID int) {
|
||||
defer wg.Done()
|
||||
stream := &MuxedStream{streamID: uint32(streamID)}
|
||||
ok := m.Set(stream)
|
||||
assert.True(t, ok)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
|
||||
|
||||
shutdownChan, alreadyInProgress := m.Shutdown()
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed")
|
||||
default:
|
||||
}
|
||||
assert.False(t, alreadyInProgress)
|
||||
|
||||
shutdownChan2, alreadyInProgress2 := m.Shutdown()
|
||||
assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel")
|
||||
assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'")
|
||||
|
||||
// Delete all the streams
|
||||
{
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func(streamID int) {
|
||||
defer wg.Done()
|
||||
m.Delete(uint32(streamID))
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
assert.Equal(t, 0, m.Len(), "All the streams should have been deleted")
|
||||
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
default:
|
||||
assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyBeforeShutdown(t *testing.T) {
|
||||
const numStreams = 1000
|
||||
m := newActiveStreamMap(true, ActiveStreams)
|
||||
|
||||
// Add all the streams
|
||||
{
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func(streamID int) {
|
||||
defer wg.Done()
|
||||
stream := &MuxedStream{streamID: uint32(streamID)}
|
||||
ok := m.Set(stream)
|
||||
assert.True(t, ok)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
|
||||
|
||||
// Delete all the streams, bringing m to size 0
|
||||
{
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func(streamID int) {
|
||||
defer wg.Done()
|
||||
m.Delete(uint32(streamID))
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
assert.Equal(t, 0, m.Len(), "All the streams should have been deleted")
|
||||
|
||||
// Add one stream back
|
||||
const soloStreamID = uint32(0)
|
||||
ok := m.Set(&MuxedStream{streamID: soloStreamID})
|
||||
assert.True(t, ok)
|
||||
|
||||
shutdownChan, alreadyInProgress := m.Shutdown()
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
assert.Fail(t, "before Shutdown(), shutdownChan shouldn't be closed")
|
||||
default:
|
||||
}
|
||||
assert.False(t, alreadyInProgress)
|
||||
|
||||
shutdownChan2, alreadyInProgress2 := m.Shutdown()
|
||||
assert.Equal(t, shutdownChan, shutdownChan2, "repeated calls to Shutdown() should return the same channel")
|
||||
assert.True(t, alreadyInProgress2, "repeated calls to Shutdown() should return true for 'in progress'")
|
||||
|
||||
// Remove the remaining stream
|
||||
m.Delete(soloStreamID)
|
||||
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
default:
|
||||
assert.Fail(t, "After all the streams are deleted, shutdownChan should have been closed")
|
||||
}
|
||||
}
|
||||
|
||||
type noopBuffer struct {
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func (t *noopBuffer) Read(p []byte) (n int, err error) { return len(p), nil }
|
||||
func (t *noopBuffer) Write(p []byte) (n int, err error) { return len(p), nil }
|
||||
func (t *noopBuffer) Reset() {}
|
||||
func (t *noopBuffer) Len() int { return 0 }
|
||||
func (t *noopBuffer) Close() error { t.isClosed = true; return nil }
|
||||
func (t *noopBuffer) Closed() bool { return t.isClosed }
|
||||
|
||||
type noopReadyList struct{}
|
||||
|
||||
func (_ *noopReadyList) Signal(streamID uint32) {}
|
||||
|
||||
func TestAbort(t *testing.T) {
|
||||
const numStreams = 1000
|
||||
m := newActiveStreamMap(true, ActiveStreams)
|
||||
|
||||
var openedStreams sync.Map
|
||||
|
||||
// Add all the streams
|
||||
{
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func(streamID int) {
|
||||
defer wg.Done()
|
||||
stream := &MuxedStream{
|
||||
streamID: uint32(streamID),
|
||||
readBuffer: &noopBuffer{},
|
||||
writeBuffer: &noopBuffer{},
|
||||
readyList: &noopReadyList{},
|
||||
}
|
||||
ok := m.Set(stream)
|
||||
assert.True(t, ok)
|
||||
|
||||
openedStreams.Store(stream.streamID, stream)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
assert.Equal(t, numStreams, m.Len(), "All the streams should have been added")
|
||||
|
||||
shutdownChan, alreadyInProgress := m.Shutdown()
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
assert.Fail(t, "before Abort(), shutdownChan shouldn't be closed")
|
||||
default:
|
||||
}
|
||||
assert.False(t, alreadyInProgress)
|
||||
|
||||
m.Abort()
|
||||
assert.Equal(t, numStreams, m.Len(), "Abort() shouldn't delete any streams")
|
||||
openedStreams.Range(func(key interface{}, value interface{}) bool {
|
||||
stream := value.(*MuxedStream)
|
||||
readBuffer := stream.readBuffer.(*noopBuffer)
|
||||
writeBuffer := stream.writeBuffer.(*noopBuffer)
|
||||
return assert.True(t, readBuffer.isClosed && writeBuffer.isClosed, "Abort() should have closed all the streams")
|
||||
})
|
||||
|
||||
select {
|
||||
case <-shutdownChan:
|
||||
default:
|
||||
assert.Fail(t, "after Abort(), shutdownChan should have been closed")
|
||||
}
|
||||
|
||||
// multiple aborts shouldn't cause any issues
|
||||
m.Abort()
|
||||
m.Abort()
|
||||
m.Abort()
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type AtomicCounter struct {
|
||||
count uint64
|
||||
}
|
||||
|
||||
func NewAtomicCounter(initCount uint64) *AtomicCounter {
|
||||
return &AtomicCounter{count: initCount}
|
||||
}
|
||||
|
||||
func (c *AtomicCounter) IncrementBy(number uint64) {
|
||||
atomic.AddUint64(&c.count, number)
|
||||
}
|
||||
|
||||
// Count returns the current value of counter and reset it to 0
|
||||
func (c *AtomicCounter) Count() uint64 {
|
||||
return atomic.SwapUint64(&c.count, 0)
|
||||
}
|
||||
|
||||
// Value returns the current value of counter
|
||||
func (c *AtomicCounter) Value() uint64 {
|
||||
return atomic.LoadUint64(&c.count)
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCounter(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(dataPoints)
|
||||
c := AtomicCounter{}
|
||||
for i := 0; i < dataPoints; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c.IncrementBy(uint64(1))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Equal(t, uint64(dataPoints), c.Count())
|
||||
assert.Equal(t, uint64(0), c.Count())
|
||||
}
|
|
@ -1,66 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
// HTTP2 error codes: https://http2.github.io/http2-spec/#ErrorCodes
|
||||
ErrHandshakeTimeout = MuxerHandshakeError{"1000 handshake timeout"}
|
||||
ErrBadHandshakeNotSettings = MuxerHandshakeError{"1001 unexpected response"}
|
||||
ErrBadHandshakeUnexpectedAck = MuxerHandshakeError{"1002 unexpected response"}
|
||||
ErrBadHandshakeNoMagic = MuxerHandshakeError{"1003 unexpected response"}
|
||||
ErrBadHandshakeWrongMagic = MuxerHandshakeError{"1004 connected to endpoint of wrong type"}
|
||||
ErrBadHandshakeNotSettingsAck = MuxerHandshakeError{"1005 unexpected response"}
|
||||
ErrBadHandshakeUnexpectedSettings = MuxerHandshakeError{"1006 unexpected response"}
|
||||
|
||||
ErrUnexpectedFrameType = MuxerProtocolError{"2001 unexpected frame type", http2.ErrCodeProtocol}
|
||||
ErrUnknownStream = MuxerProtocolError{"2002 unknown stream", http2.ErrCodeProtocol}
|
||||
ErrInvalidStream = MuxerProtocolError{"2003 invalid stream", http2.ErrCodeProtocol}
|
||||
ErrNotRPCStream = MuxerProtocolError{"2004 not RPC stream", http2.ErrCodeProtocol}
|
||||
|
||||
ErrStreamHeadersSent = MuxerApplicationError{"3000 headers already sent"}
|
||||
ErrStreamRequestConnectionClosed = MuxerApplicationError{"3001 connection closed while opening stream"}
|
||||
ErrConnectionDropped = MuxerApplicationError{"3002 connection dropped"}
|
||||
ErrStreamRequestTimeout = MuxerApplicationError{"3003 open stream timeout"}
|
||||
ErrResponseHeadersTimeout = MuxerApplicationError{"3004 timeout waiting for initial response headers"}
|
||||
ErrResponseHeadersConnectionClosed = MuxerApplicationError{"3005 connection closed while waiting for initial response headers"}
|
||||
|
||||
ErrClosedStream = MuxerStreamError{"4000 stream closed", http2.ErrCodeStreamClosed}
|
||||
)
|
||||
|
||||
type MuxerHandshakeError struct {
|
||||
cause string
|
||||
}
|
||||
|
||||
func (e MuxerHandshakeError) Error() string {
|
||||
return fmt.Sprintf("Handshake error: %s", e.cause)
|
||||
}
|
||||
|
||||
type MuxerProtocolError struct {
|
||||
cause string
|
||||
h2code http2.ErrCode
|
||||
}
|
||||
|
||||
func (e MuxerProtocolError) Error() string {
|
||||
return fmt.Sprintf("Protocol error: %s", e.cause)
|
||||
}
|
||||
|
||||
type MuxerApplicationError struct {
|
||||
cause string
|
||||
}
|
||||
|
||||
func (e MuxerApplicationError) Error() string {
|
||||
return fmt.Sprintf("Application error: %s", e.cause)
|
||||
}
|
||||
|
||||
type MuxerStreamError struct {
|
||||
cause string
|
||||
h2code http2.ErrCode
|
||||
}
|
||||
|
||||
func (e MuxerStreamError) Error() string {
|
||||
return fmt.Sprintf("Stream error: %s", e.cause)
|
||||
}
|
|
@ -1,17 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
func CompressionIsSupported() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func newDecompressor(src io.Reader) decompressor {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newCompressor(dst io.Writer, quality, lgwin int) compressor {
|
||||
return nil
|
||||
}
|
|
@ -1,596 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
/* This is an implementation of https://github.com/vkrasnov/h2-compression-dictionaries
|
||||
but modified for tunnels in a few key ways:
|
||||
Since tunnels is a server-to-server service, some aspects of the spec would cause
|
||||
unnecessary head-of-line blocking on the CPU and on the network, hence this implementation
|
||||
allows for parallel compression on the "client", and buffering on the "server" to solve
|
||||
this problem. */
|
||||
|
||||
// Assign temporary values
|
||||
const SettingCompression http2.SettingID = 0xff20
|
||||
|
||||
const (
|
||||
FrameSetCompressionContext http2.FrameType = 0xf0
|
||||
FrameUseDictionary http2.FrameType = 0xf1
|
||||
FrameSetDictionary http2.FrameType = 0xf2
|
||||
)
|
||||
|
||||
const (
|
||||
FlagSetDictionaryAppend http2.Flags = 0x1
|
||||
FlagSetDictionaryOffset http2.Flags = 0x2
|
||||
)
|
||||
|
||||
const compressionVersion = uint8(1)
|
||||
const compressionFormat = uint8(2)
|
||||
|
||||
type CompressionSetting uint
|
||||
|
||||
const (
|
||||
CompressionNone CompressionSetting = iota
|
||||
CompressionLow
|
||||
CompressionMedium
|
||||
CompressionMax
|
||||
)
|
||||
|
||||
type CompressionPreset struct {
|
||||
nDicts, dictSize, quality uint8
|
||||
}
|
||||
|
||||
type compressor interface {
|
||||
Write([]byte) (int, error)
|
||||
Flush() error
|
||||
SetDictionary([]byte)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type decompressor interface {
|
||||
Read([]byte) (int, error)
|
||||
SetDictionary([]byte)
|
||||
Close() error
|
||||
}
|
||||
|
||||
var compressionPresets = map[CompressionSetting]CompressionPreset{
|
||||
CompressionNone: {0, 0, 0},
|
||||
CompressionLow: {32, 17, 5},
|
||||
CompressionMedium: {64, 18, 6},
|
||||
CompressionMax: {255, 19, 9},
|
||||
}
|
||||
|
||||
func compressionSettingVal(version, fmt, sz, nd uint8) uint32 {
|
||||
// Currently the compression settings are include:
|
||||
// * version: only 1 is supported
|
||||
// * fmt: only 2 for brotli is supported
|
||||
// * sz: log2 of the maximal allowed dictionary size
|
||||
// * nd: max allowed number of dictionaries
|
||||
return uint32(version)<<24 + uint32(fmt)<<16 + uint32(sz)<<8 + uint32(nd)
|
||||
}
|
||||
|
||||
func parseCompressionSettingVal(setting uint32) (version, fmt, sz, nd uint8) {
|
||||
version = uint8(setting >> 24)
|
||||
fmt = uint8(setting >> 16)
|
||||
sz = uint8(setting >> 8)
|
||||
nd = uint8(setting)
|
||||
return
|
||||
}
|
||||
|
||||
func (c CompressionSetting) toH2Setting() uint32 {
|
||||
p, ok := compressionPresets[c]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return compressionSettingVal(compressionVersion, compressionFormat, p.dictSize, p.nDicts)
|
||||
}
|
||||
|
||||
func (c CompressionSetting) getPreset() CompressionPreset {
|
||||
return compressionPresets[c]
|
||||
}
|
||||
|
||||
type dictUpdate struct {
|
||||
reader *h2DictionaryReader
|
||||
dictionary *h2ReadDictionary
|
||||
buff []byte
|
||||
isReady bool
|
||||
isUse bool
|
||||
s setDictRequest
|
||||
}
|
||||
|
||||
type h2ReadDictionary struct {
|
||||
dictionary []byte
|
||||
queue []*dictUpdate
|
||||
maxSize int
|
||||
}
|
||||
|
||||
type h2ReadDictionaries struct {
|
||||
d []h2ReadDictionary
|
||||
maxSize int
|
||||
}
|
||||
|
||||
type h2DictionaryReader struct {
|
||||
*SharedBuffer // Propagate the decompressed output into the original buffer
|
||||
decompBuffer *bytes.Buffer // Intermediate buffer for the brotli compressor
|
||||
dictionary []byte // The content of the dictionary being used by this reader
|
||||
internalBuffer []byte
|
||||
s, e int // Start and end of the buffer
|
||||
decomp decompressor // The brotli compressor
|
||||
isClosed bool // Indicates that Close was called for this reader
|
||||
queue []*dictUpdate // List of dictionaries to update, when the data is available
|
||||
}
|
||||
|
||||
type h2WriteDictionary []byte
|
||||
|
||||
type setDictRequest struct {
|
||||
streamID uint32
|
||||
dictID uint8
|
||||
dictSZ uint64
|
||||
truncate, offset uint64
|
||||
P, E, D bool
|
||||
}
|
||||
|
||||
type useDictRequest struct {
|
||||
dictID uint8
|
||||
streamID uint32
|
||||
setDict []setDictRequest
|
||||
}
|
||||
|
||||
type h2WriteDictionaries struct {
|
||||
dictLock sync.Mutex
|
||||
dictChan chan useDictRequest
|
||||
dictionaries []h2WriteDictionary
|
||||
nextAvail int // next unused dictionary slot
|
||||
maxAvail int // max ID, defined by SETTINGS
|
||||
maxSize int // max size, defined by SETTINGS
|
||||
typeToDict map[string]uint8 // map from content type to dictionary that encodes it
|
||||
pathToDict map[string]uint8 // map from path to dictionary that encodes it
|
||||
quality int
|
||||
window int
|
||||
compIn, compOut *AtomicCounter
|
||||
}
|
||||
|
||||
type h2DictWriter struct {
|
||||
*bytes.Buffer
|
||||
comp compressor
|
||||
dicts *h2WriteDictionaries
|
||||
writerLock sync.Mutex
|
||||
|
||||
streamID uint32
|
||||
path string
|
||||
contentType string
|
||||
}
|
||||
|
||||
type h2Dictionaries struct {
|
||||
write *h2WriteDictionaries
|
||||
read *h2ReadDictionaries
|
||||
}
|
||||
|
||||
func (o *dictUpdate) update(buff []byte) {
|
||||
o.buff = make([]byte, len(buff))
|
||||
copy(o.buff, buff)
|
||||
o.isReady = true
|
||||
}
|
||||
|
||||
func (d *h2ReadDictionary) update() {
|
||||
for len(d.queue) > 0 {
|
||||
o := d.queue[0]
|
||||
if !o.isReady {
|
||||
break
|
||||
}
|
||||
if o.isUse {
|
||||
reader := o.reader
|
||||
reader.dictionary = make([]byte, len(d.dictionary))
|
||||
copy(reader.dictionary, d.dictionary)
|
||||
reader.decomp = newDecompressor(reader.decompBuffer)
|
||||
if len(reader.dictionary) > 0 {
|
||||
reader.decomp.SetDictionary(reader.dictionary)
|
||||
}
|
||||
reader.Write([]byte{})
|
||||
} else {
|
||||
d.dictionary = adjustDictionary(d.dictionary, o.buff, o.s, d.maxSize)
|
||||
}
|
||||
d.queue = d.queue[1:]
|
||||
}
|
||||
}
|
||||
|
||||
func newH2ReadDictionaries(nd, sz uint8) h2ReadDictionaries {
|
||||
d := make([]h2ReadDictionary, int(nd))
|
||||
for i := range d {
|
||||
d[i].maxSize = 1 << uint(sz)
|
||||
}
|
||||
return h2ReadDictionaries{d: d, maxSize: 1 << uint(sz)}
|
||||
}
|
||||
|
||||
func (dicts *h2ReadDictionaries) getDictByID(dictID uint8) (*h2ReadDictionary, error) {
|
||||
if int(dictID) > len(dicts.d) {
|
||||
return nil, MuxerStreamError{"dictID too big", http2.ErrCodeProtocol}
|
||||
}
|
||||
|
||||
return &dicts.d[dictID], nil
|
||||
}
|
||||
|
||||
func (dicts *h2ReadDictionaries) newReader(b *SharedBuffer, dictID uint8) *h2DictionaryReader {
|
||||
if int(dictID) > len(dicts.d) {
|
||||
return nil
|
||||
}
|
||||
|
||||
dictionary := &dicts.d[dictID]
|
||||
reader := &h2DictionaryReader{SharedBuffer: b, decompBuffer: &bytes.Buffer{}, internalBuffer: make([]byte, dicts.maxSize)}
|
||||
|
||||
if len(dictionary.queue) == 0 {
|
||||
reader.dictionary = make([]byte, len(dictionary.dictionary))
|
||||
copy(reader.dictionary, dictionary.dictionary)
|
||||
reader.decomp = newDecompressor(reader.decompBuffer)
|
||||
if len(reader.dictionary) > 0 {
|
||||
reader.decomp.SetDictionary(reader.dictionary)
|
||||
}
|
||||
} else {
|
||||
dictionary.queue = append(dictionary.queue, &dictUpdate{isUse: true, isReady: true, reader: reader})
|
||||
}
|
||||
return reader
|
||||
}
|
||||
|
||||
func (r *h2DictionaryReader) updateWaitingDictionaries() {
|
||||
// Update all the waiting dictionaries
|
||||
for _, o := range r.queue {
|
||||
if o.isReady {
|
||||
continue
|
||||
}
|
||||
if r.isClosed || uint64(r.e) >= o.s.dictSZ {
|
||||
o.update(r.internalBuffer[:r.e])
|
||||
if o == o.dictionary.queue[0] {
|
||||
defer o.dictionary.update()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write actually happens when reading from network, this is therefore the stage where we decompress the buffer
|
||||
func (r *h2DictionaryReader) Write(p []byte) (n int, err error) {
|
||||
// Every write goes into brotli buffer first
|
||||
n, err = r.decompBuffer.Write(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if r.decomp == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
m, err := r.decomp.Read(r.internalBuffer[r.e:])
|
||||
if err != nil && err != io.EOF {
|
||||
r.SharedBuffer.Close()
|
||||
r.decomp.Close()
|
||||
return n, err
|
||||
}
|
||||
|
||||
r.SharedBuffer.Write(r.internalBuffer[r.e : r.e+m])
|
||||
r.e += m
|
||||
|
||||
if m == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if r.e == len(r.internalBuffer) {
|
||||
r.updateWaitingDictionaries()
|
||||
r.e = 0
|
||||
}
|
||||
}
|
||||
|
||||
r.updateWaitingDictionaries()
|
||||
|
||||
if r.isClosed {
|
||||
r.SharedBuffer.Close()
|
||||
r.decomp.Close()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (r *h2DictionaryReader) Close() error {
|
||||
if r.isClosed {
|
||||
return nil
|
||||
}
|
||||
r.isClosed = true
|
||||
r.Write([]byte{})
|
||||
return nil
|
||||
}
|
||||
|
||||
var compressibleTypes = map[string]bool{
|
||||
"application/atom+xml": true,
|
||||
"application/javascript": true,
|
||||
"application/json": true,
|
||||
"application/ld+json": true,
|
||||
"application/manifest+json": true,
|
||||
"application/rss+xml": true,
|
||||
"application/vnd.geo+json": true,
|
||||
"application/vnd.ms-fontobject": true,
|
||||
"application/x-font-ttf": true,
|
||||
"application/x-yaml": true,
|
||||
"application/x-web-app-manifest+json": true,
|
||||
"application/xhtml+xml": true,
|
||||
"application/xml": true,
|
||||
"font/opentype": true,
|
||||
"image/bmp": true,
|
||||
"image/svg+xml": true,
|
||||
"image/x-icon": true,
|
||||
"text/cache-manifest": true,
|
||||
"text/css": true,
|
||||
"text/html": true,
|
||||
"text/plain": true,
|
||||
"text/vcard": true,
|
||||
"text/vnd.rim.location.xloc": true,
|
||||
"text/vtt": true,
|
||||
"text/x-component": true,
|
||||
"text/x-cross-domain-policy": true,
|
||||
"text/x-yaml": true,
|
||||
}
|
||||
|
||||
func getContentType(headers []Header) string {
|
||||
for _, h := range headers {
|
||||
if strings.ToLower(h.Name) == "content-type" {
|
||||
val := strings.ToLower(h.Value)
|
||||
sep := strings.IndexRune(val, ';')
|
||||
if sep != -1 {
|
||||
return val[:sep]
|
||||
}
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func newH2WriteDictionaries(nd, sz, quality uint8, compIn, compOut *AtomicCounter) (*h2WriteDictionaries, chan useDictRequest) {
|
||||
useDictChan := make(chan useDictRequest)
|
||||
return &h2WriteDictionaries{
|
||||
dictionaries: make([]h2WriteDictionary, nd),
|
||||
nextAvail: 0,
|
||||
maxAvail: int(nd),
|
||||
maxSize: 1 << uint(sz),
|
||||
dictChan: useDictChan,
|
||||
typeToDict: make(map[string]uint8),
|
||||
pathToDict: make(map[string]uint8),
|
||||
quality: int(quality),
|
||||
window: 1 << uint(sz+1),
|
||||
compIn: compIn,
|
||||
compOut: compOut,
|
||||
}, useDictChan
|
||||
}
|
||||
|
||||
func adjustDictionary(currentDictionary, newData []byte, set setDictRequest, maxSize int) []byte {
|
||||
currentDictionary = append(currentDictionary, newData[:set.dictSZ]...)
|
||||
|
||||
if len(currentDictionary) > maxSize {
|
||||
currentDictionary = currentDictionary[len(currentDictionary)-maxSize:]
|
||||
}
|
||||
|
||||
return currentDictionary
|
||||
}
|
||||
|
||||
func (h2d *h2WriteDictionaries) getNextDictID() (dictID uint8, ok bool) {
|
||||
if h2d.nextAvail < h2d.maxAvail {
|
||||
dictID, ok = uint8(h2d.nextAvail), true
|
||||
h2d.nextAvail++
|
||||
return
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (h2d *h2WriteDictionaries) getGenericDictID() (dictID uint8, ok bool) {
|
||||
if h2d.maxAvail == 0 {
|
||||
return 0, false
|
||||
}
|
||||
return uint8(h2d.maxAvail - 1), true
|
||||
}
|
||||
|
||||
func (h2d *h2WriteDictionaries) getDictWriter(s *MuxedStream, headers []Header) *h2DictWriter {
|
||||
w := s.writeBuffer
|
||||
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.method != "GET" && s.method != "POST" {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.contentType = getContentType(headers)
|
||||
if _, ok := compressibleTypes[s.contentType]; !ok && !strings.HasPrefix(s.contentType, "text") {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &h2DictWriter{
|
||||
Buffer: w.(*bytes.Buffer),
|
||||
path: s.path,
|
||||
contentType: s.contentType,
|
||||
streamID: s.streamID,
|
||||
dicts: h2d,
|
||||
}
|
||||
}
|
||||
|
||||
func assignDictToStream(s *MuxedStream, p []byte) bool {
|
||||
|
||||
// On first write to stream:
|
||||
// * assign the right dictionary
|
||||
// * update relevant dictionaries
|
||||
// * send the required USE_DICT and SET_DICT frames
|
||||
|
||||
h2d := s.dictionaries.write
|
||||
if h2d == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
w, ok := s.writeBuffer.(*h2DictWriter)
|
||||
if !ok || w.comp != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
h2d.dictLock.Lock()
|
||||
|
||||
if w.comp != nil {
|
||||
// Check again with lock, in therory the interface allows for unordered writes
|
||||
h2d.dictLock.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
// The logic of dictionary generation is below
|
||||
|
||||
// Is there a dictionary for the exact path or content-type?
|
||||
var useID uint8
|
||||
pathID, pathFound := h2d.pathToDict[w.path]
|
||||
typeID, typeFound := h2d.typeToDict[w.contentType]
|
||||
|
||||
if pathFound {
|
||||
// Use dictionary for path as top priority
|
||||
useID = pathID
|
||||
if !typeFound { // Shouldn't really happen, unless type changes between requests
|
||||
typeID, typeFound = h2d.getNextDictID()
|
||||
if typeFound {
|
||||
h2d.typeToDict[w.contentType] = typeID
|
||||
}
|
||||
}
|
||||
} else if typeFound {
|
||||
// Use dictionary for same content type as second priority
|
||||
useID = typeID
|
||||
pathID, pathFound = h2d.getNextDictID()
|
||||
if pathFound { // If a slot is available, generate new dictionary for path
|
||||
h2d.pathToDict[w.path] = pathID
|
||||
}
|
||||
} else {
|
||||
// Use the overflow dictionary as last resort
|
||||
// If slots are available generate new dictionaries for path and content-type
|
||||
useID, _ = h2d.getGenericDictID()
|
||||
pathID, pathFound = h2d.getNextDictID()
|
||||
if pathFound {
|
||||
h2d.pathToDict[w.path] = pathID
|
||||
}
|
||||
typeID, typeFound = h2d.getNextDictID()
|
||||
if typeFound {
|
||||
h2d.typeToDict[w.contentType] = typeID
|
||||
}
|
||||
}
|
||||
|
||||
useLen := h2d.maxSize
|
||||
if len(p) < useLen {
|
||||
useLen = len(p)
|
||||
}
|
||||
|
||||
// Update all the dictionaries using the new data
|
||||
setDicts := make([]setDictRequest, 0, 3)
|
||||
setDict := setDictRequest{
|
||||
streamID: w.streamID,
|
||||
dictID: useID,
|
||||
dictSZ: uint64(useLen),
|
||||
}
|
||||
setDicts = append(setDicts, setDict)
|
||||
if pathID != useID {
|
||||
setDict.dictID = pathID
|
||||
setDicts = append(setDicts, setDict)
|
||||
}
|
||||
if typeID != useID {
|
||||
setDict.dictID = typeID
|
||||
setDicts = append(setDicts, setDict)
|
||||
}
|
||||
|
||||
h2d.dictChan <- useDictRequest{streamID: w.streamID, dictID: uint8(useID), setDict: setDicts}
|
||||
|
||||
dict := h2d.dictionaries[useID]
|
||||
|
||||
// Brolti requires the dictionary to be immutable
|
||||
copyDict := make([]byte, len(dict))
|
||||
copy(copyDict, dict)
|
||||
|
||||
for _, set := range setDicts {
|
||||
h2d.dictionaries[set.dictID] = adjustDictionary(h2d.dictionaries[set.dictID], p, set, h2d.maxSize)
|
||||
}
|
||||
|
||||
w.comp = newCompressor(w.Buffer, h2d.quality, h2d.window)
|
||||
|
||||
s.writeLock.Lock()
|
||||
h2d.dictLock.Unlock()
|
||||
|
||||
if len(copyDict) > 0 {
|
||||
w.comp.SetDictionary(copyDict)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *h2DictWriter) Write(p []byte) (n int, err error) {
|
||||
bufLen := w.Buffer.Len()
|
||||
if w.comp != nil {
|
||||
n, err = w.comp.Write(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = w.comp.Flush()
|
||||
w.dicts.compIn.IncrementBy(uint64(n))
|
||||
w.dicts.compOut.IncrementBy(uint64(w.Buffer.Len() - bufLen))
|
||||
return
|
||||
}
|
||||
return w.Buffer.Write(p)
|
||||
}
|
||||
|
||||
func (w *h2DictWriter) Close() error {
|
||||
if w.comp != nil {
|
||||
return w.comp.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// From http2/hpack
|
||||
func http2ReadVarInt(n byte, p []byte) (remain []byte, v uint64, err error) {
|
||||
if n < 1 || n > 8 {
|
||||
panic("bad n")
|
||||
}
|
||||
if len(p) == 0 {
|
||||
return nil, 0, MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
|
||||
}
|
||||
v = uint64(p[0])
|
||||
if n < 8 {
|
||||
v &= (1 << uint64(n)) - 1
|
||||
}
|
||||
if v < (1<<uint64(n))-1 {
|
||||
return p[1:], v, nil
|
||||
}
|
||||
|
||||
origP := p
|
||||
p = p[1:]
|
||||
var m uint64
|
||||
for len(p) > 0 {
|
||||
b := p[0]
|
||||
p = p[1:]
|
||||
v += uint64(b&127) << m
|
||||
if b&128 == 0 {
|
||||
return p, v, nil
|
||||
}
|
||||
m += 7
|
||||
if m >= 63 {
|
||||
return origP, 0, MuxerStreamError{"invalid integer", http2.ErrCodeProtocol}
|
||||
}
|
||||
}
|
||||
return nil, 0, MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
|
||||
}
|
||||
|
||||
func appendVarInt(dst []byte, n byte, i uint64) []byte {
|
||||
k := uint64((1 << n) - 1)
|
||||
if i < k {
|
||||
return append(dst, byte(i))
|
||||
}
|
||||
dst = append(dst, byte(k))
|
||||
i -= k
|
||||
for ; i >= 128; i >>= 7 {
|
||||
dst = append(dst, byte(0x80|(i&0x7f)))
|
||||
}
|
||||
return append(dst, byte(i))
|
||||
}
|
506
h2mux/h2mux.go
506
h2mux/h2mux.go
|
@ -1,506 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultFrameSize uint32 = 1 << 14 // Minimum frame size in http2 spec
|
||||
defaultWindowSize uint32 = (1 << 16) - 1 // Minimum window size in http2 spec
|
||||
maxWindowSize uint32 = (1 << 31) - 1 // 2^31-1 = 2147483647, max window size in http2 spec
|
||||
defaultTimeout time.Duration = 5 * time.Second
|
||||
defaultRetries uint64 = 5
|
||||
defaultWriteBufferMaxLen int = 1024 * 1024 // 1mb
|
||||
writeBufferInitialSize int = 16 * 1024 // 16KB
|
||||
|
||||
SettingMuxerMagic http2.SettingID = 0x42db
|
||||
MuxerMagicOrigin uint32 = 0xa2e43c8b
|
||||
MuxerMagicEdge uint32 = 0x1088ebf9
|
||||
)
|
||||
|
||||
type MuxedStreamHandler interface {
|
||||
ServeStream(*MuxedStream) error
|
||||
}
|
||||
|
||||
type MuxedStreamFunc func(stream *MuxedStream) error
|
||||
|
||||
func (f MuxedStreamFunc) ServeStream(stream *MuxedStream) error {
|
||||
return f(stream)
|
||||
}
|
||||
|
||||
type MuxerConfig struct {
|
||||
Timeout time.Duration
|
||||
Handler MuxedStreamHandler
|
||||
IsClient bool
|
||||
// Name is used to identify this muxer instance when logging.
|
||||
Name string
|
||||
// The minimum time this connection can be idle before sending a heartbeat.
|
||||
HeartbeatInterval time.Duration
|
||||
// The minimum number of heartbeats to send before terminating the connection.
|
||||
MaxHeartbeats uint64
|
||||
// Logger to use
|
||||
Log *zerolog.Logger
|
||||
CompressionQuality CompressionSetting
|
||||
// Initial size for HTTP2 flow control windows
|
||||
DefaultWindowSize uint32
|
||||
// Largest allowable size for HTTP2 flow control windows
|
||||
MaxWindowSize uint32
|
||||
// Largest allowable capacity for the buffer of data to be sent
|
||||
StreamWriteBufferMaxLen int
|
||||
}
|
||||
|
||||
type Muxer struct {
|
||||
// f is used to read and write HTTP2 frames on the wire.
|
||||
f *http2.Framer
|
||||
// config is the MuxerConfig given in Handshake.
|
||||
config MuxerConfig
|
||||
// w, r are references to the underlying connection used.
|
||||
w io.WriteCloser
|
||||
r io.ReadCloser
|
||||
// muxReader is the read process.
|
||||
muxReader *MuxReader
|
||||
// muxWriter is the write process.
|
||||
muxWriter *MuxWriter
|
||||
// muxMetricsUpdater is the process to update metrics
|
||||
muxMetricsUpdater muxMetricsUpdater
|
||||
// newStreamChan is used to create new streams on the writer thread.
|
||||
// The writer will assign the next available stream ID.
|
||||
newStreamChan chan MuxedStreamRequest
|
||||
// abortChan is used to abort the writer event loop.
|
||||
abortChan chan struct{}
|
||||
// abortOnce is used to ensure abortChan is closed once only.
|
||||
abortOnce sync.Once
|
||||
// readyList is used to signal writable streams.
|
||||
readyList *ReadyList
|
||||
// streams tracks currently-open streams.
|
||||
streams *activeStreamMap
|
||||
// explicitShutdown records whether the Muxer is closing because Shutdown was called, or due to another
|
||||
// error.
|
||||
explicitShutdown *BooleanFuse
|
||||
|
||||
compressionQuality CompressionPreset
|
||||
}
|
||||
|
||||
func RPCHeaders() []Header {
|
||||
return []Header{
|
||||
{Name: ":method", Value: "RPC"},
|
||||
{Name: ":scheme", Value: "capnp"},
|
||||
{Name: ":path", Value: "*"},
|
||||
}
|
||||
}
|
||||
|
||||
// Handshake establishes a muxed connection with the peer.
|
||||
// After the handshake completes, it is possible to open and accept streams.
|
||||
func Handshake(
|
||||
w io.WriteCloser,
|
||||
r io.ReadCloser,
|
||||
config MuxerConfig,
|
||||
activeStreamsMetrics prometheus.Gauge,
|
||||
) (*Muxer, error) {
|
||||
// Set default config values
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = defaultTimeout
|
||||
}
|
||||
if config.DefaultWindowSize == 0 {
|
||||
config.DefaultWindowSize = defaultWindowSize
|
||||
}
|
||||
if config.MaxWindowSize == 0 {
|
||||
config.MaxWindowSize = maxWindowSize
|
||||
}
|
||||
if config.StreamWriteBufferMaxLen == 0 {
|
||||
config.StreamWriteBufferMaxLen = defaultWriteBufferMaxLen
|
||||
}
|
||||
// Initialise connection state fields
|
||||
m := &Muxer{
|
||||
f: http2.NewFramer(w, r), // A framer that writes to w and reads from r
|
||||
config: config,
|
||||
w: w,
|
||||
r: r,
|
||||
newStreamChan: make(chan MuxedStreamRequest),
|
||||
abortChan: make(chan struct{}),
|
||||
readyList: NewReadyList(),
|
||||
streams: newActiveStreamMap(config.IsClient, activeStreamsMetrics),
|
||||
}
|
||||
|
||||
m.f.ReadMetaHeaders = hpack.NewDecoder(4096, func(hpack.HeaderField) {})
|
||||
// Initialise the settings to identify this connection and confirm the other end is sane.
|
||||
handshakeSetting := http2.Setting{ID: SettingMuxerMagic, Val: MuxerMagicEdge}
|
||||
compressionSetting := http2.Setting{ID: SettingCompression, Val: 0}
|
||||
|
||||
expectedMagic := MuxerMagicOrigin
|
||||
if config.IsClient {
|
||||
handshakeSetting.Val = MuxerMagicOrigin
|
||||
expectedMagic = MuxerMagicEdge
|
||||
}
|
||||
errChan := make(chan error, 2)
|
||||
// Simultaneously send our settings and verify the peer's settings.
|
||||
go func() { errChan <- m.f.WriteSettings(handshakeSetting, compressionSetting) }()
|
||||
go func() { errChan <- m.readPeerSettings(expectedMagic) }()
|
||||
err := joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Confirm sanity by ACKing the frame and expecting an ACK for our frame.
|
||||
// Not strictly necessary, but let's pretend to be H2-like.
|
||||
go func() { errChan <- m.f.WriteSettingsAck() }()
|
||||
go func() { errChan <- m.readPeerSettingsAck() }()
|
||||
err = joinErrorsWithTimeout(errChan, 2, config.Timeout, ErrHandshakeTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// set up reader/writer pair ready for serve
|
||||
streamErrors := NewStreamErrorMap()
|
||||
goAwayChan := make(chan http2.ErrCode, 1)
|
||||
inBoundCounter := NewAtomicCounter(0)
|
||||
outBoundCounter := NewAtomicCounter(0)
|
||||
pingTimestamp := NewPingTimestamp()
|
||||
connActive := NewSignal()
|
||||
idleDuration := config.HeartbeatInterval
|
||||
// Sanity check to ensure idelDuration is sane
|
||||
if idleDuration == 0 || idleDuration < defaultTimeout {
|
||||
idleDuration = defaultTimeout
|
||||
config.Log.Info().Msgf("muxer: Minimum idle time has been adjusted to %d", defaultTimeout)
|
||||
}
|
||||
maxRetries := config.MaxHeartbeats
|
||||
if maxRetries == 0 {
|
||||
maxRetries = defaultRetries
|
||||
config.Log.Info().Msgf("muxer: Minimum number of unacked heartbeats to send before closing the connection has been adjusted to %d", maxRetries)
|
||||
}
|
||||
|
||||
compBytesBefore, compBytesAfter := NewAtomicCounter(0), NewAtomicCounter(0)
|
||||
|
||||
m.muxMetricsUpdater = newMuxMetricsUpdater(
|
||||
m.abortChan,
|
||||
compBytesBefore,
|
||||
compBytesAfter,
|
||||
)
|
||||
|
||||
m.explicitShutdown = NewBooleanFuse()
|
||||
m.muxReader = &MuxReader{
|
||||
f: m.f,
|
||||
handler: m.config.Handler,
|
||||
streams: m.streams,
|
||||
readyList: m.readyList,
|
||||
streamErrors: streamErrors,
|
||||
goAwayChan: goAwayChan,
|
||||
abortChan: m.abortChan,
|
||||
pingTimestamp: pingTimestamp,
|
||||
connActive: connActive,
|
||||
initialStreamWindow: m.config.DefaultWindowSize,
|
||||
streamWindowMax: m.config.MaxWindowSize,
|
||||
streamWriteBufferMaxLen: m.config.StreamWriteBufferMaxLen,
|
||||
r: m.r,
|
||||
metricsUpdater: m.muxMetricsUpdater,
|
||||
bytesRead: inBoundCounter,
|
||||
}
|
||||
m.muxWriter = &MuxWriter{
|
||||
f: m.f,
|
||||
streams: m.streams,
|
||||
streamErrors: streamErrors,
|
||||
readyStreamChan: m.readyList.ReadyChannel(),
|
||||
newStreamChan: m.newStreamChan,
|
||||
goAwayChan: goAwayChan,
|
||||
abortChan: m.abortChan,
|
||||
pingTimestamp: pingTimestamp,
|
||||
idleTimer: NewIdleTimer(idleDuration, maxRetries),
|
||||
connActiveChan: connActive.WaitChannel(),
|
||||
maxFrameSize: defaultFrameSize,
|
||||
metricsUpdater: m.muxMetricsUpdater,
|
||||
bytesWrote: outBoundCounter,
|
||||
}
|
||||
m.muxWriter.headerEncoder = hpack.NewEncoder(&m.muxWriter.headerBuffer)
|
||||
|
||||
if m.compressionQuality.dictSize > 0 && m.compressionQuality.nDicts > 0 {
|
||||
nd, sz := m.compressionQuality.nDicts, m.compressionQuality.dictSize
|
||||
writeDicts, dictChan := newH2WriteDictionaries(
|
||||
nd,
|
||||
sz,
|
||||
m.compressionQuality.quality,
|
||||
compBytesBefore,
|
||||
compBytesAfter,
|
||||
)
|
||||
readDicts := newH2ReadDictionaries(nd, sz)
|
||||
m.muxReader.dictionaries = h2Dictionaries{read: &readDicts, write: writeDicts}
|
||||
m.muxWriter.useDictChan = dictChan
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Muxer) readPeerSettings(magic uint32) error {
|
||||
frame, err := m.f.ReadFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
settingsFrame, ok := frame.(*http2.SettingsFrame)
|
||||
if !ok {
|
||||
return ErrBadHandshakeNotSettings
|
||||
}
|
||||
if settingsFrame.Header().Flags != 0 {
|
||||
return ErrBadHandshakeUnexpectedAck
|
||||
}
|
||||
peerMagic, ok := settingsFrame.Value(SettingMuxerMagic)
|
||||
if !ok {
|
||||
return ErrBadHandshakeNoMagic
|
||||
}
|
||||
if magic != peerMagic {
|
||||
return ErrBadHandshakeWrongMagic
|
||||
}
|
||||
peerCompression, ok := settingsFrame.Value(SettingCompression)
|
||||
if !ok {
|
||||
m.compressionQuality = compressionPresets[CompressionNone]
|
||||
return nil
|
||||
}
|
||||
ver, fmt, sz, nd := parseCompressionSettingVal(peerCompression)
|
||||
if ver != compressionVersion || fmt != compressionFormat || sz == 0 || nd == 0 {
|
||||
m.compressionQuality = compressionPresets[CompressionNone]
|
||||
return nil
|
||||
}
|
||||
// Values used for compression are the minimum between the two peers
|
||||
if sz < m.compressionQuality.dictSize {
|
||||
m.compressionQuality.dictSize = sz
|
||||
}
|
||||
if nd < m.compressionQuality.nDicts {
|
||||
m.compressionQuality.nDicts = nd
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Muxer) readPeerSettingsAck() error {
|
||||
frame, err := m.f.ReadFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
settingsFrame, ok := frame.(*http2.SettingsFrame)
|
||||
if !ok {
|
||||
return ErrBadHandshakeNotSettingsAck
|
||||
}
|
||||
if settingsFrame.Header().Flags != http2.FlagSettingsAck {
|
||||
return ErrBadHandshakeUnexpectedSettings
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func joinErrorsWithTimeout(errChan <-chan error, receiveCount int, timeout time.Duration, timeoutError error) error {
|
||||
for i := 0; i < receiveCount; i++ {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case <-time.After(timeout):
|
||||
return timeoutError
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve runs the event loops that comprise h2mux:
|
||||
// - MuxReader.run()
|
||||
// - MuxWriter.run()
|
||||
// - muxMetricsUpdater.run()
|
||||
// In the normal case, Shutdown() is called concurrently with Serve() to stop
|
||||
// these loops.
|
||||
func (m *Muxer) Serve(ctx context.Context) error {
|
||||
errGroup, _ := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
err := m.muxReader.run(m.config.Log)
|
||||
m.explicitShutdown.Fuse(false)
|
||||
m.r.Close()
|
||||
m.abort()
|
||||
// don't block if parent goroutine quit early
|
||||
select {
|
||||
case ch <- err:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case err := <-ch:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
err := m.muxWriter.run(m.config.Log)
|
||||
m.explicitShutdown.Fuse(false)
|
||||
m.w.Close()
|
||||
m.abort()
|
||||
// don't block if parent goroutine quit early
|
||||
select {
|
||||
case ch <- err:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case err := <-ch:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
err := m.muxMetricsUpdater.run(m.config.Log)
|
||||
// don't block if parent goroutine quit early
|
||||
select {
|
||||
case ch <- err:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case err := <-ch:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
err := errGroup.Wait()
|
||||
if isUnexpectedTunnelError(err, m.explicitShutdown.Value()) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown is called to initiate the "happy path" of muxer termination.
|
||||
// It blocks new streams from being created.
|
||||
// It returns a channel that is closed when the last stream has been closed.
|
||||
func (m *Muxer) Shutdown() <-chan struct{} {
|
||||
m.explicitShutdown.Fuse(true)
|
||||
return m.muxReader.Shutdown()
|
||||
}
|
||||
|
||||
// IsUnexpectedTunnelError identifies errors that are expected when shutting down the h2mux tunnel.
|
||||
// The set of expected errors change depending on whether we initiated shutdown or not.
|
||||
func isUnexpectedTunnelError(err error, expectedShutdown bool) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if !expectedShutdown {
|
||||
return true
|
||||
}
|
||||
return !isConnectionClosedError(err)
|
||||
}
|
||||
|
||||
func isConnectionClosedError(err error) bool {
|
||||
if err == io.EOF {
|
||||
return true
|
||||
}
|
||||
if err == io.ErrClosedPipe {
|
||||
return true
|
||||
}
|
||||
if err.Error() == "tls: use of closed connection" {
|
||||
return true
|
||||
}
|
||||
if strings.HasSuffix(err.Error(), "use of closed network connection") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// OpenStream opens a new data stream with the given headers.
|
||||
// Called by proxy server and tunnel
|
||||
func (m *Muxer) OpenStream(ctx context.Context, headers []Header, body io.Reader) (*MuxedStream, error) {
|
||||
stream := m.NewStream(headers)
|
||||
if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, body)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (m *Muxer) OpenRPCStream(ctx context.Context) (*MuxedStream, error) {
|
||||
stream := m.NewStream(RPCHeaders())
|
||||
if err := m.MakeMuxedStreamRequest(ctx, NewMuxedStreamRequest(stream, nil)); err != nil {
|
||||
stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
if err := m.AwaitResponseHeaders(ctx, stream); err != nil {
|
||||
stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !IsRPCStreamResponse(stream) {
|
||||
stream.Close()
|
||||
return nil, ErrNotRPCStream
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (m *Muxer) NewStream(headers []Header) *MuxedStream {
|
||||
return NewStream(m.config, headers, m.readyList, m.muxReader.dictionaries)
|
||||
}
|
||||
|
||||
func (m *Muxer) MakeMuxedStreamRequest(ctx context.Context, request MuxedStreamRequest) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ErrStreamRequestTimeout
|
||||
case <-m.abortChan:
|
||||
return ErrStreamRequestConnectionClosed
|
||||
// Will be received by mux writer
|
||||
case m.newStreamChan <- request:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Muxer) CloseStreamRead(stream *MuxedStream) {
|
||||
stream.CloseRead()
|
||||
if stream.WriteClosed() {
|
||||
m.streams.Delete(stream.streamID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Muxer) AwaitResponseHeaders(ctx context.Context, stream *MuxedStream) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ErrResponseHeadersTimeout
|
||||
case <-m.abortChan:
|
||||
return ErrResponseHeadersConnectionClosed
|
||||
case <-stream.responseHeadersReceived:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Muxer) Metrics() *MuxerMetrics {
|
||||
return m.muxMetricsUpdater.metrics()
|
||||
}
|
||||
|
||||
func (m *Muxer) abort() {
|
||||
m.abortOnce.Do(func() {
|
||||
close(m.abortChan)
|
||||
m.readyList.Close()
|
||||
m.streams.Abort()
|
||||
})
|
||||
}
|
||||
|
||||
// Return how many retries/ticks since the connection was last marked active
|
||||
func (m *Muxer) TimerRetries() uint64 {
|
||||
return m.muxWriter.idleTimer.RetryCount()
|
||||
}
|
||||
|
||||
func IsRPCStreamResponse(stream *MuxedStream) bool {
|
||||
headers := stream.Headers
|
||||
return len(headers) == 1 &&
|
||||
headers[0].Name == ":status" &&
|
||||
headers[0].Value == "200"
|
||||
}
|
|
@ -1,909 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
testOpenStreamTimeout = time.Millisecond * 5000
|
||||
testHandshakeTimeout = time.Millisecond * 1000
|
||||
)
|
||||
|
||||
var log = zerolog.Nop()
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if os.Getenv("VERBOSE") == "1" {
|
||||
//TODO: set log level
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
type DefaultMuxerPair struct {
|
||||
OriginMuxConfig MuxerConfig
|
||||
OriginMux *Muxer
|
||||
OriginConn net.Conn
|
||||
EdgeMuxConfig MuxerConfig
|
||||
EdgeMux *Muxer
|
||||
EdgeConn net.Conn
|
||||
doneC chan struct{}
|
||||
}
|
||||
|
||||
func NewDefaultMuxerPair(t assert.TestingT, testName string, f MuxedStreamFunc) *DefaultMuxerPair {
|
||||
origin, edge := net.Pipe()
|
||||
p := &DefaultMuxerPair{
|
||||
OriginMuxConfig: MuxerConfig{
|
||||
Timeout: testHandshakeTimeout,
|
||||
Handler: f,
|
||||
IsClient: true,
|
||||
Name: "origin",
|
||||
Log: &log,
|
||||
DefaultWindowSize: (1 << 8) - 1,
|
||||
MaxWindowSize: (1 << 15) - 1,
|
||||
StreamWriteBufferMaxLen: 1024,
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
OriginConn: origin,
|
||||
EdgeMuxConfig: MuxerConfig{
|
||||
Timeout: testHandshakeTimeout,
|
||||
IsClient: false,
|
||||
Name: "edge",
|
||||
Log: &log,
|
||||
DefaultWindowSize: (1 << 8) - 1,
|
||||
MaxWindowSize: (1 << 15) - 1,
|
||||
StreamWriteBufferMaxLen: 1024,
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
EdgeConn: edge,
|
||||
doneC: make(chan struct{}),
|
||||
}
|
||||
assert.NoError(t, p.Handshake(testName))
|
||||
return p
|
||||
}
|
||||
|
||||
func NewCompressedMuxerPair(t assert.TestingT, testName string, quality CompressionSetting, f MuxedStreamFunc) *DefaultMuxerPair {
|
||||
origin, edge := net.Pipe()
|
||||
p := &DefaultMuxerPair{
|
||||
OriginMuxConfig: MuxerConfig{
|
||||
Timeout: time.Second,
|
||||
Handler: f,
|
||||
IsClient: true,
|
||||
Name: "origin",
|
||||
CompressionQuality: quality,
|
||||
Log: &log,
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
OriginConn: origin,
|
||||
EdgeMuxConfig: MuxerConfig{
|
||||
Timeout: time.Second,
|
||||
IsClient: false,
|
||||
Name: "edge",
|
||||
CompressionQuality: quality,
|
||||
Log: &log,
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
EdgeConn: edge,
|
||||
doneC: make(chan struct{}),
|
||||
}
|
||||
assert.NoError(t, p.Handshake(testName))
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *DefaultMuxerPair) Handshake(testName string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testHandshakeTimeout)
|
||||
defer cancel()
|
||||
errGroup, _ := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() (err error) {
|
||||
p.EdgeMux, err = Handshake(p.EdgeConn, p.EdgeConn, p.EdgeMuxConfig, ActiveStreams)
|
||||
return errors.Wrap(err, "edge handshake failure")
|
||||
})
|
||||
errGroup.Go(func() (err error) {
|
||||
p.OriginMux, err = Handshake(p.OriginConn, p.OriginConn, p.OriginMuxConfig, ActiveStreams)
|
||||
return errors.Wrap(err, "origin handshake failure")
|
||||
})
|
||||
|
||||
return errGroup.Wait()
|
||||
}
|
||||
|
||||
func (p *DefaultMuxerPair) Serve(t assert.TestingT) {
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
err := p.EdgeMux.Serve(ctx)
|
||||
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
|
||||
t.Errorf("error in edge muxer Serve(): %s", err)
|
||||
}
|
||||
p.OriginMux.Shutdown()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
err := p.OriginMux.Serve(ctx)
|
||||
if err != nil && err != io.EOF && err != io.ErrClosedPipe {
|
||||
t.Errorf("error in origin muxer Serve(): %s", err)
|
||||
}
|
||||
p.EdgeMux.Shutdown()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
// notify when both muxes have stopped serving
|
||||
wg.Wait()
|
||||
close(p.doneC)
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *DefaultMuxerPair) Wait(t *testing.T) {
|
||||
select {
|
||||
case <-p.doneC:
|
||||
return
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DefaultMuxerPair) OpenEdgeMuxStream(headers []Header, body io.Reader) (*MuxedStream, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testOpenStreamTimeout)
|
||||
defer cancel()
|
||||
return p.EdgeMux.OpenStream(ctx, headers, body)
|
||||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
f := func(stream *MuxedStream) error {
|
||||
return nil
|
||||
}
|
||||
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
|
||||
AssertIfPipeReadable(t, muxPair.OriginConn)
|
||||
AssertIfPipeReadable(t, muxPair.EdgeConn)
|
||||
}
|
||||
|
||||
func TestSingleStream(t *testing.T) {
|
||||
f := MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
if len(stream.Headers) != 1 {
|
||||
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "test-header" {
|
||||
t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "headerValue" {
|
||||
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
|
||||
}
|
||||
_ = stream.WriteHeaders([]Header{
|
||||
{Name: "response-header", Value: "responseValue"},
|
||||
})
|
||||
buf := []byte("Hello world")
|
||||
_, _ = stream.Write(buf)
|
||||
n, err := io.ReadFull(stream, buf)
|
||||
if n > 0 {
|
||||
t.Fatalf("read %d bytes after EOF", n)
|
||||
}
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
|
||||
muxPair.Serve(t)
|
||||
|
||||
stream, err := muxPair.OpenEdgeMuxStream(
|
||||
[]Header{{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
if len(stream.Headers) != 1 {
|
||||
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "response-header" {
|
||||
t.Fatalf("expected header name %s, got %s", "response-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "responseValue" {
|
||||
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
|
||||
}
|
||||
responseBody := make([]byte, 11)
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error from (*MuxedStream).Read: %s", err)
|
||||
}
|
||||
if n != len(responseBody) {
|
||||
t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n)
|
||||
}
|
||||
if string(responseBody) != "Hello world" {
|
||||
t.Fatalf("expected response body %s, got %s", "Hello world", responseBody)
|
||||
}
|
||||
_ = stream.Close()
|
||||
n, err = stream.Write([]byte("aaaaa"))
|
||||
if n > 0 {
|
||||
t.Fatalf("wrote %d bytes after EOF", n)
|
||||
}
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected EOF, got %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingleStreamLargeResponseBody(t *testing.T) {
|
||||
bodySize := 1 << 24
|
||||
f := MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
if len(stream.Headers) != 1 {
|
||||
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "test-header" {
|
||||
t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "headerValue" {
|
||||
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
|
||||
}
|
||||
_ = stream.WriteHeaders([]Header{
|
||||
{Name: "response-header", Value: "responseValue"},
|
||||
})
|
||||
payload := make([]byte, bodySize)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i % 256)
|
||||
}
|
||||
t.Log("Writing payload...")
|
||||
n, err := stream.Write(payload)
|
||||
t.Logf("Wrote %d bytes into the stream", n)
|
||||
if err != nil {
|
||||
t.Fatalf("origin write error: %s", err)
|
||||
}
|
||||
if n != len(payload) {
|
||||
t.Fatalf("origin short write: %d/%d bytes", n, len(payload))
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
|
||||
muxPair.Serve(t)
|
||||
|
||||
stream, err := muxPair.OpenEdgeMuxStream(
|
||||
[]Header{{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
if len(stream.Headers) != 1 {
|
||||
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "response-header" {
|
||||
t.Fatalf("expected header name %s, got %s", "response-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "responseValue" {
|
||||
t.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
|
||||
}
|
||||
responseBody := make([]byte, bodySize)
|
||||
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error from (*MuxedStream).Read: %s", err)
|
||||
}
|
||||
if n != len(responseBody) {
|
||||
t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleStreams(t *testing.T) {
|
||||
f := MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
if len(stream.Headers) != 1 {
|
||||
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "client-token" {
|
||||
t.Fatalf("expected header name %s, got %s", "client-token", stream.Headers[0].Name)
|
||||
}
|
||||
log.Debug().Msgf("Got request for stream %s", stream.Headers[0].Value)
|
||||
_ = stream.WriteHeaders([]Header{
|
||||
{Name: "response-token", Value: stream.Headers[0].Value},
|
||||
})
|
||||
log.Debug().Msgf("Wrote headers for stream %s", stream.Headers[0].Value)
|
||||
_, _ = stream.Write([]byte("OK"))
|
||||
log.Debug().Msgf("Wrote body for stream %s", stream.Headers[0].Value)
|
||||
return nil
|
||||
})
|
||||
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
|
||||
muxPair.Serve(t)
|
||||
|
||||
maxStreams := 64
|
||||
errorsC := make(chan error, maxStreams)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(maxStreams)
|
||||
for i := 0; i < maxStreams; i++ {
|
||||
go func(tokenId int) {
|
||||
defer wg.Done()
|
||||
tokenString := fmt.Sprintf("%d", tokenId)
|
||||
stream, err := muxPair.OpenEdgeMuxStream(
|
||||
[]Header{{Name: "client-token", Value: tokenString}},
|
||||
nil,
|
||||
)
|
||||
log.Debug().Msgf("Got headers for stream %d", tokenId)
|
||||
if err != nil {
|
||||
errorsC <- err
|
||||
return
|
||||
}
|
||||
if len(stream.Headers) != 1 {
|
||||
errorsC <- fmt.Errorf("stream %d has error: expected %d headers, got %d", stream.streamID, 1, len(stream.Headers))
|
||||
return
|
||||
}
|
||||
if stream.Headers[0].Name != "response-token" {
|
||||
errorsC <- fmt.Errorf("stream %d has error: expected header name %s, got %s", stream.streamID, "response-token", stream.Headers[0].Name)
|
||||
return
|
||||
}
|
||||
if stream.Headers[0].Value != tokenString {
|
||||
errorsC <- fmt.Errorf("stream %d has error: expected header value %s, got %s", stream.streamID, tokenString, stream.Headers[0].Value)
|
||||
return
|
||||
}
|
||||
responseBody := make([]byte, 2)
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != nil {
|
||||
errorsC <- fmt.Errorf("stream %d has error: error from (*MuxedStream).Read: %s", stream.streamID, err)
|
||||
return
|
||||
}
|
||||
if n != len(responseBody) {
|
||||
errorsC <- fmt.Errorf("stream %d has error: expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n)
|
||||
return
|
||||
}
|
||||
if string(responseBody) != "OK" {
|
||||
errorsC <- fmt.Errorf("stream %d has error: expected response body %s, got %s", stream.streamID, "OK", responseBody)
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
close(errorsC)
|
||||
testFail := false
|
||||
for err := range errorsC {
|
||||
testFail = true
|
||||
log.Error().Msgf("%s", err)
|
||||
}
|
||||
if testFail {
|
||||
t.Fatalf("TestMultipleStreams failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleStreamsFlowControl(t *testing.T) {
|
||||
maxStreams := 32
|
||||
responseSizes := make([]int32, maxStreams)
|
||||
for i := 0; i < maxStreams; i++ {
|
||||
responseSizes[i] = rand.Int31n(int32(defaultWindowSize << 4))
|
||||
}
|
||||
|
||||
f := MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
if len(stream.Headers) != 1 {
|
||||
t.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "test-header" {
|
||||
t.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "headerValue" {
|
||||
t.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
|
||||
}
|
||||
_ = stream.WriteHeaders([]Header{
|
||||
{Name: "response-header", Value: "responseValue"},
|
||||
})
|
||||
payload := make([]byte, responseSizes[(stream.streamID-2)/2])
|
||||
for i := range payload {
|
||||
payload[i] = byte(i % 256)
|
||||
}
|
||||
n, err := stream.Write(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("origin write error: %s", err)
|
||||
}
|
||||
if n != len(payload) {
|
||||
t.Fatalf("origin short write: %d/%d bytes", n, len(payload))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
|
||||
muxPair.Serve(t)
|
||||
|
||||
errGroup, _ := errgroup.WithContext(context.Background())
|
||||
for i := 0; i < maxStreams; i++ {
|
||||
errGroup.Go(func() error {
|
||||
stream, err := muxPair.OpenEdgeMuxStream(
|
||||
[]Header{{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error in OpenStream: %d %s", stream.streamID, err)
|
||||
}
|
||||
if len(stream.Headers) != 1 {
|
||||
return fmt.Errorf("stream %d expected %d headers, got %d", stream.streamID, 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "response-header" {
|
||||
return fmt.Errorf("stream %d expected header name %s, got %s", stream.streamID, "response-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "responseValue" {
|
||||
return fmt.Errorf("stream %d expected header value %s, got %s", stream.streamID, "responseValue", stream.Headers[0].Value)
|
||||
}
|
||||
|
||||
responseBody := make([]byte, responseSizes[(stream.streamID-2)/2])
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stream %d error from (*MuxedStream).Read: %s", stream.streamID, err)
|
||||
}
|
||||
if n != len(responseBody) {
|
||||
return fmt.Errorf("stream %d expected response body to have %d bytes, got %d", stream.streamID, len(responseBody), n)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
assert.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
func TestGracefulShutdown(t *testing.T) {
|
||||
sendC := make(chan struct{})
|
||||
responseBuf := bytes.Repeat([]byte("Hello world"), 65536)
|
||||
|
||||
f := MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
_ = stream.WriteHeaders([]Header{
|
||||
{Name: "response-header", Value: "responseValue"},
|
||||
})
|
||||
<-sendC
|
||||
log.Debug().Msgf("Writing %d bytes", len(responseBuf))
|
||||
_, _ = stream.Write(responseBuf)
|
||||
_ = stream.CloseWrite()
|
||||
log.Debug().Msgf("Wrote %d bytes", len(responseBuf))
|
||||
// Reading from the stream will block until the edge closes its end of the stream.
|
||||
// Otherwise, we'll close the whole connection before receiving the 'stream closed'
|
||||
// message from the edge.
|
||||
// Graceful shutdown works if you omit this, it just gives spurious errors for now -
|
||||
// TODO ignore errors when writing 'stream closed' and we're shutting down.
|
||||
_, _ = stream.Read([]byte{0})
|
||||
log.Debug().Msgf("Handler ends")
|
||||
return nil
|
||||
})
|
||||
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
|
||||
muxPair.Serve(t)
|
||||
|
||||
stream, err := muxPair.OpenEdgeMuxStream(
|
||||
[]Header{{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
// Start graceful shutdown of the edge mux - this should also close the origin mux when done
|
||||
muxPair.EdgeMux.Shutdown()
|
||||
close(sendC)
|
||||
responseBody := make([]byte, len(responseBuf))
|
||||
log.Debug().Msgf("Waiting for %d bytes", len(responseBuf))
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error from (*MuxedStream).Read with %d bytes read: %s", n, err)
|
||||
}
|
||||
if n != len(responseBody) {
|
||||
t.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n)
|
||||
}
|
||||
if !bytes.Equal(responseBuf, responseBody) {
|
||||
t.Fatalf("response body mismatch")
|
||||
}
|
||||
_ = stream.Close()
|
||||
muxPair.Wait(t)
|
||||
}
|
||||
|
||||
func TestUnexpectedShutdown(t *testing.T) {
|
||||
sendC := make(chan struct{})
|
||||
handlerFinishC := make(chan struct{})
|
||||
responseBuf := bytes.Repeat([]byte("Hello world"), 65536)
|
||||
|
||||
f := MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
defer close(handlerFinishC)
|
||||
_ = stream.WriteHeaders([]Header{
|
||||
{Name: "response-header", Value: "responseValue"},
|
||||
})
|
||||
<-sendC
|
||||
n, err := stream.Read([]byte{0})
|
||||
if err != io.EOF {
|
||||
t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Fatalf("expected empty read, got %d bytes", n)
|
||||
}
|
||||
// Write comes after read, because write buffers data before it is flushed. It wouldn't know about EOF
|
||||
// until some time later. Calling read first forces it to know about EOF now.
|
||||
_, err = stream.Write(responseBuf)
|
||||
if err != io.EOF {
|
||||
t.Fatalf("unexpected error from (*MuxedStream).Write: %s", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
muxPair := NewDefaultMuxerPair(t, t.Name(), f)
|
||||
muxPair.Serve(t)
|
||||
|
||||
stream, err := muxPair.OpenEdgeMuxStream(
|
||||
[]Header{{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
// Close the underlying connection before telling the origin to write.
|
||||
_ = muxPair.EdgeConn.Close()
|
||||
close(sendC)
|
||||
if err != nil {
|
||||
t.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
responseBody := make([]byte, len(responseBuf))
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != io.EOF {
|
||||
t.Fatalf("unexpected error from (*MuxedStream).Read: %s", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Fatalf("expected response body to have %d bytes, got %d", 0, n)
|
||||
}
|
||||
// The write ordering requirement explained in the origin handler applies here too.
|
||||
_, err = stream.Write(responseBuf)
|
||||
if err != io.EOF {
|
||||
t.Fatalf("unexpected error from (*MuxedStream).Write: %s", err)
|
||||
}
|
||||
<-handlerFinishC
|
||||
}
|
||||
|
||||
func EchoHandler(stream *MuxedStream) error {
|
||||
var buf bytes.Buffer
|
||||
_, _ = fmt.Fprintf(&buf, "Hello, world!\n\n# REQUEST HEADERS:\n\n")
|
||||
for _, header := range stream.Headers {
|
||||
_, _ = fmt.Fprintf(&buf, "[%s] = %s\n", header.Name, header.Value)
|
||||
}
|
||||
_ = stream.WriteHeaders([]Header{
|
||||
{Name: ":status", Value: "200"},
|
||||
{Name: "server", Value: "Echo-server/1.0"},
|
||||
{Name: "date", Value: time.Now().Format(time.RFC850)},
|
||||
{Name: "content-type", Value: "text/html; charset=utf-8"},
|
||||
{Name: "content-length", Value: strconv.Itoa(buf.Len())},
|
||||
})
|
||||
_, _ = buf.WriteTo(stream)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOpenAfterDisconnect(t *testing.T) {
|
||||
for i := 0; i < 3; i++ {
|
||||
muxPair := NewDefaultMuxerPair(t, fmt.Sprintf("%s_%d", t.Name(), i), EchoHandler)
|
||||
muxPair.Serve(t)
|
||||
|
||||
switch i {
|
||||
case 0:
|
||||
// Close both directions of the connection to cause EOF on both peers.
|
||||
_ = muxPair.OriginConn.Close()
|
||||
_ = muxPair.EdgeConn.Close()
|
||||
case 1:
|
||||
// Close origin conn to cause EOF on origin first.
|
||||
_ = muxPair.OriginConn.Close()
|
||||
case 2:
|
||||
// Close edge conn to cause EOF on edge first.
|
||||
_ = muxPair.EdgeConn.Close()
|
||||
}
|
||||
|
||||
_, err := muxPair.OpenEdgeMuxStream(
|
||||
[]Header{{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
if err != ErrStreamRequestConnectionClosed && err != ErrResponseHeadersConnectionClosed {
|
||||
t.Fatalf("case %v: unexpected error in OpenStream: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHPACK(t *testing.T) {
|
||||
muxPair := NewDefaultMuxerPair(t, t.Name(), EchoHandler)
|
||||
muxPair.Serve(t)
|
||||
|
||||
stream, err := muxPair.OpenEdgeMuxStream(
|
||||
[]Header{
|
||||
{Name: ":method", Value: "RPC"},
|
||||
{Name: ":scheme", Value: "capnp"},
|
||||
{Name: ":path", Value: "*"},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
_ = stream.Close()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
stream, err := muxPair.OpenEdgeMuxStream(
|
||||
[]Header{
|
||||
{Name: ":method", Value: "GET"},
|
||||
{Name: ":scheme", Value: "https"},
|
||||
{Name: ":authority", Value: "tunnel.otterlyadorable.co.uk"},
|
||||
{Name: ":path", Value: "/get"},
|
||||
{Name: "accept-encoding", Value: "gzip"},
|
||||
{Name: "cf-ray", Value: "378948953f044408-SFO-DOG"},
|
||||
{Name: "cf-visitor", Value: "{\"scheme\":\"https\"}"},
|
||||
{Name: "cf-connecting-ip", Value: "2400:cb00:0025:010d:0000:0000:0000:0001"},
|
||||
{Name: "x-forwarded-for", Value: "2400:cb00:0025:010d:0000:0000:0000:0001"},
|
||||
{Name: "x-forwarded-proto", Value: "https"},
|
||||
{Name: "accept-language", Value: "en-gb"},
|
||||
{Name: "referer", Value: "https://tunnel.otterlyadorable.co.uk/"},
|
||||
{Name: "cookie", Value: "__cfduid=d4555095065f92daedc059490771967d81493032162"},
|
||||
{Name: "connection", Value: "Keep-Alive"},
|
||||
{Name: "cf-ipcountry", Value: "US"},
|
||||
{Name: "accept", Value: "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"},
|
||||
{Name: "user-agent", Value: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_5) AppleWebKit/603.2.4 (KHTML, like Gecko) Version/10.1.1 Safari/603.2.4"},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
if len(stream.Headers) == 0 {
|
||||
t.Fatal("response has no headers")
|
||||
}
|
||||
if stream.Headers[0].Name != ":status" {
|
||||
t.Fatalf("first header should be status, found %s instead", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "200" {
|
||||
t.Fatalf("expected status 200, got %s", stream.Headers[0].Value)
|
||||
}
|
||||
_, _ = io.ReadAll(stream)
|
||||
_ = stream.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func AssertIfPipeReadable(t *testing.T, pipe io.ReadCloser) {
|
||||
errC := make(chan error)
|
||||
go func() {
|
||||
b := []byte{0}
|
||||
n, err := pipe.Read(b)
|
||||
if n > 0 {
|
||||
t.Errorf("read pipe was not empty")
|
||||
return
|
||||
}
|
||||
errC <- err
|
||||
}()
|
||||
select {
|
||||
case err := <-errC:
|
||||
if err != nil {
|
||||
t.Fatalf("read error: %s", err)
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// nothing to read
|
||||
}
|
||||
}
|
||||
|
||||
func sampleSiteHandler(files map[string][]byte) MuxedStreamFunc {
|
||||
return func(stream *MuxedStream) error {
|
||||
var contentType string
|
||||
var pathHeader Header
|
||||
|
||||
for _, h := range stream.Headers {
|
||||
if h.Name == ":path" {
|
||||
pathHeader = h
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if pathHeader.Name != ":path" {
|
||||
return fmt.Errorf("Couldn't find :path header in test")
|
||||
}
|
||||
|
||||
if strings.Contains(pathHeader.Value, "html") {
|
||||
contentType = "text/html; charset=utf-8"
|
||||
} else if strings.Contains(pathHeader.Value, "js") {
|
||||
contentType = "application/javascript"
|
||||
} else if strings.Contains(pathHeader.Value, "css") {
|
||||
contentType = "text/css"
|
||||
} else {
|
||||
contentType = "img/gif"
|
||||
}
|
||||
_ = stream.WriteHeaders([]Header{
|
||||
{Name: "content-type", Value: contentType},
|
||||
})
|
||||
log.Debug().Msgf("Wrote headers for stream %s", pathHeader.Value)
|
||||
file, ok := files[pathHeader.Value]
|
||||
if !ok {
|
||||
return fmt.Errorf("%s content is not preloaded", pathHeader.Value)
|
||||
}
|
||||
_, _ = stream.Write(file)
|
||||
log.Debug().Msgf("Wrote body for stream %s", pathHeader.Value)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func sampleSiteTest(muxPair *DefaultMuxerPair, path string, files map[string][]byte) error {
|
||||
stream, err := muxPair.OpenEdgeMuxStream(
|
||||
[]Header{
|
||||
{Name: ":method", Value: "GET"},
|
||||
{Name: ":scheme", Value: "https"},
|
||||
{Name: ":authority", Value: "tunnel.otterlyadorable.co.uk"},
|
||||
{Name: ":path", Value: path},
|
||||
{Name: "accept-encoding", Value: "br, gzip"},
|
||||
{Name: "cf-ray", Value: "378948953f044408-SFO-DOG"},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error in OpenStream: %v", err)
|
||||
}
|
||||
file, ok := files[path]
|
||||
if !ok {
|
||||
return fmt.Errorf("%s content is not preloaded", path)
|
||||
}
|
||||
responseBody := make([]byte, len(file))
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error from (*MuxedStream).Read: %v", err)
|
||||
}
|
||||
if n != len(file) {
|
||||
return fmt.Errorf("expected response body to have %d bytes, got %d", len(file), n)
|
||||
}
|
||||
if string(responseBody[:n]) != string(file) {
|
||||
return fmt.Errorf("expected response body %s, got %s", file, responseBody[:n])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadSampleFiles(paths []string) (map[string][]byte, error) {
|
||||
files := make(map[string][]byte)
|
||||
for _, path := range paths {
|
||||
if _, ok := files[path]; !ok {
|
||||
expectBody, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
files[path] = expectBody
|
||||
}
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
|
||||
func BenchmarkOpenStream(b *testing.B) {
|
||||
const streams = 5000
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
f := MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
if len(stream.Headers) != 1 {
|
||||
b.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "test-header" {
|
||||
b.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "headerValue" {
|
||||
b.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
|
||||
}
|
||||
_ = stream.WriteHeaders([]Header{
|
||||
{Name: "response-header", Value: "responseValue"},
|
||||
})
|
||||
return nil
|
||||
})
|
||||
muxPair := NewDefaultMuxerPair(b, fmt.Sprintf("%s_%d", b.Name(), i), f)
|
||||
muxPair.Serve(b)
|
||||
b.StartTimer()
|
||||
openStreams(b, muxPair, streams)
|
||||
}
|
||||
}
|
||||
|
||||
func openStreams(b *testing.B, muxPair *DefaultMuxerPair, n int) {
|
||||
errGroup, _ := errgroup.WithContext(context.Background())
|
||||
for i := 0; i < n; i++ {
|
||||
errGroup.Go(func() error {
|
||||
_, err := muxPair.OpenEdgeMuxStream(
|
||||
[]Header{{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
assert.NoError(b, errGroup.Wait())
|
||||
}
|
||||
|
||||
func BenchmarkSingleStreamLargeResponseBody(b *testing.B) {
|
||||
const bodySize = 1 << 24
|
||||
|
||||
const writeBufferSize = 16 << 10
|
||||
const writeN = bodySize / writeBufferSize
|
||||
payload := make([]byte, writeBufferSize)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
const readBufferSize = 16 << 10
|
||||
const readN = bodySize / readBufferSize
|
||||
responseBody := make([]byte, readBufferSize)
|
||||
|
||||
f := MuxedStreamFunc(func(stream *MuxedStream) error {
|
||||
if len(stream.Headers) != 1 {
|
||||
b.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "test-header" {
|
||||
b.Fatalf("expected header name %s, got %s", "test-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "headerValue" {
|
||||
b.Fatalf("expected header value %s, got %s", "headerValue", stream.Headers[0].Value)
|
||||
}
|
||||
_ = stream.WriteHeaders([]Header{
|
||||
{Name: "response-header", Value: "responseValue"},
|
||||
})
|
||||
for i := 0; i < writeN; i++ {
|
||||
n, err := stream.Write(payload)
|
||||
if err != nil {
|
||||
b.Fatalf("origin write error: %s", err)
|
||||
}
|
||||
if n != len(payload) {
|
||||
b.Fatalf("origin short write: %d/%d bytes", n, len(payload))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
name := fmt.Sprintf("%s_%d", b.Name(), rand.Int())
|
||||
origin, edge := net.Pipe()
|
||||
|
||||
muxPair := &DefaultMuxerPair{
|
||||
OriginMuxConfig: MuxerConfig{
|
||||
Timeout: testHandshakeTimeout,
|
||||
Handler: f,
|
||||
IsClient: true,
|
||||
Name: "origin",
|
||||
Log: &log,
|
||||
DefaultWindowSize: defaultWindowSize,
|
||||
MaxWindowSize: maxWindowSize,
|
||||
StreamWriteBufferMaxLen: defaultWriteBufferMaxLen,
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
OriginConn: origin,
|
||||
EdgeMuxConfig: MuxerConfig{
|
||||
Timeout: testHandshakeTimeout,
|
||||
IsClient: false,
|
||||
Name: "edge",
|
||||
Log: &log,
|
||||
DefaultWindowSize: defaultWindowSize,
|
||||
MaxWindowSize: maxWindowSize,
|
||||
StreamWriteBufferMaxLen: defaultWriteBufferMaxLen,
|
||||
HeartbeatInterval: defaultTimeout,
|
||||
MaxHeartbeats: defaultRetries,
|
||||
},
|
||||
EdgeConn: edge,
|
||||
doneC: make(chan struct{}),
|
||||
}
|
||||
assert.NoError(b, muxPair.Handshake(name))
|
||||
muxPair.Serve(b)
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
stream, err := muxPair.OpenEdgeMuxStream(
|
||||
[]Header{{Name: "test-header", Value: "headerValue"}},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
b.Fatalf("error in OpenStream: %s", err)
|
||||
}
|
||||
if len(stream.Headers) != 1 {
|
||||
b.Fatalf("expected %d headers, got %d", 1, len(stream.Headers))
|
||||
}
|
||||
if stream.Headers[0].Name != "response-header" {
|
||||
b.Fatalf("expected header name %s, got %s", "response-header", stream.Headers[0].Name)
|
||||
}
|
||||
if stream.Headers[0].Value != "responseValue" {
|
||||
b.Fatalf("expected header value %s, got %s", "responseValue", stream.Headers[0].Value)
|
||||
}
|
||||
|
||||
for k := 0; k < readN; k++ {
|
||||
n, err := io.ReadFull(stream, responseBody)
|
||||
if err != nil {
|
||||
b.Fatalf("error from (*MuxedStream).Read: %s", err)
|
||||
}
|
||||
if n != len(responseBody) {
|
||||
b.Fatalf("expected response body to have %d bytes, got %d", len(responseBody), n)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,81 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IdleTimer is a type of Timer designed for managing heartbeats on an idle connection.
|
||||
// The timer ticks on an interval with added jitter to avoid accidental synchronisation
|
||||
// between two endpoints. It tracks the number of retries/ticks since the connection was
|
||||
// last marked active.
|
||||
//
|
||||
// The methods of IdleTimer must not be called while a goroutine is reading from C.
|
||||
type IdleTimer struct {
|
||||
// The channel on which ticks are delivered.
|
||||
C <-chan time.Time
|
||||
|
||||
// A timer used to measure idle connection time. Reset after sending data.
|
||||
idleTimer *time.Timer
|
||||
// The maximum length of time a connection is idle before sending a ping.
|
||||
idleDuration time.Duration
|
||||
// A pseudorandom source used to add jitter to the idle duration.
|
||||
randomSource *rand.Rand
|
||||
// The maximum number of retries allowed.
|
||||
maxRetries uint64
|
||||
// The number of retries since the connection was last marked active.
|
||||
retries uint64
|
||||
// A lock to prevent race condition while checking retries
|
||||
stateLock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewIdleTimer(idleDuration time.Duration, maxRetries uint64) *IdleTimer {
|
||||
t := &IdleTimer{
|
||||
idleTimer: time.NewTimer(idleDuration),
|
||||
idleDuration: idleDuration,
|
||||
randomSource: rand.New(rand.NewSource(time.Now().Unix())),
|
||||
maxRetries: maxRetries,
|
||||
}
|
||||
t.C = t.idleTimer.C
|
||||
return t
|
||||
}
|
||||
|
||||
// Retry should be called when retrying the idle timeout. If the maximum number of retries
|
||||
// has been met, returns false.
|
||||
// After calling this function and sending a heartbeat, call ResetTimer. Since sending the
|
||||
// heartbeat could be a blocking operation, we resetting the timer after the write completes
|
||||
// to avoid it expiring during the write.
|
||||
func (t *IdleTimer) Retry() bool {
|
||||
t.stateLock.Lock()
|
||||
defer t.stateLock.Unlock()
|
||||
if t.retries >= t.maxRetries {
|
||||
return false
|
||||
}
|
||||
t.retries++
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *IdleTimer) RetryCount() uint64 {
|
||||
t.stateLock.RLock()
|
||||
defer t.stateLock.RUnlock()
|
||||
return t.retries
|
||||
}
|
||||
|
||||
// MarkActive resets the idle connection timer and suppresses any outstanding idle events.
|
||||
func (t *IdleTimer) MarkActive() {
|
||||
if !t.idleTimer.Stop() {
|
||||
// eat the timer event to prevent spurious pings
|
||||
<-t.idleTimer.C
|
||||
}
|
||||
t.stateLock.Lock()
|
||||
t.retries = 0
|
||||
t.stateLock.Unlock()
|
||||
t.ResetTimer()
|
||||
}
|
||||
|
||||
// Reset the idle timer according to the configured duration, with some added jitter.
|
||||
func (t *IdleTimer) ResetTimer() {
|
||||
jitter := time.Duration(t.randomSource.Int63n(int64(t.idleDuration)))
|
||||
t.idleTimer.Reset(t.idleDuration + jitter)
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRetry(t *testing.T) {
|
||||
timer := NewIdleTimer(time.Second, 2)
|
||||
assert.Equal(t, uint64(0), timer.RetryCount())
|
||||
ok := timer.Retry()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, uint64(1), timer.RetryCount())
|
||||
ok = timer.Retry()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, uint64(2), timer.RetryCount())
|
||||
ok = timer.Retry()
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestMarkActive(t *testing.T) {
|
||||
timer := NewIdleTimer(time.Second, 2)
|
||||
assert.Equal(t, uint64(0), timer.RetryCount())
|
||||
ok := timer.Retry()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, uint64(1), timer.RetryCount())
|
||||
timer.MarkActive()
|
||||
assert.Equal(t, uint64(0), timer.RetryCount())
|
||||
}
|
|
@ -1,457 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ReadWriteLengther interface {
|
||||
io.ReadWriter
|
||||
Reset()
|
||||
Len() int
|
||||
}
|
||||
|
||||
type ReadWriteClosedCloser interface {
|
||||
io.ReadWriteCloser
|
||||
Closed() bool
|
||||
}
|
||||
|
||||
// MuxedStreamDataSignaller is a write-only *ReadyList
|
||||
type MuxedStreamDataSignaller interface {
|
||||
// Non-blocking: call this when data is ready to be sent for the given stream ID.
|
||||
Signal(ID uint32)
|
||||
}
|
||||
|
||||
type Header struct {
|
||||
Name, Value string
|
||||
}
|
||||
|
||||
// MuxedStream is logically an HTTP/2 stream, with an additional buffer for outgoing data.
|
||||
type MuxedStream struct {
|
||||
streamID uint32
|
||||
|
||||
// The "Receive" end of the stream
|
||||
readBufferLock sync.RWMutex
|
||||
readBuffer ReadWriteClosedCloser
|
||||
// This is the amount of bytes that are in our receive window
|
||||
// (how much data we can receive into this stream).
|
||||
receiveWindow uint32
|
||||
// current receive window size limit. Exponentially increase it when it's exhausted
|
||||
receiveWindowCurrentMax uint32
|
||||
// hard limit set in http2 spec. 2^31-1
|
||||
receiveWindowMax uint32
|
||||
// The desired size increment for receiveWindow.
|
||||
// If this is nonzero, a WINDOW_UPDATE frame needs to be sent.
|
||||
windowUpdate uint32
|
||||
// The headers that were most recently received.
|
||||
// Particularly:
|
||||
// * for an eyeball-initiated stream (as passed to TunnelHandler::ServeStream),
|
||||
// these are the request headers
|
||||
// * for a cloudflared-initiated stream (as created by Register/UnregisterTunnel),
|
||||
// these are the response headers.
|
||||
// They are useful in both of these contexts; hence `Headers` is public.
|
||||
Headers []Header
|
||||
// For use in the context of a cloudflared-initiated stream.
|
||||
responseHeadersReceived chan struct{}
|
||||
|
||||
// The "Send" end of the stream
|
||||
writeLock sync.Mutex
|
||||
writeBuffer ReadWriteLengther
|
||||
// The maximum capacity that the send buffer should grow to.
|
||||
writeBufferMaxLen int
|
||||
// A channel to be notified when the send buffer is not full.
|
||||
writeBufferHasSpace chan struct{}
|
||||
// This is the amount of bytes that are in the peer's receive window
|
||||
// (how much data we can send from this stream).
|
||||
sendWindow uint32
|
||||
// The muxer's readyList
|
||||
readyList MuxedStreamDataSignaller
|
||||
// The headers that should be sent, and a flag so we only send them once.
|
||||
headersSent bool
|
||||
writeHeaders []Header
|
||||
|
||||
// EOF-related fields
|
||||
// true if the write end of this stream has been closed
|
||||
writeEOF bool
|
||||
// true if we have sent EOF to the peer
|
||||
sentEOF bool
|
||||
// true if the peer sent us an EOF
|
||||
receivedEOF bool
|
||||
// Compression-related fields
|
||||
receivedUseDict bool
|
||||
method string
|
||||
contentType string
|
||||
path string
|
||||
dictionaries h2Dictionaries
|
||||
}
|
||||
|
||||
type TunnelHostname string
|
||||
|
||||
func (th TunnelHostname) String() string {
|
||||
return string(th)
|
||||
}
|
||||
|
||||
func (th TunnelHostname) IsSet() bool {
|
||||
return th != ""
|
||||
}
|
||||
|
||||
func NewStream(config MuxerConfig, writeHeaders []Header, readyList MuxedStreamDataSignaller, dictionaries h2Dictionaries) *MuxedStream {
|
||||
return &MuxedStream{
|
||||
responseHeadersReceived: make(chan struct{}),
|
||||
readBuffer: NewSharedBuffer(),
|
||||
writeBuffer: &bytes.Buffer{},
|
||||
writeBufferMaxLen: config.StreamWriteBufferMaxLen,
|
||||
writeBufferHasSpace: make(chan struct{}, 1),
|
||||
receiveWindow: config.DefaultWindowSize,
|
||||
receiveWindowCurrentMax: config.DefaultWindowSize,
|
||||
receiveWindowMax: config.MaxWindowSize,
|
||||
sendWindow: config.DefaultWindowSize,
|
||||
readyList: readyList,
|
||||
writeHeaders: writeHeaders,
|
||||
dictionaries: dictionaries,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MuxedStream) Read(p []byte) (n int, err error) {
|
||||
var readBuffer ReadWriteClosedCloser
|
||||
if s.dictionaries.read != nil {
|
||||
s.readBufferLock.RLock()
|
||||
readBuffer = s.readBuffer
|
||||
s.readBufferLock.RUnlock()
|
||||
} else {
|
||||
readBuffer = s.readBuffer
|
||||
}
|
||||
n, err = readBuffer.Read(p)
|
||||
s.replenishReceiveWindow(uint32(n))
|
||||
return
|
||||
}
|
||||
|
||||
// Blocks until len(p) bytes have been written to the buffer
|
||||
func (s *MuxedStream) Write(p []byte) (int, error) {
|
||||
// If assignDictToStream returns success, then it will have acquired the
|
||||
// writeLock. Otherwise we must acquire it ourselves.
|
||||
ok := assignDictToStream(s, p)
|
||||
if !ok {
|
||||
s.writeLock.Lock()
|
||||
}
|
||||
defer s.writeLock.Unlock()
|
||||
|
||||
if s.writeEOF {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// pre-allocate some space in the write buffer if possible
|
||||
if buffer, ok := s.writeBuffer.(*bytes.Buffer); ok {
|
||||
if buffer.Cap() == 0 {
|
||||
buffer.Grow(writeBufferInitialSize)
|
||||
}
|
||||
}
|
||||
|
||||
totalWritten := 0
|
||||
for totalWritten < len(p) {
|
||||
// If the buffer is full, block till there is more room.
|
||||
// Use a loop to recheck the buffer size after the lock is reacquired.
|
||||
for s.writeBufferMaxLen <= s.writeBuffer.Len() {
|
||||
s.awaitWriteBufferHasSpace()
|
||||
if s.writeEOF {
|
||||
return totalWritten, io.EOF
|
||||
}
|
||||
}
|
||||
amountToWrite := len(p) - totalWritten
|
||||
spaceAvailable := s.writeBufferMaxLen - s.writeBuffer.Len()
|
||||
if spaceAvailable < amountToWrite {
|
||||
amountToWrite = spaceAvailable
|
||||
}
|
||||
amountWritten, err := s.writeBuffer.Write(p[totalWritten : totalWritten+amountToWrite])
|
||||
totalWritten += amountWritten
|
||||
if err != nil {
|
||||
return totalWritten, err
|
||||
}
|
||||
s.writeNotify()
|
||||
}
|
||||
return totalWritten, nil
|
||||
}
|
||||
|
||||
func (s *MuxedStream) Close() error {
|
||||
// TUN-115: Close the write buffer before the read buffer.
|
||||
// In the case of shutdown, read will not get new data, but the write buffer can still receive
|
||||
// new data. Closing read before write allows application to race between a failed read and a
|
||||
// successful write, even though this close should appear to be atomic.
|
||||
// This can't happen the other way because reads may succeed after a failed write; if we read
|
||||
// past EOF the application will block until we close the buffer.
|
||||
err := s.CloseWrite()
|
||||
if err != nil {
|
||||
if s.CloseRead() == nil {
|
||||
// don't bother the caller with errors if at least one close succeeded
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return s.CloseRead()
|
||||
}
|
||||
|
||||
func (s *MuxedStream) CloseRead() error {
|
||||
return s.readBuffer.Close()
|
||||
}
|
||||
|
||||
func (s *MuxedStream) CloseWrite() error {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
if s.writeEOF {
|
||||
return io.EOF
|
||||
}
|
||||
s.writeEOF = true
|
||||
if c, ok := s.writeBuffer.(io.Closer); ok {
|
||||
c.Close()
|
||||
}
|
||||
// Allow MuxedStream::Write() to terminate its loop with err=io.EOF, if needed
|
||||
s.notifyWriteBufferHasSpace()
|
||||
// We need to send something over the wire, even if it's an END_STREAM with no data
|
||||
s.writeNotify()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MuxedStream) WriteClosed() bool {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
return s.writeEOF
|
||||
}
|
||||
|
||||
func (s *MuxedStream) WriteHeaders(headers []Header) error {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
if s.writeHeaders != nil {
|
||||
return ErrStreamHeadersSent
|
||||
}
|
||||
|
||||
if s.dictionaries.write != nil {
|
||||
dictWriter := s.dictionaries.write.getDictWriter(s, headers)
|
||||
if dictWriter != nil {
|
||||
s.writeBuffer = dictWriter
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
s.writeHeaders = headers
|
||||
s.headersSent = false
|
||||
s.writeNotify()
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsRPCStream returns if the stream is used to transport RPC.
|
||||
func (s *MuxedStream) IsRPCStream() bool {
|
||||
rpcHeaders := RPCHeaders()
|
||||
if len(s.Headers) != len(rpcHeaders) {
|
||||
return false
|
||||
}
|
||||
// The headers order matters, so RPC stream should be opened with OpenRPCStream method and let MuxWriter serializes the headers.
|
||||
for i, rpcHeader := range rpcHeaders {
|
||||
if s.Headers[i] != rpcHeader {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Block until a value is sent on writeBufferHasSpace.
|
||||
// Must be called while holding writeLock
|
||||
func (s *MuxedStream) awaitWriteBufferHasSpace() {
|
||||
s.writeLock.Unlock()
|
||||
<-s.writeBufferHasSpace
|
||||
s.writeLock.Lock()
|
||||
}
|
||||
|
||||
// Send a value on writeBufferHasSpace without blocking.
|
||||
// Must be called while holding writeLock
|
||||
func (s *MuxedStream) notifyWriteBufferHasSpace() {
|
||||
select {
|
||||
case s.writeBufferHasSpace <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MuxedStream) getReceiveWindow() uint32 {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
return s.receiveWindow
|
||||
}
|
||||
|
||||
func (s *MuxedStream) getSendWindow() uint32 {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
return s.sendWindow
|
||||
}
|
||||
|
||||
// writeNotify must happen while holding writeLock.
|
||||
func (s *MuxedStream) writeNotify() {
|
||||
s.readyList.Signal(s.streamID)
|
||||
}
|
||||
|
||||
// Call by muxreader when it gets a WindowUpdateFrame. This is an update of the peer's
|
||||
// receive window (how much data we can send).
|
||||
func (s *MuxedStream) replenishSendWindow(bytes uint32) {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
s.sendWindow += bytes
|
||||
s.writeNotify()
|
||||
}
|
||||
|
||||
// Call by muxreader when it receives a data frame
|
||||
func (s *MuxedStream) consumeReceiveWindow(bytes uint32) bool {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
// received data size is greater than receive window/buffer
|
||||
if s.receiveWindow < bytes {
|
||||
return false
|
||||
}
|
||||
s.receiveWindow -= bytes
|
||||
if s.receiveWindow < s.receiveWindowCurrentMax/2 && s.receiveWindowCurrentMax < s.receiveWindowMax {
|
||||
// exhausting client send window (how much data client can send)
|
||||
// and there is room to grow the receive window
|
||||
newMax := s.receiveWindowCurrentMax << 1
|
||||
if newMax > s.receiveWindowMax {
|
||||
newMax = s.receiveWindowMax
|
||||
}
|
||||
s.windowUpdate += newMax - s.receiveWindowCurrentMax
|
||||
s.receiveWindowCurrentMax = newMax
|
||||
// notify MuxWriter to write WINDOW_UPDATE frame
|
||||
s.writeNotify()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Arranges for the MuxWriter to send a WINDOW_UPDATE
|
||||
// Called by MuxedStream::Read when data has left the read buffer.
|
||||
func (s *MuxedStream) replenishReceiveWindow(bytes uint32) {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
s.windowUpdate += bytes
|
||||
s.writeNotify()
|
||||
}
|
||||
|
||||
// receiveEOF should be called when the peer indicates no more data will be sent.
|
||||
// Returns true if the socket is now closed (i.e. the write side is already closed).
|
||||
func (s *MuxedStream) receiveEOF() (closed bool) {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
s.receivedEOF = true
|
||||
s.CloseRead()
|
||||
return s.writeEOF && s.writeBuffer.Len() == 0
|
||||
}
|
||||
|
||||
func (s *MuxedStream) gotReceiveEOF() bool {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
return s.receivedEOF
|
||||
}
|
||||
|
||||
// MuxedStreamReader implements io.ReadCloser for the read end of the stream.
|
||||
// This is useful for passing to functions that close the object after it is done reading,
|
||||
// but you still want to be able to write data afterwards (e.g. http.Client).
|
||||
type MuxedStreamReader struct {
|
||||
*MuxedStream
|
||||
}
|
||||
|
||||
func (s MuxedStreamReader) Read(p []byte) (n int, err error) {
|
||||
return s.MuxedStream.Read(p)
|
||||
}
|
||||
|
||||
func (s MuxedStreamReader) Close() error {
|
||||
return s.MuxedStream.CloseRead()
|
||||
}
|
||||
|
||||
// streamChunk represents a chunk of data to be written.
|
||||
type streamChunk struct {
|
||||
streamID uint32
|
||||
// true if a HEADERS frame should be sent
|
||||
sendHeaders bool
|
||||
headers []Header
|
||||
// nonzero if a WINDOW_UPDATE frame should be sent;
|
||||
// in that case, it is the increment value to use
|
||||
windowUpdate uint32
|
||||
// true if data frames should be sent
|
||||
sendData bool
|
||||
eof bool
|
||||
|
||||
buffer []byte
|
||||
offset int
|
||||
}
|
||||
|
||||
// getChunk atomically extracts a chunk of data to be written by MuxWriter.
|
||||
// The data returned will not exceed the send window for this stream.
|
||||
func (s *MuxedStream) getChunk() *streamChunk {
|
||||
s.writeLock.Lock()
|
||||
defer s.writeLock.Unlock()
|
||||
|
||||
chunk := &streamChunk{
|
||||
streamID: s.streamID,
|
||||
sendHeaders: !s.headersSent,
|
||||
headers: s.writeHeaders,
|
||||
windowUpdate: s.windowUpdate,
|
||||
sendData: !s.sentEOF,
|
||||
eof: s.writeEOF && uint32(s.writeBuffer.Len()) <= s.sendWindow,
|
||||
}
|
||||
// Copy at most s.sendWindow bytes, adjust the sendWindow accordingly
|
||||
toCopy := int(s.sendWindow)
|
||||
if toCopy > s.writeBuffer.Len() {
|
||||
toCopy = s.writeBuffer.Len()
|
||||
}
|
||||
|
||||
if toCopy > 0 {
|
||||
buf := make([]byte, toCopy)
|
||||
writeLen, _ := s.writeBuffer.Read(buf)
|
||||
chunk.buffer = buf[:writeLen]
|
||||
s.sendWindow -= uint32(writeLen)
|
||||
}
|
||||
|
||||
// Allow MuxedStream::Write() to continue, if needed
|
||||
if s.writeBuffer.Len() < s.writeBufferMaxLen {
|
||||
s.notifyWriteBufferHasSpace()
|
||||
}
|
||||
|
||||
// When we write the chunk, we'll write the WINDOW_UPDATE frame if needed
|
||||
s.receiveWindow += s.windowUpdate
|
||||
s.windowUpdate = 0
|
||||
|
||||
// When we write the chunk, we'll write the headers if needed
|
||||
s.headersSent = true
|
||||
|
||||
// if this chunk contains the end of the stream, close the stream now
|
||||
if chunk.sendData && chunk.eof {
|
||||
s.sentEOF = true
|
||||
}
|
||||
|
||||
return chunk
|
||||
}
|
||||
|
||||
func (c *streamChunk) sendHeadersFrame() bool {
|
||||
return c.sendHeaders
|
||||
}
|
||||
|
||||
func (c *streamChunk) sendWindowUpdateFrame() bool {
|
||||
return c.windowUpdate > 0
|
||||
}
|
||||
|
||||
func (c *streamChunk) sendDataFrame() bool {
|
||||
return c.sendData
|
||||
}
|
||||
|
||||
func (c *streamChunk) nextDataFrame(frameSize int) (payload []byte, endStream bool) {
|
||||
bytesLeft := len(c.buffer) - c.offset
|
||||
if frameSize > bytesLeft {
|
||||
frameSize = bytesLeft
|
||||
}
|
||||
nextOffset := c.offset + frameSize
|
||||
payload = c.buffer[c.offset:nextOffset]
|
||||
c.offset = nextOffset
|
||||
|
||||
if c.offset == len(c.buffer) {
|
||||
// this is the last data frame in this chunk
|
||||
c.sendData = false
|
||||
if c.eof {
|
||||
endStream = true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
|
@ -1,127 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const testWindowSize uint32 = 65535
|
||||
const testMaxWindowSize uint32 = testWindowSize << 2
|
||||
|
||||
// Only sending WINDOW_UPDATE frame, so sendWindow should never change
|
||||
func TestFlowControlSingleStream(t *testing.T) {
|
||||
stream := &MuxedStream{
|
||||
responseHeadersReceived: make(chan struct{}),
|
||||
readBuffer: NewSharedBuffer(),
|
||||
writeBuffer: &bytes.Buffer{},
|
||||
receiveWindow: testWindowSize,
|
||||
receiveWindowCurrentMax: testWindowSize,
|
||||
receiveWindowMax: testMaxWindowSize,
|
||||
sendWindow: testWindowSize,
|
||||
readyList: NewReadyList(),
|
||||
}
|
||||
var tempWindowUpdate uint32
|
||||
var tempStreamChunk *streamChunk
|
||||
|
||||
assert.True(t, stream.consumeReceiveWindow(testWindowSize/2))
|
||||
dataSent := testWindowSize / 2
|
||||
assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow)
|
||||
assert.Equal(t, testWindowSize, stream.receiveWindowCurrentMax)
|
||||
assert.Equal(t, testWindowSize, stream.sendWindow)
|
||||
assert.Equal(t, uint32(0), stream.windowUpdate)
|
||||
|
||||
tempStreamChunk = stream.getChunk()
|
||||
assert.Equal(t, uint32(0), tempStreamChunk.windowUpdate)
|
||||
assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow)
|
||||
assert.Equal(t, testWindowSize, stream.receiveWindowCurrentMax)
|
||||
assert.Equal(t, testWindowSize, stream.sendWindow)
|
||||
assert.Equal(t, uint32(0), stream.windowUpdate)
|
||||
|
||||
assert.True(t, stream.consumeReceiveWindow(2))
|
||||
dataSent += 2
|
||||
assert.Equal(t, testWindowSize-dataSent, stream.receiveWindow)
|
||||
assert.Equal(t, testWindowSize<<1, stream.receiveWindowCurrentMax)
|
||||
assert.Equal(t, testWindowSize, stream.sendWindow)
|
||||
assert.Equal(t, testWindowSize, stream.windowUpdate)
|
||||
tempWindowUpdate = stream.windowUpdate
|
||||
|
||||
tempStreamChunk = stream.getChunk()
|
||||
assert.Equal(t, tempWindowUpdate, tempStreamChunk.windowUpdate)
|
||||
assert.Equal(t, (testWindowSize<<1)-dataSent, stream.receiveWindow)
|
||||
assert.Equal(t, testWindowSize<<1, stream.receiveWindowCurrentMax)
|
||||
assert.Equal(t, testWindowSize, stream.sendWindow)
|
||||
assert.Equal(t, uint32(0), stream.windowUpdate)
|
||||
|
||||
assert.True(t, stream.consumeReceiveWindow(testWindowSize+10))
|
||||
dataSent += testWindowSize + 10
|
||||
assert.Equal(t, (testWindowSize<<1)-dataSent, stream.receiveWindow)
|
||||
assert.Equal(t, testWindowSize<<2, stream.receiveWindowCurrentMax)
|
||||
assert.Equal(t, testWindowSize, stream.sendWindow)
|
||||
assert.Equal(t, testWindowSize<<1, stream.windowUpdate)
|
||||
tempWindowUpdate = stream.windowUpdate
|
||||
|
||||
tempStreamChunk = stream.getChunk()
|
||||
assert.Equal(t, tempWindowUpdate, tempStreamChunk.windowUpdate)
|
||||
assert.Equal(t, (testWindowSize<<2)-dataSent, stream.receiveWindow)
|
||||
assert.Equal(t, testWindowSize<<2, stream.receiveWindowCurrentMax)
|
||||
assert.Equal(t, testWindowSize, stream.sendWindow)
|
||||
assert.Equal(t, uint32(0), stream.windowUpdate)
|
||||
|
||||
assert.False(t, stream.consumeReceiveWindow(testMaxWindowSize+1))
|
||||
assert.Equal(t, (testWindowSize<<2)-dataSent, stream.receiveWindow)
|
||||
assert.Equal(t, testMaxWindowSize, stream.receiveWindowCurrentMax)
|
||||
}
|
||||
|
||||
func TestMuxedStreamEOF(t *testing.T) {
|
||||
for i := 0; i < 4096; i++ {
|
||||
readyList := NewReadyList()
|
||||
stream := &MuxedStream{
|
||||
streamID: 1,
|
||||
readBuffer: NewSharedBuffer(),
|
||||
receiveWindow: 65536,
|
||||
receiveWindowMax: 65536,
|
||||
sendWindow: 65536,
|
||||
readyList: readyList,
|
||||
}
|
||||
|
||||
go func() { stream.Close() }()
|
||||
n, err := stream.Read([]byte{0})
|
||||
assert.Equal(t, io.EOF, err)
|
||||
assert.Equal(t, 0, n)
|
||||
// Write comes after read, because write buffers data before it is flushed. It wouldn't know about EOF
|
||||
// until some time later. Calling read first forces it to know about EOF now.
|
||||
n, err = stream.Write([]byte{1})
|
||||
assert.Equal(t, io.EOF, err)
|
||||
assert.Equal(t, 0, n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRPCStream(t *testing.T) {
|
||||
tests := []struct {
|
||||
stream *MuxedStream
|
||||
isRPCStream bool
|
||||
}{
|
||||
{
|
||||
stream: &MuxedStream{},
|
||||
isRPCStream: false,
|
||||
},
|
||||
{
|
||||
stream: &MuxedStream{Headers: RPCHeaders()},
|
||||
isRPCStream: true,
|
||||
},
|
||||
{
|
||||
stream: &MuxedStream{Headers: []Header{
|
||||
{Name: ":method", Value: "rpc"},
|
||||
{Name: ":scheme", Value: "Capnp"},
|
||||
{Name: ":path", Value: "/"},
|
||||
}},
|
||||
isRPCStream: false,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
assert.Equal(t, test.isRPCStream, test.stream.IsRPCStream())
|
||||
}
|
||||
}
|
|
@ -1,296 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-collections/collections/queue"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// data points used to compute average receive window and send window size
|
||||
const (
|
||||
// data points used to compute average receive window and send window size
|
||||
dataPoints = 100
|
||||
// updateFreq is set to 1 sec so we can get inbound & outbound byes/sec
|
||||
updateFreq = time.Second
|
||||
)
|
||||
|
||||
type muxMetricsUpdater interface {
|
||||
// metrics returns the latest metrics
|
||||
metrics() *MuxerMetrics
|
||||
// run is a blocking call to start the event loop
|
||||
run(log *zerolog.Logger) error
|
||||
// updateRTTChan is called by muxReader to report new RTT measurements
|
||||
updateRTT(rtt *roundTripMeasurement)
|
||||
//updateReceiveWindowChan is called by muxReader and muxWriter when receiveWindow size is updated
|
||||
updateReceiveWindow(receiveWindow uint32)
|
||||
//updateSendWindowChan is called by muxReader and muxWriter when sendWindow size is updated
|
||||
updateSendWindow(sendWindow uint32)
|
||||
// updateInBoundBytesChan is called periodicallyby muxReader to report bytesRead
|
||||
updateInBoundBytes(inBoundBytes uint64)
|
||||
// updateOutBoundBytesChan is called periodically by muxWriter to report bytesWrote
|
||||
updateOutBoundBytes(outBoundBytes uint64)
|
||||
}
|
||||
|
||||
type muxMetricsUpdaterImpl struct {
|
||||
// rttData keeps record of rtt, rttMin, rttMax and last measured time
|
||||
rttData *rttData
|
||||
// receiveWindowData keeps record of receive window measurement
|
||||
receiveWindowData *flowControlData
|
||||
// sendWindowData keeps record of send window measurement
|
||||
sendWindowData *flowControlData
|
||||
// inBoundRate is incoming bytes/sec
|
||||
inBoundRate *rate
|
||||
// outBoundRate is outgoing bytes/sec
|
||||
outBoundRate *rate
|
||||
// updateRTTChan is the channel to receive new RTT measurement
|
||||
updateRTTChan chan *roundTripMeasurement
|
||||
//updateReceiveWindowChan is the channel to receive updated receiveWindow size
|
||||
updateReceiveWindowChan chan uint32
|
||||
//updateSendWindowChan is the channel to receive updated sendWindow size
|
||||
updateSendWindowChan chan uint32
|
||||
// updateInBoundBytesChan us the channel to receive bytesRead
|
||||
updateInBoundBytesChan chan uint64
|
||||
// updateOutBoundBytesChan us the channel to receive bytesWrote
|
||||
updateOutBoundBytesChan chan uint64
|
||||
// shutdownC is to signal the muxerMetricsUpdater to shutdown
|
||||
abortChan <-chan struct{}
|
||||
|
||||
compBytesBefore, compBytesAfter *AtomicCounter
|
||||
}
|
||||
|
||||
type MuxerMetrics struct {
|
||||
RTT, RTTMin, RTTMax time.Duration
|
||||
ReceiveWindowAve, SendWindowAve float64
|
||||
ReceiveWindowMin, ReceiveWindowMax, SendWindowMin, SendWindowMax uint32
|
||||
InBoundRateCurr, InBoundRateMin, InBoundRateMax uint64
|
||||
OutBoundRateCurr, OutBoundRateMin, OutBoundRateMax uint64
|
||||
CompBytesBefore, CompBytesAfter *AtomicCounter
|
||||
}
|
||||
|
||||
func (m *MuxerMetrics) CompRateAve() float64 {
|
||||
if m.CompBytesBefore.Value() == 0 {
|
||||
return 1.
|
||||
}
|
||||
return float64(m.CompBytesAfter.Value()) / float64(m.CompBytesBefore.Value())
|
||||
}
|
||||
|
||||
type roundTripMeasurement struct {
|
||||
receiveTime, sendTime time.Time
|
||||
}
|
||||
|
||||
type rttData struct {
|
||||
rtt, rttMin, rttMax time.Duration
|
||||
lastMeasurementTime time.Time
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
type flowControlData struct {
|
||||
sum uint64
|
||||
min, max uint32
|
||||
queue *queue.Queue
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
type rate struct {
|
||||
curr uint64
|
||||
min, max uint64
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func newMuxMetricsUpdater(
|
||||
abortChan <-chan struct{},
|
||||
compBytesBefore, compBytesAfter *AtomicCounter,
|
||||
) muxMetricsUpdater {
|
||||
updateRTTChan := make(chan *roundTripMeasurement, 1)
|
||||
updateReceiveWindowChan := make(chan uint32, 1)
|
||||
updateSendWindowChan := make(chan uint32, 1)
|
||||
updateInBoundBytesChan := make(chan uint64)
|
||||
updateOutBoundBytesChan := make(chan uint64)
|
||||
|
||||
return &muxMetricsUpdaterImpl{
|
||||
rttData: newRTTData(),
|
||||
receiveWindowData: newFlowControlData(),
|
||||
sendWindowData: newFlowControlData(),
|
||||
inBoundRate: newRate(),
|
||||
outBoundRate: newRate(),
|
||||
updateRTTChan: updateRTTChan,
|
||||
updateReceiveWindowChan: updateReceiveWindowChan,
|
||||
updateSendWindowChan: updateSendWindowChan,
|
||||
updateInBoundBytesChan: updateInBoundBytesChan,
|
||||
updateOutBoundBytesChan: updateOutBoundBytesChan,
|
||||
abortChan: abortChan,
|
||||
compBytesBefore: compBytesBefore,
|
||||
compBytesAfter: compBytesAfter,
|
||||
}
|
||||
}
|
||||
|
||||
func (updater *muxMetricsUpdaterImpl) metrics() *MuxerMetrics {
|
||||
m := &MuxerMetrics{}
|
||||
m.RTT, m.RTTMin, m.RTTMax = updater.rttData.metrics()
|
||||
m.ReceiveWindowAve, m.ReceiveWindowMin, m.ReceiveWindowMax = updater.receiveWindowData.metrics()
|
||||
m.SendWindowAve, m.SendWindowMin, m.SendWindowMax = updater.sendWindowData.metrics()
|
||||
m.InBoundRateCurr, m.InBoundRateMin, m.InBoundRateMax = updater.inBoundRate.get()
|
||||
m.OutBoundRateCurr, m.OutBoundRateMin, m.OutBoundRateMax = updater.outBoundRate.get()
|
||||
m.CompBytesBefore, m.CompBytesAfter = updater.compBytesBefore, updater.compBytesAfter
|
||||
return m
|
||||
}
|
||||
|
||||
func (updater *muxMetricsUpdaterImpl) run(log *zerolog.Logger) error {
|
||||
defer log.Debug().Msg("mux - metrics: event loop finished")
|
||||
for {
|
||||
select {
|
||||
case <-updater.abortChan:
|
||||
log.Debug().Msgf("mux - metrics: Stopping mux metrics updater")
|
||||
return nil
|
||||
case roundTripMeasurement := <-updater.updateRTTChan:
|
||||
go updater.rttData.update(roundTripMeasurement)
|
||||
log.Debug().Msg("mux - metrics: Update rtt")
|
||||
case receiveWindow := <-updater.updateReceiveWindowChan:
|
||||
go updater.receiveWindowData.update(receiveWindow)
|
||||
log.Debug().Msg("mux - metrics: Update receive window")
|
||||
case sendWindow := <-updater.updateSendWindowChan:
|
||||
go updater.sendWindowData.update(sendWindow)
|
||||
log.Debug().Msg("mux - metrics: Update send window")
|
||||
case inBoundBytes := <-updater.updateInBoundBytesChan:
|
||||
// inBoundBytes is bytes/sec because the update interval is 1 sec
|
||||
go updater.inBoundRate.update(inBoundBytes)
|
||||
log.Debug().Msgf("mux - metrics: Inbound bytes %d", inBoundBytes)
|
||||
case outBoundBytes := <-updater.updateOutBoundBytesChan:
|
||||
// outBoundBytes is bytes/sec because the update interval is 1 sec
|
||||
go updater.outBoundRate.update(outBoundBytes)
|
||||
log.Debug().Msgf("mux - metrics: Outbound bytes %d", outBoundBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (updater *muxMetricsUpdaterImpl) updateRTT(rtt *roundTripMeasurement) {
|
||||
select {
|
||||
case updater.updateRTTChan <- rtt:
|
||||
case <-updater.abortChan:
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (updater *muxMetricsUpdaterImpl) updateReceiveWindow(receiveWindow uint32) {
|
||||
select {
|
||||
case updater.updateReceiveWindowChan <- receiveWindow:
|
||||
case <-updater.abortChan:
|
||||
}
|
||||
}
|
||||
|
||||
func (updater *muxMetricsUpdaterImpl) updateSendWindow(sendWindow uint32) {
|
||||
select {
|
||||
case updater.updateSendWindowChan <- sendWindow:
|
||||
case <-updater.abortChan:
|
||||
}
|
||||
}
|
||||
|
||||
func (updater *muxMetricsUpdaterImpl) updateInBoundBytes(inBoundBytes uint64) {
|
||||
select {
|
||||
case updater.updateInBoundBytesChan <- inBoundBytes:
|
||||
case <-updater.abortChan:
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (updater *muxMetricsUpdaterImpl) updateOutBoundBytes(outBoundBytes uint64) {
|
||||
select {
|
||||
case updater.updateOutBoundBytesChan <- outBoundBytes:
|
||||
case <-updater.abortChan:
|
||||
}
|
||||
}
|
||||
|
||||
func newRTTData() *rttData {
|
||||
return &rttData{}
|
||||
}
|
||||
|
||||
func (r *rttData) update(measurement *roundTripMeasurement) {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
// discard pings before lastMeasurementTime
|
||||
if r.lastMeasurementTime.After(measurement.sendTime) {
|
||||
return
|
||||
}
|
||||
r.lastMeasurementTime = measurement.sendTime
|
||||
r.rtt = measurement.receiveTime.Sub(measurement.sendTime)
|
||||
if r.rttMax < r.rtt {
|
||||
r.rttMax = r.rtt
|
||||
}
|
||||
if r.rttMin == 0 || r.rttMin > r.rtt {
|
||||
r.rttMin = r.rtt
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rttData) metrics() (rtt, rttMin, rttMax time.Duration) {
|
||||
r.lock.RLock()
|
||||
defer r.lock.RUnlock()
|
||||
return r.rtt, r.rttMin, r.rttMax
|
||||
}
|
||||
|
||||
func newFlowControlData() *flowControlData {
|
||||
return &flowControlData{queue: queue.New()}
|
||||
}
|
||||
|
||||
func (f *flowControlData) update(measurement uint32) {
|
||||
f.lock.Lock()
|
||||
defer f.lock.Unlock()
|
||||
var firstItem uint32
|
||||
// store new data into queue, remove oldest data if queue is full
|
||||
f.queue.Enqueue(measurement)
|
||||
if f.queue.Len() > dataPoints {
|
||||
// data type should always be uint32
|
||||
firstItem = f.queue.Dequeue().(uint32)
|
||||
}
|
||||
// if (measurement - firstItem) < 0, uint64(measurement - firstItem)
|
||||
// will overflow and become a large positive number
|
||||
f.sum += uint64(measurement)
|
||||
f.sum -= uint64(firstItem)
|
||||
if measurement > f.max {
|
||||
f.max = measurement
|
||||
}
|
||||
if f.min == 0 || measurement < f.min {
|
||||
f.min = measurement
|
||||
}
|
||||
}
|
||||
|
||||
// caller of ave() should acquire lock first
|
||||
func (f *flowControlData) ave() float64 {
|
||||
if f.queue.Len() == 0 {
|
||||
return 0
|
||||
}
|
||||
return float64(f.sum) / float64(f.queue.Len())
|
||||
}
|
||||
|
||||
func (f *flowControlData) metrics() (ave float64, min, max uint32) {
|
||||
f.lock.RLock()
|
||||
defer f.lock.RUnlock()
|
||||
return f.ave(), f.min, f.max
|
||||
}
|
||||
|
||||
func newRate() *rate {
|
||||
return &rate{}
|
||||
}
|
||||
|
||||
func (r *rate) update(measurement uint64) {
|
||||
r.lock.Lock()
|
||||
defer r.lock.Unlock()
|
||||
r.curr = measurement
|
||||
// if measurement is 0, then there is no incoming/outgoing connection, don't update min/max
|
||||
if r.curr == 0 {
|
||||
return
|
||||
}
|
||||
if measurement > r.max {
|
||||
r.max = measurement
|
||||
}
|
||||
if r.min == 0 || measurement < r.min {
|
||||
r.min = measurement
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rate) get() (curr, min, max uint64) {
|
||||
r.lock.RLock()
|
||||
defer r.lock.RUnlock()
|
||||
return r.curr, r.min, r.max
|
||||
}
|
|
@ -1,169 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func ave(sum uint64, len int) float64 {
|
||||
return float64(sum) / float64(len)
|
||||
}
|
||||
|
||||
func TestRTTUpdate(t *testing.T) {
|
||||
r := newRTTData()
|
||||
start := time.Now()
|
||||
// send at 0 ms, receive at 2 ms, RTT = 2ms
|
||||
m := &roundTripMeasurement{receiveTime: start.Add(2 * time.Millisecond), sendTime: start}
|
||||
r.update(m)
|
||||
assert.Equal(t, start, r.lastMeasurementTime)
|
||||
assert.Equal(t, 2*time.Millisecond, r.rtt)
|
||||
assert.Equal(t, 2*time.Millisecond, r.rttMin)
|
||||
assert.Equal(t, 2*time.Millisecond, r.rttMax)
|
||||
|
||||
// send at 3 ms, receive at 6 ms, RTT = 3ms
|
||||
m = &roundTripMeasurement{receiveTime: start.Add(6 * time.Millisecond), sendTime: start.Add(3 * time.Millisecond)}
|
||||
r.update(m)
|
||||
assert.Equal(t, start.Add(3*time.Millisecond), r.lastMeasurementTime)
|
||||
assert.Equal(t, 3*time.Millisecond, r.rtt)
|
||||
assert.Equal(t, 2*time.Millisecond, r.rttMin)
|
||||
assert.Equal(t, 3*time.Millisecond, r.rttMax)
|
||||
|
||||
// send at 7 ms, receive at 8 ms, RTT = 1ms
|
||||
m = &roundTripMeasurement{receiveTime: start.Add(8 * time.Millisecond), sendTime: start.Add(7 * time.Millisecond)}
|
||||
r.update(m)
|
||||
assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime)
|
||||
assert.Equal(t, 1*time.Millisecond, r.rtt)
|
||||
assert.Equal(t, 1*time.Millisecond, r.rttMin)
|
||||
assert.Equal(t, 3*time.Millisecond, r.rttMax)
|
||||
|
||||
// send at -4 ms, receive at 0 ms, RTT = 4ms, but this ping is before last measurement
|
||||
// so it will be discarded
|
||||
m = &roundTripMeasurement{receiveTime: start, sendTime: start.Add(-2 * time.Millisecond)}
|
||||
r.update(m)
|
||||
assert.Equal(t, start.Add(7*time.Millisecond), r.lastMeasurementTime)
|
||||
assert.Equal(t, 1*time.Millisecond, r.rtt)
|
||||
assert.Equal(t, 1*time.Millisecond, r.rttMin)
|
||||
assert.Equal(t, 3*time.Millisecond, r.rttMax)
|
||||
}
|
||||
|
||||
func TestFlowControlDataUpdate(t *testing.T) {
|
||||
f := newFlowControlData()
|
||||
assert.Equal(t, 0, f.queue.Len())
|
||||
assert.Equal(t, float64(0), f.ave())
|
||||
|
||||
var sum uint64
|
||||
min := maxWindowSize - dataPoints
|
||||
max := maxWindowSize
|
||||
for i := 1; i <= dataPoints; i++ {
|
||||
size := maxWindowSize - uint32(i)
|
||||
f.update(size)
|
||||
assert.Equal(t, max-uint32(1), f.max)
|
||||
assert.Equal(t, size, f.min)
|
||||
|
||||
assert.Equal(t, i, f.queue.Len())
|
||||
|
||||
sum += uint64(size)
|
||||
assert.Equal(t, sum, f.sum)
|
||||
assert.Equal(t, ave(sum, f.queue.Len()), f.ave())
|
||||
}
|
||||
|
||||
// queue is full, should start to dequeue first element
|
||||
for i := 1; i <= dataPoints; i++ {
|
||||
f.update(max)
|
||||
assert.Equal(t, max, f.max)
|
||||
assert.Equal(t, min, f.min)
|
||||
|
||||
assert.Equal(t, dataPoints, f.queue.Len())
|
||||
|
||||
sum += uint64(i)
|
||||
assert.Equal(t, sum, f.sum)
|
||||
assert.Equal(t, ave(sum, dataPoints), f.ave())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMuxMetricsUpdater(t *testing.T) {
|
||||
t.Skip("Inherently racy test due to muxMetricsUpdaterImpl.run()")
|
||||
errChan := make(chan error)
|
||||
abortChan := make(chan struct{})
|
||||
compBefore, compAfter := NewAtomicCounter(0), NewAtomicCounter(0)
|
||||
m := newMuxMetricsUpdater(abortChan, compBefore, compAfter)
|
||||
log := zerolog.Nop()
|
||||
|
||||
go func() {
|
||||
errChan <- m.run(&log)
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// mock muxReader
|
||||
readerStart := time.Now()
|
||||
rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart}
|
||||
m.updateRTT(rm)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
assert.Equal(t, 0, dataPoints%4,
|
||||
"dataPoints is not divisible by 4; this test should be adjusted accordingly")
|
||||
readerSend := readerStart.Add(time.Millisecond)
|
||||
for i := 1; i <= dataPoints/4; i++ {
|
||||
readerReceive := readerSend.Add(time.Duration(i) * time.Millisecond)
|
||||
rm := &roundTripMeasurement{receiveTime: readerReceive, sendTime: readerSend}
|
||||
m.updateRTT(rm)
|
||||
readerSend = readerReceive.Add(time.Millisecond)
|
||||
m.updateReceiveWindow(uint32(i))
|
||||
m.updateSendWindow(uint32(i))
|
||||
|
||||
m.updateInBoundBytes(uint64(i))
|
||||
}
|
||||
}()
|
||||
|
||||
// mock muxWriter
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
assert.Equal(t, 0, dataPoints%4,
|
||||
"dataPoints is not divisible by 4; this test should be adjusted accordingly")
|
||||
for j := dataPoints/4 + 1; j <= dataPoints/2; j++ {
|
||||
m.updateReceiveWindow(uint32(j))
|
||||
m.updateSendWindow(uint32(j))
|
||||
|
||||
// should always be discarded since the send time is before readerSend
|
||||
rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart.Add(-time.Duration(j*dataPoints) * time.Millisecond)}
|
||||
m.updateRTT(rm)
|
||||
|
||||
m.updateOutBoundBytes(uint64(j))
|
||||
}
|
||||
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
metrics := m.metrics()
|
||||
points := dataPoints / 2
|
||||
assert.Equal(t, time.Millisecond, metrics.RTTMin)
|
||||
assert.Equal(t, time.Duration(dataPoints/4)*time.Millisecond, metrics.RTTMax)
|
||||
|
||||
// sum(1..i) = i*(i+1)/2, ave(1..i) = i*(i+1)/2/i = (i+1)/2
|
||||
assert.Equal(t, float64(points+1)/float64(2), metrics.ReceiveWindowAve)
|
||||
assert.Equal(t, uint32(1), metrics.ReceiveWindowMin)
|
||||
assert.Equal(t, uint32(points), metrics.ReceiveWindowMax)
|
||||
|
||||
assert.Equal(t, float64(points+1)/float64(2), metrics.SendWindowAve)
|
||||
assert.Equal(t, uint32(1), metrics.SendWindowMin)
|
||||
assert.Equal(t, uint32(points), metrics.SendWindowMax)
|
||||
|
||||
assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateCurr)
|
||||
assert.Equal(t, uint64(1), metrics.InBoundRateMin)
|
||||
assert.Equal(t, uint64(dataPoints/4), metrics.InBoundRateMax)
|
||||
|
||||
assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateCurr)
|
||||
assert.Equal(t, uint64(dataPoints/4+1), metrics.OutBoundRateMin)
|
||||
assert.Equal(t, uint64(dataPoints/2), metrics.OutBoundRateMax)
|
||||
|
||||
close(abortChan)
|
||||
assert.Nil(t, <-errChan)
|
||||
close(errChan)
|
||||
|
||||
}
|
|
@ -1,508 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type MuxReader struct {
|
||||
// f is used to read HTTP2 frames.
|
||||
f *http2.Framer
|
||||
// handler provides a callback to receive new streams. if nil, new streams cannot be accepted.
|
||||
handler MuxedStreamHandler
|
||||
// streams tracks currently-open streams.
|
||||
streams *activeStreamMap
|
||||
// readyList is used to signal writable streams.
|
||||
readyList *ReadyList
|
||||
// streamErrors lets us report stream errors to the MuxWriter.
|
||||
streamErrors *StreamErrorMap
|
||||
// goAwayChan is used to tell the writer to send a GOAWAY message.
|
||||
goAwayChan chan<- http2.ErrCode
|
||||
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
|
||||
abortChan <-chan struct{}
|
||||
// pingTimestamp is an atomic value containing the latest received ping timestamp.
|
||||
pingTimestamp *PingTimestamp
|
||||
// connActive is used to signal to the writer that something happened on the connection.
|
||||
// This is used to clear idle timeout disconnection deadlines.
|
||||
connActive Signal
|
||||
// The initial value for the send and receive window of a new stream.
|
||||
initialStreamWindow uint32
|
||||
// The max value for the send window of a stream.
|
||||
streamWindowMax uint32
|
||||
// The max size for the write buffer of a stream
|
||||
streamWriteBufferMaxLen int
|
||||
// r is a reference to the underlying connection used when shutting down.
|
||||
r io.Closer
|
||||
// metricsUpdater is used to report metrics
|
||||
metricsUpdater muxMetricsUpdater
|
||||
// bytesRead is the amount of bytes read from data frames since the last time we called metricsUpdater.updateInBoundBytes()
|
||||
bytesRead *AtomicCounter
|
||||
// dictionaries holds the h2 cross-stream compression dictionaries
|
||||
dictionaries h2Dictionaries
|
||||
}
|
||||
|
||||
// Shutdown blocks new streams from being created.
|
||||
// It returns a channel that is closed once the last stream has closed.
|
||||
func (r *MuxReader) Shutdown() <-chan struct{} {
|
||||
done, alreadyInProgress := r.streams.Shutdown()
|
||||
if alreadyInProgress {
|
||||
return done
|
||||
}
|
||||
r.sendGoAway(http2.ErrCodeNo)
|
||||
go func() {
|
||||
// close reader side when last stream ends; this will cause the writer to abort
|
||||
<-done
|
||||
r.r.Close()
|
||||
}()
|
||||
return done
|
||||
}
|
||||
|
||||
func (r *MuxReader) run(log *zerolog.Logger) error {
|
||||
defer log.Debug().Msg("mux - read: event loop finished")
|
||||
|
||||
// routine to periodically update bytesRead
|
||||
go func() {
|
||||
ticker := time.NewTicker(updateFreq)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-r.abortChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.metricsUpdater.updateInBoundBytes(r.bytesRead.Count())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
frame, err := r.f.ReadFrame()
|
||||
if err != nil {
|
||||
errorString := fmt.Sprintf("mux - read: %s", err)
|
||||
if errorDetail := r.f.ErrorDetail(); errorDetail != nil {
|
||||
errorString = fmt.Sprintf("%s: errorDetail: %s", errorString, errorDetail)
|
||||
}
|
||||
switch e := err.(type) {
|
||||
case http2.StreamError:
|
||||
log.Info().Msgf("%s: stream error", errorString)
|
||||
// Ideally we wouldn't return here, since that aborts the muxer.
|
||||
// We should communicate the error to the relevant MuxedStream
|
||||
// data structure, so that callers of MuxedStream.Read() and
|
||||
// MuxedStream.Write() would see it. Then we could `continue`
|
||||
// and keep the muxer going.
|
||||
return r.streamError(e.StreamID, e.Code)
|
||||
case http2.ConnectionError:
|
||||
log.Info().Msgf("%s: stream error", errorString)
|
||||
return r.connectionError(err)
|
||||
default:
|
||||
if isConnectionClosedError(err) {
|
||||
if r.streams.Len() == 0 {
|
||||
// don't log the error here -- that would just be extra noise
|
||||
log.Debug().Msg("mux - read: shutting down")
|
||||
return nil
|
||||
}
|
||||
log.Info().Msgf("%s: connection closed unexpectedly", errorString)
|
||||
return err
|
||||
} else {
|
||||
log.Info().Msgf("%s: frame read error", errorString)
|
||||
return r.connectionError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
r.connActive.Signal()
|
||||
log.Debug().Msgf("mux - read: read frame: data %v", frame)
|
||||
switch f := frame.(type) {
|
||||
case *http2.DataFrame:
|
||||
err = r.receiveFrameData(f, log)
|
||||
case *http2.MetaHeadersFrame:
|
||||
err = r.receiveHeaderData(f)
|
||||
case *http2.RSTStreamFrame:
|
||||
streamID := f.Header().StreamID
|
||||
if streamID == 0 {
|
||||
return ErrInvalidStream
|
||||
}
|
||||
if stream, ok := r.streams.Get(streamID); ok {
|
||||
stream.Close()
|
||||
}
|
||||
r.streams.Delete(streamID)
|
||||
case *http2.PingFrame:
|
||||
r.receivePingData(f)
|
||||
case *http2.GoAwayFrame:
|
||||
err = r.receiveGoAway(f)
|
||||
// The receiver of a flow-controlled frame sends a WINDOW_UPDATE frame as it
|
||||
// consumes data and frees up space in flow-control windows
|
||||
case *http2.WindowUpdateFrame:
|
||||
err = r.updateStreamWindow(f)
|
||||
case *http2.UnknownFrame:
|
||||
switch f.Header().Type {
|
||||
case FrameUseDictionary:
|
||||
err = r.receiveUseDictionary(f)
|
||||
case FrameSetDictionary:
|
||||
err = r.receiveSetDictionary(f)
|
||||
default:
|
||||
err = ErrUnexpectedFrameType
|
||||
}
|
||||
default:
|
||||
err = ErrUnexpectedFrameType
|
||||
}
|
||||
if err != nil {
|
||||
log.Debug().Msgf("mux - read: read error: data %v", frame)
|
||||
return r.connectionError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MuxReader) newMuxedStream(streamID uint32) *MuxedStream {
|
||||
return &MuxedStream{
|
||||
streamID: streamID,
|
||||
readBuffer: NewSharedBuffer(),
|
||||
writeBuffer: &bytes.Buffer{},
|
||||
writeBufferMaxLen: r.streamWriteBufferMaxLen,
|
||||
writeBufferHasSpace: make(chan struct{}, 1),
|
||||
receiveWindow: r.initialStreamWindow,
|
||||
receiveWindowCurrentMax: r.initialStreamWindow,
|
||||
receiveWindowMax: r.streamWindowMax,
|
||||
sendWindow: r.initialStreamWindow,
|
||||
readyList: r.readyList,
|
||||
dictionaries: r.dictionaries,
|
||||
}
|
||||
}
|
||||
|
||||
// getStreamForFrame returns a stream if valid, or an error describing why the stream could not be returned.
|
||||
func (r *MuxReader) getStreamForFrame(frame http2.Frame) (*MuxedStream, error) {
|
||||
sid := frame.Header().StreamID
|
||||
if sid == 0 {
|
||||
return nil, ErrUnexpectedFrameType
|
||||
}
|
||||
if stream, ok := r.streams.Get(sid); ok {
|
||||
return stream, nil
|
||||
}
|
||||
if r.streams.IsLocalStreamID(sid) {
|
||||
// no stream available, but no error
|
||||
return nil, ErrClosedStream
|
||||
}
|
||||
if sid < r.streams.LastPeerStreamID() {
|
||||
// no stream available, stream closed error
|
||||
return nil, ErrClosedStream
|
||||
}
|
||||
return nil, ErrUnknownStream
|
||||
}
|
||||
|
||||
func (r *MuxReader) defaultStreamErrorHandler(err error, header http2.FrameHeader) error {
|
||||
if header.Flags.Has(http2.FlagHeadersEndStream) {
|
||||
return nil
|
||||
} else if err == ErrUnknownStream || err == ErrClosedStream {
|
||||
return r.streamError(header.StreamID, http2.ErrCodeStreamClosed)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Receives header frames from a stream. A non-nil error is a connection error.
|
||||
func (r *MuxReader) receiveHeaderData(frame *http2.MetaHeadersFrame) error {
|
||||
var stream *MuxedStream
|
||||
sid := frame.Header().StreamID
|
||||
if sid == 0 {
|
||||
return ErrUnexpectedFrameType
|
||||
}
|
||||
newStream := r.streams.IsPeerStreamID(sid)
|
||||
if newStream {
|
||||
// header request
|
||||
// TODO support trailers (if stream exists)
|
||||
ok, err := r.streams.AcquirePeerID(sid)
|
||||
if !ok {
|
||||
// ignore new streams while shutting down
|
||||
return r.streamError(sid, err)
|
||||
}
|
||||
stream = r.newMuxedStream(sid)
|
||||
// Set stream. Returns false if a stream already existed with that ID or we are shutting down, return false.
|
||||
if !r.streams.Set(stream) {
|
||||
// got HEADERS frame for an existing stream
|
||||
// TODO support trailers
|
||||
return r.streamError(sid, http2.ErrCodeInternal)
|
||||
}
|
||||
} else {
|
||||
// header response
|
||||
var err error
|
||||
if stream, err = r.getStreamForFrame(frame); err != nil {
|
||||
return r.defaultStreamErrorHandler(err, frame.Header())
|
||||
}
|
||||
}
|
||||
headers := make([]Header, 0, len(frame.Fields))
|
||||
for _, header := range frame.Fields {
|
||||
switch header.Name {
|
||||
case ":method":
|
||||
stream.method = header.Value
|
||||
case ":path":
|
||||
u, err := url.Parse(header.Value)
|
||||
if err == nil {
|
||||
stream.path = u.Path
|
||||
}
|
||||
case "accept-encoding":
|
||||
// remove accept-encoding if dictionaries are enabled
|
||||
if r.dictionaries.write != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
headers = append(headers, Header{Name: header.Name, Value: header.Value})
|
||||
}
|
||||
stream.Headers = headers
|
||||
if frame.Header().Flags.Has(http2.FlagHeadersEndStream) {
|
||||
stream.receiveEOF()
|
||||
return nil
|
||||
}
|
||||
if newStream {
|
||||
go r.handleStream(stream)
|
||||
} else {
|
||||
close(stream.responseHeadersReceived)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MuxReader) handleStream(stream *MuxedStream) {
|
||||
defer stream.Close()
|
||||
r.handler.ServeStream(stream)
|
||||
}
|
||||
|
||||
// Receives a data frame from a stream. A non-nil error is a connection error.
|
||||
func (r *MuxReader) receiveFrameData(frame *http2.DataFrame, log *zerolog.Logger) error {
|
||||
stream, err := r.getStreamForFrame(frame)
|
||||
if err != nil {
|
||||
return r.defaultStreamErrorHandler(err, frame.Header())
|
||||
}
|
||||
data := frame.Data()
|
||||
if len(data) > 0 {
|
||||
n, err := stream.readBuffer.Write(data)
|
||||
if err != nil {
|
||||
return r.streamError(stream.streamID, http2.ErrCodeInternal)
|
||||
}
|
||||
r.bytesRead.IncrementBy(uint64(n))
|
||||
}
|
||||
if frame.Header().Flags.Has(http2.FlagDataEndStream) {
|
||||
if stream.receiveEOF() {
|
||||
r.streams.Delete(stream.streamID)
|
||||
log.Debug().Msgf("mux - read: stream closed: streamID: %d", frame.Header().StreamID)
|
||||
} else {
|
||||
log.Debug().Msgf("mux - read: shutdown receive side: streamID: %d", frame.Header().StreamID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if !stream.consumeReceiveWindow(uint32(len(data))) {
|
||||
return r.streamError(stream.streamID, http2.ErrCodeFlowControl)
|
||||
}
|
||||
r.metricsUpdater.updateReceiveWindow(stream.getReceiveWindow())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Receive a PING from the peer. Update RTT and send/receive window metrics if it's an ACK.
|
||||
func (r *MuxReader) receivePingData(frame *http2.PingFrame) {
|
||||
ts := int64(binary.LittleEndian.Uint64(frame.Data[:]))
|
||||
if !frame.IsAck() {
|
||||
r.pingTimestamp.Set(ts)
|
||||
return
|
||||
}
|
||||
|
||||
// Update the computed RTT aggregations with a new measurement.
|
||||
// `ts` is the time that the probe was sent.
|
||||
// We assume that `time.Now()` is the time we received that probe.
|
||||
r.metricsUpdater.updateRTT(&roundTripMeasurement{
|
||||
receiveTime: time.Now(),
|
||||
sendTime: time.Unix(0, ts),
|
||||
})
|
||||
}
|
||||
|
||||
// Receive a GOAWAY from the peer. Gracefully shut down our connection.
|
||||
func (r *MuxReader) receiveGoAway(frame *http2.GoAwayFrame) error {
|
||||
r.Shutdown()
|
||||
// Close all streams above the last processed stream
|
||||
lastStream := r.streams.LastLocalStreamID()
|
||||
for i := frame.LastStreamID + 2; i <= lastStream; i++ {
|
||||
if stream, ok := r.streams.Get(i); ok {
|
||||
stream.Close()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Receive a USE_DICTIONARY from the peer. Setup dictionary for stream.
|
||||
func (r *MuxReader) receiveUseDictionary(frame *http2.UnknownFrame) error {
|
||||
payload := frame.Payload()
|
||||
streamID := frame.StreamID
|
||||
|
||||
// Check frame is formatted properly
|
||||
if len(payload) != 1 {
|
||||
return r.streamError(streamID, http2.ErrCodeProtocol)
|
||||
}
|
||||
|
||||
stream, err := r.getStreamForFrame(frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if stream.receivedUseDict == true || stream.dictionaries.read == nil {
|
||||
return r.streamError(streamID, http2.ErrCodeInternal)
|
||||
}
|
||||
|
||||
stream.receivedUseDict = true
|
||||
dictID := payload[0]
|
||||
|
||||
dictReader := stream.dictionaries.read.newReader(stream.readBuffer.(*SharedBuffer), dictID)
|
||||
if dictReader == nil {
|
||||
return r.streamError(streamID, http2.ErrCodeInternal)
|
||||
}
|
||||
|
||||
stream.readBufferLock.Lock()
|
||||
stream.readBuffer = dictReader
|
||||
stream.readBufferLock.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Receive a SET_DICTIONARY from the peer. Update dictionaries accordingly.
|
||||
func (r *MuxReader) receiveSetDictionary(frame *http2.UnknownFrame) (err error) {
|
||||
|
||||
payload := frame.Payload()
|
||||
flags := frame.Flags
|
||||
|
||||
stream, err := r.getStreamForFrame(frame)
|
||||
if err != nil && err != ErrClosedStream {
|
||||
return err
|
||||
}
|
||||
reader, ok := stream.readBuffer.(*h2DictionaryReader)
|
||||
if !ok {
|
||||
return r.streamError(frame.StreamID, http2.ErrCodeProtocol)
|
||||
}
|
||||
|
||||
// A SetDictionary frame consists of several
|
||||
// Dictionary-Entries that specify how existing dictionaries
|
||||
// are to be updated using the current stream data
|
||||
// +---------------+---------------+
|
||||
// | Dictionary-Entry (+) ...
|
||||
// +---------------+---------------+
|
||||
|
||||
for {
|
||||
// Each Dictionary-Entry is formatted as follows:
|
||||
// +-------------------------------+
|
||||
// | Dictionary-ID (8) |
|
||||
// +---+---------------------------+
|
||||
// | P | Size (7+) |
|
||||
// +---+---------------------------+
|
||||
// | E?| D?| Truncate? (6+) |
|
||||
// +---+---------------------------+
|
||||
// | Offset? (8+) |
|
||||
// +-------------------------------+
|
||||
|
||||
var size, truncate, offset uint64
|
||||
var p, e, d bool
|
||||
|
||||
// Parse a single Dictionary-Entry
|
||||
if len(payload) < 2 { // Must have at least id and size
|
||||
return MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
|
||||
}
|
||||
|
||||
dictID := uint8(payload[0])
|
||||
p = (uint8(payload[1]) >> 7) == 1
|
||||
payload, size, err = http2ReadVarInt(7, payload[1:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if flags.Has(FlagSetDictionaryAppend) {
|
||||
// Presence of FlagSetDictionaryAppend means we expect e, d and truncate
|
||||
if len(payload) < 1 {
|
||||
return MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
|
||||
}
|
||||
e = (uint8(payload[0]) >> 7) == 1
|
||||
d = (uint8((payload[0])>>6) & 1) == 1
|
||||
payload, truncate, err = http2ReadVarInt(6, payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if flags.Has(FlagSetDictionaryOffset) {
|
||||
// Presence of FlagSetDictionaryOffset means we expect offset
|
||||
if len(payload) < 1 {
|
||||
return MuxerStreamError{"unexpected EOF", http2.ErrCodeProtocol}
|
||||
}
|
||||
payload, offset, err = http2ReadVarInt(8, payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
setdict := setDictRequest{streamID: stream.streamID,
|
||||
dictID: dictID,
|
||||
dictSZ: size,
|
||||
truncate: truncate,
|
||||
offset: offset,
|
||||
P: p,
|
||||
E: e,
|
||||
D: d}
|
||||
|
||||
// Find the right dictionary
|
||||
dict, err := r.dictionaries.read.getDictByID(dictID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register a dictionary update order for the dictionary and reader
|
||||
updateEntry := &dictUpdate{reader: reader, dictionary: dict, s: setdict}
|
||||
dict.queue = append(dict.queue, updateEntry)
|
||||
reader.queue = append(reader.queue, updateEntry)
|
||||
// End of frame
|
||||
if len(payload) == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Receives header frames from a stream. A non-nil error is a connection error.
|
||||
func (r *MuxReader) updateStreamWindow(frame *http2.WindowUpdateFrame) error {
|
||||
stream, err := r.getStreamForFrame(frame)
|
||||
if err != nil && err != ErrUnknownStream && err != ErrClosedStream {
|
||||
return err
|
||||
}
|
||||
if stream == nil {
|
||||
// ignore window updates on closed streams
|
||||
return nil
|
||||
}
|
||||
stream.replenishSendWindow(frame.Increment)
|
||||
r.metricsUpdater.updateSendWindow(stream.getSendWindow())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Raise a stream processing error, closing the stream. Runs on the write thread.
|
||||
func (r *MuxReader) streamError(streamID uint32, e http2.ErrCode) error {
|
||||
r.streamErrors.RaiseError(streamID, e)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MuxReader) connectionError(err error) error {
|
||||
http2Code := http2.ErrCodeInternal
|
||||
switch e := err.(type) {
|
||||
case http2.ConnectionError:
|
||||
http2Code = http2.ErrCode(e)
|
||||
case MuxerProtocolError:
|
||||
http2Code = e.h2code
|
||||
}
|
||||
r.sendGoAway(http2Code)
|
||||
return err
|
||||
}
|
||||
|
||||
// Instruct the writer to send a GOAWAY message if possible. This may fail in
|
||||
// the case where an existing GOAWAY message is in flight or the writer event
|
||||
// loop already ended.
|
||||
func (r *MuxReader) sendGoAway(errCode http2.ErrCode) {
|
||||
select {
|
||||
case r.goAwayChan <- errCode:
|
||||
default:
|
||||
}
|
||||
}
|
|
@ -1,88 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
methodHeader = Header{
|
||||
Name: ":method",
|
||||
Value: "GET",
|
||||
}
|
||||
schemeHeader = Header{
|
||||
Name: ":scheme",
|
||||
Value: "https",
|
||||
}
|
||||
pathHeader = Header{
|
||||
Name: ":path",
|
||||
Value: "/api/tunnels",
|
||||
}
|
||||
respStatusHeader = Header{
|
||||
Name: ":status",
|
||||
Value: "200",
|
||||
}
|
||||
)
|
||||
|
||||
type mockOriginStreamHandler struct {
|
||||
stream *MuxedStream
|
||||
}
|
||||
|
||||
func (mosh *mockOriginStreamHandler) ServeStream(stream *MuxedStream) error {
|
||||
mosh.stream = stream
|
||||
// Echo tunnel hostname in header
|
||||
stream.WriteHeaders([]Header{respStatusHeader})
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertOpenStreamSucceed(t *testing.T, stream *MuxedStream, err error) {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, stream.Headers, 1)
|
||||
assert.Equal(t, respStatusHeader, stream.Headers[0])
|
||||
}
|
||||
|
||||
func TestMissingHeaders(t *testing.T) {
|
||||
originHandler := &mockOriginStreamHandler{}
|
||||
muxPair := NewDefaultMuxerPair(t, t.Name(), originHandler.ServeStream)
|
||||
muxPair.Serve(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
reqHeaders := []Header{
|
||||
{
|
||||
Name: "content-type",
|
||||
Value: "application/json",
|
||||
},
|
||||
}
|
||||
|
||||
stream, err := muxPair.EdgeMux.OpenStream(ctx, reqHeaders, nil)
|
||||
assertOpenStreamSucceed(t, stream, err)
|
||||
|
||||
assert.Empty(t, originHandler.stream.method)
|
||||
assert.Empty(t, originHandler.stream.path)
|
||||
}
|
||||
|
||||
func TestReceiveHeaderData(t *testing.T) {
|
||||
originHandler := &mockOriginStreamHandler{}
|
||||
muxPair := NewDefaultMuxerPair(t, t.Name(), originHandler.ServeStream)
|
||||
muxPair.Serve(t)
|
||||
|
||||
reqHeaders := []Header{
|
||||
methodHeader,
|
||||
schemeHeader,
|
||||
pathHeader,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
stream, err := muxPair.EdgeMux.OpenStream(ctx, reqHeaders, nil)
|
||||
assertOpenStreamSucceed(t, stream, err)
|
||||
|
||||
assert.Equal(t, methodHeader.Value, originHandler.stream.method)
|
||||
assert.Equal(t, pathHeader.Value, originHandler.stream.path)
|
||||
}
|
|
@ -1,311 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
type MuxWriter struct {
|
||||
// f is used to write HTTP2 frames.
|
||||
f *http2.Framer
|
||||
// streams tracks currently-open streams.
|
||||
streams *activeStreamMap
|
||||
// streamErrors receives stream errors raised by the MuxReader.
|
||||
streamErrors *StreamErrorMap
|
||||
// readyStreamChan is used to multiplex writable streams onto the single connection.
|
||||
// When a stream becomes writable its ID is sent on this channel.
|
||||
readyStreamChan <-chan uint32
|
||||
// newStreamChan is used to create new streams with a given set of headers.
|
||||
newStreamChan <-chan MuxedStreamRequest
|
||||
// goAwayChan is used to send a single GOAWAY message to the peer. The element received
|
||||
// is the HTTP/2 error code to send.
|
||||
goAwayChan <-chan http2.ErrCode
|
||||
// abortChan is used when shutting down ungracefully. When this becomes readable, all activity should stop.
|
||||
abortChan <-chan struct{}
|
||||
// pingTimestamp is an atomic value containing the latest received ping timestamp.
|
||||
pingTimestamp *PingTimestamp
|
||||
// A timer used to measure idle connection time. Reset after sending data.
|
||||
idleTimer *IdleTimer
|
||||
// connActiveChan receives a signal that the connection received some (read) activity.
|
||||
connActiveChan <-chan struct{}
|
||||
// Maximum size of all frames that can be sent on this connection.
|
||||
maxFrameSize uint32
|
||||
// headerEncoder is the stateful header encoder for this connection
|
||||
headerEncoder *hpack.Encoder
|
||||
// headerBuffer is the temporary buffer used by headerEncoder.
|
||||
headerBuffer bytes.Buffer
|
||||
|
||||
// metricsUpdater is used to report metrics
|
||||
metricsUpdater muxMetricsUpdater
|
||||
// bytesWrote is the amount of bytes written to data frames since the last time we called metricsUpdater.updateOutBoundBytes()
|
||||
bytesWrote *AtomicCounter
|
||||
|
||||
useDictChan <-chan useDictRequest
|
||||
}
|
||||
|
||||
type MuxedStreamRequest struct {
|
||||
stream *MuxedStream
|
||||
body io.Reader
|
||||
}
|
||||
|
||||
func NewMuxedStreamRequest(stream *MuxedStream, body io.Reader) MuxedStreamRequest {
|
||||
return MuxedStreamRequest{
|
||||
stream: stream,
|
||||
body: body,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MuxedStreamRequest) flushBody() {
|
||||
io.Copy(r.stream, r.body)
|
||||
r.stream.CloseWrite()
|
||||
}
|
||||
|
||||
func tsToPingData(ts int64) [8]byte {
|
||||
pingData := [8]byte{}
|
||||
binary.LittleEndian.PutUint64(pingData[:], uint64(ts))
|
||||
return pingData
|
||||
}
|
||||
|
||||
func (w *MuxWriter) run(log *zerolog.Logger) error {
|
||||
defer log.Debug().Msg("mux - write: event loop finished")
|
||||
|
||||
// routine to periodically communicate bytesWrote
|
||||
go func() {
|
||||
ticker := time.NewTicker(updateFreq)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-w.abortChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.metricsUpdater.updateOutBoundBytes(w.bytesWrote.Count())
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.abortChan:
|
||||
log.Debug().Msg("mux - write: aborting writer thread")
|
||||
return nil
|
||||
case errCode := <-w.goAwayChan:
|
||||
log.Debug().Msgf("mux - write: sending GOAWAY code %v", errCode)
|
||||
err := w.f.WriteGoAway(w.streams.LastPeerStreamID(), errCode, []byte{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case <-w.pingTimestamp.GetUpdateChan():
|
||||
log.Debug().Msg("mux - write: sending PING ACK")
|
||||
err := w.f.WritePing(true, tsToPingData(w.pingTimestamp.Get()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case <-w.idleTimer.C:
|
||||
if !w.idleTimer.Retry() {
|
||||
return ErrConnectionDropped
|
||||
}
|
||||
log.Debug().Msg("mux - write: sending PING")
|
||||
err := w.f.WritePing(false, tsToPingData(time.Now().UnixNano()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.ResetTimer()
|
||||
case <-w.connActiveChan:
|
||||
w.idleTimer.MarkActive()
|
||||
case <-w.streamErrors.GetSignalChan():
|
||||
for streamID, errCode := range w.streamErrors.GetErrors() {
|
||||
log.Debug().Msgf("mux - write: resetting stream with code: %v streamID: %d", errCode, streamID)
|
||||
err := w.f.WriteRSTStream(streamID, errCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case streamRequest := <-w.newStreamChan:
|
||||
streamID := w.streams.AcquireLocalID()
|
||||
streamRequest.stream.streamID = streamID
|
||||
if !w.streams.Set(streamRequest.stream) {
|
||||
// Race between OpenStream and Shutdown, and Shutdown won. Let Shutdown (and the eventual abort) take
|
||||
// care of this stream. Ideally we'd pass the error directly to the stream object somehow so the
|
||||
// caller can be unblocked sooner, but the value of that optimisation is minimal for most of the
|
||||
// reasons why you'd call Shutdown anyway.
|
||||
continue
|
||||
}
|
||||
if streamRequest.body != nil {
|
||||
go streamRequest.flushBody()
|
||||
}
|
||||
err := w.writeStreamData(streamRequest.stream, log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case streamID := <-w.readyStreamChan:
|
||||
stream, ok := w.streams.Get(streamID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
err := w.writeStreamData(stream, log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
case useDict := <-w.useDictChan:
|
||||
err := w.writeUseDictionary(useDict)
|
||||
if err != nil {
|
||||
log.Error().Msgf("mux - write: error writing use dictionary: %s", err)
|
||||
return err
|
||||
}
|
||||
w.idleTimer.MarkActive()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *MuxWriter) writeStreamData(stream *MuxedStream, log *zerolog.Logger) error {
|
||||
log.Debug().Msgf("mux - write: writable: streamID: %d", stream.streamID)
|
||||
chunk := stream.getChunk()
|
||||
w.metricsUpdater.updateReceiveWindow(stream.getReceiveWindow())
|
||||
w.metricsUpdater.updateSendWindow(stream.getSendWindow())
|
||||
if chunk.sendHeadersFrame() {
|
||||
err := w.writeHeaders(chunk.streamID, chunk.headers)
|
||||
if err != nil {
|
||||
log.Error().Msgf("mux - write: error writing headers: %s: streamID: %d", err, stream.streamID)
|
||||
return err
|
||||
}
|
||||
log.Debug().Msgf("mux - write: output headers: streamID: %d", stream.streamID)
|
||||
}
|
||||
|
||||
if chunk.sendWindowUpdateFrame() {
|
||||
// Send a WINDOW_UPDATE frame to update our receive window.
|
||||
// If the Stream ID is zero, the window update applies to the connection as a whole
|
||||
// RFC7540 section-6.9.1 "A receiver that receives a flow-controlled frame MUST
|
||||
// always account for its contribution against the connection flow-control
|
||||
// window, unless the receiver treats this as a connection error"
|
||||
err := w.f.WriteWindowUpdate(chunk.streamID, chunk.windowUpdate)
|
||||
if err != nil {
|
||||
log.Error().Msgf("mux - write: error writing window update: %s: streamID: %d", err, stream.streamID)
|
||||
return err
|
||||
}
|
||||
log.Debug().Msgf("mux - write: increment receive window by %d streamID: %d", chunk.windowUpdate, stream.streamID)
|
||||
}
|
||||
|
||||
for chunk.sendDataFrame() {
|
||||
payload, sentEOF := chunk.nextDataFrame(int(w.maxFrameSize))
|
||||
err := w.f.WriteData(chunk.streamID, sentEOF, payload)
|
||||
if err != nil {
|
||||
log.Error().Msgf("mux - write: error writing data: %s: streamID: %d", err, stream.streamID)
|
||||
return err
|
||||
}
|
||||
// update the amount of data wrote
|
||||
w.bytesWrote.IncrementBy(uint64(len(payload)))
|
||||
log.Debug().Msgf("mux - write: output data: %d: streamID: %d", len(payload), stream.streamID)
|
||||
|
||||
if sentEOF {
|
||||
if stream.readBuffer.Closed() {
|
||||
// transition into closed state
|
||||
if !stream.gotReceiveEOF() {
|
||||
// the peer may send data that we no longer want to receive. Force them into the
|
||||
// closed state.
|
||||
log.Debug().Msgf("mux - write: resetting stream: streamID: %d", stream.streamID)
|
||||
w.f.WriteRSTStream(chunk.streamID, http2.ErrCodeNo)
|
||||
} else {
|
||||
// Half-open stream transitioned into closed
|
||||
log.Debug().Msgf("mux - write: closing stream: streamID: %d", stream.streamID)
|
||||
}
|
||||
w.streams.Delete(chunk.streamID)
|
||||
} else {
|
||||
log.Debug().Msgf("mux - write: closing stream write side: streamID: %d", stream.streamID)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *MuxWriter) encodeHeaders(headers []Header) ([]byte, error) {
|
||||
w.headerBuffer.Reset()
|
||||
for _, header := range headers {
|
||||
err := w.headerEncoder.WriteField(hpack.HeaderField{
|
||||
Name: header.Name,
|
||||
Value: header.Value,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return w.headerBuffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// writeHeaders writes a block of encoded headers, splitting it into multiple frames if necessary.
|
||||
func (w *MuxWriter) writeHeaders(streamID uint32, headers []Header) error {
|
||||
encodedHeaders, err := w.encodeHeaders(headers)
|
||||
if err != nil || len(encodedHeaders) == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
blockSize := int(w.maxFrameSize)
|
||||
// CONTINUATION is unnecessary; the headers fit within the blockSize
|
||||
if len(encodedHeaders) < blockSize {
|
||||
return w.f.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: streamID,
|
||||
EndHeaders: true,
|
||||
BlockFragment: encodedHeaders,
|
||||
})
|
||||
}
|
||||
|
||||
choppedHeaders := chopEncodedHeaders(encodedHeaders, blockSize)
|
||||
// len(choppedHeaders) is at least 2
|
||||
if err := w.f.WriteHeaders(http2.HeadersFrameParam{StreamID: streamID, EndHeaders: false, BlockFragment: choppedHeaders[0]}); err != nil {
|
||||
return err
|
||||
}
|
||||
for i := 1; i < len(choppedHeaders)-1; i++ {
|
||||
if err := w.f.WriteContinuation(streamID, false, choppedHeaders[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := w.f.WriteContinuation(streamID, true, choppedHeaders[len(choppedHeaders)-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Partition a slice of bytes into `len(slice) / blockSize` slices of length `blockSize`
|
||||
func chopEncodedHeaders(headers []byte, chunkSize int) [][]byte {
|
||||
var divided [][]byte
|
||||
|
||||
for i := 0; i < len(headers); i += chunkSize {
|
||||
end := i + chunkSize
|
||||
|
||||
if end > len(headers) {
|
||||
end = len(headers)
|
||||
}
|
||||
|
||||
divided = append(divided, headers[i:end])
|
||||
}
|
||||
|
||||
return divided
|
||||
}
|
||||
|
||||
func (w *MuxWriter) writeUseDictionary(dictRequest useDictRequest) error {
|
||||
err := w.f.WriteRawFrame(FrameUseDictionary, 0, dictRequest.streamID, []byte{byte(dictRequest.dictID)})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := make([]byte, 0, 64)
|
||||
for _, set := range dictRequest.setDict {
|
||||
payload = append(payload, byte(set.dictID))
|
||||
payload = appendVarInt(payload, 7, uint64(set.dictSZ))
|
||||
payload = append(payload, 0x80) // E = 1, D = 0, Truncate = 0
|
||||
}
|
||||
|
||||
err = w.f.WriteRawFrame(FrameSetDictionary, FlagSetDictionaryAppend, dictRequest.streamID, payload)
|
||||
return err
|
||||
}
|
|
@ -1,26 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestChopEncodedHeaders(t *testing.T) {
|
||||
mockEncodedHeaders := make([]byte, 5)
|
||||
for i := range mockEncodedHeaders {
|
||||
mockEncodedHeaders[i] = byte(i)
|
||||
}
|
||||
chopped := chopEncodedHeaders(mockEncodedHeaders, 4)
|
||||
|
||||
assert.Equal(t, 2, len(chopped))
|
||||
assert.Equal(t, []byte{0, 1, 2, 3}, chopped[0])
|
||||
assert.Equal(t, []byte{4}, chopped[1])
|
||||
}
|
||||
|
||||
func TestChopEncodedEmptyHeaders(t *testing.T) {
|
||||
mockEncodedHeaders := make([]byte, 0)
|
||||
chopped := chopEncodedHeaders(mockEncodedHeaders, 3)
|
||||
|
||||
assert.Equal(t, 0, len(chopped))
|
||||
}
|
|
@ -1,151 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import "sync"
|
||||
|
||||
// ReadyList multiplexes several event signals onto a single channel.
|
||||
type ReadyList struct {
|
||||
// signalC is used to signal that a stream can be enqueued
|
||||
signalC chan uint32
|
||||
// waitC is used to signal the ID of the first ready descriptor
|
||||
waitC chan uint32
|
||||
// doneC is used to signal that run should terminate
|
||||
doneC chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func NewReadyList() *ReadyList {
|
||||
rl := &ReadyList{
|
||||
signalC: make(chan uint32),
|
||||
waitC: make(chan uint32),
|
||||
doneC: make(chan struct{}),
|
||||
}
|
||||
go rl.run()
|
||||
return rl
|
||||
}
|
||||
|
||||
// ID is the stream ID
|
||||
func (r *ReadyList) Signal(ID uint32) {
|
||||
select {
|
||||
case r.signalC <- ID:
|
||||
// ReadyList already closed
|
||||
case <-r.doneC:
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ReadyList) ReadyChannel() <-chan uint32 {
|
||||
return r.waitC
|
||||
}
|
||||
|
||||
func (r *ReadyList) Close() {
|
||||
r.closeOnce.Do(func() {
|
||||
close(r.doneC)
|
||||
})
|
||||
}
|
||||
|
||||
func (r *ReadyList) run() {
|
||||
defer close(r.waitC)
|
||||
var queue readyDescriptorQueue
|
||||
var firstReady *readyDescriptor
|
||||
activeDescriptors := newReadyDescriptorMap()
|
||||
for {
|
||||
if firstReady == nil {
|
||||
select {
|
||||
case i := <-r.signalC:
|
||||
firstReady = activeDescriptors.SetIfMissing(i)
|
||||
case <-r.doneC:
|
||||
return
|
||||
}
|
||||
}
|
||||
select {
|
||||
case r.waitC <- firstReady.ID:
|
||||
activeDescriptors.Delete(firstReady.ID)
|
||||
firstReady = queue.Dequeue()
|
||||
case i := <-r.signalC:
|
||||
newReady := activeDescriptors.SetIfMissing(i)
|
||||
if newReady != nil {
|
||||
// key doesn't exist
|
||||
queue.Enqueue(newReady)
|
||||
}
|
||||
case <-r.doneC:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type readyDescriptor struct {
|
||||
ID uint32
|
||||
Next *readyDescriptor
|
||||
}
|
||||
|
||||
// readyDescriptorQueue is a queue of readyDescriptors in the form of a singly-linked list.
|
||||
// The nil readyDescriptorQueue is an empty queue ready for use.
|
||||
type readyDescriptorQueue struct {
|
||||
Head *readyDescriptor
|
||||
Tail *readyDescriptor
|
||||
}
|
||||
|
||||
func (q *readyDescriptorQueue) Empty() bool {
|
||||
return q.Head == nil
|
||||
}
|
||||
|
||||
func (q *readyDescriptorQueue) Enqueue(x *readyDescriptor) {
|
||||
if x.Next != nil {
|
||||
panic("enqueued already queued item")
|
||||
}
|
||||
if q.Empty() {
|
||||
q.Head = x
|
||||
q.Tail = x
|
||||
} else {
|
||||
q.Tail.Next = x
|
||||
q.Tail = x
|
||||
}
|
||||
}
|
||||
|
||||
// Dequeue returns the first readyDescriptor in the queue, or nil if empty.
|
||||
func (q *readyDescriptorQueue) Dequeue() *readyDescriptor {
|
||||
if q.Empty() {
|
||||
return nil
|
||||
}
|
||||
x := q.Head
|
||||
q.Head = x.Next
|
||||
x.Next = nil
|
||||
return x
|
||||
}
|
||||
|
||||
// readyDescriptorQueue is a map of readyDescriptors keyed by ID.
|
||||
// It maintains a free list of deleted ready descriptors.
|
||||
type readyDescriptorMap struct {
|
||||
descriptors map[uint32]*readyDescriptor
|
||||
free []*readyDescriptor
|
||||
}
|
||||
|
||||
func newReadyDescriptorMap() *readyDescriptorMap {
|
||||
return &readyDescriptorMap{descriptors: make(map[uint32]*readyDescriptor)}
|
||||
}
|
||||
|
||||
// create or reuse a readyDescriptor if the stream is not in the queue.
|
||||
// This avoid stream starvation caused by a single high-bandwidth stream monopolising the writer goroutine
|
||||
func (m *readyDescriptorMap) SetIfMissing(key uint32) *readyDescriptor {
|
||||
if _, ok := m.descriptors[key]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var newDescriptor *readyDescriptor
|
||||
if len(m.free) > 0 {
|
||||
// reuse deleted ready descriptors
|
||||
newDescriptor = m.free[len(m.free)-1]
|
||||
m.free = m.free[:len(m.free)-1]
|
||||
} else {
|
||||
newDescriptor = &readyDescriptor{}
|
||||
}
|
||||
newDescriptor.ID = key
|
||||
m.descriptors[key] = newDescriptor
|
||||
return newDescriptor
|
||||
}
|
||||
|
||||
func (m *readyDescriptorMap) Delete(key uint32) {
|
||||
if descriptor, ok := m.descriptors[key]; ok {
|
||||
m.free = append(m.free, descriptor)
|
||||
delete(m.descriptors, key)
|
||||
}
|
||||
}
|
|
@ -1,171 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func assertEmpty(t *testing.T, rl *ReadyList) {
|
||||
select {
|
||||
case <-rl.ReadyChannel():
|
||||
t.Fatal("Spurious wakeup")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func assertClosed(t *testing.T, rl *ReadyList) {
|
||||
select {
|
||||
case _, ok := <-rl.ReadyChannel():
|
||||
assert.False(t, ok, "ReadyChannel was not closed")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatalf("Timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func receiveWithTimeout(t *testing.T, rl *ReadyList) uint32 {
|
||||
select {
|
||||
case i := <-rl.ReadyChannel():
|
||||
return i
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatalf("Timeout")
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadyListEmpty(t *testing.T) {
|
||||
rl := NewReadyList()
|
||||
|
||||
// no signals, receive should fail
|
||||
assertEmpty(t, rl)
|
||||
}
|
||||
func TestReadyListSignal(t *testing.T) {
|
||||
rl := NewReadyList()
|
||||
assertEmpty(t, rl)
|
||||
|
||||
rl.Signal(0)
|
||||
if receiveWithTimeout(t, rl) != 0 {
|
||||
t.Fatalf("Received wrong ID of signalled event")
|
||||
}
|
||||
|
||||
assertEmpty(t, rl)
|
||||
}
|
||||
|
||||
func TestReadyListMultipleSignals(t *testing.T) {
|
||||
rl := NewReadyList()
|
||||
assertEmpty(t, rl)
|
||||
|
||||
// Signals should not block;
|
||||
// Duplicate unhandled signals should not cause multiple wakeups
|
||||
signalled := [5]bool{}
|
||||
for i := range signalled {
|
||||
rl.Signal(uint32(i))
|
||||
rl.Signal(uint32(i))
|
||||
}
|
||||
// All signals should be received once (in any order)
|
||||
for range signalled {
|
||||
i := receiveWithTimeout(t, rl)
|
||||
if signalled[i] {
|
||||
t.Fatalf("Received signal %d more than once", i)
|
||||
}
|
||||
signalled[i] = true
|
||||
}
|
||||
for i := range signalled {
|
||||
if !signalled[i] {
|
||||
t.Fatalf("Never received signal %d", i)
|
||||
}
|
||||
}
|
||||
assertEmpty(t, rl)
|
||||
}
|
||||
|
||||
func TestReadyListClose(t *testing.T) {
|
||||
rl := NewReadyList()
|
||||
rl.Close()
|
||||
|
||||
// readyList.run() occurs in a separate goroutine,
|
||||
// so there's no way to directly check that run() has terminated.
|
||||
// Perform an indirect check: is the ready channel closed?
|
||||
assertClosed(t, rl)
|
||||
|
||||
// a second rl.Close() shouldn't cause a panic
|
||||
rl.Close()
|
||||
|
||||
// Signal shouldn't block after Close()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for i := 0; i < 5; i++ {
|
||||
rl.Signal(uint32(i))
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Test timed out")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadyDescriptorQueue(t *testing.T) {
|
||||
var queue readyDescriptorQueue
|
||||
items := [4]readyDescriptor{}
|
||||
for i := range items {
|
||||
items[i].ID = uint32(i)
|
||||
}
|
||||
|
||||
if !queue.Empty() {
|
||||
t.Fatalf("nil queue should be empty")
|
||||
}
|
||||
queue.Enqueue(&items[3])
|
||||
queue.Enqueue(&items[1])
|
||||
queue.Enqueue(&items[0])
|
||||
queue.Enqueue(&items[2])
|
||||
if queue.Empty() {
|
||||
t.Fatalf("Empty should be false after enqueue")
|
||||
}
|
||||
i := queue.Dequeue().ID
|
||||
if i != 3 {
|
||||
t.Fatalf("item 3 should have been dequeued, got %d instead", i)
|
||||
}
|
||||
i = queue.Dequeue().ID
|
||||
if i != 1 {
|
||||
t.Fatalf("item 1 should have been dequeued, got %d instead", i)
|
||||
}
|
||||
i = queue.Dequeue().ID
|
||||
if i != 0 {
|
||||
t.Fatalf("item 0 should have been dequeued, got %d instead", i)
|
||||
}
|
||||
i = queue.Dequeue().ID
|
||||
if i != 2 {
|
||||
t.Fatalf("item 2 should have been dequeued, got %d instead", i)
|
||||
}
|
||||
if !queue.Empty() {
|
||||
t.Fatal("queue should be empty after dequeuing all items")
|
||||
}
|
||||
if queue.Dequeue() != nil {
|
||||
t.Fatal("dequeue on empty queue should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadyDescriptorMap(t *testing.T) {
|
||||
m := newReadyDescriptorMap()
|
||||
m.Delete(42)
|
||||
// (delete of missing key should be a noop)
|
||||
x := m.SetIfMissing(42)
|
||||
if x == nil {
|
||||
t.Fatal("SetIfMissing for new key returned nil")
|
||||
}
|
||||
if m.SetIfMissing(42) != nil {
|
||||
t.Fatal("SetIfMissing for existing key returned non-nil")
|
||||
}
|
||||
// this delete has effect
|
||||
m.Delete(42)
|
||||
// the next set should reuse the old object
|
||||
y := m.SetIfMissing(666)
|
||||
if y == nil {
|
||||
t.Fatal("SetIfMissing for new key returned nil")
|
||||
}
|
||||
if x != y {
|
||||
t.Fatal("SetIfMissing didn't reuse freed object")
|
||||
}
|
||||
}
|
29
h2mux/rtt.go
29
h2mux/rtt.go
|
@ -1,29 +0,0 @@
|
|||
package h2mux
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// PingTimestamp is an atomic interface around ping timestamping and signalling.
|
||||
type PingTimestamp struct {
|
||||
ts int64
|
||||
signal Signal
|
||||
}
|
||||
|
||||
func NewPingTimestamp() *PingTimestamp {
|
||||
return &PingTimestamp{signal: NewSignal()}
|
||||
}
|
||||
|
||||
func (pt *PingTimestamp) Set(v int64) {
|
||||
if atomic.SwapInt64(&pt.ts, v) != 0 {
|
||||
pt.signal.Signal()
|
||||
}
|
||||
}
|
||||
|
||||
func (pt *PingTimestamp) Get() int64 {
|
||||
return atomic.SwapInt64(&pt.ts, 0)
|
||||
}
|
||||
|
||||
func (pt *PingTimestamp) GetUpdateChan() <-chan struct{} {
|
||||
return pt.signal.WaitChannel()
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
!function(){"use strict";function a(a){var b,c=[];if(!a)return"";for(b in a)a.hasOwnProperty(b)&&(a[b]||a[b]===!1)&&c.push(b+"="+encodeURIComponent(a[b]));return c.length?"?"+c.join("&"):""}var b,c,d,e,f="https://cloudflare.ghost.io/ghost/api/v0.1/";d={api:function(){var d,e=Array.prototype.slice.call(arguments),g=f;return d=e.pop(),d&&"object"!=typeof d&&(e.push(d),d={}),d=d||{},d.client_id=b,d.client_secret=c,e.length&&e.forEach(function(a){g+=a.replace(/^\/|\/$/g,"")+"/"}),g+a(d)}},e=function(a){b=a.clientId?a.clientId:"",c=a.clientSecret?a.clientSecret:"",f=a.url?a.url:f.match(/{\{api-url}}/)?"":f},"undefined"!=typeof window&&(window.ghost=window.ghost||{},window.ghost.url=d,window.ghost.init=e),"undefined"!=typeof module&&(module.exports={url:d,init:e})}();
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue