Compare commits

..

No commits in common. "master" and "2024.11.1" have entirely different histories.

328 changed files with 8192 additions and 29951 deletions

View File

@ -1,89 +0,0 @@
linters:
enable:
# Some of the linters below are commented out. We should uncomment and start running them, but they return
# too many problems to fix in one commit. Something for later.
- asasalint # Check for pass []any as any in variadic func(...any).
- asciicheck # Checks that all code identifiers does not have non-ASCII symbols in the name.
- bidichk # Checks for dangerous unicode character sequences.
- bodyclose # Checks whether HTTP response body is closed successfully.
- decorder # Check declaration order and count of types, constants, variables and functions.
- dogsled # Checks assignments with too many blank identifiers (e.g. x, , , _, := f()).
- dupl # Tool for code clone detection.
- dupword # Checks for duplicate words in the source code.
- durationcheck # Check for two durations multiplied together.
- errcheck # Errcheck is a program for checking for unchecked errors in Go code. These unchecked errors can be critical bugs in some cases.
- errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error.
- exhaustive # Check exhaustiveness of enum switch statements.
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification.
- goimports # Check import statements are formatted according to the 'goimport' command. Reformat imports in autofix mode.
- gosec # Inspects source code for security problems.
- gosimple # Linter for Go source code that specializes in simplifying code.
- govet # Vet examines Go source code and reports suspicious constructs. It is roughly the same as 'go vet' and uses its passes.
- ineffassign # Detects when assignments to existing variables are not used.
- importas # Enforces consistent import aliases.
- misspell # Finds commonly misspelled English words.
- prealloc # Finds slice declarations that could potentially be pre-allocated.
- promlinter # Check Prometheus metrics naming via promlint.
- sloglint # Ensure consistent code style when using log/slog.
- sqlclosecheck # Checks that sql.Rows, sql.Stmt, sqlx.NamedStmt, pgx.Query are closed.
- staticcheck # It's a set of rules from staticcheck. It's not the same thing as the staticcheck binary.
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
- testableexamples # Linter checks if examples are testable (have an expected output).
- testifylint # Checks usage of github.com/stretchr/testify.
- tparallel # Tparallel detects inappropriate usage of t.Parallel() method in your Go test codes.
- unconvert # Remove unnecessary type conversions.
- unused # Checks Go code for unused constants, variables, functions and types.
- wastedassign # Finds wasted assignment statements.
- whitespace # Whitespace is a linter that checks for unnecessary newlines at the start and end of functions, if, for, etc.
- zerologlint # Detects the wrong usage of zerolog that a user forgets to dispatch with Send or Msg.
# Other linters are disabled, list of all is here: https://golangci-lint.run/usage/linters/
run:
timeout: 5m
modules-download-mode: vendor
# output configuration options
output:
formats:
- format: 'colored-line-number'
print-issued-lines: true
print-linter-name: true
issues:
# Maximum issues count per one linter.
# Set to 0 to disable.
# Default: 50
max-issues-per-linter: 50
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 15
# Show only new issues: if there are unstaged changes or untracked files,
# only those changes are analyzed, else only changes in HEAD~ are analyzed.
# It's a super-useful option for integration of golangci-lint into existing large codebase.
# It's not practical to fix all existing issues at the moment of integration:
# much better don't allow issues in new code.
#
# Default: false
new: true
# Show only new issues created after git revision `REV`.
# Default: ""
new-from-rev: ac34f94d423273c8fa8fdbb5f2ac60e55f2c77d5
# Show issues in any part of update files (requires new-from-rev or new-from-patch).
# Default: false
whole-files: true
# Which dirs to exclude: issues from them won't be reported.
# Can use regexp here: `generated.*`, regexp is applied on full path,
# including the path prefix if one is set.
# Default dirs are skipped independently of this option's value (see exclude-dirs-use-default).
# "/" will be replaced by current OS file path separator to properly work on Windows.
# Default: []
exclude-dirs:
- vendor
linters-settings:
# Check exhaustiveness of enum switch statements.
exhaustive:
# Presence of "default" case in switch statements satisfies exhaustiveness,
# even if all enum members are not listed.
# Default: false
default-signifies-exhaustive: true

View File

@ -3,6 +3,6 @@
cd /tmp cd /tmp
git clone -q https://github.com/cloudflare/go git clone -q https://github.com/cloudflare/go
cd go/src cd go/src
# https://github.com/cloudflare/go/tree/af19da5605ca11f85776ef7af3384a02a315a52b is version go1.22.5-devel-cf # https://github.com/cloudflare/go/tree/f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38 is version go1.22.5-devel-cf
git checkout -q af19da5605ca11f85776ef7af3384a02a315a52b git checkout -q f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38
./make.bash ./make.bash

150
.teamcity/mac/build.sh vendored
View File

@ -22,7 +22,6 @@ TARGET_DIRECTORY=".build"
BINARY_NAME="cloudflared" BINARY_NAME="cloudflared"
VERSION=$(git describe --tags --always --dirty="-dev") VERSION=$(git describe --tags --always --dirty="-dev")
PRODUCT="cloudflared" PRODUCT="cloudflared"
APPLE_CA_CERT="apple_dev_ca.cert"
CODE_SIGN_PRIV="code_sign.p12" CODE_SIGN_PRIV="code_sign.p12"
CODE_SIGN_CERT="code_sign.cer" CODE_SIGN_CERT="code_sign.cer"
INSTALLER_PRIV="installer.p12" INSTALLER_PRIV="installer.p12"
@ -36,84 +35,91 @@ mkdir -p ../src/github.com/cloudflare/
cp -r . ../src/github.com/cloudflare/cloudflared cp -r . ../src/github.com/cloudflare/cloudflared
cd ../src/github.com/cloudflare/cloudflared cd ../src/github.com/cloudflare/cloudflared
# Imports certificates to the Apple KeyChain
import_certificate() {
local CERTIFICATE_NAME=$1
local CERTIFICATE_ENV_VAR=$2
local CERTIFICATE_FILE_NAME=$3
echo "Importing $CERTIFICATE_NAME"
if [[ ! -z "$CERTIFICATE_ENV_VAR" ]]; then
# write certificate to disk and then import it keychain
echo -n -e ${CERTIFICATE_ENV_VAR} | base64 -D > ${CERTIFICATE_FILE_NAME}
# we set || true here and for every `security import invoke` because the "duplicate SecKeychainItemImport" error
# will cause set -e to exit 1. It is okay we do this because we deliberately handle this error in the lines below.
local out=$(security import ${CERTIFICATE_FILE_NAME} -A 2>&1) || true
local exitcode=$?
# delete the certificate from disk
rm -rf ${CERTIFICATE_FILE_NAME}
if [ -n "$out" ]; then
if [ $exitcode -eq 0 ]; then
echo "$out"
else
if [ "$out" != "${SEC_DUP_MSG}" ]; then
echo "$out" >&2
exit $exitcode
else
echo "already imported code signing certificate"
fi
fi
fi
fi
}
# Imports private keys to the Apple KeyChain
import_private_keys() {
local PRIVATE_KEY_NAME=$1
local PRIVATE_KEY_ENV_VAR=$2
local PRIVATE_KEY_FILE_NAME=$3
local PRIVATE_KEY_PASS=$4
echo "Importing $PRIVATE_KEY_NAME"
if [[ ! -z "$PRIVATE_KEY_ENV_VAR" ]]; then
if [[ ! -z "$PRIVATE_KEY_PASS" ]]; then
# write private key to disk and then import it keychain
echo -n -e ${PRIVATE_KEY_ENV_VAR} | base64 -D > ${PRIVATE_KEY_FILE_NAME}
# we set || true here and for every `security import invoke` because the "duplicate SecKeychainItemImport" error
# will cause set -e to exit 1. It is okay we do this because we deliberately handle this error in the lines below.
local out=$(security import ${PRIVATE_KEY_FILE_NAME} -A -P "${PRIVATE_KEY_PASS}" 2>&1) || true
local exitcode=$?
rm -rf ${PRIVATE_KEY_FILE_NAME}
if [ -n "$out" ]; then
if [ $exitcode -eq 0 ]; then
echo "$out"
else
if [ "$out" != "${SEC_DUP_MSG}" ]; then
echo "$out" >&2
exit $exitcode
fi
fi
fi
fi
fi
}
# Add Apple Root Developer certificate to the key chain
import_certificate "Apple Developer CA" "${APPLE_DEV_CA_CERT}" "${APPLE_CA_CERT}"
# Add code signing private key to the key chain # Add code signing private key to the key chain
import_private_keys "Developer ID Application" "${CFD_CODE_SIGN_KEY}" "${CODE_SIGN_PRIV}" "${CFD_CODE_SIGN_PASS}" if [[ ! -z "$CFD_CODE_SIGN_KEY" ]]; then
if [[ ! -z "$CFD_CODE_SIGN_PASS" ]]; then
# write private key to disk and then import it keychain
echo -n -e ${CFD_CODE_SIGN_KEY} | base64 -D > ${CODE_SIGN_PRIV}
# we set || true here and for every `security import invoke` because the "duplicate SecKeychainItemImport" error
# will cause set -e to exit 1. It is okay we do this because we deliberately handle this error in the lines below.
out=$(security import ${CODE_SIGN_PRIV} -A -P "${CFD_CODE_SIGN_PASS}" 2>&1) || true
exitcode=$?
if [ -n "$out" ]; then
if [ $exitcode -eq 0 ]; then
echo "$out"
else
if [ "$out" != "${SEC_DUP_MSG}" ]; then
echo "$out" >&2
exit $exitcode
fi
fi
fi
rm ${CODE_SIGN_PRIV}
fi
fi
# Add code signing certificate to the key chain # Add code signing certificate to the key chain
import_certificate "Developer ID Application" "${CFD_CODE_SIGN_CERT}" "${CODE_SIGN_CERT}" if [[ ! -z "$CFD_CODE_SIGN_CERT" ]]; then
# write certificate to disk and then import it keychain
echo -n -e ${CFD_CODE_SIGN_CERT} | base64 -D > ${CODE_SIGN_CERT}
out1=$(security import ${CODE_SIGN_CERT} -A 2>&1) || true
exitcode1=$?
if [ -n "$out1" ]; then
if [ $exitcode1 -eq 0 ]; then
echo "$out1"
else
if [ "$out1" != "${SEC_DUP_MSG}" ]; then
echo "$out1" >&2
exit $exitcode1
else
echo "already imported code signing certificate"
fi
fi
fi
rm ${CODE_SIGN_CERT}
fi
# Add package signing private key to the key chain # Add package signing private key to the key chain
import_private_keys "Developer ID Installer" "${CFD_INSTALLER_KEY}" "${INSTALLER_PRIV}" "${CFD_INSTALLER_PASS}" if [[ ! -z "$CFD_INSTALLER_KEY" ]]; then
if [[ ! -z "$CFD_INSTALLER_PASS" ]]; then
# write private key to disk and then import it into the keychain
echo -n -e ${CFD_INSTALLER_KEY} | base64 -D > ${INSTALLER_PRIV}
out2=$(security import ${INSTALLER_PRIV} -A -P "${CFD_INSTALLER_PASS}" 2>&1) || true
exitcode2=$?
if [ -n "$out2" ]; then
if [ $exitcode2 -eq 0 ]; then
echo "$out2"
else
if [ "$out2" != "${SEC_DUP_MSG}" ]; then
echo "$out2" >&2
exit $exitcode2
fi
fi
fi
rm ${INSTALLER_PRIV}
fi
fi
# Add package signing certificate to the key chain # Add package signing certificate to the key chain
import_certificate "Developer ID Installer" "${CFD_INSTALLER_CERT}" "${INSTALLER_CERT}" if [[ ! -z "$CFD_INSTALLER_CERT" ]]; then
# write certificate to disk and then import it keychain
echo -n -e ${CFD_INSTALLER_CERT} | base64 -D > ${INSTALLER_CERT}
out3=$(security import ${INSTALLER_CERT} -A 2>&1) || true
exitcode3=$?
if [ -n "$out3" ]; then
if [ $exitcode3 -eq 0 ]; then
echo "$out3"
else
if [ "$out3" != "${SEC_DUP_MSG}" ]; then
echo "$out3" >&2
exit $exitcode3
else
echo "already imported installer certificate"
fi
fi
fi
rm ${INSTALLER_CERT}
fi
# get the code signing certificate name # get the code signing certificate name
if [[ ! -z "$CFD_CODE_SIGN_NAME" ]]; then if [[ ! -z "$CFD_CODE_SIGN_NAME" ]]; then

View File

@ -9,8 +9,8 @@ Set-Location "$Env:Temp"
git clone -q https://github.com/cloudflare/go git clone -q https://github.com/cloudflare/go
Write-Output "Building go..." Write-Output "Building go..."
cd go/src cd go/src
# https://github.com/cloudflare/go/tree/af19da5605ca11f85776ef7af3384a02a315a52b is version go1.22.5-devel-cf # https://github.com/cloudflare/go/tree/f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38 is version go1.22.5-devel-cf
git checkout -q af19da5605ca11f85776ef7af3384a02a315a52b git checkout -q f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38
& ./make.bat & ./make.bat
Write-Output "Installed" Write-Output "Installed"

View File

@ -1,15 +1,3 @@
## 2025.1.1
### New Features
- This release introduces the use of new Post Quantum curves and the ability to use Post Quantum curves when running tunnels with the QUIC protocol this applies to non-FIPS and FIPS builds.
## 2024.12.2
### New Features
- This release introduces the ability to collect troubleshooting information from one instance of cloudflared running on the local machine. The command can be executed as `cloudflared tunnel diag`.
## 2024.12.1
### Notices
- The use of the `--metrics` is still honoured meaning that if this flag is set the metrics server will try to bind it, however, this version includes a change that makes the metrics server bind to a port with a semi-deterministic approach. If the metrics flag is not present the server will bind to the first available port of the range 20241 to 20245. In case of all ports being unavailable then the fallback is to bind to a random port.
## 2024.10.0 ## 2024.10.0
### Bug Fixes ### Bug Fixes
- We fixed a bug related to `--grace-period`. Tunnels that use QUIC as transport weren't abiding by this waiting period before forcefully closing the connections to the edge. From now on, both QUIC and HTTP2 tunnels will wait for either the grace period to end (defaults to 30 seconds) or until the last in-flight request is handled. Users that wish to maintain the previous behavior should set `--grace-period` to 0 if `--protocol` is set to `quic`. This will force `cloudflared` to shutdown as soon as either SIGTERM or SIGINT is received. - We fixed a bug related to `--grace-period`. Tunnels that use QUIC as transport weren't abiding by this waiting period before forcefully closing the connections to the edge. From now on, both QUIC and HTTP2 tunnels will wait for either the grace period to end (defaults to 30 seconds) or until the last in-flight request is handled. Users that wish to maintain the previous behavior should set `--grace-period` to 0 if `--protocol` is set to `quic`. This will force `cloudflared` to shutdown as soon as either SIGTERM or SIGINT is received.

View File

@ -1,15 +1,11 @@
# use a builder image for building cloudflare # use a builder image for building cloudflare
ARG TARGET_GOOS ARG TARGET_GOOS
ARG TARGET_GOARCH ARG TARGET_GOARCH
FROM golang:1.22.10 as builder FROM golang:1.22.5 as builder
ENV GO111MODULE=on \ ENV GO111MODULE=on \
CGO_ENABLED=0 \ CGO_ENABLED=0 \
TARGET_GOOS=${TARGET_GOOS} \ TARGET_GOOS=${TARGET_GOOS} \
TARGET_GOARCH=${TARGET_GOARCH} \ TARGET_GOARCH=${TARGET_GOARCH}
# the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual
# which changes how cloudflared binds the metrics server
CONTAINER_BUILD=1
WORKDIR /go/src/github.com/cloudflare/cloudflared/ WORKDIR /go/src/github.com/cloudflare/cloudflared/
@ -22,7 +18,7 @@ RUN .teamcity/install-cloudflare-go.sh
RUN PATH="/tmp/go/bin:$PATH" make cloudflared RUN PATH="/tmp/go/bin:$PATH" make cloudflared
# use a distroless base image with glibc # use a distroless base image with glibc
FROM gcr.io/distroless/base-debian12:nonroot FROM gcr.io/distroless/base-debian11:nonroot
LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared" LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared"

View File

@ -1,10 +1,7 @@
# use a builder image for building cloudflare # use a builder image for building cloudflare
FROM golang:1.22.10 as builder FROM golang:1.22.5 as builder
ENV GO111MODULE=on \ ENV GO111MODULE=on \
CGO_ENABLED=0 \ CGO_ENABLED=0
# the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual
# which changes how cloudflared binds the metrics server
CONTAINER_BUILD=1
WORKDIR /go/src/github.com/cloudflare/cloudflared/ WORKDIR /go/src/github.com/cloudflare/cloudflared/
@ -17,7 +14,7 @@ RUN .teamcity/install-cloudflare-go.sh
RUN GOOS=linux GOARCH=amd64 PATH="/tmp/go/bin:$PATH" make cloudflared RUN GOOS=linux GOARCH=amd64 PATH="/tmp/go/bin:$PATH" make cloudflared
# use a distroless base image with glibc # use a distroless base image with glibc
FROM gcr.io/distroless/base-debian12:nonroot FROM gcr.io/distroless/base-debian11:nonroot
LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared" LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared"

View File

@ -1,10 +1,7 @@
# use a builder image for building cloudflare # use a builder image for building cloudflare
FROM golang:1.22.10 as builder FROM golang:1.22.5 as builder
ENV GO111MODULE=on \ ENV GO111MODULE=on \
CGO_ENABLED=0 \ CGO_ENABLED=0
# the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual
# which changes how cloudflared binds the metrics server
CONTAINER_BUILD=1
WORKDIR /go/src/github.com/cloudflare/cloudflared/ WORKDIR /go/src/github.com/cloudflare/cloudflared/
@ -17,7 +14,7 @@ RUN .teamcity/install-cloudflare-go.sh
RUN GOOS=linux GOARCH=arm64 PATH="/tmp/go/bin:$PATH" make cloudflared RUN GOOS=linux GOARCH=arm64 PATH="/tmp/go/bin:$PATH" make cloudflared
# use a distroless base image with glibc # use a distroless base image with glibc
FROM gcr.io/distroless/base-debian12:nonroot-arm64 FROM gcr.io/distroless/base-debian11:nonroot-arm64
LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared" LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared"

View File

@ -24,16 +24,12 @@ else
DEB_PACKAGE_NAME := $(BINARY_NAME) DEB_PACKAGE_NAME := $(BINARY_NAME)
endif endif
DATE := $(shell date -u -r RELEASE_NOTES '+%Y-%m-%d-%H%M UTC') DATE := $(shell date -u '+%Y-%m-%d-%H%M UTC')
VERSION_FLAGS := -X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)" VERSION_FLAGS := -X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)"
ifdef PACKAGE_MANAGER ifdef PACKAGE_MANAGER
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)" VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)"
endif endif
ifdef CONTAINER_BUILD
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/metrics.Runtime=virtual"
endif
LINK_FLAGS := LINK_FLAGS :=
ifeq ($(FIPS), true) ifeq ($(FIPS), true)
LINK_FLAGS := -linkmode=external -extldflags=-static $(LINK_FLAGS) LINK_FLAGS := -linkmode=external -extldflags=-static $(LINK_FLAGS)
@ -133,9 +129,11 @@ clean:
cloudflared: cloudflared:
ifeq ($(FIPS), true) ifeq ($(FIPS), true)
$(info Building cloudflared with go-fips) $(info Building cloudflared with go-fips)
cp -f fips/fips.go.linux-amd64 cmd/cloudflared/fips.go
endif endif
GOOS=$(TARGET_OS) GOARCH=$(TARGET_ARCH) $(ARM_COMMAND) go build -mod=vendor $(GO_BUILD_TAGS) $(LDFLAGS) $(IMPORT_PATH)/cmd/cloudflared GOOS=$(TARGET_OS) GOARCH=$(TARGET_ARCH) $(ARM_COMMAND) go build -mod=vendor $(GO_BUILD_TAGS) $(LDFLAGS) $(IMPORT_PATH)/cmd/cloudflared
ifeq ($(FIPS), true) ifeq ($(FIPS), true)
rm -f cmd/cloudflared/fips.go
./check-fips.sh cloudflared ./check-fips.sh cloudflared
endif endif
@ -253,17 +251,4 @@ vet:
.PHONY: fmt .PHONY: fmt
fmt: fmt:
@goimports -l -w -local github.com/cloudflare/cloudflared $$(go list -mod=vendor -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc/proto) goimports -l -w -local github.com/cloudflare/cloudflared $$(go list -mod=vendor -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc/proto)
@go fmt $$(go list -mod=vendor -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc/proto)
.PHONY: fmt-check
fmt-check:
@./fmt-check.sh
.PHONY: lint
lint:
@golangci-lint run
.PHONY: mocks
mocks:
go generate mocks/mockgen.go

View File

@ -40,7 +40,7 @@ User documentation for Cloudflare Tunnel can be found at https://developers.clou
Once installed, you can authenticate `cloudflared` into your Cloudflare account and begin creating Tunnels to serve traffic to your origins. Once installed, you can authenticate `cloudflared` into your Cloudflare account and begin creating Tunnels to serve traffic to your origins.
* Create a Tunnel with [these instructions](https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/get-started/) * Create a Tunnel with [these instructions](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/create-tunnel)
* Route traffic to that Tunnel: * Route traffic to that Tunnel:
* Via public [DNS records in Cloudflare](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/dns) * Via public [DNS records in Cloudflare](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/dns)
* Or via a public hostname guided by a [Cloudflare Load Balancer](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/lb) * Or via a public hostname guided by a [Cloudflare Load Balancer](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/lb)
@ -56,27 +56,3 @@ Want to test Cloudflare Tunnel before adding a website to Cloudflare? You can do
Cloudflare currently supports versions of cloudflared that are **within one year** of the most recent release. Breaking changes unrelated to feature availability may be introduced that will impact versions released more than one year ago. You can read more about upgrading cloudflared in our [developer documentation](https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/downloads/#updating-cloudflared). Cloudflare currently supports versions of cloudflared that are **within one year** of the most recent release. Breaking changes unrelated to feature availability may be introduced that will impact versions released more than one year ago. You can read more about upgrading cloudflared in our [developer documentation](https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/downloads/#updating-cloudflared).
For example, as of January 2023 Cloudflare will support cloudflared version 2023.1.1 to cloudflared 2022.1.1. For example, as of January 2023 Cloudflare will support cloudflared version 2023.1.1 to cloudflared 2022.1.1.
## Development
### Requirements
- [GNU Make](https://www.gnu.org/software/make/)
- [capnp](https://capnproto.org/install.html)
- [cloudflare go toolchain](https://github.com/cloudflare/go)
- Optional tools:
- [capnpc-go](https://pkg.go.dev/zombiezen.com/go/capnproto2/capnpc-go)
- [goimports](https://pkg.go.dev/golang.org/x/tools/cmd/goimports)
- [golangci-lint](https://github.com/golangci/golangci-lint)
- [gomocks](https://pkg.go.dev/go.uber.org/mock)
### Build
To build cloudflared locally run `make cloudflared`
### Test
To locally run the tests run `make test`
### Linting
To format the code and keep a good code quality use `make fmt` and `make lint`
### Mocks
After changes on interfaces you might need to regenerate the mocks, so run `make mock`

View File

@ -1,91 +1,3 @@
2025.4.0
- 2025-04-02 Fix broken links in `cmd/cloudflared/*.go` related to running tunnel as a service
- 2025-04-02 chore: remove repetitive words
- 2025-04-01 Fix messages to point to one.dash.cloudflare.com
- 2025-04-01 feat: emit explicit errors for the `service` command on unsupported OSes
- 2025-04-01 Use RELEASE_NOTES date instead of build date
- 2025-04-01 chore: Update tunnel configuration link in the readme
- 2025-04-01 fix: expand home directory for credentials file
- 2025-04-01 fix: Use path and filepath operation appropriately
- 2025-04-01 feat: Adds a new command line for tunnel run for token file
- 2025-04-01 chore: fix linter rules
- 2025-03-17 TUN-9101: Don't ignore errors on `cloudflared access ssh`
- 2025-03-06 TUN-9089: Pin go import to v0.30.0, v0.31.0 requires go 1.23
2025.2.1
- 2025-02-26 TUN-9016: update base-debian to v12
- 2025-02-25 TUN-8960: Connect to FED API GW based on the OriginCert's endpoint
- 2025-02-25 TUN-9007: modify logic to resolve region when the tunnel token has an endpoint field
- 2025-02-13 SDLC-3762: Remove backstage.io/source-location from catalog-info.yaml
- 2025-02-06 TUN-8914: Create a flags module to group all cloudflared cli flags
2025.2.0
- 2025-02-03 TUN-8914: Add a new configuration to locally override the max-active-flows
- 2025-02-03 Bump x/crypto to 0.31.0
2025.1.1
- 2025-01-30 TUN-8858: update go to 1.22.10 and include quic-go FIPS changes
- 2025-01-30 TUN-8855: fix lint issues
- 2025-01-30 TUN-8855: Update PQ curve preferences
- 2025-01-30 TUN-8857: remove restriction for using FIPS and PQ
- 2025-01-30 TUN-8894: report FIPS+PQ error to Sentry when dialling to the edge
- 2025-01-22 TUN-8904: Rename Connect Response Flow Rate Limited metadata
- 2025-01-21 AUTH-6633 Fix cloudflared access login + warp as auth
- 2025-01-20 TUN-8861: Add session limiter to UDP session manager
- 2025-01-20 TUN-8861: Rename Session Limiter to Flow Limiter
- 2025-01-17 TUN-8900: Add import of Apple Developer Certificate Authority to macOS Pipeline
- 2025-01-17 TUN-8871: Accept login flag to authenticate with Fedramp environment
- 2025-01-16 TUN-8866: Add linter to cloudflared repository
- 2025-01-14 TUN-8861: Add session limiter to TCP session manager
- 2025-01-13 TUN-8861: Add configuration for active sessions limiter
- 2025-01-09 TUN-8848: Don't treat connection shutdown as an error condition when RPC server is done
2025.1.0
- 2025-01-06 TUN-8842: Add Ubuntu Noble and 'any' debian distributions to release script
- 2025-01-06 TUN-8807: Add support_datagram_v3 to remote feature rollout
- 2024-12-20 TUN-8829: add CONTAINER_BUILD to dockerfiles
2024.12.2
- 2024-12-19 TUN-8822: Prevent concurrent usage of ICMPDecoder
- 2024-12-18 TUN-8818: update changes document to reflect newly added diag subcommand
- 2024-12-17 TUN-8817: Increase close session channel by one since there are two writers
- 2024-12-13 TUN-8797: update CHANGES.md with note about semi-deterministic approach used to bind metrics server
- 2024-12-13 TUN-8724: Add CLI command for diagnostic procedure
- 2024-12-11 TUN-8786: calculate cli flags once for the diagnostic procedure
- 2024-12-11 TUN-8792: Make diag/system endpoint always return a JSON
- 2024-12-10 TUN-8783: fix log collectors for the diagnostic procedure
- 2024-12-10 TUN-8785: include the icmp sources in the diag's tunnel state
- 2024-12-10 TUN-8784: Set JSON encoder options to print formatted JSON when writing diag files
2024.12.1
- 2024-12-10 TUN-8795: update createrepo to createrepo_c to fix the release_pkgs.py script
2024.12.0
- 2024-12-09 TUN-8640: Add ICMP support for datagram V3
- 2024-12-09 TUN-8789: make python package installation consistent
- 2024-12-06 TUN-8781: Add Trixie, drop Buster. Default to Bookworm
- 2024-12-05 TUN-8775: Make sure the session Close can only be called once
- 2024-12-04 TUN-8725: implement diagnostic procedure
- 2024-12-04 TUN-8767: include raw output from network collector in diagnostic zipfile
- 2024-12-04 TUN-8770: add cli configuration and tunnel configuration to diagnostic zipfile
- 2024-12-04 TUN-8768: add job report to diagnostic zipfile
- 2024-12-03 TUN-8726: implement compression routine to be used in diagnostic procedure
- 2024-12-03 TUN-8732: implement port selection algorithm
- 2024-12-03 TUN-8762: fix argument order when invoking tracert and modify network info output parsing.
- 2024-12-03 TUN-8769: fix k8s log collector arguments
- 2024-12-03 TUN-8727: extend client to include function to get cli configuration and tunnel configuration
- 2024-11-29 TUN-8729: implement network collection for diagnostic procedure
- 2024-11-29 TUN-8727: implement metrics, runtime, system, and tunnelstate in diagnostic http client
- 2024-11-27 TUN-8733: add log collection for docker
- 2024-11-27 TUN-8734: add log collection for kubernetes
- 2024-11-27 TUN-8640: Refactor ICMPRouter to support new ICMPResponders
- 2024-11-26 TUN-8735: add managed/local log collection
- 2024-11-25 TUN-8728: implement diag/tunnel endpoint
- 2024-11-25 TUN-8730: implement diag/configuration
- 2024-11-22 TUN-8737: update metrics server port selection
- 2024-11-22 TUN-8731: Implement diag/system endpoint
- 2024-11-21 TUN-8748: Migrated datagram V3 flows to use migrated context
2024.11.1 2024.11.1
- 2024-11-18 Add cloudflared tunnel ready command - 2024-11-18 Add cloudflared tunnel ready command
- 2024-11-14 Make metrics a requirement for tunnel ready command - 2024-11-14 Make metrics a requirement for tunnel ready command

View File

@ -17,7 +17,7 @@ make cloudflared-deb
mv cloudflared-fips\_$VERSION\_$arch.deb $ARTIFACT_DIR/cloudflared-fips-linux-$arch.deb mv cloudflared-fips\_$VERSION\_$arch.deb $ARTIFACT_DIR/cloudflared-fips-linux-$arch.deb
# rpm packages invert the - and _ and use x86_64 instead of amd64. # rpm packages invert the - and _ and use x86_64 instead of amd64.
RPMVERSION=$(echo $VERSION | sed -r 's/-/_/g') RPMVERSION=$(echo $VERSION|sed -r 's/-/_/g')
RPMARCH="x86_64" RPMARCH="x86_64"
make cloudflared-rpm make cloudflared-rpm
mv cloudflared-fips-$RPMVERSION-1.$RPMARCH.rpm $ARTIFACT_DIR/cloudflared-fips-linux-$RPMARCH.rpm mv cloudflared-fips-$RPMVERSION-1.$RPMARCH.rpm $ARTIFACT_DIR/cloudflared-fips-linux-$RPMARCH.rpm

View File

@ -4,6 +4,7 @@ metadata:
name: cloudflared name: cloudflared
description: Client for Cloudflare Tunnels description: Client for Cloudflare Tunnels
annotations: annotations:
backstage.io/source-location: url:https://bitbucket.cfdata.org/projects/TUN/repos/cloudflared/browse
cloudflare.com/software-excellence-opt-in: "true" cloudflare.com/software-excellence-opt-in: "true"
cloudflare.com/jira-project-key: "TUN" cloudflare.com/jira-project-key: "TUN"
cloudflare.com/jira-project-component: "Cloudflare Tunnel" cloudflare.com/jira-project-component: "Cloudflare Tunnel"

View File

@ -1,9 +1,8 @@
pinned_go: &pinned_go go-boring=1.22.10-1 pinned_go: &pinned_go go-boring=1.22.5-1
build_dir: &build_dir /cfsetup_build build_dir: &build_dir /cfsetup_build
default-flavor: bookworm default-flavor: bullseye
buster: &buster
bullseye: &bullseye
build-linux: build-linux:
build_dir: *build_dir build_dir: *build_dir
builddeps: &build_deps builddeps: &build_deps
@ -13,14 +12,10 @@ bullseye: &bullseye
- rubygem-fpm - rubygem-fpm
- rpm - rpm
- libffi-dev - libffi-dev
- golangci-lint
pre-cache: &build_pre_cache pre-cache: &build_pre_cache
- export GOCACHE=/cfsetup_build/.cache/go-build - export GOCACHE=/cfsetup_build/.cache/go-build
- go install golang.org/x/tools/cmd/goimports@v0.30.0 - go install golang.org/x/tools/cmd/goimports@latest
post-cache: post-cache:
# Linting
- make lint
- make fmt-check
# Build binary for component test # Build binary for component test
- GOOS=linux GOARCH=amd64 make cloudflared - GOOS=linux GOARCH=amd64 make cloudflared
build-linux-fips: build-linux-fips:
@ -36,8 +31,8 @@ bullseye: &bullseye
builddeps: *build_deps builddeps: *build_deps
pre-cache: *build_pre_cache pre-cache: *build_pre_cache
post-cache: post-cache:
- make cover - make cover
# except FIPS and macos # except FIPS and macos
build-linux-release: build-linux-release:
build_dir: *build_dir build_dir: *build_dir
builddeps: &build_deps_release builddeps: &build_deps_release
@ -51,17 +46,19 @@ bullseye: &bullseye
- python3-pip - python3-pip
- python3-setuptools - python3-setuptools
- wget - wget
- python3-venv pre-cache: &build_release_pre_cache
- pip3 install pynacl==1.4.0
- pip3 install pygithub==1.55
- pip3 install boto3==1.22.9
- pip3 install python-gnupg==0.4.9
post-cache: post-cache:
- python3 -m venv env
- . /cfsetup_build/env/bin/activate
- pip install pynacl==1.4.0 pygithub==1.55 boto3==1.22.9 python-gnupg==0.4.9
# build all packages (except macos and FIPS) and move them to /cfsetup/built_artifacts # build all packages (except macos and FIPS) and move them to /cfsetup/built_artifacts
- ./build-packages.sh - ./build-packages.sh
# handle FIPS separately so that we built with gofips compiler # handle FIPS separately so that we built with gofips compiler
build-linux-fips-release: build-linux-fips-release:
build_dir: *build_dir build_dir: *build_dir
builddeps: *build_deps_release builddeps: *build_deps_release
pre-cache: *build_release_pre_cache
post-cache: post-cache:
# same logic as above, but for FIPS packages only # same logic as above, but for FIPS packages only
- ./build-packages-fips.sh - ./build-packages-fips.sh
@ -113,7 +110,7 @@ bullseye: &bullseye
- export GOOS=linux - export GOOS=linux
- export GOARCH=arm64 - export GOARCH=arm64
- export NIGHTLY=true - export NIGHTLY=true
# - export FIPS=true # TUN-7595 #- export FIPS=true # TUN-7595
- export ORIGINAL_NAME=true - export ORIGINAL_NAME=true
- make cloudflared-deb - make cloudflared-deb
build-deb-arm64: build-deb-arm64:
@ -136,14 +133,12 @@ bullseye: &bullseye
# libmsi and libgcab are libraries the wixl binary depends on. # libmsi and libgcab are libraries the wixl binary depends on.
- libmsi-dev - libmsi-dev
- libgcab-dev - libgcab-dev
- python3-venv
pre-cache: pre-cache:
- wget https://github.com/sudarshan-reddy/msitools/releases/download/v0.101b/wixl -P /usr/local/bin - wget https://github.com/sudarshan-reddy/msitools/releases/download/v0.101b/wixl -P /usr/local/bin
- chmod a+x /usr/local/bin/wixl - chmod a+x /usr/local/bin/wixl
- pip3 install pynacl==1.4.0
- pip3 install pygithub==1.55
post-cache: post-cache:
- python3 -m venv env
- . env/bin/activate
- pip install pynacl==1.4.0 pygithub==1.55
- .teamcity/package-windows.sh - .teamcity/package-windows.sh
test: test:
build_dir: *build_dir build_dir: *build_dir
@ -160,6 +155,7 @@ bullseye: &bullseye
- export GOOS=linux - export GOOS=linux
- export GOARCH=amd64 - export GOARCH=amd64
- export PATH="$HOME/go/bin:$PATH" - export PATH="$HOME/go/bin:$PATH"
- ./fmt-check.sh
- make test | gotest-to-teamcity - make test | gotest-to-teamcity
test-fips: test-fips:
build_dir: *build_dir build_dir: *build_dir
@ -170,27 +166,24 @@ bullseye: &bullseye
- export GOARCH=amd64 - export GOARCH=amd64
- export FIPS=true - export FIPS=true
- export PATH="$HOME/go/bin:$PATH" - export PATH="$HOME/go/bin:$PATH"
- ./fmt-check.sh
- make test | gotest-to-teamcity - make test | gotest-to-teamcity
component-test: component-test:
build_dir: *build_dir build_dir: *build_dir
builddeps: &build_deps_component_test builddeps: &build_deps_component_test
- *pinned_go - *pinned_go
- python3 - python3.7
- python3-pip - python3-pip
- python3-setuptools - python3-setuptools
# procps installs the ps command which is needed in test_sysv_service # procps installs the ps command which is needed in test_sysv_service because the init script
# because the init script uses ps pid to determine if the agent is # uses ps pid to determine if the agent is running
# running
- procps - procps
- python3-venv
pre-cache-copy-paths: pre-cache-copy-paths:
- component-tests/requirements.txt - component-tests/requirements.txt
pre-cache: &component_test_pre_cache
- sudo pip3 install --upgrade -r component-tests/requirements.txt
post-cache: &component_test_post_cache post-cache: &component_test_post_cache
- python3 -m venv env # Creates and routes a Named Tunnel for this build. Also constructs config file from env vars.
- . env/bin/activate
- pip install --upgrade -r component-tests/requirements.txt
# Creates and routes a Named Tunnel for this build. Also constructs
# config file from env vars.
- python3 component-tests/setup.py --type create - python3 component-tests/setup.py --type create
- pytest component-tests -o log_cli=true --log-cli-level=INFO - pytest component-tests -o log_cli=true --log-cli-level=INFO
# The Named Tunnel is deleted and its route unprovisioned here. # The Named Tunnel is deleted and its route unprovisioned here.
@ -200,6 +193,7 @@ bullseye: &bullseye
builddeps: *build_deps_component_test builddeps: *build_deps_component_test
pre-cache-copy-paths: pre-cache-copy-paths:
- component-tests/requirements.txt - component-tests/requirements.txt
pre-cache: *component_test_pre_cache
post-cache: *component_test_post_cache post-cache: *component_test_post_cache
github-release-dryrun: github-release-dryrun:
build_dir: *build_dir build_dir: *build_dir
@ -210,11 +204,10 @@ bullseye: &bullseye
- libffi-dev - libffi-dev
- python3-setuptools - python3-setuptools
- python3-pip - python3-pip
- python3-venv pre-cache:
- pip3 install pynacl==1.4.0
- pip3 install pygithub==1.55
post-cache: post-cache:
- python3 -m venv env
- . env/bin/activate
- pip install pynacl==1.4.0 pygithub==1.55
- make github-release-dryrun - make github-release-dryrun
github-release: github-release:
build_dir: *build_dir build_dir: *build_dir
@ -225,11 +218,10 @@ bullseye: &bullseye
- libffi-dev - libffi-dev
- python3-setuptools - python3-setuptools
- python3-pip - python3-pip
- python3-venv pre-cache:
- pip3 install pynacl==1.4.0
- pip3 install pygithub==1.55
post-cache: post-cache:
- python3 -m venv env
- . env/bin/activate
- pip install pynacl==1.4.0 pygithub==1.55
- make github-release - make github-release
r2-linux-release: r2-linux-release:
build_dir: *build_dir build_dir: *build_dir
@ -245,13 +237,14 @@ bullseye: &bullseye
- python3-setuptools - python3-setuptools
- python3-pip - python3-pip
- reprepro - reprepro
- createrepo-c - createrepo
- python3-venv pre-cache:
- pip3 install pynacl==1.4.0
- pip3 install pygithub==1.55
- pip3 install boto3==1.22.9
- pip3 install python-gnupg==0.4.9
post-cache: post-cache:
- python3 -m venv env
- . env/bin/activate
- pip install pynacl==1.4.0 pygithub==1.55 boto3==1.22.9 python-gnupg==0.4.9
- make r2-linux-release - make r2-linux-release
bookworm: *bullseye bullseye: *buster
trixie: *bullseye bookworm: *buster

View File

@ -104,7 +104,7 @@ func ssh(c *cli.Context) error {
case 3: case 3:
options.OriginURL = fmt.Sprintf("https://%s:%s", parts[2], parts[1]) options.OriginURL = fmt.Sprintf("https://%s:%s", parts[2], parts[1])
options.TLSClientConfig = &tls.Config{ options.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true, // #nosec G402 InsecureSkipVerify: true,
ServerName: parts[0], ServerName: parts[0],
} }
log.Warn().Msgf("Using insecure SSL connection because SNI overridden to %s", parts[0]) log.Warn().Msgf("Using insecure SSL connection because SNI overridden to %s", parts[0])
@ -141,5 +141,6 @@ func ssh(c *cli.Context) error {
logger := log.With().Str("host", url.Host).Logger() logger := log.With().Str("host", url.Host).Logger()
s = stream.NewDebugStream(s, &logger, maxMessages) s = stream.NewDebugStream(s, &logger, maxMessages)
} }
return carrier.StartClient(wsConn, s, options) carrier.StartClient(wsConn, s, options)
return nil
} }

View File

@ -19,7 +19,6 @@ import (
"github.com/cloudflare/cloudflared/carrier" "github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/sshgen" "github.com/cloudflare/cloudflared/sshgen"
"github.com/cloudflare/cloudflared/token" "github.com/cloudflare/cloudflared/token"
@ -173,15 +172,15 @@ func Commands() []*cli.Command {
EnvVars: []string{"TUNNEL_SERVICE_TOKEN_SECRET"}, EnvVars: []string{"TUNNEL_SERVICE_TOKEN_SECRET"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: cfdflags.LogFile, Name: logger.LogFileFlag,
Usage: "Save application log to this file for reporting issues.", Usage: "Save application log to this file for reporting issues.",
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: cfdflags.LogDirectory, Name: logger.LogSSHDirectoryFlag,
Usage: "Save application log to this directory for reporting issues.", Usage: "Save application log to this directory for reporting issues.",
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: cfdflags.LogLevelSSH, Name: logger.LogSSHLevelFlag,
Aliases: []string{"loglevel"}, //added to match the tunnel side Aliases: []string{"loglevel"}, //added to match the tunnel side
Usage: "Application logging level {debug, info, warn, error, fatal}. ", Usage: "Application logging level {debug, info, warn, error, fatal}. ",
}, },
@ -343,7 +342,7 @@ func run(cmd string, args ...string) error {
return err return err
} }
go func() { go func() {
_, _ = io.Copy(os.Stderr, stderr) io.Copy(os.Stderr, stderr)
}() }()
stdout, err := c.StdoutPipe() stdout, err := c.StdoutPipe()
@ -351,7 +350,7 @@ func run(cmd string, args ...string) error {
return err return err
} }
go func() { go func() {
_, _ = io.Copy(os.Stdout, stdout) io.Copy(os.Stdout, stdout)
}() }()
return c.Run() return c.Run()
} }
@ -532,7 +531,7 @@ func isFileThere(candidate string) bool {
} }
// verifyTokenAtEdge checks for a token on disk, or generates a new one. // verifyTokenAtEdge checks for a token on disk, or generates a new one.
// Then makes a request to the origin with the token to ensure it is valid. // Then makes a request to to the origin with the token to ensure it is valid.
// Returns nil if token is valid. // Returns nil if token is valid.
func verifyTokenAtEdge(appUrl *url.URL, appInfo *token.AppInfo, c *cli.Context, log *zerolog.Logger) error { func verifyTokenAtEdge(appUrl *url.URL, appInfo *token.AppInfo, c *cli.Context, log *zerolog.Logger) error {
headers := parseRequestHeaders(c.StringSlice(sshHeaderFlag)) headers := parseRequestHeaders(c.StringSlice(sshHeaderFlag))

View File

@ -4,7 +4,7 @@ import (
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc" "github.com/urfave/cli/v2/altsrc"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags" "github.com/cloudflare/cloudflared/logger"
) )
var ( var (
@ -15,14 +15,14 @@ var (
func ConfigureLoggingFlags(shouldHide bool) []cli.Flag { func ConfigureLoggingFlags(shouldHide bool) []cli.Flag {
return []cli.Flag{ return []cli.Flag{
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.LogLevel, Name: logger.LogLevelFlag,
Value: "info", Value: "info",
Usage: "Application logging level {debug, info, warn, error, fatal}. " + debugLevelWarning, Usage: "Application logging level {debug, info, warn, error, fatal}. " + debugLevelWarning,
EnvVars: []string{"TUNNEL_LOGLEVEL"}, EnvVars: []string{"TUNNEL_LOGLEVEL"},
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.TransportLogLevel, Name: logger.LogTransportLevelFlag,
Aliases: []string{"proto-loglevel"}, // This flag used to be called proto-loglevel Aliases: []string{"proto-loglevel"}, // This flag used to be called proto-loglevel
Value: "info", Value: "info",
Usage: "Transport logging level(previously called protocol logging level) {debug, info, warn, error, fatal}", Usage: "Transport logging level(previously called protocol logging level) {debug, info, warn, error, fatal}",
@ -30,19 +30,19 @@ func ConfigureLoggingFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.LogFile, Name: logger.LogFileFlag,
Usage: "Save application log to this file for reporting issues.", Usage: "Save application log to this file for reporting issues.",
EnvVars: []string{"TUNNEL_LOGFILE"}, EnvVars: []string{"TUNNEL_LOGFILE"},
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.LogDirectory, Name: logger.LogDirectoryFlag,
Usage: "Save application log to this directory for reporting issues.", Usage: "Save application log to this directory for reporting issues.",
EnvVars: []string{"TUNNEL_LOGDIRECTORY"}, EnvVars: []string{"TUNNEL_LOGDIRECTORY"},
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.TraceOutput, Name: "trace-output",
Usage: "Name of trace output file, generated when cloudflared stops.", Usage: "Name of trace output file, generated when cloudflared stops.",
EnvVars: []string{"TUNNEL_TRACE_OUTPUT"}, EnvVars: []string{"TUNNEL_TRACE_OUTPUT"},
Hidden: shouldHide, Hidden: shouldHide,

View File

@ -1,155 +0,0 @@
package flags
const (
// HaConnections specifies how many connections to make to the edge
HaConnections = "ha-connections"
// SshPort is the port on localhost the cloudflared ssh server will run on
SshPort = "local-ssh-port"
// SshIdleTimeout defines the duration a SSH session can remain idle before being closed
SshIdleTimeout = "ssh-idle-timeout"
// SshMaxTimeout defines the max duration a SSH session can remain open for
SshMaxTimeout = "ssh-max-timeout"
// SshLogUploaderBucketName is the bucket name to use for the SSH log uploader
SshLogUploaderBucketName = "bucket-name"
// SshLogUploaderRegionName is the AWS region name to use for the SSH log uploader
SshLogUploaderRegionName = "region-name"
// SshLogUploaderSecretID is the Secret id of SSH log uploader
SshLogUploaderSecretID = "secret-id"
// SshLogUploaderAccessKeyID is the Access key id of SSH log uploader
SshLogUploaderAccessKeyID = "access-key-id"
// SshLogUploaderSessionTokenID is the Session token of SSH log uploader
SshLogUploaderSessionTokenID = "session-token"
// SshLogUploaderS3URL is the S3 URL of SSH log uploader (e.g. don't use AWS s3 and use google storage bucket instead)
SshLogUploaderS3URL = "s3-url-host"
// HostKeyPath is the path of the dir to save SSH host keys too
HostKeyPath = "host-key-path"
// RpcTimeout is how long to wait for a Capnp RPC request to the edge
RpcTimeout = "rpc-timeout"
// WriteStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin).
WriteStreamTimeout = "write-stream-timeout"
// QuicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size.
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
// Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets.
QuicDisablePathMTUDiscovery = "quic-disable-pmtu-discovery"
// QuicConnLevelFlowControlLimit controls the max flow control limit allocated for a QUIC connection. This controls how much data is the
// receiver willing to buffer. Once the limit is reached, the sender will send a DATA_BLOCKED frame to indicate it has more data to write,
// but it's blocked by flow control
QuicConnLevelFlowControlLimit = "quic-connection-level-flow-control-limit"
// QuicStreamLevelFlowControlLimit is similar to quicConnLevelFlowControlLimit but for each QUIC stream. When the sender is blocked,
// it will send a STREAM_DATA_BLOCKED frame
QuicStreamLevelFlowControlLimit = "quic-stream-level-flow-control-limit"
// Ui is to enable launching cloudflared in interactive UI mode
Ui = "ui"
// ConnectorLabel is the command line flag to give a meaningful label to a specific connector
ConnectorLabel = "label"
// MaxActiveFlows is the command line flag to set the maximum number of flows that cloudflared can be processing at the same time
MaxActiveFlows = "max-active-flows"
// Tag is the command line flag to set custom tags used to identify this tunnel via added HTTP request headers to the origin
Tag = "tag"
// Protocol is the command line flag to set the protocol to use to connect to the Cloudflare Edge
Protocol = "protocol"
// PostQuantum is the command line flag to force the connection to Cloudflare Edge to use Post Quantum cryptography
PostQuantum = "post-quantum"
// Features is the command line flag to opt into various features that are still being developed or tested
Features = "features"
// EdgeIpVersion is the command line flag to set the Cloudflare Edge IP address version to connect with
EdgeIpVersion = "edge-ip-version"
// EdgeBindAddress is the command line flag to bind to IP address for outgoing connections to Cloudflare Edge
EdgeBindAddress = "edge-bind-address"
// Force is the command line flag to specify if you wish to force an action
Force = "force"
// Edge is the command line flag to set the address of the Cloudflare tunnel server. Only works in Cloudflare's internal testing environment
Edge = "edge"
// Region is the command line flag to set the Cloudflare Edge region to connect to
Region = "region"
// IsAutoUpdated is the command line flag to signal the new process that cloudflared has been autoupdated
IsAutoUpdated = "is-autoupdated"
// LBPool is the command line flag to set the name of the load balancing pool to add this origin to
LBPool = "lb-pool"
// Retries is the command line flag to set the maximum number of retries for connection/protocol errors
Retries = "retries"
// MaxEdgeAddrRetries is the command line flag to set the maximum number of times to retry on edge addrs before falling back to a lower protocol
MaxEdgeAddrRetries = "max-edge-addr-retries"
// GracePeriod is the command line flag to set the maximum amount of time that cloudflared waits to shut down if it is still serving requests
GracePeriod = "grace-period"
// ICMPV4Src is the command line flag to set the source address and the interface name to send/receive ICMPv4 messages
ICMPV4Src = "icmpv4-src"
// ICMPV6Src is the command line flag to set the source address and the interface name to send/receive ICMPv6 messages
ICMPV6Src = "icmpv6-src"
// ProxyDns is the command line flag to run DNS server over HTTPS
ProxyDns = "proxy-dns"
// Name is the command line to set the name of the tunnel
Name = "name"
// AutoUpdateFreq is the command line for setting the frequency that cloudflared checks for updates
AutoUpdateFreq = "autoupdate-freq"
// NoAutoUpdate is the command line flag to disable cloudflared from checking for updates
NoAutoUpdate = "no-autoupdate"
// LogLevel is the command line flag for the cloudflared logging level
LogLevel = "loglevel"
// LogLevelSSH is the command line flag for the cloudflared ssh logging level
LogLevelSSH = "log-level"
// TransportLogLevel is the command line flag for the transport logging level
TransportLogLevel = "transport-loglevel"
// LogFile is the command line flag to define the file where application logs will be stored
LogFile = "logfile"
// LogDirectory is the command line flag to define the directory where application logs will be stored.
LogDirectory = "log-directory"
// TraceOutput is the command line flag to set the name of trace output file
TraceOutput = "trace-output"
// OriginCert is the command line flag to define the path for the origin certificate used by cloudflared
OriginCert = "origincert"
// Metrics is the command line flag to define the address of the metrics server
Metrics = "metrics"
// MetricsUpdateFreq is the command line flag to define how frequently tunnel metrics are updated
MetricsUpdateFreq = "metrics-update-freq"
// ApiURL is the command line flag used to define the base URL of the API
ApiURL = "api-url"
)

View File

@ -3,38 +3,11 @@
package main package main
import ( import (
"fmt"
"os" "os"
cli "github.com/urfave/cli/v2" cli "github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
) )
func runApp(app *cli.App, graceShutdownC chan struct{}) { func runApp(app *cli.App, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the cloudflared system service (not supported on this operating system)",
Subcommands: []*cli.Command{
{
Name: "install",
Usage: "Install cloudflared as a system service (not supported on this operating system)",
Action: cliutil.ConfiguredAction(installGenericService),
},
{
Name: "uninstall",
Usage: "Uninstall the cloudflared service (not supported on this operating system)",
Action: cliutil.ConfiguredAction(uninstallGenericService),
},
},
})
app.Run(os.Args) app.Run(os.Args)
} }
func installGenericService(c *cli.Context) error {
return fmt.Errorf("service installation is not supported on this operating system")
}
func uninstallGenericService(c *cli.Context) error {
return fmt.Errorf("service uninstallation is not supported on this operating system")
}

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"os" "os"
homedir "github.com/mitchellh/go-homedir"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
@ -18,7 +17,7 @@ const (
launchdIdentifier = "com.cloudflare.cloudflared" launchdIdentifier = "com.cloudflare.cloudflared"
) )
func runApp(app *cli.App, _ chan struct{}) { func runApp(app *cli.App, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{ app.Commands = append(app.Commands, &cli.Command{
Name: "service", Name: "service",
Usage: "Manages the cloudflared launch agent", Usage: "Manages the cloudflared launch agent",
@ -120,7 +119,7 @@ func installLaunchd(c *cli.Context) error {
log.Info().Msg("Installing cloudflared client as an user launch agent. " + log.Info().Msg("Installing cloudflared client as an user launch agent. " +
"Note that cloudflared client will only run when the user is logged in. " + "Note that cloudflared client will only run when the user is logged in. " +
"If you want to run cloudflared client at boot, install with root permission. " + "If you want to run cloudflared client at boot, install with root permission. " +
"For more information, visit https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/configure-tunnels/local-management/as-a-service/macos/") "For more information, visit https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/run-as-service")
} }
etPath, err := os.Executable() etPath, err := os.Executable()
if err != nil { if err != nil {
@ -208,15 +207,3 @@ func uninstallLaunchd(c *cli.Context) error {
} }
return err return err
} }
func userHomeDir() (string, error) {
// This returns the home dir of the executing user using OS-specific method
// for discovering the home dir. It's not recommended to call this function
// when the user has root permission as $HOME depends on what options the user
// use with sudo.
homeDir, err := homedir.Dir()
if err != nil {
return "", errors.Wrap(err, "Cannot determine home directory for the user")
}
return homeDir, nil
}

View File

@ -2,17 +2,19 @@ package main
import ( import (
"fmt" "fmt"
"math/rand"
"os" "os"
"strings" "strings"
"time" "time"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
homedir "github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"go.uber.org/automaxprocs/maxprocs" "go.uber.org/automaxprocs/maxprocs"
"github.com/cloudflare/cloudflared/cmd/cloudflared/access" "github.com/cloudflare/cloudflared/cmd/cloudflared/access"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/cmd/cloudflared/proxydns" "github.com/cloudflare/cloudflared/cmd/cloudflared/proxydns"
"github.com/cloudflare/cloudflared/cmd/cloudflared/tail" "github.com/cloudflare/cloudflared/cmd/cloudflared/tail"
"github.com/cloudflare/cloudflared/cmd/cloudflared/tunnel" "github.com/cloudflare/cloudflared/cmd/cloudflared/tunnel"
@ -50,8 +52,10 @@ var (
func main() { func main() {
// FIXME: TUN-8148: Disable QUIC_GO ECN due to bugs in proper detection if supported // FIXME: TUN-8148: Disable QUIC_GO ECN due to bugs in proper detection if supported
os.Setenv("QUIC_GO_DISABLE_ECN", "1") os.Setenv("QUIC_GO_DISABLE_ECN", "1")
rand.Seed(time.Now().UnixNano())
metrics.RegisterBuildInfo(BuildType, BuildTime, Version) metrics.RegisterBuildInfo(BuildType, BuildTime, Version)
_, _ = maxprocs.Set() maxprocs.Set()
bInfo := cliutil.GetBuildInfo(BuildType, Version) bInfo := cliutil.GetBuildInfo(BuildType, Version)
// Graceful shutdown channel used by the app. When closed, app must terminate gracefully. // Graceful shutdown channel used by the app. When closed, app must terminate gracefully.
@ -106,7 +110,7 @@ func commands(version func(c *cli.Context)) []*cli.Command {
Usage: "specify if you wish to update to the latest beta version", Usage: "specify if you wish to update to the latest beta version",
}, },
&cli.BoolFlag{ &cli.BoolFlag{
Name: cfdflags.Force, Name: "force",
Usage: "specify if you wish to force an upgrade to the latest version regardless of the current version", Usage: "specify if you wish to force an upgrade to the latest version regardless of the current version",
Hidden: true, Hidden: true,
}, },
@ -180,6 +184,18 @@ func action(graceShutdownC chan struct{}) cli.ActionFunc {
}) })
} }
func userHomeDir() (string, error) {
// This returns the home dir of the executing user using OS-specific method
// for discovering the home dir. It's not recommended to call this function
// when the user has root permission as $HOME depends on what options the user
// use with sudo.
homeDir, err := homedir.Dir()
if err != nil {
return "", errors.Wrap(err, "Cannot determine home directory for the user")
}
return homeDir, nil
}
// In order to keep the amount of noise sent to Sentry low, typical network errors can be filtered out here by a substring match. // In order to keep the amount of noise sent to Sentry low, typical network errors can be filtered out here by a substring match.
func captureError(err error) { func captureError(err error) {
errorMessage := err.Error() errorMessage := err.Error()

View File

@ -1,13 +1,13 @@
package main package main
import ( import (
"bufio"
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path"
"text/template" "text/template"
homedir "github.com/mitchellh/go-homedir" homedir "github.com/mitchellh/go-homedir"
@ -44,7 +44,7 @@ func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
return err return err
} }
if _, err = os.Stat(resolvedPath); err == nil { if _, err = os.Stat(resolvedPath); err == nil {
return errors.New(serviceAlreadyExistsWarn(resolvedPath)) return fmt.Errorf(serviceAlreadyExistsWarn(resolvedPath))
} }
var buffer bytes.Buffer var buffer bytes.Buffer
@ -57,7 +57,7 @@ func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
fileMode = st.FileMode fileMode = st.FileMode
} }
plistFolder := filepath.Dir(resolvedPath) plistFolder := path.Dir(resolvedPath)
err = os.MkdirAll(plistFolder, 0o755) err = os.MkdirAll(plistFolder, 0o755)
if err != nil { if err != nil {
return fmt.Errorf("error creating %s: %v", plistFolder, err) return fmt.Errorf("error creating %s: %v", plistFolder, err)
@ -118,6 +118,49 @@ func ensureConfigDirExists(configDir string) error {
return err return err
} }
// openFile opens the file at path. If create is set and the file exists, returns nil, true, nil
func openFile(path string, create bool) (file *os.File, exists bool, err error) {
expandedPath, err := homedir.Expand(path)
if err != nil {
return nil, false, err
}
if create {
fileInfo, err := os.Stat(expandedPath)
if err == nil && fileInfo.Size() > 0 {
return nil, true, nil
}
file, err = os.OpenFile(expandedPath, os.O_RDWR|os.O_CREATE, 0600)
} else {
file, err = os.Open(expandedPath)
}
return file, false, err
}
func copyCredential(srcCredentialPath, destCredentialPath string) error {
destFile, exists, err := openFile(destCredentialPath, true)
if err != nil {
return err
} else if exists {
// credentials already exist, do nothing
return nil
}
defer destFile.Close()
srcFile, _, err := openFile(srcCredentialPath, false)
if err != nil {
return err
}
defer srcFile.Close()
// Copy certificate
_, err = io.Copy(destFile, srcFile)
if err != nil {
return fmt.Errorf("unable to copy %s to %s: %v", srcCredentialPath, destCredentialPath, err)
}
return nil
}
func copyFile(src, dest string) error { func copyFile(src, dest string) error {
srcFile, err := os.Open(src) srcFile, err := os.Open(src)
if err != nil { if err != nil {
@ -144,3 +187,36 @@ func copyFile(src, dest string) error {
ok = true ok = true
return nil return nil
} }
func copyConfig(srcConfigPath, destConfigPath string) error {
// Copy or create config
destFile, exists, err := openFile(destConfigPath, true)
if err != nil {
return fmt.Errorf("cannot open %s with error: %s", destConfigPath, err)
} else if exists {
// config already exists, do nothing
return nil
}
defer destFile.Close()
srcFile, _, err := openFile(srcConfigPath, false)
if err != nil {
fmt.Println("Your service needs a config file that at least specifies the hostname option.")
fmt.Println("Type in a hostname now, or leave it blank and create the config file later.")
fmt.Print("Hostname: ")
reader := bufio.NewReader(os.Stdin)
input, _ := reader.ReadString('\n')
if input == "" {
return err
}
fmt.Fprintf(destFile, "hostname: %s\n", input)
} else {
defer srcFile.Close()
_, err = io.Copy(destFile, srcFile)
if err != nil {
return fmt.Errorf("unable to copy %s to %s: %v", srcConfigPath, destConfigPath, err)
}
}
return nil
}

View File

@ -18,12 +18,14 @@ import (
"nhooyr.io/websocket" "nhooyr.io/websocket"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/credentials" "github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/management"
) )
var buildInfo *cliutil.BuildInfo var (
buildInfo *cliutil.BuildInfo
)
func Init(bi *cliutil.BuildInfo) { func Init(bi *cliutil.BuildInfo) {
buildInfo = bi buildInfo = bi
@ -54,7 +56,7 @@ func managementTokenCommand(c *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
tokenResponse := struct { var tokenResponse = struct {
Token string `json:"token"` Token string `json:"token"`
}{Token: token} }{Token: token}
@ -117,13 +119,13 @@ func buildTailCommand(subcommands []*cli.Command) *cli.Command {
Value: "", Value: "",
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: cfdflags.LogLevel, Name: logger.LogLevelFlag,
Value: "info", Value: "info",
Usage: "Application logging level {debug, info, warn, error, fatal}", Usage: "Application logging level {debug, info, warn, error, fatal}",
EnvVars: []string{"TUNNEL_LOGLEVEL"}, EnvVars: []string{"TUNNEL_LOGLEVEL"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: cfdflags.OriginCert, Name: credentials.OriginCertFlag,
Usage: "Path to the certificate generated for your origin when you run cloudflared login.", Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
EnvVars: []string{"TUNNEL_ORIGIN_CERT"}, EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: credentials.FindDefaultOriginCertPath(), Value: credentials.FindDefaultOriginCertPath(),
@ -167,7 +169,7 @@ func handleValidationError(resp *http.Response, log *zerolog.Logger) {
// logger will be created to emit only against the os.Stderr as to not obstruct with normal output from // logger will be created to emit only against the os.Stderr as to not obstruct with normal output from
// management requests // management requests
func createLogger(c *cli.Context) *zerolog.Logger { func createLogger(c *cli.Context) *zerolog.Logger {
level, levelErr := zerolog.ParseLevel(c.String(cfdflags.LogLevel)) level, levelErr := zerolog.ParseLevel(c.String(logger.LogLevelFlag))
if levelErr != nil { if levelErr != nil {
level = zerolog.InfoLevel level = zerolog.InfoLevel
} }
@ -181,10 +183,9 @@ func createLogger(c *cli.Context) *zerolog.Logger {
// parseFilters will attempt to parse provided filters to send to with the EventStartStreaming // parseFilters will attempt to parse provided filters to send to with the EventStartStreaming
func parseFilters(c *cli.Context) (*management.StreamingFilters, error) { func parseFilters(c *cli.Context) (*management.StreamingFilters, error) {
var level *management.LogLevel var level *management.LogLevel
var events []management.LogEventType
var sample float64 var sample float64
events := make([]management.LogEventType, 0)
argLevel := c.String("level") argLevel := c.String("level")
argEvents := c.StringSlice("event") argEvents := c.StringSlice("event")
argSample := c.Float64("sample") argSample := c.Float64("sample")
@ -224,12 +225,12 @@ func parseFilters(c *cli.Context) (*management.StreamingFilters, error) {
// getManagementToken will make a call to the Cloudflare API to acquire a management token for the requested tunnel. // getManagementToken will make a call to the Cloudflare API to acquire a management token for the requested tunnel.
func getManagementToken(c *cli.Context, log *zerolog.Logger) (string, error) { func getManagementToken(c *cli.Context, log *zerolog.Logger) (string, error) {
userCreds, err := credentials.Read(c.String(cfdflags.OriginCert), log) userCreds, err := credentials.Read(c.String(credentials.OriginCertFlag), log)
if err != nil { if err != nil {
return "", err return "", err
} }
client, err := userCreds.Client(c.String(cfdflags.ApiURL), buildInfo.UserAgent(), log) client, err := userCreds.Client(c.String("api-url"), buildInfo.UserAgent(), log)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -330,7 +331,6 @@ func Run(c *cli.Context) error {
header["cf-trace-id"] = []string{trace} header["cf-trace-id"] = []string{trace}
} }
ctx := c.Context ctx := c.Context
// nolint: bodyclose
conn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{ conn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
HTTPHeader: header, HTTPHeader: header,
}) })

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
"path/filepath"
"runtime/trace" "runtime/trace"
"strings" "strings"
"sync" "sync"
@ -16,7 +15,7 @@ import (
"github.com/facebookgo/grace/gracenet" "github.com/facebookgo/grace/gracenet"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mitchellh/go-homedir" homedir "github.com/mitchellh/go-homedir"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
@ -24,14 +23,13 @@ import (
"github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/cmd/cloudflared/proxydns" "github.com/cloudflare/cloudflared/cmd/cloudflared/proxydns"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater" "github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/credentials" "github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/diagnostic"
"github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/features"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/management"
@ -41,13 +39,67 @@ import (
"github.com/cloudflare/cloudflared/supervisor" "github.com/cloudflare/cloudflared/supervisor"
"github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns" "github.com/cloudflare/cloudflared/tunneldns"
"github.com/cloudflare/cloudflared/tunnelstate"
"github.com/cloudflare/cloudflared/validation" "github.com/cloudflare/cloudflared/validation"
) )
const ( const (
sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878" sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
// ha-Connections specifies how many connections to make to the edge
haConnectionsFlag = "ha-connections"
// sshPortFlag is the port on localhost the cloudflared ssh server will run on
sshPortFlag = "local-ssh-port"
// sshIdleTimeoutFlag defines the duration a SSH session can remain idle before being closed
sshIdleTimeoutFlag = "ssh-idle-timeout"
// sshMaxTimeoutFlag defines the max duration a SSH session can remain open for
sshMaxTimeoutFlag = "ssh-max-timeout"
// bucketNameFlag is the bucket name to use for the SSH log uploader
bucketNameFlag = "bucket-name"
// regionNameFlag is the AWS region name to use for the SSH log uploader
regionNameFlag = "region-name"
// secretIDFlag is the Secret id of SSH log uploader
secretIDFlag = "secret-id"
// accessKeyIDFlag is the Access key id of SSH log uploader
accessKeyIDFlag = "access-key-id"
// sessionTokenIDFlag is the Session token of SSH log uploader
sessionTokenIDFlag = "session-token"
// s3URLFlag is the S3 URL of SSH log uploader (e.g. don't use AWS s3 and use google storage bucket instead)
s3URLFlag = "s3-url-host"
// hostKeyPath is the path of the dir to save SSH host keys too
hostKeyPath = "host-key-path"
// rpcTimeout is how long to wait for a Capnp RPC request to the edge
rpcTimeout = "rpc-timeout"
// writeStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin).
writeStreamTimeout = "write-stream-timeout"
// quicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size.
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
// Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets.
quicDisablePathMTUDiscovery = "quic-disable-pmtu-discovery"
// quicConnLevelFlowControlLimit controls the max flow control limit allocated for a QUIC connection. This controls how much data is the
// receiver willing to buffer. Once the limit is reached, the sender will send a DATA_BLOCKED frame to indicate it has more data to write,
// but it's blocked by flow control
quicConnLevelFlowControlLimit = "quic-connection-level-flow-control-limit"
// quicStreamLevelFlowControlLimit is similar to quicConnLevelFlowControlLimit but for each QUIC stream. When the sender is blocked,
// it will send a STREAM_DATA_BLOCKED frame
quicStreamLevelFlowControlLimit = "quic-stream-level-flow-control-limit"
// uiFlag is to enable launching cloudflared in interactive UI mode
uiFlag = "ui"
LogFieldCommand = "command" LogFieldCommand = "command"
LogFieldExpandedPath = "expandedPath" LogFieldExpandedPath = "expandedPath"
LogFieldPIDPathname = "pidPathname" LogFieldPIDPathname = "pidPathname"
@ -62,6 +114,7 @@ Eg. cloudflared tunnel --url localhost:8080/.
Please note that Quick Tunnels are meant to be ephemeral and should only be used for testing purposes. Please note that Quick Tunnels are meant to be ephemeral and should only be used for testing purposes.
For production usage, we recommend creating Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/) For production usage, we recommend creating Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)
` `
connectorLabelFlag = "label"
) )
var ( var (
@ -71,96 +124,7 @@ var (
routeFailMsg = fmt.Sprintf("failed to provision routing, please create it manually via Cloudflare dashboard or UI; "+ routeFailMsg = fmt.Sprintf("failed to provision routing, please create it manually via Cloudflare dashboard or UI; "+
"most likely you already have a conflicting record there. You can also rerun this command with --%s to overwrite "+ "most likely you already have a conflicting record there. You can also rerun this command with --%s to overwrite "+
"any existing DNS records for this hostname.", overwriteDNSFlag) "any existing DNS records for this hostname.", overwriteDNSFlag)
errDeprecatedClassicTunnel = errors.New("Classic tunnels have been deprecated, please use Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)") deprecatedClassicTunnelErr = fmt.Errorf("Classic tunnels have been deprecated, please use Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)")
// TODO: TUN-8756 the list below denotes the flags that do not possess any kind of sensitive information
// however this approach is not maintainble in the long-term.
nonSecretFlagsList = []string{
"config",
cfdflags.AutoUpdateFreq,
cfdflags.NoAutoUpdate,
cfdflags.Metrics,
"pidfile",
"url",
"hello-world",
"socks5",
"proxy-connect-timeout",
"proxy-tls-timeout",
"proxy-tcp-keepalive",
"proxy-no-happy-eyeballs",
"proxy-keepalive-connections",
"proxy-keepalive-timeout",
"proxy-connection-timeout",
"proxy-expect-continue-timeout",
"http-host-header",
"origin-server-name",
"unix-socket",
"origin-ca-pool",
"no-tls-verify",
"no-chunked-encoding",
"http2-origin",
"management-hostname",
"service-op-ip",
"local-ssh-port",
"ssh-idle-timeout",
"ssh-max-timeout",
"bucket-name",
"region-name",
"s3-url-host",
"host-key-path",
"ssh-server",
"bastion",
"proxy-address",
"proxy-port",
cfdflags.LogLevel,
cfdflags.TransportLogLevel,
cfdflags.LogFile,
cfdflags.LogDirectory,
cfdflags.TraceOutput,
cfdflags.ProxyDns,
"proxy-dns-port",
"proxy-dns-address",
"proxy-dns-upstream",
"proxy-dns-max-upstream-conns",
"proxy-dns-bootstrap",
cfdflags.IsAutoUpdated,
cfdflags.Edge,
cfdflags.Region,
cfdflags.EdgeIpVersion,
cfdflags.EdgeBindAddress,
"cacert",
"hostname",
"id",
cfdflags.LBPool,
cfdflags.ApiURL,
cfdflags.MetricsUpdateFreq,
cfdflags.Tag,
"heartbeat-interval",
"heartbeat-count",
cfdflags.MaxEdgeAddrRetries,
cfdflags.Retries,
"ha-connections",
"rpc-timeout",
"write-stream-timeout",
"quic-disable-pmtu-discovery",
"quic-connection-level-flow-control-limit",
"quic-stream-level-flow-control-limit",
cfdflags.ConnectorLabel,
cfdflags.GracePeriod,
"compression-quality",
"use-reconnect-token",
"dial-edge-timeout",
"stdin-control",
cfdflags.Name,
cfdflags.Ui,
"quick-service",
"max-fetch-size",
cfdflags.PostQuantum,
"management-diagnostics",
cfdflags.Protocol,
"overwrite-dns",
"help",
cfdflags.MaxActiveFlows,
}
) )
func Flags() []cli.Flag { func Flags() []cli.Flag {
@ -181,7 +145,6 @@ func Commands() []*cli.Command {
buildDeleteCommand(), buildDeleteCommand(),
buildCleanupCommand(), buildCleanupCommand(),
buildTokenCommand(), buildTokenCommand(),
buildDiagCommand(),
// for compatibility, allow following as tunnel subcommands // for compatibility, allow following as tunnel subcommands
proxydns.Command(true), proxydns.Command(true),
cliutil.RemovedCommand("db-connect"), cliutil.RemovedCommand("db-connect"),
@ -208,7 +171,7 @@ then protect with Cloudflare Access).
B) Locally reachable TCP/UDP-based private services to Cloudflare connected private users in the same account, e.g., B) Locally reachable TCP/UDP-based private services to Cloudflare connected private users in the same account, e.g.,
those enrolled to a Zero Trust WARP Client. those enrolled to a Zero Trust WARP Client.
You can manage your Tunnels via one.dash.cloudflare.com. This approach will only require you to run a single command You can manage your Tunnels via dash.teams.cloudflare.com. This approach will only require you to run a single command
later in each machine where you wish to run a Tunnel. later in each machine where you wish to run a Tunnel.
Alternatively, you can manage your Tunnels via the command line. Begin by obtaining a certificate to be able to do so: Alternatively, you can manage your Tunnels via the command line. Begin by obtaining a certificate to be able to do so:
@ -244,7 +207,7 @@ func TunnelCommand(c *cli.Context) error {
// --name required // --name required
// --url or --hello-world required // --url or --hello-world required
// --hostname optional // --hostname optional
if name := c.String(cfdflags.Name); name != "" { if name := c.String("name"); name != "" {
hostname, err := validation.ValidateHostname(c.String("hostname")) hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil { if err != nil {
return errors.Wrap(err, "Invalid hostname provided") return errors.Wrap(err, "Invalid hostname provided")
@ -261,7 +224,7 @@ func TunnelCommand(c *cli.Context) error {
// A unauthenticated named tunnel hosted on <random>.<quick-tunnels-service>.com // A unauthenticated named tunnel hosted on <random>.<quick-tunnels-service>.com
// We don't support running proxy-dns and a quick tunnel at the same time as the same process // We don't support running proxy-dns and a quick tunnel at the same time as the same process
shouldRunQuickTunnel := c.IsSet("url") || c.IsSet(ingress.HelloWorldFlag) shouldRunQuickTunnel := c.IsSet("url") || c.IsSet(ingress.HelloWorldFlag)
if !c.IsSet(cfdflags.ProxyDns) && c.String("quick-service") != "" && shouldRunQuickTunnel { if !c.IsSet("proxy-dns") && c.String("quick-service") != "" && shouldRunQuickTunnel {
return RunQuickTunnel(sc) return RunQuickTunnel(sc)
} }
@ -272,10 +235,10 @@ func TunnelCommand(c *cli.Context) error {
// Classic tunnel usage is no longer supported // Classic tunnel usage is no longer supported
if c.String("hostname") != "" { if c.String("hostname") != "" {
return errDeprecatedClassicTunnel return deprecatedClassicTunnelErr
} }
if c.IsSet(cfdflags.ProxyDns) { if c.IsSet("proxy-dns") {
if shouldRunQuickTunnel { if shouldRunQuickTunnel {
return fmt.Errorf("running a quick tunnel with `proxy-dns` is not supported") return fmt.Errorf("running a quick tunnel with `proxy-dns` is not supported")
} }
@ -322,7 +285,7 @@ func runAdhocNamedTunnel(sc *subcommandContext, name, credentialsOutputPath stri
func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) { func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) {
if hostname := c.String("hostname"); hostname != "" { if hostname := c.String("hostname"); hostname != "" {
if lbPool := c.String(cfdflags.LBPool); lbPool != "" { if lbPool := c.String("lb-pool"); lbPool != "" {
return cfapi.NewLBRoute(hostname, lbPool), true return cfapi.NewLBRoute(hostname, lbPool), true
} }
return cfapi.NewDNSRoute(hostname, c.Bool(overwriteDNSFlagName)), true return cfapi.NewDNSRoute(hostname, c.Bool(overwriteDNSFlagName)), true
@ -352,7 +315,7 @@ func StartServer(
log.Info().Msg(config.ErrNoConfigFile.Error()) log.Info().Msg(config.ErrNoConfigFile.Error())
} }
if c.IsSet(cfdflags.TraceOutput) { if c.IsSet("trace-output") {
tmpTraceFile, err := os.CreateTemp("", "trace") tmpTraceFile, err := os.CreateTemp("", "trace")
if err != nil { if err != nil {
log.Err(err).Msg("Failed to create new temporary file to save trace output") log.Err(err).Msg("Failed to create new temporary file to save trace output")
@ -364,7 +327,7 @@ func StartServer(
if err := tmpTraceFile.Close(); err != nil { if err := tmpTraceFile.Close(); err != nil {
traceLog.Err(err).Msg("Failed to close temporary trace output file") traceLog.Err(err).Msg("Failed to close temporary trace output file")
} }
traceOutputFilepath := c.String(cfdflags.TraceOutput) traceOutputFilepath := c.String("trace-output")
if err := os.Rename(tmpTraceFile.Name(), traceOutputFilepath); err != nil { if err := os.Rename(tmpTraceFile.Name(), traceOutputFilepath); err != nil {
traceLog. traceLog.
Err(err). Err(err).
@ -394,7 +357,7 @@ func StartServer(
go waitForSignal(graceShutdownC, log) go waitForSignal(graceShutdownC, log)
if c.IsSet(cfdflags.ProxyDns) { if c.IsSet("proxy-dns") {
dnsReadySignal := make(chan struct{}) dnsReadySignal := make(chan struct{})
wg.Add(1) wg.Add(1)
go func() { go func() {
@ -416,7 +379,7 @@ func StartServer(
go func() { go func() {
defer wg.Done() defer wg.Done()
autoupdater := updater.NewAutoUpdater( autoupdater := updater.NewAutoUpdater(
c.Bool(cfdflags.NoAutoUpdate), c.Duration(cfdflags.AutoUpdateFreq), &listeners, log, c.Bool("no-autoupdate"), c.Duration("autoupdate-freq"), &listeners, log,
) )
errC <- autoupdater.Run(ctx) errC <- autoupdater.Run(ctx)
}() }()
@ -457,74 +420,54 @@ func StartServer(
// Disable ICMP packet routing for quick tunnels // Disable ICMP packet routing for quick tunnels
if quickTunnelURL != "" { if quickTunnelURL != "" {
tunnelConfig.ICMPRouterServer = nil tunnelConfig.PacketConfig = nil
} }
serviceIP := c.String("service-op-ip") internalRules := []ingress.Rule{}
if edgeAddrs, err := edgediscovery.ResolveEdge(log, tunnelConfig.Region, tunnelConfig.EdgeIPVersion); err == nil { if features.Contains(features.FeatureManagementLogs) {
if serviceAddr, err := edgeAddrs.GetAddrForRPC(); err == nil { serviceIP := c.String("service-op-ip")
serviceIP = serviceAddr.TCP.String() if edgeAddrs, err := edgediscovery.ResolveEdge(log, tunnelConfig.Region, tunnelConfig.EdgeIPVersion); err == nil {
if serviceAddr, err := edgeAddrs.GetAddrForRPC(); err == nil {
serviceIP = serviceAddr.TCP.String()
}
} }
}
mgmt := management.New( mgmt := management.New(
c.String("management-hostname"), c.String("management-hostname"),
c.Bool("management-diagnostics"), c.Bool("management-diagnostics"),
serviceIP, serviceIP,
clientID, clientID,
c.String(cfdflags.ConnectorLabel), c.String(connectorLabelFlag),
logger.ManagementLogger.Log, logger.ManagementLogger.Log,
logger.ManagementLogger, logger.ManagementLogger,
) )
internalRules := []ingress.Rule{ingress.NewManagementRule(mgmt)} internalRules = []ingress.Rule{ingress.NewManagementRule(mgmt)}
}
orchestrator, err := orchestration.NewOrchestrator(ctx, orchestratorConfig, tunnelConfig.Tags, internalRules, tunnelConfig.Log) orchestrator, err := orchestration.NewOrchestrator(ctx, orchestratorConfig, tunnelConfig.Tags, internalRules, tunnelConfig.Log)
if err != nil { if err != nil {
return err return err
} }
metricsListener, err := metrics.CreateMetricsListener(&listeners, c.String("metrics")) metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
if err != nil { if err != nil {
log.Err(err).Msg("Error opening metrics server listener") log.Err(err).Msg("Error opening metrics server listener")
return errors.Wrap(err, "Error opening metrics server listener") return errors.Wrap(err, "Error opening metrics server listener")
} }
defer metricsListener.Close() defer metricsListener.Close()
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
tracker := tunnelstate.NewConnTracker(log) readinessServer := metrics.NewReadyServer(log, clientID)
observer.RegisterSink(tracker) observer.RegisterSink(readinessServer)
ipv4, ipv6, err := determineICMPSources(c, log)
sources := make([]string, 0)
if err == nil {
sources = append(sources, ipv4.String())
sources = append(sources, ipv6.String())
}
readinessServer := metrics.NewReadyServer(clientID, tracker)
cliFlags := nonSecretCliFlags(log, c, nonSecretFlagsList)
diagnosticHandler := diagnostic.NewDiagnosticHandler(
log,
0,
diagnostic.NewSystemCollectorImpl(buildInfo.CloudflaredVersion),
tunnelConfig.NamedTunnel.Credentials.TunnelID,
clientID,
tracker,
cliFlags,
sources,
)
metricsConfig := metrics.Config{ metricsConfig := metrics.Config{
ReadyServer: readinessServer, ReadyServer: readinessServer,
DiagnosticHandler: diagnosticHandler,
QuickTunnelHostname: quickTunnelURL, QuickTunnelHostname: quickTunnelURL,
Orchestrator: orchestrator, Orchestrator: orchestrator,
} }
errC <- metrics.ServeMetrics(metricsListener, ctx, metricsConfig, log) errC <- metrics.ServeMetrics(metricsListener, ctx, metricsConfig, log)
}() }()
reconnectCh := make(chan supervisor.ReconnectSignal, c.Int(cfdflags.HaConnections)) reconnectCh := make(chan supervisor.ReconnectSignal, c.Int(haConnectionsFlag))
if c.IsSet("stdin-control") { if c.IsSet("stdin-control") {
log.Info().Msg("Enabling control through stdin") log.Info().Msg("Enabling control through stdin")
go stdinControl(reconnectCh, log) go stdinControl(reconnectCh, log)
@ -561,10 +504,8 @@ func waitToShutdown(wg *sync.WaitGroup,
log.Debug().Msg("Graceful shutdown signalled") log.Debug().Msg("Graceful shutdown signalled")
if gracePeriod > 0 { if gracePeriod > 0 {
// wait for either grace period or service termination // wait for either grace period or service termination
ticker := time.NewTicker(gracePeriod)
defer ticker.Stop()
select { select {
case <-ticker.C: case <-time.Tick(gracePeriod):
case <-errC: case <-errC:
} }
} }
@ -592,7 +533,7 @@ func waitToShutdown(wg *sync.WaitGroup,
func notifySystemd(waitForSignal *signal.Signal) { func notifySystemd(waitForSignal *signal.Signal) {
<-waitForSignal.Wait() <-waitForSignal.Wait()
_, _ = daemon.SdNotify(false, "READY=1") daemon.SdNotify(false, "READY=1")
} }
func writePidFile(waitForSignal *signal.Signal, pidPathname string, log *zerolog.Logger) { func writePidFile(waitForSignal *signal.Signal, pidPathname string, log *zerolog.Logger) {
@ -644,31 +585,31 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
flags = append(flags, []cli.Flag{ flags = append(flags, []cli.Flag{
credentialsFileFlag, credentialsFileFlag,
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: cfdflags.IsAutoUpdated, Name: "is-autoupdated",
Usage: "Signal the new process that Cloudflare Tunnel connector has been autoupdated", Usage: "Signal the new process that Cloudflare Tunnel connector has been autoupdated",
Value: false, Value: false,
Hidden: true, Hidden: true,
}), }),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: cfdflags.Edge, Name: "edge",
Usage: "Address of the Cloudflare tunnel server. Only works in Cloudflare's internal testing environment.", Usage: "Address of the Cloudflare tunnel server. Only works in Cloudflare's internal testing environment.",
EnvVars: []string{"TUNNEL_EDGE"}, EnvVars: []string{"TUNNEL_EDGE"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.Region, Name: "region",
Usage: "Cloudflare Edge region to connect to. Omit or set to empty to connect to the global region.", Usage: "Cloudflare Edge region to connect to. Omit or set to empty to connect to the global region.",
EnvVars: []string{"TUNNEL_REGION"}, EnvVars: []string{"TUNNEL_REGION"},
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.EdgeIpVersion, Name: "edge-ip-version",
Usage: "Cloudflare Edge IP address version to connect with. {4, 6, auto}", Usage: "Cloudflare Edge IP address version to connect with. {4, 6, auto}",
EnvVars: []string{"TUNNEL_EDGE_IP_VERSION"}, EnvVars: []string{"TUNNEL_EDGE_IP_VERSION"},
Value: "4", Value: "4",
Hidden: false, Hidden: false,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.EdgeBindAddress, Name: "edge-bind-address",
Usage: "Bind to IP address for outgoing connections to Cloudflare Edge.", Usage: "Bind to IP address for outgoing connections to Cloudflare Edge.",
EnvVars: []string{"TUNNEL_EDGE_BIND_ADDRESS"}, EnvVars: []string{"TUNNEL_EDGE_BIND_ADDRESS"},
Hidden: false, Hidden: false,
@ -692,7 +633,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: true, Hidden: true,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.LBPool, Name: "lb-pool",
Usage: "The name of a (new/existing) load balancing pool to add this origin to.", Usage: "The name of a (new/existing) load balancing pool to add this origin to.",
EnvVars: []string{"TUNNEL_LB_POOL"}, EnvVars: []string{"TUNNEL_LB_POOL"},
Hidden: shouldHide, Hidden: shouldHide,
@ -716,21 +657,21 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: true, Hidden: true,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.ApiURL, Name: "api-url",
Usage: "Base URL for Cloudflare API v4", Usage: "Base URL for Cloudflare API v4",
EnvVars: []string{"TUNNEL_API_URL"}, EnvVars: []string{"TUNNEL_API_URL"},
Value: "https://api.cloudflare.com/client/v4", Value: "https://api.cloudflare.com/client/v4",
Hidden: true, Hidden: true,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: cfdflags.MetricsUpdateFreq, Name: "metrics-update-freq",
Usage: "Frequency to update tunnel metrics", Usage: "Frequency to update tunnel metrics",
Value: time.Second * 5, Value: time.Second * 5,
EnvVars: []string{"TUNNEL_METRICS_UPDATE_FREQ"}, EnvVars: []string{"TUNNEL_METRICS_UPDATE_FREQ"},
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: cfdflags.Tag, Name: "tag",
Usage: "Custom tags used to identify this tunnel via added HTTP request headers to the origin, in format `KEY=VALUE`. Multiple tags may be specified.", Usage: "Custom tags used to identify this tunnel via added HTTP request headers to the origin, in format `KEY=VALUE`. Multiple tags may be specified.",
EnvVars: []string{"TUNNEL_TAG"}, EnvVars: []string{"TUNNEL_TAG"},
Hidden: true, Hidden: true,
@ -749,64 +690,64 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: true, Hidden: true,
}), }),
altsrc.NewIntFlag(&cli.IntFlag{ altsrc.NewIntFlag(&cli.IntFlag{
Name: cfdflags.MaxEdgeAddrRetries, Name: "max-edge-addr-retries",
Usage: "Maximum number of times to retry on edge addrs before falling back to a lower protocol", Usage: "Maximum number of times to retry on edge addrs before falling back to a lower protocol",
Value: 8, Value: 8,
Hidden: true, Hidden: true,
}), }),
// Note TUN-3758 , we use Int because UInt is not supported with altsrc // Note TUN-3758 , we use Int because UInt is not supported with altsrc
altsrc.NewIntFlag(&cli.IntFlag{ altsrc.NewIntFlag(&cli.IntFlag{
Name: cfdflags.Retries, Name: "retries",
Value: 5, Value: 5,
Usage: "Maximum number of retries for connection/protocol errors.", Usage: "Maximum number of retries for connection/protocol errors.",
EnvVars: []string{"TUNNEL_RETRIES"}, EnvVars: []string{"TUNNEL_RETRIES"},
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewIntFlag(&cli.IntFlag{ altsrc.NewIntFlag(&cli.IntFlag{
Name: cfdflags.HaConnections, Name: haConnectionsFlag,
Value: 4, Value: 4,
Hidden: true, Hidden: true,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: cfdflags.RpcTimeout, Name: rpcTimeout,
Value: 5 * time.Second, Value: 5 * time.Second,
Hidden: true, Hidden: true,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: cfdflags.WriteStreamTimeout, Name: writeStreamTimeout,
EnvVars: []string{"TUNNEL_STREAM_WRITE_TIMEOUT"}, EnvVars: []string{"TUNNEL_STREAM_WRITE_TIMEOUT"},
Usage: "Use this option to add a stream write timeout for connections when writing towards the origin or edge. Default is 0 which disables the write timeout.", Usage: "Use this option to add a stream write timeout for connections when writing towards the origin or edge. Default is 0 which disables the write timeout.",
Value: 0 * time.Second, Value: 0 * time.Second,
Hidden: true, Hidden: true,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: cfdflags.QuicDisablePathMTUDiscovery, Name: quicDisablePathMTUDiscovery,
EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"}, EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"},
Usage: "Use this option to disable PTMU discovery for QUIC connections. This will result in lower packet sizes. Not however, that this may cause instability for UDP proxying.", Usage: "Use this option to disable PTMU discovery for QUIC connections. This will result in lower packet sizes. Not however, that this may cause instability for UDP proxying.",
Value: false, Value: false,
Hidden: true, Hidden: true,
}), }),
altsrc.NewIntFlag(&cli.IntFlag{ altsrc.NewIntFlag(&cli.IntFlag{
Name: cfdflags.QuicConnLevelFlowControlLimit, Name: quicConnLevelFlowControlLimit,
EnvVars: []string{"TUNNEL_QUIC_CONN_LEVEL_FLOW_CONTROL_LIMIT"}, EnvVars: []string{"TUNNEL_QUIC_CONN_LEVEL_FLOW_CONTROL_LIMIT"},
Usage: "Use this option to change the connection-level flow control limit for QUIC transport.", Usage: "Use this option to change the connection-level flow control limit for QUIC transport.",
Value: 30 * (1 << 20), // 30 MB Value: 30 * (1 << 20), // 30 MB
Hidden: true, Hidden: true,
}), }),
altsrc.NewIntFlag(&cli.IntFlag{ altsrc.NewIntFlag(&cli.IntFlag{
Name: cfdflags.QuicStreamLevelFlowControlLimit, Name: quicStreamLevelFlowControlLimit,
EnvVars: []string{"TUNNEL_QUIC_STREAM_LEVEL_FLOW_CONTROL_LIMIT"}, EnvVars: []string{"TUNNEL_QUIC_STREAM_LEVEL_FLOW_CONTROL_LIMIT"},
Usage: "Use this option to change the connection-level flow control limit for QUIC transport.", Usage: "Use this option to change the connection-level flow control limit for QUIC transport.",
Value: 6 * (1 << 20), // 6 MB Value: 6 * (1 << 20), // 6 MB
Hidden: true, Hidden: true,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.ConnectorLabel, Name: connectorLabelFlag,
Usage: "Use this option to give a meaningful label to a specific connector. When a tunnel starts up, a connector id unique to the tunnel is generated. This is a uuid. To make it easier to identify a connector, we will use the hostname of the machine the tunnel is running on along with the connector ID. This option exists if one wants to have more control over what their individual connectors are called.", Usage: "Use this option to give a meaningful label to a specific connector. When a tunnel starts up, a connector id unique to the tunnel is generated. This is a uuid. To make it easier to identify a connector, we will use the hostname of the machine the tunnel is running on along with the connector ID. This option exists if one wants to have more control over what their individual connectors are called.",
Value: "", Value: "",
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: cfdflags.GracePeriod, Name: "grace-period",
Usage: "When cloudflared receives SIGINT/SIGTERM it will stop accepting new requests, wait for in-progress requests to terminate, then shutdown. Waiting for in-progress requests will timeout after this grace period, or when a second SIGTERM/SIGINT is received.", Usage: "When cloudflared receives SIGINT/SIGTERM it will stop accepting new requests, wait for in-progress requests to terminate, then shutdown. Waiting for in-progress requests will timeout after this grace period, or when a second SIGTERM/SIGINT is received.",
Value: time.Second * 30, Value: time.Second * 30,
EnvVars: []string{"TUNNEL_GRACE_PERIOD"}, EnvVars: []string{"TUNNEL_GRACE_PERIOD"},
@ -842,14 +783,14 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Value: false, Value: false,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.Name, Name: "name",
Aliases: []string{"n"}, Aliases: []string{"n"},
EnvVars: []string{"TUNNEL_NAME"}, EnvVars: []string{"TUNNEL_NAME"},
Usage: "Stable name to identify the tunnel. Using this flag will create, route and run a tunnel. For production usage, execute each command separately", Usage: "Stable name to identify the tunnel. Using this flag will create, route and run a tunnel. For production usage, execute each command separately",
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: cfdflags.Ui, Name: uiFlag,
Usage: "(depreciated) Launch tunnel UI. Tunnel logs are scrollable via 'j', 'k', or arrow keys.", Usage: "(depreciated) Launch tunnel UI. Tunnel logs are scrollable via 'j', 'k', or arrow keys.",
Value: false, Value: false,
Hidden: true, Hidden: true,
@ -867,10 +808,11 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: true, Hidden: true,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: cfdflags.PostQuantum, Name: "post-quantum",
Usage: "When given creates an experimental post-quantum secure tunnel", Usage: "When given creates an experimental post-quantum secure tunnel",
Aliases: []string{"pq"}, Aliases: []string{"pq"},
EnvVars: []string{"TUNNEL_POST_QUANTUM"}, EnvVars: []string{"TUNNEL_POST_QUANTUM"},
Hidden: FipsEnabled,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "management-diagnostics", Name: "management-diagnostics",
@ -895,35 +837,29 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide, Hidden: shouldHide,
}, },
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.OriginCert, Name: credentials.OriginCertFlag,
Usage: "Path to the certificate generated for your origin when you run cloudflared login.", Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
EnvVars: []string{"TUNNEL_ORIGIN_CERT"}, EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: credentials.FindDefaultOriginCertPath(), Value: credentials.FindDefaultOriginCertPath(),
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: cfdflags.AutoUpdateFreq, Name: "autoupdate-freq",
Usage: fmt.Sprintf("Autoupdate frequency. Default is %v.", updater.DefaultCheckUpdateFreq), Usage: fmt.Sprintf("Autoupdate frequency. Default is %v.", updater.DefaultCheckUpdateFreq),
Value: updater.DefaultCheckUpdateFreq, Value: updater.DefaultCheckUpdateFreq,
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: cfdflags.NoAutoUpdate, Name: "no-autoupdate",
Usage: "Disable periodic check for updates, restarting the server with the new version.", Usage: "Disable periodic check for updates, restarting the server with the new version.",
EnvVars: []string{"NO_AUTOUPDATE"}, EnvVars: []string{"NO_AUTOUPDATE"},
Value: false, Value: false,
Hidden: shouldHide, Hidden: shouldHide,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.Metrics, Name: "metrics",
Value: metrics.GetMetricsDefaultAddress(metrics.Runtime), Value: "localhost:",
Usage: fmt.Sprintf( Usage: "Listen address for metrics reporting.",
`Listen address for metrics reporting. If no address is passed cloudflared will try to bind to %v.
If all are unavailable, a random port will be used. Note that when running cloudflared from an virtual
environment the default address binds to all interfaces, hence, it is important to isolate the host
and virtualized host network stacks from each other`,
metrics.GetMetricsKnownAddresses(metrics.Runtime),
),
EnvVars: []string{"TUNNEL_METRICS"}, EnvVars: []string{"TUNNEL_METRICS"},
Hidden: shouldHide, Hidden: shouldHide,
}), }),
@ -1079,62 +1015,62 @@ func legacyTunnelFlag(msg string) string {
func sshFlags(shouldHide bool) []cli.Flag { func sshFlags(shouldHide bool) []cli.Flag {
return []cli.Flag{ return []cli.Flag{
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.SshPort, Name: sshPortFlag,
Usage: "Localhost port that cloudflared SSH server will run on", Usage: "Localhost port that cloudflared SSH server will run on",
Value: "2222", Value: "2222",
EnvVars: []string{"LOCAL_SSH_PORT"}, EnvVars: []string{"LOCAL_SSH_PORT"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: cfdflags.SshIdleTimeout, Name: sshIdleTimeoutFlag,
Usage: "Connection timeout after no activity", Usage: "Connection timeout after no activity",
EnvVars: []string{"SSH_IDLE_TIMEOUT"}, EnvVars: []string{"SSH_IDLE_TIMEOUT"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{ altsrc.NewDurationFlag(&cli.DurationFlag{
Name: cfdflags.SshMaxTimeout, Name: sshMaxTimeoutFlag,
Usage: "Absolute connection timeout", Usage: "Absolute connection timeout",
EnvVars: []string{"SSH_MAX_TIMEOUT"}, EnvVars: []string{"SSH_MAX_TIMEOUT"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.SshLogUploaderBucketName, Name: bucketNameFlag,
Usage: "Bucket name of where to upload SSH logs", Usage: "Bucket name of where to upload SSH logs",
EnvVars: []string{"BUCKET_ID"}, EnvVars: []string{"BUCKET_ID"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.SshLogUploaderRegionName, Name: regionNameFlag,
Usage: "Region name of where to upload SSH logs", Usage: "Region name of where to upload SSH logs",
EnvVars: []string{"REGION_ID"}, EnvVars: []string{"REGION_ID"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.SshLogUploaderSecretID, Name: secretIDFlag,
Usage: "Secret ID of where to upload SSH logs", Usage: "Secret ID of where to upload SSH logs",
EnvVars: []string{"SECRET_ID"}, EnvVars: []string{"SECRET_ID"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.SshLogUploaderAccessKeyID, Name: accessKeyIDFlag,
Usage: "Access Key ID of where to upload SSH logs", Usage: "Access Key ID of where to upload SSH logs",
EnvVars: []string{"ACCESS_CLIENT_ID"}, EnvVars: []string{"ACCESS_CLIENT_ID"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.SshLogUploaderSessionTokenID, Name: sessionTokenIDFlag,
Usage: "Session Token to use in the configuration of SSH logs uploading", Usage: "Session Token to use in the configuration of SSH logs uploading",
EnvVars: []string{"SESSION_TOKEN_ID"}, EnvVars: []string{"SESSION_TOKEN_ID"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewStringFlag(&cli.StringFlag{ altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.SshLogUploaderS3URL, Name: s3URLFlag,
Usage: "S3 url of where to upload SSH logs", Usage: "S3 url of where to upload SSH logs",
EnvVars: []string{"S3_URL"}, EnvVars: []string{"S3_URL"},
Hidden: true, Hidden: true,
}), }),
altsrc.NewPathFlag(&cli.PathFlag{ altsrc.NewPathFlag(&cli.PathFlag{
Name: cfdflags.HostKeyPath, Name: hostKeyPath,
Usage: "Absolute path of directory to save SSH host keys in", Usage: "Absolute path of directory to save SSH host keys in",
EnvVars: []string{"HOST_KEY_PATH"}, EnvVars: []string{"HOST_KEY_PATH"},
Hidden: true, Hidden: true,
@ -1174,7 +1110,7 @@ func sshFlags(shouldHide bool) []cli.Flag {
func configureProxyDNSFlags(shouldHide bool) []cli.Flag { func configureProxyDNSFlags(shouldHide bool) []cli.Flag {
return []cli.Flag{ return []cli.Flag{
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: cfdflags.ProxyDns, Name: "proxy-dns",
Usage: "Run a DNS over HTTPS proxy server.", Usage: "Run a DNS over HTTPS proxy server.",
EnvVars: []string{"TUNNEL_DNS"}, EnvVars: []string{"TUNNEL_DNS"},
Hidden: shouldHide, Hidden: shouldHide,
@ -1254,46 +1190,3 @@ reconnect [delay]
} }
} }
} }
func nonSecretCliFlags(log *zerolog.Logger, cli *cli.Context, flagInclusionList []string) map[string]string {
flagsNames := cli.FlagNames()
flags := make(map[string]string, len(flagsNames))
for _, flag := range flagsNames {
value := cli.String(flag)
if value == "" {
continue
}
isIncluded := isFlagIncluded(flagInclusionList, flag)
if !isIncluded {
continue
}
switch flag {
case cfdflags.LogDirectory, cfdflags.LogFile:
{
absolute, err := filepath.Abs(value)
if err != nil {
log.Error().Err(err).Msgf("could not convert %s path to absolute", flag)
} else {
flags[flag] = absolute
}
}
default:
flags[flag] = value
}
}
return flags
}
func isFlagIncluded(flagInclusionList []string, flag string) bool {
for _, include := range flagInclusionList {
if include == flag {
return true
}
}
return false
}

View File

@ -18,7 +18,6 @@ import (
"golang.org/x/term" "golang.org/x/term"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery" "github.com/cloudflare/cloudflared/edgediscovery"
@ -34,28 +33,27 @@ import (
const ( const (
secretValue = "*****" secretValue = "*****"
icmpFunnelTimeout = time.Second * 10 icmpFunnelTimeout = time.Second * 10
fedRampRegion = "fed" // const string denoting the region used to connect to FEDRamp servers
) )
var ( var (
developerPortal = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup"
serviceUrl = developerPortal + "/tunnel-guide/local/as-a-service/"
argumentsUrl = developerPortal + "/tunnel-guide/local/local-management/arguments/"
secretFlags = [2]*altsrc.StringFlag{credentialsContentsFlag, tunnelTokenFlag} secretFlags = [2]*altsrc.StringFlag{credentialsContentsFlag, tunnelTokenFlag}
configFlags = []string{ configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address"}
flags.AutoUpdateFreq,
flags.NoAutoUpdate,
flags.Retries,
flags.Protocol,
flags.LogLevel,
flags.TransportLogLevel,
flags.OriginCert,
flags.Metrics,
flags.MetricsUpdateFreq,
flags.EdgeIpVersion,
flags.EdgeBindAddress,
flags.MaxActiveFlows,
}
) )
func generateRandomClientID(log *zerolog.Logger) (string, error) {
u, err := uuid.NewRandom()
if err != nil {
log.Error().Msgf("couldn't create UUID for client ID %s", err)
return "", err
}
return u.String(), nil
}
func logClientOptions(c *cli.Context, log *zerolog.Logger) { func logClientOptions(c *cli.Context, log *zerolog.Logger) {
flags := make(map[string]interface{}) flags := make(map[string]interface{})
for _, flag := range c.FlagNames() { for _, flag := range c.FlagNames() {
@ -111,8 +109,8 @@ func isSecretEnvVar(key string) bool {
} }
func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.TunnelProperties) bool { func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.TunnelProperties) bool {
return c.IsSet(flags.ProxyDns) && return c.IsSet("proxy-dns") &&
!(c.IsSet(flags.Name) || // adhoc-named tunnel !(c.IsSet("name") || // adhoc-named tunnel
c.IsSet(ingress.HelloWorldFlag) || // quick or named tunnel c.IsSet(ingress.HelloWorldFlag) || // quick or named tunnel
namedTunnel != nil) // named tunnel namedTunnel != nil) // named tunnel
} }
@ -130,21 +128,29 @@ func prepareTunnelConfig(
return nil, nil, errors.Wrap(err, "can't generate connector UUID") return nil, nil, errors.Wrap(err, "can't generate connector UUID")
} }
log.Info().Msgf("Generated Connector ID: %s", clientID) log.Info().Msgf("Generated Connector ID: %s", clientID)
tags, err := NewTagSliceFromCLI(c.StringSlice(flags.Tag)) tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
if err != nil { if err != nil {
log.Err(err).Msg("Tag parse failure") log.Err(err).Msg("Tag parse failure")
return nil, nil, errors.Wrap(err, "Tag parse failure") return nil, nil, errors.Wrap(err, "Tag parse failure")
} }
tags = append(tags, pogs.Tag{Name: "ID", Value: clientID.String()}) tags = append(tags, pogs.Tag{Name: "ID", Value: clientID.String()})
transportProtocol := c.String(flags.Protocol) transportProtocol := c.String("protocol")
isPostQuantumEnforced := c.Bool(flags.PostQuantum)
featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, c.StringSlice("features"), c.Bool("post-quantum"), log) clientFeatures := features.Dedup(append(c.StringSlice("features"), features.DefaultFeatures...))
staticFeatures := features.StaticFeatures{}
if c.Bool("post-quantum") {
if FipsEnabled {
return nil, nil, fmt.Errorf("post-quantum not supported in FIPS mode")
}
pqMode := features.PostQuantumStrict
staticFeatures.PostQuantumMode = &pqMode
}
featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, staticFeatures, log)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "Failed to create feature selector") return nil, nil, errors.Wrap(err, "Failed to create feature selector")
} }
clientFeatures := featureSelector.ClientFeatures()
pqMode := featureSelector.PostQuantumMode() pqMode := featureSelector.PostQuantumMode()
if pqMode == features.PostQuantumStrict { if pqMode == features.PostQuantumStrict {
// Error if the user tries to force a non-quic transport protocol // Error if the user tries to force a non-quic transport protocol
@ -152,6 +158,12 @@ func prepareTunnelConfig(
return nil, nil, fmt.Errorf("post-quantum is only supported with the quic transport") return nil, nil, fmt.Errorf("post-quantum is only supported with the quic transport")
} }
transportProtocol = connection.QUIC.String() transportProtocol = connection.QUIC.String()
clientFeatures = append(clientFeatures, features.FeaturePostQuantum)
log.Info().Msgf(
"Using hybrid post-quantum key agreement %s",
supervisor.PQKexName,
)
} }
namedTunnel.Client = pogs.ClientInfo{ namedTunnel.Client = pogs.ClientInfo{
@ -166,7 +178,7 @@ func prepareTunnelConfig(
return nil, nil, err return nil, nil, err
} }
protocolSelector, err := connection.NewProtocolSelector(transportProtocol, namedTunnel.Credentials.AccountTag, c.IsSet(TunnelTokenFlag), isPostQuantumEnforced, edgediscovery.ProtocolPercentage, connection.ResolveTTL, log) protocolSelector, err := connection.NewProtocolSelector(transportProtocol, namedTunnel.Credentials.AccountTag, c.IsSet(TunnelTokenFlag), c.Bool("post-quantum"), edgediscovery.ProtocolPercentage, connection.ResolveTTL, log)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -192,11 +204,11 @@ func prepareTunnelConfig(
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
edgeIPVersion, err := parseConfigIPVersion(c.String(flags.EdgeIpVersion)) edgeIPVersion, err := parseConfigIPVersion(c.String("edge-ip-version"))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
edgeBindAddr, err := parseConfigBindAddress(c.String(flags.EdgeBindAddress)) edgeBindAddr, err := parseConfigBindAddress(c.String("edge-bind-address"))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -209,62 +221,48 @@ func prepareTunnelConfig(
log.Warn().Str("edgeIPVersion", edgeIPVersion.String()).Err(err).Msg("Overriding edge-ip-version") log.Warn().Str("edgeIPVersion", edgeIPVersion.String()).Err(err).Msg("Overriding edge-ip-version")
} }
region := c.String(flags.Region)
endpoint := namedTunnel.Credentials.Endpoint
var resolvedRegion string
// set resolvedRegion to either the region passed as argument
// or to the endpoint in the credentials.
// Region and endpoint are interchangeable
if region != "" && endpoint != "" {
return nil, nil, fmt.Errorf("region provided with a token that has an endpoint")
} else if region != "" {
resolvedRegion = region
} else if endpoint != "" {
resolvedRegion = endpoint
}
tunnelConfig := &supervisor.TunnelConfig{ tunnelConfig := &supervisor.TunnelConfig{
GracePeriod: gracePeriod, GracePeriod: gracePeriod,
ReplaceExisting: c.Bool(flags.Force), ReplaceExisting: c.Bool("force"),
OSArch: info.OSArch(), OSArch: info.OSArch(),
ClientID: clientID.String(), ClientID: clientID.String(),
EdgeAddrs: c.StringSlice(flags.Edge), EdgeAddrs: c.StringSlice("edge"),
Region: resolvedRegion, Region: c.String("region"),
EdgeIPVersion: edgeIPVersion, EdgeIPVersion: edgeIPVersion,
EdgeBindAddr: edgeBindAddr, EdgeBindAddr: edgeBindAddr,
HAConnections: c.Int(flags.HaConnections), HAConnections: c.Int(haConnectionsFlag),
IsAutoupdated: c.Bool(flags.IsAutoUpdated), IsAutoupdated: c.Bool("is-autoupdated"),
LBPool: c.String(flags.LBPool), LBPool: c.String("lb-pool"),
Tags: tags, Tags: tags,
Log: log, Log: log,
LogTransport: logTransport, LogTransport: logTransport,
Observer: observer, Observer: observer,
ReportedVersion: info.Version(), ReportedVersion: info.Version(),
// Note TUN-3758 , we use Int because UInt is not supported with altsrc // Note TUN-3758 , we use Int because UInt is not supported with altsrc
Retries: uint(c.Int(flags.Retries)), // nolint: gosec Retries: uint(c.Int("retries")),
RunFromTerminal: isRunningFromTerminal(), RunFromTerminal: isRunningFromTerminal(),
NamedTunnel: namedTunnel, NamedTunnel: namedTunnel,
ProtocolSelector: protocolSelector, ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs, EdgeTLSConfigs: edgeTLSConfigs,
FeatureSelector: featureSelector, FeatureSelector: featureSelector,
MaxEdgeAddrRetries: uint8(c.Int(flags.MaxEdgeAddrRetries)), // nolint: gosec MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
RPCTimeout: c.Duration(flags.RpcTimeout), RPCTimeout: c.Duration(rpcTimeout),
WriteStreamTimeout: c.Duration(flags.WriteStreamTimeout), WriteStreamTimeout: c.Duration(writeStreamTimeout),
DisableQUICPathMTUDiscovery: c.Bool(flags.QuicDisablePathMTUDiscovery), DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery),
QUICConnectionLevelFlowControlLimit: c.Uint64(flags.QuicConnLevelFlowControlLimit), QUICConnectionLevelFlowControlLimit: c.Uint64(quicConnLevelFlowControlLimit),
QUICStreamLevelFlowControlLimit: c.Uint64(flags.QuicStreamLevelFlowControlLimit), QUICStreamLevelFlowControlLimit: c.Uint64(quicStreamLevelFlowControlLimit),
} }
icmpRouter, err := newICMPRouter(c, log) packetConfig, err := newPacketConfig(c, log)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("ICMP proxy feature is disabled") log.Warn().Err(err).Msg("ICMP proxy feature is disabled")
} else { } else {
tunnelConfig.ICMPRouterServer = icmpRouter tunnelConfig.PacketConfig = packetConfig
} }
orchestratorConfig := &orchestration.Config{ orchestratorConfig := &orchestration.Config{
Ingress: &ingressRules, Ingress: &ingressRules,
WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting), WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting),
ConfigurationFlags: parseConfigFlags(c), ConfigurationFlags: parseConfigFlags(c),
WriteTimeout: tunnelConfig.WriteStreamTimeout, WriteTimeout: c.Duration(writeStreamTimeout),
} }
return tunnelConfig, orchestratorConfig, nil return tunnelConfig, orchestratorConfig, nil
} }
@ -282,9 +280,9 @@ func parseConfigFlags(c *cli.Context) map[string]string {
} }
func gracePeriod(c *cli.Context) (time.Duration, error) { func gracePeriod(c *cli.Context) (time.Duration, error) {
period := c.Duration(flags.GracePeriod) period := c.Duration("grace-period")
if period > connection.MaxGracePeriod { if period > connection.MaxGracePeriod {
return time.Duration(0), fmt.Errorf("%s must be equal or less than %v", flags.GracePeriod, connection.MaxGracePeriod) return time.Duration(0), fmt.Errorf("grace-period must be equal or less than %v", connection.MaxGracePeriod)
} }
return period, nil return period, nil
} }
@ -353,39 +351,33 @@ func adjustIPVersionByBindAddress(ipVersion allregions.ConfigIPVersion, ip net.I
} }
} }
func newICMPRouter(c *cli.Context, logger *zerolog.Logger) (ingress.ICMPRouterServer, error) { func newPacketConfig(c *cli.Context, logger *zerolog.Logger) (*ingress.GlobalRouterConfig, error) {
ipv4Src, ipv6Src, err := determineICMPSources(c, logger) ipv4Src, err := determineICMPv4Src(c.String("icmpv4-src"), logger)
if err != nil { if err != nil {
return nil, err return nil, errors.Wrap(err, "failed to determine IPv4 source address for ICMP proxy")
} }
icmpRouter, err := ingress.NewICMPRouter(ipv4Src, ipv6Src, logger, icmpFunnelTimeout)
if err != nil {
return nil, err
}
return icmpRouter, nil
}
func determineICMPSources(c *cli.Context, logger *zerolog.Logger) (netip.Addr, netip.Addr, error) {
ipv4Src, err := determineICMPv4Src(c.String(flags.ICMPV4Src), logger)
if err != nil {
return netip.Addr{}, netip.Addr{}, errors.Wrap(err, "failed to determine IPv4 source address for ICMP proxy")
}
logger.Info().Msgf("ICMP proxy will use %s as source for IPv4", ipv4Src) logger.Info().Msgf("ICMP proxy will use %s as source for IPv4", ipv4Src)
ipv6Src, zone, err := determineICMPv6Src(c.String(flags.ICMPV6Src), logger, ipv4Src) ipv6Src, zone, err := determineICMPv6Src(c.String("icmpv6-src"), logger, ipv4Src)
if err != nil { if err != nil {
return netip.Addr{}, netip.Addr{}, errors.Wrap(err, "failed to determine IPv6 source address for ICMP proxy") return nil, errors.Wrap(err, "failed to determine IPv6 source address for ICMP proxy")
} }
if zone != "" { if zone != "" {
logger.Info().Msgf("ICMP proxy will use %s in zone %s as source for IPv6", ipv6Src, zone) logger.Info().Msgf("ICMP proxy will use %s in zone %s as source for IPv6", ipv6Src, zone)
} else { } else {
logger.Info().Msgf("ICMP proxy will use %s as source for IPv6", ipv6Src) logger.Info().Msgf("ICMP proxy will use %s as source for IPv6", ipv6Src)
} }
return ipv4Src, ipv6Src, nil icmpRouter, err := ingress.NewICMPRouter(ipv4Src, ipv6Src, zone, logger, icmpFunnelTimeout)
if err != nil {
return nil, err
}
return &ingress.GlobalRouterConfig{
ICMPRouter: icmpRouter,
IPv4Src: ipv4Src,
IPv6Src: ipv6Src,
Zone: zone,
}, nil
} }
func determineICMPv4Src(userDefinedSrc string, logger *zerolog.Logger) (netip.Addr, error) { func determineICMPv4Src(userDefinedSrc string, logger *zerolog.Logger) (netip.Addr, error) {
@ -415,12 +407,13 @@ type interfaceIP struct {
func determineICMPv6Src(userDefinedSrc string, logger *zerolog.Logger, ipv4Src netip.Addr) (addr netip.Addr, zone string, err error) { func determineICMPv6Src(userDefinedSrc string, logger *zerolog.Logger, ipv4Src netip.Addr) (addr netip.Addr, zone string, err error) {
if userDefinedSrc != "" { if userDefinedSrc != "" {
addr, err := netip.ParseAddr(userDefinedSrc) userDefinedIP, zone, _ := strings.Cut(userDefinedSrc, "%")
addr, err := netip.ParseAddr(userDefinedIP)
if err != nil { if err != nil {
return netip.Addr{}, "", err return netip.Addr{}, "", err
} }
if addr.Is6() { if addr.Is6() {
return addr, addr.Zone(), nil return addr, zone, nil
} }
return netip.Addr{}, "", fmt.Errorf("expect IPv6, but %s is IPv4", userDefinedSrc) return netip.Addr{}, "", fmt.Errorf("expect IPv6, but %s is IPv4", userDefinedSrc)
} }

View File

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"path/filepath" "path/filepath"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/credentials" "github.com/cloudflare/cloudflared/credentials"
@ -58,7 +57,7 @@ func newSearchByID(id uuid.UUID, c *cli.Context, log *zerolog.Logger, fs fileSys
} }
func (s searchByID) Path() (string, error) { func (s searchByID) Path() (string, error) {
originCertPath := s.c.String(cfdflags.OriginCert) originCertPath := s.c.String(credentials.OriginCertFlag)
originCertLog := s.log.With(). originCertLog := s.log.With().
Str("originCertPath", originCertPath). Str("originCertPath", originCertPath).
Logger() Logger()

View File

@ -0,0 +1,3 @@
package tunnel
var FipsEnabled bool

View File

@ -19,32 +19,8 @@ import (
) )
const ( const (
baseLoginURL = "https://dash.cloudflare.com/argotunnel" baseLoginURL = "https://dash.cloudflare.com/argotunnel"
callbackURL = "https://login.cloudflareaccess.org/" callbackStoreURL = "https://login.cloudflareaccess.org/"
// For now these are the same but will change in the future once we know which URLs to use (TUN-8872)
fedBaseLoginURL = "https://dash.cloudflare.com/argotunnel"
fedCallbackStoreURL = "https://login.cloudflareaccess.org/"
fedRAMPParamName = "fedramp"
loginURLParamName = "loginURL"
callbackURLParamName = "callbackURL"
)
var (
loginURL = &cli.StringFlag{
Name: loginURLParamName,
Value: baseLoginURL,
Usage: "The URL used to login (default is https://dash.cloudflare.com/argotunnel)",
}
callbackStore = &cli.StringFlag{
Name: callbackURLParamName,
Value: callbackURL,
Usage: "The URL used for the callback (default is https://login.cloudflareaccess.org/)",
}
fedramp = &cli.BoolFlag{
Name: fedRAMPParamName,
Aliases: []string{"f"},
Usage: "Login with FedRAMP High environment.",
}
) )
func buildLoginSubcommand(hidden bool) *cli.Command { func buildLoginSubcommand(hidden bool) *cli.Command {
@ -54,11 +30,6 @@ func buildLoginSubcommand(hidden bool) *cli.Command {
Usage: "Generate a configuration file with your login details", Usage: "Generate a configuration file with your login details",
ArgsUsage: " ", ArgsUsage: " ",
Hidden: hidden, Hidden: hidden,
Flags: []cli.Flag{
loginURL,
callbackStore,
fedramp,
},
} }
} }
@ -67,25 +38,15 @@ func login(c *cli.Context) error {
path, ok, err := checkForExistingCert() path, ok, err := checkForExistingCert()
if ok { if ok {
log.Error().Err(err).Msgf("You have an existing certificate at %s which login would overwrite.\nIf this is intentional, please move or delete that file then run this command again.\n", path) fmt.Fprintf(os.Stdout, "You have an existing certificate at %s which login would overwrite.\nIf this is intentional, please move or delete that file then run this command again.\n", path)
return nil return nil
} else if err != nil { } else if err != nil {
return err return err
} }
var ( loginURL, err := url.Parse(baseLoginURL)
baseloginURL = c.String(loginURLParamName)
callbackStoreURL = c.String(callbackURLParamName)
)
isFEDRamp := c.Bool(fedRAMPParamName)
if isFEDRamp {
baseloginURL = fedBaseLoginURL
callbackStoreURL = fedCallbackStoreURL
}
loginURL, err := url.Parse(baseloginURL)
if err != nil { if err != nil {
// shouldn't happen, URL is hardcoded
return err return err
} }
@ -100,23 +61,7 @@ func login(c *cli.Context) error {
log, log,
) )
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Failed to write the certificate.\n\nYour browser will download the certificate instead. You will have to manually\ncopy it to the following path:\n\n%s\n", path) fmt.Fprintf(os.Stderr, "Failed to write the certificate due to the following error:\n%v\n\nYour browser will download the certificate instead. You will have to manually\ncopy it to the following path:\n\n%s\n", err, path)
return err
}
cert, err := credentials.DecodeOriginCert(resourceData)
if err != nil {
log.Error().Err(err).Msg("failed to decode origin certificate")
return err
}
if isFEDRamp {
cert.Endpoint = credentials.FedEndpoint
}
resourceData, err = cert.EncodeOriginCert()
if err != nil {
log.Error().Err(err).Msg("failed to encode origin certificate")
return err return err
} }
@ -124,7 +69,7 @@ func login(c *cli.Context) error {
return errors.Wrap(err, fmt.Sprintf("error writing cert to %s", path)) return errors.Wrap(err, fmt.Sprintf("error writing cert to %s", path))
} }
log.Info().Msgf("You have successfully logged in.\nIf you wish to copy your credentials to a server, they have been saved to:\n%s\n", path) fmt.Fprintf(os.Stdout, "You have successfully logged in.\nIf you wish to copy your credentials to a server, they have been saved to:\n%s\n", path)
return nil return nil
} }

View File

@ -11,7 +11,6 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
) )
@ -83,13 +82,13 @@ func RunQuickTunnel(sc *subcommandContext) error {
sc.log.Info().Msg(line) sc.log.Info().Msg(line)
} }
if !sc.c.IsSet(flags.Protocol) { if !sc.c.IsSet("protocol") {
_ = sc.c.Set(flags.Protocol, "quic") sc.c.Set("protocol", "quic")
} }
// Override the number of connections used. Quick tunnels shouldn't be used for production usage, // Override the number of connections used. Quick tunnels shouldn't be used for production usage,
// so, use a single connection instead. // so, use a single connection instead.
_ = sc.c.Set(flags.HaConnections, "1") sc.c.Set(haConnectionsFlag, "1")
return StartServer( return StartServer(
sc.c, sc.c,
buildInfo, buildInfo,

View File

@ -9,26 +9,22 @@ import (
"strings" "strings"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mitchellh/go-homedir"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/cfapi"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/credentials" "github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
) )
const fedRampBaseApiURL = "https://api.fed.cloudflare.com/client/v4" type errInvalidJSONCredential struct {
type invalidJSONCredentialError struct {
err error err error
path string path string
} }
func (e invalidJSONCredentialError) Error() string { func (e errInvalidJSONCredential) Error() string {
return "Invalid JSON when parsing tunnel credentials file" return "Invalid JSON when parsing tunnel credentials file"
} }
@ -55,12 +51,7 @@ func newSubcommandContext(c *cli.Context) (*subcommandContext, error) {
// Returns something that can find the given tunnel's credentials file. // Returns something that can find the given tunnel's credentials file.
func (sc *subcommandContext) credentialFinder(tunnelID uuid.UUID) CredFinder { func (sc *subcommandContext) credentialFinder(tunnelID uuid.UUID) CredFinder {
if path := sc.c.String(CredFileFlag); path != "" { if path := sc.c.String(CredFileFlag); path != "" {
// Expand path if CredFileFlag contains `~` return newStaticPath(path, sc.fs)
absPath, err := homedir.Expand(path)
if err != nil {
return newStaticPath(path, sc.fs)
}
return newStaticPath(absPath, sc.fs)
} }
return newSearchByID(tunnelID, sc.c, sc.log, sc.fs) return newSearchByID(tunnelID, sc.c, sc.log, sc.fs)
} }
@ -73,16 +64,7 @@ func (sc *subcommandContext) client() (cfapi.Client, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
sc.tunnelstoreClient, err = cred.Client(sc.c.String("api-url"), buildInfo.UserAgent(), sc.log)
var apiURL string
if cred.IsFEDEndpoint() {
sc.log.Info().Str("api-url", fedRampBaseApiURL).Msg("using fedramp base api")
apiURL = fedRampBaseApiURL
} else {
apiURL = sc.c.String(cfdflags.ApiURL)
}
sc.tunnelstoreClient, err = cred.Client(apiURL, buildInfo.UserAgent(), sc.log)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -91,7 +73,7 @@ func (sc *subcommandContext) client() (cfapi.Client, error) {
func (sc *subcommandContext) credential() (*credentials.User, error) { func (sc *subcommandContext) credential() (*credentials.User, error) {
if sc.userCredential == nil { if sc.userCredential == nil {
uc, err := credentials.Read(sc.c.String(cfdflags.OriginCert), sc.log) uc, err := credentials.Read(sc.c.String(credentials.OriginCertFlag), sc.log)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -112,13 +94,13 @@ func (sc *subcommandContext) readTunnelCredentials(credFinder CredFinder) (conne
var credentials connection.Credentials var credentials connection.Credentials
if err = json.Unmarshal(body, &credentials); err != nil { if err = json.Unmarshal(body, &credentials); err != nil {
if filepath.Ext(filePath) == ".pem" { if strings.HasSuffix(filePath, ".pem") {
return connection.Credentials{}, fmt.Errorf("The tunnel credentials file should be .json but you gave a .pem. " + return connection.Credentials{}, fmt.Errorf("The tunnel credentials file should be .json but you gave a .pem. " +
"The tunnel credentials file was originally created by `cloudflared tunnel create`. " + "The tunnel credentials file was originally created by `cloudflared tunnel create`. " +
"You may have accidentally used the filepath to cert.pem, which is generated by `cloudflared tunnel " + "You may have accidentally used the filepath to cert.pem, which is generated by `cloudflared tunnel " +
"login`.") "login`.")
} }
return connection.Credentials{}, invalidJSONCredentialError{path: filePath, err: err} return connection.Credentials{}, errInvalidJSONCredential{path: filePath, err: err}
} }
return credentials, nil return credentials, nil
} }
@ -140,7 +122,7 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Couldn't decode tunnel secret from base64") return nil, errors.Wrap(err, "Couldn't decode tunnel secret from base64")
} }
tunnelSecret = decodedSecret tunnelSecret = []byte(decodedSecret)
if len(tunnelSecret) < 32 { if len(tunnelSecret) < 32 {
return nil, errors.New("Decoded tunnel secret must be at least 32 bytes long") return nil, errors.New("Decoded tunnel secret must be at least 32 bytes long")
} }
@ -178,7 +160,7 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec
errorLines = append(errorLines, fmt.Sprintf("Cloudflared tried to delete the tunnel for you, but encountered an error. You should use `cloudflared tunnel delete %v` to delete the tunnel yourself, because the tunnel can't be run without the tunnelfile.", tunnel.ID)) errorLines = append(errorLines, fmt.Sprintf("Cloudflared tried to delete the tunnel for you, but encountered an error. You should use `cloudflared tunnel delete %v` to delete the tunnel yourself, because the tunnel can't be run without the tunnelfile.", tunnel.ID))
errorLines = append(errorLines, fmt.Sprintf("The delete tunnel error is: %v", deleteErr)) errorLines = append(errorLines, fmt.Sprintf("The delete tunnel error is: %v", deleteErr))
} else { } else {
errorLines = append(errorLines, "The tunnel was deleted, because the tunnel can't be run without the credentials file") errorLines = append(errorLines, fmt.Sprintf("The tunnel was deleted, because the tunnel can't be run without the credentials file"))
} }
errorMsg := strings.Join(errorLines, "\n") errorMsg := strings.Join(errorLines, "\n")
return nil, errors.New(errorMsg) return nil, errors.New(errorMsg)
@ -207,7 +189,7 @@ func (sc *subcommandContext) list(filter *cfapi.TunnelFilter) ([]*cfapi.Tunnel,
} }
func (sc *subcommandContext) delete(tunnelIDs []uuid.UUID) error { func (sc *subcommandContext) delete(tunnelIDs []uuid.UUID) error {
forceFlagSet := sc.c.Bool(cfdflags.Force) forceFlagSet := sc.c.Bool("force")
client, err := sc.client() client, err := sc.client()
if err != nil { if err != nil {
@ -247,7 +229,7 @@ func (sc *subcommandContext) findCredentials(tunnelID uuid.UUID) (connection.Cre
var err error var err error
if credentialsContents := sc.c.String(CredContentsFlag); credentialsContents != "" { if credentialsContents := sc.c.String(CredContentsFlag); credentialsContents != "" {
if err = json.Unmarshal([]byte(credentialsContents), &credentials); err != nil { if err = json.Unmarshal([]byte(credentialsContents), &credentials); err != nil {
err = invalidJSONCredentialError{path: "TUNNEL_CRED_CONTENTS", err: err} err = errInvalidJSONCredential{path: "TUNNEL_CRED_CONTENTS", err: err}
} }
} else { } else {
credFinder := sc.credentialFinder(tunnelID) credFinder := sc.credentialFinder(tunnelID)
@ -263,7 +245,7 @@ func (sc *subcommandContext) findCredentials(tunnelID uuid.UUID) (connection.Cre
func (sc *subcommandContext) run(tunnelID uuid.UUID) error { func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
credentials, err := sc.findCredentials(tunnelID) credentials, err := sc.findCredentials(tunnelID)
if err != nil { if err != nil {
if e, ok := err.(invalidJSONCredentialError); ok { if e, ok := err.(errInvalidJSONCredential); ok {
sc.log.Error().Msgf("The credentials file at %s contained invalid JSON. This is probably caused by passing the wrong filepath. Reminder: the credentials file is a .json file created via `cloudflared tunnel create`.", e.path) sc.log.Error().Msgf("The credentials file at %s contained invalid JSON. This is probably caused by passing the wrong filepath. Reminder: the credentials file is a .json file created via `cloudflared tunnel create`.", e.path)
sc.log.Error().Msgf("Invalid JSON when parsing credentials file: %s", e.err.Error()) sc.log.Error().Msgf("Invalid JSON when parsing credentials file: %s", e.err.Error())
} }

View File

@ -16,40 +16,28 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mitchellh/go-homedir" homedir "github.com/mitchellh/go-homedir"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc" "github.com/urfave/cli/v2/altsrc"
"golang.org/x/net/idna" "golang.org/x/net/idna"
"gopkg.in/yaml.v3" yaml "gopkg.in/yaml.v3"
"github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater" "github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/diagnostic"
"github.com/cloudflare/cloudflared/fips"
"github.com/cloudflare/cloudflared/metrics"
) )
const ( const (
allSortByOptions = "name, id, createdAt, deletedAt, numConnections" allSortByOptions = "name, id, createdAt, deletedAt, numConnections"
connsSortByOptions = "id, startedAt, numConnections, version" connsSortByOptions = "id, startedAt, numConnections, version"
CredFileFlagAlias = "cred-file" CredFileFlagAlias = "cred-file"
CredFileFlag = "credentials-file" CredFileFlag = "credentials-file"
CredContentsFlag = "credentials-contents" CredContentsFlag = "credentials-contents"
TunnelTokenFlag = "token" TunnelTokenFlag = "token"
TunnelTokenFileFlag = "token-file" overwriteDNSFlagName = "overwrite-dns"
overwriteDNSFlagName = "overwrite-dns"
noDiagLogsFlagName = "no-diag-logs"
noDiagMetricsFlagName = "no-diag-metrics"
noDiagSystemFlagName = "no-diag-system"
noDiagRuntimeFlagName = "no-diag-runtime"
noDiagNetworkFlagName = "no-diag-network"
diagContainerIDFlagName = "diag-container-id"
diagPodFlagName = "diag-pod-id"
LogFieldTunnelID = "tunnelID" LogFieldTunnelID = "tunnelID"
) )
@ -61,7 +49,7 @@ var (
Usage: "Include deleted tunnels in the list", Usage: "Include deleted tunnels in the list",
} }
listNameFlag = &cli.StringFlag{ listNameFlag = &cli.StringFlag{
Name: flags.Name, Name: "name",
Aliases: []string{"n"}, Aliases: []string{"n"},
Usage: "List tunnels with the given `NAME`", Usage: "List tunnels with the given `NAME`",
} }
@ -109,7 +97,7 @@ var (
EnvVars: []string{"TUNNEL_LIST_INVERT_SORT"}, EnvVars: []string{"TUNNEL_LIST_INVERT_SORT"},
} }
featuresFlag = altsrc.NewStringSliceFlag(&cli.StringSliceFlag{ featuresFlag = altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: flags.Features, Name: "features",
Aliases: []string{"F"}, Aliases: []string{"F"},
Usage: "Opt into various features that are still being developed or tested.", Usage: "Opt into various features that are still being developed or tested.",
}) })
@ -127,23 +115,18 @@ var (
}) })
tunnelTokenFlag = altsrc.NewStringFlag(&cli.StringFlag{ tunnelTokenFlag = altsrc.NewStringFlag(&cli.StringFlag{
Name: TunnelTokenFlag, Name: TunnelTokenFlag,
Usage: "The Tunnel token. When provided along with credentials, this will take precedence. Also takes precedence over token-file", Usage: "The Tunnel token. When provided along with credentials, this will take precedence.",
EnvVars: []string{"TUNNEL_TOKEN"}, EnvVars: []string{"TUNNEL_TOKEN"},
}) })
tunnelTokenFileFlag = altsrc.NewStringFlag(&cli.StringFlag{
Name: TunnelTokenFileFlag,
Usage: "Filepath at which to read the tunnel token. When provided along with credentials, this will take precedence.",
EnvVars: []string{"TUNNEL_TOKEN_FILE"},
})
forceDeleteFlag = &cli.BoolFlag{ forceDeleteFlag = &cli.BoolFlag{
Name: flags.Force, Name: "force",
Aliases: []string{"f"}, Aliases: []string{"f"},
Usage: "Deletes a tunnel even if tunnel is connected and it has dependencies associated to it. (eg. IP routes)." + Usage: "Deletes a tunnel even if tunnel is connected and it has dependencies associated to it. (eg. IP routes)." +
" It is not possible to delete tunnels that have connections or non-deleted dependencies, without this flag.", " It is not possible to delete tunnels that have connections or non-deleted dependencies, without this flag.",
EnvVars: []string{"TUNNEL_RUN_FORCE_OVERWRITE"}, EnvVars: []string{"TUNNEL_RUN_FORCE_OVERWRITE"},
} }
selectProtocolFlag = altsrc.NewStringFlag(&cli.StringFlag{ selectProtocolFlag = altsrc.NewStringFlag(&cli.StringFlag{
Name: flags.Protocol, Name: "protocol",
Value: connection.AutoSelectFlag, Value: connection.AutoSelectFlag,
Aliases: []string{"p"}, Aliases: []string{"p"},
Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", connection.AvailableProtocolFlagMessage), Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", connection.AvailableProtocolFlagMessage),
@ -151,11 +134,11 @@ var (
Hidden: true, Hidden: true,
}) })
postQuantumFlag = altsrc.NewBoolFlag(&cli.BoolFlag{ postQuantumFlag = altsrc.NewBoolFlag(&cli.BoolFlag{
Name: flags.PostQuantum, Name: "post-quantum",
Usage: "When given creates an experimental post-quantum secure tunnel", Usage: "When given creates an experimental post-quantum secure tunnel",
Aliases: []string{"pq"}, Aliases: []string{"pq"},
EnvVars: []string{"TUNNEL_POST_QUANTUM"}, EnvVars: []string{"TUNNEL_POST_QUANTUM"},
Hidden: fips.IsFipsEnabled(), Hidden: FipsEnabled,
}) })
sortInfoByFlag = &cli.StringFlag{ sortInfoByFlag = &cli.StringFlag{
Name: "sort-by", Name: "sort-by",
@ -187,60 +170,15 @@ var (
EnvVars: []string{"TUNNEL_CREATE_SECRET"}, EnvVars: []string{"TUNNEL_CREATE_SECRET"},
} }
icmpv4SrcFlag = &cli.StringFlag{ icmpv4SrcFlag = &cli.StringFlag{
Name: flags.ICMPV4Src, Name: "icmpv4-src",
Usage: "Source address to send/receive ICMPv4 messages. If not provided cloudflared will dial a local address to determine the source IP or fallback to 0.0.0.0.", Usage: "Source address to send/receive ICMPv4 messages. If not provided cloudflared will dial a local address to determine the source IP or fallback to 0.0.0.0.",
EnvVars: []string{"TUNNEL_ICMPV4_SRC"}, EnvVars: []string{"TUNNEL_ICMPV4_SRC"},
} }
icmpv6SrcFlag = &cli.StringFlag{ icmpv6SrcFlag = &cli.StringFlag{
Name: flags.ICMPV6Src, Name: "icmpv6-src",
Usage: "Source address and the interface name to send/receive ICMPv6 messages. If not provided cloudflared will dial a local address to determine the source IP or fallback to ::.", Usage: "Source address and the interface name to send/receive ICMPv6 messages. If not provided cloudflared will dial a local address to determine the source IP or fallback to ::.",
EnvVars: []string{"TUNNEL_ICMPV6_SRC"}, EnvVars: []string{"TUNNEL_ICMPV6_SRC"},
} }
metricsFlag = &cli.StringFlag{
Name: flags.Metrics,
Usage: "The metrics server address i.e.: 127.0.0.1:12345. If your instance is running in a Docker/Kubernetes environment you need to setup port forwarding for your application.",
Value: "",
}
diagContainerFlag = &cli.StringFlag{
Name: diagContainerIDFlagName,
Usage: "Container ID or Name to collect logs from",
Value: "",
}
diagPodFlag = &cli.StringFlag{
Name: diagPodFlagName,
Usage: "Kubernetes POD to collect logs from",
Value: "",
}
noDiagLogsFlag = &cli.BoolFlag{
Name: noDiagLogsFlagName,
Usage: "Log collection will not be performed",
Value: false,
}
noDiagMetricsFlag = &cli.BoolFlag{
Name: noDiagMetricsFlagName,
Usage: "Metric collection will not be performed",
Value: false,
}
noDiagSystemFlag = &cli.BoolFlag{
Name: noDiagSystemFlagName,
Usage: "System information collection will not be performed",
Value: false,
}
noDiagRuntimeFlag = &cli.BoolFlag{
Name: noDiagRuntimeFlagName,
Usage: "Runtime information collection will not be performed",
Value: false,
}
noDiagNetworkFlag = &cli.BoolFlag{
Name: noDiagNetworkFlagName,
Usage: "Network diagnostics won't be performed",
Value: false,
}
maxActiveFlowsFlag = &cli.Uint64Flag{
Name: flags.MaxActiveFlows,
Usage: "Overrides the remote configuration for max active private network flows (TCP/UDP) that this cloudflared instance supports",
EnvVars: []string{"TUNNEL_MAX_ACTIVE_FLOWS"},
}
) )
func buildCreateCommand() *cli.Command { func buildCreateCommand() *cli.Command {
@ -343,7 +281,7 @@ func listCommand(c *cli.Context) error {
if !c.Bool("show-deleted") { if !c.Bool("show-deleted") {
filter.NoDeleted() filter.NoDeleted()
} }
if name := c.String(flags.Name); name != "" { if name := c.String("name"); name != "" {
filter.ByName(name) filter.ByName(name)
} }
if namePrefix := c.String("name-prefix"); namePrefix != "" { if namePrefix := c.String("name-prefix"); namePrefix != "" {
@ -437,6 +375,7 @@ func formatAndPrintTunnelList(tunnels []*cfapi.Tunnel, showRecentlyDisconnected
} }
func fmtConnections(connections []cfapi.Connection, showRecentlyDisconnected bool) string { func fmtConnections(connections []cfapi.Connection, showRecentlyDisconnected bool) string {
// Count connections per colo // Count connections per colo
numConnsPerColo := make(map[string]uint, len(connections)) numConnsPerColo := make(map[string]uint, len(connections))
for _, connection := range connections { for _, connection := range connections {
@ -453,7 +392,7 @@ func fmtConnections(connections []cfapi.Connection, showRecentlyDisconnected boo
sort.Strings(sortedColos) sort.Strings(sortedColos)
// Map each colo to its frequency, combine into output string. // Map each colo to its frequency, combine into output string.
output := make([]string, 0, len(sortedColos)) var output []string
for _, coloName := range sortedColos { for _, coloName := range sortedColos {
output = append(output, fmt.Sprintf("%dx%s", numConnsPerColo[coloName], coloName)) output = append(output, fmt.Sprintf("%dx%s", numConnsPerColo[coloName], coloName))
} }
@ -473,21 +412,16 @@ func buildReadyCommand() *cli.Command {
} }
func readyCommand(c *cli.Context) error { func readyCommand(c *cli.Context) error {
metricsOpts := c.String(flags.Metrics) metricsOpts := c.String("metrics")
if !c.IsSet(flags.Metrics) { if !c.IsSet("metrics") {
return errors.New("--metrics has to be provided") return fmt.Errorf("--metrics has to be provided")
} }
requestURL := fmt.Sprintf("http://%s/ready", metricsOpts) requestURL := fmt.Sprintf("http://%s/ready", metricsOpts)
req, err := http.NewRequest(http.MethodGet, requestURL, nil) res, err := http.Get(requestURL)
if err != nil { if err != nil {
return err return err
} }
res, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != 200 { if res.StatusCode != 200 {
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
@ -714,10 +648,8 @@ func buildRunCommand() *cli.Command {
selectProtocolFlag, selectProtocolFlag,
featuresFlag, featuresFlag,
tunnelTokenFlag, tunnelTokenFlag,
tunnelTokenFileFlag,
icmpv4SrcFlag, icmpv4SrcFlag,
icmpv6SrcFlag, icmpv6SrcFlag,
maxActiveFlowsFlag,
} }
flags = append(flags, configureProxyFlags(false)...) flags = append(flags, configureProxyFlags(false)...)
return &cli.Command{ return &cli.Command{
@ -755,22 +687,12 @@ func runCommand(c *cli.Context) error {
"your origin will not be reachable. You should remove the `hostname` property to avoid this warning.") "your origin will not be reachable. You should remove the `hostname` property to avoid this warning.")
} }
tokenStr := c.String(TunnelTokenFlag)
// Check if tokenStr is blank before checking for tokenFile
if tokenStr == "" {
if tokenFile := c.String(TunnelTokenFileFlag); tokenFile != "" {
data, err := os.ReadFile(tokenFile)
if err != nil {
return cliutil.UsageError("Failed to read token file: " + err.Error())
}
tokenStr = strings.TrimSpace(string(data))
}
}
// Check if token is provided and if not use default tunnelID flag method // Check if token is provided and if not use default tunnelID flag method
if tokenStr != "" { if tokenStr := c.String(TunnelTokenFlag); tokenStr != "" {
if token, err := ParseToken(tokenStr); err == nil { if token, err := ParseToken(tokenStr); err == nil {
return sc.runWithCredentials(token.Credentials()) return sc.runWithCredentials(token.Credentials())
} }
return cliutil.UsageError("Provided Tunnel token is not valid.") return cliutil.UsageError("Provided Tunnel token is not valid.")
} else { } else {
tunnelRef := c.Args().First() tunnelRef := c.Args().First()
@ -975,10 +897,8 @@ func lbRouteFromArg(c *cli.Context) (cfapi.HostnameRoute, error) {
return cfapi.NewLBRoute(lbName, lbPool), nil return cfapi.NewLBRoute(lbName, lbPool), nil
} }
var ( var nameRegex = regexp.MustCompile("^[_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
nameRegex = regexp.MustCompile("^[_a-zA-Z0-9][-_.a-zA-Z0-9]*$") var hostNameRegex = regexp.MustCompile("^[*_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
hostNameRegex = regexp.MustCompile("^[*_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
)
func validateName(s string, allowWildcardSubdomain bool) bool { func validateName(s string, allowWildcardSubdomain bool) bool {
if allowWildcardSubdomain { if allowWildcardSubdomain {
@ -1066,78 +986,3 @@ SUBCOMMAND OPTIONS:
` `
return fmt.Sprintf(template, parentFlagsHelp) return fmt.Sprintf(template, parentFlagsHelp)
} }
func buildDiagCommand() *cli.Command {
return &cli.Command{
Name: "diag",
Action: cliutil.ConfiguredAction(diagCommand),
Usage: "Creates a diagnostic report from a local cloudflared instance",
UsageText: "cloudflared tunnel [tunnel command options] diag [subcommand options]",
Description: "cloudflared tunnel diag will create a diagnostic report of a local cloudflared instance. The diagnostic procedure collects: logs, metrics, system information, traceroute to Cloudflare Edge, and runtime information. Since there may be multiple instances of cloudflared running the --metrics option may be provided to target a specific instance.",
Flags: []cli.Flag{
metricsFlag,
diagContainerFlag,
diagPodFlag,
noDiagLogsFlag,
noDiagMetricsFlag,
noDiagSystemFlag,
noDiagRuntimeFlag,
noDiagNetworkFlag,
},
CustomHelpTemplate: commandHelpTemplate(),
}
}
func diagCommand(ctx *cli.Context) error {
sctx, err := newSubcommandContext(ctx)
if err != nil {
return err
}
log := sctx.log
options := diagnostic.Options{
KnownAddresses: metrics.GetMetricsKnownAddresses(metrics.Runtime),
Address: sctx.c.String(flags.Metrics),
ContainerID: sctx.c.String(diagContainerIDFlagName),
PodID: sctx.c.String(diagPodFlagName),
Toggles: diagnostic.Toggles{
NoDiagLogs: sctx.c.Bool(noDiagLogsFlagName),
NoDiagMetrics: sctx.c.Bool(noDiagMetricsFlagName),
NoDiagSystem: sctx.c.Bool(noDiagSystemFlagName),
NoDiagRuntime: sctx.c.Bool(noDiagRuntimeFlagName),
NoDiagNetwork: sctx.c.Bool(noDiagNetworkFlagName),
},
}
if options.Address == "" {
log.Info().Msg("If your instance is running in a Docker/Kubernetes environment you need to setup port forwarding for your application.")
}
states, err := diagnostic.RunDiagnostic(log, options)
if errors.Is(err, diagnostic.ErrMetricsServerNotFound) {
log.Warn().Msg("No instances found")
return nil
}
if errors.Is(err, diagnostic.ErrMultipleMetricsServerFound) {
if states != nil {
log.Info().Msgf("Found multiple instances running:")
for _, state := range states {
log.Info().Msgf("Instance: tunnel-id=%s connector-id=%s metrics-address=%s", state.TunnelID, state.ConnectorID, state.URL.String())
}
log.Info().Msgf("To select one instance use the option --metrics")
}
return nil
}
if errors.Is(err, diagnostic.ErrLogConfigurationIsInvalid) {
log.Info().Msg("Couldn't extract logs from the instance. If the instance is running in a containerized environment use the option --diag-container-id or --diag-pod-id. If there is no logging configuration use --no-diag-logs.")
}
if err != nil {
log.Warn().Msg("Diagnostic completed with one or more errors")
} else {
log.Info().Msg("Diagnostic completed")
}
return nil
}

View File

@ -22,7 +22,7 @@ var (
Usage: "The ID or name of the virtual network to which the route is associated to.", Usage: "The ID or name of the virtual network to which the route is associated to.",
} }
errAddRoute = errors.New("You must supply exactly one argument, the ID or CIDR of the route you want to delete") routeAddError = errors.New("You must supply exactly one argument, the ID or CIDR of the route you want to delete")
) )
func buildRouteIPSubcommand() *cli.Command { func buildRouteIPSubcommand() *cli.Command {
@ -32,7 +32,7 @@ func buildRouteIPSubcommand() *cli.Command {
UsageText: "cloudflared tunnel [--config FILEPATH] route COMMAND [arguments...]", UsageText: "cloudflared tunnel [--config FILEPATH] route COMMAND [arguments...]",
Description: `cloudflared can provision routes for any IP space in your corporate network. Users enrolled in Description: `cloudflared can provision routes for any IP space in your corporate network. Users enrolled in
your Cloudflare for Teams organization can reach those IPs through the Cloudflare WARP your Cloudflare for Teams organization can reach those IPs through the Cloudflare WARP
client. You can then configure L7/L4 filtering on https://one.dash.cloudflare.com to client. You can then configure L7/L4 filtering on https://dash.teams.cloudflare.com to
determine who can reach certain routes. determine who can reach certain routes.
By default IP routes all exist within a single virtual network. If you use the same IP By default IP routes all exist within a single virtual network. If you use the same IP
space(s) in different physical private networks, all meant to be reachable via IP routes, space(s) in different physical private networks, all meant to be reachable via IP routes,
@ -187,7 +187,7 @@ func deleteRouteCommand(c *cli.Context) error {
} }
if c.NArg() != 1 { if c.NArg() != 1 {
return errAddRoute return routeAddError
} }
var routeId uuid.UUID var routeId uuid.UUID
@ -195,7 +195,7 @@ func deleteRouteCommand(c *cli.Context) error {
if err != nil { if err != nil {
_, network, err := net.ParseCIDR(c.Args().First()) _, network, err := net.ParseCIDR(c.Args().First())
if err != nil || network == nil { if err != nil || network == nil {
return errAddRoute return routeAddError
} }
var vnetId *uuid.UUID var vnetId *uuid.UUID

View File

@ -15,14 +15,13 @@ import (
"golang.org/x/term" "golang.org/x/term"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/logger" "github.com/cloudflare/cloudflared/logger"
) )
const ( const (
DefaultCheckUpdateFreq = time.Hour * 24 DefaultCheckUpdateFreq = time.Hour * 24
noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/configure-tunnels/local-management/as-a-service/" noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/as-a-service/"
noUpdateOnWindowsMessage = "cloudflared will not automatically update on Windows systems." noUpdateOnWindowsMessage = "cloudflared will not automatically update on Windows systems."
noUpdateManagedPackageMessage = "cloudflared will not automatically update if installed by a package manager." noUpdateManagedPackageMessage = "cloudflared will not automatically update if installed by a package manager."
isManagedInstallFile = ".installedFromPackageManager" isManagedInstallFile = ".installedFromPackageManager"
@ -39,7 +38,6 @@ var (
// BinaryUpdated implements ExitCoder interface, the app will exit with status code 11 // BinaryUpdated implements ExitCoder interface, the app will exit with status code 11
// https://pkg.go.dev/github.com/urfave/cli/v2?tab=doc#ExitCoder // https://pkg.go.dev/github.com/urfave/cli/v2?tab=doc#ExitCoder
// nolint: errname
type statusSuccess struct { type statusSuccess struct {
newVersion string newVersion string
} }
@ -52,16 +50,16 @@ func (u *statusSuccess) ExitCode() int {
return 11 return 11
} }
// statusError implements ExitCoder interface, the app will exit with status code 10 // UpdateErr implements ExitCoder interface, the app will exit with status code 10
type statusError struct { type statusErr struct {
err error err error
} }
func (e *statusError) Error() string { func (e *statusErr) Error() string {
return fmt.Sprintf("failed to update cloudflared: %v", e.err) return fmt.Sprintf("failed to update cloudflared: %v", e.err)
} }
func (e *statusError) ExitCode() int { func (e *statusErr) ExitCode() int {
return 10 return 10
} }
@ -81,7 +79,7 @@ type UpdateOutcome struct {
} }
func (uo *UpdateOutcome) noUpdate() bool { func (uo *UpdateOutcome) noUpdate() bool {
return uo.Error == nil && !uo.Updated return uo.Error == nil && uo.Updated == false
} }
func Init(info *cliutil.BuildInfo) { func Init(info *cliutil.BuildInfo) {
@ -155,7 +153,7 @@ func Update(c *cli.Context) error {
log.Info().Msg("cloudflared is set to update from staging") log.Info().Msg("cloudflared is set to update from staging")
} }
isForced := c.Bool(cfdflags.Force) isForced := c.Bool("force")
if isForced { if isForced {
log.Info().Msg("cloudflared is set to upgrade to the latest publish version regardless of the current version") log.Info().Msg("cloudflared is set to upgrade to the latest publish version regardless of the current version")
} }
@ -168,7 +166,7 @@ func Update(c *cli.Context) error {
intendedVersion: c.String("version"), intendedVersion: c.String("version"),
}) })
if updateOutcome.Error != nil { if updateOutcome.Error != nil {
return &statusError{updateOutcome.Error} return &statusErr{updateOutcome.Error}
} }
if updateOutcome.noUpdate() { if updateOutcome.noUpdate() {
@ -254,7 +252,7 @@ func (a *AutoUpdater) Run(ctx context.Context) error {
pid, err := a.listeners.StartProcess() pid, err := a.listeners.StartProcess()
if err != nil { if err != nil {
a.log.Err(err).Msg("Unable to restart server automatically") a.log.Err(err).Msg("Unable to restart server automatically")
return &statusError{err: err} return &statusErr{err: err}
} }
// stop old process after autoupdate. Otherwise we create a new process // stop old process after autoupdate. Otherwise we create a new process
// after each update // after each update

View File

@ -10,9 +10,9 @@ import (
"net/url" "net/url"
"os" "os"
"os/exec" "os/exec"
"path"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings"
"text/template" "text/template"
"time" "time"
@ -134,7 +134,7 @@ func (v *WorkersVersion) Apply() error {
if err := os.Rename(newFilePath, v.targetPath); err != nil { if err := os.Rename(newFilePath, v.targetPath); err != nil {
//attempt rollback //attempt rollback
_ = os.Rename(oldFilePath, v.targetPath) os.Rename(oldFilePath, v.targetPath)
return err return err
} }
os.Remove(oldFilePath) os.Remove(oldFilePath)
@ -181,7 +181,7 @@ func download(url, filepath string, isCompressed bool) error {
tr := tar.NewReader(gr) tr := tar.NewReader(gr)
// advance the reader pass the header, which will be the single binary file // advance the reader pass the header, which will be the single binary file
_, _ = tr.Next() tr.Next()
r = tr r = tr
} }
@ -198,7 +198,7 @@ func download(url, filepath string, isCompressed bool) error {
// isCompressedFile is a really simple file extension check to see if this is a macos tar and gzipped // isCompressedFile is a really simple file extension check to see if this is a macos tar and gzipped
func isCompressedFile(urlstring string) bool { func isCompressedFile(urlstring string) bool {
if path.Ext(urlstring) == ".tgz" { if strings.HasSuffix(urlstring, ".tgz") {
return true return true
} }
@ -206,7 +206,7 @@ func isCompressedFile(urlstring string) bool {
if err != nil { if err != nil {
return false return false
} }
return path.Ext(u.Path) == ".tgz" return strings.HasSuffix(u.Path, ".tgz")
} }
// writeBatchFile writes a batch file out to disk // writeBatchFile writes a batch file out to disk
@ -249,6 +249,7 @@ func runWindowsBatch(batchFile string) error {
if exitError, ok := err.(*exec.ExitError); ok { if exitError, ok := err.(*exec.ExitError); ok {
return fmt.Errorf("Error during update : %s;", string(exitError.Stderr)) return fmt.Errorf("Error during update : %s;", string(exitError.Stderr))
} }
} }
return err return err
} }

View File

@ -26,7 +26,7 @@ import (
const ( const (
windowsServiceName = "Cloudflared" windowsServiceName = "Cloudflared"
windowsServiceDescription = "Cloudflared agent" windowsServiceDescription = "Cloudflared agent"
windowsServiceUrl = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/configure-tunnels/local-management/as-a-service/windows/" windowsServiceUrl = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/as-a-service/windows/"
recoverActionDelay = time.Second * 20 recoverActionDelay = time.Second * 20
failureCountResetPeriod = time.Hour * 24 failureCountResetPeriod = time.Hour * 24

View File

@ -1,6 +1,7 @@
from util import LOGGER, start_cloudflared, wait_tunnel_ready from util import LOGGER, nofips, start_cloudflared, wait_tunnel_ready
@nofips
class TestPostQuantum: class TestPostQuantum:
def _extra_config(self): def _extra_config(self):
config = { config = {
@ -11,11 +12,6 @@ class TestPostQuantum:
def test_post_quantum(self, tmp_path, component_tests_config): def test_post_quantum(self, tmp_path, component_tests_config):
config = component_tests_config(self._extra_config()) config = component_tests_config(self._extra_config())
LOGGER.debug(config) LOGGER.debug(config)
with start_cloudflared( with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], cfd_args=["run", "--post-quantum"], new_process=True):
tmp_path, wait_tunnel_ready(tunnel_url=config.get_url(),
config, require_min_connections=1)
cfd_pre_args=["tunnel", "--ha-connections", "1"],
cfd_args=["run", "--post-quantum"],
new_process=True,
):
wait_tunnel_ready(tunnel_url=config.get_url(), require_min_connections=1)

View File

@ -155,7 +155,7 @@ func FindOrCreateConfigPath() string {
// i.e. it fails if a user specifies both --url and --unix-socket // i.e. it fails if a user specifies both --url and --unix-socket
func ValidateUnixSocket(c *cli.Context) (string, error) { func ValidateUnixSocket(c *cli.Context) (string, error) {
if c.IsSet("unix-socket") && (c.IsSet("url") || c.NArg() > 0) { if c.IsSet("unix-socket") && (c.IsSet("url") || c.NArg() > 0) {
return "", errors.New("--unix-socket must be used exclusively.") return "", errors.New("--unix-socket must be used exclusivly.")
} }
return c.String("unix-socket"), nil return c.String("unix-socket"), nil
} }
@ -260,7 +260,6 @@ type Configuration struct {
type WarpRoutingConfig struct { type WarpRoutingConfig struct {
ConnectTimeout *CustomDuration `yaml:"connectTimeout" json:"connectTimeout,omitempty"` ConnectTimeout *CustomDuration `yaml:"connectTimeout" json:"connectTimeout,omitempty"`
MaxActiveFlows *uint64 `yaml:"maxActiveFlows" json:"maxActiveFlows,omitempty"`
TCPKeepAlive *CustomDuration `yaml:"tcpKeepAlive" json:"tcpKeepAlive,omitempty"` TCPKeepAlive *CustomDuration `yaml:"tcpKeepAlive" json:"tcpKeepAlive,omitempty"`
} }

View File

@ -60,7 +60,6 @@ type Credentials struct {
AccountTag string AccountTag string
TunnelSecret []byte TunnelSecret []byte
TunnelID uuid.UUID TunnelID uuid.UUID
Endpoint string
} }
func (c *Credentials) Auth() pogs.TunnelAuth { func (c *Credentials) Auth() pogs.TunnelAuth {
@ -75,16 +74,13 @@ type TunnelToken struct {
AccountTag string `json:"a"` AccountTag string `json:"a"`
TunnelSecret []byte `json:"s"` TunnelSecret []byte `json:"s"`
TunnelID uuid.UUID `json:"t"` TunnelID uuid.UUID `json:"t"`
Endpoint string `json:"e,omitempty"`
} }
func (t TunnelToken) Credentials() Credentials { func (t TunnelToken) Credentials() Credentials {
// nolint: gosimple
return Credentials{ return Credentials{
AccountTag: t.AccountTag, AccountTag: t.AccountTag,
TunnelSecret: t.TunnelSecret, TunnelSecret: t.TunnelSecret,
TunnelID: t.TunnelID, TunnelID: t.TunnelID,
Endpoint: t.Endpoint,
} }
} }

View File

@ -2,18 +2,14 @@ package connection
import ( import (
"context" "context"
"crypto/rand"
"fmt" "fmt"
"io" "io"
"math/big" "math/rand"
"net/http" "net/http"
"time" "time"
pkgerrors "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/stream" "github.com/cloudflare/cloudflared/stream"
"github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tracing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -81,7 +77,7 @@ func (moc *mockOriginProxy) ProxyHTTP(
return wsFlakyEndpoint(w, req) return wsFlakyEndpoint(w, req)
default: default:
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found")) originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
return fmt.Errorf("unknown websocket endpoint %s", req.URL.Path) return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
} }
} }
switch req.URL.Path { switch req.URL.Path {
@ -99,6 +95,7 @@ func (moc *mockOriginProxy) ProxyHTTP(
originRespEndpoint(w, http.StatusNotFound, []byte("page not found")) originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
} }
return nil return nil
} }
func (moc *mockOriginProxy) ProxyTCP( func (moc *mockOriginProxy) ProxyTCP(
@ -106,10 +103,6 @@ func (moc *mockOriginProxy) ProxyTCP(
rwa ReadWriteAcker, rwa ReadWriteAcker,
r *TCPRequest, r *TCPRequest,
) error { ) error {
if r.CfTraceID == "flow-rate-limited" {
return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "tcp flow rate limited")
}
return nil return nil
} }
@ -185,8 +178,7 @@ func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log) wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
rInt, _ := rand.Int(rand.Reader, big.NewInt(50)) closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
closedAfter := time.Millisecond * time.Duration(rInt.Int64())
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)} originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
stream.Pipe(wsConn, originConn, &log) stream.Pipe(wsConn, originConn, &log)
cancel() cancel()

View File

@ -102,7 +102,7 @@ func (c *controlStream) ServeControlStream(
c.observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc() c.observer.metrics.regSuccess.WithLabelValues("registerConnection").Inc()
c.observer.logConnected(registrationDetails.UUID, c.connIndex, registrationDetails.Location, c.edgeAddress, c.protocol) c.observer.logConnected(registrationDetails.UUID, c.connIndex, registrationDetails.Location, c.edgeAddress, c.protocol)
c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location, c.edgeAddress) c.observer.sendConnectedEvent(c.connIndex, c.protocol, registrationDetails.Location)
c.connectedFuse.Connected() c.connectedFuse.Connected()
// if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration // if conn index is 0 and tunnel is not remotely managed, then send local ingress rules configuration

View File

@ -1,15 +1,12 @@
package connection package connection
import "net"
// Event is something that happened to a connection, e.g. disconnection or registration. // Event is something that happened to a connection, e.g. disconnection or registration.
type Event struct { type Event struct {
Index uint8 Index uint8
EventType Status EventType Status
Location string Location string
Protocol Protocol Protocol Protocol
URL string URL string
EdgeAddress net.IP
} }
// Status is the status of a connection. // Status is the status of a connection.

View File

@ -22,9 +22,8 @@ var (
var ( var (
// pre-generate possible values for res // pre-generate possible values for res
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared", false) responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
responseMetaHeaderCfdFlowRateLimited = mustInitRespMetaHeader("cloudflared", true) responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin", false)
) )
// HTTPHeader is a custom header struct that expects only ever one value for the header. // HTTPHeader is a custom header struct that expects only ever one value for the header.
@ -35,12 +34,11 @@ type HTTPHeader struct {
} }
type responseMetaHeader struct { type responseMetaHeader struct {
Source string `json:"src"` Source string `json:"src"`
FlowRateLimited bool `json:"flow_rate_limited,omitempty"`
} }
func mustInitRespMetaHeader(src string, flowRateLimited bool) string { func mustInitRespMetaHeader(src string) string {
header, err := json.Marshal(responseMetaHeader{Source: src, FlowRateLimited: flowRateLimited}) header, err := json.Marshal(responseMetaHeader{Source: src})
if err != nil { if err != nil {
panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err)) panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err))
} }
@ -114,7 +112,7 @@ func SerializeHeaders(h1Headers http.Header) string {
func DeserializeHeaders(serializedHeaders string) ([]HTTPHeader, error) { func DeserializeHeaders(serializedHeaders string) ([]HTTPHeader, error) {
const unableToDeserializeErr = "Unable to deserialize headers" const unableToDeserializeErr = "Unable to deserialize headers"
deserialized := make([]HTTPHeader, 0) var deserialized []HTTPHeader
for _, serializedPair := range strings.Split(serializedHeaders, ";") { for _, serializedPair := range strings.Split(serializedHeaders, ";") {
if len(serializedPair) == 0 { if len(serializedPair) == 0 {
continue continue

View File

@ -16,8 +16,6 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/net/http2" "golang.org/x/net/http2"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tracing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
@ -158,7 +156,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.log.Error().Err(requestErr).Msg("failed to serve incoming request") c.log.Error().Err(requestErr).Msg("failed to serve incoming request")
// WriteErrorResponse will return false if status was already written. we need to abort handler. // WriteErrorResponse will return false if status was already written. we need to abort handler.
if !respWriter.WriteErrorResponse(requestErr) { if !respWriter.WriteErrorResponse() {
c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent") c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent")
panic(http.ErrAbortHandler) panic(http.ErrAbortHandler)
} }
@ -211,9 +209,8 @@ func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, l
w: w, w: w,
log: log, log: log,
} }
err := fmt.Errorf("%T doesn't implement http.Flusher", w) respWriter.WriteErrorResponse()
respWriter.WriteErrorResponse(err) return nil, fmt.Errorf("%T doesn't implement http.Flusher", w)
return nil, err
} }
return &http2RespWriter{ return &http2RespWriter{
@ -298,7 +295,7 @@ func (rp *http2RespWriter) WriteHeader(status int) {
rp.log.Warn().Msg("WriteHeader after hijack") rp.log.Warn().Msg("WriteHeader after hijack")
return return
} }
_ = rp.WriteRespHeaders(status, rp.respHeaders) rp.WriteRespHeaders(status, rp.respHeaders)
} }
func (rp *http2RespWriter) hijacked() bool { func (rp *http2RespWriter) hijacked() bool {
@ -331,16 +328,12 @@ func (rp *http2RespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return conn, readWriter, nil return conn, readWriter, nil
} }
func (rp *http2RespWriter) WriteErrorResponse(err error) bool { func (rp *http2RespWriter) WriteErrorResponse() bool {
if rp.statusWritten { if rp.statusWritten {
return false return false
} }
if errors.Is(err, cfdflow.ErrTooManyActiveFlows) { rp.setResponseMetaHeader(responseMetaHeaderCfd)
rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited)
} else {
rp.setResponseMetaHeader(responseMetaHeaderCfd)
}
rp.w.WriteHeader(http.StatusBadGateway) rp.w.WriteHeader(http.StatusBadGateway)
rp.statusWritten = true rp.statusWritten = true

View File

@ -20,8 +20,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc" "github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
@ -67,18 +65,19 @@ func TestHTTP2ConfigurationSet(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
_ = http2Conn.Serve(ctx) http2Conn.Serve(ctx)
}() }()
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn) edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err) require.NoError(t, err)
endpoint := fmt.Sprintf("http://localhost:8080/ok")
reqBody := []byte(`{ reqBody := []byte(`{
"version": 2, "version": 2,
"config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}} "config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}}
`) `)
reader := bytes.NewReader(reqBody) reader := bytes.NewReader(reqBody)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://localhost:8080/ok", reader) req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, reader)
require.NoError(t, err) require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate) req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate)
@ -86,11 +85,11 @@ func TestHTTP2ConfigurationSet(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, http.StatusOK, resp.StatusCode)
bdy, err := io.ReadAll(resp.Body) bdy, err := io.ReadAll(resp.Body)
defer resp.Body.Close()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy)) assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy))
cancel() cancel()
wg.Wait() wg.Wait()
} }
func TestServeHTTP(t *testing.T) { func TestServeHTTP(t *testing.T) {
@ -135,7 +134,7 @@ func TestServeHTTP(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
_ = http2Conn.Serve(ctx) http2Conn.Serve(ctx)
}() }()
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn) edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
@ -154,7 +153,6 @@ func TestServeHTTP(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, test.expectedBody, respBody) require.Equal(t, test.expectedBody, respBody)
} }
_ = resp.Body.Close()
if test.isProxyError { if test.isProxyError {
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader)) require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
} else { } else {
@ -283,11 +281,10 @@ func TestServeWS(t *testing.T) {
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody()) respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, data, respBody, "expect %s, got %s", string(data), string(respBody)) require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
cancel() cancel()
resp := respWriter.Result() resp := respWriter.Result()
defer resp.Body.Close()
// http2RespWriter should rewrite status 101 to 200 // http2RespWriter should rewrite status 101 to 200
require.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader)) require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
@ -307,7 +304,7 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
serverDone := make(chan struct{}) serverDone := make(chan struct{})
go func() { go func() {
defer close(serverDone) defer close(serverDone)
_ = cfdHTTP2Conn.Serve(ctx) cfdHTTP2Conn.Serve(ctx)
}() }()
edgeTransport := http2.Transport{} edgeTransport := http2.Transport{}
@ -322,16 +319,13 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
readPipe, writePipe := io.Pipe() readPipe, writePipe := io.Pipe()
reqCtx, reqCancel := context.WithCancel(ctx) reqCtx, reqCancel := context.WithCancel(ctx)
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe) req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
assert.NoError(t, err) require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade) req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
resp, err := edgeHTTP2Conn.RoundTrip(req) resp, err := edgeHTTP2Conn.RoundTrip(req)
assert.NoError(t, err) require.NoError(t, err)
_ = resp.Body.Close()
// http2RespWriter should rewrite status 101 to 200 // http2RespWriter should rewrite status 101 to 200
assert.Equal(t, http.StatusOK, resp.StatusCode) require.Equal(t, http.StatusOK, resp.StatusCode)
wg.Add(1) wg.Add(1)
go func() { go func() {
@ -384,7 +378,7 @@ func TestServeControlStream(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
_ = http2Conn.Serve(ctx) http2Conn.Serve(ctx)
}() }()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
@ -397,8 +391,7 @@ func TestServeControlStream(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
// nolint: bodyclose edgeHTTP2Conn.RoundTrip(req)
_, _ = edgeHTTP2Conn.RoundTrip(req)
}() }()
<-rpcClientFactory.registered <-rpcClientFactory.registered
@ -438,7 +431,7 @@ func TestFailRegistration(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
_ = http2Conn.Serve(ctx) http2Conn.Serve(ctx)
}() }()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
@ -449,10 +442,9 @@ func TestFailRegistration(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
resp, err := edgeHTTP2Conn.RoundTrip(req) resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err) require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusBadGateway, resp.StatusCode) require.Equal(t, http.StatusBadGateway, resp.StatusCode)
require.Error(t, http2Conn.controlStreamErr) assert.NotNil(t, http2Conn.controlStreamErr)
cancel() cancel()
wg.Wait() wg.Wait()
} }
@ -489,7 +481,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
_ = http2Conn.Serve(ctx) http2Conn.Serve(ctx)
}() }()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
@ -502,7 +494,6 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
// nolint: bodyclose
_, _ = edgeHTTP2Conn.RoundTrip(req) _, _ = edgeHTTP2Conn.RoundTrip(req)
}() }()
@ -533,36 +524,6 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
}) })
} }
func TestServeTCP_RateLimited(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
http2Conn, edgeConn := newTestHTTP2Connection()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = http2Conn.Serve(ctx)
}()
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080", nil)
require.NoError(t, err)
req.Header.Set(InternalTCPProxySrcHeader, "tcp")
req.Header.Set(tracing.TracerContextName, "flow-rate-limited")
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
require.Equal(t, responseMetaHeaderCfdFlowRateLimited, resp.Header.Get(ResponseMetaHeader))
cancel()
wg.Wait()
}
func benchmarkServeHTTP(b *testing.B, test testRequest) { func benchmarkServeHTTP(b *testing.B, test testRequest) {
http2Conn, edgeConn := newTestHTTP2Connection() http2Conn, edgeConn := newTestHTTP2Connection()
@ -571,7 +532,7 @@ func benchmarkServeHTTP(b *testing.B, test testRequest) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
_ = http2Conn.Serve(ctx) http2Conn.Serve(ctx)
}() }()
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint) endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)

View File

@ -47,6 +47,7 @@ func (o *Observer) RegisterSink(sink EventSink) {
} }
func (o *Observer) logConnected(connectionID uuid.UUID, connIndex uint8, location string, address net.IP, protocol Protocol) { func (o *Observer) logConnected(connectionID uuid.UUID, connIndex uint8, location string, address net.IP, protocol Protocol) {
o.sendEvent(Event{Index: connIndex, EventType: Connected, Location: location})
o.log.Info(). o.log.Info().
Int(management.EventTypeKey, int(management.Cloudflared)). Int(management.EventTypeKey, int(management.Cloudflared)).
Str(LogFieldConnectionID, connectionID.String()). Str(LogFieldConnectionID, connectionID.String()).
@ -62,8 +63,8 @@ func (o *Observer) sendRegisteringEvent(connIndex uint8) {
o.sendEvent(Event{Index: connIndex, EventType: RegisteringTunnel}) o.sendEvent(Event{Index: connIndex, EventType: RegisteringTunnel})
} }
func (o *Observer) sendConnectedEvent(connIndex uint8, protocol Protocol, location string, edgeAddress net.IP) { func (o *Observer) sendConnectedEvent(connIndex uint8, protocol Protocol, location string) {
o.sendEvent(Event{Index: connIndex, EventType: Connected, Protocol: protocol, Location: location, EdgeAddress: edgeAddress}) o.sendEvent(Event{Index: connIndex, EventType: Connected, Protocol: protocol, Location: location})
} }
func (o *Observer) SendURL(url string) { func (o *Observer) SendURL(url string) {

View File

@ -14,7 +14,7 @@ import (
const ( const (
AvailableProtocolFlagMessage = "Available protocols: 'auto' - automatically chooses the best protocol over time (the default; and also the recommended one); 'quic' - based on QUIC, relying on UDP egress to Cloudflare edge; 'http2' - using Go's HTTP2 library, relying on TCP egress to Cloudflare edge" AvailableProtocolFlagMessage = "Available protocols: 'auto' - automatically chooses the best protocol over time (the default; and also the recommended one); 'quic' - based on QUIC, relying on UDP egress to Cloudflare edge; 'http2' - using Go's HTTP2 library, relying on TCP egress to Cloudflare edge"
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge (unused, but kept for legacy reference). // edgeH2muxTLSServerName is the server name to establish h2mux connection with edge (unused, but kept for legacy reference).
_ = "cftunnel.com" edgeH2muxTLSServerName = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge // edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com" edgeH2TLSServerName = "h2.cftunnel.com"
// edgeQUICServerName is the server name to establish quic connection with edge. // edgeQUICServerName is the server name to establish quic connection with edge.
@ -24,9 +24,11 @@ const (
ResolveTTL = time.Hour ResolveTTL = time.Hour
) )
// ProtocolList represents a list of supported protocols for communication with the edge var (
// in order of precedence for remote percentage fetcher. // ProtocolList represents a list of supported protocols for communication with the edge
var ProtocolList = []Protocol{QUIC, HTTP2} // in order of precedence for remote percentage fetcher.
ProtocolList = []Protocol{QUIC, HTTP2}
)
type Protocol int64 type Protocol int64
@ -56,7 +58,7 @@ func (p Protocol) String() string {
case QUIC: case QUIC:
return "quic" return "quic"
default: default:
return "unknown protocol" return fmt.Sprintf("unknown protocol")
} }
} }
@ -244,11 +246,11 @@ func NewProtocolSelector(
return newRemoteProtocolSelector(fetchedProtocol, ProtocolList, threshold, protocolFetcher, resolveTTL, log), nil return newRemoteProtocolSelector(fetchedProtocol, ProtocolList, threshold, protocolFetcher, resolveTTL, log), nil
} }
return nil, fmt.Errorf("unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
} }
func switchThreshold(accountTag string) int32 { func switchThreshold(accountTag string) int32 {
h := fnv.New32a() h := fnv.New32a()
_, _ = h.Write([]byte(accountTag)) _, _ = h.Write([]byte(accountTag))
return int32(h.Sum32() % 100) // nolint: gosec return int32(h.Sum32() % 100)
} }

View File

@ -7,6 +7,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"net/netip"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
@ -17,8 +18,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
cfdflow "github.com/cloudflare/cloudflared/flow" "github.com/cloudflare/cloudflared/packet"
cfdquic "github.com/cloudflare/cloudflared/quic" cfdquic "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -103,19 +103,14 @@ func (q *quicConnection) Serve(ctx context.Context) error {
// amount of the grace period, allowing requests to finish before we cancel the context, which will // amount of the grace period, allowing requests to finish before we cancel the context, which will
// make cloudflared exit. // make cloudflared exit.
if err := q.serveControlStream(ctx, controlStream); err == nil { if err := q.serveControlStream(ctx, controlStream); err == nil {
if q.gracePeriod > 0 { select {
// In Go1.23 this can be removed and replaced with time.Ticker case <-ctx.Done():
// see https://pkg.go.dev/time#Tick case <-time.Tick(q.gracePeriod):
ticker := time.NewTicker(q.gracePeriod)
defer ticker.Stop()
select {
case <-ctx.Done():
case <-ticker.C:
}
} }
} }
cancel() cancel()
return err return err
}) })
errGroup.Go(func() error { errGroup.Go(func() error {
defer cancel() defer cancel()
@ -136,7 +131,7 @@ func (q *quicConnection) serveControlStream(ctx context.Context, controlStream q
// Close the connection with no errors specified. // Close the connection with no errors specified.
func (q *quicConnection) Close() { func (q *quicConnection) Close() {
_ = q.conn.CloseWithError(0, "") q.conn.CloseWithError(0, "")
} }
func (q *quicConnection) acceptStream(ctx context.Context) error { func (q *quicConnection) acceptStream(ctx context.Context) error {
@ -189,13 +184,7 @@ func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.R
return err return err
} }
var metadata []pogs.Metadata if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil {
// Check the type of error that was throw and add metadata that will help identify it on OTD.
if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
metadata = append(metadata, pogs.ErrorFlowConnectRateLimitedMetadata)
}
if writeRespErr := stream.WriteConnectResponseData(err, metadata...); writeRespErr != nil {
return writeRespErr return writeRespErr
} }
} }
@ -291,7 +280,7 @@ func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header)
func (hrw *httpResponseAdapter) Write(p []byte) (int, error) { func (hrw *httpResponseAdapter) Write(p []byte) (int, error) {
// Make sure to send WriteHeader response if not called yet // Make sure to send WriteHeader response if not called yet
if !hrw.connectResponseSent { if !hrw.connectResponseSent {
_ = hrw.WriteRespHeaders(http.StatusOK, hrw.headers) hrw.WriteRespHeaders(http.StatusOK, hrw.headers)
} }
return hrw.RequestServerStream.Write(p) return hrw.RequestServerStream.Write(p)
} }
@ -304,7 +293,7 @@ func (hrw *httpResponseAdapter) Header() http.Header {
func (hrw *httpResponseAdapter) Flush() {} func (hrw *httpResponseAdapter) Flush() {}
func (hrw *httpResponseAdapter) WriteHeader(status int) { func (hrw *httpResponseAdapter) WriteHeader(status int) {
_ = hrw.WriteRespHeaders(status, hrw.headers) hrw.WriteRespHeaders(status, hrw.headers)
} }
func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
@ -317,7 +306,7 @@ func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
} }
func (hrw *httpResponseAdapter) WriteErrorResponse(err error) { func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
_ = hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
} }
func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error { func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error {
@ -428,3 +417,28 @@ func (np *nopCloserReadWriter) Close() error {
return nil return nil
} }
// muxerWrapper wraps DatagramMuxerV2 to satisfy the packet.FunnelUniPipe interface
type muxerWrapper struct {
muxer *cfdquic.DatagramMuxerV2
}
func (rp *muxerWrapper) SendPacket(dst netip.Addr, pk packet.RawPacket) error {
return rp.muxer.SendPacket(cfdquic.RawPacket(pk))
}
func (rp *muxerWrapper) ReceivePacket(ctx context.Context) (packet.RawPacket, error) {
pk, err := rp.muxer.ReceivePacket(ctx)
if err != nil {
return packet.RawPacket{}, err
}
rawPacket, ok := pk.(cfdquic.RawPacket)
if ok {
return packet.RawPacket(rawPacket), nil
}
return packet.RawPacket{}, fmt.Errorf("unexpected packet type %+v", pk)
}
func (rp *muxerWrapper) Close() error {
return nil
}

View File

@ -8,7 +8,6 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"io" "io"
"math/big" "math/big"
@ -22,15 +21,13 @@ import (
"github.com/gobwas/ws/wsutil" "github.com/gobwas/ws/wsutil"
"github.com/google/uuid" "github.com/google/uuid"
pkgerrors "github.com/pkg/errors" "github.com/pkg/errors"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/net/nettest" "golang.org/x/net/nettest"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/datagramsession" "github.com/cloudflare/cloudflared/datagramsession"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/packet" "github.com/cloudflare/cloudflared/packet"
@ -56,8 +53,7 @@ var _ ReadWriteAcker = (*streamReadWriteAcker)(nil)
func TestQUICServer(t *testing.T) { func TestQUICServer(t *testing.T) {
// This is simply a sample websocket frame message. // This is simply a sample websocket frame message.
wsBuf := &bytes.Buffer{} wsBuf := &bytes.Buffer{}
err := wsutil.WriteClientBinary(wsBuf, []byte("Hello")) wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
require.NoError(t, err)
var tests = []struct { var tests = []struct {
desc string desc string
@ -162,19 +158,17 @@ func TestQUICServer(t *testing.T) {
serverDone := make(chan struct{}) serverDone := make(chan struct{})
go func() { go func() {
// nolint: testifylint
quicServer( quicServer(
ctx, t, quicListener, test.dest, test.connectionType, test.metadata, test.message, test.expectedResponse, ctx, t, quicListener, test.dest, test.connectionType, test.metadata, test.message, test.expectedResponse,
) )
close(serverDone) close(serverDone)
}() }()
// nolint: gosec
tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(i)) tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(i))
connDone := make(chan struct{}) connDone := make(chan struct{})
go func() { go func() {
_ = tunnelConn.Serve(ctx) tunnelConn.Serve(ctx)
close(connDone) close(connDone)
}() }()
@ -260,14 +254,14 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, tr *tracing.T
case "/ok": case "/ok":
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK))) originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
case "/slow_echo_body": case "/slow_echo_body":
time.Sleep(5 * time.Nanosecond) time.Sleep(5)
fallthrough fallthrough
case "/echo_body": case "/echo_body":
resp := &http.Response{ resp := &http.Response{
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
} }
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header) _ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
_, _ = io.Copy(w, r.Body) io.Copy(w, r.Body)
case "/error": case "/error":
return fmt.Errorf("Failed to proxy to origin") return fmt.Errorf("Failed to proxy to origin")
default: default:
@ -499,20 +493,16 @@ func TestBuildHTTPRequest(t *testing.T) {
test := test // capture range variable test := test // capture range variable
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
req, err := buildHTTPRequest(context.Background(), test.connectRequest, test.body, 0, &log) req, err := buildHTTPRequest(context.Background(), test.connectRequest, test.body, 0, &log)
require.NoError(t, err) assert.NoError(t, err)
test.req = test.req.WithContext(req.Context()) test.req = test.req.WithContext(req.Context())
require.Equal(t, test.req, req.Request) assert.Equal(t, test.req, req.Request)
}) })
} }
} }
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error { func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
if tcpRequest.Dest == "rate-limit-me" { rwa.AckConnection("")
return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "failed tcp stream") io.Copy(rwa, rwa)
}
_ = rwa.AckConnection("")
_, _ = io.Copy(rwa, rwa)
return nil return nil
} }
@ -530,19 +520,16 @@ func TestServeUDPSession(t *testing.T) {
edgeQUICSessionChan := make(chan quic.Connection) edgeQUICSessionChan := make(chan quic.Connection)
go func() { go func() {
earlyListener, err := quic.Listen(udpListener, testTLSServerConfig, testQUICConfig) earlyListener, err := quic.Listen(udpListener, testTLSServerConfig, testQUICConfig)
assert.NoError(t, err) require.NoError(t, err)
edgeQUICSession, err := earlyListener.Accept(ctx) edgeQUICSession, err := earlyListener.Accept(ctx)
assert.NoError(t, err) require.NoError(t, err)
edgeQUICSessionChan <- edgeQUICSession edgeQUICSessionChan <- edgeQUICSession
}() }()
// Random index to avoid reusing port // Random index to avoid reusing port
tunnelConn, datagramConn := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), 28) tunnelConn, datagramConn := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), 28)
go func() { go tunnelConn.Serve(ctx)
_ = tunnelConn.Serve(ctx)
}()
edgeQUICSession := <-edgeQUICSessionChan edgeQUICSession := <-edgeQUICSessionChan
@ -558,14 +545,14 @@ func TestNopCloserReadWriterCloseBeforeEOF(t *testing.T) {
n, err := readerWriter.Read(buffer) n, err := readerWriter.Read(buffer)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 5, n) require.Equal(t, n, 5)
// close // close
require.NoError(t, readerWriter.Close()) require.NoError(t, readerWriter.Close())
// read should get error // read should get error
n, err = readerWriter.Read(buffer) n, err = readerWriter.Read(buffer)
require.Equal(t, 0, n) require.Equal(t, n, 0)
require.Equal(t, err, fmt.Errorf("closed by handler")) require.Equal(t, err, fmt.Errorf("closed by handler"))
} }
@ -575,7 +562,7 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
n, err := readerWriter.Read(buffer) n, err := readerWriter.Read(buffer)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 9, n) require.Equal(t, n, 9)
// force another read to read eof // force another read to read eof
_, err = readerWriter.Read(buffer) _, err = readerWriter.Read(buffer)
@ -586,7 +573,7 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
// read should get EOF still // read should get EOF still
n, err = readerWriter.Read(buffer) n, err = readerWriter.Read(buffer)
require.Equal(t, 0, n) require.Equal(t, n, 0)
require.Equal(t, err, io.EOF) require.Equal(t, err, io.EOF)
} }
@ -602,59 +589,6 @@ func TestCreateUDPConnReuseSourcePort(t *testing.T) {
} }
} }
// TestTCPProxy_FlowRateLimited tests if the pogs.ConnectResponse returns the expected error and metadata, when a
// new flow is rate limited.
func TestTCPProxy_FlowRateLimited(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
// Start a UDP Listener for QUIC.
udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
require.NoError(t, err)
udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr)
require.NoError(t, err)
defer udpListener.Close()
quicTransport := &quic.Transport{Conn: udpListener, ConnectionIDLength: 16}
quicListener, err := quicTransport.Listen(testTLSServerConfig, testQUICConfig)
require.NoError(t, err)
serverDone := make(chan struct{})
go func() {
defer close(serverDone)
session, err := quicListener.Accept(ctx)
assert.NoError(t, err)
quicStream, err := session.OpenStreamSync(context.Background())
assert.NoError(t, err)
stream := cfdquic.NewSafeStreamCloser(quicStream, defaultQUICTimeout, &log)
reqClientStream := rpcquic.RequestClientStream{ReadWriteCloser: stream}
err = reqClientStream.WriteConnectRequestData("rate-limit-me", pogs.ConnectionTypeTCP)
assert.NoError(t, err)
response, err := reqClientStream.ReadConnectResponseData()
assert.NoError(t, err)
// Got Rate Limited
assert.NotEmpty(t, response.Error)
assert.Contains(t, response.Metadata, pogs.ErrorFlowConnectRateLimitedMetadata)
}()
tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(0))
connDone := make(chan struct{})
go func() {
defer close(connDone)
_ = tunnelConn.Serve(ctx)
}()
<-serverDone
cancel()
<-connDone
}
func testCreateUDPConnReuseSourcePortForEdgeIP(t *testing.T, edgeIP netip.AddrPort) { func testCreateUDPConnReuseSourcePortForEdgeIP(t *testing.T, edgeIP netip.AddrPort) {
logger := zerolog.Nop() logger := zerolog.Nop()
conn, err := createUDPConnForConnIndex(0, nil, edgeIP, &logger) conn, err := createUDPConnForConnIndex(0, nil, edgeIP, &logger)
@ -735,7 +669,6 @@ func serveSession(ctx context.Context, datagramConn *datagramV2Connection, edgeQ
unregisterReason: expectedReason, unregisterReason: expectedReason,
calledUnregisterChan: unregisterFromEdgeChan, calledUnregisterChan: unregisterFromEdgeChan,
} }
// nolint: testifylint
go runRPCServer(ctx, edgeQUICSession, sessionRPCServer, nil, t) go runRPCServer(ctx, edgeQUICSession, sessionRPCServer, nil, t)
<-unregisterFromEdgeChan <-unregisterFromEdgeChan
@ -796,7 +729,6 @@ func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionI
func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) (TunnelConnection, *datagramV2Connection) { func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) (TunnelConnection, *datagramV2Connection) {
tlsClientConfig := &tls.Config{ tlsClientConfig := &tls.Config{
// nolint: gosec
InsecureSkipVerify: true, InsecureSkipVerify: true,
NextProtos: []string{"argotunnel"}, NextProtos: []string{"argotunnel"},
} }
@ -815,20 +747,16 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
index, index,
&log, &log,
) )
require.NoError(t, err)
// Start a session manager for the connection // Start a session manager for the connection
sessionDemuxChan := make(chan *packet.Session, 4) sessionDemuxChan := make(chan *packet.Session, 4)
datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, &log, sessionDemuxChan) datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, &log, sessionDemuxChan)
sessionManager := datagramsession.NewManager(&log, datagramMuxer.SendToSession, sessionDemuxChan) sessionManager := datagramsession.NewManager(&log, datagramMuxer.SendToSession, sessionDemuxChan)
var connIndex uint8 = 0 packetRouter := ingress.NewPacketRouter(nil, datagramMuxer, &log)
packetRouter := ingress.NewPacketRouter(nil, datagramMuxer, connIndex, &log)
datagramConn := &datagramV2Connection{ datagramConn := &datagramV2Connection{
conn, conn,
index,
sessionManager, sessionManager,
cfdflow.NewLimiter(0),
datagramMuxer, datagramMuxer,
packetRouter, packetRouter,
15 * time.Second, 15 * time.Second,
@ -867,7 +795,6 @@ func (m *mockReaderNoopWriter) Close() error {
// GenerateTLSConfig sets up a bare-bones TLS config for a QUIC server // GenerateTLSConfig sets up a bare-bones TLS config for a QUIC server
func GenerateTLSConfig() *tls.Config { func GenerateTLSConfig() *tls.Config {
// nolint: gosec
key, err := rsa.GenerateKey(rand.Reader, 1024) key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil { if err != nil {
panic(err) panic(err)
@ -884,7 +811,6 @@ func GenerateTLSConfig() *tls.Config {
if err != nil { if err != nil {
panic(err) panic(err)
} }
// nolint: gosec
return &tls.Config{ return &tls.Config{
Certificates: []tls.Certificate{tlsCert}, Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{"argotunnel"}, NextProtos: []string{"argotunnel"},

View File

@ -7,15 +7,12 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
pkgerrors "github.com/pkg/errors"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace" "go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/datagramsession" "github.com/cloudflare/cloudflared/datagramsession"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/management"
@ -41,14 +38,10 @@ type DatagramSessionHandler interface {
} }
type datagramV2Connection struct { type datagramV2Connection struct {
conn quic.Connection conn quic.Connection
index uint8
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer // sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
sessionManager datagramsession.Manager sessionManager datagramsession.Manager
// flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
flowLimiter cfdflow.Limiter
// datagramMuxer mux/demux datagrams from quic connection // datagramMuxer mux/demux datagrams from quic connection
datagramMuxer *cfdquic.DatagramMuxerV2 datagramMuxer *cfdquic.DatagramMuxerV2
packetRouter *ingress.PacketRouter packetRouter *ingress.PacketRouter
@ -61,28 +54,24 @@ type datagramV2Connection struct {
func NewDatagramV2Connection(ctx context.Context, func NewDatagramV2Connection(ctx context.Context,
conn quic.Connection, conn quic.Connection,
icmpRouter ingress.ICMPRouter, packetConfig *ingress.GlobalRouterConfig,
index uint8,
rpcTimeout time.Duration, rpcTimeout time.Duration,
streamWriteTimeout time.Duration, streamWriteTimeout time.Duration,
flowLimiter cfdflow.Limiter,
logger *zerolog.Logger, logger *zerolog.Logger,
) DatagramSessionHandler { ) DatagramSessionHandler {
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity) sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, logger, sessionDemuxChan) datagramMuxer := cfdquic.NewDatagramMuxerV2(conn, logger, sessionDemuxChan)
sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan) sessionManager := datagramsession.NewManager(logger, datagramMuxer.SendToSession, sessionDemuxChan)
packetRouter := ingress.NewPacketRouter(icmpRouter, datagramMuxer, index, logger) packetRouter := ingress.NewPacketRouter(packetConfig, datagramMuxer, logger)
return &datagramV2Connection{ return &datagramV2Connection{
conn: conn, conn,
index: index, sessionManager,
sessionManager: sessionManager, datagramMuxer,
flowLimiter: flowLimiter, packetRouter,
datagramMuxer: datagramMuxer, rpcTimeout,
packetRouter: packetRouter, streamWriteTimeout,
rpcTimeout: rpcTimeout, logger,
streamWriteTimeout: streamWriteTimeout,
logger: logger,
} }
} }
@ -119,23 +108,12 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
attribute.String("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)), attribute.String("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)),
)) ))
log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger() log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger()
// Try to start a new session
if err := q.flowLimiter.Acquire(management.UDP.String()); err != nil {
log.Warn().Msgf("Too many concurrent sessions being handled, rejecting udp proxy to %s:%d", dstIP, dstPort)
err := pkgerrors.Wrap(err, "failed to start udp session due to rate limiting")
tracing.EndWithErrorStatus(registerSpan, err)
return nil, err
}
// Each session is a series of datagram from an eyeball to a dstIP:dstPort. // Each session is a series of datagram from an eyeball to a dstIP:dstPort.
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket. // (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
originProxy, err := ingress.DialUDP(dstIP, dstPort) originProxy, err := ingress.DialUDP(dstIP, dstPort)
if err != nil { if err != nil {
log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort) log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
tracing.EndWithErrorStatus(registerSpan, err) tracing.EndWithErrorStatus(registerSpan, err)
q.flowLimiter.Release()
return nil, err return nil, err
} }
registerSpan.SetAttributes( registerSpan.SetAttributes(
@ -148,14 +126,10 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
originProxy.Close() originProxy.Close()
log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session") log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
tracing.EndWithErrorStatus(registerSpan, err) tracing.EndWithErrorStatus(registerSpan, err)
q.flowLimiter.Release()
return nil, err return nil, err
} }
go func() { go q.serveUDPSession(session, closeAfterIdleHint)
defer q.flowLimiter.Release() // we do the release here, instead of inside the `serveUDPSession` just to keep all acquire/release calls in the same method.
q.serveUDPSession(session, closeAfterIdleHint)
}()
log.Debug(). log.Debug().
Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)). Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).
@ -195,7 +169,7 @@ func (q *datagramV2Connection) serveUDPSession(session *datagramsession.Session,
// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge // closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge
func (q *datagramV2Connection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) { func (q *datagramV2Connection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) {
_ = q.sessionManager.UnregisterSession(ctx, sessionID, message, false) q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
quicStream, err := q.conn.OpenStream() quicStream, err := q.conn.OpenStream()
if err != nil { if err != nil {
// Log this at debug because this is not an error if session was closed due to lost connection // Log this at debug because this is not an error if session was closed due to lost connection

View File

@ -1,96 +0,0 @@
package connection
import (
"context"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/mocks"
)
type mockQuicConnection struct {
}
func (m *mockQuicConnection) AcceptStream(_ context.Context) (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) AcceptUniStream(_ context.Context) (quic.ReceiveStream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenStream() (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenStreamSync(_ context.Context) (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenUniStream() (quic.SendStream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenUniStreamSync(_ context.Context) (quic.SendStream, error) {
return nil, nil
}
func (m *mockQuicConnection) LocalAddr() net.Addr {
return nil
}
func (m *mockQuicConnection) RemoteAddr() net.Addr {
return nil
}
func (m *mockQuicConnection) CloseWithError(_ quic.ApplicationErrorCode, s string) error {
return nil
}
func (m *mockQuicConnection) Context() context.Context {
return nil
}
func (m *mockQuicConnection) ConnectionState() quic.ConnectionState {
panic("not meant to be called")
}
func (m *mockQuicConnection) SendDatagram(_ []byte) error {
return nil
}
func (m *mockQuicConnection) ReceiveDatagram(_ context.Context) ([]byte, error) {
return nil, nil
}
func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
log := zerolog.Nop()
conn := &mockQuicConnection{}
ctrl := gomock.NewController(t)
flowLimiterMock := mocks.NewMockLimiter(ctrl)
datagramConn := NewDatagramV2Connection(
context.Background(),
conn,
nil,
0,
0*time.Second,
0*time.Second,
flowLimiterMock,
&log,
)
flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
flowLimiterMock.EXPECT().Release().Times(0)
_, err := datagramConn.RegisterUdpSession(context.Background(), uuid.New(), net.IPv4(0, 0, 0, 0), 1000, 1*time.Second, "")
require.ErrorIs(t, err, cfdflow.ErrTooManyActiveFlows)
}

View File

@ -10,7 +10,6 @@ import (
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/management" "github.com/cloudflare/cloudflared/management"
cfdquic "github.com/cloudflare/cloudflared/quic/v3" cfdquic "github.com/cloudflare/cloudflared/quic/v3"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs" "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -26,7 +25,6 @@ type datagramV3Connection struct {
func NewDatagramV3Connection(ctx context.Context, func NewDatagramV3Connection(ctx context.Context,
conn quic.Connection, conn quic.Connection,
sessionManager cfdquic.SessionManager, sessionManager cfdquic.SessionManager,
icmpRouter ingress.ICMPRouter,
index uint8, index uint8,
metrics cfdquic.Metrics, metrics cfdquic.Metrics,
logger *zerolog.Logger, logger *zerolog.Logger,
@ -36,7 +34,7 @@ func NewDatagramV3Connection(ctx context.Context,
Int(management.EventTypeKey, int(management.UDP)). Int(management.EventTypeKey, int(management.UDP)).
Uint8(LogFieldConnIndex, index). Uint8(LogFieldConnIndex, index).
Logger() Logger()
datagramMuxer := cfdquic.NewDatagramConn(conn, sessionManager, icmpRouter, index, metrics, &log) datagramMuxer := cfdquic.NewDatagramConn(conn, sessionManager, index, metrics, &log)
return &datagramV3Connection{ return &datagramV3Connection{
conn, conn,

View File

@ -9,7 +9,6 @@ import (
const ( const (
logFieldOriginCertPath = "originCertPath" logFieldOriginCertPath = "originCertPath"
FedEndpoint = "fed"
) )
type User struct { type User struct {
@ -33,10 +32,6 @@ func (c User) CertPath() string {
return c.certPath return c.certPath
} }
func (c User) IsFEDEndpoint() bool {
return c.cert.Endpoint == FedEndpoint
}
// Client uses the user credentials to create a Cloudflare API client // Client uses the user credentials to create a Cloudflare API client
func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfapi.Client, error) { func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfapi.Client, error) {
if apiURL == "" { if apiURL == "" {
@ -50,6 +45,7 @@ func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfa
userAgent, userAgent,
log, log,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -3,7 +3,7 @@ package credentials
import ( import (
"io/fs" "io/fs"
"os" "os"
"path/filepath" "path"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -13,8 +13,8 @@ func TestCredentialsRead(t *testing.T) {
file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem") file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem")
require.NoError(t, err) require.NoError(t, err)
dir := t.TempDir() dir := t.TempDir()
certPath := filepath.Join(dir, originCertFile) certPath := path.Join(dir, originCertFile)
_ = os.WriteFile(certPath, file, fs.ModePerm) os.WriteFile(certPath, file, fs.ModePerm)
user, err := Read(certPath, &nopLog) user, err := Read(certPath, &nopLog)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, certPath, user.CertPath()) require.Equal(t, certPath, user.CertPath())

View File

@ -1,13 +1,11 @@
package credentials package credentials
import ( import (
"bytes"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"github.com/mitchellh/go-homedir" "github.com/mitchellh/go-homedir"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -17,30 +15,19 @@ import (
const ( const (
DefaultCredentialFile = "cert.pem" DefaultCredentialFile = "cert.pem"
OriginCertFlag = "origincert"
) )
type OriginCert struct { type namedTunnelToken struct {
ZoneID string `json:"zoneID"` ZoneID string `json:"zoneID"`
AccountID string `json:"accountID"` AccountID string `json:"accountID"`
APIToken string `json:"apiToken"` APIToken string `json:"apiToken"`
Endpoint string `json:"endpoint,omitempty"`
} }
func (oc *OriginCert) UnmarshalJSON(data []byte) error { type OriginCert struct {
var aux struct { ZoneID string
ZoneID string `json:"zoneID"` APIToken string
AccountID string `json:"accountID"` AccountID string
APIToken string `json:"apiToken"`
Endpoint string `json:"endpoint,omitempty"`
}
if err := json.Unmarshal(data, &aux); err != nil {
return fmt.Errorf("error parsing OriginCert: %v", err)
}
oc.ZoneID = aux.ZoneID
oc.AccountID = aux.AccountID
oc.APIToken = aux.APIToken
oc.Endpoint = strings.ToLower(aux.Endpoint)
return nil
} }
// FindDefaultOriginCertPath returns the first path that contains a cert.pem file. If none of the // FindDefaultOriginCertPath returns the first path that contains a cert.pem file. If none of the
@ -55,56 +42,40 @@ func FindDefaultOriginCertPath() string {
return "" return ""
} }
func DecodeOriginCert(blocks []byte) (*OriginCert, error) {
return decodeOriginCert(blocks)
}
func (cert *OriginCert) EncodeOriginCert() ([]byte, error) {
if cert == nil {
return nil, fmt.Errorf("originCert cannot be nil")
}
buffer, err := json.Marshal(cert)
if err != nil {
return nil, fmt.Errorf("originCert marshal failed: %v", err)
}
block := pem.Block{
Type: "ARGO TUNNEL TOKEN",
Headers: map[string]string{},
Bytes: buffer,
}
var out bytes.Buffer
err = pem.Encode(&out, &block)
if err != nil {
return nil, fmt.Errorf("pem encoding failed: %v", err)
}
return out.Bytes(), nil
}
func decodeOriginCert(blocks []byte) (*OriginCert, error) { func decodeOriginCert(blocks []byte) (*OriginCert, error) {
if len(blocks) == 0 { if len(blocks) == 0 {
return nil, fmt.Errorf("cannot decode empty certificate") return nil, fmt.Errorf("Cannot decode empty certificate")
} }
originCert := OriginCert{} originCert := OriginCert{}
block, rest := pem.Decode(blocks) block, rest := pem.Decode(blocks)
for block != nil { for {
if block == nil {
break
}
switch block.Type { switch block.Type {
case "PRIVATE KEY", "CERTIFICATE": case "PRIVATE KEY", "CERTIFICATE":
// this is for legacy purposes. // this is for legacy purposes.
break
case "ARGO TUNNEL TOKEN": case "ARGO TUNNEL TOKEN":
if originCert.ZoneID != "" || originCert.APIToken != "" { if originCert.ZoneID != "" || originCert.APIToken != "" {
return nil, fmt.Errorf("found multiple tokens in the certificate") return nil, fmt.Errorf("Found multiple tokens in the certificate")
} }
// The token is a string, // The token is a string,
// Try the newer JSON format // Try the newer JSON format
_ = json.Unmarshal(block.Bytes, &originCert) ntt := namedTunnelToken{}
if err := json.Unmarshal(block.Bytes, &ntt); err == nil {
originCert.ZoneID = ntt.ZoneID
originCert.APIToken = ntt.APIToken
originCert.AccountID = ntt.AccountID
}
default: default:
return nil, fmt.Errorf("unknown block %s in the certificate", block.Type) return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type)
} }
block, rest = pem.Decode(rest) block, rest = pem.Decode(rest)
} }
if originCert.ZoneID == "" || originCert.APIToken == "" { if originCert.ZoneID == "" || originCert.APIToken == "" {
return nil, fmt.Errorf("missing token in the certificate") return nil, fmt.Errorf("Missing token in the certificate")
} }
return &originCert, nil return &originCert, nil

View File

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"os" "os"
"path/filepath" "path"
"testing" "testing"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -16,25 +16,27 @@ const (
originCertFile = "cert.pem" originCertFile = "cert.pem"
) )
var nopLog = zerolog.Nop().With().Logger() var (
nopLog = zerolog.Nop().With().Logger()
)
func TestLoadOriginCert(t *testing.T) { func TestLoadOriginCert(t *testing.T) {
cert, err := decodeOriginCert([]byte{}) cert, err := decodeOriginCert([]byte{})
assert.Equal(t, fmt.Errorf("cannot decode empty certificate"), err) assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err)
assert.Nil(t, cert) assert.Nil(t, cert)
blocks, err := os.ReadFile("test-cert-unknown-block.pem") blocks, err := os.ReadFile("test-cert-unknown-block.pem")
require.NoError(t, err) assert.NoError(t, err)
cert, err = decodeOriginCert(blocks) cert, err = decodeOriginCert(blocks)
assert.Equal(t, fmt.Errorf("unknown block RSA PRIVATE KEY in the certificate"), err) assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err)
assert.Nil(t, cert) assert.Nil(t, cert)
} }
func TestJSONArgoTunnelTokenEmpty(t *testing.T) { func TestJSONArgoTunnelTokenEmpty(t *testing.T) {
blocks, err := os.ReadFile("test-cert-no-token.pem") blocks, err := os.ReadFile("test-cert-no-token.pem")
require.NoError(t, err) assert.NoError(t, err)
cert, err := decodeOriginCert(blocks) cert, err := decodeOriginCert(blocks)
assert.Equal(t, fmt.Errorf("missing token in the certificate"), err) assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err)
assert.Nil(t, cert) assert.Nil(t, cert)
} }
@ -50,21 +52,51 @@ func TestJSONArgoTunnelToken(t *testing.T) {
func CloudflareTunnelTokenTest(t *testing.T, path string) { func CloudflareTunnelTokenTest(t *testing.T, path string) {
blocks, err := os.ReadFile(path) blocks, err := os.ReadFile(path)
require.NoError(t, err) assert.NoError(t, err)
cert, err := decodeOriginCert(blocks) cert, err := decodeOriginCert(blocks)
require.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, cert) assert.NotNil(t, cert)
assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID) assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID)
key := "test-service-key" key := "test-service-key"
assert.Equal(t, key, cert.APIToken) assert.Equal(t, key, cert.APIToken)
} }
type mockFile struct {
path string
data []byte
err error
}
type mockFileSystem struct {
files map[string]mockFile
}
func newMockFileSystem(files ...mockFile) *mockFileSystem {
fs := mockFileSystem{map[string]mockFile{}}
for _, f := range files {
fs.files[f.path] = f
}
return &fs
}
func (fs *mockFileSystem) ReadFile(path string) ([]byte, error) {
if f, ok := fs.files[path]; ok {
return f.data, f.err
}
return nil, os.ErrNotExist
}
func (fs *mockFileSystem) ValidFilePath(path string) bool {
_, exists := fs.files[path]
return exists
}
func TestFindOriginCert_Valid(t *testing.T) { func TestFindOriginCert_Valid(t *testing.T) {
file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem") file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem")
require.NoError(t, err) require.NoError(t, err)
dir := t.TempDir() dir := t.TempDir()
certPath := filepath.Join(dir, originCertFile) certPath := path.Join(dir, originCertFile)
_ = os.WriteFile(certPath, file, fs.ModePerm) os.WriteFile(certPath, file, fs.ModePerm)
path, err := FindOriginCert(certPath, &nopLog) path, err := FindOriginCert(certPath, &nopLog)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, certPath, path) require.Equal(t, certPath, path)
@ -72,32 +104,7 @@ func TestFindOriginCert_Valid(t *testing.T) {
func TestFindOriginCert_Missing(t *testing.T) { func TestFindOriginCert_Missing(t *testing.T) {
dir := t.TempDir() dir := t.TempDir()
certPath := filepath.Join(dir, originCertFile) certPath := path.Join(dir, originCertFile)
_, err := FindOriginCert(certPath, &nopLog) _, err := FindOriginCert(certPath, &nopLog)
require.Error(t, err) require.Error(t, err)
} }
func TestEncodeDecodeOriginCert(t *testing.T) {
cert := OriginCert{
ZoneID: "zone",
AccountID: "account",
APIToken: "token",
Endpoint: "FED",
}
blocks, err := cert.EncodeOriginCert()
require.NoError(t, err)
decodedCert, err := DecodeOriginCert(blocks)
require.NoError(t, err)
assert.NotNil(t, cert)
assert.Equal(t, "zone", decodedCert.ZoneID)
assert.Equal(t, "account", decodedCert.AccountID)
assert.Equal(t, "token", decodedCert.APIToken)
assert.Equal(t, FedEndpoint, decodedCert.Endpoint)
}
func TestEncodeDecodeNilOriginCert(t *testing.T) {
var cert *OriginCert
blocks, err := cert.EncodeOriginCert()
assert.Equal(t, fmt.Errorf("originCert cannot be nil"), err)
require.Nil(t, blocks)
}

View File

@ -87,4 +87,3 @@ M2i4QoOFcSKIG+v4SuvgEJHgG8vGvxh2qlSxnMWuPV+7/1P5ATLqDj1PlKms+BNR
y7sc5AT9PclkL3Y9MNzOu0LXyBkGYcl8M0EQfLv9VPbWT+NXiMg/O2CHiT02pAAz y7sc5AT9PclkL3Y9MNzOu0LXyBkGYcl8M0EQfLv9VPbWT+NXiMg/O2CHiT02pAAz
uQicoQq3yzeQh20wtrtaXzTNmA== uQicoQq3yzeQh20wtrtaXzTNmA==
-----END RSA PRIVATE KEY----- -----END RSA PRIVATE KEY-----

View File

@ -1,6 +1,6 @@
FROM golang:1.22.10 as builder FROM golang:1.22.5 as builder
ENV GO111MODULE=on \ ENV GO111MODULE=on \
CGO_ENABLED=0 CGO_ENABLED=0
WORKDIR /go/src/github.com/cloudflare/cloudflared/ WORKDIR /go/src/github.com/cloudflare/cloudflared/
RUN apt-get update RUN apt-get update
COPY . . COPY . .

View File

@ -1,216 +0,0 @@
package diagnostic
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
)
type httpClient struct {
http.Client
baseURL *url.URL
}
func NewHTTPClient() *httpClient {
httpTransport := http.Transport{
TLSHandshakeTimeout: defaultTimeout,
ResponseHeaderTimeout: defaultTimeout,
}
return &httpClient{
http.Client{
Transport: &httpTransport,
Timeout: defaultTimeout,
},
nil,
}
}
func (client *httpClient) SetBaseURL(baseURL *url.URL) {
client.baseURL = baseURL
}
func (client *httpClient) GET(ctx context.Context, endpoint string) (*http.Response, error) {
if client.baseURL == nil {
return nil, ErrNoBaseURL
}
url := client.baseURL.JoinPath(endpoint)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
if err != nil {
return nil, fmt.Errorf("error creating GET request: %w", err)
}
req.Header.Add("Accept", "application/json;version=1")
response, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("error GET request: %w", err)
}
return response, nil
}
type LogConfiguration struct {
logFile string
logDirectory string
uid int // the uid of the user that started cloudflared
}
func (client *httpClient) GetLogConfiguration(ctx context.Context) (*LogConfiguration, error) {
response, err := client.GET(ctx, cliConfigurationEndpoint)
if err != nil {
return nil, err
}
defer response.Body.Close()
var data map[string]string
if err := json.NewDecoder(response.Body).Decode(&data); err != nil {
return nil, fmt.Errorf("failed to decode body: %w", err)
}
uidStr, exists := data[configurationKeyUID]
if !exists {
return nil, ErrKeyNotFound
}
uid, err := strconv.Atoi(uidStr)
if err != nil {
return nil, fmt.Errorf("error convertin pid to int: %w", err)
}
logFile, exists := data[cfdflags.LogFile]
if exists {
return &LogConfiguration{logFile, "", uid}, nil
}
logDirectory, exists := data[cfdflags.LogDirectory]
if exists {
return &LogConfiguration{"", logDirectory, uid}, nil
}
// No log configured may happen when cloudflared is executed as a managed service or
// when containerized
return &LogConfiguration{"", "", uid}, nil
}
func (client *httpClient) GetMemoryDump(ctx context.Context, writer io.Writer) error {
response, err := client.GET(ctx, memoryDumpEndpoint)
if err != nil {
return err
}
return copyToWriter(response, writer)
}
func (client *httpClient) GetGoroutineDump(ctx context.Context, writer io.Writer) error {
response, err := client.GET(ctx, goroutineDumpEndpoint)
if err != nil {
return err
}
return copyToWriter(response, writer)
}
func (client *httpClient) GetTunnelState(ctx context.Context) (*TunnelState, error) {
response, err := client.GET(ctx, tunnelStateEndpoint)
if err != nil {
return nil, err
}
defer response.Body.Close()
var state TunnelState
if err := json.NewDecoder(response.Body).Decode(&state); err != nil {
return nil, fmt.Errorf("failed to decode body: %w", err)
}
return &state, nil
}
func (client *httpClient) GetSystemInformation(ctx context.Context, writer io.Writer) error {
response, err := client.GET(ctx, systemInformationEndpoint)
if err != nil {
return err
}
return copyJSONToWriter(response, writer)
}
func (client *httpClient) GetMetrics(ctx context.Context, writer io.Writer) error {
response, err := client.GET(ctx, metricsEndpoint)
if err != nil {
return err
}
return copyToWriter(response, writer)
}
func (client *httpClient) GetTunnelConfiguration(ctx context.Context, writer io.Writer) error {
response, err := client.GET(ctx, tunnelConfigurationEndpoint)
if err != nil {
return err
}
return copyJSONToWriter(response, writer)
}
func (client *httpClient) GetCliConfiguration(ctx context.Context, writer io.Writer) error {
response, err := client.GET(ctx, cliConfigurationEndpoint)
if err != nil {
return err
}
return copyJSONToWriter(response, writer)
}
func copyToWriter(response *http.Response, writer io.Writer) error {
defer response.Body.Close()
_, err := io.Copy(writer, response.Body)
if err != nil {
return fmt.Errorf("error writing response: %w", err)
}
return nil
}
func copyJSONToWriter(response *http.Response, writer io.Writer) error {
defer response.Body.Close()
var data interface{}
decoder := json.NewDecoder(response.Body)
err := decoder.Decode(&data)
if err != nil {
return fmt.Errorf("diagnostic client error whilst reading response: %w", err)
}
encoder := newFormattedEncoder(writer)
err = encoder.Encode(data)
if err != nil {
return fmt.Errorf("diagnostic client error whilst writing json: %w", err)
}
return nil
}
type HTTPClient interface {
GetLogConfiguration(ctx context.Context) (*LogConfiguration, error)
GetMemoryDump(ctx context.Context, writer io.Writer) error
GetGoroutineDump(ctx context.Context, writer io.Writer) error
GetTunnelState(ctx context.Context) (*TunnelState, error)
GetSystemInformation(ctx context.Context, writer io.Writer) error
GetMetrics(ctx context.Context, writer io.Writer) error
GetCliConfiguration(ctx context.Context, writer io.Writer) error
GetTunnelConfiguration(ctx context.Context, writer io.Writer) error
}

View File

@ -1,37 +0,0 @@
package diagnostic
import "time"
const (
defaultCollectorTimeout = time.Second * 10 // This const define the timeout value of a collector operation.
collectorField = "collector" // used for logging purposes
systemCollectorName = "system" // used for logging purposes
tunnelStateCollectorName = "tunnelState" // used for logging purposes
configurationCollectorName = "configuration" // used for logging purposes
defaultTimeout = 15 * time.Second // timeout for the collectors
twoWeeksOffset = -14 * 24 * time.Hour // maximum offset for the logs
logFilename = "cloudflared_logs.txt" // name of the output log file
configurationKeyUID = "uid" // Key used to set and get the UID value from the configuration map
tailMaxNumberOfLines = "10000" // maximum number of log lines from a virtual runtime (docker or kubernetes)
// Endpoints used by the diagnostic HTTP Client.
cliConfigurationEndpoint = "/diag/configuration"
tunnelStateEndpoint = "/diag/tunnel"
systemInformationEndpoint = "/diag/system"
memoryDumpEndpoint = "debug/pprof/heap"
goroutineDumpEndpoint = "debug/pprof/goroutine"
metricsEndpoint = "metrics"
tunnelConfigurationEndpoint = "/config"
// Base for filenames of the diagnostic procedure
systemInformationBaseName = "systeminformation.json"
metricsBaseName = "metrics.txt"
zipName = "cloudflared-diag"
heapPprofBaseName = "heap.pprof"
goroutinePprofBaseName = "goroutine.pprof"
networkBaseName = "network.json"
rawNetworkBaseName = "raw-network.txt"
tunnelStateBaseName = "tunnelstate.json"
cliConfigurationBaseName = "cli-configuration.json"
configurationBaseName = "configuration.json"
taskResultBaseName = "task-result.json"
)

View File

@ -1,561 +0,0 @@
package diagnostic
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/rs/zerolog"
network "github.com/cloudflare/cloudflared/diagnostic/network"
)
const (
taskSuccess = "success"
taskFailure = "failure"
jobReportName = "job report"
tunnelStateJobName = "tunnel state"
systemInformationJobName = "system information"
goroutineJobName = "goroutine profile"
heapJobName = "heap profile"
metricsJobName = "metrics"
logInformationJobName = "log information"
rawNetworkInformationJobName = "raw network information"
networkInformationJobName = "network information"
cliConfigurationJobName = "cli configuration"
configurationJobName = "configuration"
)
// Struct used to hold the results of different routines executing the network collection.
type taskResult struct {
Result string `json:"result,omitempty"`
Err error `json:"error,omitempty"`
path string
}
func (result taskResult) MarshalJSON() ([]byte, error) {
s := map[string]string{
"result": result.Result,
}
if result.Err != nil {
s["error"] = result.Err.Error()
}
return json.Marshal(s)
}
// Struct used to hold the results of different routines executing the network collection.
type networkCollectionResult struct {
name string
info []*network.Hop
raw string
err error
}
// This type represents the most common functions from the diagnostic http client
// functions.
type collectToWriterFunc func(ctx context.Context, writer io.Writer) error
// This type represents the common denominator among all the collection procedures.
type collectFunc func(ctx context.Context) (string, error)
// collectJob is an internal struct that denotes holds the information necessary
// to run a collection job.
type collectJob struct {
jobName string
fn collectFunc
bypass bool
}
// The Toggles structure denotes the available toggles for the diagnostic procedure.
// Each toggle enables/disables tasks from the diagnostic.
type Toggles struct {
NoDiagLogs bool
NoDiagMetrics bool
NoDiagSystem bool
NoDiagRuntime bool
NoDiagNetwork bool
}
// The Options structure holds every option necessary for
// the diagnostic procedure to work.
type Options struct {
KnownAddresses []string
Address string
ContainerID string
PodID string
Toggles Toggles
}
func collectLogs(
ctx context.Context,
client HTTPClient,
diagContainer, diagPod string,
) (string, error) {
var collector LogCollector
if diagPod != "" {
collector = NewKubernetesLogCollector(diagContainer, diagPod)
} else if diagContainer != "" {
collector = NewDockerLogCollector(diagContainer)
} else {
collector = NewHostLogCollector(client)
}
logInformation, err := collector.Collect(ctx)
if err != nil {
return "", fmt.Errorf("error collecting logs: %w", err)
}
if logInformation.isDirectory {
return CopyFilesFromDirectory(logInformation.path)
}
if logInformation.wasCreated {
return logInformation.path, nil
}
logHandle, err := os.Open(logInformation.path)
if err != nil {
return "", fmt.Errorf("error opening log file while collecting logs: %w", err)
}
defer logHandle.Close()
outputLogHandle, err := os.Create(filepath.Join(os.TempDir(), logFilename))
if err != nil {
return "", ErrCreatingTemporaryFile
}
defer outputLogHandle.Close()
_, err = io.Copy(outputLogHandle, logHandle)
if err != nil {
return "", fmt.Errorf("error copying logs while collecting logs: %w", err)
}
return outputLogHandle.Name(), err
}
func collectNetworkResultRoutine(
ctx context.Context,
collector network.NetworkCollector,
hostname string,
useIPv4 bool,
results chan networkCollectionResult,
) {
const (
hopsNo = 5
timeout = time.Second * 5
)
name := hostname
if useIPv4 {
name += "-v4"
} else {
name += "-v6"
}
hops, raw, err := collector.Collect(ctx, network.NewTraceOptions(hopsNo, timeout, hostname, useIPv4))
results <- networkCollectionResult{name, hops, raw, err}
}
func gatherNetworkInformation(ctx context.Context) map[string]networkCollectionResult {
networkCollector := network.NetworkCollectorImpl{}
hostAndIPversionPairs := []struct {
host string
useV4 bool
}{
{"region1.v2.argotunnel.com", true},
{"region1.v2.argotunnel.com", false},
{"region2.v2.argotunnel.com", true},
{"region2.v2.argotunnel.com", false},
}
// the number of results is known thus use len to avoid footguns
results := make(chan networkCollectionResult, len(hostAndIPversionPairs))
var wgroup sync.WaitGroup
for _, item := range hostAndIPversionPairs {
wgroup.Add(1)
go func() {
defer wgroup.Done()
collectNetworkResultRoutine(ctx, &networkCollector, item.host, item.useV4, results)
}()
}
// Wait for routines to end.
wgroup.Wait()
resultMap := make(map[string]networkCollectionResult)
for range len(hostAndIPversionPairs) {
result := <-results
resultMap[result.name] = result
}
return resultMap
}
func networkInformationCollectors() (rawNetworkCollector, jsonNetworkCollector collectFunc) {
// The network collector is an operation that takes most of the diagnostic time, thus,
// the sync.Once is used to memoize the result of the collector and then create different
// outputs.
var once sync.Once
var resultMap map[string]networkCollectionResult
rawNetworkCollector = func(ctx context.Context) (string, error) {
once.Do(func() { resultMap = gatherNetworkInformation(ctx) })
return rawNetworkInformationWriter(resultMap)
}
jsonNetworkCollector = func(ctx context.Context) (string, error) {
once.Do(func() { resultMap = gatherNetworkInformation(ctx) })
return jsonNetworkInformationWriter(resultMap)
}
return rawNetworkCollector, jsonNetworkCollector
}
func rawNetworkInformationWriter(resultMap map[string]networkCollectionResult) (string, error) {
networkDumpHandle, err := os.Create(filepath.Join(os.TempDir(), rawNetworkBaseName))
if err != nil {
return "", ErrCreatingTemporaryFile
}
defer networkDumpHandle.Close()
var exitErr error
for k, v := range resultMap {
if v.err != nil {
if exitErr == nil {
exitErr = v.err
}
_, err := networkDumpHandle.WriteString(k + "\nno content\n")
if err != nil {
return networkDumpHandle.Name(), fmt.Errorf("error writing 'no content' to raw network file: %w", err)
}
} else {
_, err := networkDumpHandle.WriteString(k + "\n" + v.raw + "\n")
if err != nil {
return networkDumpHandle.Name(), fmt.Errorf("error writing raw network information: %w", err)
}
}
}
return networkDumpHandle.Name(), exitErr
}
func jsonNetworkInformationWriter(resultMap map[string]networkCollectionResult) (string, error) {
networkDumpHandle, err := os.Create(filepath.Join(os.TempDir(), networkBaseName))
if err != nil {
return "", ErrCreatingTemporaryFile
}
defer networkDumpHandle.Close()
encoder := newFormattedEncoder(networkDumpHandle)
var exitErr error
jsonMap := make(map[string][]*network.Hop, len(resultMap))
for k, v := range resultMap {
jsonMap[k] = v.info
if exitErr == nil && v.err != nil {
exitErr = v.err
}
}
err = encoder.Encode(jsonMap)
if err != nil {
return networkDumpHandle.Name(), fmt.Errorf("error encoding network information results: %w", err)
}
return networkDumpHandle.Name(), exitErr
}
func collectFromEndpointAdapter(collect collectToWriterFunc, fileName string) collectFunc {
return func(ctx context.Context) (string, error) {
dumpHandle, err := os.Create(filepath.Join(os.TempDir(), fileName))
if err != nil {
return "", ErrCreatingTemporaryFile
}
defer dumpHandle.Close()
err = collect(ctx, dumpHandle)
if err != nil {
return dumpHandle.Name(), fmt.Errorf("error running collector: %w", err)
}
return dumpHandle.Name(), nil
}
}
func tunnelStateCollectEndpointAdapter(client HTTPClient, tunnel *TunnelState, fileName string) collectFunc {
endpointFunc := func(ctx context.Context, writer io.Writer) error {
if tunnel == nil {
// When the metrics server is not passed the diagnostic will query all known hosts
// and get the tunnel state, however, when the metrics server is passed that won't
// happen hence the check for nil in this function.
tunnelResponse, err := client.GetTunnelState(ctx)
if err != nil {
return fmt.Errorf("error retrieving tunnel state: %w", err)
}
tunnel = tunnelResponse
}
encoder := newFormattedEncoder(writer)
err := encoder.Encode(tunnel)
if err != nil {
return fmt.Errorf("error encoding tunnel state: %w", err)
}
return nil
}
return collectFromEndpointAdapter(endpointFunc, fileName)
}
// resolveInstanceBaseURL is responsible to
// resolve the base URL of the instance that should be diagnosed.
// To resolve the instance it may be necessary to query the
// /diag/tunnel endpoint of the known instances, thus, if a single
// instance is found its state is also returned; if multiple instances
// are found then their states are returned in an array along with an
// error.
func resolveInstanceBaseURL(
metricsServerAddress string,
log *zerolog.Logger,
client *httpClient,
addresses []string,
) (*url.URL, *TunnelState, []*AddressableTunnelState, error) {
if metricsServerAddress != "" {
if !strings.HasPrefix(metricsServerAddress, "http://") {
metricsServerAddress = "http://" + metricsServerAddress
}
url, err := url.Parse(metricsServerAddress)
if err != nil {
return nil, nil, nil, fmt.Errorf("provided address is not valid: %w", err)
}
return url, nil, nil, nil
}
tunnelState, foundTunnelStates, err := FindMetricsServer(log, client, addresses)
if err != nil {
return nil, nil, foundTunnelStates, err
}
return tunnelState.URL, tunnelState.TunnelState, nil, nil
}
func createJobs(
client *httpClient,
tunnel *TunnelState,
diagContainer string,
diagPod string,
noDiagSystem bool,
noDiagRuntime bool,
noDiagMetrics bool,
noDiagLogs bool,
noDiagNetwork bool,
) []collectJob {
rawNetworkCollectorFunc, jsonNetworkCollectorFunc := networkInformationCollectors()
jobs := []collectJob{
{
jobName: tunnelStateJobName,
fn: tunnelStateCollectEndpointAdapter(client, tunnel, tunnelStateBaseName),
bypass: false,
},
{
jobName: systemInformationJobName,
fn: collectFromEndpointAdapter(client.GetSystemInformation, systemInformationBaseName),
bypass: noDiagSystem,
},
{
jobName: goroutineJobName,
fn: collectFromEndpointAdapter(client.GetGoroutineDump, goroutinePprofBaseName),
bypass: noDiagRuntime,
},
{
jobName: heapJobName,
fn: collectFromEndpointAdapter(client.GetMemoryDump, heapPprofBaseName),
bypass: noDiagRuntime,
},
{
jobName: metricsJobName,
fn: collectFromEndpointAdapter(client.GetMetrics, metricsBaseName),
bypass: noDiagMetrics,
},
{
jobName: logInformationJobName,
fn: func(ctx context.Context) (string, error) {
return collectLogs(ctx, client, diagContainer, diagPod)
},
bypass: noDiagLogs,
},
{
jobName: rawNetworkInformationJobName,
fn: rawNetworkCollectorFunc,
bypass: noDiagNetwork,
},
{
jobName: networkInformationJobName,
fn: jsonNetworkCollectorFunc,
bypass: noDiagNetwork,
},
{
jobName: cliConfigurationJobName,
fn: collectFromEndpointAdapter(client.GetCliConfiguration, cliConfigurationBaseName),
bypass: false,
},
{
jobName: configurationJobName,
fn: collectFromEndpointAdapter(client.GetTunnelConfiguration, configurationBaseName),
bypass: false,
},
}
return jobs
}
func createTaskReport(taskReport map[string]taskResult) (string, error) {
dumpHandle, err := os.Create(filepath.Join(os.TempDir(), taskResultBaseName))
if err != nil {
return "", ErrCreatingTemporaryFile
}
defer dumpHandle.Close()
encoder := newFormattedEncoder(dumpHandle)
err = encoder.Encode(taskReport)
if err != nil {
return "", fmt.Errorf("error encoding task results: %w", err)
}
return dumpHandle.Name(), nil
}
func runJobs(ctx context.Context, jobs []collectJob, log *zerolog.Logger) map[string]taskResult {
jobReport := make(map[string]taskResult, len(jobs))
for _, job := range jobs {
if job.bypass {
continue
}
log.Info().Msgf("Collecting %s...", job.jobName)
path, err := job.fn(ctx)
var result taskResult
if err != nil {
result = taskResult{Result: taskFailure, Err: err, path: path}
log.Error().Err(err).Msgf("Job: %s finished with error.", job.jobName)
} else {
result = taskResult{Result: taskSuccess, Err: nil, path: path}
log.Info().Msgf("Collected %s.", job.jobName)
}
jobReport[job.jobName] = result
}
taskReportName, err := createTaskReport(jobReport)
var result taskResult
if err != nil {
result = taskResult{
Result: taskFailure,
path: taskReportName,
Err: err,
}
} else {
result = taskResult{
Result: taskSuccess,
path: taskReportName,
Err: nil,
}
}
jobReport[jobReportName] = result
return jobReport
}
func RunDiagnostic(
log *zerolog.Logger,
options Options,
) ([]*AddressableTunnelState, error) {
client := NewHTTPClient()
baseURL, tunnel, foundTunnels, err := resolveInstanceBaseURL(options.Address, log, client, options.KnownAddresses)
if err != nil {
return foundTunnels, err
}
log.Info().Msgf("Selected server %s starting diagnostic...", baseURL.String())
client.SetBaseURL(baseURL)
const timeout = 45 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
jobs := createJobs(
client,
tunnel,
options.ContainerID,
options.PodID,
options.Toggles.NoDiagSystem,
options.Toggles.NoDiagRuntime,
options.Toggles.NoDiagMetrics,
options.Toggles.NoDiagLogs,
options.Toggles.NoDiagNetwork,
)
jobsReport := runJobs(ctx, jobs, log)
paths := make([]string, 0)
var gerr error
for _, v := range jobsReport {
paths = append(paths, v.path)
if gerr == nil && v.Err != nil {
gerr = v.Err
}
defer func() {
if !errors.Is(v.Err, ErrCreatingTemporaryFile) {
os.Remove(v.path)
}
}()
}
zipfile, err := CreateDiagnosticZipFile(zipName, paths)
if err != nil {
return nil, err
}
log.Info().Msgf("Diagnostic file written: %v", zipfile)
return nil, gerr
}

View File

@ -1,148 +0,0 @@
package diagnostic
import (
"archive/zip"
"context"
"encoding/json"
"fmt"
"io"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog"
)
// CreateDiagnosticZipFile create a zip file with the contents from the all
// files paths. The files will be written in the root of the zip file.
// In case of an error occurs after whilst writing to the zip file
// this will be removed.
func CreateDiagnosticZipFile(base string, paths []string) (zipFileName string, err error) {
// Create a zip file with all files from paths added to the root
suffix := time.Now().Format(time.RFC3339)
zipFileName = base + "-" + suffix + ".zip"
zipFileName = strings.ReplaceAll(zipFileName, ":", "-")
archive, cerr := os.Create(zipFileName)
if cerr != nil {
return "", fmt.Errorf("error creating file %s: %w", zipFileName, cerr)
}
archiveWriter := zip.NewWriter(archive)
defer func() {
archiveWriter.Close()
archive.Close()
if err != nil {
os.Remove(zipFileName)
}
}()
for _, file := range paths {
if file == "" {
continue
}
var handle *os.File
handle, err = os.Open(file)
if err != nil {
return "", fmt.Errorf("error opening file %s: %w", zipFileName, err)
}
defer handle.Close()
// Keep the base only to not create sub directories in the
// zip file.
var writer io.Writer
writer, err = archiveWriter.Create(filepath.Base(file))
if err != nil {
return "", fmt.Errorf("error creating archive writer from %s: %w", file, err)
}
if _, err = io.Copy(writer, handle); err != nil {
return "", fmt.Errorf("error copying file %s: %w", file, err)
}
}
zipFileName = archive.Name()
return zipFileName, nil
}
type AddressableTunnelState struct {
*TunnelState
URL *url.URL
}
func findMetricsServerPredicate(tunnelID, connectorID uuid.UUID) func(state *TunnelState) bool {
if tunnelID != uuid.Nil && connectorID != uuid.Nil {
return func(state *TunnelState) bool {
return state.ConnectorID == connectorID && state.TunnelID == tunnelID
}
} else if tunnelID == uuid.Nil && connectorID != uuid.Nil {
return func(state *TunnelState) bool {
return state.ConnectorID == connectorID
}
} else if tunnelID != uuid.Nil && connectorID == uuid.Nil {
return func(state *TunnelState) bool {
return state.TunnelID == tunnelID
}
}
return func(*TunnelState) bool {
return true
}
}
// The FindMetricsServer will try to find the metrics server url.
// There are two possible error scenarios:
// 1. No instance is found which will only return ErrMetricsServerNotFound
// 2. Multiple instances are found which will return an array of state and ErrMultipleMetricsServerFound
// In case of success, only the state for the instance is returned.
func FindMetricsServer(
log *zerolog.Logger,
client *httpClient,
addresses []string,
) (*AddressableTunnelState, []*AddressableTunnelState, error) {
instances := make([]*AddressableTunnelState, 0)
for _, address := range addresses {
url, err := url.Parse("http://" + address)
if err != nil {
log.Debug().Err(err).Msgf("error parsing address %s", address)
continue
}
client.SetBaseURL(url)
state, err := client.GetTunnelState(context.Background())
if err == nil {
instances = append(instances, &AddressableTunnelState{state, url})
} else {
log.Debug().Err(err).Msgf("error getting tunnel state from address %s", address)
}
}
if len(instances) == 0 {
return nil, nil, ErrMetricsServerNotFound
}
if len(instances) == 1 {
return instances[0], nil, nil
}
return nil, instances, ErrMultipleMetricsServerFound
}
// newFormattedEncoder return a JSON encoder with identation
func newFormattedEncoder(w io.Writer) *json.Encoder {
encoder := json.NewEncoder(w)
encoder.SetIndent("", " ")
return encoder
}

View File

@ -1,147 +0,0 @@
package diagnostic_test
import (
"context"
"net/http"
"net/url"
"sync"
"testing"
"time"
"github.com/facebookgo/grace/gracenet"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/diagnostic"
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/tunnelstate"
)
func helperCreateServer(t *testing.T, listeners *gracenet.Net, tunnelID uuid.UUID, connectorID uuid.UUID) func() {
t.Helper()
listener, err := metrics.CreateMetricsListener(listeners, "localhost:0")
require.NoError(t, err)
log := zerolog.Nop()
tracker := tunnelstate.NewConnTracker(&log)
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tunnelID, connectorID, tracker, map[string]string{}, []string{})
router := http.NewServeMux()
router.HandleFunc("/diag/tunnel", handler.TunnelStateHandler)
server := &http.Server{
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
Handler: router,
}
var wgroup sync.WaitGroup
wgroup.Add(1)
go func() {
defer wgroup.Done()
_ = server.Serve(listener)
}()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
cleanUp := func() {
_ = server.Shutdown(ctx)
cancel()
wgroup.Wait()
}
return cleanUp
}
func TestFindMetricsServer_WhenSingleServerIsRunning_ReturnState(t *testing.T) {
listeners := gracenet.Net{}
tid1 := uuid.New()
cid1 := uuid.New()
cleanUp := helperCreateServer(t, &listeners, tid1, cid1)
defer cleanUp()
log := zerolog.Nop()
client := diagnostic.NewHTTPClient()
addresses := metrics.GetMetricsKnownAddresses("host")
url1, err := url.Parse("http://localhost:20241")
require.NoError(t, err)
tunnel1 := &diagnostic.AddressableTunnelState{
TunnelState: &diagnostic.TunnelState{
TunnelID: tid1,
ConnectorID: cid1,
Connections: nil,
},
URL: url1,
}
state, tunnels, err := diagnostic.FindMetricsServer(&log, client, addresses[:])
if err != nil {
require.ErrorIs(t, err, diagnostic.ErrMultipleMetricsServerFound)
}
assert.Equal(t, tunnel1, state)
assert.Nil(t, tunnels)
}
func TestFindMetricsServer_WhenMultipleServerAreRunning_ReturnError(t *testing.T) {
listeners := gracenet.Net{}
tid1 := uuid.New()
cid1 := uuid.New()
cid2 := uuid.New()
cleanUp := helperCreateServer(t, &listeners, tid1, cid1)
defer cleanUp()
cleanUp = helperCreateServer(t, &listeners, tid1, cid2)
defer cleanUp()
log := zerolog.Nop()
client := diagnostic.NewHTTPClient()
addresses := metrics.GetMetricsKnownAddresses("host")
url1, err := url.Parse("http://localhost:20241")
require.NoError(t, err)
url2, err := url.Parse("http://localhost:20242")
require.NoError(t, err)
tunnel1 := &diagnostic.AddressableTunnelState{
TunnelState: &diagnostic.TunnelState{
TunnelID: tid1,
ConnectorID: cid1,
Connections: nil,
},
URL: url1,
}
tunnel2 := &diagnostic.AddressableTunnelState{
TunnelState: &diagnostic.TunnelState{
TunnelID: tid1,
ConnectorID: cid2,
Connections: nil,
},
URL: url2,
}
state, tunnels, err := diagnostic.FindMetricsServer(&log, client, addresses[:])
if err != nil {
require.ErrorIs(t, err, diagnostic.ErrMultipleMetricsServerFound)
}
assert.Nil(t, state)
assert.Equal(t, []*diagnostic.AddressableTunnelState{tunnel1, tunnel2}, tunnels)
}
func TestFindMetricsServer_WhenNoInstanceIsRuning_ReturnError(t *testing.T) {
log := zerolog.Nop()
client := diagnostic.NewHTTPClient()
addresses := metrics.GetMetricsKnownAddresses("host")
state, tunnels, err := diagnostic.FindMetricsServer(&log, client, addresses[:])
require.ErrorIs(t, err, diagnostic.ErrMetricsServerNotFound)
assert.Nil(t, state)
assert.Nil(t, tunnels)
}

View File

@ -1,28 +0,0 @@
package diagnostic
import (
"errors"
)
var (
// Error used when there is no log directory available.
ErrManagedLogNotFound = errors.New("managed log directory not found")
// Error used when it is not possible to collect logs using the log configuration.
ErrLogConfigurationIsInvalid = errors.New("provided log configuration is invalid")
// Error used when parsing the fields of the output of collector.
ErrInsufficientLines = errors.New("insufficient lines")
// Error used when parsing the lines of the output of collector.
ErrInsuficientFields = errors.New("insufficient fields")
// Error used when given key is not found while parsing KV.
ErrKeyNotFound = errors.New("key not found")
// Error used when there is no disk volume information available.
ErrNoVolumeFound = errors.New("no disk volume information found")
// Error user when the base url of the diagnostic client is not provided.
ErrNoBaseURL = errors.New("no base url")
// Error used when no metrics server is found listening to the known addresses list (check [metrics.GetMetricsKnownAddresses]).
ErrMetricsServerNotFound = errors.New("metrics server not found")
// Error used when multiple metrics server are found listening to the known addresses list (check [metrics.GetMetricsKnownAddresses]).
ErrMultipleMetricsServerFound = errors.New("multiple metrics server found")
// Error used when a temporary file creation fails within the diagnostic procedure
ErrCreatingTemporaryFile = errors.New("temporary file creation failed")
)

View File

@ -1,144 +0,0 @@
package diagnostic
import (
"context"
"encoding/json"
"net/http"
"os"
"strconv"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/tunnelstate"
)
type Handler struct {
log *zerolog.Logger
timeout time.Duration
systemCollector SystemCollector
tunnelID uuid.UUID
connectorID uuid.UUID
tracker *tunnelstate.ConnTracker
cliFlags map[string]string
icmpSources []string
}
func NewDiagnosticHandler(
log *zerolog.Logger,
timeout time.Duration,
systemCollector SystemCollector,
tunnelID uuid.UUID,
connectorID uuid.UUID,
tracker *tunnelstate.ConnTracker,
cliFlags map[string]string,
icmpSources []string,
) *Handler {
logger := log.With().Logger()
if timeout == 0 {
timeout = defaultCollectorTimeout
}
cliFlags[configurationKeyUID] = strconv.Itoa(os.Getuid())
return &Handler{
log: &logger,
timeout: timeout,
systemCollector: systemCollector,
tunnelID: tunnelID,
connectorID: connectorID,
tracker: tracker,
cliFlags: cliFlags,
icmpSources: icmpSources,
}
}
func (handler *Handler) InstallEndpoints(router *http.ServeMux) {
router.HandleFunc(cliConfigurationEndpoint, handler.ConfigurationHandler)
router.HandleFunc(tunnelStateEndpoint, handler.TunnelStateHandler)
router.HandleFunc(systemInformationEndpoint, handler.SystemHandler)
}
type SystemInformationResponse struct {
Info *SystemInformation `json:"info"`
Err error `json:"errors"`
}
func (handler *Handler) SystemHandler(writer http.ResponseWriter, request *http.Request) {
logger := handler.log.With().Str(collectorField, systemCollectorName).Logger()
logger.Info().Msg("Collection started")
defer logger.Info().Msg("Collection finished")
ctx, cancel := context.WithTimeout(request.Context(), handler.timeout)
defer cancel()
info, err := handler.systemCollector.Collect(ctx)
response := SystemInformationResponse{
Info: info,
Err: err,
}
encoder := json.NewEncoder(writer)
err = encoder.Encode(response)
if err != nil {
logger.Error().Err(err).Msgf("error occurred whilst serializing information")
writer.WriteHeader(http.StatusInternalServerError)
}
}
type TunnelState struct {
TunnelID uuid.UUID `json:"tunnelID,omitempty"`
ConnectorID uuid.UUID `json:"connectorID,omitempty"`
Connections []tunnelstate.IndexedConnectionInfo `json:"connections,omitempty"`
ICMPSources []string `json:"icmp_sources,omitempty"`
}
func (handler *Handler) TunnelStateHandler(writer http.ResponseWriter, _ *http.Request) {
log := handler.log.With().Str(collectorField, tunnelStateCollectorName).Logger()
log.Info().Msg("Collection started")
defer log.Info().Msg("Collection finished")
body := TunnelState{
handler.tunnelID,
handler.connectorID,
handler.tracker.GetActiveConnections(),
handler.icmpSources,
}
encoder := json.NewEncoder(writer)
err := encoder.Encode(body)
if err != nil {
handler.log.Error().Err(err).Msgf("error occurred whilst serializing information")
writer.WriteHeader(http.StatusInternalServerError)
}
}
func (handler *Handler) ConfigurationHandler(writer http.ResponseWriter, _ *http.Request) {
log := handler.log.With().Str(collectorField, configurationCollectorName).Logger()
log.Info().Msg("Collection started")
defer func() {
log.Info().Msg("Collection finished")
}()
encoder := json.NewEncoder(writer)
err := encoder.Encode(handler.cliFlags)
if err != nil {
handler.log.Error().Err(err).Msgf("error occurred whilst serializing response")
writer.WriteHeader(http.StatusInternalServerError)
}
}
func writeResponse(w http.ResponseWriter, bytes []byte, logger *zerolog.Logger) {
bytesWritten, err := w.Write(bytes)
if err != nil {
logger.Error().Err(err).Msg("error occurred writing response")
} else if bytesWritten != len(bytes) {
logger.Error().Msgf("error incomplete write response %d/%d", bytesWritten, len(bytes))
}
}

View File

@ -1,224 +0,0 @@
package diagnostic_test
import (
"context"
"encoding/json"
"errors"
"net"
"net/http"
"net/http/httptest"
"runtime"
"testing"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/diagnostic"
"github.com/cloudflare/cloudflared/tunnelstate"
)
type SystemCollectorMock struct {
systemInfo *diagnostic.SystemInformation
err error
}
const (
systemInformationKey = "sikey"
errorKey = "errkey"
)
func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker {
t.Helper()
log := zerolog.Nop()
tracker := tunnelstate.NewConnTracker(&log)
for _, conn := range connections {
tracker.OnTunnelEvent(connection.Event{
Index: conn.Index,
EventType: connection.Connected,
Protocol: conn.Protocol,
EdgeAddress: conn.EdgeAddress,
})
}
return tracker
}
func (collector *SystemCollectorMock) Collect(context.Context) (*diagnostic.SystemInformation, error) {
return collector.systemInfo, collector.err
}
func TestSystemHandler(t *testing.T) {
t.Parallel()
log := zerolog.Nop()
tests := []struct {
name string
systemInfo *diagnostic.SystemInformation
err error
statusCode int
}{
{
name: "happy path",
systemInfo: diagnostic.NewSystemInformation(
0, 0, 0, 0,
"string", "string", "string", "string",
"string", "string",
runtime.Version(), runtime.GOARCH, nil,
),
err: nil,
statusCode: http.StatusOK,
},
{
name: "on error and no raw info", systemInfo: nil,
err: errors.New("an error"), statusCode: http.StatusOK,
},
}
for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) {
t.Parallel()
handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{
systemInfo: tCase.systemInfo,
err: tCase.err,
}, uuid.New(), uuid.New(), nil, map[string]string{}, nil)
recorder := httptest.NewRecorder()
ctx := context.Background()
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/system", nil)
require.NoError(t, err)
handler.SystemHandler(recorder, request)
assert.Equal(t, tCase.statusCode, recorder.Code)
if tCase.statusCode == http.StatusOK && tCase.systemInfo != nil {
var response diagnostic.SystemInformationResponse
decoder := json.NewDecoder(recorder.Body)
err := decoder.Decode(&response)
require.NoError(t, err)
assert.Equal(t, tCase.systemInfo, response.Info)
}
})
}
}
func TestTunnelStateHandler(t *testing.T) {
t.Parallel()
log := zerolog.Nop()
tests := []struct {
name string
tunnelID uuid.UUID
clientID uuid.UUID
connections []tunnelstate.IndexedConnectionInfo
icmpSources []string
}{
{
name: "case1",
tunnelID: uuid.New(),
clientID: uuid.New(),
},
{
name: "case2",
tunnelID: uuid.New(),
clientID: uuid.New(),
icmpSources: []string{"172.17.0.3", "::1"},
connections: []tunnelstate.IndexedConnectionInfo{{
ConnectionInfo: tunnelstate.ConnectionInfo{
IsConnected: true,
Protocol: connection.QUIC,
EdgeAddress: net.IPv4(100, 100, 100, 100),
},
Index: 0,
}},
},
}
for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) {
t.Parallel()
tracker := newTrackerFromConns(t, tCase.connections)
handler := diagnostic.NewDiagnosticHandler(
&log,
0,
nil,
tCase.tunnelID,
tCase.clientID,
tracker,
map[string]string{},
tCase.icmpSources,
)
recorder := httptest.NewRecorder()
handler.TunnelStateHandler(recorder, nil)
decoder := json.NewDecoder(recorder.Body)
var response diagnostic.TunnelState
err := decoder.Decode(&response)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, tCase.tunnelID, response.TunnelID)
assert.Equal(t, tCase.clientID, response.ConnectorID)
assert.Equal(t, tCase.connections, response.Connections)
assert.Equal(t, tCase.icmpSources, response.ICMPSources)
})
}
}
func TestConfigurationHandler(t *testing.T) {
t.Parallel()
log := zerolog.Nop()
tests := []struct {
name string
flags map[string]string
expected map[string]string
}{
{
name: "empty cli",
flags: make(map[string]string),
expected: map[string]string{
"uid": "0",
},
},
{
name: "cli with flags",
flags: map[string]string{
"b": "a",
"c": "a",
"d": "a",
"uid": "0",
},
expected: map[string]string{
"b": "a",
"c": "a",
"d": "a",
"uid": "0",
},
},
}
for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) {
t.Parallel()
var response map[string]string
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, tCase.flags, nil)
recorder := httptest.NewRecorder()
handler.ConfigurationHandler(recorder, nil)
decoder := json.NewDecoder(recorder.Body)
err := decoder.Decode(&response)
require.NoError(t, err)
_, ok := response["uid"]
assert.True(t, ok)
delete(tCase.expected, "uid")
delete(response, "uid")
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, tCase.expected, response)
})
}
}

View File

@ -1,34 +0,0 @@
package diagnostic
import (
"context"
)
// Represents the path of the log file or log directory.
// This struct is meant to give some ergonimics regarding
// the logging information.
type LogInformation struct {
path string // path to a file or directory
wasCreated bool // denotes if `path` was created
isDirectory bool // denotes if `path` is a directory
}
func NewLogInformation(
path string,
wasCreated bool,
isDirectory bool,
) *LogInformation {
return &LogInformation{
path,
wasCreated,
isDirectory,
}
}
type LogCollector interface {
// This function is responsible for returning a path to a single file
// whose contents are the logs of a cloudflared instance.
// A new file may be create by a LogCollector, thus, its the caller
// responsibility to remove the newly create file.
Collect(ctx context.Context) (*LogInformation, error)
}

View File

@ -1,47 +0,0 @@
package diagnostic
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"time"
)
type DockerLogCollector struct {
containerID string // This member identifies the container by identifier or name
}
func NewDockerLogCollector(containerID string) *DockerLogCollector {
return &DockerLogCollector{
containerID,
}
}
func (collector *DockerLogCollector) Collect(ctx context.Context) (*LogInformation, error) {
tmp := os.TempDir()
outputHandle, err := os.Create(filepath.Join(tmp, logFilename))
if err != nil {
return nil, fmt.Errorf("error opening output file: %w", err)
}
defer outputHandle.Close()
// Calculate 2 weeks ago
since := time.Now().Add(twoWeeksOffset).Format(time.RFC3339)
command := exec.CommandContext(
ctx,
"docker",
"logs",
"--tail",
tailMaxNumberOfLines,
"--since",
since,
collector.containerID,
)
return PipeCommandOutputToFile(command, outputHandle)
}

View File

@ -1,105 +0,0 @@
package diagnostic
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
)
const (
linuxManagedLogsPath = "/var/log/cloudflared.err"
darwinManagedLogsPath = "/Library/Logs/com.cloudflare.cloudflared.err.log"
linuxServiceConfigurationPath = "/etc/systemd/system/cloudflared.service"
linuxSystemdPath = "/run/systemd/system"
)
type HostLogCollector struct {
client HTTPClient
}
func NewHostLogCollector(client HTTPClient) *HostLogCollector {
return &HostLogCollector{
client,
}
}
func extractLogsFromJournalCtl(ctx context.Context) (*LogInformation, error) {
tmp := os.TempDir()
outputHandle, err := os.Create(filepath.Join(tmp, logFilename))
if err != nil {
return nil, fmt.Errorf("error opening output file: %w", err)
}
defer outputHandle.Close()
command := exec.CommandContext(
ctx,
"journalctl",
"--since",
"2 weeks ago",
"-u",
"cloudflared.service",
)
return PipeCommandOutputToFile(command, outputHandle)
}
func getServiceLogPath() (string, error) {
switch runtime.GOOS {
case "darwin":
{
path := darwinManagedLogsPath
if _, err := os.Stat(path); err == nil {
return path, nil
}
userHomeDir, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("error getting user home: %w", err)
}
return filepath.Join(userHomeDir, darwinManagedLogsPath), nil
}
case "linux":
{
return linuxManagedLogsPath, nil
}
default:
return "", ErrManagedLogNotFound
}
}
func (collector *HostLogCollector) Collect(ctx context.Context) (*LogInformation, error) {
logConfiguration, err := collector.client.GetLogConfiguration(ctx)
if err != nil {
return nil, fmt.Errorf("error getting log configuration: %w", err)
}
if logConfiguration.uid == 0 {
_, statSystemdErr := os.Stat(linuxServiceConfigurationPath)
_, statServiceConfigurationErr := os.Stat(linuxServiceConfigurationPath)
if statSystemdErr == nil && statServiceConfigurationErr == nil && runtime.GOOS == "linux" {
return extractLogsFromJournalCtl(ctx)
}
path, err := getServiceLogPath()
if err != nil {
return nil, err
}
return NewLogInformation(path, false, false), nil
}
if logConfiguration.logFile != "" {
return NewLogInformation(logConfiguration.logFile, false, false), nil
} else if logConfiguration.logDirectory != "" {
return NewLogInformation(logConfiguration.logDirectory, false, true), nil
}
return nil, ErrLogConfigurationIsInvalid
}

View File

@ -1,63 +0,0 @@
package diagnostic
import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"time"
)
type KubernetesLogCollector struct {
containerID string // This member identifies the container by identifier or name
pod string // This member identifies the pod where the container is deployed
}
func NewKubernetesLogCollector(containerID, pod string) *KubernetesLogCollector {
return &KubernetesLogCollector{
containerID,
pod,
}
}
func (collector *KubernetesLogCollector) Collect(ctx context.Context) (*LogInformation, error) {
tmp := os.TempDir()
outputHandle, err := os.Create(filepath.Join(tmp, logFilename))
if err != nil {
return nil, fmt.Errorf("error opening output file: %w", err)
}
defer outputHandle.Close()
var command *exec.Cmd
// Calculate 2 weeks ago
since := time.Now().Add(twoWeeksOffset).Format(time.RFC3339)
if collector.containerID != "" {
command = exec.CommandContext(
ctx,
"kubectl",
"logs",
collector.pod,
"--since-time",
since,
"--tail",
tailMaxNumberOfLines,
"-c",
collector.containerID,
)
} else {
command = exec.CommandContext(
ctx,
"kubectl",
"logs",
collector.pod,
"--since-time",
since,
"--tail",
tailMaxNumberOfLines,
)
}
return PipeCommandOutputToFile(command, outputHandle)
}

View File

@ -1,109 +0,0 @@
package diagnostic
import (
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
)
func PipeCommandOutputToFile(command *exec.Cmd, outputHandle *os.File) (*LogInformation, error) {
stdoutReader, err := command.StdoutPipe()
if err != nil {
return nil, fmt.Errorf(
"error retrieving stdout from command '%s': %w",
command.String(),
err,
)
}
stderrReader, err := command.StderrPipe()
if err != nil {
return nil, fmt.Errorf(
"error retrieving stderr from command '%s': %w",
command.String(),
err,
)
}
if err := command.Start(); err != nil {
return nil, fmt.Errorf(
"error running command '%s': %w",
command.String(),
err,
)
}
_, err = io.Copy(outputHandle, stdoutReader)
if err != nil {
return nil, fmt.Errorf(
"error copying stdout from %s to file %s: %w",
command.String(),
outputHandle.Name(),
err,
)
}
_, err = io.Copy(outputHandle, stderrReader)
if err != nil {
return nil, fmt.Errorf(
"error copying stderr from %s to file %s: %w",
command.String(),
outputHandle.Name(),
err,
)
}
if err := command.Wait(); err != nil {
return nil, fmt.Errorf(
"error waiting from command '%s': %w",
command.String(),
err,
)
}
return NewLogInformation(outputHandle.Name(), true, false), nil
}
func CopyFilesFromDirectory(path string) (string, error) {
// rolling logs have as suffix the current date thus
// when iterating the path files they are already in
// chronological order
files, err := os.ReadDir(path)
if err != nil {
return "", fmt.Errorf("error reading directory %s: %w", path, err)
}
outputHandle, err := os.Create(filepath.Join(os.TempDir(), logFilename))
if err != nil {
return "", fmt.Errorf("creating file %s: %w", outputHandle.Name(), err)
}
defer outputHandle.Close()
for _, file := range files {
logHandle, err := os.Open(filepath.Join(path, file.Name()))
if err != nil {
return "", fmt.Errorf("error opening file %s:%w", file.Name(), err)
}
defer logHandle.Close()
_, err = io.Copy(outputHandle, logHandle)
if err != nil {
return "", fmt.Errorf("error copying file %s:%w", logHandle.Name(), err)
}
}
logHandle, err := os.Open(filepath.Join(path, "cloudflared.log"))
if err != nil {
return "", fmt.Errorf("error opening file %s:%w", logHandle.Name(), err)
}
defer logHandle.Close()
_, err = io.Copy(outputHandle, logHandle)
if err != nil {
return "", fmt.Errorf("error copying file %s:%w", logHandle.Name(), err)
}
return outputHandle.Name(), nil
}

View File

@ -1,77 +0,0 @@
package diagnostic
import (
"context"
"errors"
"time"
)
const MicrosecondsFactor = 1000.0
var ErrEmptyDomain = errors.New("domain must not be empty")
// For now only support ICMP is provided.
type IPVersion int
const (
V4 IPVersion = iota
V6 IPVersion = iota
)
type Hop struct {
Hop uint8 `json:"hop,omitempty"` // hop number along the route
Domain string `json:"domain,omitempty"` // domain and/or ip of the hop, this field will be '*' if the hop is a timeout
Rtts []time.Duration `json:"rtts,omitempty"` // RTT measurements in microseconds
}
type TraceOptions struct {
ttl uint64 // number of hops to perform
timeout time.Duration // wait timeout for each response
address string // address to trace
useV4 bool
}
func NewTimeoutHop(
hop uint8,
) *Hop {
// Whenever there is a hop in the format of 'N * * *'
// it means that the hop in the path didn't answer to
// any probe.
return NewHop(
hop,
"*",
nil,
)
}
func NewHop(hop uint8, domain string, rtts []time.Duration) *Hop {
return &Hop{
hop,
domain,
rtts,
}
}
func NewTraceOptions(
ttl uint64,
timeout time.Duration,
address string,
useV4 bool,
) TraceOptions {
return TraceOptions{
ttl,
timeout,
address,
useV4,
}
}
type NetworkCollector interface {
// Performs a trace route operation with the specified options.
// In case the trace fails, it will return a non-nil error and
// it may return a string which represents the raw information
// obtained.
// In case it is successful it will only return an array of Hops
// an empty string and a nil error.
Collect(ctx context.Context, options TraceOptions) ([]*Hop, string, error)
}

View File

@ -1,78 +0,0 @@
//go:build darwin || linux
package diagnostic
import (
"context"
"fmt"
"os/exec"
"strconv"
"strings"
"time"
)
type NetworkCollectorImpl struct{}
func (tracer *NetworkCollectorImpl) Collect(ctx context.Context, options TraceOptions) ([]*Hop, string, error) {
args := []string{
"-I",
"-w",
strconv.FormatInt(int64(options.timeout.Seconds()), 10),
"-m",
strconv.FormatUint(options.ttl, 10),
options.address,
}
var command string
switch options.useV4 {
case false:
command = "traceroute6"
default:
command = "traceroute"
}
process := exec.CommandContext(ctx, command, args...)
return decodeNetworkOutputToFile(process, DecodeLine)
}
func DecodeLine(text string) (*Hop, error) {
fields := strings.Fields(text)
parts := []string{}
filter := func(s string) bool { return s != "*" && s != "ms" }
for _, field := range fields {
if filter(field) {
parts = append(parts, field)
}
}
index, err := strconv.ParseUint(parts[0], 10, 8)
if err != nil {
return nil, fmt.Errorf("couldn't parse index from timeout hop: %w", err)
}
if len(parts) == 1 {
return NewTimeoutHop(uint8(index)), nil
}
domain := ""
rtts := []time.Duration{}
for _, part := range parts[1:] {
rtt, err := strconv.ParseFloat(part, 64)
if err != nil {
domain += part + " "
} else {
rtts = append(rtts, time.Duration(rtt*MicrosecondsFactor))
}
}
domain, _ = strings.CutSuffix(domain, " ")
if domain == "" {
return nil, ErrEmptyDomain
}
return NewHop(uint8(index), domain, rtts), nil
}

View File

@ -1,173 +0,0 @@
//go:build darwin || linux
package diagnostic_test
import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
diagnostic "github.com/cloudflare/cloudflared/diagnostic/network"
)
func TestDecode(t *testing.T) {
t.Parallel()
tests := []struct {
name string
text string
expectedHops []*diagnostic.Hop
}{
{
"repeated hop index parse failure",
`1 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
2 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
someletters * * *
4 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms `,
[]*diagnostic.Hop{
diagnostic.NewHop(
uint8(1),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
diagnostic.NewHop(
uint8(2),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
diagnostic.NewHop(
uint8(4),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
},
},
{
"hop index parse failure",
`1 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
2 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
someletters 8.8.8.8 8.8.8.9 abc ms 0.456 ms 0.789 ms`,
[]*diagnostic.Hop{
diagnostic.NewHop(
uint8(1),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
diagnostic.NewHop(
uint8(2),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
},
},
{
"missing rtt",
`1 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
2 * 8.8.8.8 8.8.8.9 0.456 ms 0.789 ms`,
[]*diagnostic.Hop{
diagnostic.NewHop(
uint8(1),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
diagnostic.NewHop(
uint8(2),
"8.8.8.8 8.8.8.9",
[]time.Duration{
time.Duration(456),
time.Duration(789),
},
),
},
},
{
"simple example ipv4",
`1 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
2 172.68.101.121 (172.68.101.121) 12.874 ms 15.517 ms 15.311 ms
3 * * *`,
[]*diagnostic.Hop{
diagnostic.NewHop(
uint8(1),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
diagnostic.NewHop(
uint8(2),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
diagnostic.NewTimeoutHop(uint8(3)),
},
},
{
"simple example ipv6",
` 1 2400:cb00:107:1024::ac44:6550 12.780 ms 9.118 ms 10.046 ms
2 2a09:bac1:: 9.945 ms 10.033 ms 11.562 ms`,
[]*diagnostic.Hop{
diagnostic.NewHop(
uint8(1),
"2400:cb00:107:1024::ac44:6550",
[]time.Duration{
time.Duration(12780),
time.Duration(9118),
time.Duration(10046),
},
),
diagnostic.NewHop(
uint8(2),
"2a09:bac1::",
[]time.Duration{
time.Duration(9945),
time.Duration(10033),
time.Duration(11562),
},
),
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
hops, err := diagnostic.Decode(strings.NewReader(test.text), diagnostic.DecodeLine)
require.NoError(t, err)
assert.Equal(t, test.expectedHops, hops)
})
}
}

View File

@ -1,74 +0,0 @@
package diagnostic
import (
"bufio"
"bytes"
"fmt"
"io"
"os/exec"
)
type DecodeLineFunc func(text string) (*Hop, error)
func decodeNetworkOutputToFile(command *exec.Cmd, decodeLine DecodeLineFunc) ([]*Hop, string, error) {
stdout, err := command.StdoutPipe()
if err != nil {
return nil, "", fmt.Errorf("error piping traceroute's output: %w", err)
}
if err := command.Start(); err != nil {
return nil, "", fmt.Errorf("error starting traceroute: %w", err)
}
// Tee the output to a string to have the raw information
// in case the decode call fails
// This error is handled only after the Wait call below returns
// otherwise the process can become a zombie
buf := bytes.NewBuffer([]byte{})
tee := io.TeeReader(stdout, buf)
hops, err := Decode(tee, decodeLine)
// regardless of success of the decoding
// consume all output to have available in buf
_, _ = io.ReadAll(tee)
if werr := command.Wait(); werr != nil {
return nil, "", fmt.Errorf("error finishing traceroute: %w", werr)
}
if err != nil {
return nil, buf.String(), err
}
return hops, buf.String(), nil
}
func Decode(reader io.Reader, decodeLine DecodeLineFunc) ([]*Hop, error) {
scanner := bufio.NewScanner(reader)
scanner.Split(bufio.ScanLines)
var hops []*Hop
for scanner.Scan() {
text := scanner.Text()
if text == "" {
continue
}
hop, err := decodeLine(text)
if err != nil {
// This continue is here on the error case because there are lines at the start and end
// that may not be parsable. (check windows tracert output)
// The skip is here because aside from the start and end lines the other lines should
// always be parsable without errors.
continue
}
hops = append(hops, hop)
}
if scanner.Err() != nil {
return nil, fmt.Errorf("scanner reported an error: %w", scanner.Err())
}
return hops, nil
}

View File

@ -1,81 +0,0 @@
//go:build windows
package diagnostic
import (
"context"
"fmt"
"os/exec"
"strconv"
"strings"
"time"
)
type NetworkCollectorImpl struct{}
func (tracer *NetworkCollectorImpl) Collect(ctx context.Context, options TraceOptions) ([]*Hop, string, error) {
ipversion := "-4"
if !options.useV4 {
ipversion = "-6"
}
args := []string{
ipversion,
"-w",
strconv.FormatInt(int64(options.timeout.Seconds()), 10),
"-h",
strconv.FormatUint(options.ttl, 10),
// Do not resolve host names (can add 30+ seconds to run time)
"-d",
options.address,
}
command := exec.CommandContext(ctx, "tracert.exe", args...)
return decodeNetworkOutputToFile(command, DecodeLine)
}
func DecodeLine(text string) (*Hop, error) {
const requestTimedOut = "Request timed out."
fields := strings.Fields(text)
parts := []string{}
filter := func(s string) bool { return s != "*" && s != "ms" }
for _, field := range fields {
if filter(field) {
parts = append(parts, field)
}
}
index, err := strconv.ParseUint(parts[0], 10, 8)
if err != nil {
return nil, fmt.Errorf("couldn't parse index from timeout hop: %w", err)
}
domain := ""
rtts := []time.Duration{}
for _, part := range parts[1:] {
rtt, err := strconv.ParseFloat(strings.TrimLeft(part, "<"), 64)
if err != nil {
domain += part + " "
} else {
rtts = append(rtts, time.Duration(rtt*MicrosecondsFactor))
}
}
domain, _ = strings.CutSuffix(domain, " ")
// If the domain is equal to "Request timed out." then we build a
// timeout hop.
if domain == requestTimedOut {
return NewTimeoutHop(uint8(index)), nil
}
if domain == "" {
return nil, ErrEmptyDomain
}
return NewHop(uint8(index), domain, rtts), nil
}

View File

@ -1,210 +0,0 @@
//go:build windows
package diagnostic_test
import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
diagnostic "github.com/cloudflare/cloudflared/diagnostic/network"
)
func TestDecode(t *testing.T) {
t.Parallel()
tests := []struct {
name string
text string
expectedHops []*diagnostic.Hop
}{
{
"tracert output",
`
Tracing route to region2.v2.argotunnel.com [198.41.200.73]
over a maximum of 5 hops:
1 10 ms <1 ms 1 ms 192.168.64.1
2 27 ms 14 ms 5 ms 192.168.1.254
3 * * * Request timed out.
4 * * * Request timed out.
5 27 ms 5 ms 5 ms 195.8.30.245
Trace complete.
`,
[]*diagnostic.Hop{
diagnostic.NewHop(
uint8(1),
"192.168.64.1",
[]time.Duration{
time.Duration(10000),
time.Duration(1000),
time.Duration(1000),
},
),
diagnostic.NewHop(
uint8(2),
"192.168.1.254",
[]time.Duration{
time.Duration(27000),
time.Duration(14000),
time.Duration(5000),
},
),
diagnostic.NewTimeoutHop(uint8(3)),
diagnostic.NewTimeoutHop(uint8(4)),
diagnostic.NewHop(
uint8(5),
"195.8.30.245",
[]time.Duration{
time.Duration(27000),
time.Duration(5000),
time.Duration(5000),
},
),
},
},
{
"repeated hop index parse failure",
`1 12.874 ms 15.517 ms 15.311 ms 172.68.101.121 (172.68.101.121)
2 12.874 ms 15.517 ms 15.311 ms 172.68.101.121 (172.68.101.121)
someletters * * *`,
[]*diagnostic.Hop{
diagnostic.NewHop(
uint8(1),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
diagnostic.NewHop(
uint8(2),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
},
},
{
"hop index parse failure",
`1 12.874 ms 15.517 ms 15.311 ms 172.68.101.121 (172.68.101.121)
2 12.874 ms 15.517 ms 15.311 ms 172.68.101.121 (172.68.101.121)
someletters abc ms 0.456 ms 0.789 ms 8.8.8.8 8.8.8.9`,
[]*diagnostic.Hop{
diagnostic.NewHop(
uint8(1),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
diagnostic.NewHop(
uint8(2),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
},
},
{
"missing rtt",
`1 <12.874 ms <15.517 ms <15.311 ms 172.68.101.121 (172.68.101.121)
2 * 0.456 ms 0.789 ms 8.8.8.8 8.8.8.9`,
[]*diagnostic.Hop{
diagnostic.NewHop(
uint8(1),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
diagnostic.NewHop(
uint8(2),
"8.8.8.8 8.8.8.9",
[]time.Duration{
time.Duration(456),
time.Duration(789),
},
),
},
},
{
"simple example ipv4",
`1 12.874 ms 15.517 ms 15.311 ms 172.68.101.121 (172.68.101.121)
2 12.874 ms 15.517 ms 15.311 ms 172.68.101.121 (172.68.101.121)
3 * * * Request timed out.`,
[]*diagnostic.Hop{
diagnostic.NewHop(
uint8(1),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
diagnostic.NewHop(
uint8(2),
"172.68.101.121 (172.68.101.121)",
[]time.Duration{
time.Duration(12874),
time.Duration(15517),
time.Duration(15311),
},
),
diagnostic.NewTimeoutHop(uint8(3)),
},
},
{
"simple example ipv6",
` 1 12.780 ms 9.118 ms 10.046 ms 2400:cb00:107:1024::ac44:6550
2 9.945 ms 10.033 ms 11.562 ms 2a09:bac1::`,
[]*diagnostic.Hop{
diagnostic.NewHop(
uint8(1),
"2400:cb00:107:1024::ac44:6550",
[]time.Duration{
time.Duration(12780),
time.Duration(9118),
time.Duration(10046),
},
),
diagnostic.NewHop(
uint8(2),
"2a09:bac1::",
[]time.Duration{
time.Duration(9945),
time.Duration(10033),
time.Duration(11562),
},
),
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
hops, err := diagnostic.Decode(strings.NewReader(test.text), diagnostic.DecodeLine)
require.NoError(t, err)
assert.Equal(t, test.expectedHops, hops)
})
}
}

View File

@ -1,150 +0,0 @@
package diagnostic
import (
"context"
"encoding/json"
"errors"
"strings"
)
type SystemInformationError struct {
Err error `json:"error"`
RawInfo string `json:"rawInfo"`
}
func (err SystemInformationError) Error() string {
return err.Err.Error()
}
func (err SystemInformationError) MarshalJSON() ([]byte, error) {
s := map[string]string{
"error": err.Err.Error(),
"rawInfo": err.RawInfo,
}
return json.Marshal(s)
}
type SystemInformationGeneralError struct {
OperatingSystemInformationError error
MemoryInformationError error
FileDescriptorsInformationError error
DiskVolumeInformationError error
}
func (err SystemInformationGeneralError) Error() string {
builder := &strings.Builder{}
builder.WriteString("errors found:")
if err.OperatingSystemInformationError != nil {
builder.WriteString(err.OperatingSystemInformationError.Error() + ", ")
}
if err.MemoryInformationError != nil {
builder.WriteString(err.MemoryInformationError.Error() + ", ")
}
if err.FileDescriptorsInformationError != nil {
builder.WriteString(err.FileDescriptorsInformationError.Error() + ", ")
}
if err.DiskVolumeInformationError != nil {
builder.WriteString(err.DiskVolumeInformationError.Error() + ", ")
}
return builder.String()
}
func (err SystemInformationGeneralError) MarshalJSON() ([]byte, error) {
data := map[string]SystemInformationError{}
var sysErr SystemInformationError
if errors.As(err.OperatingSystemInformationError, &sysErr) {
data["operatingSystemInformationError"] = sysErr
}
if errors.As(err.MemoryInformationError, &sysErr) {
data["memoryInformationError"] = sysErr
}
if errors.As(err.FileDescriptorsInformationError, &sysErr) {
data["fileDescriptorsInformationError"] = sysErr
}
if errors.As(err.DiskVolumeInformationError, &sysErr) {
data["diskVolumeInformationError"] = sysErr
}
return json.Marshal(data)
}
type DiskVolumeInformation struct {
Name string `json:"name"` // represents the filesystem in linux/macos or device name in windows
SizeMaximum uint64 `json:"sizeMaximum"` // represents the maximum size of the disk in kilobytes
SizeCurrent uint64 `json:"sizeCurrent"` // represents the current size of the disk in kilobytes
}
func NewDiskVolumeInformation(name string, maximum, current uint64) *DiskVolumeInformation {
return &DiskVolumeInformation{
name,
maximum,
current,
}
}
type SystemInformation struct {
MemoryMaximum uint64 `json:"memoryMaximum,omitempty"` // represents the maximum memory of the system in kilobytes
MemoryCurrent uint64 `json:"memoryCurrent,omitempty"` // represents the system's memory in use in kilobytes
FileDescriptorMaximum uint64 `json:"fileDescriptorMaximum,omitempty"` // represents the maximum number of file descriptors of the system
FileDescriptorCurrent uint64 `json:"fileDescriptorCurrent,omitempty"` // represents the system's file descriptors in use
OsSystem string `json:"osSystem,omitempty"` // represents the operating system name i.e.: linux, windows, darwin
HostName string `json:"hostName,omitempty"` // represents the system host name
OsVersion string `json:"osVersion,omitempty"` // detailed information about the system's release version level
OsRelease string `json:"osRelease,omitempty"` // detailed information about the system's release
Architecture string `json:"architecture,omitempty"` // represents the system's hardware platform i.e: arm64/amd64
CloudflaredVersion string `json:"cloudflaredVersion,omitempty"` // the runtime version of cloudflared
GoVersion string `json:"goVersion,omitempty"`
GoArch string `json:"goArch,omitempty"`
Disk []*DiskVolumeInformation `json:"disk,omitempty"`
}
func NewSystemInformation(
memoryMaximum,
memoryCurrent,
filesMaximum,
filesCurrent uint64,
osystem,
name,
osVersion,
osRelease,
architecture,
cloudflaredVersion,
goVersion,
goArchitecture string,
disk []*DiskVolumeInformation,
) *SystemInformation {
return &SystemInformation{
memoryMaximum,
memoryCurrent,
filesMaximum,
filesCurrent,
osystem,
name,
osVersion,
osRelease,
architecture,
cloudflaredVersion,
goVersion,
goArchitecture,
disk,
}
}
type SystemCollector interface {
// If the collection is successful it will return `SystemInformation` struct,
// and a nil error.
//
// This function expects that the caller sets the context timeout to prevent
// long-lived collectors.
Collect(ctx context.Context) (*SystemInformation, error)
}

View File

@ -1,150 +0,0 @@
//go:build linux
package diagnostic
import (
"context"
"fmt"
"os/exec"
"runtime"
"strconv"
"strings"
)
type SystemCollectorImpl struct {
version string
}
func NewSystemCollectorImpl(
version string,
) *SystemCollectorImpl {
return &SystemCollectorImpl{
version,
}
}
func (collector *SystemCollectorImpl) Collect(ctx context.Context) (*SystemInformation, error) {
memoryInfo, memoryInfoRaw, memoryInfoErr := collectMemoryInformation(ctx)
fdInfo, fdInfoRaw, fdInfoErr := collectFileDescriptorInformation(ctx)
disks, disksRaw, diskErr := collectDiskVolumeInformationUnix(ctx)
osInfo, osInfoRaw, osInfoErr := collectOSInformationUnix(ctx)
var memoryMaximum, memoryCurrent, fileDescriptorMaximum, fileDescriptorCurrent uint64
var osSystem, name, osVersion, osRelease, architecture string
gerror := SystemInformationGeneralError{}
if memoryInfoErr != nil {
gerror.MemoryInformationError = SystemInformationError{
Err: memoryInfoErr,
RawInfo: memoryInfoRaw,
}
} else {
memoryMaximum = memoryInfo.MemoryMaximum
memoryCurrent = memoryInfo.MemoryCurrent
}
if fdInfoErr != nil {
gerror.FileDescriptorsInformationError = SystemInformationError{
Err: fdInfoErr,
RawInfo: fdInfoRaw,
}
} else {
fileDescriptorMaximum = fdInfo.FileDescriptorMaximum
fileDescriptorCurrent = fdInfo.FileDescriptorCurrent
}
if diskErr != nil {
gerror.DiskVolumeInformationError = SystemInformationError{
Err: diskErr,
RawInfo: disksRaw,
}
}
if osInfoErr != nil {
gerror.OperatingSystemInformationError = SystemInformationError{
Err: osInfoErr,
RawInfo: osInfoRaw,
}
} else {
osSystem = osInfo.OsSystem
name = osInfo.Name
osVersion = osInfo.OsVersion
osRelease = osInfo.OsRelease
architecture = osInfo.Architecture
}
cloudflaredVersion := collector.version
info := NewSystemInformation(
memoryMaximum,
memoryCurrent,
fileDescriptorMaximum,
fileDescriptorCurrent,
osSystem,
name,
osVersion,
osRelease,
architecture,
cloudflaredVersion,
runtime.Version(),
runtime.GOARCH,
disks,
)
return info, gerror
}
func collectMemoryInformation(ctx context.Context) (*MemoryInformation, string, error) {
// This function relies on the output of `cat /proc/meminfo` to retrieve
// memoryMax and memoryCurrent.
// The expected output is in the format of `KEY VALUE UNIT`.
const (
memTotalPrefix = "MemTotal"
memAvailablePrefix = "MemAvailable"
)
command := exec.CommandContext(ctx, "cat", "/proc/meminfo")
stdout, err := command.Output()
if err != nil {
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
}
output := string(stdout)
mapper := func(field string) (uint64, error) {
field = strings.TrimRight(field, " kB")
return strconv.ParseUint(field, 10, 64)
}
memoryInfo, err := ParseMemoryInformationFromKV(output, memTotalPrefix, memAvailablePrefix, mapper)
if err != nil {
return nil, output, err
}
// returning raw output in case other collected information
// resulted in errors
return memoryInfo, output, nil
}
func collectFileDescriptorInformation(ctx context.Context) (*FileDescriptorInformation, string, error) {
// Command retrieved from https://docs.kernel.org/admin-guide/sysctl/fs.html#file-max-file-nr.
// If the sysctl is not available the command with fail.
command := exec.CommandContext(ctx, "sysctl", "-n", "fs.file-nr")
stdout, err := command.Output()
if err != nil {
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
}
output := string(stdout)
fileDescriptorInfo, err := ParseSysctlFileDescriptorInformation(output)
if err != nil {
return nil, output, err
}
// returning raw output in case other collected information
// resulted in errors
return fileDescriptorInfo, output, nil
}

View File

@ -1,172 +0,0 @@
//go:build darwin
package diagnostic
import (
"context"
"fmt"
"os/exec"
"runtime"
"strconv"
)
type SystemCollectorImpl struct {
version string
}
func NewSystemCollectorImpl(
version string,
) *SystemCollectorImpl {
return &SystemCollectorImpl{
version,
}
}
func (collector *SystemCollectorImpl) Collect(ctx context.Context) (*SystemInformation, error) {
memoryInfo, memoryInfoRaw, memoryInfoErr := collectMemoryInformation(ctx)
fdInfo, fdInfoRaw, fdInfoErr := collectFileDescriptorInformation(ctx)
disks, disksRaw, diskErr := collectDiskVolumeInformationUnix(ctx)
osInfo, osInfoRaw, osInfoErr := collectOSInformationUnix(ctx)
var memoryMaximum, memoryCurrent, fileDescriptorMaximum, fileDescriptorCurrent uint64
var osSystem, name, osVersion, osRelease, architecture string
err := SystemInformationGeneralError{
OperatingSystemInformationError: nil,
MemoryInformationError: nil,
FileDescriptorsInformationError: nil,
DiskVolumeInformationError: nil,
}
if memoryInfoErr != nil {
err.MemoryInformationError = SystemInformationError{
Err: memoryInfoErr,
RawInfo: memoryInfoRaw,
}
} else {
memoryMaximum = memoryInfo.MemoryMaximum
memoryCurrent = memoryInfo.MemoryCurrent
}
if fdInfoErr != nil {
err.FileDescriptorsInformationError = SystemInformationError{
Err: fdInfoErr,
RawInfo: fdInfoRaw,
}
} else {
fileDescriptorMaximum = fdInfo.FileDescriptorMaximum
fileDescriptorCurrent = fdInfo.FileDescriptorCurrent
}
if diskErr != nil {
err.DiskVolumeInformationError = SystemInformationError{
Err: diskErr,
RawInfo: disksRaw,
}
}
if osInfoErr != nil {
err.OperatingSystemInformationError = SystemInformationError{
Err: osInfoErr,
RawInfo: osInfoRaw,
}
} else {
osSystem = osInfo.OsSystem
name = osInfo.Name
osVersion = osInfo.OsVersion
osRelease = osInfo.OsRelease
architecture = osInfo.Architecture
}
cloudflaredVersion := collector.version
info := NewSystemInformation(
memoryMaximum,
memoryCurrent,
fileDescriptorMaximum,
fileDescriptorCurrent,
osSystem,
name,
osVersion,
osRelease,
architecture,
cloudflaredVersion,
runtime.Version(),
runtime.GOARCH,
disks,
)
return info, err
}
func collectFileDescriptorInformation(ctx context.Context) (
*FileDescriptorInformation,
string,
error,
) {
const (
fileDescriptorMaximumKey = "kern.maxfiles"
fileDescriptorCurrentKey = "kern.num_files"
)
command := exec.CommandContext(ctx, "sysctl", fileDescriptorMaximumKey, fileDescriptorCurrentKey)
stdout, err := command.Output()
if err != nil {
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
}
output := string(stdout)
fileDescriptorInfo, err := ParseFileDescriptorInformationFromKV(
output,
fileDescriptorMaximumKey,
fileDescriptorCurrentKey,
)
if err != nil {
return nil, output, err
}
// returning raw output in case other collected information
// resulted in errors
return fileDescriptorInfo, output, nil
}
func collectMemoryInformation(ctx context.Context) (
*MemoryInformation,
string,
error,
) {
const (
memoryMaximumKey = "hw.memsize"
memoryAvailableKey = "hw.memsize_usable"
)
command := exec.CommandContext(
ctx,
"sysctl",
memoryMaximumKey,
memoryAvailableKey,
)
stdout, err := command.Output()
if err != nil {
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
}
output := string(stdout)
mapper := func(field string) (uint64, error) {
const kiloBytes = 1024
value, err := strconv.ParseUint(field, 10, 64)
return value / kiloBytes, err
}
memoryInfo, err := ParseMemoryInformationFromKV(output, memoryMaximumKey, memoryAvailableKey, mapper)
if err != nil {
return nil, output, err
}
// returning raw output in case other collected information
// resulted in errors
return memoryInfo, output, nil
}

View File

@ -1,466 +0,0 @@
package diagnostic_test
import (
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/diagnostic"
)
func TestParseMemoryInformationFromKV(t *testing.T) {
t.Parallel()
mapper := func(field string) (uint64, error) {
value, err := strconv.ParseUint(field, 10, 64)
return value, err
}
linuxMapper := func(field string) (uint64, error) {
field = strings.TrimRight(field, " kB")
return strconv.ParseUint(field, 10, 64)
}
windowsMemoryOutput := `
FreeVirtualMemory : 5350472
TotalVirtualMemorySize : 8903424
`
macosMemoryOutput := `hw.memsize: 38654705664
hw.memsize_usable: 38009012224`
memoryOutputWithMissingKey := `hw.memsize: 38654705664`
linuxMemoryOutput := `MemTotal: 8028860 kB
MemFree: 731396 kB
MemAvailable: 4678844 kB
Buffers: 472632 kB
Cached: 3186492 kB
SwapCached: 4196 kB
Active: 3088988 kB
Inactive: 3468560 kB`
tests := []struct {
name string
output string
memoryMaximumKey string
memoryAvailableKey string
expected *diagnostic.MemoryInformation
expectedErr bool
mapper func(string) (uint64, error)
}{
{
name: "parse linux memory values",
output: linuxMemoryOutput,
memoryMaximumKey: "MemTotal",
memoryAvailableKey: "MemAvailable",
expected: &diagnostic.MemoryInformation{
8028860,
8028860 - 4678844,
},
expectedErr: false,
mapper: linuxMapper,
},
{
name: "parse memory values with missing key",
output: memoryOutputWithMissingKey,
memoryMaximumKey: "hw.memsize",
memoryAvailableKey: "hw.memsize_usable",
expected: nil,
expectedErr: true,
mapper: mapper,
},
{
name: "parse macos memory values",
output: macosMemoryOutput,
memoryMaximumKey: "hw.memsize",
memoryAvailableKey: "hw.memsize_usable",
expected: &diagnostic.MemoryInformation{
38654705664,
38654705664 - 38009012224,
},
expectedErr: false,
mapper: mapper,
},
{
name: "parse windows memory values",
output: windowsMemoryOutput,
memoryMaximumKey: "TotalVirtualMemorySize",
memoryAvailableKey: "FreeVirtualMemory",
expected: &diagnostic.MemoryInformation{
8903424,
8903424 - 5350472,
},
expectedErr: false,
mapper: mapper,
},
}
for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) {
t.Parallel()
memoryInfo, err := diagnostic.ParseMemoryInformationFromKV(
tCase.output,
tCase.memoryMaximumKey,
tCase.memoryAvailableKey,
tCase.mapper,
)
if tCase.expectedErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tCase.expected, memoryInfo)
}
})
}
}
func TestParseUnameOutput(t *testing.T) {
t.Parallel()
tests := []struct {
name string
output string
os string
expected *diagnostic.OsInfo
expectedErr bool
}{
{
name: "darwin machine",
output: "Darwin APC 23.6.0 Darwin Kernel Version 99.6.0: Wed Jul 31 20:48:04 PDT 1997; root:xnu-66666.666.6.666.6~1/RELEASE_ARM64_T6666 arm64",
os: "darwin",
expected: &diagnostic.OsInfo{
Architecture: "arm64",
Name: "APC",
OsSystem: "Darwin",
OsRelease: "Darwin Kernel Version 99.6.0: Wed Jul 31 20:48:04 PDT 1997; root:xnu-66666.666.6.666.6~1/RELEASE_ARM64_T6666",
OsVersion: "23.6.0",
},
expectedErr: false,
},
{
name: "linux machine",
output: "Linux dab00d565591 6.6.31-linuxkit #1 SMP Thu May 23 08:36:57 UTC 2024 aarch64 GNU/Linux",
os: "linux",
expected: &diagnostic.OsInfo{
Architecture: "aarch64",
Name: "dab00d565591",
OsSystem: "Linux",
OsRelease: "#1 SMP Thu May 23 08:36:57 UTC 2024",
OsVersion: "6.6.31-linuxkit",
},
expectedErr: false,
},
{
name: "not enough fields",
output: "Linux ",
os: "linux",
expected: nil,
expectedErr: true,
},
}
for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) {
t.Parallel()
memoryInfo, err := diagnostic.ParseUnameOutput(
tCase.output,
tCase.os,
)
if tCase.expectedErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tCase.expected, memoryInfo)
}
})
}
}
func TestParseFileDescriptorInformationFromKV(t *testing.T) {
const (
fileDescriptorMaximumKey = "kern.maxfiles"
fileDescriptorCurrentKey = "kern.num_files"
)
t.Parallel()
memoryOutput := `kern.maxfiles: 276480
kern.num_files: 11787`
memoryOutputWithMissingKey := `kern.maxfiles: 276480`
tests := []struct {
name string
output string
expected *diagnostic.FileDescriptorInformation
expectedErr bool
}{
{
name: "parse memory values with missing key",
output: memoryOutputWithMissingKey,
expected: nil,
expectedErr: true,
},
{
name: "parse macos memory values",
output: memoryOutput,
expected: &diagnostic.FileDescriptorInformation{
276480,
11787,
},
expectedErr: false,
},
}
for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) {
t.Parallel()
fdInfo, err := diagnostic.ParseFileDescriptorInformationFromKV(
tCase.output,
fileDescriptorMaximumKey,
fileDescriptorCurrentKey,
)
if tCase.expectedErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tCase.expected, fdInfo)
}
})
}
}
func TestParseSysctlFileDescriptorInformation(t *testing.T) {
t.Parallel()
tests := []struct {
name string
output string
expected *diagnostic.FileDescriptorInformation
expectedErr bool
}{
{
name: "expected output",
output: "111 0 1111111",
expected: &diagnostic.FileDescriptorInformation{
FileDescriptorMaximum: 1111111,
FileDescriptorCurrent: 111,
},
expectedErr: false,
},
{
name: "not enough fields",
output: "111 111 ",
expected: nil,
expectedErr: true,
},
}
for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) {
t.Parallel()
fdsInfo, err := diagnostic.ParseSysctlFileDescriptorInformation(
tCase.output,
)
if tCase.expectedErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tCase.expected, fdsInfo)
}
})
}
}
func TestParseWinOperatingSystemInfo(t *testing.T) {
const (
architecturePrefix = "OSArchitecture"
osSystemPrefix = "Caption"
osVersionPrefix = "Version"
osReleasePrefix = "BuildNumber"
namePrefix = "CSName"
)
t.Parallel()
windowsIncompleteOsInfo := `
OSArchitecture : ARM 64 bits
Caption : Microsoft Windows 11 Home
Morekeys : 121314
CSName : UTILIZA-QO859QP
`
windowsCompleteOsInfo := `
OSArchitecture : ARM 64 bits
Caption : Microsoft Windows 11 Home
Version : 10.0.22631
BuildNumber : 22631
Morekeys : 121314
CSName : UTILIZA-QO859QP
`
tests := []struct {
name string
output string
expected *diagnostic.OsInfo
expectedErr bool
}{
{
name: "expected output",
output: windowsCompleteOsInfo,
expected: &diagnostic.OsInfo{
Architecture: "ARM 64 bits",
Name: "UTILIZA-QO859QP",
OsSystem: "Microsoft Windows 11 Home",
OsRelease: "22631",
OsVersion: "10.0.22631",
},
expectedErr: false,
},
{
name: "missing keys",
output: windowsIncompleteOsInfo,
expected: nil,
expectedErr: true,
},
}
for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) {
t.Parallel()
osInfo, err := diagnostic.ParseWinOperatingSystemInfo(
tCase.output,
architecturePrefix,
osSystemPrefix,
osVersionPrefix,
osReleasePrefix,
namePrefix,
)
if tCase.expectedErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tCase.expected, osInfo)
}
})
}
}
func TestParseDiskVolumeInformationOutput(t *testing.T) {
t.Parallel()
invalidUnixDiskVolumeInfo := `Filesystem Size Used Avail Use% Mounted on
overlay 59G 19G 38G 33% /
tmpfs 64M 0 64M 0% /dev
shm 64M 0 64M 0% /dev/shm
/run/host_mark/Users 461G 266G 195G 58% /tmp/cloudflared
/dev/vda1 59G 19G 38G 33% /etc/hosts
tmpfs 3.9G 0 3.9G 0% /sys/firmware
`
unixDiskVolumeInfo := `Filesystem Size Used Avail Use% Mounted on
overlay 61202244 18881444 39179476 33% /
tmpfs 65536 0 65536 0% /dev
shm 65536 0 65536 0% /dev/shm
/run/host_mark/Users 482797652 278648468 204149184 58% /tmp/cloudflared
/dev/vda1 61202244 18881444 39179476 33% /etc/hosts
tmpfs 4014428 0 4014428 0% /sys/firmware`
missingFields := ` DeviceID Size
-------- ----
C: size
E: 235563008
Z: 67754782720
`
invalidTypeField := ` DeviceID Size FreeSpace
-------- ---- ---------
C: size 31318736896
D:
E: 235563008 0
Z: 67754782720 31318732800
`
windowsDiskVolumeInfo := `
DeviceID Size FreeSpace
-------- ---- ---------
C: 67754782720 31318736896
E: 235563008 0
Z: 67754782720 31318732800`
tests := []struct {
name string
output string
expected []*diagnostic.DiskVolumeInformation
skipLines int
expectedErr bool
}{
{
name: "invalid unix disk volume information (numbers have units)",
output: invalidUnixDiskVolumeInfo,
expected: []*diagnostic.DiskVolumeInformation{},
skipLines: 1,
expectedErr: true,
},
{
name: "unix disk volume information",
output: unixDiskVolumeInfo,
skipLines: 1,
expected: []*diagnostic.DiskVolumeInformation{
diagnostic.NewDiskVolumeInformation("overlay", 61202244, 18881444),
diagnostic.NewDiskVolumeInformation("tmpfs", 65536, 0),
diagnostic.NewDiskVolumeInformation("shm", 65536, 0),
diagnostic.NewDiskVolumeInformation("/run/host_mark/Users", 482797652, 278648468),
diagnostic.NewDiskVolumeInformation("/dev/vda1", 61202244, 18881444),
diagnostic.NewDiskVolumeInformation("tmpfs", 4014428, 0),
},
expectedErr: false,
},
{
name: "windows disk volume information",
output: windowsDiskVolumeInfo,
expected: []*diagnostic.DiskVolumeInformation{
diagnostic.NewDiskVolumeInformation("C:", 67754782720, 31318736896),
diagnostic.NewDiskVolumeInformation("E:", 235563008, 0),
diagnostic.NewDiskVolumeInformation("Z:", 67754782720, 31318732800),
},
skipLines: 4,
expectedErr: false,
},
{
name: "insuficient fields",
output: missingFields,
expected: nil,
skipLines: 2,
expectedErr: true,
},
{
name: "invalid field",
output: invalidTypeField,
expected: nil,
skipLines: 2,
expectedErr: true,
},
}
for _, tCase := range tests {
t.Run(tCase.name, func(t *testing.T) {
t.Parallel()
disks, err := diagnostic.ParseDiskVolumeInformationOutput(tCase.output, tCase.skipLines, 1)
if tCase.expectedErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tCase.expected, disks)
}
})
}
}

View File

@ -1,377 +0,0 @@
package diagnostic
import (
"context"
"fmt"
"os/exec"
"runtime"
"sort"
"strconv"
"strings"
)
func findColonSeparatedPairs[V any](output string, keys []string, mapper func(string) (V, error)) map[string]V {
const (
memoryField = 1
memoryInformationFields = 2
)
lines := strings.Split(output, "\n")
pairs := make(map[string]V, 0)
// sort keys and lines to allow incremental search
sort.Strings(lines)
sort.Strings(keys)
// keeps track of the last key found
lastIndex := 0
for _, line := range lines {
if lastIndex == len(keys) {
// already found all keys no need to continue iterating
// over the other values
break
}
for index, key := range keys[lastIndex:] {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, key) {
fields := strings.Split(line, ":")
if len(fields) < memoryInformationFields {
lastIndex = index + 1
break
}
field, err := mapper(strings.TrimSpace(fields[memoryField]))
if err != nil {
lastIndex = lastIndex + index + 1
break
}
pairs[key] = field
lastIndex = lastIndex + index + 1
break
}
}
}
return pairs
}
func ParseDiskVolumeInformationOutput(output string, skipLines int, scale float64) ([]*DiskVolumeInformation, error) {
const (
diskFieldsMinimum = 3
nameField = 0
sizeMaximumField = 1
sizeCurrentField = 2
)
disksRaw := strings.Split(output, "\n")
disks := make([]*DiskVolumeInformation, 0)
if skipLines > len(disksRaw) || skipLines < 0 {
skipLines = 0
}
for _, disk := range disksRaw[skipLines:] {
if disk == "" {
// skip empty line
continue
}
fields := strings.Fields(disk)
if len(fields) < diskFieldsMinimum {
return nil, fmt.Errorf("expected disk volume to have %d fields got %d: %w",
diskFieldsMinimum, len(fields), ErrInsuficientFields,
)
}
name := fields[nameField]
sizeMaximum, err := strconv.ParseUint(fields[sizeMaximumField], 10, 64)
if err != nil {
continue
}
sizeCurrent, err := strconv.ParseUint(fields[sizeCurrentField], 10, 64)
if err != nil {
continue
}
diskInfo := NewDiskVolumeInformation(
name, uint64(float64(sizeMaximum)*scale), uint64(float64(sizeCurrent)*scale),
)
disks = append(disks, diskInfo)
}
if len(disks) == 0 {
return nil, ErrNoVolumeFound
}
return disks, nil
}
type OsInfo struct {
OsSystem string
Name string
OsVersion string
OsRelease string
Architecture string
}
func ParseUnameOutput(output string, system string) (*OsInfo, error) {
const (
osystemField = 0
nameField = 1
osVersionField = 2
osReleaseStartField = 3
osInformationFieldsMinimum = 6
darwin = "darwin"
)
architectureOffset := 2
if system == darwin {
architectureOffset = 1
}
fields := strings.Fields(output)
if len(fields) < osInformationFieldsMinimum {
return nil, fmt.Errorf("expected system information to have %d fields got %d: %w",
osInformationFieldsMinimum, len(fields), ErrInsuficientFields,
)
}
architectureField := len(fields) - architectureOffset
osystem := fields[osystemField]
name := fields[nameField]
osVersion := fields[osVersionField]
osRelease := strings.Join(fields[osReleaseStartField:architectureField], " ")
architecture := fields[architectureField]
return &OsInfo{
osystem,
name,
osVersion,
osRelease,
architecture,
}, nil
}
func ParseWinOperatingSystemInfo(
output string,
architectureKey string,
osSystemKey string,
osVersionKey string,
osReleaseKey string,
nameKey string,
) (*OsInfo, error) {
identity := func(s string) (string, error) { return s, nil }
keys := []string{architectureKey, osSystemKey, osVersionKey, osReleaseKey, nameKey}
pairs := findColonSeparatedPairs(
output,
keys,
identity,
)
architecture, exists := pairs[architectureKey]
if !exists {
return nil, fmt.Errorf("parsing os information: %w, key=%s", ErrKeyNotFound, architectureKey)
}
osSystem, exists := pairs[osSystemKey]
if !exists {
return nil, fmt.Errorf("parsing os information: %w, key=%s", ErrKeyNotFound, osSystemKey)
}
osVersion, exists := pairs[osVersionKey]
if !exists {
return nil, fmt.Errorf("parsing os information: %w, key=%s", ErrKeyNotFound, osVersionKey)
}
osRelease, exists := pairs[osReleaseKey]
if !exists {
return nil, fmt.Errorf("parsing os information: %w, key=%s", ErrKeyNotFound, osReleaseKey)
}
name, exists := pairs[nameKey]
if !exists {
return nil, fmt.Errorf("parsing os information: %w, key=%s", ErrKeyNotFound, nameKey)
}
return &OsInfo{osSystem, name, osVersion, osRelease, architecture}, nil
}
type FileDescriptorInformation struct {
FileDescriptorMaximum uint64
FileDescriptorCurrent uint64
}
func ParseSysctlFileDescriptorInformation(output string) (*FileDescriptorInformation, error) {
const (
openFilesField = 0
maxFilesField = 2
fileDescriptorLimitsFields = 3
)
fields := strings.Fields(output)
if len(fields) != fileDescriptorLimitsFields {
return nil,
fmt.Errorf(
"expected file descriptor information to have %d fields got %d: %w",
fileDescriptorLimitsFields,
len(fields),
ErrInsuficientFields,
)
}
fileDescriptorCurrent, err := strconv.ParseUint(fields[openFilesField], 10, 64)
if err != nil {
return nil, fmt.Errorf(
"error parsing files current field '%s': %w",
fields[openFilesField],
err,
)
}
fileDescriptorMaximum, err := strconv.ParseUint(fields[maxFilesField], 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing files max field '%s': %w", fields[maxFilesField], err)
}
return &FileDescriptorInformation{fileDescriptorMaximum, fileDescriptorCurrent}, nil
}
func ParseFileDescriptorInformationFromKV(
output string,
fileDescriptorMaximumKey string,
fileDescriptorCurrentKey string,
) (*FileDescriptorInformation, error) {
mapper := func(field string) (uint64, error) {
return strconv.ParseUint(field, 10, 64)
}
pairs := findColonSeparatedPairs(output, []string{fileDescriptorMaximumKey, fileDescriptorCurrentKey}, mapper)
fileDescriptorMaximum, exists := pairs[fileDescriptorMaximumKey]
if !exists {
return nil, fmt.Errorf(
"parsing file descriptor information: %w, key=%s",
ErrKeyNotFound,
fileDescriptorMaximumKey,
)
}
fileDescriptorCurrent, exists := pairs[fileDescriptorCurrentKey]
if !exists {
return nil, fmt.Errorf(
"parsing file descriptor information: %w, key=%s",
ErrKeyNotFound,
fileDescriptorCurrentKey,
)
}
return &FileDescriptorInformation{fileDescriptorMaximum, fileDescriptorCurrent}, nil
}
type MemoryInformation struct {
MemoryMaximum uint64 // size in KB
MemoryCurrent uint64 // size in KB
}
func ParseMemoryInformationFromKV(
output string,
memoryMaximumKey string,
memoryAvailableKey string,
mapper func(field string) (uint64, error),
) (*MemoryInformation, error) {
pairs := findColonSeparatedPairs(output, []string{memoryMaximumKey, memoryAvailableKey}, mapper)
memoryMaximum, exists := pairs[memoryMaximumKey]
if !exists {
return nil, fmt.Errorf("parsing memory information: %w, key=%s", ErrKeyNotFound, memoryMaximumKey)
}
memoryAvailable, exists := pairs[memoryAvailableKey]
if !exists {
return nil, fmt.Errorf("parsing memory information: %w, key=%s", ErrKeyNotFound, memoryAvailableKey)
}
memoryCurrent := memoryMaximum - memoryAvailable
return &MemoryInformation{memoryMaximum, memoryCurrent}, nil
}
func RawSystemInformation(osInfoRaw string, memoryInfoRaw string, fdInfoRaw string, disksRaw string) string {
var builder strings.Builder
formatInfo := func(info string, builder *strings.Builder) {
if info == "" {
builder.WriteString("No information\n")
} else {
builder.WriteString(info)
builder.WriteString("\n")
}
}
builder.WriteString("---BEGIN Operating system information\n")
formatInfo(osInfoRaw, &builder)
builder.WriteString("---END Operating system information\n")
builder.WriteString("---BEGIN Memory information\n")
formatInfo(memoryInfoRaw, &builder)
builder.WriteString("---END Memory information\n")
builder.WriteString("---BEGIN File descriptors information\n")
formatInfo(fdInfoRaw, &builder)
builder.WriteString("---END File descriptors information\n")
builder.WriteString("---BEGIN Disks information\n")
formatInfo(disksRaw, &builder)
builder.WriteString("---END Disks information\n")
rawInformation := builder.String()
return rawInformation
}
func collectDiskVolumeInformationUnix(ctx context.Context) ([]*DiskVolumeInformation, string, error) {
command := exec.CommandContext(ctx, "df", "-k")
stdout, err := command.Output()
if err != nil {
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
}
output := string(stdout)
disks, err := ParseDiskVolumeInformationOutput(output, 1, 1)
if err != nil {
return nil, output, err
}
// returning raw output in case other collected information
// resulted in errors
return disks, output, nil
}
func collectOSInformationUnix(ctx context.Context) (*OsInfo, string, error) {
command := exec.CommandContext(ctx, "uname", "-a")
stdout, err := command.Output()
if err != nil {
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
}
output := string(stdout)
osInfo, err := ParseUnameOutput(output, runtime.GOOS)
if err != nil {
return nil, output, err
}
// returning raw output in case other collected information
// resulted in errors
return osInfo, output, nil
}

View File

@ -1,183 +0,0 @@
//go:build windows
package diagnostic
import (
"context"
"fmt"
"os/exec"
"runtime"
"strconv"
)
const kiloBytesScale = 1.0 / 1024
type SystemCollectorImpl struct {
version string
}
func NewSystemCollectorImpl(
version string,
) *SystemCollectorImpl {
return &SystemCollectorImpl{
version,
}
}
func (collector *SystemCollectorImpl) Collect(ctx context.Context) (*SystemInformation, error) {
memoryInfo, memoryInfoRaw, memoryInfoErr := collectMemoryInformation(ctx)
disks, disksRaw, diskErr := collectDiskVolumeInformation(ctx)
osInfo, osInfoRaw, osInfoErr := collectOSInformation(ctx)
var memoryMaximum, memoryCurrent, fileDescriptorMaximum, fileDescriptorCurrent uint64
var osSystem, name, osVersion, osRelease, architecture string
err := SystemInformationGeneralError{
OperatingSystemInformationError: nil,
MemoryInformationError: nil,
FileDescriptorsInformationError: nil,
DiskVolumeInformationError: nil,
}
if memoryInfoErr != nil {
err.MemoryInformationError = SystemInformationError{
Err: memoryInfoErr,
RawInfo: memoryInfoRaw,
}
} else {
memoryMaximum = memoryInfo.MemoryMaximum
memoryCurrent = memoryInfo.MemoryCurrent
}
if diskErr != nil {
err.DiskVolumeInformationError = SystemInformationError{
Err: diskErr,
RawInfo: disksRaw,
}
}
if osInfoErr != nil {
err.OperatingSystemInformationError = SystemInformationError{
Err: osInfoErr,
RawInfo: osInfoRaw,
}
} else {
osSystem = osInfo.OsSystem
name = osInfo.Name
osVersion = osInfo.OsVersion
osRelease = osInfo.OsRelease
architecture = osInfo.Architecture
}
cloudflaredVersion := collector.version
info := NewSystemInformation(
memoryMaximum,
memoryCurrent,
fileDescriptorMaximum,
fileDescriptorCurrent,
osSystem,
name,
osVersion,
osRelease,
architecture,
cloudflaredVersion,
runtime.Version(),
runtime.GOARCH,
disks,
)
return info, err
}
func collectMemoryInformation(ctx context.Context) (*MemoryInformation, string, error) {
const (
memoryTotalPrefix = "TotalVirtualMemorySize"
memoryAvailablePrefix = "FreeVirtualMemory"
)
command := exec.CommandContext(
ctx,
"powershell",
"-Command",
"Get-CimInstance -Class Win32_OperatingSystem | Select-Object FreeVirtualMemory, TotalVirtualMemorySize | Format-List",
)
stdout, err := command.Output()
if err != nil {
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
}
output := string(stdout)
// the result of the command above will return values in bytes hence
// they need to be converted to kilobytes
mapper := func(field string) (uint64, error) {
value, err := strconv.ParseUint(field, 10, 64)
return uint64(float64(value) * kiloBytesScale), err
}
memoryInfo, err := ParseMemoryInformationFromKV(output, memoryTotalPrefix, memoryAvailablePrefix, mapper)
if err != nil {
return nil, output, err
}
// returning raw output in case other collected information
// resulted in errors
return memoryInfo, output, nil
}
func collectDiskVolumeInformation(ctx context.Context) ([]*DiskVolumeInformation, string, error) {
command := exec.CommandContext(
ctx,
"powershell", "-Command", "Get-CimInstance -Class Win32_LogicalDisk | Select-Object DeviceID, Size, FreeSpace")
stdout, err := command.Output()
if err != nil {
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
}
output := string(stdout)
disks, err := ParseDiskVolumeInformationOutput(output, 2, kiloBytesScale)
if err != nil {
return nil, output, err
}
// returning raw output in case other collected information
// resulted in errors
return disks, output, nil
}
func collectOSInformation(ctx context.Context) (*OsInfo, string, error) {
const (
architecturePrefix = "OSArchitecture"
osSystemPrefix = "Caption"
osVersionPrefix = "Version"
osReleasePrefix = "BuildNumber"
namePrefix = "CSName"
)
command := exec.CommandContext(
ctx,
"powershell",
"-Command",
"Get-CimInstance -Class Win32_OperatingSystem | Select-Object OSArchitecture, Caption, Version, BuildNumber, CSName | Format-List",
)
stdout, err := command.Output()
if err != nil {
return nil, "", fmt.Errorf("error retrieving output from command '%s': %w", command.String(), err)
}
output := string(stdout)
osInfo, err := ParseWinOperatingSystemInfo(output, architecturePrefix, osSystemPrefix, osVersionPrefix, osReleasePrefix, namePrefix)
if err != nil {
return nil, output, err
}
// returning raw output in case other collected information
// resulted in errors
return osInfo, output, nil
}

View File

@ -11,40 +11,28 @@ const (
FeatureDatagramV3 = "support_datagram_v3" FeatureDatagramV3 = "support_datagram_v3"
) )
var defaultFeatures = []string{ var (
FeatureAllowRemoteConfig, DefaultFeatures = []string{
FeatureSerializedHeaders, FeatureAllowRemoteConfig,
FeatureDatagramV2, FeatureSerializedHeaders,
FeatureQUICSupportEOF, FeatureDatagramV2,
FeatureManagementLogs, FeatureQUICSupportEOF,
} FeatureManagementLogs,
}
// Features set by user provided flags
type staticFeatures struct {
PostQuantumMode *PostQuantumMode
}
type PostQuantumMode uint8
const (
// Prefer post quantum, but fallback if connection cannot be established
PostQuantumPrefer PostQuantumMode = iota
// If the user passes the --post-quantum flag, we override
// CurvePreferences to only support hybrid post-quantum key agreements.
PostQuantumStrict
) )
type DatagramVersion string func Contains(feature string) bool {
for _, f := range DefaultFeatures {
const ( if f == feature {
// DatagramV2 is the currently supported datagram protocol for UDP and ICMP packets return true
DatagramV2 DatagramVersion = FeatureDatagramV2 }
// DatagramV3 is a new datagram protocol for UDP and ICMP packets. It is not backwards compatible with datagram v2. }
DatagramV3 DatagramVersion = FeatureDatagramV3 return false
) }
// Remove any duplicates from the slice // Remove any duplicates from the slice
func Dedup(slice []string) []string { func Dedup(slice []string) []string {
// Convert the slice into a set // Convert the slice into a set
set := make(map[string]bool, 0) set := make(map[string]bool, 0)
for _, str := range slice { for _, str := range slice {

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"net" "net"
"slices"
"sync" "sync"
"time" "time"
@ -19,67 +18,61 @@ const (
lookupTimeout = time.Second * 10 lookupTimeout = time.Second * 10
) )
type PostQuantumMode uint8
const (
// Prefer post quantum, but fallback if connection cannot be established
PostQuantumPrefer PostQuantumMode = iota
// If the user passes the --post-quantum flag, we override
// CurvePreferences to only support hybrid post-quantum key agreements.
PostQuantumStrict
)
// If the TXT record adds other fields, the umarshal logic will ignore those keys // If the TXT record adds other fields, the umarshal logic will ignore those keys
// If the TXT record is missing a key, the field will unmarshal to the default Go value // If the TXT record is missing a key, the field will unmarshal to the default Go value
// pq was removed in TUN-7970
type featuresRecord struct{}
type featuresRecord struct { func NewFeatureSelector(ctx context.Context, accountTag string, staticFeatures StaticFeatures, logger *zerolog.Logger) (*FeatureSelector, error) {
// support_datagram_v3 return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), staticFeatures, defaultRefreshFreq)
DatagramV3Percentage int32 `json:"dv3"`
// PostQuantumPercentage int32 `json:"pq"` // Removed in TUN-7970
} }
func NewFeatureSelector(ctx context.Context, accountTag string, cliFeatures []string, pq bool, logger *zerolog.Logger) (*FeatureSelector, error) { // FeatureSelector determines if this account will try new features. It preiodically queries a DNS TXT record
return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), cliFeatures, pq, defaultRefreshFreq) // to see which features are turned on
}
// FeatureSelector determines if this account will try new features. It periodically queries a DNS TXT record
// to see which features are turned on.
type FeatureSelector struct { type FeatureSelector struct {
accountHash int32 accountHash int32
logger *zerolog.Logger logger *zerolog.Logger
resolver resolver resolver resolver
staticFeatures staticFeatures staticFeatures StaticFeatures
cliFeatures []string
// lock protects concurrent access to dynamic features // lock protects concurrent access to dynamic features
lock sync.RWMutex lock sync.RWMutex
features featuresRecord features featuresRecord
} }
func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, cliFeatures []string, pq bool, refreshFreq time.Duration) (*FeatureSelector, error) { // Features set by user provided flags
// Combine default features and user-provided features type StaticFeatures struct {
var pqMode *PostQuantumMode PostQuantumMode *PostQuantumMode
if pq { }
mode := PostQuantumStrict
pqMode = &mode func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, staticFeatures StaticFeatures, refreshFreq time.Duration) (*FeatureSelector, error) {
cliFeatures = append(cliFeatures, FeaturePostQuantum)
}
staticFeatures := staticFeatures{
PostQuantumMode: pqMode,
}
selector := &FeatureSelector{ selector := &FeatureSelector{
accountHash: switchThreshold(accountTag), accountHash: switchThreshold(accountTag),
logger: logger, logger: logger,
resolver: resolver, resolver: resolver,
staticFeatures: staticFeatures, staticFeatures: staticFeatures,
cliFeatures: Dedup(cliFeatures),
} }
if err := selector.refresh(ctx); err != nil { if err := selector.refresh(ctx); err != nil {
logger.Err(err).Msg("Failed to fetch features, default to disable") logger.Err(err).Msg("Failed to fetch features, default to disable")
} }
go selector.refreshLoop(ctx, refreshFreq) // Run refreshLoop next time we have a new feature to rollout
return selector, nil return selector, nil
} }
func (fs *FeatureSelector) accountEnabled(percentage int32) bool {
return percentage > fs.accountHash
}
func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode { func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode {
if fs.staticFeatures.PostQuantumMode != nil { if fs.staticFeatures.PostQuantumMode != nil {
return *fs.staticFeatures.PostQuantumMode return *fs.staticFeatures.PostQuantumMode
@ -88,33 +81,6 @@ func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode {
return PostQuantumPrefer return PostQuantumPrefer
} }
func (fs *FeatureSelector) DatagramVersion() DatagramVersion {
fs.lock.RLock()
defer fs.lock.RUnlock()
// If user provides the feature via the cli, we take it as priority over remote feature evaluation
if slices.Contains(fs.cliFeatures, FeatureDatagramV3) {
return DatagramV3
}
// If the user specifies DatagramV2, we also take that over remote
if slices.Contains(fs.cliFeatures, FeatureDatagramV2) {
return DatagramV2
}
if fs.accountEnabled(fs.features.DatagramV3Percentage) {
return DatagramV3
}
return DatagramV2
}
// ClientFeatures will return the list of currently available features that cloudflared should provide to the edge.
//
// This list is dynamic and can change in-between returns.
func (fs *FeatureSelector) ClientFeatures() []string {
// Evaluate any remote features along with static feature list to construct the list of features
return Dedup(slices.Concat(defaultFeatures, fs.cliFeatures, []string{string(fs.DatagramVersion())}))
}
func (fs *FeatureSelector) refreshLoop(ctx context.Context, refreshFreq time.Duration) { func (fs *FeatureSelector) refreshLoop(ctx context.Context, refreshFreq time.Duration) {
ticker := time.NewTicker(refreshFreq) ticker := time.NewTicker(refreshFreq)
for { for {

View File

@ -13,20 +13,16 @@ import (
func TestUnmarshalFeaturesRecord(t *testing.T) { func TestUnmarshalFeaturesRecord(t *testing.T) {
tests := []struct { tests := []struct {
record []byte record []byte
expectedPercentage int32
}{ }{
{ {
record: []byte(`{"dv3":0}`), record: []byte(`{"pq":0}`),
expectedPercentage: 0,
}, },
{ {
record: []byte(`{"dv3":39}`), record: []byte(`{"pq":39}`),
expectedPercentage: 39,
}, },
{ {
record: []byte(`{"dv3":100}`), record: []byte(`{"pq":100}`),
expectedPercentage: 100,
}, },
{ {
record: []byte(`{}`), // Unmarshal to default struct if key is not present record: []byte(`{}`), // Unmarshal to default struct if key is not present
@ -40,186 +36,37 @@ func TestUnmarshalFeaturesRecord(t *testing.T) {
var features featuresRecord var features featuresRecord
err := json.Unmarshal(test.record, &features) err := json.Unmarshal(test.record, &features)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, test.expectedPercentage, features.DatagramV3Percentage, test) require.Equal(t, featuresRecord{}, features)
} }
} }
func TestFeaturePrecedenceEvaluationPostQuantum(t *testing.T) {
logger := zerolog.Nop()
tests := []struct {
name string
cli bool
expectedFeatures []string
expectedVersion PostQuantumMode
}{
{
name: "default",
cli: false,
expectedFeatures: defaultFeatures,
expectedVersion: PostQuantumPrefer,
},
{
name: "user_specified",
cli: true,
expectedFeatures: Dedup(append(defaultFeatures, FeaturePostQuantum)),
expectedVersion: PostQuantumStrict,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
resolver := &staticResolver{record: featuresRecord{}}
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, []string{}, test.cli, time.Second)
require.NoError(t, err)
require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures())
require.Equal(t, test.expectedVersion, selector.PostQuantumMode())
})
}
}
func TestFeaturePrecedenceEvaluationDatagramVersion(t *testing.T) {
logger := zerolog.Nop()
tests := []struct {
name string
cli []string
remote featuresRecord
expectedFeatures []string
expectedVersion DatagramVersion
}{
{
name: "default",
cli: []string{},
remote: featuresRecord{},
expectedFeatures: defaultFeatures,
expectedVersion: DatagramV2,
},
{
name: "user_specified_v2",
cli: []string{FeatureDatagramV2},
remote: featuresRecord{},
expectedFeatures: defaultFeatures,
expectedVersion: DatagramV2,
},
{
name: "user_specified_v3",
cli: []string{FeatureDatagramV3},
remote: featuresRecord{},
expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)),
expectedVersion: FeatureDatagramV3,
},
{
name: "remote_specified_v3",
cli: []string{},
remote: featuresRecord{
DatagramV3Percentage: 100,
},
expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)),
expectedVersion: FeatureDatagramV3,
},
{
name: "remote_and_user_specified_v3",
cli: []string{FeatureDatagramV3},
remote: featuresRecord{
DatagramV3Percentage: 100,
},
expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)),
expectedVersion: FeatureDatagramV3,
},
{
name: "remote_v3_and_user_specified_v2",
cli: []string{FeatureDatagramV2},
remote: featuresRecord{
DatagramV3Percentage: 100,
},
expectedFeatures: defaultFeatures,
expectedVersion: DatagramV2,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
resolver := &staticResolver{record: test.remote}
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false, time.Second)
require.NoError(t, err)
require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures())
require.Equal(t, test.expectedVersion, selector.DatagramVersion())
})
}
}
func TestRefreshFeaturesRecord(t *testing.T) {
// The hash of the accountTag is 82
accountTag := t.Name()
threshold := switchThreshold(accountTag)
percentages := []int32{0, 10, 81, 82, 83, 100, 101, 1000}
refreshFreq := time.Millisecond * 10
selector := newTestSelector(t, percentages, false, refreshFreq)
// Starting out should default to DatagramV2
require.Equal(t, DatagramV2, selector.DatagramVersion())
for _, percentage := range percentages {
if percentage > threshold {
require.Equal(t, DatagramV3, selector.DatagramVersion())
} else {
require.Equal(t, DatagramV2, selector.DatagramVersion())
}
time.Sleep(refreshFreq + time.Millisecond)
}
// Make sure error doesn't override the last fetched features
require.Equal(t, DatagramV3, selector.DatagramVersion())
}
func TestStaticFeatures(t *testing.T) { func TestStaticFeatures(t *testing.T) {
percentages := []int32{0} pqMode := PostQuantumStrict
// PostQuantum Enabled from user flag selector := newTestSelector(t, &pqMode, time.Millisecond*10)
selector := newTestSelector(t, percentages, true, time.Millisecond*10)
require.Equal(t, PostQuantumStrict, selector.PostQuantumMode()) require.Equal(t, PostQuantumStrict, selector.PostQuantumMode())
// PostQuantum Disabled (or not set) // No StaticFeatures configured
selector = newTestSelector(t, percentages, false, time.Millisecond*10) selector = newTestSelector(t, nil, time.Millisecond*10)
require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode()) require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode())
} }
func newTestSelector(t *testing.T, percentages []int32, pq bool, refreshFreq time.Duration) *FeatureSelector { func newTestSelector(t *testing.T, pqMode *PostQuantumMode, refreshFreq time.Duration) *FeatureSelector {
accountTag := t.Name() accountTag := t.Name()
logger := zerolog.Nop() logger := zerolog.Nop()
resolver := &mockResolver{ resolver := &mockResolver{}
percentages: percentages,
}
selector, err := newFeatureSelector(context.Background(), accountTag, &logger, resolver, []string{}, pq, refreshFreq) staticFeatures := StaticFeatures{
PostQuantumMode: pqMode,
}
selector, err := newFeatureSelector(context.Background(), accountTag, &logger, resolver, staticFeatures, refreshFreq)
require.NoError(t, err) require.NoError(t, err)
return selector return selector
} }
type mockResolver struct { type mockResolver struct{}
nextIndex int
percentages []int32
}
func (mr *mockResolver) lookupRecord(ctx context.Context) ([]byte, error) { func (mr *mockResolver) lookupRecord(ctx context.Context) ([]byte, error) {
if mr.nextIndex >= len(mr.percentages) { return nil, fmt.Errorf("mockResolver hasn't implement lookupRecord")
return nil, fmt.Errorf("no more record to lookup")
}
record, err := json.Marshal(featuresRecord{
DatagramV3Percentage: mr.percentages[mr.nextIndex],
})
mr.nextIndex++
return record, err
}
type staticResolver struct {
record featuresRecord
}
func (r *staticResolver) lookupRecord(ctx context.Context) ([]byte, error) {
return json.Marshal(r.record)
} }

View File

@ -1,11 +0,0 @@
//go:build fips
package fips
import (
_ "crypto/tls/fipsonly"
)
func IsFipsEnabled() bool {
return true
}

12
fips/fips.go.linux-amd64 Normal file
View File

@ -0,0 +1,12 @@
// +build fips
package main
import (
_ "crypto/tls/fipsonly"
"github.com/cloudflare/cloudflared/cmd/cloudflared/tunnel"
)
func init () {
tunnel.FipsEnabled = true
}

View File

@ -1,7 +0,0 @@
//go:build !fips
package fips
func IsFipsEnabled() bool {
return false
}

View File

@ -1,77 +0,0 @@
package flow
import (
"errors"
"sync"
)
const (
unlimitedActiveFlows = 0
)
var (
ErrTooManyActiveFlows = errors.New("too many active flows")
)
type Limiter interface {
// Acquire tries to acquire a free slot for a flow, if the value of flows is already above
// the maximum it returns ErrTooManyActiveFlows.
Acquire(flowType string) error
// Release releases a slot for a flow.
Release()
// SetLimit allows to hot swap the limit value of the limiter.
SetLimit(uint64)
}
type flowLimiter struct {
limiterLock sync.Mutex
activeFlowsCounter uint64
maxActiveFlows uint64
unlimited bool
}
func NewLimiter(maxActiveFlows uint64) Limiter {
flowLimiter := &flowLimiter{
maxActiveFlows: maxActiveFlows,
unlimited: isUnlimited(maxActiveFlows),
}
return flowLimiter
}
func (s *flowLimiter) Acquire(flowType string) error {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
if !s.unlimited && s.activeFlowsCounter >= s.maxActiveFlows {
flowRegistrationsDropped.WithLabelValues(flowType).Inc()
return ErrTooManyActiveFlows
}
s.activeFlowsCounter++
return nil
}
func (s *flowLimiter) Release() {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
if s.activeFlowsCounter <= 0 {
return
}
s.activeFlowsCounter--
}
func (s *flowLimiter) SetLimit(newMaxActiveFlows uint64) {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
s.maxActiveFlows = newMaxActiveFlows
s.unlimited = isUnlimited(newMaxActiveFlows)
}
// isUnlimited checks if the value received matches the configuration for the unlimited flow limiter.
func isUnlimited(value uint64) bool {
return value == unlimitedActiveFlows
}

View File

@ -1,119 +0,0 @@
package flow_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/flow"
)
func TestFlowLimiter_Unlimited(t *testing.T) {
unlimitedLimiter := flow.NewLimiter(0)
for i := 0; i < 1000; i++ {
err := unlimitedLimiter.Acquire("test")
require.NoError(t, err)
}
}
func TestFlowLimiter_Limited(t *testing.T) {
maxFlows := uint64(5)
limiter := flow.NewLimiter(maxFlows)
for i := uint64(0); i < maxFlows; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
err := limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
}
func TestFlowLimiter_AcquireAndReleaseFlow(t *testing.T) {
maxFlows := uint64(5)
limiter := flow.NewLimiter(maxFlows)
// Acquire the maximum number of flows
for i := uint64(0); i < maxFlows; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
// Validate acquire 1 more flows fails
err := limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
// Release the maximum number of flows
for i := uint64(0); i < maxFlows; i++ {
limiter.Release()
}
// Validate acquire 1 more flows works
err = limiter.Acquire("shouldn't fail")
require.NoError(t, err)
// Release a 10x the number of max flows
for i := uint64(0); i < 10*maxFlows; i++ {
limiter.Release()
}
// Validate it still can only acquire a value = number max flows.
for i := uint64(0); i < maxFlows; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
err = limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
}
func TestFlowLimiter_SetLimit(t *testing.T) {
maxFlows := uint64(5)
limiter := flow.NewLimiter(maxFlows)
// Acquire the maximum number of flows
for i := uint64(0); i < maxFlows; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
// Validate acquire 1 more flows fails
err := limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
// Set the flow limiter to support one more request
limiter.SetLimit(maxFlows + 1)
// Validate acquire 1 more flows now works
err = limiter.Acquire("shouldn't fail")
require.NoError(t, err)
// Validate acquire 1 more flows doesn't work because we already reached the limit
err = limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
// Release all flows
for i := uint64(0); i < maxFlows+1; i++ {
limiter.Release()
}
// Validate 1 flow works again
err = limiter.Acquire("shouldn't fail")
require.NoError(t, err)
// Set the flow limit to 1
limiter.SetLimit(1)
// Validate acquire 1 more flows doesn't work
err = limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
// Set the flow limit to unlimited
limiter.SetLimit(0)
// Validate it can acquire a lot of flows because it is now unlimited.
for i := uint64(0); i < 10*maxFlows; i++ {
err := limiter.Acquire("shouldn't fail")
require.NoError(t, err)
}
}

View File

@ -1,23 +0,0 @@
package flow
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
const (
namespace = "flow"
)
var (
labels = []string{"flow_type"}
flowRegistrationsDropped = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "registrations_rate_limited_total",
Help: "Count registrations dropped due to high number of concurrent flows being handled",
},
labels,
)
)

21
go.mod
View File

@ -35,12 +35,11 @@ require (
go.opentelemetry.io/otel/trace v1.26.0 go.opentelemetry.io/otel/trace v1.26.0
go.opentelemetry.io/proto/otlp v1.2.0 go.opentelemetry.io/proto/otlp v1.2.0
go.uber.org/automaxprocs v1.4.0 go.uber.org/automaxprocs v1.4.0
go.uber.org/mock v0.5.0 golang.org/x/crypto v0.23.0
golang.org/x/crypto v0.31.0 golang.org/x/net v0.25.0
golang.org/x/net v0.26.0 golang.org/x/sync v0.7.0
golang.org/x/sync v0.10.0 golang.org/x/sys v0.20.0
golang.org/x/sys v0.28.0 golang.org/x/term v0.20.0
golang.org/x/term v0.27.0
google.golang.org/protobuf v1.34.1 google.golang.org/protobuf v1.34.1
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
@ -84,11 +83,12 @@ require (
github.com/prometheus/procfs v0.12.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect
go.opentelemetry.io/otel/metric v1.26.0 // indirect go.opentelemetry.io/otel/metric v1.26.0 // indirect
go.uber.org/mock v0.4.0 // indirect
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
golang.org/x/mod v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect
golang.org/x/oauth2 v0.18.0 // indirect golang.org/x/oauth2 v0.18.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/text v0.15.0 // indirect
golang.org/x/tools v0.22.0 // indirect golang.org/x/tools v0.21.0 // indirect
google.golang.org/appengine v1.6.8 // indirect google.golang.org/appengine v1.6.8 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240311132316-a219d84964c2 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240311132316-a219d84964c2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect
@ -102,6 +102,3 @@ replace github.com/urfave/cli/v2 => github.com/ipostelnik/cli/v2 v2.3.1-0.202103
replace github.com/prometheus/golang_client => github.com/prometheus/golang_client v1.12.1 replace github.com/prometheus/golang_client => github.com/prometheus/golang_client v1.12.1
replace gopkg.in/yaml.v3 => gopkg.in/yaml.v3 v3.0.1 replace gopkg.in/yaml.v3 => gopkg.in/yaml.v3 v3.0.1
// This fork is based on quic-go v0.45
replace github.com/quic-go/quic-go => github.com/chungthuang/quic-go v0.45.1-0.20250128102735-2687bd175910

40
go.sum
View File

@ -7,8 +7,6 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chungthuang/quic-go v0.45.1-0.20250128102735-2687bd175910 h1:/hTvBpxBDj/3NIzTodi1oEOyNBpirvgDSPKSV7VqAZU=
github.com/chungthuang/quic-go v0.45.1-0.20250128102735-2687bd175910/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
github.com/coredns/caddy v1.1.1 h1:2eYKZT7i6yxIfGP3qLJoJ7HAsDJqYB+X68g4NYjSrE0= github.com/coredns/caddy v1.1.1 h1:2eYKZT7i6yxIfGP3qLJoJ7HAsDJqYB+X68g4NYjSrE0=
github.com/coredns/caddy v1.1.1/go.mod h1:A6ntJQlAWuQfFlsd9hvigKbo2WS0VUs2l1e2F+BawD4= github.com/coredns/caddy v1.1.1/go.mod h1:A6ntJQlAWuQfFlsd9hvigKbo2WS0VUs2l1e2F+BawD4=
github.com/coredns/coredns v1.11.3 h1:8RjnpZc42db5th84/QJKH2i137ecJdzZK1HJwhetSPk= github.com/coredns/coredns v1.11.3 h1:8RjnpZc42db5th84/QJKH2i137ecJdzZK1HJwhetSPk=
@ -175,6 +173,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
github.com/quic-go/quic-go v0.45.0 h1:OHmkQGM37luZITyTSu6ff03HP/2IrwDX1ZFiNEhSFUE=
github.com/quic-go/quic-go v0.45.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
@ -217,33 +217,33 @@ go.opentelemetry.io/proto/otlp v1.2.0 h1:pVeZGk7nXDC9O2hncA6nHldxEjm6LByfA2aN8IO
go.opentelemetry.io/proto/otlp v1.2.0/go.mod h1:gGpR8txAl5M03pDhMC79G6SdqNV26naRm/KDsgaHD8A= go.opentelemetry.io/proto/otlp v1.2.0/go.mod h1:gGpR8txAl5M03pDhMC79G6SdqNV26naRm/KDsgaHD8A=
go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0= go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0=
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q= go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/oauth2 v0.18.0 h1:09qnuIAgzdx1XplqJvW6CQqMCtGZykZWcXzPMPUusvI= golang.org/x/oauth2 v0.18.0 h1:09qnuIAgzdx1XplqJvW6CQqMCtGZykZWcXzPMPUusvI=
golang.org/x/oauth2 v0.18.0/go.mod h1:Wf7knwG0MPoWIMMBgFlEaSUDaKskp0dCfrlJRJXbBi8= golang.org/x/oauth2 v0.18.0/go.mod h1:Wf7knwG0MPoWIMMBgFlEaSUDaKskp0dCfrlJRJXbBi8=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -254,19 +254,19 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
@ -275,8 +275,8 @@ golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -22,7 +22,6 @@ var (
const ( const (
defaultProxyAddress = "127.0.0.1" defaultProxyAddress = "127.0.0.1"
defaultKeepAliveConnections = 100 defaultKeepAliveConnections = 100
defaultMaxActiveFlows = 0 // unlimited
SSHServerFlag = "ssh-server" SSHServerFlag = "ssh-server"
Socks5Flag = "socks5" Socks5Flag = "socks5"
ProxyConnectTimeoutFlag = "proxy-connect-timeout" ProxyConnectTimeoutFlag = "proxy-connect-timeout"
@ -47,22 +46,17 @@ const (
type WarpRoutingConfig struct { type WarpRoutingConfig struct {
ConnectTimeout config.CustomDuration `yaml:"connectTimeout" json:"connectTimeout,omitempty"` ConnectTimeout config.CustomDuration `yaml:"connectTimeout" json:"connectTimeout,omitempty"`
MaxActiveFlows uint64 `yaml:"maxActiveFlows" json:"MaxActiveFlows,omitempty"`
TCPKeepAlive config.CustomDuration `yaml:"tcpKeepAlive" json:"tcpKeepAlive,omitempty"` TCPKeepAlive config.CustomDuration `yaml:"tcpKeepAlive" json:"tcpKeepAlive,omitempty"`
} }
func NewWarpRoutingConfig(raw *config.WarpRoutingConfig) WarpRoutingConfig { func NewWarpRoutingConfig(raw *config.WarpRoutingConfig) WarpRoutingConfig {
cfg := WarpRoutingConfig{ cfg := WarpRoutingConfig{
ConnectTimeout: defaultWarpRoutingConnectTimeout, ConnectTimeout: defaultWarpRoutingConnectTimeout,
MaxActiveFlows: defaultMaxActiveFlows,
TCPKeepAlive: defaultTCPKeepAlive, TCPKeepAlive: defaultTCPKeepAlive,
} }
if raw.ConnectTimeout != nil { if raw.ConnectTimeout != nil {
cfg.ConnectTimeout = *raw.ConnectTimeout cfg.ConnectTimeout = *raw.ConnectTimeout
} }
if raw.MaxActiveFlows != nil {
cfg.MaxActiveFlows = *raw.MaxActiveFlows
}
if raw.TCPKeepAlive != nil { if raw.TCPKeepAlive != nil {
cfg.TCPKeepAlive = *raw.TCPKeepAlive cfg.TCPKeepAlive = *raw.TCPKeepAlive
} }
@ -74,9 +68,6 @@ func (c *WarpRoutingConfig) RawConfig() config.WarpRoutingConfig {
if c.ConnectTimeout.Duration != defaultWarpRoutingConnectTimeout.Duration { if c.ConnectTimeout.Duration != defaultWarpRoutingConnectTimeout.Duration {
raw.ConnectTimeout = &c.ConnectTimeout raw.ConnectTimeout = &c.ConnectTimeout
} }
if c.MaxActiveFlows != defaultMaxActiveFlows {
raw.MaxActiveFlows = &c.MaxActiveFlows
}
if c.TCPKeepAlive.Duration != defaultTCPKeepAlive.Duration { if c.TCPKeepAlive.Duration != defaultTCPKeepAlive.Duration {
raw.TCPKeepAlive = &c.TCPKeepAlive raw.TCPKeepAlive = &c.TCPKeepAlive
} }
@ -181,7 +172,6 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
} }
if flag := ProxyPortFlag; c.IsSet(flag) { if flag := ProxyPortFlag; c.IsSet(flag) {
// Note TUN-3758 , we use Int because UInt is not supported with altsrc // Note TUN-3758 , we use Int because UInt is not supported with altsrc
// nolint: gosec
proxyPort = uint(c.Int(flag)) proxyPort = uint(c.Int(flag))
} }
if flag := Http2OriginFlag; c.IsSet(flag) { if flag := Http2OriginFlag; c.IsSet(flag) {
@ -561,7 +551,7 @@ func convertToRawIPRules(ipRules []ipaccess.Rule) []config.IngressIPRule {
} }
func defaultBoolToNil(b bool) *bool { func defaultBoolToNil(b bool) *bool {
if !b { if b == false {
return nil return nil
} }

View File

@ -28,8 +28,10 @@ type icmpProxy struct {
srcFunnelTracker *packet.FunnelTracker srcFunnelTracker *packet.FunnelTracker
echoIDTracker *echoIDTracker echoIDTracker *echoIDTracker
conn *icmp.PacketConn conn *icmp.PacketConn
logger *zerolog.Logger // Response is handled in one-by-one, so encoder can be shared between funnels
idleTimeout time.Duration encoder *packet.Encoder
logger *zerolog.Logger
idleTimeout time.Duration
} }
// echoIDTracker tracks which ID has been assigned. It first loops through assignment from lastAssignment to then end, // echoIDTracker tracks which ID has been assigned. It first loops through assignment from lastAssignment to then end,
@ -112,8 +114,8 @@ func (snf echoFunnelID) String() string {
return strconv.FormatUint(uint64(snf), 10) return strconv.FormatUint(uint64(snf), 10)
} }
func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) {
conn, err := newICMPConn(listenIP) conn, err := newICMPConn(listenIP, zone)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -121,15 +123,16 @@ func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.
return &icmpProxy{ return &icmpProxy{
srcFunnelTracker: packet.NewFunnelTracker(), srcFunnelTracker: packet.NewFunnelTracker(),
echoIDTracker: newEchoIDTracker(), echoIDTracker: newEchoIDTracker(),
encoder: packet.NewEncoder(),
conn: conn, conn: conn,
logger: logger, logger: logger,
idleTimeout: idleTimeout, idleTimeout: idleTimeout,
}, nil }, nil
} }
func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICMPResponder) error { func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *packetResponder) error {
_, span := responder.RequestSpan(ctx, pk) _, span := responder.requestSpan(ctx, pk)
defer responder.ExportSpan() defer responder.exportSpan()
originalEcho, err := getICMPEcho(pk.Message) originalEcho, err := getICMPEcho(pk.Message)
if err != nil { if err != nil {
@ -151,7 +154,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICM
} }
span.SetAttributes(attribute.Int("assignedEchoID", int(assignedEchoID))) span.SetAttributes(attribute.Int("assignedEchoID", int(assignedEchoID)))
shouldReplaceFunnelFunc := createShouldReplaceFunnelFunc(ip.logger, responder, pk, originalEcho.ID) shouldReplaceFunnelFunc := createShouldReplaceFunnelFunc(ip.logger, responder.datagramMuxer, pk, originalEcho.ID)
newFunnelFunc := func() (packet.Funnel, error) { newFunnelFunc := func() (packet.Funnel, error) {
originalEcho, err := getICMPEcho(pk.Message) originalEcho, err := getICMPEcho(pk.Message)
if err != nil { if err != nil {
@ -161,7 +164,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICM
ip.echoIDTracker.release(echoIDTrackerKey, assignedEchoID) ip.echoIDTracker.release(echoIDTrackerKey, assignedEchoID)
return nil return nil
} }
icmpFlow := newICMPEchoFlow(pk.Src, closeCallback, ip.conn, responder, int(assignedEchoID), originalEcho.ID) icmpFlow := newICMPEchoFlow(pk.Src, closeCallback, ip.conn, responder, int(assignedEchoID), originalEcho.ID, ip.encoder)
return icmpFlow, nil return icmpFlow, nil
} }
funnelID := echoFunnelID(assignedEchoID) funnelID := echoFunnelID(assignedEchoID)
@ -262,8 +265,8 @@ func (ip *icmpProxy) sendReply(ctx context.Context, reply *echoReply) error {
return err return err
} }
_, span := icmpFlow.responder.ReplySpan(ctx, ip.logger) _, span := icmpFlow.responder.replySpan(ctx, ip.logger)
defer icmpFlow.responder.ExportSpan() defer icmpFlow.responder.exportSpan()
if err := icmpFlow.returnToSrc(reply); err != nil { if err := icmpFlow.returnToSrc(reply); err != nil {
tracing.EndWithErrorStatus(span, err) tracing.EndWithErrorStatus(span, err)

View File

@ -18,7 +18,7 @@ var errICMPProxyNotImplemented = fmt.Errorf("ICMP proxy is not implemented on %s
type icmpProxy struct{} type icmpProxy struct{}
func (ip icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICMPResponder) error { func (ip icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *packetResponder) error {
return errICMPProxyNotImplemented return errICMPProxyNotImplemented
} }
@ -26,6 +26,6 @@ func (ip *icmpProxy) Serve(ctx context.Context) error {
return errICMPProxyNotImplemented return errICMPProxyNotImplemented
} }
func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) {
return nil, errICMPProxyNotImplemented return nil, errICMPProxyNotImplemented
} }

View File

@ -37,23 +37,25 @@ var (
type icmpProxy struct { type icmpProxy struct {
srcFunnelTracker *packet.FunnelTracker srcFunnelTracker *packet.FunnelTracker
listenIP netip.Addr listenIP netip.Addr
ipv6Zone string
logger *zerolog.Logger logger *zerolog.Logger
idleTimeout time.Duration idleTimeout time.Duration
} }
func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) {
if err := testPermission(listenIP, logger); err != nil { if err := testPermission(listenIP, zone, logger); err != nil {
return nil, err return nil, err
} }
return &icmpProxy{ return &icmpProxy{
srcFunnelTracker: packet.NewFunnelTracker(), srcFunnelTracker: packet.NewFunnelTracker(),
listenIP: listenIP, listenIP: listenIP,
ipv6Zone: zone,
logger: logger, logger: logger,
idleTimeout: idleTimeout, idleTimeout: idleTimeout,
}, nil }, nil
} }
func testPermission(listenIP netip.Addr, logger *zerolog.Logger) error { func testPermission(listenIP netip.Addr, zone string, logger *zerolog.Logger) error {
// Opens a non-privileged ICMP socket. On Linux the group ID of the process needs to be in ping_group_range // Opens a non-privileged ICMP socket. On Linux the group ID of the process needs to be in ping_group_range
// Only check ping_group_range once for IPv4 // Only check ping_group_range once for IPv4
if listenIP.Is4() { if listenIP.Is4() {
@ -62,7 +64,7 @@ func testPermission(listenIP netip.Addr, logger *zerolog.Logger) error {
return err return err
} }
} }
conn, err := newICMPConn(listenIP) conn, err := newICMPConn(listenIP, zone)
if err != nil { if err != nil {
return err return err
} }
@ -96,9 +98,9 @@ func checkInPingGroup() error {
return fmt.Errorf("did not find group range in %s", pingGroupPath) return fmt.Errorf("did not find group range in %s", pingGroupPath)
} }
func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICMPResponder) error { func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *packetResponder) error {
ctx, span := responder.RequestSpan(ctx, pk) ctx, span := responder.requestSpan(ctx, pk)
defer responder.ExportSpan() defer responder.exportSpan()
originalEcho, err := getICMPEcho(pk.Message) originalEcho, err := getICMPEcho(pk.Message)
if err != nil { if err != nil {
@ -107,9 +109,9 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICM
} }
observeICMPRequest(ip.logger, span, pk.Src.String(), pk.Dst.String(), originalEcho.ID, originalEcho.Seq) observeICMPRequest(ip.logger, span, pk.Src.String(), pk.Dst.String(), originalEcho.ID, originalEcho.Seq)
shouldReplaceFunnelFunc := createShouldReplaceFunnelFunc(ip.logger, responder, pk, originalEcho.ID) shouldReplaceFunnelFunc := createShouldReplaceFunnelFunc(ip.logger, responder.datagramMuxer, pk, originalEcho.ID)
newFunnelFunc := func() (packet.Funnel, error) { newFunnelFunc := func() (packet.Funnel, error) {
conn, err := newICMPConn(ip.listenIP) conn, err := newICMPConn(ip.listenIP, ip.ipv6Zone)
if err != nil { if err != nil {
tracing.EndWithErrorStatus(span, err) tracing.EndWithErrorStatus(span, err)
return nil, errors.Wrap(err, "failed to open ICMP socket") return nil, errors.Wrap(err, "failed to open ICMP socket")
@ -125,7 +127,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICM
span.SetAttributes(attribute.Int("port", localUDPAddr.Port)) span.SetAttributes(attribute.Int("port", localUDPAddr.Port))
echoID := localUDPAddr.Port echoID := localUDPAddr.Port
icmpFlow := newICMPEchoFlow(pk.Src, closeCallback, conn, responder, echoID, originalEcho.ID) icmpFlow := newICMPEchoFlow(pk.Src, closeCallback, conn, responder, echoID, originalEcho.ID, packet.NewEncoder())
return icmpFlow, nil return icmpFlow, nil
} }
funnelID := flow3Tuple{ funnelID := flow3Tuple{
@ -179,8 +181,8 @@ func (ip *icmpProxy) listenResponse(ctx context.Context, flow *icmpEchoFlow) {
// Listens for ICMP response and handles error logging // Listens for ICMP response and handles error logging
func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf []byte) (done bool) { func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf []byte) (done bool) {
_, span := flow.responder.ReplySpan(ctx, ip.logger) _, span := flow.responder.replySpan(ctx, ip.logger)
defer flow.responder.ExportSpan() defer flow.responder.exportSpan()
span.SetAttributes( span.SetAttributes(
attribute.Int("originalEchoID", flow.originalEchoID), attribute.Int("originalEchoID", flow.originalEchoID),

View File

@ -18,11 +18,15 @@ import (
) )
// Opens a non-privileged ICMP socket on Linux and Darwin // Opens a non-privileged ICMP socket on Linux and Darwin
func newICMPConn(listenIP netip.Addr) (*icmp.PacketConn, error) { func newICMPConn(listenIP netip.Addr, zone string) (*icmp.PacketConn, error) {
if listenIP.Is4() { if listenIP.Is4() {
return icmp.ListenPacket("udp4", listenIP.String()) return icmp.ListenPacket("udp4", listenIP.String())
} }
return icmp.ListenPacket("udp6", listenIP.String()) listenAddr := listenIP.String()
if zone != "" {
listenAddr = listenAddr + "%" + zone
}
return icmp.ListenPacket("udp6", listenAddr)
} }
func netipAddr(addr net.Addr) (netip.Addr, bool) { func netipAddr(addr net.Addr) (netip.Addr, bool) {
@ -30,8 +34,7 @@ func netipAddr(addr net.Addr) (netip.Addr, bool) {
if !ok { if !ok {
return netip.Addr{}, false return netip.Addr{}, false
} }
return netip.AddrFromSlice(udpAddr.IP)
return udpAddr.AddrPort().Addr(), true
} }
type flow3Tuple struct { type flow3Tuple struct {
@ -47,12 +50,14 @@ type icmpEchoFlow struct {
closed *atomic.Bool closed *atomic.Bool
src netip.Addr src netip.Addr
originConn *icmp.PacketConn originConn *icmp.PacketConn
responder ICMPResponder responder *packetResponder
assignedEchoID int assignedEchoID int
originalEchoID int originalEchoID int
// it's up to the user to ensure respEncoder is not used concurrently
respEncoder *packet.Encoder
} }
func newICMPEchoFlow(src netip.Addr, closeCallback func() error, originConn *icmp.PacketConn, responder ICMPResponder, assignedEchoID, originalEchoID int) *icmpEchoFlow { func newICMPEchoFlow(src netip.Addr, closeCallback func() error, originConn *icmp.PacketConn, responder *packetResponder, assignedEchoID, originalEchoID int, respEncoder *packet.Encoder) *icmpEchoFlow {
return &icmpEchoFlow{ return &icmpEchoFlow{
ActivityTracker: packet.NewActivityTracker(), ActivityTracker: packet.NewActivityTracker(),
closeCallback: closeCallback, closeCallback: closeCallback,
@ -62,6 +67,7 @@ func newICMPEchoFlow(src netip.Addr, closeCallback func() error, originConn *icm
responder: responder, responder: responder,
assignedEchoID: assignedEchoID, assignedEchoID: assignedEchoID,
originalEchoID: originalEchoID, originalEchoID: originalEchoID,
respEncoder: respEncoder,
} }
} }
@ -133,7 +139,11 @@ func (ief *icmpEchoFlow) returnToSrc(reply *echoReply) error {
}, },
Message: reply.msg, Message: reply.msg,
} }
return ief.responder.ReturnPacket(&pk) serializedPacket, err := ief.respEncoder.Encode(&pk)
if err != nil {
return err
}
return ief.responder.returnPacket(serializedPacket)
} }
type echoReply struct { type echoReply struct {
@ -174,7 +184,7 @@ func toICMPEchoFlow(funnel packet.Funnel) (*icmpEchoFlow, error) {
return icmpFlow, nil return icmpFlow, nil
} }
func createShouldReplaceFunnelFunc(logger *zerolog.Logger, responder ICMPResponder, pk *packet.ICMP, originalEchoID int) func(packet.Funnel) bool { func createShouldReplaceFunnelFunc(logger *zerolog.Logger, muxer muxer, pk *packet.ICMP, originalEchoID int) func(packet.Funnel) bool {
return func(existing packet.Funnel) bool { return func(existing packet.Funnel) bool {
existingFlow, err := toICMPEchoFlow(existing) existingFlow, err := toICMPEchoFlow(existing)
if err != nil { if err != nil {
@ -189,7 +199,7 @@ func createShouldReplaceFunnelFunc(logger *zerolog.Logger, responder ICMPRespond
// If the existing flow has a different muxer, there's a new quic connection where return packets should be // If the existing flow has a different muxer, there's a new quic connection where return packets should be
// routed. Otherwise, return packets will be send to the first observed incoming connection, rather than the // routed. Otherwise, return packets will be send to the first observed incoming connection, rather than the
// most recently observed connection. // most recently observed connection.
if existingFlow.responder.ConnectionIndex() != responder.ConnectionIndex() { if existingFlow.responder.datagramMuxer != muxer {
logger.Debug(). logger.Debug().
Str("src", pk.Src.String()). Str("src", pk.Src.String()).
Str("dst", pk.Dst.String()). Str("dst", pk.Dst.String()).

View File

@ -27,7 +27,7 @@ func TestFunnelIdleTimeout(t *testing.T) {
startSeq = 8129 startSeq = 8129
) )
logger := zerolog.New(os.Stderr) logger := zerolog.New(os.Stderr)
proxy, err := newICMPProxy(localhostIP, &logger, idleTimeout) proxy, err := newICMPProxy(localhostIP, "", &logger, idleTimeout)
require.NoError(t, err) require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -56,19 +56,24 @@ func TestFunnelIdleTimeout(t *testing.T) {
}, },
} }
muxer := newMockMuxer(0) muxer := newMockMuxer(0)
responder := newPacketResponder(muxer, 0, packet.NewEncoder()) responder := packetResponder{
require.NoError(t, proxy.Request(ctx, &pk, responder)) datagramMuxer: muxer,
}
require.NoError(t, proxy.Request(ctx, &pk, &responder))
validateEchoFlow(t, <-muxer.cfdToEdge, &pk) validateEchoFlow(t, <-muxer.cfdToEdge, &pk)
// Send second request, should reuse the funnel // Send second request, should reuse the funnel
require.NoError(t, proxy.Request(ctx, &pk, responder)) require.NoError(t, proxy.Request(ctx, &pk, &packetResponder{
datagramMuxer: muxer,
}))
validateEchoFlow(t, <-muxer.cfdToEdge, &pk) validateEchoFlow(t, <-muxer.cfdToEdge, &pk)
// New muxer on a different connection should use a new flow
time.Sleep(idleTimeout * 2) time.Sleep(idleTimeout * 2)
newMuxer := newMockMuxer(0) newMuxer := newMockMuxer(0)
newResponder := newPacketResponder(newMuxer, 1, packet.NewEncoder()) newResponder := packetResponder{
require.NoError(t, proxy.Request(ctx, &pk, newResponder)) datagramMuxer: newMuxer,
}
require.NoError(t, proxy.Request(ctx, &pk, &newResponder))
validateEchoFlow(t, <-newMuxer.cfdToEdge, &pk) validateEchoFlow(t, <-newMuxer.cfdToEdge, &pk)
time.Sleep(idleTimeout * 2) time.Sleep(idleTimeout * 2)
@ -85,7 +90,7 @@ func TestReuseFunnel(t *testing.T) {
startSeq = 8129 startSeq = 8129
) )
logger := zerolog.New(os.Stderr) logger := zerolog.New(os.Stderr)
proxy, err := newICMPProxy(localhostIP, &logger, idleTimeout) proxy, err := newICMPProxy(localhostIP, "", &logger, idleTimeout)
require.NoError(t, err) require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -119,14 +124,18 @@ func TestReuseFunnel(t *testing.T) {
originalEchoID: echoID, originalEchoID: echoID,
} }
muxer := newMockMuxer(0) muxer := newMockMuxer(0)
responder := newPacketResponder(muxer, 0, packet.NewEncoder()) responder := packetResponder{
require.NoError(t, proxy.Request(ctx, &pk, responder)) datagramMuxer: muxer,
}
require.NoError(t, proxy.Request(ctx, &pk, &responder))
validateEchoFlow(t, <-muxer.cfdToEdge, &pk) validateEchoFlow(t, <-muxer.cfdToEdge, &pk)
funnel1, found := getFunnel(t, proxy, tuple) funnel1, found := getFunnel(t, proxy, tuple)
require.True(t, found) require.True(t, found)
// Send second request, should reuse the funnel // Send second request, should reuse the funnel
require.NoError(t, proxy.Request(ctx, &pk, responder)) require.NoError(t, proxy.Request(ctx, &pk, &packetResponder{
datagramMuxer: muxer,
}))
validateEchoFlow(t, <-muxer.cfdToEdge, &pk) validateEchoFlow(t, <-muxer.cfdToEdge, &pk)
funnel2, found := getFunnel(t, proxy, tuple) funnel2, found := getFunnel(t, proxy, tuple)
require.True(t, found) require.True(t, found)

View File

@ -13,6 +13,7 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"runtime/debug" "runtime/debug"
"sync"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
@ -221,9 +222,11 @@ type icmpProxy struct {
// This is a ICMPv6 if srcSocketAddr is not nil // This is a ICMPv6 if srcSocketAddr is not nil
srcSocketAddr *sockAddrIn6 srcSocketAddr *sockAddrIn6
logger *zerolog.Logger logger *zerolog.Logger
// A pool of reusable *packet.Encoder
encoderPool sync.Pool
} }
func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) { func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idleTimeout time.Duration) (*icmpProxy, error) {
var ( var (
srcSocketAddr *sockAddrIn6 srcSocketAddr *sockAddrIn6
handle uintptr handle uintptr
@ -247,6 +250,11 @@ func newICMPProxy(listenIP netip.Addr, logger *zerolog.Logger, idleTimeout time.
handle: handle, handle: handle,
srcSocketAddr: srcSocketAddr, srcSocketAddr: srcSocketAddr,
logger: logger, logger: logger,
encoderPool: sync.Pool{
New: func() any {
return packet.NewEncoder()
},
},
}, nil }, nil
} }
@ -259,15 +267,15 @@ func (ip *icmpProxy) Serve(ctx context.Context) error {
// Request sends an ICMP echo request and wait for a reply or timeout. // Request sends an ICMP echo request and wait for a reply or timeout.
// The async version of Win32 APIs take a callback whose memory is not garbage collected, so we use the synchronous version. // The async version of Win32 APIs take a callback whose memory is not garbage collected, so we use the synchronous version.
// It's possible that a slow request will block other requests, so we set the timeout to only 1s. // It's possible that a slow request will block other requests, so we set the timeout to only 1s.
func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICMPResponder) error { func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *packetResponder) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
ip.logger.Error().Interface("error", r).Msgf("Recover panic from sending icmp request/response, error %s", debug.Stack()) ip.logger.Error().Interface("error", r).Msgf("Recover panic from sending icmp request/response, error %s", debug.Stack())
} }
}() }()
_, requestSpan := responder.RequestSpan(ctx, pk) _, requestSpan := responder.requestSpan(ctx, pk)
defer responder.ExportSpan() defer responder.exportSpan()
echo, err := getICMPEcho(pk.Message) echo, err := getICMPEcho(pk.Message)
if err != nil { if err != nil {
@ -282,9 +290,9 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICM
return err return err
} }
tracing.End(requestSpan) tracing.End(requestSpan)
responder.ExportSpan() responder.exportSpan()
_, replySpan := responder.ReplySpan(ctx, ip.logger) _, replySpan := responder.replySpan(ctx, ip.logger)
err = ip.handleEchoReply(pk, echo, resp, responder) err = ip.handleEchoReply(pk, echo, resp, responder)
if err != nil { if err != nil {
ip.logger.Err(err).Msg("Failed to send ICMP reply") ip.logger.Err(err).Msg("Failed to send ICMP reply")
@ -300,7 +308,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder ICM
return nil return nil
} }
func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, resp echoResp, responder ICMPResponder) error { func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, resp echoResp, responder *packetResponder) error {
var replyType icmp.Type var replyType icmp.Type
if request.Dst.Is4() { if request.Dst.Is4() {
replyType = ipv4.ICMPTypeEchoReply replyType = ipv4.ICMPTypeEchoReply
@ -325,7 +333,21 @@ func (ip *icmpProxy) handleEchoReply(request *packet.ICMP, echoReq *icmp.Echo, r
}, },
}, },
} }
return responder.ReturnPacket(&pk)
cachedEncoder := ip.encoderPool.Get()
// The encoded packet is a slice to of the encoder, so we shouldn't return the encoder back to the pool until
// the encoded packet is sent.
defer ip.encoderPool.Put(cachedEncoder)
encoder, ok := cachedEncoder.(*packet.Encoder)
if !ok {
return fmt.Errorf("encoderPool returned %T, expect *packet.Encoder", cachedEncoder)
}
serializedPacket, err := encoder.Encode(&pk)
if err != nil {
return err
}
return responder.returnPacket(serializedPacket)
} }
func (ip *icmpProxy) icmpEchoRoundtrip(dst netip.Addr, echo *icmp.Echo) (echoResp, error) { func (ip *icmpProxy) icmpEchoRoundtrip(dst netip.Addr, echo *icmp.Echo) (echoResp, error) {

Some files were not shown because too many files have changed in this diff Show More