Compare commits
55 Commits
Author | SHA1 | Date |
---|---|---|
|
d8a066628b | |
|
553e77e061 | |
|
8f94f54ec7 | |
|
2827b2fe8f | |
|
6dc8ed710e | |
|
e0b1ac0d05 | |
|
e7c5eb54af | |
|
cfec602fa7 | |
|
6fceb94998 | |
|
cf817f7036 | |
|
c8724a290a | |
|
e7586153be | |
|
11777db304 | |
|
3f6b1f24d0 | |
|
a4105e8708 | |
|
6496322bee | |
|
906452a9c9 | |
|
d969fdec3e | |
|
7336a1a4d6 | |
|
df5dafa6d7 | |
|
c19f919428 | |
|
b187879e69 | |
|
2feccd772c | |
|
90176a79b4 | |
|
9695829e5b | |
|
31a870b291 | |
|
bfdb0c76dc | |
|
45f67c23fd | |
|
0f1bfe99ce | |
|
18eecaf151 | |
|
4eb0f8ce5f | |
|
8c2eda16c1 | |
|
8bfe111cab | |
|
bf4954e96a | |
|
8918b6729e | |
|
25c3f676f4 | |
|
a1963aed80 | |
|
ac34f94d42 | |
|
d8c7f1c1ec | |
|
3b522a27cf | |
|
5cfe9bef79 | |
|
2714d10d62 | |
|
ac57ed9709 | |
|
c6901551e7 | |
|
9bc6cbd06d | |
|
bc9c5d2e6e | |
|
1859d742a8 | |
|
8ed19222b9 | |
|
02e7ffd5b7 | |
|
ba9f28ef43 | |
|
77b99cf5fe | |
|
d74ca97b51 | |
|
29f0cf354c | |
|
e7dcb6edca | |
|
14cf0eff1d |
|
@ -0,0 +1,89 @@
|
|||
linters:
|
||||
enable:
|
||||
# Some of the linters below are commented out. We should uncomment and start running them, but they return
|
||||
# too many problems to fix in one commit. Something for later.
|
||||
- asasalint # Check for pass []any as any in variadic func(...any).
|
||||
- asciicheck # Checks that all code identifiers does not have non-ASCII symbols in the name.
|
||||
- bidichk # Checks for dangerous unicode character sequences.
|
||||
- bodyclose # Checks whether HTTP response body is closed successfully.
|
||||
- decorder # Check declaration order and count of types, constants, variables and functions.
|
||||
- dogsled # Checks assignments with too many blank identifiers (e.g. x, , , _, := f()).
|
||||
- dupl # Tool for code clone detection.
|
||||
- dupword # Checks for duplicate words in the source code.
|
||||
- durationcheck # Check for two durations multiplied together.
|
||||
- errcheck # Errcheck is a program for checking for unchecked errors in Go code. These unchecked errors can be critical bugs in some cases.
|
||||
- errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error.
|
||||
- exhaustive # Check exhaustiveness of enum switch statements.
|
||||
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification.
|
||||
- goimports # Check import statements are formatted according to the 'goimport' command. Reformat imports in autofix mode.
|
||||
- gosec # Inspects source code for security problems.
|
||||
- gosimple # Linter for Go source code that specializes in simplifying code.
|
||||
- govet # Vet examines Go source code and reports suspicious constructs. It is roughly the same as 'go vet' and uses its passes.
|
||||
- ineffassign # Detects when assignments to existing variables are not used.
|
||||
- importas # Enforces consistent import aliases.
|
||||
- misspell # Finds commonly misspelled English words.
|
||||
- prealloc # Finds slice declarations that could potentially be pre-allocated.
|
||||
- promlinter # Check Prometheus metrics naming via promlint.
|
||||
- sloglint # Ensure consistent code style when using log/slog.
|
||||
- sqlclosecheck # Checks that sql.Rows, sql.Stmt, sqlx.NamedStmt, pgx.Query are closed.
|
||||
- staticcheck # It's a set of rules from staticcheck. It's not the same thing as the staticcheck binary.
|
||||
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
|
||||
- testableexamples # Linter checks if examples are testable (have an expected output).
|
||||
- testifylint # Checks usage of github.com/stretchr/testify.
|
||||
- tparallel # Tparallel detects inappropriate usage of t.Parallel() method in your Go test codes.
|
||||
- unconvert # Remove unnecessary type conversions.
|
||||
- unused # Checks Go code for unused constants, variables, functions and types.
|
||||
- wastedassign # Finds wasted assignment statements.
|
||||
- whitespace # Whitespace is a linter that checks for unnecessary newlines at the start and end of functions, if, for, etc.
|
||||
- zerologlint # Detects the wrong usage of zerolog that a user forgets to dispatch with Send or Msg.
|
||||
# Other linters are disabled, list of all is here: https://golangci-lint.run/usage/linters/
|
||||
run:
|
||||
timeout: 5m
|
||||
modules-download-mode: vendor
|
||||
|
||||
# output configuration options
|
||||
output:
|
||||
formats:
|
||||
- format: 'colored-line-number'
|
||||
print-issued-lines: true
|
||||
print-linter-name: true
|
||||
|
||||
issues:
|
||||
# Maximum issues count per one linter.
|
||||
# Set to 0 to disable.
|
||||
# Default: 50
|
||||
max-issues-per-linter: 50
|
||||
# Maximum count of issues with the same text.
|
||||
# Set to 0 to disable.
|
||||
# Default: 3
|
||||
max-same-issues: 15
|
||||
# Show only new issues: if there are unstaged changes or untracked files,
|
||||
# only those changes are analyzed, else only changes in HEAD~ are analyzed.
|
||||
# It's a super-useful option for integration of golangci-lint into existing large codebase.
|
||||
# It's not practical to fix all existing issues at the moment of integration:
|
||||
# much better don't allow issues in new code.
|
||||
#
|
||||
# Default: false
|
||||
new: true
|
||||
# Show only new issues created after git revision `REV`.
|
||||
# Default: ""
|
||||
new-from-rev: ac34f94d423273c8fa8fdbb5f2ac60e55f2c77d5
|
||||
# Show issues in any part of update files (requires new-from-rev or new-from-patch).
|
||||
# Default: false
|
||||
whole-files: true
|
||||
# Which dirs to exclude: issues from them won't be reported.
|
||||
# Can use regexp here: `generated.*`, regexp is applied on full path,
|
||||
# including the path prefix if one is set.
|
||||
# Default dirs are skipped independently of this option's value (see exclude-dirs-use-default).
|
||||
# "/" will be replaced by current OS file path separator to properly work on Windows.
|
||||
# Default: []
|
||||
exclude-dirs:
|
||||
- vendor
|
||||
|
||||
linters-settings:
|
||||
# Check exhaustiveness of enum switch statements.
|
||||
exhaustive:
|
||||
# Presence of "default" case in switch statements satisfies exhaustiveness,
|
||||
# even if all enum members are not listed.
|
||||
# Default: false
|
||||
default-signifies-exhaustive: true
|
|
@ -3,6 +3,6 @@
|
|||
cd /tmp
|
||||
git clone -q https://github.com/cloudflare/go
|
||||
cd go/src
|
||||
# https://github.com/cloudflare/go/tree/f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38 is version go1.22.5-devel-cf
|
||||
git checkout -q f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38
|
||||
# https://github.com/cloudflare/go/tree/af19da5605ca11f85776ef7af3384a02a315a52b is version go1.22.5-devel-cf
|
||||
git checkout -q af19da5605ca11f85776ef7af3384a02a315a52b
|
||||
./make.bash
|
||||
|
|
|
@ -22,6 +22,7 @@ TARGET_DIRECTORY=".build"
|
|||
BINARY_NAME="cloudflared"
|
||||
VERSION=$(git describe --tags --always --dirty="-dev")
|
||||
PRODUCT="cloudflared"
|
||||
APPLE_CA_CERT="apple_dev_ca.cert"
|
||||
CODE_SIGN_PRIV="code_sign.p12"
|
||||
CODE_SIGN_CERT="code_sign.cer"
|
||||
INSTALLER_PRIV="installer.p12"
|
||||
|
@ -35,15 +36,56 @@ mkdir -p ../src/github.com/cloudflare/
|
|||
cp -r . ../src/github.com/cloudflare/cloudflared
|
||||
cd ../src/github.com/cloudflare/cloudflared
|
||||
|
||||
# Add code signing private key to the key chain
|
||||
if [[ ! -z "$CFD_CODE_SIGN_KEY" ]]; then
|
||||
if [[ ! -z "$CFD_CODE_SIGN_PASS" ]]; then
|
||||
# write private key to disk and then import it keychain
|
||||
echo -n -e ${CFD_CODE_SIGN_KEY} | base64 -D > ${CODE_SIGN_PRIV}
|
||||
# Imports certificates to the Apple KeyChain
|
||||
import_certificate() {
|
||||
local CERTIFICATE_NAME=$1
|
||||
local CERTIFICATE_ENV_VAR=$2
|
||||
local CERTIFICATE_FILE_NAME=$3
|
||||
|
||||
echo "Importing $CERTIFICATE_NAME"
|
||||
|
||||
if [[ ! -z "$CERTIFICATE_ENV_VAR" ]]; then
|
||||
# write certificate to disk and then import it keychain
|
||||
echo -n -e ${CERTIFICATE_ENV_VAR} | base64 -D > ${CERTIFICATE_FILE_NAME}
|
||||
# we set || true here and for every `security import invoke` because the "duplicate SecKeychainItemImport" error
|
||||
# will cause set -e to exit 1. It is okay we do this because we deliberately handle this error in the lines below.
|
||||
out=$(security import ${CODE_SIGN_PRIV} -A -P "${CFD_CODE_SIGN_PASS}" 2>&1) || true
|
||||
exitcode=$?
|
||||
local out=$(security import ${CERTIFICATE_FILE_NAME} -A 2>&1) || true
|
||||
local exitcode=$?
|
||||
# delete the certificate from disk
|
||||
rm -rf ${CERTIFICATE_FILE_NAME}
|
||||
if [ -n "$out" ]; then
|
||||
if [ $exitcode -eq 0 ]; then
|
||||
echo "$out"
|
||||
else
|
||||
if [ "$out" != "${SEC_DUP_MSG}" ]; then
|
||||
echo "$out" >&2
|
||||
exit $exitcode
|
||||
else
|
||||
echo "already imported code signing certificate"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Imports private keys to the Apple KeyChain
|
||||
import_private_keys() {
|
||||
local PRIVATE_KEY_NAME=$1
|
||||
local PRIVATE_KEY_ENV_VAR=$2
|
||||
local PRIVATE_KEY_FILE_NAME=$3
|
||||
local PRIVATE_KEY_PASS=$4
|
||||
|
||||
echo "Importing $PRIVATE_KEY_NAME"
|
||||
|
||||
if [[ ! -z "$PRIVATE_KEY_ENV_VAR" ]]; then
|
||||
if [[ ! -z "$PRIVATE_KEY_PASS" ]]; then
|
||||
# write private key to disk and then import it keychain
|
||||
echo -n -e ${PRIVATE_KEY_ENV_VAR} | base64 -D > ${PRIVATE_KEY_FILE_NAME}
|
||||
# we set || true here and for every `security import invoke` because the "duplicate SecKeychainItemImport" error
|
||||
# will cause set -e to exit 1. It is okay we do this because we deliberately handle this error in the lines below.
|
||||
local out=$(security import ${PRIVATE_KEY_FILE_NAME} -A -P "${PRIVATE_KEY_PASS}" 2>&1) || true
|
||||
local exitcode=$?
|
||||
rm -rf ${PRIVATE_KEY_FILE_NAME}
|
||||
if [ -n "$out" ]; then
|
||||
if [ $exitcode -eq 0 ]; then
|
||||
echo "$out"
|
||||
|
@ -54,72 +96,24 @@ if [[ ! -z "$CFD_CODE_SIGN_KEY" ]]; then
|
|||
fi
|
||||
fi
|
||||
fi
|
||||
rm ${CODE_SIGN_PRIV}
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Add Apple Root Developer certificate to the key chain
|
||||
import_certificate "Apple Developer CA" "${APPLE_DEV_CA_CERT}" "${APPLE_CA_CERT}"
|
||||
|
||||
# Add code signing private key to the key chain
|
||||
import_private_keys "Developer ID Application" "${CFD_CODE_SIGN_KEY}" "${CODE_SIGN_PRIV}" "${CFD_CODE_SIGN_PASS}"
|
||||
|
||||
# Add code signing certificate to the key chain
|
||||
if [[ ! -z "$CFD_CODE_SIGN_CERT" ]]; then
|
||||
# write certificate to disk and then import it keychain
|
||||
echo -n -e ${CFD_CODE_SIGN_CERT} | base64 -D > ${CODE_SIGN_CERT}
|
||||
out1=$(security import ${CODE_SIGN_CERT} -A 2>&1) || true
|
||||
exitcode1=$?
|
||||
if [ -n "$out1" ]; then
|
||||
if [ $exitcode1 -eq 0 ]; then
|
||||
echo "$out1"
|
||||
else
|
||||
if [ "$out1" != "${SEC_DUP_MSG}" ]; then
|
||||
echo "$out1" >&2
|
||||
exit $exitcode1
|
||||
else
|
||||
echo "already imported code signing certificate"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
rm ${CODE_SIGN_CERT}
|
||||
fi
|
||||
import_certificate "Developer ID Application" "${CFD_CODE_SIGN_CERT}" "${CODE_SIGN_CERT}"
|
||||
|
||||
# Add package signing private key to the key chain
|
||||
if [[ ! -z "$CFD_INSTALLER_KEY" ]]; then
|
||||
if [[ ! -z "$CFD_INSTALLER_PASS" ]]; then
|
||||
# write private key to disk and then import it into the keychain
|
||||
echo -n -e ${CFD_INSTALLER_KEY} | base64 -D > ${INSTALLER_PRIV}
|
||||
out2=$(security import ${INSTALLER_PRIV} -A -P "${CFD_INSTALLER_PASS}" 2>&1) || true
|
||||
exitcode2=$?
|
||||
if [ -n "$out2" ]; then
|
||||
if [ $exitcode2 -eq 0 ]; then
|
||||
echo "$out2"
|
||||
else
|
||||
if [ "$out2" != "${SEC_DUP_MSG}" ]; then
|
||||
echo "$out2" >&2
|
||||
exit $exitcode2
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
rm ${INSTALLER_PRIV}
|
||||
fi
|
||||
fi
|
||||
import_private_keys "Developer ID Installer" "${CFD_INSTALLER_KEY}" "${INSTALLER_PRIV}" "${CFD_INSTALLER_PASS}"
|
||||
|
||||
# Add package signing certificate to the key chain
|
||||
if [[ ! -z "$CFD_INSTALLER_CERT" ]]; then
|
||||
# write certificate to disk and then import it keychain
|
||||
echo -n -e ${CFD_INSTALLER_CERT} | base64 -D > ${INSTALLER_CERT}
|
||||
out3=$(security import ${INSTALLER_CERT} -A 2>&1) || true
|
||||
exitcode3=$?
|
||||
if [ -n "$out3" ]; then
|
||||
if [ $exitcode3 -eq 0 ]; then
|
||||
echo "$out3"
|
||||
else
|
||||
if [ "$out3" != "${SEC_DUP_MSG}" ]; then
|
||||
echo "$out3" >&2
|
||||
exit $exitcode3
|
||||
else
|
||||
echo "already imported installer certificate"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
rm ${INSTALLER_CERT}
|
||||
fi
|
||||
import_certificate "Developer ID Installer" "${CFD_INSTALLER_CERT}" "${INSTALLER_CERT}"
|
||||
|
||||
# get the code signing certificate name
|
||||
if [[ ! -z "$CFD_CODE_SIGN_NAME" ]]; then
|
||||
|
|
|
@ -9,8 +9,8 @@ Set-Location "$Env:Temp"
|
|||
git clone -q https://github.com/cloudflare/go
|
||||
Write-Output "Building go..."
|
||||
cd go/src
|
||||
# https://github.com/cloudflare/go/tree/f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38 is version go1.22.5-devel-cf
|
||||
git checkout -q f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38
|
||||
# https://github.com/cloudflare/go/tree/af19da5605ca11f85776ef7af3384a02a315a52b is version go1.22.5-devel-cf
|
||||
git checkout -q af19da5605ca11f85776ef7af3384a02a315a52b
|
||||
& ./make.bat
|
||||
|
||||
Write-Output "Installed"
|
||||
|
|
12
CHANGES.md
12
CHANGES.md
|
@ -1,3 +1,15 @@
|
|||
## 2025.1.1
|
||||
### New Features
|
||||
- This release introduces the use of new Post Quantum curves and the ability to use Post Quantum curves when running tunnels with the QUIC protocol this applies to non-FIPS and FIPS builds.
|
||||
|
||||
## 2024.12.2
|
||||
### New Features
|
||||
- This release introduces the ability to collect troubleshooting information from one instance of cloudflared running on the local machine. The command can be executed as `cloudflared tunnel diag`.
|
||||
|
||||
## 2024.12.1
|
||||
### Notices
|
||||
- The use of the `--metrics` is still honoured meaning that if this flag is set the metrics server will try to bind it, however, this version includes a change that makes the metrics server bind to a port with a semi-deterministic approach. If the metrics flag is not present the server will bind to the first available port of the range 20241 to 20245. In case of all ports being unavailable then the fallback is to bind to a random port.
|
||||
|
||||
## 2024.10.0
|
||||
### Bug Fixes
|
||||
- We fixed a bug related to `--grace-period`. Tunnels that use QUIC as transport weren't abiding by this waiting period before forcefully closing the connections to the edge. From now on, both QUIC and HTTP2 tunnels will wait for either the grace period to end (defaults to 30 seconds) or until the last in-flight request is handled. Users that wish to maintain the previous behavior should set `--grace-period` to 0 if `--protocol` is set to `quic`. This will force `cloudflared` to shutdown as soon as either SIGTERM or SIGINT is received.
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
# use a builder image for building cloudflare
|
||||
ARG TARGET_GOOS
|
||||
ARG TARGET_GOARCH
|
||||
FROM golang:1.22.5 as builder
|
||||
FROM golang:1.22.10 as builder
|
||||
ENV GO111MODULE=on \
|
||||
CGO_ENABLED=0 \
|
||||
TARGET_GOOS=${TARGET_GOOS} \
|
||||
TARGET_GOARCH=${TARGET_GOARCH} \
|
||||
# the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual
|
||||
# which changes how cloudflared binds the metrics server
|
||||
CONTAINER_BUILD=1
|
||||
|
||||
|
||||
|
@ -20,7 +22,7 @@ RUN .teamcity/install-cloudflare-go.sh
|
|||
RUN PATH="/tmp/go/bin:$PATH" make cloudflared
|
||||
|
||||
# use a distroless base image with glibc
|
||||
FROM gcr.io/distroless/base-debian11:nonroot
|
||||
FROM gcr.io/distroless/base-debian12:nonroot
|
||||
|
||||
LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared"
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
# use a builder image for building cloudflare
|
||||
FROM golang:1.22.5 as builder
|
||||
FROM golang:1.22.10 as builder
|
||||
ENV GO111MODULE=on \
|
||||
CGO_ENABLED=0
|
||||
CGO_ENABLED=0 \
|
||||
# the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual
|
||||
# which changes how cloudflared binds the metrics server
|
||||
CONTAINER_BUILD=1
|
||||
|
||||
WORKDIR /go/src/github.com/cloudflare/cloudflared/
|
||||
|
||||
|
@ -14,7 +17,7 @@ RUN .teamcity/install-cloudflare-go.sh
|
|||
RUN GOOS=linux GOARCH=amd64 PATH="/tmp/go/bin:$PATH" make cloudflared
|
||||
|
||||
# use a distroless base image with glibc
|
||||
FROM gcr.io/distroless/base-debian11:nonroot
|
||||
FROM gcr.io/distroless/base-debian12:nonroot
|
||||
|
||||
LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared"
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
# use a builder image for building cloudflare
|
||||
FROM golang:1.22.5 as builder
|
||||
FROM golang:1.22.10 as builder
|
||||
ENV GO111MODULE=on \
|
||||
CGO_ENABLED=0
|
||||
CGO_ENABLED=0 \
|
||||
# the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual
|
||||
# which changes how cloudflared binds the metrics server
|
||||
CONTAINER_BUILD=1
|
||||
|
||||
WORKDIR /go/src/github.com/cloudflare/cloudflared/
|
||||
|
||||
|
@ -14,7 +17,7 @@ RUN .teamcity/install-cloudflare-go.sh
|
|||
RUN GOOS=linux GOARCH=arm64 PATH="/tmp/go/bin:$PATH" make cloudflared
|
||||
|
||||
# use a distroless base image with glibc
|
||||
FROM gcr.io/distroless/base-debian11:nonroot-arm64
|
||||
FROM gcr.io/distroless/base-debian12:nonroot-arm64
|
||||
|
||||
LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared"
|
||||
|
||||
|
|
19
Makefile
19
Makefile
|
@ -24,7 +24,7 @@ else
|
|||
DEB_PACKAGE_NAME := $(BINARY_NAME)
|
||||
endif
|
||||
|
||||
DATE := $(shell date -u '+%Y-%m-%d-%H%M UTC')
|
||||
DATE := $(shell date -u -r RELEASE_NOTES '+%Y-%m-%d-%H%M UTC')
|
||||
VERSION_FLAGS := -X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)"
|
||||
ifdef PACKAGE_MANAGER
|
||||
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)"
|
||||
|
@ -133,11 +133,9 @@ clean:
|
|||
cloudflared:
|
||||
ifeq ($(FIPS), true)
|
||||
$(info Building cloudflared with go-fips)
|
||||
cp -f fips/fips.go.linux-amd64 cmd/cloudflared/fips.go
|
||||
endif
|
||||
GOOS=$(TARGET_OS) GOARCH=$(TARGET_ARCH) $(ARM_COMMAND) go build -mod=vendor $(GO_BUILD_TAGS) $(LDFLAGS) $(IMPORT_PATH)/cmd/cloudflared
|
||||
ifeq ($(FIPS), true)
|
||||
rm -f cmd/cloudflared/fips.go
|
||||
./check-fips.sh cloudflared
|
||||
endif
|
||||
|
||||
|
@ -255,4 +253,17 @@ vet:
|
|||
|
||||
.PHONY: fmt
|
||||
fmt:
|
||||
goimports -l -w -local github.com/cloudflare/cloudflared $$(go list -mod=vendor -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc/proto)
|
||||
@goimports -l -w -local github.com/cloudflare/cloudflared $$(go list -mod=vendor -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc/proto)
|
||||
@go fmt $$(go list -mod=vendor -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc/proto)
|
||||
|
||||
.PHONY: fmt-check
|
||||
fmt-check:
|
||||
@./fmt-check.sh
|
||||
|
||||
.PHONY: lint
|
||||
lint:
|
||||
@golangci-lint run
|
||||
|
||||
.PHONY: mocks
|
||||
mocks:
|
||||
go generate mocks/mockgen.go
|
||||
|
|
26
README.md
26
README.md
|
@ -40,7 +40,7 @@ User documentation for Cloudflare Tunnel can be found at https://developers.clou
|
|||
|
||||
Once installed, you can authenticate `cloudflared` into your Cloudflare account and begin creating Tunnels to serve traffic to your origins.
|
||||
|
||||
* Create a Tunnel with [these instructions](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/create-tunnel)
|
||||
* Create a Tunnel with [these instructions](https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/get-started/)
|
||||
* Route traffic to that Tunnel:
|
||||
* Via public [DNS records in Cloudflare](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/dns)
|
||||
* Or via a public hostname guided by a [Cloudflare Load Balancer](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/lb)
|
||||
|
@ -56,3 +56,27 @@ Want to test Cloudflare Tunnel before adding a website to Cloudflare? You can do
|
|||
Cloudflare currently supports versions of cloudflared that are **within one year** of the most recent release. Breaking changes unrelated to feature availability may be introduced that will impact versions released more than one year ago. You can read more about upgrading cloudflared in our [developer documentation](https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/downloads/#updating-cloudflared).
|
||||
|
||||
For example, as of January 2023 Cloudflare will support cloudflared version 2023.1.1 to cloudflared 2022.1.1.
|
||||
|
||||
## Development
|
||||
|
||||
### Requirements
|
||||
- [GNU Make](https://www.gnu.org/software/make/)
|
||||
- [capnp](https://capnproto.org/install.html)
|
||||
- [cloudflare go toolchain](https://github.com/cloudflare/go)
|
||||
- Optional tools:
|
||||
- [capnpc-go](https://pkg.go.dev/zombiezen.com/go/capnproto2/capnpc-go)
|
||||
- [goimports](https://pkg.go.dev/golang.org/x/tools/cmd/goimports)
|
||||
- [golangci-lint](https://github.com/golangci/golangci-lint)
|
||||
- [gomocks](https://pkg.go.dev/go.uber.org/mock)
|
||||
|
||||
### Build
|
||||
To build cloudflared locally run `make cloudflared`
|
||||
|
||||
### Test
|
||||
To locally run the tests run `make test`
|
||||
|
||||
### Linting
|
||||
To format the code and keep a good code quality use `make fmt` and `make lint`
|
||||
|
||||
### Mocks
|
||||
After changes on interfaces you might need to regenerate the mocks, so run `make mock`
|
||||
|
|
|
@ -1,3 +1,65 @@
|
|||
2025.4.0
|
||||
- 2025-04-02 Fix broken links in `cmd/cloudflared/*.go` related to running tunnel as a service
|
||||
- 2025-04-02 chore: remove repetitive words
|
||||
- 2025-04-01 Fix messages to point to one.dash.cloudflare.com
|
||||
- 2025-04-01 feat: emit explicit errors for the `service` command on unsupported OSes
|
||||
- 2025-04-01 Use RELEASE_NOTES date instead of build date
|
||||
- 2025-04-01 chore: Update tunnel configuration link in the readme
|
||||
- 2025-04-01 fix: expand home directory for credentials file
|
||||
- 2025-04-01 fix: Use path and filepath operation appropriately
|
||||
- 2025-04-01 feat: Adds a new command line for tunnel run for token file
|
||||
- 2025-04-01 chore: fix linter rules
|
||||
- 2025-03-17 TUN-9101: Don't ignore errors on `cloudflared access ssh`
|
||||
- 2025-03-06 TUN-9089: Pin go import to v0.30.0, v0.31.0 requires go 1.23
|
||||
|
||||
2025.2.1
|
||||
- 2025-02-26 TUN-9016: update base-debian to v12
|
||||
- 2025-02-25 TUN-8960: Connect to FED API GW based on the OriginCert's endpoint
|
||||
- 2025-02-25 TUN-9007: modify logic to resolve region when the tunnel token has an endpoint field
|
||||
- 2025-02-13 SDLC-3762: Remove backstage.io/source-location from catalog-info.yaml
|
||||
- 2025-02-06 TUN-8914: Create a flags module to group all cloudflared cli flags
|
||||
|
||||
2025.2.0
|
||||
- 2025-02-03 TUN-8914: Add a new configuration to locally override the max-active-flows
|
||||
- 2025-02-03 Bump x/crypto to 0.31.0
|
||||
|
||||
2025.1.1
|
||||
- 2025-01-30 TUN-8858: update go to 1.22.10 and include quic-go FIPS changes
|
||||
- 2025-01-30 TUN-8855: fix lint issues
|
||||
- 2025-01-30 TUN-8855: Update PQ curve preferences
|
||||
- 2025-01-30 TUN-8857: remove restriction for using FIPS and PQ
|
||||
- 2025-01-30 TUN-8894: report FIPS+PQ error to Sentry when dialling to the edge
|
||||
- 2025-01-22 TUN-8904: Rename Connect Response Flow Rate Limited metadata
|
||||
- 2025-01-21 AUTH-6633 Fix cloudflared access login + warp as auth
|
||||
- 2025-01-20 TUN-8861: Add session limiter to UDP session manager
|
||||
- 2025-01-20 TUN-8861: Rename Session Limiter to Flow Limiter
|
||||
- 2025-01-17 TUN-8900: Add import of Apple Developer Certificate Authority to macOS Pipeline
|
||||
- 2025-01-17 TUN-8871: Accept login flag to authenticate with Fedramp environment
|
||||
- 2025-01-16 TUN-8866: Add linter to cloudflared repository
|
||||
- 2025-01-14 TUN-8861: Add session limiter to TCP session manager
|
||||
- 2025-01-13 TUN-8861: Add configuration for active sessions limiter
|
||||
- 2025-01-09 TUN-8848: Don't treat connection shutdown as an error condition when RPC server is done
|
||||
|
||||
2025.1.0
|
||||
- 2025-01-06 TUN-8842: Add Ubuntu Noble and 'any' debian distributions to release script
|
||||
- 2025-01-06 TUN-8807: Add support_datagram_v3 to remote feature rollout
|
||||
- 2024-12-20 TUN-8829: add CONTAINER_BUILD to dockerfiles
|
||||
|
||||
2024.12.2
|
||||
- 2024-12-19 TUN-8822: Prevent concurrent usage of ICMPDecoder
|
||||
- 2024-12-18 TUN-8818: update changes document to reflect newly added diag subcommand
|
||||
- 2024-12-17 TUN-8817: Increase close session channel by one since there are two writers
|
||||
- 2024-12-13 TUN-8797: update CHANGES.md with note about semi-deterministic approach used to bind metrics server
|
||||
- 2024-12-13 TUN-8724: Add CLI command for diagnostic procedure
|
||||
- 2024-12-11 TUN-8786: calculate cli flags once for the diagnostic procedure
|
||||
- 2024-12-11 TUN-8792: Make diag/system endpoint always return a JSON
|
||||
- 2024-12-10 TUN-8783: fix log collectors for the diagnostic procedure
|
||||
- 2024-12-10 TUN-8785: include the icmp sources in the diag's tunnel state
|
||||
- 2024-12-10 TUN-8784: Set JSON encoder options to print formatted JSON when writing diag files
|
||||
|
||||
2024.12.1
|
||||
- 2024-12-10 TUN-8795: update createrepo to createrepo_c to fix the release_pkgs.py script
|
||||
|
||||
2024.12.0
|
||||
- 2024-12-09 TUN-8640: Add ICMP support for datagram V3
|
||||
- 2024-12-09 TUN-8789: make python package installation consistent
|
||||
|
|
|
@ -17,7 +17,7 @@ make cloudflared-deb
|
|||
mv cloudflared-fips\_$VERSION\_$arch.deb $ARTIFACT_DIR/cloudflared-fips-linux-$arch.deb
|
||||
|
||||
# rpm packages invert the - and _ and use x86_64 instead of amd64.
|
||||
RPMVERSION=$(echo $VERSION|sed -r 's/-/_/g')
|
||||
RPMVERSION=$(echo $VERSION | sed -r 's/-/_/g')
|
||||
RPMARCH="x86_64"
|
||||
make cloudflared-rpm
|
||||
mv cloudflared-fips-$RPMVERSION-1.$RPMARCH.rpm $ARTIFACT_DIR/cloudflared-fips-linux-$RPMARCH.rpm
|
||||
|
|
|
@ -4,7 +4,6 @@ metadata:
|
|||
name: cloudflared
|
||||
description: Client for Cloudflare Tunnels
|
||||
annotations:
|
||||
backstage.io/source-location: url:https://bitbucket.cfdata.org/projects/TUN/repos/cloudflared/browse
|
||||
cloudflare.com/software-excellence-opt-in: "true"
|
||||
cloudflare.com/jira-project-key: "TUN"
|
||||
cloudflare.com/jira-project-component: "Cloudflare Tunnel"
|
||||
|
|
12
cfsetup.yaml
12
cfsetup.yaml
|
@ -1,4 +1,4 @@
|
|||
pinned_go: &pinned_go go-boring=1.22.5-1
|
||||
pinned_go: &pinned_go go-boring=1.22.10-1
|
||||
|
||||
build_dir: &build_dir /cfsetup_build
|
||||
default-flavor: bookworm
|
||||
|
@ -13,10 +13,14 @@ bullseye: &bullseye
|
|||
- rubygem-fpm
|
||||
- rpm
|
||||
- libffi-dev
|
||||
- golangci-lint
|
||||
pre-cache: &build_pre_cache
|
||||
- export GOCACHE=/cfsetup_build/.cache/go-build
|
||||
- go install golang.org/x/tools/cmd/goimports@latest
|
||||
- go install golang.org/x/tools/cmd/goimports@v0.30.0
|
||||
post-cache:
|
||||
# Linting
|
||||
- make lint
|
||||
- make fmt-check
|
||||
# Build binary for component test
|
||||
- GOOS=linux GOARCH=amd64 make cloudflared
|
||||
build-linux-fips:
|
||||
|
@ -156,7 +160,6 @@ bullseye: &bullseye
|
|||
- export GOOS=linux
|
||||
- export GOARCH=amd64
|
||||
- export PATH="$HOME/go/bin:$PATH"
|
||||
- ./fmt-check.sh
|
||||
- make test | gotest-to-teamcity
|
||||
test-fips:
|
||||
build_dir: *build_dir
|
||||
|
@ -167,7 +170,6 @@ bullseye: &bullseye
|
|||
- export GOARCH=amd64
|
||||
- export FIPS=true
|
||||
- export PATH="$HOME/go/bin:$PATH"
|
||||
- ./fmt-check.sh
|
||||
- make test | gotest-to-teamcity
|
||||
component-test:
|
||||
build_dir: *build_dir
|
||||
|
@ -243,7 +245,7 @@ bullseye: &bullseye
|
|||
- python3-setuptools
|
||||
- python3-pip
|
||||
- reprepro
|
||||
- createrepo
|
||||
- createrepo-c
|
||||
- python3-venv
|
||||
post-cache:
|
||||
- python3 -m venv env
|
||||
|
|
|
@ -104,7 +104,7 @@ func ssh(c *cli.Context) error {
|
|||
case 3:
|
||||
options.OriginURL = fmt.Sprintf("https://%s:%s", parts[2], parts[1])
|
||||
options.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
InsecureSkipVerify: true, // #nosec G402
|
||||
ServerName: parts[0],
|
||||
}
|
||||
log.Warn().Msgf("Using insecure SSL connection because SNI overridden to %s", parts[0])
|
||||
|
@ -141,6 +141,5 @@ func ssh(c *cli.Context) error {
|
|||
logger := log.With().Str("host", url.Host).Logger()
|
||||
s = stream.NewDebugStream(s, &logger, maxMessages)
|
||||
}
|
||||
carrier.StartClient(wsConn, s, options)
|
||||
return nil
|
||||
return carrier.StartClient(wsConn, s, options)
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
|
||||
"github.com/cloudflare/cloudflared/carrier"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/sshgen"
|
||||
"github.com/cloudflare/cloudflared/token"
|
||||
|
@ -172,15 +173,15 @@ func Commands() []*cli.Command {
|
|||
EnvVars: []string{"TUNNEL_SERVICE_TOKEN_SECRET"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: logger.LogFileFlag,
|
||||
Name: cfdflags.LogFile,
|
||||
Usage: "Save application log to this file for reporting issues.",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: logger.LogSSHDirectoryFlag,
|
||||
Name: cfdflags.LogDirectory,
|
||||
Usage: "Save application log to this directory for reporting issues.",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: logger.LogSSHLevelFlag,
|
||||
Name: cfdflags.LogLevelSSH,
|
||||
Aliases: []string{"loglevel"}, //added to match the tunnel side
|
||||
Usage: "Application logging level {debug, info, warn, error, fatal}. ",
|
||||
},
|
||||
|
@ -342,7 +343,7 @@ func run(cmd string, args ...string) error {
|
|||
return err
|
||||
}
|
||||
go func() {
|
||||
io.Copy(os.Stderr, stderr)
|
||||
_, _ = io.Copy(os.Stderr, stderr)
|
||||
}()
|
||||
|
||||
stdout, err := c.StdoutPipe()
|
||||
|
@ -350,7 +351,7 @@ func run(cmd string, args ...string) error {
|
|||
return err
|
||||
}
|
||||
go func() {
|
||||
io.Copy(os.Stdout, stdout)
|
||||
_, _ = io.Copy(os.Stdout, stdout)
|
||||
}()
|
||||
return c.Run()
|
||||
}
|
||||
|
@ -531,7 +532,7 @@ func isFileThere(candidate string) bool {
|
|||
}
|
||||
|
||||
// verifyTokenAtEdge checks for a token on disk, or generates a new one.
|
||||
// Then makes a request to to the origin with the token to ensure it is valid.
|
||||
// Then makes a request to the origin with the token to ensure it is valid.
|
||||
// Returns nil if token is valid.
|
||||
func verifyTokenAtEdge(appUrl *url.URL, appInfo *token.AppInfo, c *cli.Context, log *zerolog.Logger) error {
|
||||
headers := parseRequestHeaders(c.StringSlice(sshHeaderFlag))
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"github.com/urfave/cli/v2"
|
||||
"github.com/urfave/cli/v2/altsrc"
|
||||
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -15,14 +15,14 @@ var (
|
|||
func ConfigureLoggingFlags(shouldHide bool) []cli.Flag {
|
||||
return []cli.Flag{
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: logger.LogLevelFlag,
|
||||
Name: cfdflags.LogLevel,
|
||||
Value: "info",
|
||||
Usage: "Application logging level {debug, info, warn, error, fatal}. " + debugLevelWarning,
|
||||
EnvVars: []string{"TUNNEL_LOGLEVEL"},
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: logger.LogTransportLevelFlag,
|
||||
Name: cfdflags.TransportLogLevel,
|
||||
Aliases: []string{"proto-loglevel"}, // This flag used to be called proto-loglevel
|
||||
Value: "info",
|
||||
Usage: "Transport logging level(previously called protocol logging level) {debug, info, warn, error, fatal}",
|
||||
|
@ -30,19 +30,19 @@ func ConfigureLoggingFlags(shouldHide bool) []cli.Flag {
|
|||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: logger.LogFileFlag,
|
||||
Name: cfdflags.LogFile,
|
||||
Usage: "Save application log to this file for reporting issues.",
|
||||
EnvVars: []string{"TUNNEL_LOGFILE"},
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: logger.LogDirectoryFlag,
|
||||
Name: cfdflags.LogDirectory,
|
||||
Usage: "Save application log to this directory for reporting issues.",
|
||||
EnvVars: []string{"TUNNEL_LOGDIRECTORY"},
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "trace-output",
|
||||
Name: cfdflags.TraceOutput,
|
||||
Usage: "Name of trace output file, generated when cloudflared stops.",
|
||||
EnvVars: []string{"TUNNEL_TRACE_OUTPUT"},
|
||||
Hidden: shouldHide,
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
package flags
|
||||
|
||||
const (
|
||||
// HaConnections specifies how many connections to make to the edge
|
||||
HaConnections = "ha-connections"
|
||||
|
||||
// SshPort is the port on localhost the cloudflared ssh server will run on
|
||||
SshPort = "local-ssh-port"
|
||||
|
||||
// SshIdleTimeout defines the duration a SSH session can remain idle before being closed
|
||||
SshIdleTimeout = "ssh-idle-timeout"
|
||||
|
||||
// SshMaxTimeout defines the max duration a SSH session can remain open for
|
||||
SshMaxTimeout = "ssh-max-timeout"
|
||||
|
||||
// SshLogUploaderBucketName is the bucket name to use for the SSH log uploader
|
||||
SshLogUploaderBucketName = "bucket-name"
|
||||
|
||||
// SshLogUploaderRegionName is the AWS region name to use for the SSH log uploader
|
||||
SshLogUploaderRegionName = "region-name"
|
||||
|
||||
// SshLogUploaderSecretID is the Secret id of SSH log uploader
|
||||
SshLogUploaderSecretID = "secret-id"
|
||||
|
||||
// SshLogUploaderAccessKeyID is the Access key id of SSH log uploader
|
||||
SshLogUploaderAccessKeyID = "access-key-id"
|
||||
|
||||
// SshLogUploaderSessionTokenID is the Session token of SSH log uploader
|
||||
SshLogUploaderSessionTokenID = "session-token"
|
||||
|
||||
// SshLogUploaderS3URL is the S3 URL of SSH log uploader (e.g. don't use AWS s3 and use google storage bucket instead)
|
||||
SshLogUploaderS3URL = "s3-url-host"
|
||||
|
||||
// HostKeyPath is the path of the dir to save SSH host keys too
|
||||
HostKeyPath = "host-key-path"
|
||||
|
||||
// RpcTimeout is how long to wait for a Capnp RPC request to the edge
|
||||
RpcTimeout = "rpc-timeout"
|
||||
|
||||
// WriteStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin).
|
||||
WriteStreamTimeout = "write-stream-timeout"
|
||||
|
||||
// QuicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size.
|
||||
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
|
||||
// Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets.
|
||||
QuicDisablePathMTUDiscovery = "quic-disable-pmtu-discovery"
|
||||
|
||||
// QuicConnLevelFlowControlLimit controls the max flow control limit allocated for a QUIC connection. This controls how much data is the
|
||||
// receiver willing to buffer. Once the limit is reached, the sender will send a DATA_BLOCKED frame to indicate it has more data to write,
|
||||
// but it's blocked by flow control
|
||||
QuicConnLevelFlowControlLimit = "quic-connection-level-flow-control-limit"
|
||||
|
||||
// QuicStreamLevelFlowControlLimit is similar to quicConnLevelFlowControlLimit but for each QUIC stream. When the sender is blocked,
|
||||
// it will send a STREAM_DATA_BLOCKED frame
|
||||
QuicStreamLevelFlowControlLimit = "quic-stream-level-flow-control-limit"
|
||||
|
||||
// Ui is to enable launching cloudflared in interactive UI mode
|
||||
Ui = "ui"
|
||||
|
||||
// ConnectorLabel is the command line flag to give a meaningful label to a specific connector
|
||||
ConnectorLabel = "label"
|
||||
|
||||
// MaxActiveFlows is the command line flag to set the maximum number of flows that cloudflared can be processing at the same time
|
||||
MaxActiveFlows = "max-active-flows"
|
||||
|
||||
// Tag is the command line flag to set custom tags used to identify this tunnel via added HTTP request headers to the origin
|
||||
Tag = "tag"
|
||||
|
||||
// Protocol is the command line flag to set the protocol to use to connect to the Cloudflare Edge
|
||||
Protocol = "protocol"
|
||||
|
||||
// PostQuantum is the command line flag to force the connection to Cloudflare Edge to use Post Quantum cryptography
|
||||
PostQuantum = "post-quantum"
|
||||
|
||||
// Features is the command line flag to opt into various features that are still being developed or tested
|
||||
Features = "features"
|
||||
|
||||
// EdgeIpVersion is the command line flag to set the Cloudflare Edge IP address version to connect with
|
||||
EdgeIpVersion = "edge-ip-version"
|
||||
|
||||
// EdgeBindAddress is the command line flag to bind to IP address for outgoing connections to Cloudflare Edge
|
||||
EdgeBindAddress = "edge-bind-address"
|
||||
|
||||
// Force is the command line flag to specify if you wish to force an action
|
||||
Force = "force"
|
||||
|
||||
// Edge is the command line flag to set the address of the Cloudflare tunnel server. Only works in Cloudflare's internal testing environment
|
||||
Edge = "edge"
|
||||
|
||||
// Region is the command line flag to set the Cloudflare Edge region to connect to
|
||||
Region = "region"
|
||||
|
||||
// IsAutoUpdated is the command line flag to signal the new process that cloudflared has been autoupdated
|
||||
IsAutoUpdated = "is-autoupdated"
|
||||
|
||||
// LBPool is the command line flag to set the name of the load balancing pool to add this origin to
|
||||
LBPool = "lb-pool"
|
||||
|
||||
// Retries is the command line flag to set the maximum number of retries for connection/protocol errors
|
||||
Retries = "retries"
|
||||
|
||||
// MaxEdgeAddrRetries is the command line flag to set the maximum number of times to retry on edge addrs before falling back to a lower protocol
|
||||
MaxEdgeAddrRetries = "max-edge-addr-retries"
|
||||
|
||||
// GracePeriod is the command line flag to set the maximum amount of time that cloudflared waits to shut down if it is still serving requests
|
||||
GracePeriod = "grace-period"
|
||||
|
||||
// ICMPV4Src is the command line flag to set the source address and the interface name to send/receive ICMPv4 messages
|
||||
ICMPV4Src = "icmpv4-src"
|
||||
|
||||
// ICMPV6Src is the command line flag to set the source address and the interface name to send/receive ICMPv6 messages
|
||||
ICMPV6Src = "icmpv6-src"
|
||||
|
||||
// ProxyDns is the command line flag to run DNS server over HTTPS
|
||||
ProxyDns = "proxy-dns"
|
||||
|
||||
// Name is the command line to set the name of the tunnel
|
||||
Name = "name"
|
||||
|
||||
// AutoUpdateFreq is the command line for setting the frequency that cloudflared checks for updates
|
||||
AutoUpdateFreq = "autoupdate-freq"
|
||||
|
||||
// NoAutoUpdate is the command line flag to disable cloudflared from checking for updates
|
||||
NoAutoUpdate = "no-autoupdate"
|
||||
|
||||
// LogLevel is the command line flag for the cloudflared logging level
|
||||
LogLevel = "loglevel"
|
||||
|
||||
// LogLevelSSH is the command line flag for the cloudflared ssh logging level
|
||||
LogLevelSSH = "log-level"
|
||||
|
||||
// TransportLogLevel is the command line flag for the transport logging level
|
||||
TransportLogLevel = "transport-loglevel"
|
||||
|
||||
// LogFile is the command line flag to define the file where application logs will be stored
|
||||
LogFile = "logfile"
|
||||
|
||||
// LogDirectory is the command line flag to define the directory where application logs will be stored.
|
||||
LogDirectory = "log-directory"
|
||||
|
||||
// TraceOutput is the command line flag to set the name of trace output file
|
||||
TraceOutput = "trace-output"
|
||||
|
||||
// OriginCert is the command line flag to define the path for the origin certificate used by cloudflared
|
||||
OriginCert = "origincert"
|
||||
|
||||
// Metrics is the command line flag to define the address of the metrics server
|
||||
Metrics = "metrics"
|
||||
|
||||
// MetricsUpdateFreq is the command line flag to define how frequently tunnel metrics are updated
|
||||
MetricsUpdateFreq = "metrics-update-freq"
|
||||
|
||||
// ApiURL is the command line flag used to define the base URL of the API
|
||||
ApiURL = "api-url"
|
||||
)
|
|
@ -3,11 +3,38 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
cli "github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
)
|
||||
|
||||
func runApp(app *cli.App, graceShutdownC chan struct{}) {
|
||||
app.Commands = append(app.Commands, &cli.Command{
|
||||
Name: "service",
|
||||
Usage: "Manages the cloudflared system service (not supported on this operating system)",
|
||||
Subcommands: []*cli.Command{
|
||||
{
|
||||
Name: "install",
|
||||
Usage: "Install cloudflared as a system service (not supported on this operating system)",
|
||||
Action: cliutil.ConfiguredAction(installGenericService),
|
||||
},
|
||||
{
|
||||
Name: "uninstall",
|
||||
Usage: "Uninstall the cloudflared service (not supported on this operating system)",
|
||||
Action: cliutil.ConfiguredAction(uninstallGenericService),
|
||||
},
|
||||
},
|
||||
})
|
||||
app.Run(os.Args)
|
||||
}
|
||||
|
||||
func installGenericService(c *cli.Context) error {
|
||||
return fmt.Errorf("service installation is not supported on this operating system")
|
||||
}
|
||||
|
||||
func uninstallGenericService(c *cli.Context) error {
|
||||
return fmt.Errorf("service uninstallation is not supported on this operating system")
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"os"
|
||||
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
|
@ -17,7 +18,7 @@ const (
|
|||
launchdIdentifier = "com.cloudflare.cloudflared"
|
||||
)
|
||||
|
||||
func runApp(app *cli.App, graceShutdownC chan struct{}) {
|
||||
func runApp(app *cli.App, _ chan struct{}) {
|
||||
app.Commands = append(app.Commands, &cli.Command{
|
||||
Name: "service",
|
||||
Usage: "Manages the cloudflared launch agent",
|
||||
|
@ -119,7 +120,7 @@ func installLaunchd(c *cli.Context) error {
|
|||
log.Info().Msg("Installing cloudflared client as an user launch agent. " +
|
||||
"Note that cloudflared client will only run when the user is logged in. " +
|
||||
"If you want to run cloudflared client at boot, install with root permission. " +
|
||||
"For more information, visit https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/run-as-service")
|
||||
"For more information, visit https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/configure-tunnels/local-management/as-a-service/macos/")
|
||||
}
|
||||
etPath, err := os.Executable()
|
||||
if err != nil {
|
||||
|
@ -207,3 +208,15 @@ func uninstallLaunchd(c *cli.Context) error {
|
|||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func userHomeDir() (string, error) {
|
||||
// This returns the home dir of the executing user using OS-specific method
|
||||
// for discovering the home dir. It's not recommended to call this function
|
||||
// when the user has root permission as $HOME depends on what options the user
|
||||
// use with sudo.
|
||||
homeDir, err := homedir.Dir()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "Cannot determine home directory for the user")
|
||||
}
|
||||
return homeDir, nil
|
||||
}
|
||||
|
|
|
@ -2,19 +2,17 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/urfave/cli/v2"
|
||||
"go.uber.org/automaxprocs/maxprocs"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/access"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/proxydns"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/tail"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/tunnel"
|
||||
|
@ -52,10 +50,8 @@ var (
|
|||
func main() {
|
||||
// FIXME: TUN-8148: Disable QUIC_GO ECN due to bugs in proper detection if supported
|
||||
os.Setenv("QUIC_GO_DISABLE_ECN", "1")
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
metrics.RegisterBuildInfo(BuildType, BuildTime, Version)
|
||||
maxprocs.Set()
|
||||
_, _ = maxprocs.Set()
|
||||
bInfo := cliutil.GetBuildInfo(BuildType, Version)
|
||||
|
||||
// Graceful shutdown channel used by the app. When closed, app must terminate gracefully.
|
||||
|
@ -110,7 +106,7 @@ func commands(version func(c *cli.Context)) []*cli.Command {
|
|||
Usage: "specify if you wish to update to the latest beta version",
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "force",
|
||||
Name: cfdflags.Force,
|
||||
Usage: "specify if you wish to force an upgrade to the latest version regardless of the current version",
|
||||
Hidden: true,
|
||||
},
|
||||
|
@ -184,18 +180,6 @@ func action(graceShutdownC chan struct{}) cli.ActionFunc {
|
|||
})
|
||||
}
|
||||
|
||||
func userHomeDir() (string, error) {
|
||||
// This returns the home dir of the executing user using OS-specific method
|
||||
// for discovering the home dir. It's not recommended to call this function
|
||||
// when the user has root permission as $HOME depends on what options the user
|
||||
// use with sudo.
|
||||
homeDir, err := homedir.Dir()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "Cannot determine home directory for the user")
|
||||
}
|
||||
return homeDir, nil
|
||||
}
|
||||
|
||||
// In order to keep the amount of noise sent to Sentry low, typical network errors can be filtered out here by a substring match.
|
||||
func captureError(err error) {
|
||||
errorMessage := err.Error()
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"text/template"
|
||||
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
|
@ -44,7 +44,7 @@ func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
|
|||
return err
|
||||
}
|
||||
if _, err = os.Stat(resolvedPath); err == nil {
|
||||
return fmt.Errorf(serviceAlreadyExistsWarn(resolvedPath))
|
||||
return errors.New(serviceAlreadyExistsWarn(resolvedPath))
|
||||
}
|
||||
|
||||
var buffer bytes.Buffer
|
||||
|
@ -57,7 +57,7 @@ func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
|
|||
fileMode = st.FileMode
|
||||
}
|
||||
|
||||
plistFolder := path.Dir(resolvedPath)
|
||||
plistFolder := filepath.Dir(resolvedPath)
|
||||
err = os.MkdirAll(plistFolder, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating %s: %v", plistFolder, err)
|
||||
|
@ -118,49 +118,6 @@ func ensureConfigDirExists(configDir string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// openFile opens the file at path. If create is set and the file exists, returns nil, true, nil
|
||||
func openFile(path string, create bool) (file *os.File, exists bool, err error) {
|
||||
expandedPath, err := homedir.Expand(path)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if create {
|
||||
fileInfo, err := os.Stat(expandedPath)
|
||||
if err == nil && fileInfo.Size() > 0 {
|
||||
return nil, true, nil
|
||||
}
|
||||
file, err = os.OpenFile(expandedPath, os.O_RDWR|os.O_CREATE, 0600)
|
||||
} else {
|
||||
file, err = os.Open(expandedPath)
|
||||
}
|
||||
return file, false, err
|
||||
}
|
||||
|
||||
func copyCredential(srcCredentialPath, destCredentialPath string) error {
|
||||
destFile, exists, err := openFile(destCredentialPath, true)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if exists {
|
||||
// credentials already exist, do nothing
|
||||
return nil
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
srcFile, _, err := openFile(srcCredentialPath, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer srcFile.Close()
|
||||
|
||||
// Copy certificate
|
||||
_, err = io.Copy(destFile, srcFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to copy %s to %s: %v", srcCredentialPath, destCredentialPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyFile(src, dest string) error {
|
||||
srcFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
|
@ -187,36 +144,3 @@ func copyFile(src, dest string) error {
|
|||
ok = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyConfig(srcConfigPath, destConfigPath string) error {
|
||||
// Copy or create config
|
||||
destFile, exists, err := openFile(destConfigPath, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot open %s with error: %s", destConfigPath, err)
|
||||
} else if exists {
|
||||
// config already exists, do nothing
|
||||
return nil
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
srcFile, _, err := openFile(srcConfigPath, false)
|
||||
if err != nil {
|
||||
fmt.Println("Your service needs a config file that at least specifies the hostname option.")
|
||||
fmt.Println("Type in a hostname now, or leave it blank and create the config file later.")
|
||||
fmt.Print("Hostname: ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
input, _ := reader.ReadString('\n')
|
||||
if input == "" {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(destFile, "hostname: %s\n", input)
|
||||
} else {
|
||||
defer srcFile.Close()
|
||||
_, err = io.Copy(destFile, srcFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to copy %s to %s: %v", srcConfigPath, destConfigPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -18,14 +18,12 @@ import (
|
|||
"nhooyr.io/websocket"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
"github.com/cloudflare/cloudflared/credentials"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
)
|
||||
|
||||
var (
|
||||
buildInfo *cliutil.BuildInfo
|
||||
)
|
||||
var buildInfo *cliutil.BuildInfo
|
||||
|
||||
func Init(bi *cliutil.BuildInfo) {
|
||||
buildInfo = bi
|
||||
|
@ -56,7 +54,7 @@ func managementTokenCommand(c *cli.Context) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var tokenResponse = struct {
|
||||
tokenResponse := struct {
|
||||
Token string `json:"token"`
|
||||
}{Token: token}
|
||||
|
||||
|
@ -119,13 +117,13 @@ func buildTailCommand(subcommands []*cli.Command) *cli.Command {
|
|||
Value: "",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: logger.LogLevelFlag,
|
||||
Name: cfdflags.LogLevel,
|
||||
Value: "info",
|
||||
Usage: "Application logging level {debug, info, warn, error, fatal}",
|
||||
EnvVars: []string{"TUNNEL_LOGLEVEL"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: credentials.OriginCertFlag,
|
||||
Name: cfdflags.OriginCert,
|
||||
Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
|
||||
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
|
||||
Value: credentials.FindDefaultOriginCertPath(),
|
||||
|
@ -169,7 +167,7 @@ func handleValidationError(resp *http.Response, log *zerolog.Logger) {
|
|||
// logger will be created to emit only against the os.Stderr as to not obstruct with normal output from
|
||||
// management requests
|
||||
func createLogger(c *cli.Context) *zerolog.Logger {
|
||||
level, levelErr := zerolog.ParseLevel(c.String(logger.LogLevelFlag))
|
||||
level, levelErr := zerolog.ParseLevel(c.String(cfdflags.LogLevel))
|
||||
if levelErr != nil {
|
||||
level = zerolog.InfoLevel
|
||||
}
|
||||
|
@ -183,9 +181,10 @@ func createLogger(c *cli.Context) *zerolog.Logger {
|
|||
// parseFilters will attempt to parse provided filters to send to with the EventStartStreaming
|
||||
func parseFilters(c *cli.Context) (*management.StreamingFilters, error) {
|
||||
var level *management.LogLevel
|
||||
var events []management.LogEventType
|
||||
var sample float64
|
||||
|
||||
events := make([]management.LogEventType, 0)
|
||||
|
||||
argLevel := c.String("level")
|
||||
argEvents := c.StringSlice("event")
|
||||
argSample := c.Float64("sample")
|
||||
|
@ -225,12 +224,12 @@ func parseFilters(c *cli.Context) (*management.StreamingFilters, error) {
|
|||
|
||||
// getManagementToken will make a call to the Cloudflare API to acquire a management token for the requested tunnel.
|
||||
func getManagementToken(c *cli.Context, log *zerolog.Logger) (string, error) {
|
||||
userCreds, err := credentials.Read(c.String(credentials.OriginCertFlag), log)
|
||||
userCreds, err := credentials.Read(c.String(cfdflags.OriginCert), log)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
client, err := userCreds.Client(c.String("api-url"), buildInfo.UserAgent(), log)
|
||||
client, err := userCreds.Client(c.String(cfdflags.ApiURL), buildInfo.UserAgent(), log)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -331,6 +330,7 @@ func Run(c *cli.Context) error {
|
|||
header["cf-trace-id"] = []string{trace}
|
||||
}
|
||||
ctx := c.Context
|
||||
// nolint: bodyclose
|
||||
conn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
|
||||
HTTPHeader: header,
|
||||
})
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/trace"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -15,7 +16,7 @@ import (
|
|||
"github.com/facebookgo/grace/gracenet"
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/google/uuid"
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/urfave/cli/v2"
|
||||
|
@ -23,6 +24,7 @@ import (
|
|||
|
||||
"github.com/cloudflare/cloudflared/cfapi"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/proxydns"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
|
@ -30,7 +32,6 @@ import (
|
|||
"github.com/cloudflare/cloudflared/credentials"
|
||||
"github.com/cloudflare/cloudflared/diagnostic"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/features"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
|
@ -47,61 +48,6 @@ import (
|
|||
const (
|
||||
sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
|
||||
|
||||
// ha-Connections specifies how many connections to make to the edge
|
||||
haConnectionsFlag = "ha-connections"
|
||||
|
||||
// sshPortFlag is the port on localhost the cloudflared ssh server will run on
|
||||
sshPortFlag = "local-ssh-port"
|
||||
|
||||
// sshIdleTimeoutFlag defines the duration a SSH session can remain idle before being closed
|
||||
sshIdleTimeoutFlag = "ssh-idle-timeout"
|
||||
|
||||
// sshMaxTimeoutFlag defines the max duration a SSH session can remain open for
|
||||
sshMaxTimeoutFlag = "ssh-max-timeout"
|
||||
|
||||
// bucketNameFlag is the bucket name to use for the SSH log uploader
|
||||
bucketNameFlag = "bucket-name"
|
||||
|
||||
// regionNameFlag is the AWS region name to use for the SSH log uploader
|
||||
regionNameFlag = "region-name"
|
||||
|
||||
// secretIDFlag is the Secret id of SSH log uploader
|
||||
secretIDFlag = "secret-id"
|
||||
|
||||
// accessKeyIDFlag is the Access key id of SSH log uploader
|
||||
accessKeyIDFlag = "access-key-id"
|
||||
|
||||
// sessionTokenIDFlag is the Session token of SSH log uploader
|
||||
sessionTokenIDFlag = "session-token"
|
||||
|
||||
// s3URLFlag is the S3 URL of SSH log uploader (e.g. don't use AWS s3 and use google storage bucket instead)
|
||||
s3URLFlag = "s3-url-host"
|
||||
|
||||
// hostKeyPath is the path of the dir to save SSH host keys too
|
||||
hostKeyPath = "host-key-path"
|
||||
|
||||
// rpcTimeout is how long to wait for a Capnp RPC request to the edge
|
||||
rpcTimeout = "rpc-timeout"
|
||||
|
||||
// writeStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin).
|
||||
writeStreamTimeout = "write-stream-timeout"
|
||||
|
||||
// quicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size.
|
||||
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
|
||||
// Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets.
|
||||
quicDisablePathMTUDiscovery = "quic-disable-pmtu-discovery"
|
||||
|
||||
// quicConnLevelFlowControlLimit controls the max flow control limit allocated for a QUIC connection. This controls how much data is the
|
||||
// receiver willing to buffer. Once the limit is reached, the sender will send a DATA_BLOCKED frame to indicate it has more data to write,
|
||||
// but it's blocked by flow control
|
||||
quicConnLevelFlowControlLimit = "quic-connection-level-flow-control-limit"
|
||||
// quicStreamLevelFlowControlLimit is similar to quicConnLevelFlowControlLimit but for each QUIC stream. When the sender is blocked,
|
||||
// it will send a STREAM_DATA_BLOCKED frame
|
||||
quicStreamLevelFlowControlLimit = "quic-stream-level-flow-control-limit"
|
||||
|
||||
// uiFlag is to enable launching cloudflared in interactive UI mode
|
||||
uiFlag = "ui"
|
||||
|
||||
LogFieldCommand = "command"
|
||||
LogFieldExpandedPath = "expandedPath"
|
||||
LogFieldPIDPathname = "pidPathname"
|
||||
|
@ -116,7 +62,6 @@ Eg. cloudflared tunnel --url localhost:8080/.
|
|||
Please note that Quick Tunnels are meant to be ephemeral and should only be used for testing purposes.
|
||||
For production usage, we recommend creating Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)
|
||||
`
|
||||
connectorLabelFlag = "label"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -126,14 +71,14 @@ var (
|
|||
routeFailMsg = fmt.Sprintf("failed to provision routing, please create it manually via Cloudflare dashboard or UI; "+
|
||||
"most likely you already have a conflicting record there. You can also rerun this command with --%s to overwrite "+
|
||||
"any existing DNS records for this hostname.", overwriteDNSFlag)
|
||||
deprecatedClassicTunnelErr = fmt.Errorf("Classic tunnels have been deprecated, please use Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)")
|
||||
errDeprecatedClassicTunnel = errors.New("Classic tunnels have been deprecated, please use Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)")
|
||||
// TODO: TUN-8756 the list below denotes the flags that do not possess any kind of sensitive information
|
||||
// however this approach is not maintainble in the long-term.
|
||||
nonSecretFlagsList = []string{
|
||||
"config",
|
||||
"autoupdate-freq",
|
||||
"no-autoupdate",
|
||||
"metrics",
|
||||
cfdflags.AutoUpdateFreq,
|
||||
cfdflags.NoAutoUpdate,
|
||||
cfdflags.Metrics,
|
||||
"pidfile",
|
||||
"url",
|
||||
"hello-world",
|
||||
|
@ -166,54 +111,55 @@ var (
|
|||
"bastion",
|
||||
"proxy-address",
|
||||
"proxy-port",
|
||||
"loglevel",
|
||||
"transport-loglevel",
|
||||
"logfile",
|
||||
"log-directory",
|
||||
"trace-output",
|
||||
"proxy-dns",
|
||||
cfdflags.LogLevel,
|
||||
cfdflags.TransportLogLevel,
|
||||
cfdflags.LogFile,
|
||||
cfdflags.LogDirectory,
|
||||
cfdflags.TraceOutput,
|
||||
cfdflags.ProxyDns,
|
||||
"proxy-dns-port",
|
||||
"proxy-dns-address",
|
||||
"proxy-dns-upstream",
|
||||
"proxy-dns-max-upstream-conns",
|
||||
"proxy-dns-bootstrap",
|
||||
"is-autoupdated",
|
||||
"edge",
|
||||
"region",
|
||||
"edge-ip-version",
|
||||
"edge-bind-address",
|
||||
cfdflags.IsAutoUpdated,
|
||||
cfdflags.Edge,
|
||||
cfdflags.Region,
|
||||
cfdflags.EdgeIpVersion,
|
||||
cfdflags.EdgeBindAddress,
|
||||
"cacert",
|
||||
"hostname",
|
||||
"id",
|
||||
"lb-pool",
|
||||
"api-url",
|
||||
"metrics-update-freq",
|
||||
"tag",
|
||||
cfdflags.LBPool,
|
||||
cfdflags.ApiURL,
|
||||
cfdflags.MetricsUpdateFreq,
|
||||
cfdflags.Tag,
|
||||
"heartbeat-interval",
|
||||
"heartbeat-count",
|
||||
"max-edge-addr-retries",
|
||||
"retries",
|
||||
cfdflags.MaxEdgeAddrRetries,
|
||||
cfdflags.Retries,
|
||||
"ha-connections",
|
||||
"rpc-timeout",
|
||||
"write-stream-timeout",
|
||||
"quic-disable-pmtu-discovery",
|
||||
"quic-connection-level-flow-control-limit",
|
||||
"quic-stream-level-flow-control-limit",
|
||||
"label",
|
||||
"grace-period",
|
||||
cfdflags.ConnectorLabel,
|
||||
cfdflags.GracePeriod,
|
||||
"compression-quality",
|
||||
"use-reconnect-token",
|
||||
"dial-edge-timeout",
|
||||
"stdin-control",
|
||||
"name",
|
||||
"ui",
|
||||
cfdflags.Name,
|
||||
cfdflags.Ui,
|
||||
"quick-service",
|
||||
"max-fetch-size",
|
||||
"post-quantum",
|
||||
cfdflags.PostQuantum,
|
||||
"management-diagnostics",
|
||||
"protocol",
|
||||
cfdflags.Protocol,
|
||||
"overwrite-dns",
|
||||
"help",
|
||||
cfdflags.MaxActiveFlows,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -235,6 +181,7 @@ func Commands() []*cli.Command {
|
|||
buildDeleteCommand(),
|
||||
buildCleanupCommand(),
|
||||
buildTokenCommand(),
|
||||
buildDiagCommand(),
|
||||
// for compatibility, allow following as tunnel subcommands
|
||||
proxydns.Command(true),
|
||||
cliutil.RemovedCommand("db-connect"),
|
||||
|
@ -261,7 +208,7 @@ then protect with Cloudflare Access).
|
|||
B) Locally reachable TCP/UDP-based private services to Cloudflare connected private users in the same account, e.g.,
|
||||
those enrolled to a Zero Trust WARP Client.
|
||||
|
||||
You can manage your Tunnels via dash.teams.cloudflare.com. This approach will only require you to run a single command
|
||||
You can manage your Tunnels via one.dash.cloudflare.com. This approach will only require you to run a single command
|
||||
later in each machine where you wish to run a Tunnel.
|
||||
|
||||
Alternatively, you can manage your Tunnels via the command line. Begin by obtaining a certificate to be able to do so:
|
||||
|
@ -297,7 +244,7 @@ func TunnelCommand(c *cli.Context) error {
|
|||
// --name required
|
||||
// --url or --hello-world required
|
||||
// --hostname optional
|
||||
if name := c.String("name"); name != "" {
|
||||
if name := c.String(cfdflags.Name); name != "" {
|
||||
hostname, err := validation.ValidateHostname(c.String("hostname"))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Invalid hostname provided")
|
||||
|
@ -314,7 +261,7 @@ func TunnelCommand(c *cli.Context) error {
|
|||
// A unauthenticated named tunnel hosted on <random>.<quick-tunnels-service>.com
|
||||
// We don't support running proxy-dns and a quick tunnel at the same time as the same process
|
||||
shouldRunQuickTunnel := c.IsSet("url") || c.IsSet(ingress.HelloWorldFlag)
|
||||
if !c.IsSet("proxy-dns") && c.String("quick-service") != "" && shouldRunQuickTunnel {
|
||||
if !c.IsSet(cfdflags.ProxyDns) && c.String("quick-service") != "" && shouldRunQuickTunnel {
|
||||
return RunQuickTunnel(sc)
|
||||
}
|
||||
|
||||
|
@ -325,10 +272,10 @@ func TunnelCommand(c *cli.Context) error {
|
|||
|
||||
// Classic tunnel usage is no longer supported
|
||||
if c.String("hostname") != "" {
|
||||
return deprecatedClassicTunnelErr
|
||||
return errDeprecatedClassicTunnel
|
||||
}
|
||||
|
||||
if c.IsSet("proxy-dns") {
|
||||
if c.IsSet(cfdflags.ProxyDns) {
|
||||
if shouldRunQuickTunnel {
|
||||
return fmt.Errorf("running a quick tunnel with `proxy-dns` is not supported")
|
||||
}
|
||||
|
@ -375,7 +322,7 @@ func runAdhocNamedTunnel(sc *subcommandContext, name, credentialsOutputPath stri
|
|||
|
||||
func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) {
|
||||
if hostname := c.String("hostname"); hostname != "" {
|
||||
if lbPool := c.String("lb-pool"); lbPool != "" {
|
||||
if lbPool := c.String(cfdflags.LBPool); lbPool != "" {
|
||||
return cfapi.NewLBRoute(hostname, lbPool), true
|
||||
}
|
||||
return cfapi.NewDNSRoute(hostname, c.Bool(overwriteDNSFlagName)), true
|
||||
|
@ -405,7 +352,7 @@ func StartServer(
|
|||
log.Info().Msg(config.ErrNoConfigFile.Error())
|
||||
}
|
||||
|
||||
if c.IsSet("trace-output") {
|
||||
if c.IsSet(cfdflags.TraceOutput) {
|
||||
tmpTraceFile, err := os.CreateTemp("", "trace")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to create new temporary file to save trace output")
|
||||
|
@ -417,7 +364,7 @@ func StartServer(
|
|||
if err := tmpTraceFile.Close(); err != nil {
|
||||
traceLog.Err(err).Msg("Failed to close temporary trace output file")
|
||||
}
|
||||
traceOutputFilepath := c.String("trace-output")
|
||||
traceOutputFilepath := c.String(cfdflags.TraceOutput)
|
||||
if err := os.Rename(tmpTraceFile.Name(), traceOutputFilepath); err != nil {
|
||||
traceLog.
|
||||
Err(err).
|
||||
|
@ -447,7 +394,7 @@ func StartServer(
|
|||
|
||||
go waitForSignal(graceShutdownC, log)
|
||||
|
||||
if c.IsSet("proxy-dns") {
|
||||
if c.IsSet(cfdflags.ProxyDns) {
|
||||
dnsReadySignal := make(chan struct{})
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
|
@ -469,7 +416,7 @@ func StartServer(
|
|||
go func() {
|
||||
defer wg.Done()
|
||||
autoupdater := updater.NewAutoUpdater(
|
||||
c.Bool("no-autoupdate"), c.Duration("autoupdate-freq"), &listeners, log,
|
||||
c.Bool(cfdflags.NoAutoUpdate), c.Duration(cfdflags.AutoUpdateFreq), &listeners, log,
|
||||
)
|
||||
errC <- autoupdater.Run(ctx)
|
||||
}()
|
||||
|
@ -513,8 +460,6 @@ func StartServer(
|
|||
tunnelConfig.ICMPRouterServer = nil
|
||||
}
|
||||
|
||||
internalRules := []ingress.Rule{}
|
||||
if features.Contains(features.FeatureManagementLogs) {
|
||||
serviceIP := c.String("service-op-ip")
|
||||
if edgeAddrs, err := edgediscovery.ResolveEdge(log, tunnelConfig.Region, tunnelConfig.EdgeIPVersion); err == nil {
|
||||
if serviceAddr, err := edgeAddrs.GetAddrForRPC(); err == nil {
|
||||
|
@ -527,12 +472,11 @@ func StartServer(
|
|||
c.Bool("management-diagnostics"),
|
||||
serviceIP,
|
||||
clientID,
|
||||
c.String(connectorLabelFlag),
|
||||
c.String(cfdflags.ConnectorLabel),
|
||||
logger.ManagementLogger.Log,
|
||||
logger.ManagementLogger,
|
||||
)
|
||||
internalRules = []ingress.Rule{ingress.NewManagementRule(mgmt)}
|
||||
}
|
||||
internalRules := []ingress.Rule{ingress.NewManagementRule(mgmt)}
|
||||
orchestrator, err := orchestration.NewOrchestrator(ctx, orchestratorConfig, tunnelConfig.Tags, internalRules, tunnelConfig.Log)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -552,7 +496,15 @@ func StartServer(
|
|||
tracker := tunnelstate.NewConnTracker(log)
|
||||
observer.RegisterSink(tracker)
|
||||
|
||||
ipv4, ipv6, err := determineICMPSources(c, log)
|
||||
sources := make([]string, 0)
|
||||
if err == nil {
|
||||
sources = append(sources, ipv4.String())
|
||||
sources = append(sources, ipv6.String())
|
||||
}
|
||||
|
||||
readinessServer := metrics.NewReadyServer(clientID, tracker)
|
||||
cliFlags := nonSecretCliFlags(log, c, nonSecretFlagsList)
|
||||
diagnosticHandler := diagnostic.NewDiagnosticHandler(
|
||||
log,
|
||||
0,
|
||||
|
@ -560,8 +512,8 @@ func StartServer(
|
|||
tunnelConfig.NamedTunnel.Credentials.TunnelID,
|
||||
clientID,
|
||||
tracker,
|
||||
c,
|
||||
nonSecretFlagsList,
|
||||
cliFlags,
|
||||
sources,
|
||||
)
|
||||
metricsConfig := metrics.Config{
|
||||
ReadyServer: readinessServer,
|
||||
|
@ -572,7 +524,7 @@ func StartServer(
|
|||
errC <- metrics.ServeMetrics(metricsListener, ctx, metricsConfig, log)
|
||||
}()
|
||||
|
||||
reconnectCh := make(chan supervisor.ReconnectSignal, c.Int(haConnectionsFlag))
|
||||
reconnectCh := make(chan supervisor.ReconnectSignal, c.Int(cfdflags.HaConnections))
|
||||
if c.IsSet("stdin-control") {
|
||||
log.Info().Msg("Enabling control through stdin")
|
||||
go stdinControl(reconnectCh, log)
|
||||
|
@ -609,8 +561,10 @@ func waitToShutdown(wg *sync.WaitGroup,
|
|||
log.Debug().Msg("Graceful shutdown signalled")
|
||||
if gracePeriod > 0 {
|
||||
// wait for either grace period or service termination
|
||||
ticker := time.NewTicker(gracePeriod)
|
||||
defer ticker.Stop()
|
||||
select {
|
||||
case <-time.Tick(gracePeriod):
|
||||
case <-ticker.C:
|
||||
case <-errC:
|
||||
}
|
||||
}
|
||||
|
@ -638,7 +592,7 @@ func waitToShutdown(wg *sync.WaitGroup,
|
|||
|
||||
func notifySystemd(waitForSignal *signal.Signal) {
|
||||
<-waitForSignal.Wait()
|
||||
daemon.SdNotify(false, "READY=1")
|
||||
_, _ = daemon.SdNotify(false, "READY=1")
|
||||
}
|
||||
|
||||
func writePidFile(waitForSignal *signal.Signal, pidPathname string, log *zerolog.Logger) {
|
||||
|
@ -690,31 +644,31 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
flags = append(flags, []cli.Flag{
|
||||
credentialsFileFlag,
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "is-autoupdated",
|
||||
Name: cfdflags.IsAutoUpdated,
|
||||
Usage: "Signal the new process that Cloudflare Tunnel connector has been autoupdated",
|
||||
Value: false,
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||
Name: "edge",
|
||||
Name: cfdflags.Edge,
|
||||
Usage: "Address of the Cloudflare tunnel server. Only works in Cloudflare's internal testing environment.",
|
||||
EnvVars: []string{"TUNNEL_EDGE"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "region",
|
||||
Name: cfdflags.Region,
|
||||
Usage: "Cloudflare Edge region to connect to. Omit or set to empty to connect to the global region.",
|
||||
EnvVars: []string{"TUNNEL_REGION"},
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "edge-ip-version",
|
||||
Name: cfdflags.EdgeIpVersion,
|
||||
Usage: "Cloudflare Edge IP address version to connect with. {4, 6, auto}",
|
||||
EnvVars: []string{"TUNNEL_EDGE_IP_VERSION"},
|
||||
Value: "4",
|
||||
Hidden: false,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "edge-bind-address",
|
||||
Name: cfdflags.EdgeBindAddress,
|
||||
Usage: "Bind to IP address for outgoing connections to Cloudflare Edge.",
|
||||
EnvVars: []string{"TUNNEL_EDGE_BIND_ADDRESS"},
|
||||
Hidden: false,
|
||||
|
@ -738,7 +692,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "lb-pool",
|
||||
Name: cfdflags.LBPool,
|
||||
Usage: "The name of a (new/existing) load balancing pool to add this origin to.",
|
||||
EnvVars: []string{"TUNNEL_LB_POOL"},
|
||||
Hidden: shouldHide,
|
||||
|
@ -762,21 +716,21 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "api-url",
|
||||
Name: cfdflags.ApiURL,
|
||||
Usage: "Base URL for Cloudflare API v4",
|
||||
EnvVars: []string{"TUNNEL_API_URL"},
|
||||
Value: "https://api.cloudflare.com/client/v4",
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||
Name: "metrics-update-freq",
|
||||
Name: cfdflags.MetricsUpdateFreq,
|
||||
Usage: "Frequency to update tunnel metrics",
|
||||
Value: time.Second * 5,
|
||||
EnvVars: []string{"TUNNEL_METRICS_UPDATE_FREQ"},
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||
Name: "tag",
|
||||
Name: cfdflags.Tag,
|
||||
Usage: "Custom tags used to identify this tunnel via added HTTP request headers to the origin, in format `KEY=VALUE`. Multiple tags may be specified.",
|
||||
EnvVars: []string{"TUNNEL_TAG"},
|
||||
Hidden: true,
|
||||
|
@ -795,64 +749,64 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewIntFlag(&cli.IntFlag{
|
||||
Name: "max-edge-addr-retries",
|
||||
Name: cfdflags.MaxEdgeAddrRetries,
|
||||
Usage: "Maximum number of times to retry on edge addrs before falling back to a lower protocol",
|
||||
Value: 8,
|
||||
Hidden: true,
|
||||
}),
|
||||
// Note TUN-3758 , we use Int because UInt is not supported with altsrc
|
||||
altsrc.NewIntFlag(&cli.IntFlag{
|
||||
Name: "retries",
|
||||
Name: cfdflags.Retries,
|
||||
Value: 5,
|
||||
Usage: "Maximum number of retries for connection/protocol errors.",
|
||||
EnvVars: []string{"TUNNEL_RETRIES"},
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewIntFlag(&cli.IntFlag{
|
||||
Name: haConnectionsFlag,
|
||||
Name: cfdflags.HaConnections,
|
||||
Value: 4,
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||
Name: rpcTimeout,
|
||||
Name: cfdflags.RpcTimeout,
|
||||
Value: 5 * time.Second,
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||
Name: writeStreamTimeout,
|
||||
Name: cfdflags.WriteStreamTimeout,
|
||||
EnvVars: []string{"TUNNEL_STREAM_WRITE_TIMEOUT"},
|
||||
Usage: "Use this option to add a stream write timeout for connections when writing towards the origin or edge. Default is 0 which disables the write timeout.",
|
||||
Value: 0 * time.Second,
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: quicDisablePathMTUDiscovery,
|
||||
Name: cfdflags.QuicDisablePathMTUDiscovery,
|
||||
EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"},
|
||||
Usage: "Use this option to disable PTMU discovery for QUIC connections. This will result in lower packet sizes. Not however, that this may cause instability for UDP proxying.",
|
||||
Value: false,
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewIntFlag(&cli.IntFlag{
|
||||
Name: quicConnLevelFlowControlLimit,
|
||||
Name: cfdflags.QuicConnLevelFlowControlLimit,
|
||||
EnvVars: []string{"TUNNEL_QUIC_CONN_LEVEL_FLOW_CONTROL_LIMIT"},
|
||||
Usage: "Use this option to change the connection-level flow control limit for QUIC transport.",
|
||||
Value: 30 * (1 << 20), // 30 MB
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewIntFlag(&cli.IntFlag{
|
||||
Name: quicStreamLevelFlowControlLimit,
|
||||
Name: cfdflags.QuicStreamLevelFlowControlLimit,
|
||||
EnvVars: []string{"TUNNEL_QUIC_STREAM_LEVEL_FLOW_CONTROL_LIMIT"},
|
||||
Usage: "Use this option to change the connection-level flow control limit for QUIC transport.",
|
||||
Value: 6 * (1 << 20), // 6 MB
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: connectorLabelFlag,
|
||||
Name: cfdflags.ConnectorLabel,
|
||||
Usage: "Use this option to give a meaningful label to a specific connector. When a tunnel starts up, a connector id unique to the tunnel is generated. This is a uuid. To make it easier to identify a connector, we will use the hostname of the machine the tunnel is running on along with the connector ID. This option exists if one wants to have more control over what their individual connectors are called.",
|
||||
Value: "",
|
||||
}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||
Name: "grace-period",
|
||||
Name: cfdflags.GracePeriod,
|
||||
Usage: "When cloudflared receives SIGINT/SIGTERM it will stop accepting new requests, wait for in-progress requests to terminate, then shutdown. Waiting for in-progress requests will timeout after this grace period, or when a second SIGTERM/SIGINT is received.",
|
||||
Value: time.Second * 30,
|
||||
EnvVars: []string{"TUNNEL_GRACE_PERIOD"},
|
||||
|
@ -888,14 +842,14 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
Value: false,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "name",
|
||||
Name: cfdflags.Name,
|
||||
Aliases: []string{"n"},
|
||||
EnvVars: []string{"TUNNEL_NAME"},
|
||||
Usage: "Stable name to identify the tunnel. Using this flag will create, route and run a tunnel. For production usage, execute each command separately",
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: uiFlag,
|
||||
Name: cfdflags.Ui,
|
||||
Usage: "(depreciated) Launch tunnel UI. Tunnel logs are scrollable via 'j', 'k', or arrow keys.",
|
||||
Value: false,
|
||||
Hidden: true,
|
||||
|
@ -913,11 +867,10 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
|
|||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "post-quantum",
|
||||
Name: cfdflags.PostQuantum,
|
||||
Usage: "When given creates an experimental post-quantum secure tunnel",
|
||||
Aliases: []string{"pq"},
|
||||
EnvVars: []string{"TUNNEL_POST_QUANTUM"},
|
||||
Hidden: FipsEnabled,
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "management-diagnostics",
|
||||
|
@ -942,27 +895,27 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag {
|
|||
Hidden: shouldHide,
|
||||
},
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: credentials.OriginCertFlag,
|
||||
Name: cfdflags.OriginCert,
|
||||
Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
|
||||
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
|
||||
Value: credentials.FindDefaultOriginCertPath(),
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||
Name: "autoupdate-freq",
|
||||
Name: cfdflags.AutoUpdateFreq,
|
||||
Usage: fmt.Sprintf("Autoupdate frequency. Default is %v.", updater.DefaultCheckUpdateFreq),
|
||||
Value: updater.DefaultCheckUpdateFreq,
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "no-autoupdate",
|
||||
Name: cfdflags.NoAutoUpdate,
|
||||
Usage: "Disable periodic check for updates, restarting the server with the new version.",
|
||||
EnvVars: []string{"NO_AUTOUPDATE"},
|
||||
Value: false,
|
||||
Hidden: shouldHide,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "metrics",
|
||||
Name: cfdflags.Metrics,
|
||||
Value: metrics.GetMetricsDefaultAddress(metrics.Runtime),
|
||||
Usage: fmt.Sprintf(
|
||||
`Listen address for metrics reporting. If no address is passed cloudflared will try to bind to %v.
|
||||
|
@ -1126,62 +1079,62 @@ func legacyTunnelFlag(msg string) string {
|
|||
func sshFlags(shouldHide bool) []cli.Flag {
|
||||
return []cli.Flag{
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: sshPortFlag,
|
||||
Name: cfdflags.SshPort,
|
||||
Usage: "Localhost port that cloudflared SSH server will run on",
|
||||
Value: "2222",
|
||||
EnvVars: []string{"LOCAL_SSH_PORT"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||
Name: sshIdleTimeoutFlag,
|
||||
Name: cfdflags.SshIdleTimeout,
|
||||
Usage: "Connection timeout after no activity",
|
||||
EnvVars: []string{"SSH_IDLE_TIMEOUT"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewDurationFlag(&cli.DurationFlag{
|
||||
Name: sshMaxTimeoutFlag,
|
||||
Name: cfdflags.SshMaxTimeout,
|
||||
Usage: "Absolute connection timeout",
|
||||
EnvVars: []string{"SSH_MAX_TIMEOUT"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: bucketNameFlag,
|
||||
Name: cfdflags.SshLogUploaderBucketName,
|
||||
Usage: "Bucket name of where to upload SSH logs",
|
||||
EnvVars: []string{"BUCKET_ID"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: regionNameFlag,
|
||||
Name: cfdflags.SshLogUploaderRegionName,
|
||||
Usage: "Region name of where to upload SSH logs",
|
||||
EnvVars: []string{"REGION_ID"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: secretIDFlag,
|
||||
Name: cfdflags.SshLogUploaderSecretID,
|
||||
Usage: "Secret ID of where to upload SSH logs",
|
||||
EnvVars: []string{"SECRET_ID"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: accessKeyIDFlag,
|
||||
Name: cfdflags.SshLogUploaderAccessKeyID,
|
||||
Usage: "Access Key ID of where to upload SSH logs",
|
||||
EnvVars: []string{"ACCESS_CLIENT_ID"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: sessionTokenIDFlag,
|
||||
Name: cfdflags.SshLogUploaderSessionTokenID,
|
||||
Usage: "Session Token to use in the configuration of SSH logs uploading",
|
||||
EnvVars: []string{"SESSION_TOKEN_ID"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: s3URLFlag,
|
||||
Name: cfdflags.SshLogUploaderS3URL,
|
||||
Usage: "S3 url of where to upload SSH logs",
|
||||
EnvVars: []string{"S3_URL"},
|
||||
Hidden: true,
|
||||
}),
|
||||
altsrc.NewPathFlag(&cli.PathFlag{
|
||||
Name: hostKeyPath,
|
||||
Name: cfdflags.HostKeyPath,
|
||||
Usage: "Absolute path of directory to save SSH host keys in",
|
||||
EnvVars: []string{"HOST_KEY_PATH"},
|
||||
Hidden: true,
|
||||
|
@ -1221,7 +1174,7 @@ func sshFlags(shouldHide bool) []cli.Flag {
|
|||
func configureProxyDNSFlags(shouldHide bool) []cli.Flag {
|
||||
return []cli.Flag{
|
||||
altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "proxy-dns",
|
||||
Name: cfdflags.ProxyDns,
|
||||
Usage: "Run a DNS over HTTPS proxy server.",
|
||||
EnvVars: []string{"TUNNEL_DNS"},
|
||||
Hidden: shouldHide,
|
||||
|
@ -1301,3 +1254,46 @@ reconnect [delay]
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func nonSecretCliFlags(log *zerolog.Logger, cli *cli.Context, flagInclusionList []string) map[string]string {
|
||||
flagsNames := cli.FlagNames()
|
||||
flags := make(map[string]string, len(flagsNames))
|
||||
|
||||
for _, flag := range flagsNames {
|
||||
value := cli.String(flag)
|
||||
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
isIncluded := isFlagIncluded(flagInclusionList, flag)
|
||||
if !isIncluded {
|
||||
continue
|
||||
}
|
||||
|
||||
switch flag {
|
||||
case cfdflags.LogDirectory, cfdflags.LogFile:
|
||||
{
|
||||
absolute, err := filepath.Abs(value)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("could not convert %s path to absolute", flag)
|
||||
} else {
|
||||
flags[flag] = absolute
|
||||
}
|
||||
}
|
||||
default:
|
||||
flags[flag] = value
|
||||
}
|
||||
}
|
||||
return flags
|
||||
}
|
||||
|
||||
func isFlagIncluded(flagInclusionList []string, flag string) bool {
|
||||
for _, include := range flagInclusionList {
|
||||
if include == flag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"golang.org/x/term"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||
|
@ -33,26 +34,27 @@ import (
|
|||
const (
|
||||
secretValue = "*****"
|
||||
icmpFunnelTimeout = time.Second * 10
|
||||
fedRampRegion = "fed" // const string denoting the region used to connect to FEDRamp servers
|
||||
)
|
||||
|
||||
var (
|
||||
developerPortal = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup"
|
||||
serviceUrl = developerPortal + "/tunnel-guide/local/as-a-service/"
|
||||
argumentsUrl = developerPortal + "/tunnel-guide/local/local-management/arguments/"
|
||||
|
||||
secretFlags = [2]*altsrc.StringFlag{credentialsContentsFlag, tunnelTokenFlag}
|
||||
|
||||
configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address"}
|
||||
)
|
||||
|
||||
func generateRandomClientID(log *zerolog.Logger) (string, error) {
|
||||
u, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
log.Error().Msgf("couldn't create UUID for client ID %s", err)
|
||||
return "", err
|
||||
configFlags = []string{
|
||||
flags.AutoUpdateFreq,
|
||||
flags.NoAutoUpdate,
|
||||
flags.Retries,
|
||||
flags.Protocol,
|
||||
flags.LogLevel,
|
||||
flags.TransportLogLevel,
|
||||
flags.OriginCert,
|
||||
flags.Metrics,
|
||||
flags.MetricsUpdateFreq,
|
||||
flags.EdgeIpVersion,
|
||||
flags.EdgeBindAddress,
|
||||
flags.MaxActiveFlows,
|
||||
}
|
||||
return u.String(), nil
|
||||
}
|
||||
)
|
||||
|
||||
func logClientOptions(c *cli.Context, log *zerolog.Logger) {
|
||||
flags := make(map[string]interface{})
|
||||
|
@ -109,8 +111,8 @@ func isSecretEnvVar(key string) bool {
|
|||
}
|
||||
|
||||
func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.TunnelProperties) bool {
|
||||
return c.IsSet("proxy-dns") &&
|
||||
!(c.IsSet("name") || // adhoc-named tunnel
|
||||
return c.IsSet(flags.ProxyDns) &&
|
||||
!(c.IsSet(flags.Name) || // adhoc-named tunnel
|
||||
c.IsSet(ingress.HelloWorldFlag) || // quick or named tunnel
|
||||
namedTunnel != nil) // named tunnel
|
||||
}
|
||||
|
@ -128,29 +130,21 @@ func prepareTunnelConfig(
|
|||
return nil, nil, errors.Wrap(err, "can't generate connector UUID")
|
||||
}
|
||||
log.Info().Msgf("Generated Connector ID: %s", clientID)
|
||||
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
|
||||
tags, err := NewTagSliceFromCLI(c.StringSlice(flags.Tag))
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Tag parse failure")
|
||||
return nil, nil, errors.Wrap(err, "Tag parse failure")
|
||||
}
|
||||
tags = append(tags, pogs.Tag{Name: "ID", Value: clientID.String()})
|
||||
|
||||
transportProtocol := c.String("protocol")
|
||||
transportProtocol := c.String(flags.Protocol)
|
||||
isPostQuantumEnforced := c.Bool(flags.PostQuantum)
|
||||
|
||||
clientFeatures := features.Dedup(append(c.StringSlice("features"), features.DefaultFeatures...))
|
||||
|
||||
staticFeatures := features.StaticFeatures{}
|
||||
if c.Bool("post-quantum") {
|
||||
if FipsEnabled {
|
||||
return nil, nil, fmt.Errorf("post-quantum not supported in FIPS mode")
|
||||
}
|
||||
pqMode := features.PostQuantumStrict
|
||||
staticFeatures.PostQuantumMode = &pqMode
|
||||
}
|
||||
featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, staticFeatures, log)
|
||||
featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, c.StringSlice("features"), c.Bool("post-quantum"), log)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "Failed to create feature selector")
|
||||
}
|
||||
clientFeatures := featureSelector.ClientFeatures()
|
||||
pqMode := featureSelector.PostQuantumMode()
|
||||
if pqMode == features.PostQuantumStrict {
|
||||
// Error if the user tries to force a non-quic transport protocol
|
||||
|
@ -158,12 +152,6 @@ func prepareTunnelConfig(
|
|||
return nil, nil, fmt.Errorf("post-quantum is only supported with the quic transport")
|
||||
}
|
||||
transportProtocol = connection.QUIC.String()
|
||||
clientFeatures = append(clientFeatures, features.FeaturePostQuantum)
|
||||
|
||||
log.Info().Msgf(
|
||||
"Using hybrid post-quantum key agreement %s",
|
||||
supervisor.PQKexName,
|
||||
)
|
||||
}
|
||||
|
||||
namedTunnel.Client = pogs.ClientInfo{
|
||||
|
@ -178,7 +166,7 @@ func prepareTunnelConfig(
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
protocolSelector, err := connection.NewProtocolSelector(transportProtocol, namedTunnel.Credentials.AccountTag, c.IsSet(TunnelTokenFlag), c.Bool("post-quantum"), edgediscovery.ProtocolPercentage, connection.ResolveTTL, log)
|
||||
protocolSelector, err := connection.NewProtocolSelector(transportProtocol, namedTunnel.Credentials.AccountTag, c.IsSet(TunnelTokenFlag), isPostQuantumEnforced, edgediscovery.ProtocolPercentage, connection.ResolveTTL, log)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -204,11 +192,11 @@ func prepareTunnelConfig(
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
edgeIPVersion, err := parseConfigIPVersion(c.String("edge-ip-version"))
|
||||
edgeIPVersion, err := parseConfigIPVersion(c.String(flags.EdgeIpVersion))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
edgeBindAddr, err := parseConfigBindAddress(c.String("edge-bind-address"))
|
||||
edgeBindAddr, err := parseConfigBindAddress(c.String(flags.EdgeBindAddress))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -221,36 +209,50 @@ func prepareTunnelConfig(
|
|||
log.Warn().Str("edgeIPVersion", edgeIPVersion.String()).Err(err).Msg("Overriding edge-ip-version")
|
||||
}
|
||||
|
||||
region := c.String(flags.Region)
|
||||
endpoint := namedTunnel.Credentials.Endpoint
|
||||
var resolvedRegion string
|
||||
// set resolvedRegion to either the region passed as argument
|
||||
// or to the endpoint in the credentials.
|
||||
// Region and endpoint are interchangeable
|
||||
if region != "" && endpoint != "" {
|
||||
return nil, nil, fmt.Errorf("region provided with a token that has an endpoint")
|
||||
} else if region != "" {
|
||||
resolvedRegion = region
|
||||
} else if endpoint != "" {
|
||||
resolvedRegion = endpoint
|
||||
}
|
||||
|
||||
tunnelConfig := &supervisor.TunnelConfig{
|
||||
GracePeriod: gracePeriod,
|
||||
ReplaceExisting: c.Bool("force"),
|
||||
ReplaceExisting: c.Bool(flags.Force),
|
||||
OSArch: info.OSArch(),
|
||||
ClientID: clientID.String(),
|
||||
EdgeAddrs: c.StringSlice("edge"),
|
||||
Region: c.String("region"),
|
||||
EdgeAddrs: c.StringSlice(flags.Edge),
|
||||
Region: resolvedRegion,
|
||||
EdgeIPVersion: edgeIPVersion,
|
||||
EdgeBindAddr: edgeBindAddr,
|
||||
HAConnections: c.Int(haConnectionsFlag),
|
||||
IsAutoupdated: c.Bool("is-autoupdated"),
|
||||
LBPool: c.String("lb-pool"),
|
||||
HAConnections: c.Int(flags.HaConnections),
|
||||
IsAutoupdated: c.Bool(flags.IsAutoUpdated),
|
||||
LBPool: c.String(flags.LBPool),
|
||||
Tags: tags,
|
||||
Log: log,
|
||||
LogTransport: logTransport,
|
||||
Observer: observer,
|
||||
ReportedVersion: info.Version(),
|
||||
// Note TUN-3758 , we use Int because UInt is not supported with altsrc
|
||||
Retries: uint(c.Int("retries")),
|
||||
Retries: uint(c.Int(flags.Retries)), // nolint: gosec
|
||||
RunFromTerminal: isRunningFromTerminal(),
|
||||
NamedTunnel: namedTunnel,
|
||||
ProtocolSelector: protocolSelector,
|
||||
EdgeTLSConfigs: edgeTLSConfigs,
|
||||
FeatureSelector: featureSelector,
|
||||
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
|
||||
RPCTimeout: c.Duration(rpcTimeout),
|
||||
WriteStreamTimeout: c.Duration(writeStreamTimeout),
|
||||
DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery),
|
||||
QUICConnectionLevelFlowControlLimit: c.Uint64(quicConnLevelFlowControlLimit),
|
||||
QUICStreamLevelFlowControlLimit: c.Uint64(quicStreamLevelFlowControlLimit),
|
||||
MaxEdgeAddrRetries: uint8(c.Int(flags.MaxEdgeAddrRetries)), // nolint: gosec
|
||||
RPCTimeout: c.Duration(flags.RpcTimeout),
|
||||
WriteStreamTimeout: c.Duration(flags.WriteStreamTimeout),
|
||||
DisableQUICPathMTUDiscovery: c.Bool(flags.QuicDisablePathMTUDiscovery),
|
||||
QUICConnectionLevelFlowControlLimit: c.Uint64(flags.QuicConnLevelFlowControlLimit),
|
||||
QUICStreamLevelFlowControlLimit: c.Uint64(flags.QuicStreamLevelFlowControlLimit),
|
||||
}
|
||||
icmpRouter, err := newICMPRouter(c, log)
|
||||
if err != nil {
|
||||
|
@ -262,7 +264,7 @@ func prepareTunnelConfig(
|
|||
Ingress: &ingressRules,
|
||||
WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting),
|
||||
ConfigurationFlags: parseConfigFlags(c),
|
||||
WriteTimeout: c.Duration(writeStreamTimeout),
|
||||
WriteTimeout: tunnelConfig.WriteStreamTimeout,
|
||||
}
|
||||
return tunnelConfig, orchestratorConfig, nil
|
||||
}
|
||||
|
@ -280,9 +282,9 @@ func parseConfigFlags(c *cli.Context) map[string]string {
|
|||
}
|
||||
|
||||
func gracePeriod(c *cli.Context) (time.Duration, error) {
|
||||
period := c.Duration("grace-period")
|
||||
period := c.Duration(flags.GracePeriod)
|
||||
if period > connection.MaxGracePeriod {
|
||||
return time.Duration(0), fmt.Errorf("grace-period must be equal or less than %v", connection.MaxGracePeriod)
|
||||
return time.Duration(0), fmt.Errorf("%s must be equal or less than %v", flags.GracePeriod, connection.MaxGracePeriod)
|
||||
}
|
||||
return period, nil
|
||||
}
|
||||
|
@ -352,20 +354,9 @@ func adjustIPVersionByBindAddress(ipVersion allregions.ConfigIPVersion, ip net.I
|
|||
}
|
||||
|
||||
func newICMPRouter(c *cli.Context, logger *zerolog.Logger) (ingress.ICMPRouterServer, error) {
|
||||
ipv4Src, err := determineICMPv4Src(c.String("icmpv4-src"), logger)
|
||||
ipv4Src, ipv6Src, err := determineICMPSources(c, logger)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to determine IPv4 source address for ICMP proxy")
|
||||
}
|
||||
logger.Info().Msgf("ICMP proxy will use %s as source for IPv4", ipv4Src)
|
||||
|
||||
ipv6Src, zone, err := determineICMPv6Src(c.String("icmpv6-src"), logger, ipv4Src)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to determine IPv6 source address for ICMP proxy")
|
||||
}
|
||||
if zone != "" {
|
||||
logger.Info().Msgf("ICMP proxy will use %s in zone %s as source for IPv6", ipv6Src, zone)
|
||||
} else {
|
||||
logger.Info().Msgf("ICMP proxy will use %s as source for IPv6", ipv6Src)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
icmpRouter, err := ingress.NewICMPRouter(ipv4Src, ipv6Src, logger, icmpFunnelTimeout)
|
||||
|
@ -375,6 +366,28 @@ func newICMPRouter(c *cli.Context, logger *zerolog.Logger) (ingress.ICMPRouterSe
|
|||
return icmpRouter, nil
|
||||
}
|
||||
|
||||
func determineICMPSources(c *cli.Context, logger *zerolog.Logger) (netip.Addr, netip.Addr, error) {
|
||||
ipv4Src, err := determineICMPv4Src(c.String(flags.ICMPV4Src), logger)
|
||||
if err != nil {
|
||||
return netip.Addr{}, netip.Addr{}, errors.Wrap(err, "failed to determine IPv4 source address for ICMP proxy")
|
||||
}
|
||||
|
||||
logger.Info().Msgf("ICMP proxy will use %s as source for IPv4", ipv4Src)
|
||||
|
||||
ipv6Src, zone, err := determineICMPv6Src(c.String(flags.ICMPV6Src), logger, ipv4Src)
|
||||
if err != nil {
|
||||
return netip.Addr{}, netip.Addr{}, errors.Wrap(err, "failed to determine IPv6 source address for ICMP proxy")
|
||||
}
|
||||
|
||||
if zone != "" {
|
||||
logger.Info().Msgf("ICMP proxy will use %s in zone %s as source for IPv6", ipv6Src, zone)
|
||||
} else {
|
||||
logger.Info().Msgf("ICMP proxy will use %s as source for IPv6", ipv6Src)
|
||||
}
|
||||
|
||||
return ipv4Src, ipv6Src, nil
|
||||
}
|
||||
|
||||
func determineICMPv4Src(userDefinedSrc string, logger *zerolog.Logger) (netip.Addr, error) {
|
||||
if userDefinedSrc != "" {
|
||||
addr, err := netip.ParseAddr(userDefinedSrc)
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/credentials"
|
||||
|
||||
|
@ -57,7 +58,7 @@ func newSearchByID(id uuid.UUID, c *cli.Context, log *zerolog.Logger, fs fileSys
|
|||
}
|
||||
|
||||
func (s searchByID) Path() (string, error) {
|
||||
originCertPath := s.c.String(credentials.OriginCertFlag)
|
||||
originCertPath := s.c.String(cfdflags.OriginCert)
|
||||
originCertLog := s.log.With().
|
||||
Str("originCertPath", originCertPath).
|
||||
Logger()
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
package tunnel
|
||||
|
||||
var FipsEnabled bool
|
|
@ -20,7 +20,31 @@ import (
|
|||
|
||||
const (
|
||||
baseLoginURL = "https://dash.cloudflare.com/argotunnel"
|
||||
callbackStoreURL = "https://login.cloudflareaccess.org/"
|
||||
callbackURL = "https://login.cloudflareaccess.org/"
|
||||
// For now these are the same but will change in the future once we know which URLs to use (TUN-8872)
|
||||
fedBaseLoginURL = "https://dash.cloudflare.com/argotunnel"
|
||||
fedCallbackStoreURL = "https://login.cloudflareaccess.org/"
|
||||
fedRAMPParamName = "fedramp"
|
||||
loginURLParamName = "loginURL"
|
||||
callbackURLParamName = "callbackURL"
|
||||
)
|
||||
|
||||
var (
|
||||
loginURL = &cli.StringFlag{
|
||||
Name: loginURLParamName,
|
||||
Value: baseLoginURL,
|
||||
Usage: "The URL used to login (default is https://dash.cloudflare.com/argotunnel)",
|
||||
}
|
||||
callbackStore = &cli.StringFlag{
|
||||
Name: callbackURLParamName,
|
||||
Value: callbackURL,
|
||||
Usage: "The URL used for the callback (default is https://login.cloudflareaccess.org/)",
|
||||
}
|
||||
fedramp = &cli.BoolFlag{
|
||||
Name: fedRAMPParamName,
|
||||
Aliases: []string{"f"},
|
||||
Usage: "Login with FedRAMP High environment.",
|
||||
}
|
||||
)
|
||||
|
||||
func buildLoginSubcommand(hidden bool) *cli.Command {
|
||||
|
@ -30,6 +54,11 @@ func buildLoginSubcommand(hidden bool) *cli.Command {
|
|||
Usage: "Generate a configuration file with your login details",
|
||||
ArgsUsage: " ",
|
||||
Hidden: hidden,
|
||||
Flags: []cli.Flag{
|
||||
loginURL,
|
||||
callbackStore,
|
||||
fedramp,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -38,15 +67,25 @@ func login(c *cli.Context) error {
|
|||
|
||||
path, ok, err := checkForExistingCert()
|
||||
if ok {
|
||||
fmt.Fprintf(os.Stdout, "You have an existing certificate at %s which login would overwrite.\nIf this is intentional, please move or delete that file then run this command again.\n", path)
|
||||
log.Error().Err(err).Msgf("You have an existing certificate at %s which login would overwrite.\nIf this is intentional, please move or delete that file then run this command again.\n", path)
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loginURL, err := url.Parse(baseLoginURL)
|
||||
var (
|
||||
baseloginURL = c.String(loginURLParamName)
|
||||
callbackStoreURL = c.String(callbackURLParamName)
|
||||
)
|
||||
|
||||
isFEDRamp := c.Bool(fedRAMPParamName)
|
||||
if isFEDRamp {
|
||||
baseloginURL = fedBaseLoginURL
|
||||
callbackStoreURL = fedCallbackStoreURL
|
||||
}
|
||||
|
||||
loginURL, err := url.Parse(baseloginURL)
|
||||
if err != nil {
|
||||
// shouldn't happen, URL is hardcoded
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -61,7 +100,23 @@ func login(c *cli.Context) error {
|
|||
log,
|
||||
)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to write the certificate due to the following error:\n%v\n\nYour browser will download the certificate instead. You will have to manually\ncopy it to the following path:\n\n%s\n", err, path)
|
||||
log.Error().Err(err).Msgf("Failed to write the certificate.\n\nYour browser will download the certificate instead. You will have to manually\ncopy it to the following path:\n\n%s\n", path)
|
||||
return err
|
||||
}
|
||||
|
||||
cert, err := credentials.DecodeOriginCert(resourceData)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to decode origin certificate")
|
||||
return err
|
||||
}
|
||||
|
||||
if isFEDRamp {
|
||||
cert.Endpoint = credentials.FedEndpoint
|
||||
}
|
||||
|
||||
resourceData, err = cert.EncodeOriginCert()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to encode origin certificate")
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -69,7 +124,7 @@ func login(c *cli.Context) error {
|
|||
return errors.Wrap(err, fmt.Sprintf("error writing cert to %s", path))
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stdout, "You have successfully logged in.\nIf you wish to copy your credentials to a server, they have been saved to:\n%s\n", path)
|
||||
log.Info().Msgf("You have successfully logged in.\nIf you wish to copy your credentials to a server, they have been saved to:\n%s\n", path)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
)
|
||||
|
||||
|
@ -82,13 +83,13 @@ func RunQuickTunnel(sc *subcommandContext) error {
|
|||
sc.log.Info().Msg(line)
|
||||
}
|
||||
|
||||
if !sc.c.IsSet("protocol") {
|
||||
sc.c.Set("protocol", "quic")
|
||||
if !sc.c.IsSet(flags.Protocol) {
|
||||
_ = sc.c.Set(flags.Protocol, "quic")
|
||||
}
|
||||
|
||||
// Override the number of connections used. Quick tunnels shouldn't be used for production usage,
|
||||
// so, use a single connection instead.
|
||||
sc.c.Set(haConnectionsFlag, "1")
|
||||
_ = sc.c.Set(flags.HaConnections, "1")
|
||||
return StartServer(
|
||||
sc.c,
|
||||
buildInfo,
|
||||
|
|
|
@ -9,22 +9,26 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cfapi"
|
||||
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/credentials"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
)
|
||||
|
||||
type errInvalidJSONCredential struct {
|
||||
const fedRampBaseApiURL = "https://api.fed.cloudflare.com/client/v4"
|
||||
|
||||
type invalidJSONCredentialError struct {
|
||||
err error
|
||||
path string
|
||||
}
|
||||
|
||||
func (e errInvalidJSONCredential) Error() string {
|
||||
func (e invalidJSONCredentialError) Error() string {
|
||||
return "Invalid JSON when parsing tunnel credentials file"
|
||||
}
|
||||
|
||||
|
@ -51,8 +55,13 @@ func newSubcommandContext(c *cli.Context) (*subcommandContext, error) {
|
|||
// Returns something that can find the given tunnel's credentials file.
|
||||
func (sc *subcommandContext) credentialFinder(tunnelID uuid.UUID) CredFinder {
|
||||
if path := sc.c.String(CredFileFlag); path != "" {
|
||||
// Expand path if CredFileFlag contains `~`
|
||||
absPath, err := homedir.Expand(path)
|
||||
if err != nil {
|
||||
return newStaticPath(path, sc.fs)
|
||||
}
|
||||
return newStaticPath(absPath, sc.fs)
|
||||
}
|
||||
return newSearchByID(tunnelID, sc.c, sc.log, sc.fs)
|
||||
}
|
||||
|
||||
|
@ -64,7 +73,16 @@ func (sc *subcommandContext) client() (cfapi.Client, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sc.tunnelstoreClient, err = cred.Client(sc.c.String("api-url"), buildInfo.UserAgent(), sc.log)
|
||||
|
||||
var apiURL string
|
||||
if cred.IsFEDEndpoint() {
|
||||
sc.log.Info().Str("api-url", fedRampBaseApiURL).Msg("using fedramp base api")
|
||||
apiURL = fedRampBaseApiURL
|
||||
} else {
|
||||
apiURL = sc.c.String(cfdflags.ApiURL)
|
||||
}
|
||||
|
||||
sc.tunnelstoreClient, err = cred.Client(apiURL, buildInfo.UserAgent(), sc.log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -73,7 +91,7 @@ func (sc *subcommandContext) client() (cfapi.Client, error) {
|
|||
|
||||
func (sc *subcommandContext) credential() (*credentials.User, error) {
|
||||
if sc.userCredential == nil {
|
||||
uc, err := credentials.Read(sc.c.String(credentials.OriginCertFlag), sc.log)
|
||||
uc, err := credentials.Read(sc.c.String(cfdflags.OriginCert), sc.log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -94,13 +112,13 @@ func (sc *subcommandContext) readTunnelCredentials(credFinder CredFinder) (conne
|
|||
|
||||
var credentials connection.Credentials
|
||||
if err = json.Unmarshal(body, &credentials); err != nil {
|
||||
if strings.HasSuffix(filePath, ".pem") {
|
||||
if filepath.Ext(filePath) == ".pem" {
|
||||
return connection.Credentials{}, fmt.Errorf("The tunnel credentials file should be .json but you gave a .pem. " +
|
||||
"The tunnel credentials file was originally created by `cloudflared tunnel create`. " +
|
||||
"You may have accidentally used the filepath to cert.pem, which is generated by `cloudflared tunnel " +
|
||||
"login`.")
|
||||
}
|
||||
return connection.Credentials{}, errInvalidJSONCredential{path: filePath, err: err}
|
||||
return connection.Credentials{}, invalidJSONCredentialError{path: filePath, err: err}
|
||||
}
|
||||
return credentials, nil
|
||||
}
|
||||
|
@ -122,7 +140,7 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec
|
|||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Couldn't decode tunnel secret from base64")
|
||||
}
|
||||
tunnelSecret = []byte(decodedSecret)
|
||||
tunnelSecret = decodedSecret
|
||||
if len(tunnelSecret) < 32 {
|
||||
return nil, errors.New("Decoded tunnel secret must be at least 32 bytes long")
|
||||
}
|
||||
|
@ -160,7 +178,7 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec
|
|||
errorLines = append(errorLines, fmt.Sprintf("Cloudflared tried to delete the tunnel for you, but encountered an error. You should use `cloudflared tunnel delete %v` to delete the tunnel yourself, because the tunnel can't be run without the tunnelfile.", tunnel.ID))
|
||||
errorLines = append(errorLines, fmt.Sprintf("The delete tunnel error is: %v", deleteErr))
|
||||
} else {
|
||||
errorLines = append(errorLines, fmt.Sprintf("The tunnel was deleted, because the tunnel can't be run without the credentials file"))
|
||||
errorLines = append(errorLines, "The tunnel was deleted, because the tunnel can't be run without the credentials file")
|
||||
}
|
||||
errorMsg := strings.Join(errorLines, "\n")
|
||||
return nil, errors.New(errorMsg)
|
||||
|
@ -189,7 +207,7 @@ func (sc *subcommandContext) list(filter *cfapi.TunnelFilter) ([]*cfapi.Tunnel,
|
|||
}
|
||||
|
||||
func (sc *subcommandContext) delete(tunnelIDs []uuid.UUID) error {
|
||||
forceFlagSet := sc.c.Bool("force")
|
||||
forceFlagSet := sc.c.Bool(cfdflags.Force)
|
||||
|
||||
client, err := sc.client()
|
||||
if err != nil {
|
||||
|
@ -229,7 +247,7 @@ func (sc *subcommandContext) findCredentials(tunnelID uuid.UUID) (connection.Cre
|
|||
var err error
|
||||
if credentialsContents := sc.c.String(CredContentsFlag); credentialsContents != "" {
|
||||
if err = json.Unmarshal([]byte(credentialsContents), &credentials); err != nil {
|
||||
err = errInvalidJSONCredential{path: "TUNNEL_CRED_CONTENTS", err: err}
|
||||
err = invalidJSONCredentialError{path: "TUNNEL_CRED_CONTENTS", err: err}
|
||||
}
|
||||
} else {
|
||||
credFinder := sc.credentialFinder(tunnelID)
|
||||
|
@ -245,7 +263,7 @@ func (sc *subcommandContext) findCredentials(tunnelID uuid.UUID) (connection.Cre
|
|||
func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
|
||||
credentials, err := sc.findCredentials(tunnelID)
|
||||
if err != nil {
|
||||
if e, ok := err.(errInvalidJSONCredential); ok {
|
||||
if e, ok := err.(invalidJSONCredentialError); ok {
|
||||
sc.log.Error().Msgf("The credentials file at %s contained invalid JSON. This is probably caused by passing the wrong filepath. Reminder: the credentials file is a .json file created via `cloudflared tunnel create`.", e.path)
|
||||
sc.log.Error().Msgf("Invalid JSON when parsing credentials file: %s", e.err.Error())
|
||||
}
|
||||
|
|
|
@ -16,18 +16,22 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/urfave/cli/v2"
|
||||
"github.com/urfave/cli/v2/altsrc"
|
||||
"golang.org/x/net/idna"
|
||||
yaml "gopkg.in/yaml.v3"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cfapi"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/diagnostic"
|
||||
"github.com/cloudflare/cloudflared/fips"
|
||||
"github.com/cloudflare/cloudflared/metrics"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -37,7 +41,15 @@ const (
|
|||
CredFileFlag = "credentials-file"
|
||||
CredContentsFlag = "credentials-contents"
|
||||
TunnelTokenFlag = "token"
|
||||
TunnelTokenFileFlag = "token-file"
|
||||
overwriteDNSFlagName = "overwrite-dns"
|
||||
noDiagLogsFlagName = "no-diag-logs"
|
||||
noDiagMetricsFlagName = "no-diag-metrics"
|
||||
noDiagSystemFlagName = "no-diag-system"
|
||||
noDiagRuntimeFlagName = "no-diag-runtime"
|
||||
noDiagNetworkFlagName = "no-diag-network"
|
||||
diagContainerIDFlagName = "diag-container-id"
|
||||
diagPodFlagName = "diag-pod-id"
|
||||
|
||||
LogFieldTunnelID = "tunnelID"
|
||||
)
|
||||
|
@ -49,7 +61,7 @@ var (
|
|||
Usage: "Include deleted tunnels in the list",
|
||||
}
|
||||
listNameFlag = &cli.StringFlag{
|
||||
Name: "name",
|
||||
Name: flags.Name,
|
||||
Aliases: []string{"n"},
|
||||
Usage: "List tunnels with the given `NAME`",
|
||||
}
|
||||
|
@ -97,7 +109,7 @@ var (
|
|||
EnvVars: []string{"TUNNEL_LIST_INVERT_SORT"},
|
||||
}
|
||||
featuresFlag = altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
|
||||
Name: "features",
|
||||
Name: flags.Features,
|
||||
Aliases: []string{"F"},
|
||||
Usage: "Opt into various features that are still being developed or tested.",
|
||||
})
|
||||
|
@ -115,18 +127,23 @@ var (
|
|||
})
|
||||
tunnelTokenFlag = altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: TunnelTokenFlag,
|
||||
Usage: "The Tunnel token. When provided along with credentials, this will take precedence.",
|
||||
Usage: "The Tunnel token. When provided along with credentials, this will take precedence. Also takes precedence over token-file",
|
||||
EnvVars: []string{"TUNNEL_TOKEN"},
|
||||
})
|
||||
tunnelTokenFileFlag = altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: TunnelTokenFileFlag,
|
||||
Usage: "Filepath at which to read the tunnel token. When provided along with credentials, this will take precedence.",
|
||||
EnvVars: []string{"TUNNEL_TOKEN_FILE"},
|
||||
})
|
||||
forceDeleteFlag = &cli.BoolFlag{
|
||||
Name: "force",
|
||||
Name: flags.Force,
|
||||
Aliases: []string{"f"},
|
||||
Usage: "Deletes a tunnel even if tunnel is connected and it has dependencies associated to it. (eg. IP routes)." +
|
||||
" It is not possible to delete tunnels that have connections or non-deleted dependencies, without this flag.",
|
||||
EnvVars: []string{"TUNNEL_RUN_FORCE_OVERWRITE"},
|
||||
}
|
||||
selectProtocolFlag = altsrc.NewStringFlag(&cli.StringFlag{
|
||||
Name: "protocol",
|
||||
Name: flags.Protocol,
|
||||
Value: connection.AutoSelectFlag,
|
||||
Aliases: []string{"p"},
|
||||
Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", connection.AvailableProtocolFlagMessage),
|
||||
|
@ -134,11 +151,11 @@ var (
|
|||
Hidden: true,
|
||||
})
|
||||
postQuantumFlag = altsrc.NewBoolFlag(&cli.BoolFlag{
|
||||
Name: "post-quantum",
|
||||
Name: flags.PostQuantum,
|
||||
Usage: "When given creates an experimental post-quantum secure tunnel",
|
||||
Aliases: []string{"pq"},
|
||||
EnvVars: []string{"TUNNEL_POST_QUANTUM"},
|
||||
Hidden: FipsEnabled,
|
||||
Hidden: fips.IsFipsEnabled(),
|
||||
})
|
||||
sortInfoByFlag = &cli.StringFlag{
|
||||
Name: "sort-by",
|
||||
|
@ -170,15 +187,60 @@ var (
|
|||
EnvVars: []string{"TUNNEL_CREATE_SECRET"},
|
||||
}
|
||||
icmpv4SrcFlag = &cli.StringFlag{
|
||||
Name: "icmpv4-src",
|
||||
Name: flags.ICMPV4Src,
|
||||
Usage: "Source address to send/receive ICMPv4 messages. If not provided cloudflared will dial a local address to determine the source IP or fallback to 0.0.0.0.",
|
||||
EnvVars: []string{"TUNNEL_ICMPV4_SRC"},
|
||||
}
|
||||
icmpv6SrcFlag = &cli.StringFlag{
|
||||
Name: "icmpv6-src",
|
||||
Name: flags.ICMPV6Src,
|
||||
Usage: "Source address and the interface name to send/receive ICMPv6 messages. If not provided cloudflared will dial a local address to determine the source IP or fallback to ::.",
|
||||
EnvVars: []string{"TUNNEL_ICMPV6_SRC"},
|
||||
}
|
||||
metricsFlag = &cli.StringFlag{
|
||||
Name: flags.Metrics,
|
||||
Usage: "The metrics server address i.e.: 127.0.0.1:12345. If your instance is running in a Docker/Kubernetes environment you need to setup port forwarding for your application.",
|
||||
Value: "",
|
||||
}
|
||||
diagContainerFlag = &cli.StringFlag{
|
||||
Name: diagContainerIDFlagName,
|
||||
Usage: "Container ID or Name to collect logs from",
|
||||
Value: "",
|
||||
}
|
||||
diagPodFlag = &cli.StringFlag{
|
||||
Name: diagPodFlagName,
|
||||
Usage: "Kubernetes POD to collect logs from",
|
||||
Value: "",
|
||||
}
|
||||
noDiagLogsFlag = &cli.BoolFlag{
|
||||
Name: noDiagLogsFlagName,
|
||||
Usage: "Log collection will not be performed",
|
||||
Value: false,
|
||||
}
|
||||
noDiagMetricsFlag = &cli.BoolFlag{
|
||||
Name: noDiagMetricsFlagName,
|
||||
Usage: "Metric collection will not be performed",
|
||||
Value: false,
|
||||
}
|
||||
noDiagSystemFlag = &cli.BoolFlag{
|
||||
Name: noDiagSystemFlagName,
|
||||
Usage: "System information collection will not be performed",
|
||||
Value: false,
|
||||
}
|
||||
noDiagRuntimeFlag = &cli.BoolFlag{
|
||||
Name: noDiagRuntimeFlagName,
|
||||
Usage: "Runtime information collection will not be performed",
|
||||
Value: false,
|
||||
}
|
||||
noDiagNetworkFlag = &cli.BoolFlag{
|
||||
Name: noDiagNetworkFlagName,
|
||||
Usage: "Network diagnostics won't be performed",
|
||||
Value: false,
|
||||
}
|
||||
maxActiveFlowsFlag = &cli.Uint64Flag{
|
||||
Name: flags.MaxActiveFlows,
|
||||
Usage: "Overrides the remote configuration for max active private network flows (TCP/UDP) that this cloudflared instance supports",
|
||||
EnvVars: []string{"TUNNEL_MAX_ACTIVE_FLOWS"},
|
||||
}
|
||||
)
|
||||
|
||||
func buildCreateCommand() *cli.Command {
|
||||
|
@ -281,7 +343,7 @@ func listCommand(c *cli.Context) error {
|
|||
if !c.Bool("show-deleted") {
|
||||
filter.NoDeleted()
|
||||
}
|
||||
if name := c.String("name"); name != "" {
|
||||
if name := c.String(flags.Name); name != "" {
|
||||
filter.ByName(name)
|
||||
}
|
||||
if namePrefix := c.String("name-prefix"); namePrefix != "" {
|
||||
|
@ -375,7 +437,6 @@ func formatAndPrintTunnelList(tunnels []*cfapi.Tunnel, showRecentlyDisconnected
|
|||
}
|
||||
|
||||
func fmtConnections(connections []cfapi.Connection, showRecentlyDisconnected bool) string {
|
||||
|
||||
// Count connections per colo
|
||||
numConnsPerColo := make(map[string]uint, len(connections))
|
||||
for _, connection := range connections {
|
||||
|
@ -392,7 +453,7 @@ func fmtConnections(connections []cfapi.Connection, showRecentlyDisconnected boo
|
|||
sort.Strings(sortedColos)
|
||||
|
||||
// Map each colo to its frequency, combine into output string.
|
||||
var output []string
|
||||
output := make([]string, 0, len(sortedColos))
|
||||
for _, coloName := range sortedColos {
|
||||
output = append(output, fmt.Sprintf("%dx%s", numConnsPerColo[coloName], coloName))
|
||||
}
|
||||
|
@ -412,16 +473,21 @@ func buildReadyCommand() *cli.Command {
|
|||
}
|
||||
|
||||
func readyCommand(c *cli.Context) error {
|
||||
metricsOpts := c.String("metrics")
|
||||
if !c.IsSet("metrics") {
|
||||
return fmt.Errorf("--metrics has to be provided")
|
||||
metricsOpts := c.String(flags.Metrics)
|
||||
if !c.IsSet(flags.Metrics) {
|
||||
return errors.New("--metrics has to be provided")
|
||||
}
|
||||
|
||||
requestURL := fmt.Sprintf("http://%s/ready", metricsOpts)
|
||||
res, err := http.Get(requestURL)
|
||||
req, err := http.NewRequest(http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
|
@ -648,8 +714,10 @@ func buildRunCommand() *cli.Command {
|
|||
selectProtocolFlag,
|
||||
featuresFlag,
|
||||
tunnelTokenFlag,
|
||||
tunnelTokenFileFlag,
|
||||
icmpv4SrcFlag,
|
||||
icmpv6SrcFlag,
|
||||
maxActiveFlowsFlag,
|
||||
}
|
||||
flags = append(flags, configureProxyFlags(false)...)
|
||||
return &cli.Command{
|
||||
|
@ -687,12 +755,22 @@ func runCommand(c *cli.Context) error {
|
|||
"your origin will not be reachable. You should remove the `hostname` property to avoid this warning.")
|
||||
}
|
||||
|
||||
tokenStr := c.String(TunnelTokenFlag)
|
||||
// Check if tokenStr is blank before checking for tokenFile
|
||||
if tokenStr == "" {
|
||||
if tokenFile := c.String(TunnelTokenFileFlag); tokenFile != "" {
|
||||
data, err := os.ReadFile(tokenFile)
|
||||
if err != nil {
|
||||
return cliutil.UsageError("Failed to read token file: " + err.Error())
|
||||
}
|
||||
tokenStr = strings.TrimSpace(string(data))
|
||||
}
|
||||
}
|
||||
// Check if token is provided and if not use default tunnelID flag method
|
||||
if tokenStr := c.String(TunnelTokenFlag); tokenStr != "" {
|
||||
if tokenStr != "" {
|
||||
if token, err := ParseToken(tokenStr); err == nil {
|
||||
return sc.runWithCredentials(token.Credentials())
|
||||
}
|
||||
|
||||
return cliutil.UsageError("Provided Tunnel token is not valid.")
|
||||
} else {
|
||||
tunnelRef := c.Args().First()
|
||||
|
@ -897,8 +975,10 @@ func lbRouteFromArg(c *cli.Context) (cfapi.HostnameRoute, error) {
|
|||
return cfapi.NewLBRoute(lbName, lbPool), nil
|
||||
}
|
||||
|
||||
var nameRegex = regexp.MustCompile("^[_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
|
||||
var hostNameRegex = regexp.MustCompile("^[*_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
|
||||
var (
|
||||
nameRegex = regexp.MustCompile("^[_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
|
||||
hostNameRegex = regexp.MustCompile("^[*_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
|
||||
)
|
||||
|
||||
func validateName(s string, allowWildcardSubdomain bool) bool {
|
||||
if allowWildcardSubdomain {
|
||||
|
@ -986,3 +1066,78 @@ SUBCOMMAND OPTIONS:
|
|||
`
|
||||
return fmt.Sprintf(template, parentFlagsHelp)
|
||||
}
|
||||
|
||||
func buildDiagCommand() *cli.Command {
|
||||
return &cli.Command{
|
||||
Name: "diag",
|
||||
Action: cliutil.ConfiguredAction(diagCommand),
|
||||
Usage: "Creates a diagnostic report from a local cloudflared instance",
|
||||
UsageText: "cloudflared tunnel [tunnel command options] diag [subcommand options]",
|
||||
Description: "cloudflared tunnel diag will create a diagnostic report of a local cloudflared instance. The diagnostic procedure collects: logs, metrics, system information, traceroute to Cloudflare Edge, and runtime information. Since there may be multiple instances of cloudflared running the --metrics option may be provided to target a specific instance.",
|
||||
Flags: []cli.Flag{
|
||||
metricsFlag,
|
||||
diagContainerFlag,
|
||||
diagPodFlag,
|
||||
noDiagLogsFlag,
|
||||
noDiagMetricsFlag,
|
||||
noDiagSystemFlag,
|
||||
noDiagRuntimeFlag,
|
||||
noDiagNetworkFlag,
|
||||
},
|
||||
CustomHelpTemplate: commandHelpTemplate(),
|
||||
}
|
||||
}
|
||||
|
||||
func diagCommand(ctx *cli.Context) error {
|
||||
sctx, err := newSubcommandContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log := sctx.log
|
||||
options := diagnostic.Options{
|
||||
KnownAddresses: metrics.GetMetricsKnownAddresses(metrics.Runtime),
|
||||
Address: sctx.c.String(flags.Metrics),
|
||||
ContainerID: sctx.c.String(diagContainerIDFlagName),
|
||||
PodID: sctx.c.String(diagPodFlagName),
|
||||
Toggles: diagnostic.Toggles{
|
||||
NoDiagLogs: sctx.c.Bool(noDiagLogsFlagName),
|
||||
NoDiagMetrics: sctx.c.Bool(noDiagMetricsFlagName),
|
||||
NoDiagSystem: sctx.c.Bool(noDiagSystemFlagName),
|
||||
NoDiagRuntime: sctx.c.Bool(noDiagRuntimeFlagName),
|
||||
NoDiagNetwork: sctx.c.Bool(noDiagNetworkFlagName),
|
||||
},
|
||||
}
|
||||
|
||||
if options.Address == "" {
|
||||
log.Info().Msg("If your instance is running in a Docker/Kubernetes environment you need to setup port forwarding for your application.")
|
||||
}
|
||||
|
||||
states, err := diagnostic.RunDiagnostic(log, options)
|
||||
|
||||
if errors.Is(err, diagnostic.ErrMetricsServerNotFound) {
|
||||
log.Warn().Msg("No instances found")
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, diagnostic.ErrMultipleMetricsServerFound) {
|
||||
if states != nil {
|
||||
log.Info().Msgf("Found multiple instances running:")
|
||||
for _, state := range states {
|
||||
log.Info().Msgf("Instance: tunnel-id=%s connector-id=%s metrics-address=%s", state.TunnelID, state.ConnectorID, state.URL.String())
|
||||
}
|
||||
log.Info().Msgf("To select one instance use the option --metrics")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, diagnostic.ErrLogConfigurationIsInvalid) {
|
||||
log.Info().Msg("Couldn't extract logs from the instance. If the instance is running in a containerized environment use the option --diag-container-id or --diag-pod-id. If there is no logging configuration use --no-diag-logs.")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warn().Msg("Diagnostic completed with one or more errors")
|
||||
} else {
|
||||
log.Info().Msg("Diagnostic completed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ var (
|
|||
Usage: "The ID or name of the virtual network to which the route is associated to.",
|
||||
}
|
||||
|
||||
routeAddError = errors.New("You must supply exactly one argument, the ID or CIDR of the route you want to delete")
|
||||
errAddRoute = errors.New("You must supply exactly one argument, the ID or CIDR of the route you want to delete")
|
||||
)
|
||||
|
||||
func buildRouteIPSubcommand() *cli.Command {
|
||||
|
@ -32,7 +32,7 @@ func buildRouteIPSubcommand() *cli.Command {
|
|||
UsageText: "cloudflared tunnel [--config FILEPATH] route COMMAND [arguments...]",
|
||||
Description: `cloudflared can provision routes for any IP space in your corporate network. Users enrolled in
|
||||
your Cloudflare for Teams organization can reach those IPs through the Cloudflare WARP
|
||||
client. You can then configure L7/L4 filtering on https://dash.teams.cloudflare.com to
|
||||
client. You can then configure L7/L4 filtering on https://one.dash.cloudflare.com to
|
||||
determine who can reach certain routes.
|
||||
By default IP routes all exist within a single virtual network. If you use the same IP
|
||||
space(s) in different physical private networks, all meant to be reachable via IP routes,
|
||||
|
@ -187,7 +187,7 @@ func deleteRouteCommand(c *cli.Context) error {
|
|||
}
|
||||
|
||||
if c.NArg() != 1 {
|
||||
return routeAddError
|
||||
return errAddRoute
|
||||
}
|
||||
|
||||
var routeId uuid.UUID
|
||||
|
@ -195,7 +195,7 @@ func deleteRouteCommand(c *cli.Context) error {
|
|||
if err != nil {
|
||||
_, network, err := net.ParseCIDR(c.Args().First())
|
||||
if err != nil || network == nil {
|
||||
return routeAddError
|
||||
return errAddRoute
|
||||
}
|
||||
|
||||
var vnetId *uuid.UUID
|
||||
|
|
|
@ -15,13 +15,14 @@ import (
|
|||
"golang.org/x/term"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCheckUpdateFreq = time.Hour * 24
|
||||
noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/as-a-service/"
|
||||
noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/configure-tunnels/local-management/as-a-service/"
|
||||
noUpdateOnWindowsMessage = "cloudflared will not automatically update on Windows systems."
|
||||
noUpdateManagedPackageMessage = "cloudflared will not automatically update if installed by a package manager."
|
||||
isManagedInstallFile = ".installedFromPackageManager"
|
||||
|
@ -38,6 +39,7 @@ var (
|
|||
|
||||
// BinaryUpdated implements ExitCoder interface, the app will exit with status code 11
|
||||
// https://pkg.go.dev/github.com/urfave/cli/v2?tab=doc#ExitCoder
|
||||
// nolint: errname
|
||||
type statusSuccess struct {
|
||||
newVersion string
|
||||
}
|
||||
|
@ -50,16 +52,16 @@ func (u *statusSuccess) ExitCode() int {
|
|||
return 11
|
||||
}
|
||||
|
||||
// UpdateErr implements ExitCoder interface, the app will exit with status code 10
|
||||
type statusErr struct {
|
||||
// statusError implements ExitCoder interface, the app will exit with status code 10
|
||||
type statusError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *statusErr) Error() string {
|
||||
func (e *statusError) Error() string {
|
||||
return fmt.Sprintf("failed to update cloudflared: %v", e.err)
|
||||
}
|
||||
|
||||
func (e *statusErr) ExitCode() int {
|
||||
func (e *statusError) ExitCode() int {
|
||||
return 10
|
||||
}
|
||||
|
||||
|
@ -79,7 +81,7 @@ type UpdateOutcome struct {
|
|||
}
|
||||
|
||||
func (uo *UpdateOutcome) noUpdate() bool {
|
||||
return uo.Error == nil && uo.Updated == false
|
||||
return uo.Error == nil && !uo.Updated
|
||||
}
|
||||
|
||||
func Init(info *cliutil.BuildInfo) {
|
||||
|
@ -153,7 +155,7 @@ func Update(c *cli.Context) error {
|
|||
log.Info().Msg("cloudflared is set to update from staging")
|
||||
}
|
||||
|
||||
isForced := c.Bool("force")
|
||||
isForced := c.Bool(cfdflags.Force)
|
||||
if isForced {
|
||||
log.Info().Msg("cloudflared is set to upgrade to the latest publish version regardless of the current version")
|
||||
}
|
||||
|
@ -166,7 +168,7 @@ func Update(c *cli.Context) error {
|
|||
intendedVersion: c.String("version"),
|
||||
})
|
||||
if updateOutcome.Error != nil {
|
||||
return &statusErr{updateOutcome.Error}
|
||||
return &statusError{updateOutcome.Error}
|
||||
}
|
||||
|
||||
if updateOutcome.noUpdate() {
|
||||
|
@ -252,7 +254,7 @@ func (a *AutoUpdater) Run(ctx context.Context) error {
|
|||
pid, err := a.listeners.StartProcess()
|
||||
if err != nil {
|
||||
a.log.Err(err).Msg("Unable to restart server automatically")
|
||||
return &statusErr{err: err}
|
||||
return &statusError{err: err}
|
||||
}
|
||||
// stop old process after autoupdate. Otherwise we create a new process
|
||||
// after each update
|
||||
|
|
|
@ -10,9 +10,9 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
|
@ -134,7 +134,7 @@ func (v *WorkersVersion) Apply() error {
|
|||
|
||||
if err := os.Rename(newFilePath, v.targetPath); err != nil {
|
||||
//attempt rollback
|
||||
os.Rename(oldFilePath, v.targetPath)
|
||||
_ = os.Rename(oldFilePath, v.targetPath)
|
||||
return err
|
||||
}
|
||||
os.Remove(oldFilePath)
|
||||
|
@ -181,7 +181,7 @@ func download(url, filepath string, isCompressed bool) error {
|
|||
tr := tar.NewReader(gr)
|
||||
|
||||
// advance the reader pass the header, which will be the single binary file
|
||||
tr.Next()
|
||||
_, _ = tr.Next()
|
||||
|
||||
r = tr
|
||||
}
|
||||
|
@ -198,7 +198,7 @@ func download(url, filepath string, isCompressed bool) error {
|
|||
|
||||
// isCompressedFile is a really simple file extension check to see if this is a macos tar and gzipped
|
||||
func isCompressedFile(urlstring string) bool {
|
||||
if strings.HasSuffix(urlstring, ".tgz") {
|
||||
if path.Ext(urlstring) == ".tgz" {
|
||||
return true
|
||||
}
|
||||
|
||||
|
@ -206,7 +206,7 @@ func isCompressedFile(urlstring string) bool {
|
|||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.HasSuffix(u.Path, ".tgz")
|
||||
return path.Ext(u.Path) == ".tgz"
|
||||
}
|
||||
|
||||
// writeBatchFile writes a batch file out to disk
|
||||
|
@ -249,7 +249,6 @@ func runWindowsBatch(batchFile string) error {
|
|||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
return fmt.Errorf("Error during update : %s;", string(exitError.Stderr))
|
||||
}
|
||||
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ import (
|
|||
const (
|
||||
windowsServiceName = "Cloudflared"
|
||||
windowsServiceDescription = "Cloudflared agent"
|
||||
windowsServiceUrl = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/as-a-service/windows/"
|
||||
windowsServiceUrl = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/configure-tunnels/local-management/as-a-service/windows/"
|
||||
|
||||
recoverActionDelay = time.Second * 20
|
||||
failureCountResetPeriod = time.Hour * 24
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from util import LOGGER, nofips, start_cloudflared, wait_tunnel_ready
|
||||
from util import LOGGER, start_cloudflared, wait_tunnel_ready
|
||||
|
||||
|
||||
@nofips
|
||||
class TestPostQuantum:
|
||||
def _extra_config(self):
|
||||
config = {
|
||||
|
@ -12,6 +11,11 @@ class TestPostQuantum:
|
|||
def test_post_quantum(self, tmp_path, component_tests_config):
|
||||
config = component_tests_config(self._extra_config())
|
||||
LOGGER.debug(config)
|
||||
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], cfd_args=["run", "--post-quantum"], new_process=True):
|
||||
wait_tunnel_ready(tunnel_url=config.get_url(),
|
||||
require_min_connections=1)
|
||||
with start_cloudflared(
|
||||
tmp_path,
|
||||
config,
|
||||
cfd_pre_args=["tunnel", "--ha-connections", "1"],
|
||||
cfd_args=["run", "--post-quantum"],
|
||||
new_process=True,
|
||||
):
|
||||
wait_tunnel_ready(tunnel_url=config.get_url(), require_min_connections=1)
|
||||
|
|
|
@ -155,7 +155,7 @@ func FindOrCreateConfigPath() string {
|
|||
// i.e. it fails if a user specifies both --url and --unix-socket
|
||||
func ValidateUnixSocket(c *cli.Context) (string, error) {
|
||||
if c.IsSet("unix-socket") && (c.IsSet("url") || c.NArg() > 0) {
|
||||
return "", errors.New("--unix-socket must be used exclusivly.")
|
||||
return "", errors.New("--unix-socket must be used exclusively.")
|
||||
}
|
||||
return c.String("unix-socket"), nil
|
||||
}
|
||||
|
@ -260,6 +260,7 @@ type Configuration struct {
|
|||
|
||||
type WarpRoutingConfig struct {
|
||||
ConnectTimeout *CustomDuration `yaml:"connectTimeout" json:"connectTimeout,omitempty"`
|
||||
MaxActiveFlows *uint64 `yaml:"maxActiveFlows" json:"maxActiveFlows,omitempty"`
|
||||
TCPKeepAlive *CustomDuration `yaml:"tcpKeepAlive" json:"tcpKeepAlive,omitempty"`
|
||||
}
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ type Credentials struct {
|
|||
AccountTag string
|
||||
TunnelSecret []byte
|
||||
TunnelID uuid.UUID
|
||||
Endpoint string
|
||||
}
|
||||
|
||||
func (c *Credentials) Auth() pogs.TunnelAuth {
|
||||
|
@ -74,13 +75,16 @@ type TunnelToken struct {
|
|||
AccountTag string `json:"a"`
|
||||
TunnelSecret []byte `json:"s"`
|
||||
TunnelID uuid.UUID `json:"t"`
|
||||
Endpoint string `json:"e,omitempty"`
|
||||
}
|
||||
|
||||
func (t TunnelToken) Credentials() Credentials {
|
||||
// nolint: gosimple
|
||||
return Credentials{
|
||||
AccountTag: t.AccountTag,
|
||||
TunnelSecret: t.TunnelSecret,
|
||||
TunnelID: t.TunnelID,
|
||||
Endpoint: t.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,14 +2,18 @@ package connection
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
|
||||
"github.com/cloudflare/cloudflared/stream"
|
||||
"github.com/cloudflare/cloudflared/tracing"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
|
@ -77,7 +81,7 @@ func (moc *mockOriginProxy) ProxyHTTP(
|
|||
return wsFlakyEndpoint(w, req)
|
||||
default:
|
||||
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
|
||||
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
|
||||
return fmt.Errorf("unknown websocket endpoint %s", req.URL.Path)
|
||||
}
|
||||
}
|
||||
switch req.URL.Path {
|
||||
|
@ -95,7 +99,6 @@ func (moc *mockOriginProxy) ProxyHTTP(
|
|||
originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (moc *mockOriginProxy) ProxyTCP(
|
||||
|
@ -103,6 +106,10 @@ func (moc *mockOriginProxy) ProxyTCP(
|
|||
rwa ReadWriteAcker,
|
||||
r *TCPRequest,
|
||||
) error {
|
||||
if r.CfTraceID == "flow-rate-limited" {
|
||||
return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "tcp flow rate limited")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -178,7 +185,8 @@ func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
|
|||
|
||||
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
|
||||
|
||||
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
|
||||
rInt, _ := rand.Int(rand.Reader, big.NewInt(50))
|
||||
closedAfter := time.Millisecond * time.Duration(rInt.Int64())
|
||||
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
|
||||
stream.Pipe(wsConn, originConn, &log)
|
||||
cancel()
|
||||
|
|
|
@ -22,8 +22,9 @@ var (
|
|||
|
||||
var (
|
||||
// pre-generate possible values for res
|
||||
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
|
||||
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
|
||||
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared", false)
|
||||
responseMetaHeaderCfdFlowRateLimited = mustInitRespMetaHeader("cloudflared", true)
|
||||
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin", false)
|
||||
)
|
||||
|
||||
// HTTPHeader is a custom header struct that expects only ever one value for the header.
|
||||
|
@ -35,10 +36,11 @@ type HTTPHeader struct {
|
|||
|
||||
type responseMetaHeader struct {
|
||||
Source string `json:"src"`
|
||||
FlowRateLimited bool `json:"flow_rate_limited,omitempty"`
|
||||
}
|
||||
|
||||
func mustInitRespMetaHeader(src string) string {
|
||||
header, err := json.Marshal(responseMetaHeader{Source: src})
|
||||
func mustInitRespMetaHeader(src string, flowRateLimited bool) string {
|
||||
header, err := json.Marshal(responseMetaHeader{Source: src, FlowRateLimited: flowRateLimited})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err))
|
||||
}
|
||||
|
@ -112,7 +114,7 @@ func SerializeHeaders(h1Headers http.Header) string {
|
|||
func DeserializeHeaders(serializedHeaders string) ([]HTTPHeader, error) {
|
||||
const unableToDeserializeErr = "Unable to deserialize headers"
|
||||
|
||||
var deserialized []HTTPHeader
|
||||
deserialized := make([]HTTPHeader, 0)
|
||||
for _, serializedPair := range strings.Split(serializedHeaders, ";") {
|
||||
if len(serializedPair) == 0 {
|
||||
continue
|
||||
|
|
|
@ -16,6 +16,8 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tracing"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
@ -156,7 +158,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
c.log.Error().Err(requestErr).Msg("failed to serve incoming request")
|
||||
|
||||
// WriteErrorResponse will return false if status was already written. we need to abort handler.
|
||||
if !respWriter.WriteErrorResponse() {
|
||||
if !respWriter.WriteErrorResponse(requestErr) {
|
||||
c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent")
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
|
@ -209,8 +211,9 @@ func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, l
|
|||
w: w,
|
||||
log: log,
|
||||
}
|
||||
respWriter.WriteErrorResponse()
|
||||
return nil, fmt.Errorf("%T doesn't implement http.Flusher", w)
|
||||
err := fmt.Errorf("%T doesn't implement http.Flusher", w)
|
||||
respWriter.WriteErrorResponse(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &http2RespWriter{
|
||||
|
@ -295,7 +298,7 @@ func (rp *http2RespWriter) WriteHeader(status int) {
|
|||
rp.log.Warn().Msg("WriteHeader after hijack")
|
||||
return
|
||||
}
|
||||
rp.WriteRespHeaders(status, rp.respHeaders)
|
||||
_ = rp.WriteRespHeaders(status, rp.respHeaders)
|
||||
}
|
||||
|
||||
func (rp *http2RespWriter) hijacked() bool {
|
||||
|
@ -328,12 +331,16 @@ func (rp *http2RespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|||
return conn, readWriter, nil
|
||||
}
|
||||
|
||||
func (rp *http2RespWriter) WriteErrorResponse() bool {
|
||||
func (rp *http2RespWriter) WriteErrorResponse(err error) bool {
|
||||
if rp.statusWritten {
|
||||
return false
|
||||
}
|
||||
|
||||
if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
|
||||
rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited)
|
||||
} else {
|
||||
rp.setResponseMetaHeader(responseMetaHeaderCfd)
|
||||
}
|
||||
rp.w.WriteHeader(http.StatusBadGateway)
|
||||
rp.statusWritten = true
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tracing"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
@ -65,19 +67,18 @@ func TestHTTP2ConfigurationSet(t *testing.T) {
|
|||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
http2Conn.Serve(ctx)
|
||||
_ = http2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||
require.NoError(t, err)
|
||||
|
||||
endpoint := fmt.Sprintf("http://localhost:8080/ok")
|
||||
reqBody := []byte(`{
|
||||
"version": 2,
|
||||
"config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}}
|
||||
`)
|
||||
reader := bytes.NewReader(reqBody)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, reader)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://localhost:8080/ok", reader)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate)
|
||||
|
||||
|
@ -85,11 +86,11 @@ func TestHTTP2ConfigurationSet(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
bdy, err := io.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy))
|
||||
cancel()
|
||||
wg.Wait()
|
||||
|
||||
}
|
||||
|
||||
func TestServeHTTP(t *testing.T) {
|
||||
|
@ -134,7 +135,7 @@ func TestServeHTTP(t *testing.T) {
|
|||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
http2Conn.Serve(ctx)
|
||||
_ = http2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||
|
@ -153,6 +154,7 @@ func TestServeHTTP(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, test.expectedBody, respBody)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
if test.isProxyError {
|
||||
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
|
||||
} else {
|
||||
|
@ -281,10 +283,11 @@ func TestServeWS(t *testing.T) {
|
|||
|
||||
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
|
||||
require.Equal(t, data, respBody, "expect %s, got %s", string(data), string(respBody))
|
||||
|
||||
cancel()
|
||||
resp := respWriter.Result()
|
||||
defer resp.Body.Close()
|
||||
// http2RespWriter should rewrite status 101 to 200
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
|
||||
|
@ -304,7 +307,7 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
|
|||
serverDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
cfdHTTP2Conn.Serve(ctx)
|
||||
_ = cfdHTTP2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
edgeTransport := http2.Transport{}
|
||||
|
@ -319,13 +322,16 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
|
|||
readPipe, writePipe := io.Pipe()
|
||||
reqCtx, reqCancel := context.WithCancel(ctx)
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
||||
|
||||
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
// http2RespWriter should rewrite status 101 to 200
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
|
@ -378,7 +384,7 @@ func TestServeControlStream(t *testing.T) {
|
|||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
http2Conn.Serve(ctx)
|
||||
_ = http2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||
|
@ -391,7 +397,8 @@ func TestServeControlStream(t *testing.T) {
|
|||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
edgeHTTP2Conn.RoundTrip(req)
|
||||
// nolint: bodyclose
|
||||
_, _ = edgeHTTP2Conn.RoundTrip(req)
|
||||
}()
|
||||
|
||||
<-rpcClientFactory.registered
|
||||
|
@ -431,7 +438,7 @@ func TestFailRegistration(t *testing.T) {
|
|||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
http2Conn.Serve(ctx)
|
||||
_ = http2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||
|
@ -442,9 +449,10 @@ func TestFailRegistration(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
|
||||
|
||||
assert.NotNil(t, http2Conn.controlStreamErr)
|
||||
require.Error(t, http2Conn.controlStreamErr)
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
@ -481,7 +489,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
|
|||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
http2Conn.Serve(ctx)
|
||||
_ = http2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||
|
@ -494,6 +502,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
|
|||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// nolint: bodyclose
|
||||
_, _ = edgeHTTP2Conn.RoundTrip(req)
|
||||
}()
|
||||
|
||||
|
@ -524,6 +533,36 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestServeTCP_RateLimited(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
http2Conn, edgeConn := newTestHTTP2Connection()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = http2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(InternalTCPProxySrcHeader, "tcp")
|
||||
req.Header.Set(tracing.TracerContextName, "flow-rate-limited")
|
||||
|
||||
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
|
||||
require.Equal(t, responseMetaHeaderCfdFlowRateLimited, resp.Header.Get(ResponseMetaHeader))
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func benchmarkServeHTTP(b *testing.B, test testRequest) {
|
||||
http2Conn, edgeConn := newTestHTTP2Connection()
|
||||
|
||||
|
@ -532,7 +571,7 @@ func benchmarkServeHTTP(b *testing.B, test testRequest) {
|
|||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
http2Conn.Serve(ctx)
|
||||
_ = http2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
const (
|
||||
AvailableProtocolFlagMessage = "Available protocols: 'auto' - automatically chooses the best protocol over time (the default; and also the recommended one); 'quic' - based on QUIC, relying on UDP egress to Cloudflare edge; 'http2' - using Go's HTTP2 library, relying on TCP egress to Cloudflare edge"
|
||||
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge (unused, but kept for legacy reference).
|
||||
edgeH2muxTLSServerName = "cftunnel.com"
|
||||
_ = "cftunnel.com"
|
||||
// edgeH2TLSServerName is the server name to establish http2 connection with edge
|
||||
edgeH2TLSServerName = "h2.cftunnel.com"
|
||||
// edgeQUICServerName is the server name to establish quic connection with edge.
|
||||
|
@ -24,11 +24,9 @@ const (
|
|||
ResolveTTL = time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
// ProtocolList represents a list of supported protocols for communication with the edge
|
||||
// in order of precedence for remote percentage fetcher.
|
||||
ProtocolList = []Protocol{QUIC, HTTP2}
|
||||
)
|
||||
// ProtocolList represents a list of supported protocols for communication with the edge
|
||||
// in order of precedence for remote percentage fetcher.
|
||||
var ProtocolList = []Protocol{QUIC, HTTP2}
|
||||
|
||||
type Protocol int64
|
||||
|
||||
|
@ -58,7 +56,7 @@ func (p Protocol) String() string {
|
|||
case QUIC:
|
||||
return "quic"
|
||||
default:
|
||||
return fmt.Sprintf("unknown protocol")
|
||||
return "unknown protocol"
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -246,11 +244,11 @@ func NewProtocolSelector(
|
|||
return newRemoteProtocolSelector(fetchedProtocol, ProtocolList, threshold, protocolFetcher, resolveTTL, log), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
|
||||
return nil, fmt.Errorf("unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
|
||||
}
|
||||
|
||||
func switchThreshold(accountTag string) int32 {
|
||||
h := fnv.New32a()
|
||||
_, _ = h.Write([]byte(accountTag))
|
||||
return int32(h.Sum32() % 100)
|
||||
return int32(h.Sum32() % 100) // nolint: gosec
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
|
||||
cfdquic "github.com/cloudflare/cloudflared/quic"
|
||||
"github.com/cloudflare/cloudflared/tracing"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
|
@ -101,14 +103,19 @@ func (q *quicConnection) Serve(ctx context.Context) error {
|
|||
// amount of the grace period, allowing requests to finish before we cancel the context, which will
|
||||
// make cloudflared exit.
|
||||
if err := q.serveControlStream(ctx, controlStream); err == nil {
|
||||
if q.gracePeriod > 0 {
|
||||
// In Go1.23 this can be removed and replaced with time.Ticker
|
||||
// see https://pkg.go.dev/time#Tick
|
||||
ticker := time.NewTicker(q.gracePeriod)
|
||||
defer ticker.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.Tick(q.gracePeriod):
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
return err
|
||||
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
defer cancel()
|
||||
|
@ -129,7 +136,7 @@ func (q *quicConnection) serveControlStream(ctx context.Context, controlStream q
|
|||
|
||||
// Close the connection with no errors specified.
|
||||
func (q *quicConnection) Close() {
|
||||
q.conn.CloseWithError(0, "")
|
||||
_ = q.conn.CloseWithError(0, "")
|
||||
}
|
||||
|
||||
func (q *quicConnection) acceptStream(ctx context.Context) error {
|
||||
|
@ -182,7 +189,13 @@ func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.R
|
|||
return err
|
||||
}
|
||||
|
||||
if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil {
|
||||
var metadata []pogs.Metadata
|
||||
// Check the type of error that was throw and add metadata that will help identify it on OTD.
|
||||
if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
|
||||
metadata = append(metadata, pogs.ErrorFlowConnectRateLimitedMetadata)
|
||||
}
|
||||
|
||||
if writeRespErr := stream.WriteConnectResponseData(err, metadata...); writeRespErr != nil {
|
||||
return writeRespErr
|
||||
}
|
||||
}
|
||||
|
@ -278,7 +291,7 @@ func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header)
|
|||
func (hrw *httpResponseAdapter) Write(p []byte) (int, error) {
|
||||
// Make sure to send WriteHeader response if not called yet
|
||||
if !hrw.connectResponseSent {
|
||||
hrw.WriteRespHeaders(http.StatusOK, hrw.headers)
|
||||
_ = hrw.WriteRespHeaders(http.StatusOK, hrw.headers)
|
||||
}
|
||||
return hrw.RequestServerStream.Write(p)
|
||||
}
|
||||
|
@ -291,7 +304,7 @@ func (hrw *httpResponseAdapter) Header() http.Header {
|
|||
func (hrw *httpResponseAdapter) Flush() {}
|
||||
|
||||
func (hrw *httpResponseAdapter) WriteHeader(status int) {
|
||||
hrw.WriteRespHeaders(status, hrw.headers)
|
||||
_ = hrw.WriteRespHeaders(status, hrw.headers)
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
|
@ -304,7 +317,7 @@ func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
|
||||
hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
|
||||
_ = hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
|
||||
}
|
||||
|
||||
func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error {
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
|
@ -21,13 +22,15 @@ import (
|
|||
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/nettest"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
|
||||
"github.com/cloudflare/cloudflared/datagramsession"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/packet"
|
||||
|
@ -53,7 +56,8 @@ var _ ReadWriteAcker = (*streamReadWriteAcker)(nil)
|
|||
func TestQUICServer(t *testing.T) {
|
||||
// This is simply a sample websocket frame message.
|
||||
wsBuf := &bytes.Buffer{}
|
||||
wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
|
||||
err := wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var tests = []struct {
|
||||
desc string
|
||||
|
@ -158,17 +162,19 @@ func TestQUICServer(t *testing.T) {
|
|||
|
||||
serverDone := make(chan struct{})
|
||||
go func() {
|
||||
// nolint: testifylint
|
||||
quicServer(
|
||||
ctx, t, quicListener, test.dest, test.connectionType, test.metadata, test.message, test.expectedResponse,
|
||||
)
|
||||
close(serverDone)
|
||||
}()
|
||||
|
||||
// nolint: gosec
|
||||
tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(i))
|
||||
|
||||
connDone := make(chan struct{})
|
||||
go func() {
|
||||
tunnelConn.Serve(ctx)
|
||||
_ = tunnelConn.Serve(ctx)
|
||||
close(connDone)
|
||||
}()
|
||||
|
||||
|
@ -254,14 +260,14 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, tr *tracing.T
|
|||
case "/ok":
|
||||
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
|
||||
case "/slow_echo_body":
|
||||
time.Sleep(5)
|
||||
time.Sleep(5 * time.Nanosecond)
|
||||
fallthrough
|
||||
case "/echo_body":
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
}
|
||||
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
|
||||
io.Copy(w, r.Body)
|
||||
_, _ = io.Copy(w, r.Body)
|
||||
case "/error":
|
||||
return fmt.Errorf("Failed to proxy to origin")
|
||||
default:
|
||||
|
@ -493,16 +499,20 @@ func TestBuildHTTPRequest(t *testing.T) {
|
|||
test := test // capture range variable
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req, err := buildHTTPRequest(context.Background(), test.connectRequest, test.body, 0, &log)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
test.req = test.req.WithContext(req.Context())
|
||||
assert.Equal(t, test.req, req.Request)
|
||||
require.Equal(t, test.req, req.Request)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
|
||||
rwa.AckConnection("")
|
||||
io.Copy(rwa, rwa)
|
||||
if tcpRequest.Dest == "rate-limit-me" {
|
||||
return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "failed tcp stream")
|
||||
}
|
||||
|
||||
_ = rwa.AckConnection("")
|
||||
_, _ = io.Copy(rwa, rwa)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -520,16 +530,19 @@ func TestServeUDPSession(t *testing.T) {
|
|||
edgeQUICSessionChan := make(chan quic.Connection)
|
||||
go func() {
|
||||
earlyListener, err := quic.Listen(udpListener, testTLSServerConfig, testQUICConfig)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
edgeQUICSession, err := earlyListener.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
edgeQUICSessionChan <- edgeQUICSession
|
||||
}()
|
||||
|
||||
// Random index to avoid reusing port
|
||||
tunnelConn, datagramConn := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), 28)
|
||||
go tunnelConn.Serve(ctx)
|
||||
go func() {
|
||||
_ = tunnelConn.Serve(ctx)
|
||||
}()
|
||||
|
||||
edgeQUICSession := <-edgeQUICSessionChan
|
||||
|
||||
|
@ -545,14 +558,14 @@ func TestNopCloserReadWriterCloseBeforeEOF(t *testing.T) {
|
|||
|
||||
n, err := readerWriter.Read(buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, 5)
|
||||
require.Equal(t, 5, n)
|
||||
|
||||
// close
|
||||
require.NoError(t, readerWriter.Close())
|
||||
|
||||
// read should get error
|
||||
n, err = readerWriter.Read(buffer)
|
||||
require.Equal(t, n, 0)
|
||||
require.Equal(t, 0, n)
|
||||
require.Equal(t, err, fmt.Errorf("closed by handler"))
|
||||
}
|
||||
|
||||
|
@ -562,7 +575,7 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
|
|||
|
||||
n, err := readerWriter.Read(buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, 9)
|
||||
require.Equal(t, 9, n)
|
||||
|
||||
// force another read to read eof
|
||||
_, err = readerWriter.Read(buffer)
|
||||
|
@ -573,7 +586,7 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
|
|||
|
||||
// read should get EOF still
|
||||
n, err = readerWriter.Read(buffer)
|
||||
require.Equal(t, n, 0)
|
||||
require.Equal(t, 0, n)
|
||||
require.Equal(t, err, io.EOF)
|
||||
}
|
||||
|
||||
|
@ -589,6 +602,59 @@ func TestCreateUDPConnReuseSourcePort(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestTCPProxy_FlowRateLimited tests if the pogs.ConnectResponse returns the expected error and metadata, when a
|
||||
// new flow is rate limited.
|
||||
func TestTCPProxy_FlowRateLimited(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Start a UDP Listener for QUIC.
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr)
|
||||
require.NoError(t, err)
|
||||
defer udpListener.Close()
|
||||
|
||||
quicTransport := &quic.Transport{Conn: udpListener, ConnectionIDLength: 16}
|
||||
quicListener, err := quicTransport.Listen(testTLSServerConfig, testQUICConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
|
||||
session, err := quicListener.Accept(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
quicStream, err := session.OpenStreamSync(context.Background())
|
||||
assert.NoError(t, err)
|
||||
stream := cfdquic.NewSafeStreamCloser(quicStream, defaultQUICTimeout, &log)
|
||||
|
||||
reqClientStream := rpcquic.RequestClientStream{ReadWriteCloser: stream}
|
||||
err = reqClientStream.WriteConnectRequestData("rate-limit-me", pogs.ConnectionTypeTCP)
|
||||
assert.NoError(t, err)
|
||||
|
||||
response, err := reqClientStream.ReadConnectResponseData()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Got Rate Limited
|
||||
assert.NotEmpty(t, response.Error)
|
||||
assert.Contains(t, response.Metadata, pogs.ErrorFlowConnectRateLimitedMetadata)
|
||||
}()
|
||||
|
||||
tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(0))
|
||||
|
||||
connDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(connDone)
|
||||
_ = tunnelConn.Serve(ctx)
|
||||
}()
|
||||
|
||||
<-serverDone
|
||||
cancel()
|
||||
<-connDone
|
||||
}
|
||||
|
||||
func testCreateUDPConnReuseSourcePortForEdgeIP(t *testing.T, edgeIP netip.AddrPort) {
|
||||
logger := zerolog.Nop()
|
||||
conn, err := createUDPConnForConnIndex(0, nil, edgeIP, &logger)
|
||||
|
@ -669,6 +735,7 @@ func serveSession(ctx context.Context, datagramConn *datagramV2Connection, edgeQ
|
|||
unregisterReason: expectedReason,
|
||||
calledUnregisterChan: unregisterFromEdgeChan,
|
||||
}
|
||||
// nolint: testifylint
|
||||
go runRPCServer(ctx, edgeQUICSession, sessionRPCServer, nil, t)
|
||||
|
||||
<-unregisterFromEdgeChan
|
||||
|
@ -729,6 +796,7 @@ func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionI
|
|||
|
||||
func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) (TunnelConnection, *datagramV2Connection) {
|
||||
tlsClientConfig := &tls.Config{
|
||||
// nolint: gosec
|
||||
InsecureSkipVerify: true,
|
||||
NextProtos: []string{"argotunnel"},
|
||||
}
|
||||
|
@ -747,6 +815,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
|
|||
index,
|
||||
&log,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start a session manager for the connection
|
||||
sessionDemuxChan := make(chan *packet.Session, 4)
|
||||
|
@ -757,7 +826,9 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
|
|||
|
||||
datagramConn := &datagramV2Connection{
|
||||
conn,
|
||||
index,
|
||||
sessionManager,
|
||||
cfdflow.NewLimiter(0),
|
||||
datagramMuxer,
|
||||
packetRouter,
|
||||
15 * time.Second,
|
||||
|
@ -796,6 +867,7 @@ func (m *mockReaderNoopWriter) Close() error {
|
|||
|
||||
// GenerateTLSConfig sets up a bare-bones TLS config for a QUIC server
|
||||
func GenerateTLSConfig() *tls.Config {
|
||||
// nolint: gosec
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -812,6 +884,7 @@ func GenerateTLSConfig() *tls.Config {
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// nolint: gosec
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
NextProtos: []string{"argotunnel"},
|
||||
|
|
|
@ -7,12 +7,15 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
|
||||
"github.com/cloudflare/cloudflared/datagramsession"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
|
@ -39,9 +42,13 @@ type DatagramSessionHandler interface {
|
|||
|
||||
type datagramV2Connection struct {
|
||||
conn quic.Connection
|
||||
index uint8
|
||||
|
||||
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
|
||||
sessionManager datagramsession.Manager
|
||||
// flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
|
||||
flowLimiter cfdflow.Limiter
|
||||
|
||||
// datagramMuxer mux/demux datagrams from quic connection
|
||||
datagramMuxer *cfdquic.DatagramMuxerV2
|
||||
packetRouter *ingress.PacketRouter
|
||||
|
@ -58,6 +65,7 @@ func NewDatagramV2Connection(ctx context.Context,
|
|||
index uint8,
|
||||
rpcTimeout time.Duration,
|
||||
streamWriteTimeout time.Duration,
|
||||
flowLimiter cfdflow.Limiter,
|
||||
logger *zerolog.Logger,
|
||||
) DatagramSessionHandler {
|
||||
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
|
||||
|
@ -66,13 +74,15 @@ func NewDatagramV2Connection(ctx context.Context,
|
|||
packetRouter := ingress.NewPacketRouter(icmpRouter, datagramMuxer, index, logger)
|
||||
|
||||
return &datagramV2Connection{
|
||||
conn,
|
||||
sessionManager,
|
||||
datagramMuxer,
|
||||
packetRouter,
|
||||
rpcTimeout,
|
||||
streamWriteTimeout,
|
||||
logger,
|
||||
conn: conn,
|
||||
index: index,
|
||||
sessionManager: sessionManager,
|
||||
flowLimiter: flowLimiter,
|
||||
datagramMuxer: datagramMuxer,
|
||||
packetRouter: packetRouter,
|
||||
rpcTimeout: rpcTimeout,
|
||||
streamWriteTimeout: streamWriteTimeout,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -109,12 +119,23 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
|
|||
attribute.String("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)),
|
||||
))
|
||||
log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger()
|
||||
|
||||
// Try to start a new session
|
||||
if err := q.flowLimiter.Acquire(management.UDP.String()); err != nil {
|
||||
log.Warn().Msgf("Too many concurrent sessions being handled, rejecting udp proxy to %s:%d", dstIP, dstPort)
|
||||
|
||||
err := pkgerrors.Wrap(err, "failed to start udp session due to rate limiting")
|
||||
tracing.EndWithErrorStatus(registerSpan, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Each session is a series of datagram from an eyeball to a dstIP:dstPort.
|
||||
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
|
||||
originProxy, err := ingress.DialUDP(dstIP, dstPort)
|
||||
if err != nil {
|
||||
log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
|
||||
tracing.EndWithErrorStatus(registerSpan, err)
|
||||
q.flowLimiter.Release()
|
||||
return nil, err
|
||||
}
|
||||
registerSpan.SetAttributes(
|
||||
|
@ -127,10 +148,14 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
|
|||
originProxy.Close()
|
||||
log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
|
||||
tracing.EndWithErrorStatus(registerSpan, err)
|
||||
q.flowLimiter.Release()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go q.serveUDPSession(session, closeAfterIdleHint)
|
||||
go func() {
|
||||
defer q.flowLimiter.Release() // we do the release here, instead of inside the `serveUDPSession` just to keep all acquire/release calls in the same method.
|
||||
q.serveUDPSession(session, closeAfterIdleHint)
|
||||
}()
|
||||
|
||||
log.Debug().
|
||||
Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).
|
||||
|
@ -170,7 +195,7 @@ func (q *datagramV2Connection) serveUDPSession(session *datagramsession.Session,
|
|||
|
||||
// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge
|
||||
func (q *datagramV2Connection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) {
|
||||
q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
|
||||
_ = q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
|
||||
quicStream, err := q.conn.OpenStream()
|
||||
if err != nil {
|
||||
// Log this at debug because this is not an error if session was closed due to lost connection
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
"github.com/cloudflare/cloudflared/mocks"
|
||||
)
|
||||
|
||||
type mockQuicConnection struct {
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) AcceptStream(_ context.Context) (quic.Stream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) AcceptUniStream(_ context.Context) (quic.ReceiveStream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) OpenStream() (quic.Stream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) OpenStreamSync(_ context.Context) (quic.Stream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) OpenUniStream() (quic.SendStream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) OpenUniStreamSync(_ context.Context) (quic.SendStream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) CloseWithError(_ quic.ApplicationErrorCode, s string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) Context() context.Context {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) ConnectionState() quic.ConnectionState {
|
||||
panic("not meant to be called")
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) SendDatagram(_ []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) ReceiveDatagram(_ context.Context) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
conn := &mockQuicConnection{}
|
||||
ctrl := gomock.NewController(t)
|
||||
flowLimiterMock := mocks.NewMockLimiter(ctrl)
|
||||
|
||||
datagramConn := NewDatagramV2Connection(
|
||||
context.Background(),
|
||||
conn,
|
||||
nil,
|
||||
0,
|
||||
0*time.Second,
|
||||
0*time.Second,
|
||||
flowLimiterMock,
|
||||
&log,
|
||||
)
|
||||
|
||||
flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
|
||||
flowLimiterMock.EXPECT().Release().Times(0)
|
||||
|
||||
_, err := datagramConn.RegisterUdpSession(context.Background(), uuid.New(), net.IPv4(0, 0, 0, 0), 1000, 1*time.Second, "")
|
||||
require.ErrorIs(t, err, cfdflow.ErrTooManyActiveFlows)
|
||||
}
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
const (
|
||||
logFieldOriginCertPath = "originCertPath"
|
||||
FedEndpoint = "fed"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
|
@ -32,6 +33,10 @@ func (c User) CertPath() string {
|
|||
return c.certPath
|
||||
}
|
||||
|
||||
func (c User) IsFEDEndpoint() bool {
|
||||
return c.cert.Endpoint == FedEndpoint
|
||||
}
|
||||
|
||||
// Client uses the user credentials to create a Cloudflare API client
|
||||
func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfapi.Client, error) {
|
||||
if apiURL == "" {
|
||||
|
@ -45,7 +50,6 @@ func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfa
|
|||
userAgent,
|
||||
log,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ package credentials
|
|||
import (
|
||||
"io/fs"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -13,8 +13,8 @@ func TestCredentialsRead(t *testing.T) {
|
|||
file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem")
|
||||
require.NoError(t, err)
|
||||
dir := t.TempDir()
|
||||
certPath := path.Join(dir, originCertFile)
|
||||
os.WriteFile(certPath, file, fs.ModePerm)
|
||||
certPath := filepath.Join(dir, originCertFile)
|
||||
_ = os.WriteFile(certPath, file, fs.ModePerm)
|
||||
user, err := Read(certPath, &nopLog)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, certPath, user.CertPath())
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
package credentials
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/rs/zerolog"
|
||||
|
@ -15,19 +17,30 @@ import (
|
|||
|
||||
const (
|
||||
DefaultCredentialFile = "cert.pem"
|
||||
OriginCertFlag = "origincert"
|
||||
)
|
||||
|
||||
type namedTunnelToken struct {
|
||||
type OriginCert struct {
|
||||
ZoneID string `json:"zoneID"`
|
||||
AccountID string `json:"accountID"`
|
||||
APIToken string `json:"apiToken"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
}
|
||||
|
||||
type OriginCert struct {
|
||||
ZoneID string
|
||||
APIToken string
|
||||
AccountID string
|
||||
func (oc *OriginCert) UnmarshalJSON(data []byte) error {
|
||||
var aux struct {
|
||||
ZoneID string `json:"zoneID"`
|
||||
AccountID string `json:"accountID"`
|
||||
APIToken string `json:"apiToken"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &aux); err != nil {
|
||||
return fmt.Errorf("error parsing OriginCert: %v", err)
|
||||
}
|
||||
oc.ZoneID = aux.ZoneID
|
||||
oc.AccountID = aux.AccountID
|
||||
oc.APIToken = aux.APIToken
|
||||
oc.Endpoint = strings.ToLower(aux.Endpoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindDefaultOriginCertPath returns the first path that contains a cert.pem file. If none of the
|
||||
|
@ -42,40 +55,56 @@ func FindDefaultOriginCertPath() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func DecodeOriginCert(blocks []byte) (*OriginCert, error) {
|
||||
return decodeOriginCert(blocks)
|
||||
}
|
||||
|
||||
func (cert *OriginCert) EncodeOriginCert() ([]byte, error) {
|
||||
if cert == nil {
|
||||
return nil, fmt.Errorf("originCert cannot be nil")
|
||||
}
|
||||
buffer, err := json.Marshal(cert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("originCert marshal failed: %v", err)
|
||||
}
|
||||
block := pem.Block{
|
||||
Type: "ARGO TUNNEL TOKEN",
|
||||
Headers: map[string]string{},
|
||||
Bytes: buffer,
|
||||
}
|
||||
var out bytes.Buffer
|
||||
err = pem.Encode(&out, &block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pem encoding failed: %v", err)
|
||||
}
|
||||
return out.Bytes(), nil
|
||||
}
|
||||
|
||||
func decodeOriginCert(blocks []byte) (*OriginCert, error) {
|
||||
if len(blocks) == 0 {
|
||||
return nil, fmt.Errorf("Cannot decode empty certificate")
|
||||
return nil, fmt.Errorf("cannot decode empty certificate")
|
||||
}
|
||||
originCert := OriginCert{}
|
||||
block, rest := pem.Decode(blocks)
|
||||
for {
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
for block != nil {
|
||||
switch block.Type {
|
||||
case "PRIVATE KEY", "CERTIFICATE":
|
||||
// this is for legacy purposes.
|
||||
break
|
||||
case "ARGO TUNNEL TOKEN":
|
||||
if originCert.ZoneID != "" || originCert.APIToken != "" {
|
||||
return nil, fmt.Errorf("Found multiple tokens in the certificate")
|
||||
return nil, fmt.Errorf("found multiple tokens in the certificate")
|
||||
}
|
||||
// The token is a string,
|
||||
// Try the newer JSON format
|
||||
ntt := namedTunnelToken{}
|
||||
if err := json.Unmarshal(block.Bytes, &ntt); err == nil {
|
||||
originCert.ZoneID = ntt.ZoneID
|
||||
originCert.APIToken = ntt.APIToken
|
||||
originCert.AccountID = ntt.AccountID
|
||||
}
|
||||
_ = json.Unmarshal(block.Bytes, &originCert)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type)
|
||||
return nil, fmt.Errorf("unknown block %s in the certificate", block.Type)
|
||||
}
|
||||
block, rest = pem.Decode(rest)
|
||||
}
|
||||
|
||||
if originCert.ZoneID == "" || originCert.APIToken == "" {
|
||||
return nil, fmt.Errorf("Missing token in the certificate")
|
||||
return nil, fmt.Errorf("missing token in the certificate")
|
||||
}
|
||||
|
||||
return &originCert, nil
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
@ -16,27 +16,25 @@ const (
|
|||
originCertFile = "cert.pem"
|
||||
)
|
||||
|
||||
var (
|
||||
nopLog = zerolog.Nop().With().Logger()
|
||||
)
|
||||
var nopLog = zerolog.Nop().With().Logger()
|
||||
|
||||
func TestLoadOriginCert(t *testing.T) {
|
||||
cert, err := decodeOriginCert([]byte{})
|
||||
assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err)
|
||||
assert.Equal(t, fmt.Errorf("cannot decode empty certificate"), err)
|
||||
assert.Nil(t, cert)
|
||||
|
||||
blocks, err := os.ReadFile("test-cert-unknown-block.pem")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
cert, err = decodeOriginCert(blocks)
|
||||
assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err)
|
||||
assert.Equal(t, fmt.Errorf("unknown block RSA PRIVATE KEY in the certificate"), err)
|
||||
assert.Nil(t, cert)
|
||||
}
|
||||
|
||||
func TestJSONArgoTunnelTokenEmpty(t *testing.T) {
|
||||
blocks, err := os.ReadFile("test-cert-no-token.pem")
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
cert, err := decodeOriginCert(blocks)
|
||||
assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err)
|
||||
assert.Equal(t, fmt.Errorf("missing token in the certificate"), err)
|
||||
assert.Nil(t, cert)
|
||||
}
|
||||
|
||||
|
@ -52,51 +50,21 @@ func TestJSONArgoTunnelToken(t *testing.T) {
|
|||
|
||||
func CloudflareTunnelTokenTest(t *testing.T, path string) {
|
||||
blocks, err := os.ReadFile(path)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
cert, err := decodeOriginCert(blocks)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, cert)
|
||||
assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID)
|
||||
key := "test-service-key"
|
||||
assert.Equal(t, key, cert.APIToken)
|
||||
}
|
||||
|
||||
type mockFile struct {
|
||||
path string
|
||||
data []byte
|
||||
err error
|
||||
}
|
||||
|
||||
type mockFileSystem struct {
|
||||
files map[string]mockFile
|
||||
}
|
||||
|
||||
func newMockFileSystem(files ...mockFile) *mockFileSystem {
|
||||
fs := mockFileSystem{map[string]mockFile{}}
|
||||
for _, f := range files {
|
||||
fs.files[f.path] = f
|
||||
}
|
||||
return &fs
|
||||
}
|
||||
|
||||
func (fs *mockFileSystem) ReadFile(path string) ([]byte, error) {
|
||||
if f, ok := fs.files[path]; ok {
|
||||
return f.data, f.err
|
||||
}
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
func (fs *mockFileSystem) ValidFilePath(path string) bool {
|
||||
_, exists := fs.files[path]
|
||||
return exists
|
||||
}
|
||||
|
||||
func TestFindOriginCert_Valid(t *testing.T) {
|
||||
file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem")
|
||||
require.NoError(t, err)
|
||||
dir := t.TempDir()
|
||||
certPath := path.Join(dir, originCertFile)
|
||||
os.WriteFile(certPath, file, fs.ModePerm)
|
||||
certPath := filepath.Join(dir, originCertFile)
|
||||
_ = os.WriteFile(certPath, file, fs.ModePerm)
|
||||
path, err := FindOriginCert(certPath, &nopLog)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, certPath, path)
|
||||
|
@ -104,7 +72,32 @@ func TestFindOriginCert_Valid(t *testing.T) {
|
|||
|
||||
func TestFindOriginCert_Missing(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPath := path.Join(dir, originCertFile)
|
||||
certPath := filepath.Join(dir, originCertFile)
|
||||
_, err := FindOriginCert(certPath, &nopLog)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestEncodeDecodeOriginCert(t *testing.T) {
|
||||
cert := OriginCert{
|
||||
ZoneID: "zone",
|
||||
AccountID: "account",
|
||||
APIToken: "token",
|
||||
Endpoint: "FED",
|
||||
}
|
||||
blocks, err := cert.EncodeOriginCert()
|
||||
require.NoError(t, err)
|
||||
decodedCert, err := DecodeOriginCert(blocks)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, cert)
|
||||
assert.Equal(t, "zone", decodedCert.ZoneID)
|
||||
assert.Equal(t, "account", decodedCert.AccountID)
|
||||
assert.Equal(t, "token", decodedCert.APIToken)
|
||||
assert.Equal(t, FedEndpoint, decodedCert.Endpoint)
|
||||
}
|
||||
|
||||
func TestEncodeDecodeNilOriginCert(t *testing.T) {
|
||||
var cert *OriginCert
|
||||
blocks, err := cert.EncodeOriginCert()
|
||||
assert.Equal(t, fmt.Errorf("originCert cannot be nil"), err)
|
||||
require.Nil(t, blocks)
|
||||
}
|
||||
|
|
|
@ -87,3 +87,4 @@ M2i4QoOFcSKIG+v4SuvgEJHgG8vGvxh2qlSxnMWuPV+7/1P5ATLqDj1PlKms+BNR
|
|||
y7sc5AT9PclkL3Y9MNzOu0LXyBkGYcl8M0EQfLv9VPbWT+NXiMg/O2CHiT02pAAz
|
||||
uQicoQq3yzeQh20wtrtaXzTNmA==
|
||||
-----END RSA PRIVATE KEY-----
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
FROM golang:1.22.5 as builder
|
||||
FROM golang:1.22.10 as builder
|
||||
ENV GO111MODULE=on \
|
||||
CGO_ENABLED=0
|
||||
WORKDIR /go/src/github.com/cloudflare/cloudflared/
|
||||
|
|
|
@ -9,7 +9,7 @@ import (
|
|||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
)
|
||||
|
||||
type httpClient struct {
|
||||
|
@ -86,12 +86,12 @@ func (client *httpClient) GetLogConfiguration(ctx context.Context) (*LogConfigur
|
|||
return nil, fmt.Errorf("error convertin pid to int: %w", err)
|
||||
}
|
||||
|
||||
logFile, exists := data[logger.LogFileFlag]
|
||||
logFile, exists := data[cfdflags.LogFile]
|
||||
if exists {
|
||||
return &LogConfiguration{logFile, "", uid}, nil
|
||||
}
|
||||
|
||||
logDirectory, exists := data[logger.LogDirectoryFlag]
|
||||
logDirectory, exists := data[cfdflags.LogDirectory]
|
||||
if exists {
|
||||
return &LogConfiguration{"", logDirectory, uid}, nil
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ func (client *httpClient) GetSystemInformation(ctx context.Context, writer io.Wr
|
|||
return err
|
||||
}
|
||||
|
||||
return copyToWriter(response, writer)
|
||||
return copyJSONToWriter(response, writer)
|
||||
}
|
||||
|
||||
func (client *httpClient) GetMetrics(ctx context.Context, writer io.Writer) error {
|
||||
|
@ -159,7 +159,7 @@ func (client *httpClient) GetTunnelConfiguration(ctx context.Context, writer io.
|
|||
return err
|
||||
}
|
||||
|
||||
return copyToWriter(response, writer)
|
||||
return copyJSONToWriter(response, writer)
|
||||
}
|
||||
|
||||
func (client *httpClient) GetCliConfiguration(ctx context.Context, writer io.Writer) error {
|
||||
|
@ -168,7 +168,7 @@ func (client *httpClient) GetCliConfiguration(ctx context.Context, writer io.Wri
|
|||
return err
|
||||
}
|
||||
|
||||
return copyToWriter(response, writer)
|
||||
return copyJSONToWriter(response, writer)
|
||||
}
|
||||
|
||||
func copyToWriter(response *http.Response, writer io.Writer) error {
|
||||
|
@ -176,7 +176,29 @@ func copyToWriter(response *http.Response, writer io.Writer) error {
|
|||
|
||||
_, err := io.Copy(writer, response.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing metrics: %w", err)
|
||||
return fmt.Errorf("error writing response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyJSONToWriter(response *http.Response, writer io.Writer) error {
|
||||
defer response.Body.Close()
|
||||
|
||||
var data interface{}
|
||||
|
||||
decoder := json.NewDecoder(response.Body)
|
||||
|
||||
err := decoder.Decode(&data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("diagnostic client error whilst reading response: %w", err)
|
||||
}
|
||||
|
||||
encoder := newFormattedEncoder(writer)
|
||||
|
||||
err = encoder.Encode(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("diagnostic client error whilst writing json: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -40,6 +41,17 @@ type taskResult struct {
|
|||
path string
|
||||
}
|
||||
|
||||
func (result taskResult) MarshalJSON() ([]byte, error) {
|
||||
s := map[string]string{
|
||||
"result": result.Result,
|
||||
}
|
||||
if result.Err != nil {
|
||||
s["error"] = result.Err.Error()
|
||||
}
|
||||
|
||||
return json.Marshal(s)
|
||||
}
|
||||
|
||||
// Struct used to hold the results of different routines executing the network collection.
|
||||
type networkCollectionResult struct {
|
||||
name string
|
||||
|
@ -151,17 +163,7 @@ func collectNetworkResultRoutine(
|
|||
}
|
||||
|
||||
hops, raw, err := collector.Collect(ctx, network.NewTraceOptions(hopsNo, timeout, hostname, useIPv4))
|
||||
if err != nil {
|
||||
if raw == "" {
|
||||
// An error happened and there is no raw output
|
||||
results <- networkCollectionResult{name, nil, "", err}
|
||||
} else {
|
||||
// An error happened and there is raw output then write to file
|
||||
results <- networkCollectionResult{name, nil, raw, nil}
|
||||
}
|
||||
} else {
|
||||
results <- networkCollectionResult{name, hops, raw, nil}
|
||||
}
|
||||
results <- networkCollectionResult{name, hops, raw, err}
|
||||
}
|
||||
|
||||
func gatherNetworkInformation(ctx context.Context) map[string]networkCollectionResult {
|
||||
|
@ -198,10 +200,6 @@ func gatherNetworkInformation(ctx context.Context) map[string]networkCollectionR
|
|||
|
||||
for range len(hostAndIPversionPairs) {
|
||||
result := <-results
|
||||
if result.err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
resultMap[result.name] = result
|
||||
}
|
||||
|
||||
|
@ -238,22 +236,30 @@ func rawNetworkInformationWriter(resultMap map[string]networkCollectionResult) (
|
|||
|
||||
defer networkDumpHandle.Close()
|
||||
|
||||
var exitErr error
|
||||
|
||||
for k, v := range resultMap {
|
||||
if v.err != nil {
|
||||
if exitErr == nil {
|
||||
exitErr = v.err
|
||||
}
|
||||
|
||||
_, err := networkDumpHandle.WriteString(k + "\nno content\n")
|
||||
if err != nil {
|
||||
return networkDumpHandle.Name(), fmt.Errorf("error writing 'no content' to raw network file: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, err := networkDumpHandle.WriteString(k + "\n" + v.raw + "\n")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error writing raw network information: %w", err)
|
||||
return networkDumpHandle.Name(), fmt.Errorf("error writing raw network information: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return networkDumpHandle.Name(), nil
|
||||
return networkDumpHandle.Name(), exitErr
|
||||
}
|
||||
|
||||
func jsonNetworkInformationWriter(resultMap map[string]networkCollectionResult) (string, error) {
|
||||
jsonMap := make(map[string][]*network.Hop, len(resultMap))
|
||||
for k, v := range resultMap {
|
||||
jsonMap[k] = v.info
|
||||
}
|
||||
|
||||
networkDumpHandle, err := os.Create(filepath.Join(os.TempDir(), networkBaseName))
|
||||
if err != nil {
|
||||
return "", ErrCreatingTemporaryFile
|
||||
|
@ -261,12 +267,25 @@ func jsonNetworkInformationWriter(resultMap map[string]networkCollectionResult)
|
|||
|
||||
defer networkDumpHandle.Close()
|
||||
|
||||
err = json.NewEncoder(networkDumpHandle).Encode(jsonMap)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error encoding network information results: %w", err)
|
||||
encoder := newFormattedEncoder(networkDumpHandle)
|
||||
|
||||
var exitErr error
|
||||
|
||||
jsonMap := make(map[string][]*network.Hop, len(resultMap))
|
||||
for k, v := range resultMap {
|
||||
jsonMap[k] = v.info
|
||||
|
||||
if exitErr == nil && v.err != nil {
|
||||
exitErr = v.err
|
||||
}
|
||||
}
|
||||
|
||||
return networkDumpHandle.Name(), nil
|
||||
err = encoder.Encode(jsonMap)
|
||||
if err != nil {
|
||||
return networkDumpHandle.Name(), fmt.Errorf("error encoding network information results: %w", err)
|
||||
}
|
||||
|
||||
return networkDumpHandle.Name(), exitErr
|
||||
}
|
||||
|
||||
func collectFromEndpointAdapter(collect collectToWriterFunc, fileName string) collectFunc {
|
||||
|
@ -279,7 +298,7 @@ func collectFromEndpointAdapter(collect collectToWriterFunc, fileName string) co
|
|||
|
||||
err = collect(ctx, dumpHandle)
|
||||
if err != nil {
|
||||
return "", ErrCreatingTemporaryFile
|
||||
return dumpHandle.Name(), fmt.Errorf("error running collector: %w", err)
|
||||
}
|
||||
|
||||
return dumpHandle.Name(), nil
|
||||
|
@ -300,13 +319,16 @@ func tunnelStateCollectEndpointAdapter(client HTTPClient, tunnel *TunnelState, f
|
|||
tunnel = tunnelResponse
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(writer)
|
||||
encoder := newFormattedEncoder(writer)
|
||||
|
||||
err := encoder.Encode(tunnel)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("error encoding tunnel state: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return collectFromEndpointAdapter(endpointFunc, fileName)
|
||||
}
|
||||
|
||||
|
@ -324,15 +346,14 @@ func resolveInstanceBaseURL(
|
|||
addresses []string,
|
||||
) (*url.URL, *TunnelState, []*AddressableTunnelState, error) {
|
||||
if metricsServerAddress != "" {
|
||||
if !strings.HasPrefix(metricsServerAddress, "http://") {
|
||||
metricsServerAddress = "http://" + metricsServerAddress
|
||||
}
|
||||
url, err := url.Parse(metricsServerAddress)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("provided address is not valid: %w", err)
|
||||
}
|
||||
|
||||
if url.Scheme == "" {
|
||||
url.Scheme = "http://"
|
||||
}
|
||||
|
||||
return url, nil, nil, nil
|
||||
}
|
||||
|
||||
|
@ -421,7 +442,9 @@ func createTaskReport(taskReport map[string]taskResult) (string, error) {
|
|||
}
|
||||
defer dumpHandle.Close()
|
||||
|
||||
err = json.NewEncoder(dumpHandle).Encode(taskReport)
|
||||
encoder := newFormattedEncoder(dumpHandle)
|
||||
|
||||
err = encoder.Encode(taskReport)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error encoding task results: %w", err)
|
||||
}
|
||||
|
@ -511,9 +534,15 @@ func RunDiagnostic(
|
|||
jobsReport := runJobs(ctx, jobs, log)
|
||||
paths := make([]string, 0)
|
||||
|
||||
var gerr error
|
||||
|
||||
for _, v := range jobsReport {
|
||||
paths = append(paths, v.path)
|
||||
|
||||
if gerr == nil && v.Err != nil {
|
||||
gerr = v.Err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if !errors.Is(v.Err, ErrCreatingTemporaryFile) {
|
||||
os.Remove(v.path)
|
||||
|
@ -523,14 +552,10 @@ func RunDiagnostic(
|
|||
|
||||
zipfile, err := CreateDiagnosticZipFile(zipName, paths)
|
||||
if err != nil {
|
||||
if zipfile != "" {
|
||||
os.Remove(zipfile)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info().Msgf("Diagnostic file written: %v", zipfile)
|
||||
|
||||
return nil, nil
|
||||
return nil, gerr
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package diagnostic
|
|||
import (
|
||||
"archive/zip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
|
@ -138,3 +139,10 @@ func FindMetricsServer(
|
|||
|
||||
return nil, instances, ErrMultipleMetricsServerFound
|
||||
}
|
||||
|
||||
// newFormattedEncoder return a JSON encoder with identation
|
||||
func newFormattedEncoder(w io.Writer) *json.Encoder {
|
||||
encoder := json.NewEncoder(w)
|
||||
encoder.SetIndent("", " ")
|
||||
return encoder
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ func helperCreateServer(t *testing.T, listeners *gracenet.Net, tunnelID uuid.UUI
|
|||
require.NoError(t, err)
|
||||
log := zerolog.Nop()
|
||||
tracker := tunnelstate.NewConnTracker(&log)
|
||||
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tunnelID, connectorID, tracker, nil, []string{})
|
||||
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tunnelID, connectorID, tracker, map[string]string{}, []string{})
|
||||
router := http.NewServeMux()
|
||||
router.HandleFunc("/diag/tunnel", handler.TunnelStateHandler)
|
||||
server := &http.Server{
|
||||
|
|
|
@ -5,15 +5,12 @@ import (
|
|||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/tunnelstate"
|
||||
)
|
||||
|
||||
|
@ -24,8 +21,8 @@ type Handler struct {
|
|||
tunnelID uuid.UUID
|
||||
connectorID uuid.UUID
|
||||
tracker *tunnelstate.ConnTracker
|
||||
cli *cli.Context
|
||||
flagInclusionList []string
|
||||
cliFlags map[string]string
|
||||
icmpSources []string
|
||||
}
|
||||
|
||||
func NewDiagnosticHandler(
|
||||
|
@ -35,14 +32,15 @@ func NewDiagnosticHandler(
|
|||
tunnelID uuid.UUID,
|
||||
connectorID uuid.UUID,
|
||||
tracker *tunnelstate.ConnTracker,
|
||||
cli *cli.Context,
|
||||
flagInclusionList []string,
|
||||
cliFlags map[string]string,
|
||||
icmpSources []string,
|
||||
) *Handler {
|
||||
logger := log.With().Logger()
|
||||
if timeout == 0 {
|
||||
timeout = defaultCollectorTimeout
|
||||
}
|
||||
|
||||
cliFlags[configurationKeyUID] = strconv.Itoa(os.Getuid())
|
||||
return &Handler{
|
||||
log: &logger,
|
||||
timeout: timeout,
|
||||
|
@ -50,8 +48,8 @@ func NewDiagnosticHandler(
|
|||
tunnelID: tunnelID,
|
||||
connectorID: connectorID,
|
||||
tracker: tracker,
|
||||
cli: cli,
|
||||
flagInclusionList: flagInclusionList,
|
||||
cliFlags: cliFlags,
|
||||
icmpSources: icmpSources,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -61,6 +59,11 @@ func (handler *Handler) InstallEndpoints(router *http.ServeMux) {
|
|||
router.HandleFunc(systemInformationEndpoint, handler.SystemHandler)
|
||||
}
|
||||
|
||||
type SystemInformationResponse struct {
|
||||
Info *SystemInformation `json:"info"`
|
||||
Err error `json:"errors"`
|
||||
}
|
||||
|
||||
func (handler *Handler) SystemHandler(writer http.ResponseWriter, request *http.Request) {
|
||||
logger := handler.log.With().Str(collectorField, systemCollectorName).Logger()
|
||||
logger.Info().Msg("Collection started")
|
||||
|
@ -71,30 +74,15 @@ func (handler *Handler) SystemHandler(writer http.ResponseWriter, request *http.
|
|||
|
||||
defer cancel()
|
||||
|
||||
info, rawInfo, err := handler.systemCollector.Collect(ctx)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("error occurred whilst collecting system information")
|
||||
info, err := handler.systemCollector.Collect(ctx)
|
||||
|
||||
if rawInfo != "" {
|
||||
logger.Info().Msg("using raw information fallback")
|
||||
bytes := []byte(rawInfo)
|
||||
writeResponse(writer, bytes, &logger)
|
||||
} else {
|
||||
logger.Error().Msg("no raw information available")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if info == nil {
|
||||
logger.Error().Msgf("system information collection is nil")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
response := SystemInformationResponse{
|
||||
Info: info,
|
||||
Err: err,
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(writer)
|
||||
|
||||
err = encoder.Encode(info)
|
||||
err = encoder.Encode(response)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msgf("error occurred whilst serializing information")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
|
@ -105,6 +93,7 @@ type TunnelState struct {
|
|||
TunnelID uuid.UUID `json:"tunnelID,omitempty"`
|
||||
ConnectorID uuid.UUID `json:"connectorID,omitempty"`
|
||||
Connections []tunnelstate.IndexedConnectionInfo `json:"connections,omitempty"`
|
||||
ICMPSources []string `json:"icmp_sources,omitempty"`
|
||||
}
|
||||
|
||||
func (handler *Handler) TunnelStateHandler(writer http.ResponseWriter, _ *http.Request) {
|
||||
|
@ -117,6 +106,7 @@ func (handler *Handler) TunnelStateHandler(writer http.ResponseWriter, _ *http.R
|
|||
handler.tunnelID,
|
||||
handler.connectorID,
|
||||
handler.tracker.GetActiveConnections(),
|
||||
handler.icmpSources,
|
||||
}
|
||||
encoder := json.NewEncoder(writer)
|
||||
|
||||
|
@ -135,68 +125,15 @@ func (handler *Handler) ConfigurationHandler(writer http.ResponseWriter, _ *http
|
|||
log.Info().Msg("Collection finished")
|
||||
}()
|
||||
|
||||
flagsNames := handler.cli.FlagNames()
|
||||
flags := make(map[string]string, len(flagsNames))
|
||||
|
||||
for _, flag := range flagsNames {
|
||||
value := handler.cli.String(flag)
|
||||
|
||||
// empty values are not relevant
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// exclude flags that are sensitive
|
||||
isIncluded := handler.isFlagIncluded(flag)
|
||||
if !isIncluded {
|
||||
continue
|
||||
}
|
||||
|
||||
switch flag {
|
||||
case logger.LogDirectoryFlag:
|
||||
fallthrough
|
||||
case logger.LogFileFlag:
|
||||
{
|
||||
// the log directory may be relative to the instance thus it must be resolved
|
||||
absolute, err := filepath.Abs(value)
|
||||
if err != nil {
|
||||
handler.log.Error().Err(err).Msgf("could not convert %s path to absolute", flag)
|
||||
} else {
|
||||
flags[flag] = absolute
|
||||
}
|
||||
}
|
||||
default:
|
||||
flags[flag] = value
|
||||
}
|
||||
}
|
||||
|
||||
// The UID is included to help the
|
||||
// diagnostic tool to understand
|
||||
// if this instance is managed or not.
|
||||
flags[configurationKeyUID] = strconv.Itoa(os.Getuid())
|
||||
encoder := json.NewEncoder(writer)
|
||||
|
||||
err := encoder.Encode(flags)
|
||||
err := encoder.Encode(handler.cliFlags)
|
||||
if err != nil {
|
||||
handler.log.Error().Err(err).Msgf("error occurred whilst serializing response")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (handler *Handler) isFlagIncluded(flag string) bool {
|
||||
isIncluded := false
|
||||
|
||||
for _, include := range handler.flagInclusionList {
|
||||
if include == flag {
|
||||
isIncluded = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return isIncluded
|
||||
}
|
||||
|
||||
func writeResponse(w http.ResponseWriter, bytes []byte, logger *zerolog.Logger) {
|
||||
bytesWritten, err := w.Write(bytes)
|
||||
if err != nil {
|
||||
|
|
|
@ -4,47 +4,32 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/diagnostic"
|
||||
"github.com/cloudflare/cloudflared/tunnelstate"
|
||||
)
|
||||
|
||||
type SystemCollectorMock struct{}
|
||||
type SystemCollectorMock struct {
|
||||
systemInfo *diagnostic.SystemInformation
|
||||
err error
|
||||
}
|
||||
|
||||
const (
|
||||
systemInformationKey = "sikey"
|
||||
rawInformationKey = "rikey"
|
||||
errorKey = "errkey"
|
||||
)
|
||||
|
||||
func buildCliContext(t *testing.T, flags map[string]string) *cli.Context {
|
||||
t.Helper()
|
||||
|
||||
flagSet := flag.NewFlagSet("", flag.PanicOnError)
|
||||
ctx := cli.NewContext(cli.NewApp(), flagSet, nil)
|
||||
|
||||
for k, v := range flags {
|
||||
flagSet.String(k, v, "")
|
||||
err := ctx.Set(k, v)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker {
|
||||
t.Helper()
|
||||
|
||||
|
@ -63,25 +48,8 @@ func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnecti
|
|||
return tracker
|
||||
}
|
||||
|
||||
func setCtxValuesForSystemCollector(
|
||||
systemInfo *diagnostic.SystemInformation,
|
||||
rawInfo string,
|
||||
err error,
|
||||
) context.Context {
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, systemInformationKey, systemInfo)
|
||||
ctx = context.WithValue(ctx, rawInformationKey, rawInfo)
|
||||
ctx = context.WithValue(ctx, errorKey, err)
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (*SystemCollectorMock) Collect(ctx context.Context) (*diagnostic.SystemInformation, string, error) {
|
||||
si, _ := ctx.Value(systemInformationKey).(*diagnostic.SystemInformation)
|
||||
ri, _ := ctx.Value(rawInformationKey).(string)
|
||||
err, _ := ctx.Value(errorKey).(error)
|
||||
|
||||
return si, ri, err
|
||||
func (collector *SystemCollectorMock) Collect(context.Context) (*diagnostic.SystemInformation, error) {
|
||||
return collector.systemInfo, collector.err
|
||||
}
|
||||
|
||||
func TestSystemHandler(t *testing.T) {
|
||||
|
@ -91,7 +59,6 @@ func TestSystemHandler(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
systemInfo *diagnostic.SystemInformation
|
||||
rawInfo string
|
||||
err error
|
||||
statusCode int
|
||||
}{
|
||||
|
@ -100,48 +67,39 @@ func TestSystemHandler(t *testing.T) {
|
|||
systemInfo: diagnostic.NewSystemInformation(
|
||||
0, 0, 0, 0,
|
||||
"string", "string", "string", "string",
|
||||
"string", "string", nil,
|
||||
"string", "string",
|
||||
runtime.Version(), runtime.GOARCH, nil,
|
||||
),
|
||||
rawInfo: "",
|
||||
|
||||
err: nil,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "on error and raw info", systemInfo: nil,
|
||||
rawInfo: "raw info", err: errors.New("an error"), statusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "on error and no raw info", systemInfo: nil,
|
||||
rawInfo: "", err: errors.New("an error"), statusCode: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "malformed response", systemInfo: nil, rawInfo: "", err: nil, statusCode: http.StatusInternalServerError,
|
||||
err: errors.New("an error"), statusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tCase := range tests {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil, nil, nil)
|
||||
handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{
|
||||
systemInfo: tCase.systemInfo,
|
||||
err: tCase.err,
|
||||
}, uuid.New(), uuid.New(), nil, map[string]string{}, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx := setCtxValuesForSystemCollector(tCase.systemInfo, tCase.rawInfo, tCase.err)
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/syste,", nil)
|
||||
ctx := context.Background()
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/system", nil)
|
||||
require.NoError(t, err)
|
||||
handler.SystemHandler(recorder, request)
|
||||
|
||||
assert.Equal(t, tCase.statusCode, recorder.Code)
|
||||
if tCase.statusCode == http.StatusOK && tCase.systemInfo != nil {
|
||||
var response diagnostic.SystemInformation
|
||||
|
||||
var response diagnostic.SystemInformationResponse
|
||||
decoder := json.NewDecoder(recorder.Body)
|
||||
err = decoder.Decode(&response)
|
||||
err := decoder.Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tCase.systemInfo, &response)
|
||||
} else if tCase.statusCode == http.StatusOK && tCase.rawInfo != "" {
|
||||
rawBytes, err := io.ReadAll(recorder.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tCase.rawInfo, string(rawBytes))
|
||||
assert.Equal(t, tCase.systemInfo, response.Info)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -156,6 +114,7 @@ func TestTunnelStateHandler(t *testing.T) {
|
|||
tunnelID uuid.UUID
|
||||
clientID uuid.UUID
|
||||
connections []tunnelstate.IndexedConnectionInfo
|
||||
icmpSources []string
|
||||
}{
|
||||
{
|
||||
name: "case1",
|
||||
|
@ -166,6 +125,7 @@ func TestTunnelStateHandler(t *testing.T) {
|
|||
name: "case2",
|
||||
tunnelID: uuid.New(),
|
||||
clientID: uuid.New(),
|
||||
icmpSources: []string{"172.17.0.3", "::1"},
|
||||
connections: []tunnelstate.IndexedConnectionInfo{{
|
||||
ConnectionInfo: tunnelstate.ConnectionInfo{
|
||||
IsConnected: true,
|
||||
|
@ -181,7 +141,16 @@ func TestTunnelStateHandler(t *testing.T) {
|
|||
t.Run(tCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tracker := newTrackerFromConns(t, tCase.connections)
|
||||
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tCase.tunnelID, tCase.clientID, tracker, nil, nil)
|
||||
handler := diagnostic.NewDiagnosticHandler(
|
||||
&log,
|
||||
0,
|
||||
nil,
|
||||
tCase.tunnelID,
|
||||
tCase.clientID,
|
||||
tracker,
|
||||
map[string]string{},
|
||||
tCase.icmpSources,
|
||||
)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.TunnelStateHandler(recorder, nil)
|
||||
decoder := json.NewDecoder(recorder.Body)
|
||||
|
@ -193,6 +162,7 @@ func TestTunnelStateHandler(t *testing.T) {
|
|||
assert.Equal(t, tCase.tunnelID, response.TunnelID)
|
||||
assert.Equal(t, tCase.clientID, response.ConnectorID)
|
||||
assert.Equal(t, tCase.connections, response.Connections)
|
||||
assert.Equal(t, tCase.icmpSources, response.ICMPSources)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -217,10 +187,10 @@ func TestConfigurationHandler(t *testing.T) {
|
|||
{
|
||||
name: "cli with flags",
|
||||
flags: map[string]string{
|
||||
"a": "a",
|
||||
"b": "a",
|
||||
"c": "a",
|
||||
"d": "a",
|
||||
"uid": "0",
|
||||
},
|
||||
expected: map[string]string{
|
||||
"b": "a",
|
||||
|
@ -233,11 +203,11 @@ func TestConfigurationHandler(t *testing.T) {
|
|||
|
||||
for _, tCase := range tests {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var response map[string]string
|
||||
|
||||
t.Parallel()
|
||||
ctx := buildCliContext(t, tCase.flags)
|
||||
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, ctx, []string{"b", "c", "d"})
|
||||
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, tCase.flags, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ConfigurationHandler(recorder, nil)
|
||||
decoder := json.NewDecoder(recorder.Body)
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
@ -11,6 +12,8 @@ import (
|
|||
const (
|
||||
linuxManagedLogsPath = "/var/log/cloudflared.err"
|
||||
darwinManagedLogsPath = "/Library/Logs/com.cloudflare.cloudflared.err.log"
|
||||
linuxServiceConfigurationPath = "/etc/systemd/system/cloudflared.service"
|
||||
linuxSystemdPath = "/run/systemd/system"
|
||||
)
|
||||
|
||||
type HostLogCollector struct {
|
||||
|
@ -23,6 +26,28 @@ func NewHostLogCollector(client HTTPClient) *HostLogCollector {
|
|||
}
|
||||
}
|
||||
|
||||
func extractLogsFromJournalCtl(ctx context.Context) (*LogInformation, error) {
|
||||
tmp := os.TempDir()
|
||||
|
||||
outputHandle, err := os.Create(filepath.Join(tmp, logFilename))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening output file: %w", err)
|
||||
}
|
||||
|
||||
defer outputHandle.Close()
|
||||
|
||||
command := exec.CommandContext(
|
||||
ctx,
|
||||
"journalctl",
|
||||
"--since",
|
||||
"2 weeks ago",
|
||||
"-u",
|
||||
"cloudflared.service",
|
||||
)
|
||||
|
||||
return PipeCommandOutputToFile(command, outputHandle)
|
||||
}
|
||||
|
||||
func getServiceLogPath() (string, error) {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
|
@ -55,6 +80,13 @@ func (collector *HostLogCollector) Collect(ctx context.Context) (*LogInformation
|
|||
}
|
||||
|
||||
if logConfiguration.uid == 0 {
|
||||
_, statSystemdErr := os.Stat(linuxServiceConfigurationPath)
|
||||
|
||||
_, statServiceConfigurationErr := os.Stat(linuxServiceConfigurationPath)
|
||||
if statSystemdErr == nil && statServiceConfigurationErr == nil && runtime.GOOS == "linux" {
|
||||
return extractLogsFromJournalCtl(ctx)
|
||||
}
|
||||
|
||||
path, err := getServiceLogPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -12,7 +12,16 @@ func PipeCommandOutputToFile(command *exec.Cmd, outputHandle *os.File) (*LogInfo
|
|||
stdoutReader, err := command.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error retrieving output from command '%s': %w",
|
||||
"error retrieving stdout from command '%s': %w",
|
||||
command.String(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
stderrReader, err := command.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error retrieving stderr from command '%s': %w",
|
||||
command.String(),
|
||||
err,
|
||||
)
|
||||
|
@ -29,7 +38,17 @@ func PipeCommandOutputToFile(command *exec.Cmd, outputHandle *os.File) (*LogInfo
|
|||
_, err = io.Copy(outputHandle, stdoutReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error copying output from %s to file %s: %w",
|
||||
"error copying stdout from %s to file %s: %w",
|
||||
command.String(),
|
||||
outputHandle.Name(),
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
_, err = io.Copy(outputHandle, stderrReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"error copying stderr from %s to file %s: %w",
|
||||
command.String(),
|
||||
outputHandle.Name(),
|
||||
err,
|
||||
|
|
|
@ -1,6 +1,82 @@
|
|||
package diagnostic
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SystemInformationError struct {
|
||||
Err error `json:"error"`
|
||||
RawInfo string `json:"rawInfo"`
|
||||
}
|
||||
|
||||
func (err SystemInformationError) Error() string {
|
||||
return err.Err.Error()
|
||||
}
|
||||
|
||||
func (err SystemInformationError) MarshalJSON() ([]byte, error) {
|
||||
s := map[string]string{
|
||||
"error": err.Err.Error(),
|
||||
"rawInfo": err.RawInfo,
|
||||
}
|
||||
|
||||
return json.Marshal(s)
|
||||
}
|
||||
|
||||
type SystemInformationGeneralError struct {
|
||||
OperatingSystemInformationError error
|
||||
MemoryInformationError error
|
||||
FileDescriptorsInformationError error
|
||||
DiskVolumeInformationError error
|
||||
}
|
||||
|
||||
func (err SystemInformationGeneralError) Error() string {
|
||||
builder := &strings.Builder{}
|
||||
builder.WriteString("errors found:")
|
||||
|
||||
if err.OperatingSystemInformationError != nil {
|
||||
builder.WriteString(err.OperatingSystemInformationError.Error() + ", ")
|
||||
}
|
||||
|
||||
if err.MemoryInformationError != nil {
|
||||
builder.WriteString(err.MemoryInformationError.Error() + ", ")
|
||||
}
|
||||
|
||||
if err.FileDescriptorsInformationError != nil {
|
||||
builder.WriteString(err.FileDescriptorsInformationError.Error() + ", ")
|
||||
}
|
||||
|
||||
if err.DiskVolumeInformationError != nil {
|
||||
builder.WriteString(err.DiskVolumeInformationError.Error() + ", ")
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (err SystemInformationGeneralError) MarshalJSON() ([]byte, error) {
|
||||
data := map[string]SystemInformationError{}
|
||||
|
||||
var sysErr SystemInformationError
|
||||
if errors.As(err.OperatingSystemInformationError, &sysErr) {
|
||||
data["operatingSystemInformationError"] = sysErr
|
||||
}
|
||||
|
||||
if errors.As(err.MemoryInformationError, &sysErr) {
|
||||
data["memoryInformationError"] = sysErr
|
||||
}
|
||||
|
||||
if errors.As(err.FileDescriptorsInformationError, &sysErr) {
|
||||
data["fileDescriptorsInformationError"] = sysErr
|
||||
}
|
||||
|
||||
if errors.As(err.DiskVolumeInformationError, &sysErr) {
|
||||
data["diskVolumeInformationError"] = sysErr
|
||||
}
|
||||
|
||||
return json.Marshal(data)
|
||||
}
|
||||
|
||||
type DiskVolumeInformation struct {
|
||||
Name string `json:"name"` // represents the filesystem in linux/macos or device name in windows
|
||||
|
@ -17,17 +93,19 @@ func NewDiskVolumeInformation(name string, maximum, current uint64) *DiskVolumeI
|
|||
}
|
||||
|
||||
type SystemInformation struct {
|
||||
MemoryMaximum uint64 `json:"memoryMaximum"` // represents the maximum memory of the system in kilobytes
|
||||
MemoryCurrent uint64 `json:"memoryCurrent"` // represents the system's memory in use in kilobytes
|
||||
FileDescriptorMaximum uint64 `json:"fileDescriptorMaximum"` // represents the maximum number of file descriptors of the system
|
||||
FileDescriptorCurrent uint64 `json:"fileDescriptorCurrent"` // represents the system's file descriptors in use
|
||||
OsSystem string `json:"osSystem"` // represents the operating system name i.e.: linux, windows, darwin
|
||||
HostName string `json:"hostName"` // represents the system host name
|
||||
OsVersion string `json:"osVersion"` // detailed information about the system's release version level
|
||||
OsRelease string `json:"osRelease"` // detailed information about the system's release
|
||||
Architecture string `json:"architecture"` // represents the system's hardware platform i.e: arm64/amd64
|
||||
CloudflaredVersion string `json:"cloudflaredVersion"` // the runtime version of cloudflared
|
||||
Disk []*DiskVolumeInformation `json:"disk"`
|
||||
MemoryMaximum uint64 `json:"memoryMaximum,omitempty"` // represents the maximum memory of the system in kilobytes
|
||||
MemoryCurrent uint64 `json:"memoryCurrent,omitempty"` // represents the system's memory in use in kilobytes
|
||||
FileDescriptorMaximum uint64 `json:"fileDescriptorMaximum,omitempty"` // represents the maximum number of file descriptors of the system
|
||||
FileDescriptorCurrent uint64 `json:"fileDescriptorCurrent,omitempty"` // represents the system's file descriptors in use
|
||||
OsSystem string `json:"osSystem,omitempty"` // represents the operating system name i.e.: linux, windows, darwin
|
||||
HostName string `json:"hostName,omitempty"` // represents the system host name
|
||||
OsVersion string `json:"osVersion,omitempty"` // detailed information about the system's release version level
|
||||
OsRelease string `json:"osRelease,omitempty"` // detailed information about the system's release
|
||||
Architecture string `json:"architecture,omitempty"` // represents the system's hardware platform i.e: arm64/amd64
|
||||
CloudflaredVersion string `json:"cloudflaredVersion,omitempty"` // the runtime version of cloudflared
|
||||
GoVersion string `json:"goVersion,omitempty"`
|
||||
GoArch string `json:"goArch,omitempty"`
|
||||
Disk []*DiskVolumeInformation `json:"disk,omitempty"`
|
||||
}
|
||||
|
||||
func NewSystemInformation(
|
||||
|
@ -40,7 +118,9 @@ func NewSystemInformation(
|
|||
osVersion,
|
||||
osRelease,
|
||||
architecture,
|
||||
cloudflaredVersion string,
|
||||
cloudflaredVersion,
|
||||
goVersion,
|
||||
goArchitecture string,
|
||||
disk []*DiskVolumeInformation,
|
||||
) *SystemInformation {
|
||||
return &SystemInformation{
|
||||
|
@ -54,17 +134,17 @@ func NewSystemInformation(
|
|||
osRelease,
|
||||
architecture,
|
||||
cloudflaredVersion,
|
||||
goVersion,
|
||||
goArchitecture,
|
||||
disk,
|
||||
}
|
||||
}
|
||||
|
||||
type SystemCollector interface {
|
||||
// If the collection is successful it will return `SystemInformation` struct,
|
||||
// an empty string, and a nil error.
|
||||
// In case there is an error a string with the raw data will be returned
|
||||
// however the returned string not contain all the data points.
|
||||
// and a nil error.
|
||||
//
|
||||
// This function expects that the caller sets the context timeout to prevent
|
||||
// long-lived collectors.
|
||||
Collect(ctx context.Context) (*SystemInformation, string, error)
|
||||
Collect(ctx context.Context) (*SystemInformation, error)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
@ -22,45 +23,74 @@ func NewSystemCollectorImpl(
|
|||
}
|
||||
}
|
||||
|
||||
func (collector *SystemCollectorImpl) Collect(ctx context.Context) (*SystemInformation, string, error) {
|
||||
func (collector *SystemCollectorImpl) Collect(ctx context.Context) (*SystemInformation, error) {
|
||||
memoryInfo, memoryInfoRaw, memoryInfoErr := collectMemoryInformation(ctx)
|
||||
fdInfo, fdInfoRaw, fdInfoErr := collectFileDescriptorInformation(ctx)
|
||||
disks, disksRaw, diskErr := collectDiskVolumeInformationUnix(ctx)
|
||||
osInfo, osInfoRaw, osInfoErr := collectOSInformationUnix(ctx)
|
||||
|
||||
var memoryMaximum, memoryCurrent, fileDescriptorMaximum, fileDescriptorCurrent uint64
|
||||
var osSystem, name, osVersion, osRelease, architecture string
|
||||
gerror := SystemInformationGeneralError{}
|
||||
|
||||
if memoryInfoErr != nil {
|
||||
raw := RawSystemInformation(osInfoRaw, memoryInfoRaw, fdInfoRaw, disksRaw)
|
||||
return nil, raw, memoryInfoErr
|
||||
gerror.MemoryInformationError = SystemInformationError{
|
||||
Err: memoryInfoErr,
|
||||
RawInfo: memoryInfoRaw,
|
||||
}
|
||||
} else {
|
||||
memoryMaximum = memoryInfo.MemoryMaximum
|
||||
memoryCurrent = memoryInfo.MemoryCurrent
|
||||
}
|
||||
|
||||
if fdInfoErr != nil {
|
||||
raw := RawSystemInformation(osInfoRaw, memoryInfoRaw, fdInfoRaw, disksRaw)
|
||||
return nil, raw, fdInfoErr
|
||||
gerror.FileDescriptorsInformationError = SystemInformationError{
|
||||
Err: fdInfoErr,
|
||||
RawInfo: fdInfoRaw,
|
||||
}
|
||||
} else {
|
||||
fileDescriptorMaximum = fdInfo.FileDescriptorMaximum
|
||||
fileDescriptorCurrent = fdInfo.FileDescriptorCurrent
|
||||
}
|
||||
|
||||
if diskErr != nil {
|
||||
raw := RawSystemInformation(osInfoRaw, memoryInfoRaw, fdInfoRaw, disksRaw)
|
||||
return nil, raw, diskErr
|
||||
gerror.DiskVolumeInformationError = SystemInformationError{
|
||||
Err: diskErr,
|
||||
RawInfo: disksRaw,
|
||||
}
|
||||
}
|
||||
|
||||
if osInfoErr != nil {
|
||||
raw := RawSystemInformation(osInfoRaw, memoryInfoRaw, fdInfoRaw, disksRaw)
|
||||
return nil, raw, osInfoErr
|
||||
gerror.OperatingSystemInformationError = SystemInformationError{
|
||||
Err: osInfoErr,
|
||||
RawInfo: osInfoRaw,
|
||||
}
|
||||
} else {
|
||||
osSystem = osInfo.OsSystem
|
||||
name = osInfo.Name
|
||||
osVersion = osInfo.OsVersion
|
||||
osRelease = osInfo.OsRelease
|
||||
architecture = osInfo.Architecture
|
||||
}
|
||||
|
||||
return NewSystemInformation(
|
||||
memoryInfo.MemoryMaximum,
|
||||
memoryInfo.MemoryCurrent,
|
||||
fdInfo.FileDescriptorMaximum,
|
||||
fdInfo.FileDescriptorCurrent,
|
||||
osInfo.OsSystem,
|
||||
osInfo.Name,
|
||||
osInfo.OsVersion,
|
||||
osInfo.OsRelease,
|
||||
osInfo.Architecture,
|
||||
collector.version,
|
||||
cloudflaredVersion := collector.version
|
||||
info := NewSystemInformation(
|
||||
memoryMaximum,
|
||||
memoryCurrent,
|
||||
fileDescriptorMaximum,
|
||||
fileDescriptorCurrent,
|
||||
osSystem,
|
||||
name,
|
||||
osVersion,
|
||||
osRelease,
|
||||
architecture,
|
||||
cloudflaredVersion,
|
||||
runtime.Version(),
|
||||
runtime.GOARCH,
|
||||
disks,
|
||||
), "", nil
|
||||
)
|
||||
|
||||
return info, gerror
|
||||
}
|
||||
|
||||
func collectMemoryInformation(ctx context.Context) (*MemoryInformation, string, error) {
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
|
@ -21,41 +22,80 @@ func NewSystemCollectorImpl(
|
|||
}
|
||||
}
|
||||
|
||||
func (collector *SystemCollectorImpl) Collect(ctx context.Context) (*SystemInformation, string, error) {
|
||||
func (collector *SystemCollectorImpl) Collect(ctx context.Context) (*SystemInformation, error) {
|
||||
memoryInfo, memoryInfoRaw, memoryInfoErr := collectMemoryInformation(ctx)
|
||||
fdInfo, fdInfoRaw, fdInfoErr := collectFileDescriptorInformation(ctx)
|
||||
disks, disksRaw, diskErr := collectDiskVolumeInformationUnix(ctx)
|
||||
osInfo, osInfoRaw, osInfoErr := collectOSInformationUnix(ctx)
|
||||
|
||||
var memoryMaximum, memoryCurrent, fileDescriptorMaximum, fileDescriptorCurrent uint64
|
||||
var osSystem, name, osVersion, osRelease, architecture string
|
||||
|
||||
err := SystemInformationGeneralError{
|
||||
OperatingSystemInformationError: nil,
|
||||
MemoryInformationError: nil,
|
||||
FileDescriptorsInformationError: nil,
|
||||
DiskVolumeInformationError: nil,
|
||||
}
|
||||
|
||||
if memoryInfoErr != nil {
|
||||
return nil, RawSystemInformation(osInfoRaw, memoryInfoRaw, fdInfoRaw, disksRaw), memoryInfoErr
|
||||
err.MemoryInformationError = SystemInformationError{
|
||||
Err: memoryInfoErr,
|
||||
RawInfo: memoryInfoRaw,
|
||||
}
|
||||
} else {
|
||||
memoryMaximum = memoryInfo.MemoryMaximum
|
||||
memoryCurrent = memoryInfo.MemoryCurrent
|
||||
}
|
||||
|
||||
if fdInfoErr != nil {
|
||||
return nil, RawSystemInformation(osInfoRaw, memoryInfoRaw, fdInfoRaw, disksRaw), fdInfoErr
|
||||
err.FileDescriptorsInformationError = SystemInformationError{
|
||||
Err: fdInfoErr,
|
||||
RawInfo: fdInfoRaw,
|
||||
}
|
||||
} else {
|
||||
fileDescriptorMaximum = fdInfo.FileDescriptorMaximum
|
||||
fileDescriptorCurrent = fdInfo.FileDescriptorCurrent
|
||||
}
|
||||
|
||||
if diskErr != nil {
|
||||
return nil, RawSystemInformation(osInfoRaw, memoryInfoRaw, fdInfoRaw, disksRaw), diskErr
|
||||
err.DiskVolumeInformationError = SystemInformationError{
|
||||
Err: diskErr,
|
||||
RawInfo: disksRaw,
|
||||
}
|
||||
}
|
||||
|
||||
if osInfoErr != nil {
|
||||
return nil, RawSystemInformation(osInfoRaw, memoryInfoRaw, fdInfoRaw, disksRaw), osInfoErr
|
||||
err.OperatingSystemInformationError = SystemInformationError{
|
||||
Err: osInfoErr,
|
||||
RawInfo: osInfoRaw,
|
||||
}
|
||||
} else {
|
||||
osSystem = osInfo.OsSystem
|
||||
name = osInfo.Name
|
||||
osVersion = osInfo.OsVersion
|
||||
osRelease = osInfo.OsRelease
|
||||
architecture = osInfo.Architecture
|
||||
}
|
||||
|
||||
return NewSystemInformation(
|
||||
memoryInfo.MemoryMaximum,
|
||||
memoryInfo.MemoryCurrent,
|
||||
fdInfo.FileDescriptorMaximum,
|
||||
fdInfo.FileDescriptorCurrent,
|
||||
osInfo.OsSystem,
|
||||
osInfo.Name,
|
||||
osInfo.OsVersion,
|
||||
osInfo.OsRelease,
|
||||
osInfo.Architecture,
|
||||
collector.version,
|
||||
cloudflaredVersion := collector.version
|
||||
info := NewSystemInformation(
|
||||
memoryMaximum,
|
||||
memoryCurrent,
|
||||
fileDescriptorMaximum,
|
||||
fileDescriptorCurrent,
|
||||
osSystem,
|
||||
name,
|
||||
osVersion,
|
||||
osRelease,
|
||||
architecture,
|
||||
cloudflaredVersion,
|
||||
runtime.Version(),
|
||||
runtime.GOARCH,
|
||||
disks,
|
||||
), "", nil
|
||||
)
|
||||
|
||||
return info, err
|
||||
}
|
||||
|
||||
func collectFileDescriptorInformation(ctx context.Context) (
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
|
@ -22,41 +23,70 @@ func NewSystemCollectorImpl(
|
|||
version,
|
||||
}
|
||||
}
|
||||
func (collector *SystemCollectorImpl) Collect(ctx context.Context) (*SystemInformation, string, error) {
|
||||
|
||||
func (collector *SystemCollectorImpl) Collect(ctx context.Context) (*SystemInformation, error) {
|
||||
memoryInfo, memoryInfoRaw, memoryInfoErr := collectMemoryInformation(ctx)
|
||||
disks, disksRaw, diskErr := collectDiskVolumeInformation(ctx)
|
||||
osInfo, osInfoRaw, osInfoErr := collectOSInformation(ctx)
|
||||
|
||||
var memoryMaximum, memoryCurrent, fileDescriptorMaximum, fileDescriptorCurrent uint64
|
||||
var osSystem, name, osVersion, osRelease, architecture string
|
||||
|
||||
err := SystemInformationGeneralError{
|
||||
OperatingSystemInformationError: nil,
|
||||
MemoryInformationError: nil,
|
||||
FileDescriptorsInformationError: nil,
|
||||
DiskVolumeInformationError: nil,
|
||||
}
|
||||
|
||||
if memoryInfoErr != nil {
|
||||
raw := RawSystemInformation(osInfoRaw, memoryInfoRaw, "", disksRaw)
|
||||
return nil, raw, memoryInfoErr
|
||||
err.MemoryInformationError = SystemInformationError{
|
||||
Err: memoryInfoErr,
|
||||
RawInfo: memoryInfoRaw,
|
||||
}
|
||||
} else {
|
||||
memoryMaximum = memoryInfo.MemoryMaximum
|
||||
memoryCurrent = memoryInfo.MemoryCurrent
|
||||
}
|
||||
|
||||
if diskErr != nil {
|
||||
raw := RawSystemInformation(osInfoRaw, memoryInfoRaw, "", disksRaw)
|
||||
return nil, raw, diskErr
|
||||
err.DiskVolumeInformationError = SystemInformationError{
|
||||
Err: diskErr,
|
||||
RawInfo: disksRaw,
|
||||
}
|
||||
}
|
||||
|
||||
if osInfoErr != nil {
|
||||
raw := RawSystemInformation(osInfoRaw, memoryInfoRaw, "", disksRaw)
|
||||
return nil, raw, osInfoErr
|
||||
err.OperatingSystemInformationError = SystemInformationError{
|
||||
Err: osInfoErr,
|
||||
RawInfo: osInfoRaw,
|
||||
}
|
||||
} else {
|
||||
osSystem = osInfo.OsSystem
|
||||
name = osInfo.Name
|
||||
osVersion = osInfo.OsVersion
|
||||
osRelease = osInfo.OsRelease
|
||||
architecture = osInfo.Architecture
|
||||
}
|
||||
|
||||
return NewSystemInformation(
|
||||
memoryInfo.MemoryMaximum,
|
||||
memoryInfo.MemoryCurrent,
|
||||
// For windows we leave both the fileDescriptorMaximum and fileDescriptorCurrent with zero
|
||||
// since there is no obvious way to get this information.
|
||||
0,
|
||||
0,
|
||||
osInfo.OsSystem,
|
||||
osInfo.Name,
|
||||
osInfo.OsVersion,
|
||||
osInfo.OsRelease,
|
||||
osInfo.Architecture,
|
||||
collector.version,
|
||||
cloudflaredVersion := collector.version
|
||||
info := NewSystemInformation(
|
||||
memoryMaximum,
|
||||
memoryCurrent,
|
||||
fileDescriptorMaximum,
|
||||
fileDescriptorCurrent,
|
||||
osSystem,
|
||||
name,
|
||||
osVersion,
|
||||
osRelease,
|
||||
architecture,
|
||||
cloudflaredVersion,
|
||||
runtime.Version(),
|
||||
runtime.GOARCH,
|
||||
disks,
|
||||
), "", nil
|
||||
)
|
||||
|
||||
return info, err
|
||||
}
|
||||
|
||||
func collectMemoryInformation(ctx context.Context) (*MemoryInformation, string, error) {
|
||||
|
|
|
@ -11,28 +11,40 @@ const (
|
|||
FeatureDatagramV3 = "support_datagram_v3"
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultFeatures = []string{
|
||||
var defaultFeatures = []string{
|
||||
FeatureAllowRemoteConfig,
|
||||
FeatureSerializedHeaders,
|
||||
FeatureDatagramV2,
|
||||
FeatureQUICSupportEOF,
|
||||
FeatureManagementLogs,
|
||||
}
|
||||
}
|
||||
|
||||
// Features set by user provided flags
|
||||
type staticFeatures struct {
|
||||
PostQuantumMode *PostQuantumMode
|
||||
}
|
||||
|
||||
type PostQuantumMode uint8
|
||||
|
||||
const (
|
||||
// Prefer post quantum, but fallback if connection cannot be established
|
||||
PostQuantumPrefer PostQuantumMode = iota
|
||||
// If the user passes the --post-quantum flag, we override
|
||||
// CurvePreferences to only support hybrid post-quantum key agreements.
|
||||
PostQuantumStrict
|
||||
)
|
||||
|
||||
func Contains(feature string) bool {
|
||||
for _, f := range DefaultFeatures {
|
||||
if f == feature {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
type DatagramVersion string
|
||||
|
||||
const (
|
||||
// DatagramV2 is the currently supported datagram protocol for UDP and ICMP packets
|
||||
DatagramV2 DatagramVersion = FeatureDatagramV2
|
||||
// DatagramV3 is a new datagram protocol for UDP and ICMP packets. It is not backwards compatible with datagram v2.
|
||||
DatagramV3 DatagramVersion = FeatureDatagramV3
|
||||
)
|
||||
|
||||
// Remove any duplicates from the slice
|
||||
func Dedup(slice []string) []string {
|
||||
|
||||
// Convert the slice into a set
|
||||
set := make(map[string]bool, 0)
|
||||
for _, str := range slice {
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"hash/fnv"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -18,61 +19,67 @@ const (
|
|||
lookupTimeout = time.Second * 10
|
||||
)
|
||||
|
||||
type PostQuantumMode uint8
|
||||
|
||||
const (
|
||||
// Prefer post quantum, but fallback if connection cannot be established
|
||||
PostQuantumPrefer PostQuantumMode = iota
|
||||
// If the user passes the --post-quantum flag, we override
|
||||
// CurvePreferences to only support hybrid post-quantum key agreements.
|
||||
PostQuantumStrict
|
||||
)
|
||||
|
||||
// If the TXT record adds other fields, the umarshal logic will ignore those keys
|
||||
// If the TXT record is missing a key, the field will unmarshal to the default Go value
|
||||
// pq was removed in TUN-7970
|
||||
type featuresRecord struct{}
|
||||
|
||||
func NewFeatureSelector(ctx context.Context, accountTag string, staticFeatures StaticFeatures, logger *zerolog.Logger) (*FeatureSelector, error) {
|
||||
return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), staticFeatures, defaultRefreshFreq)
|
||||
type featuresRecord struct {
|
||||
// support_datagram_v3
|
||||
DatagramV3Percentage int32 `json:"dv3"`
|
||||
|
||||
// PostQuantumPercentage int32 `json:"pq"` // Removed in TUN-7970
|
||||
}
|
||||
|
||||
// FeatureSelector determines if this account will try new features. It preiodically queries a DNS TXT record
|
||||
// to see which features are turned on
|
||||
func NewFeatureSelector(ctx context.Context, accountTag string, cliFeatures []string, pq bool, logger *zerolog.Logger) (*FeatureSelector, error) {
|
||||
return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), cliFeatures, pq, defaultRefreshFreq)
|
||||
}
|
||||
|
||||
// FeatureSelector determines if this account will try new features. It periodically queries a DNS TXT record
|
||||
// to see which features are turned on.
|
||||
type FeatureSelector struct {
|
||||
accountHash int32
|
||||
logger *zerolog.Logger
|
||||
resolver resolver
|
||||
|
||||
staticFeatures StaticFeatures
|
||||
staticFeatures staticFeatures
|
||||
cliFeatures []string
|
||||
|
||||
// lock protects concurrent access to dynamic features
|
||||
lock sync.RWMutex
|
||||
features featuresRecord
|
||||
}
|
||||
|
||||
// Features set by user provided flags
|
||||
type StaticFeatures struct {
|
||||
PostQuantumMode *PostQuantumMode
|
||||
}
|
||||
|
||||
func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, staticFeatures StaticFeatures, refreshFreq time.Duration) (*FeatureSelector, error) {
|
||||
func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, cliFeatures []string, pq bool, refreshFreq time.Duration) (*FeatureSelector, error) {
|
||||
// Combine default features and user-provided features
|
||||
var pqMode *PostQuantumMode
|
||||
if pq {
|
||||
mode := PostQuantumStrict
|
||||
pqMode = &mode
|
||||
cliFeatures = append(cliFeatures, FeaturePostQuantum)
|
||||
}
|
||||
staticFeatures := staticFeatures{
|
||||
PostQuantumMode: pqMode,
|
||||
}
|
||||
selector := &FeatureSelector{
|
||||
accountHash: switchThreshold(accountTag),
|
||||
logger: logger,
|
||||
resolver: resolver,
|
||||
staticFeatures: staticFeatures,
|
||||
cliFeatures: Dedup(cliFeatures),
|
||||
}
|
||||
|
||||
if err := selector.refresh(ctx); err != nil {
|
||||
logger.Err(err).Msg("Failed to fetch features, default to disable")
|
||||
}
|
||||
|
||||
// Run refreshLoop next time we have a new feature to rollout
|
||||
go selector.refreshLoop(ctx, refreshFreq)
|
||||
|
||||
return selector, nil
|
||||
}
|
||||
|
||||
func (fs *FeatureSelector) accountEnabled(percentage int32) bool {
|
||||
return percentage > fs.accountHash
|
||||
}
|
||||
|
||||
func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode {
|
||||
if fs.staticFeatures.PostQuantumMode != nil {
|
||||
return *fs.staticFeatures.PostQuantumMode
|
||||
|
@ -81,6 +88,33 @@ func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode {
|
|||
return PostQuantumPrefer
|
||||
}
|
||||
|
||||
func (fs *FeatureSelector) DatagramVersion() DatagramVersion {
|
||||
fs.lock.RLock()
|
||||
defer fs.lock.RUnlock()
|
||||
|
||||
// If user provides the feature via the cli, we take it as priority over remote feature evaluation
|
||||
if slices.Contains(fs.cliFeatures, FeatureDatagramV3) {
|
||||
return DatagramV3
|
||||
}
|
||||
// If the user specifies DatagramV2, we also take that over remote
|
||||
if slices.Contains(fs.cliFeatures, FeatureDatagramV2) {
|
||||
return DatagramV2
|
||||
}
|
||||
|
||||
if fs.accountEnabled(fs.features.DatagramV3Percentage) {
|
||||
return DatagramV3
|
||||
}
|
||||
return DatagramV2
|
||||
}
|
||||
|
||||
// ClientFeatures will return the list of currently available features that cloudflared should provide to the edge.
|
||||
//
|
||||
// This list is dynamic and can change in-between returns.
|
||||
func (fs *FeatureSelector) ClientFeatures() []string {
|
||||
// Evaluate any remote features along with static feature list to construct the list of features
|
||||
return Dedup(slices.Concat(defaultFeatures, fs.cliFeatures, []string{string(fs.DatagramVersion())}))
|
||||
}
|
||||
|
||||
func (fs *FeatureSelector) refreshLoop(ctx context.Context, refreshFreq time.Duration) {
|
||||
ticker := time.NewTicker(refreshFreq)
|
||||
for {
|
||||
|
|
|
@ -14,15 +14,19 @@ import (
|
|||
func TestUnmarshalFeaturesRecord(t *testing.T) {
|
||||
tests := []struct {
|
||||
record []byte
|
||||
expectedPercentage int32
|
||||
}{
|
||||
{
|
||||
record: []byte(`{"pq":0}`),
|
||||
record: []byte(`{"dv3":0}`),
|
||||
expectedPercentage: 0,
|
||||
},
|
||||
{
|
||||
record: []byte(`{"pq":39}`),
|
||||
record: []byte(`{"dv3":39}`),
|
||||
expectedPercentage: 39,
|
||||
},
|
||||
{
|
||||
record: []byte(`{"pq":100}`),
|
||||
record: []byte(`{"dv3":100}`),
|
||||
expectedPercentage: 100,
|
||||
},
|
||||
{
|
||||
record: []byte(`{}`), // Unmarshal to default struct if key is not present
|
||||
|
@ -36,37 +40,186 @@ func TestUnmarshalFeaturesRecord(t *testing.T) {
|
|||
var features featuresRecord
|
||||
err := json.Unmarshal(test.record, &features)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, featuresRecord{}, features)
|
||||
require.Equal(t, test.expectedPercentage, features.DatagramV3Percentage, test)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFeaturePrecedenceEvaluationPostQuantum(t *testing.T) {
|
||||
logger := zerolog.Nop()
|
||||
tests := []struct {
|
||||
name string
|
||||
cli bool
|
||||
expectedFeatures []string
|
||||
expectedVersion PostQuantumMode
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
cli: false,
|
||||
expectedFeatures: defaultFeatures,
|
||||
expectedVersion: PostQuantumPrefer,
|
||||
},
|
||||
{
|
||||
name: "user_specified",
|
||||
cli: true,
|
||||
expectedFeatures: Dedup(append(defaultFeatures, FeaturePostQuantum)),
|
||||
expectedVersion: PostQuantumStrict,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
resolver := &staticResolver{record: featuresRecord{}}
|
||||
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, []string{}, test.cli, time.Second)
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures())
|
||||
require.Equal(t, test.expectedVersion, selector.PostQuantumMode())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFeaturePrecedenceEvaluationDatagramVersion(t *testing.T) {
|
||||
logger := zerolog.Nop()
|
||||
tests := []struct {
|
||||
name string
|
||||
cli []string
|
||||
remote featuresRecord
|
||||
expectedFeatures []string
|
||||
expectedVersion DatagramVersion
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
cli: []string{},
|
||||
remote: featuresRecord{},
|
||||
expectedFeatures: defaultFeatures,
|
||||
expectedVersion: DatagramV2,
|
||||
},
|
||||
{
|
||||
name: "user_specified_v2",
|
||||
cli: []string{FeatureDatagramV2},
|
||||
remote: featuresRecord{},
|
||||
expectedFeatures: defaultFeatures,
|
||||
expectedVersion: DatagramV2,
|
||||
},
|
||||
{
|
||||
name: "user_specified_v3",
|
||||
cli: []string{FeatureDatagramV3},
|
||||
remote: featuresRecord{},
|
||||
expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)),
|
||||
expectedVersion: FeatureDatagramV3,
|
||||
},
|
||||
{
|
||||
name: "remote_specified_v3",
|
||||
cli: []string{},
|
||||
remote: featuresRecord{
|
||||
DatagramV3Percentage: 100,
|
||||
},
|
||||
expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)),
|
||||
expectedVersion: FeatureDatagramV3,
|
||||
},
|
||||
{
|
||||
name: "remote_and_user_specified_v3",
|
||||
cli: []string{FeatureDatagramV3},
|
||||
remote: featuresRecord{
|
||||
DatagramV3Percentage: 100,
|
||||
},
|
||||
expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)),
|
||||
expectedVersion: FeatureDatagramV3,
|
||||
},
|
||||
{
|
||||
name: "remote_v3_and_user_specified_v2",
|
||||
cli: []string{FeatureDatagramV2},
|
||||
remote: featuresRecord{
|
||||
DatagramV3Percentage: 100,
|
||||
},
|
||||
expectedFeatures: defaultFeatures,
|
||||
expectedVersion: DatagramV2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
resolver := &staticResolver{record: test.remote}
|
||||
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false, time.Second)
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures())
|
||||
require.Equal(t, test.expectedVersion, selector.DatagramVersion())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshFeaturesRecord(t *testing.T) {
|
||||
// The hash of the accountTag is 82
|
||||
accountTag := t.Name()
|
||||
threshold := switchThreshold(accountTag)
|
||||
|
||||
percentages := []int32{0, 10, 81, 82, 83, 100, 101, 1000}
|
||||
refreshFreq := time.Millisecond * 10
|
||||
selector := newTestSelector(t, percentages, false, refreshFreq)
|
||||
|
||||
// Starting out should default to DatagramV2
|
||||
require.Equal(t, DatagramV2, selector.DatagramVersion())
|
||||
|
||||
for _, percentage := range percentages {
|
||||
if percentage > threshold {
|
||||
require.Equal(t, DatagramV3, selector.DatagramVersion())
|
||||
} else {
|
||||
require.Equal(t, DatagramV2, selector.DatagramVersion())
|
||||
}
|
||||
|
||||
time.Sleep(refreshFreq + time.Millisecond)
|
||||
}
|
||||
|
||||
// Make sure error doesn't override the last fetched features
|
||||
require.Equal(t, DatagramV3, selector.DatagramVersion())
|
||||
}
|
||||
|
||||
func TestStaticFeatures(t *testing.T) {
|
||||
pqMode := PostQuantumStrict
|
||||
selector := newTestSelector(t, &pqMode, time.Millisecond*10)
|
||||
percentages := []int32{0}
|
||||
// PostQuantum Enabled from user flag
|
||||
selector := newTestSelector(t, percentages, true, time.Millisecond*10)
|
||||
require.Equal(t, PostQuantumStrict, selector.PostQuantumMode())
|
||||
|
||||
// No StaticFeatures configured
|
||||
selector = newTestSelector(t, nil, time.Millisecond*10)
|
||||
// PostQuantum Disabled (or not set)
|
||||
selector = newTestSelector(t, percentages, false, time.Millisecond*10)
|
||||
require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode())
|
||||
}
|
||||
|
||||
func newTestSelector(t *testing.T, pqMode *PostQuantumMode, refreshFreq time.Duration) *FeatureSelector {
|
||||
func newTestSelector(t *testing.T, percentages []int32, pq bool, refreshFreq time.Duration) *FeatureSelector {
|
||||
accountTag := t.Name()
|
||||
logger := zerolog.Nop()
|
||||
|
||||
resolver := &mockResolver{}
|
||||
|
||||
staticFeatures := StaticFeatures{
|
||||
PostQuantumMode: pqMode,
|
||||
resolver := &mockResolver{
|
||||
percentages: percentages,
|
||||
}
|
||||
selector, err := newFeatureSelector(context.Background(), accountTag, &logger, resolver, staticFeatures, refreshFreq)
|
||||
|
||||
selector, err := newFeatureSelector(context.Background(), accountTag, &logger, resolver, []string{}, pq, refreshFreq)
|
||||
require.NoError(t, err)
|
||||
|
||||
return selector
|
||||
}
|
||||
|
||||
type mockResolver struct{}
|
||||
type mockResolver struct {
|
||||
nextIndex int
|
||||
percentages []int32
|
||||
}
|
||||
|
||||
func (mr *mockResolver) lookupRecord(ctx context.Context) ([]byte, error) {
|
||||
return nil, fmt.Errorf("mockResolver hasn't implement lookupRecord")
|
||||
if mr.nextIndex >= len(mr.percentages) {
|
||||
return nil, fmt.Errorf("no more record to lookup")
|
||||
}
|
||||
|
||||
record, err := json.Marshal(featuresRecord{
|
||||
DatagramV3Percentage: mr.percentages[mr.nextIndex],
|
||||
})
|
||||
mr.nextIndex++
|
||||
|
||||
return record, err
|
||||
}
|
||||
|
||||
type staticResolver struct {
|
||||
record featuresRecord
|
||||
}
|
||||
|
||||
func (r *staticResolver) lookupRecord(ctx context.Context) ([]byte, error) {
|
||||
return json.Marshal(r.record)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
//go:build fips
|
||||
|
||||
package fips
|
||||
|
||||
import (
|
||||
_ "crypto/tls/fipsonly"
|
||||
)
|
||||
|
||||
func IsFipsEnabled() bool {
|
||||
return true
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
// +build fips
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "crypto/tls/fipsonly"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/tunnel"
|
||||
)
|
||||
|
||||
func init () {
|
||||
tunnel.FipsEnabled = true
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
//go:build !fips
|
||||
|
||||
package fips
|
||||
|
||||
func IsFipsEnabled() bool {
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package flow
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
unlimitedActiveFlows = 0
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTooManyActiveFlows = errors.New("too many active flows")
|
||||
)
|
||||
|
||||
type Limiter interface {
|
||||
// Acquire tries to acquire a free slot for a flow, if the value of flows is already above
|
||||
// the maximum it returns ErrTooManyActiveFlows.
|
||||
Acquire(flowType string) error
|
||||
// Release releases a slot for a flow.
|
||||
Release()
|
||||
// SetLimit allows to hot swap the limit value of the limiter.
|
||||
SetLimit(uint64)
|
||||
}
|
||||
|
||||
type flowLimiter struct {
|
||||
limiterLock sync.Mutex
|
||||
activeFlowsCounter uint64
|
||||
maxActiveFlows uint64
|
||||
unlimited bool
|
||||
}
|
||||
|
||||
func NewLimiter(maxActiveFlows uint64) Limiter {
|
||||
flowLimiter := &flowLimiter{
|
||||
maxActiveFlows: maxActiveFlows,
|
||||
unlimited: isUnlimited(maxActiveFlows),
|
||||
}
|
||||
|
||||
return flowLimiter
|
||||
}
|
||||
|
||||
func (s *flowLimiter) Acquire(flowType string) error {
|
||||
s.limiterLock.Lock()
|
||||
defer s.limiterLock.Unlock()
|
||||
|
||||
if !s.unlimited && s.activeFlowsCounter >= s.maxActiveFlows {
|
||||
flowRegistrationsDropped.WithLabelValues(flowType).Inc()
|
||||
return ErrTooManyActiveFlows
|
||||
}
|
||||
|
||||
s.activeFlowsCounter++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *flowLimiter) Release() {
|
||||
s.limiterLock.Lock()
|
||||
defer s.limiterLock.Unlock()
|
||||
|
||||
if s.activeFlowsCounter <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
s.activeFlowsCounter--
|
||||
}
|
||||
|
||||
func (s *flowLimiter) SetLimit(newMaxActiveFlows uint64) {
|
||||
s.limiterLock.Lock()
|
||||
defer s.limiterLock.Unlock()
|
||||
|
||||
s.maxActiveFlows = newMaxActiveFlows
|
||||
s.unlimited = isUnlimited(newMaxActiveFlows)
|
||||
}
|
||||
|
||||
// isUnlimited checks if the value received matches the configuration for the unlimited flow limiter.
|
||||
func isUnlimited(value uint64) bool {
|
||||
return value == unlimitedActiveFlows
|
||||
}
|
|
@ -0,0 +1,119 @@
|
|||
package flow_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/flow"
|
||||
)
|
||||
|
||||
func TestFlowLimiter_Unlimited(t *testing.T) {
|
||||
unlimitedLimiter := flow.NewLimiter(0)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
err := unlimitedLimiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlowLimiter_Limited(t *testing.T) {
|
||||
maxFlows := uint64(5)
|
||||
limiter := flow.NewLimiter(maxFlows)
|
||||
|
||||
for i := uint64(0); i < maxFlows; i++ {
|
||||
err := limiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||
}
|
||||
|
||||
func TestFlowLimiter_AcquireAndReleaseFlow(t *testing.T) {
|
||||
maxFlows := uint64(5)
|
||||
limiter := flow.NewLimiter(maxFlows)
|
||||
|
||||
// Acquire the maximum number of flows
|
||||
for i := uint64(0); i < maxFlows; i++ {
|
||||
err := limiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Validate acquire 1 more flows fails
|
||||
err := limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||
|
||||
// Release the maximum number of flows
|
||||
for i := uint64(0); i < maxFlows; i++ {
|
||||
limiter.Release()
|
||||
}
|
||||
|
||||
// Validate acquire 1 more flows works
|
||||
err = limiter.Acquire("shouldn't fail")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Release a 10x the number of max flows
|
||||
for i := uint64(0); i < 10*maxFlows; i++ {
|
||||
limiter.Release()
|
||||
}
|
||||
|
||||
// Validate it still can only acquire a value = number max flows.
|
||||
for i := uint64(0); i < maxFlows; i++ {
|
||||
err := limiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
err = limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||
}
|
||||
|
||||
func TestFlowLimiter_SetLimit(t *testing.T) {
|
||||
maxFlows := uint64(5)
|
||||
limiter := flow.NewLimiter(maxFlows)
|
||||
|
||||
// Acquire the maximum number of flows
|
||||
for i := uint64(0); i < maxFlows; i++ {
|
||||
err := limiter.Acquire("test")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Validate acquire 1 more flows fails
|
||||
err := limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||
|
||||
// Set the flow limiter to support one more request
|
||||
limiter.SetLimit(maxFlows + 1)
|
||||
|
||||
// Validate acquire 1 more flows now works
|
||||
err = limiter.Acquire("shouldn't fail")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate acquire 1 more flows doesn't work because we already reached the limit
|
||||
err = limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||
|
||||
// Release all flows
|
||||
for i := uint64(0); i < maxFlows+1; i++ {
|
||||
limiter.Release()
|
||||
}
|
||||
|
||||
// Validate 1 flow works again
|
||||
err = limiter.Acquire("shouldn't fail")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set the flow limit to 1
|
||||
limiter.SetLimit(1)
|
||||
|
||||
// Validate acquire 1 more flows doesn't work
|
||||
err = limiter.Acquire("should fail")
|
||||
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
|
||||
|
||||
// Set the flow limit to unlimited
|
||||
limiter.SetLimit(0)
|
||||
|
||||
// Validate it can acquire a lot of flows because it is now unlimited.
|
||||
for i := uint64(0); i < 10*maxFlows; i++ {
|
||||
err := limiter.Acquire("shouldn't fail")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package flow
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
|
||||
const (
|
||||
namespace = "flow"
|
||||
)
|
||||
|
||||
var (
|
||||
labels = []string{"flow_type"}
|
||||
|
||||
flowRegistrationsDropped = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: "client",
|
||||
Name: "registrations_rate_limited_total",
|
||||
Help: "Count registrations dropped due to high number of concurrent flows being handled",
|
||||
},
|
||||
labels,
|
||||
)
|
||||
)
|
21
go.mod
21
go.mod
|
@ -35,11 +35,12 @@ require (
|
|||
go.opentelemetry.io/otel/trace v1.26.0
|
||||
go.opentelemetry.io/proto/otlp v1.2.0
|
||||
go.uber.org/automaxprocs v1.4.0
|
||||
golang.org/x/crypto v0.23.0
|
||||
golang.org/x/net v0.25.0
|
||||
golang.org/x/sync v0.7.0
|
||||
golang.org/x/sys v0.20.0
|
||||
golang.org/x/term v0.20.0
|
||||
go.uber.org/mock v0.5.0
|
||||
golang.org/x/crypto v0.31.0
|
||||
golang.org/x/net v0.26.0
|
||||
golang.org/x/sync v0.10.0
|
||||
golang.org/x/sys v0.28.0
|
||||
golang.org/x/term v0.27.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
|
@ -83,12 +84,11 @@ require (
|
|||
github.com/prometheus/procfs v0.12.0 // indirect
|
||||
github.com/russross/blackfriday/v2 v2.1.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.26.0 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
golang.org/x/mod v0.18.0 // indirect
|
||||
golang.org/x/oauth2 v0.18.0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
golang.org/x/tools v0.21.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/tools v0.22.0 // indirect
|
||||
google.golang.org/appengine v1.6.8 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240311132316-a219d84964c2 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect
|
||||
|
@ -102,3 +102,6 @@ replace github.com/urfave/cli/v2 => github.com/ipostelnik/cli/v2 v2.3.1-0.202103
|
|||
replace github.com/prometheus/golang_client => github.com/prometheus/golang_client v1.12.1
|
||||
|
||||
replace gopkg.in/yaml.v3 => gopkg.in/yaml.v3 v3.0.1
|
||||
|
||||
// This fork is based on quic-go v0.45
|
||||
replace github.com/quic-go/quic-go => github.com/chungthuang/quic-go v0.45.1-0.20250128102735-2687bd175910
|
||||
|
|
40
go.sum
40
go.sum
|
@ -7,6 +7,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
|||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/chungthuang/quic-go v0.45.1-0.20250128102735-2687bd175910 h1:/hTvBpxBDj/3NIzTodi1oEOyNBpirvgDSPKSV7VqAZU=
|
||||
github.com/chungthuang/quic-go v0.45.1-0.20250128102735-2687bd175910/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
|
||||
github.com/coredns/caddy v1.1.1 h1:2eYKZT7i6yxIfGP3qLJoJ7HAsDJqYB+X68g4NYjSrE0=
|
||||
github.com/coredns/caddy v1.1.1/go.mod h1:A6ntJQlAWuQfFlsd9hvigKbo2WS0VUs2l1e2F+BawD4=
|
||||
github.com/coredns/coredns v1.11.3 h1:8RjnpZc42db5th84/QJKH2i137ecJdzZK1HJwhetSPk=
|
||||
|
@ -173,8 +175,6 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
|
|||
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
|
||||
github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
|
||||
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
|
||||
github.com/quic-go/quic-go v0.45.0 h1:OHmkQGM37luZITyTSu6ff03HP/2IrwDX1ZFiNEhSFUE=
|
||||
github.com/quic-go/quic-go v0.45.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
||||
|
@ -217,33 +217,33 @@ go.opentelemetry.io/proto/otlp v1.2.0 h1:pVeZGk7nXDC9O2hncA6nHldxEjm6LByfA2aN8IO
|
|||
go.opentelemetry.io/proto/otlp v1.2.0/go.mod h1:gGpR8txAl5M03pDhMC79G6SdqNV26naRm/KDsgaHD8A=
|
||||
go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0=
|
||||
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
|
||||
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
|
||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||
golang.org/x/oauth2 v0.18.0 h1:09qnuIAgzdx1XplqJvW6CQqMCtGZykZWcXzPMPUusvI=
|
||||
golang.org/x/oauth2 v0.18.0/go.mod h1:Wf7knwG0MPoWIMMBgFlEaSUDaKskp0dCfrlJRJXbBi8=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
|
@ -254,19 +254,19 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
|
@ -275,8 +275,8 @@ golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtn
|
|||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
|
||||
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
|
||||
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
|
|
@ -22,6 +22,7 @@ var (
|
|||
const (
|
||||
defaultProxyAddress = "127.0.0.1"
|
||||
defaultKeepAliveConnections = 100
|
||||
defaultMaxActiveFlows = 0 // unlimited
|
||||
SSHServerFlag = "ssh-server"
|
||||
Socks5Flag = "socks5"
|
||||
ProxyConnectTimeoutFlag = "proxy-connect-timeout"
|
||||
|
@ -46,17 +47,22 @@ const (
|
|||
|
||||
type WarpRoutingConfig struct {
|
||||
ConnectTimeout config.CustomDuration `yaml:"connectTimeout" json:"connectTimeout,omitempty"`
|
||||
MaxActiveFlows uint64 `yaml:"maxActiveFlows" json:"MaxActiveFlows,omitempty"`
|
||||
TCPKeepAlive config.CustomDuration `yaml:"tcpKeepAlive" json:"tcpKeepAlive,omitempty"`
|
||||
}
|
||||
|
||||
func NewWarpRoutingConfig(raw *config.WarpRoutingConfig) WarpRoutingConfig {
|
||||
cfg := WarpRoutingConfig{
|
||||
ConnectTimeout: defaultWarpRoutingConnectTimeout,
|
||||
MaxActiveFlows: defaultMaxActiveFlows,
|
||||
TCPKeepAlive: defaultTCPKeepAlive,
|
||||
}
|
||||
if raw.ConnectTimeout != nil {
|
||||
cfg.ConnectTimeout = *raw.ConnectTimeout
|
||||
}
|
||||
if raw.MaxActiveFlows != nil {
|
||||
cfg.MaxActiveFlows = *raw.MaxActiveFlows
|
||||
}
|
||||
if raw.TCPKeepAlive != nil {
|
||||
cfg.TCPKeepAlive = *raw.TCPKeepAlive
|
||||
}
|
||||
|
@ -68,6 +74,9 @@ func (c *WarpRoutingConfig) RawConfig() config.WarpRoutingConfig {
|
|||
if c.ConnectTimeout.Duration != defaultWarpRoutingConnectTimeout.Duration {
|
||||
raw.ConnectTimeout = &c.ConnectTimeout
|
||||
}
|
||||
if c.MaxActiveFlows != defaultMaxActiveFlows {
|
||||
raw.MaxActiveFlows = &c.MaxActiveFlows
|
||||
}
|
||||
if c.TCPKeepAlive.Duration != defaultTCPKeepAlive.Duration {
|
||||
raw.TCPKeepAlive = &c.TCPKeepAlive
|
||||
}
|
||||
|
@ -172,6 +181,7 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
|
|||
}
|
||||
if flag := ProxyPortFlag; c.IsSet(flag) {
|
||||
// Note TUN-3758 , we use Int because UInt is not supported with altsrc
|
||||
// nolint: gosec
|
||||
proxyPort = uint(c.Int(flag))
|
||||
}
|
||||
if flag := Http2OriginFlag; c.IsSet(flag) {
|
||||
|
@ -551,7 +561,7 @@ func convertToRawIPRules(ipRules []ipaccess.Rule) []config.IngressIPRule {
|
|||
}
|
||||
|
||||
func defaultBoolToNil(b bool) *bool {
|
||||
if b == false {
|
||||
if !b {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -16,7 +15,7 @@ import (
|
|||
"golang.org/x/term"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/features"
|
||||
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
)
|
||||
|
||||
|
@ -24,14 +23,6 @@ const (
|
|||
EnableTerminalLog = false
|
||||
DisableTerminalLog = true
|
||||
|
||||
LogLevelFlag = "loglevel"
|
||||
LogFileFlag = "logfile"
|
||||
LogDirectoryFlag = "log-directory"
|
||||
LogTransportLevelFlag = "transport-loglevel"
|
||||
|
||||
LogSSHDirectoryFlag = "log-directory"
|
||||
LogSSHLevelFlag = "log-level"
|
||||
|
||||
dirPermMode = 0744 // rwxr--r--
|
||||
filePermMode = 0644 // rw-r--r--
|
||||
|
||||
|
@ -46,11 +37,7 @@ func init() {
|
|||
zerolog.TimeFieldFormat = time.RFC3339
|
||||
zerolog.TimestampFunc = utcNow
|
||||
|
||||
if features.Contains(features.FeatureManagementLogs) {
|
||||
// Management logger needs to be initialized before any of the other loggers as to not capture
|
||||
// it's own logging events.
|
||||
ManagementLogger = management.NewLogger()
|
||||
}
|
||||
}
|
||||
|
||||
func utcNow() time.Time {
|
||||
|
@ -124,10 +111,7 @@ func newZerolog(loggerConfig *Config) *zerolog.Logger {
|
|||
writers = append(writers, rollingLogger)
|
||||
}
|
||||
|
||||
var managementWriter zerolog.LevelWriter
|
||||
if features.Contains(features.FeatureManagementLogs) {
|
||||
managementWriter = ManagementLogger
|
||||
}
|
||||
managementWriter := ManagementLogger
|
||||
|
||||
level, levelErr := zerolog.ParseLevel(loggerConfig.MinLevel)
|
||||
if levelErr != nil {
|
||||
|
@ -145,15 +129,15 @@ func newZerolog(loggerConfig *Config) *zerolog.Logger {
|
|||
}
|
||||
|
||||
func CreateTransportLoggerFromContext(c *cli.Context, disableTerminal bool) *zerolog.Logger {
|
||||
return createFromContext(c, LogTransportLevelFlag, LogDirectoryFlag, disableTerminal)
|
||||
return createFromContext(c, cfdflags.TransportLogLevel, cfdflags.LogDirectory, disableTerminal)
|
||||
}
|
||||
|
||||
func CreateLoggerFromContext(c *cli.Context, disableTerminal bool) *zerolog.Logger {
|
||||
return createFromContext(c, LogLevelFlag, LogDirectoryFlag, disableTerminal)
|
||||
return createFromContext(c, cfdflags.LogLevel, cfdflags.LogDirectory, disableTerminal)
|
||||
}
|
||||
|
||||
func CreateSSHLoggerFromContext(c *cli.Context, disableTerminal bool) *zerolog.Logger {
|
||||
return createFromContext(c, LogSSHLevelFlag, LogSSHDirectoryFlag, disableTerminal)
|
||||
return createFromContext(c, cfdflags.LogLevelSSH, cfdflags.LogDirectory, disableTerminal)
|
||||
}
|
||||
|
||||
func createFromContext(
|
||||
|
@ -163,7 +147,7 @@ func createFromContext(
|
|||
disableTerminal bool,
|
||||
) *zerolog.Logger {
|
||||
logLevel := c.String(logLevelFlagName)
|
||||
logFile := c.String(LogFileFlag)
|
||||
logFile := c.String(cfdflags.LogFile)
|
||||
logDirectory := c.String(logDirectoryFlagName)
|
||||
|
||||
loggerConfig := CreateConfig(
|
||||
|
@ -175,7 +159,7 @@ func createFromContext(
|
|||
|
||||
log := newZerolog(loggerConfig)
|
||||
if incompatibleFlagsSet := logFile != "" && logDirectory != ""; incompatibleFlagsSet {
|
||||
log.Error().Msgf("Your config includes values for both %s (%s) and %s (%s), but they are incompatible. %s takes precedence.", LogFileFlag, logFile, logDirectoryFlagName, logDirectory, LogFileFlag)
|
||||
log.Error().Msgf("Your config includes values for both %s (%s) and %s (%s), but they are incompatible. %s takes precedence.", cfdflags.LogFile, logFile, logDirectoryFlagName, logDirectory, cfdflags.LogFile)
|
||||
}
|
||||
return log
|
||||
}
|
||||
|
@ -214,7 +198,6 @@ var (
|
|||
|
||||
func createFileWriter(config FileConfig) (io.Writer, error) {
|
||||
singleFileInit.once.Do(func() {
|
||||
|
||||
var logFile io.Writer
|
||||
fullpath := config.Fullpath()
|
||||
|
||||
|
@ -265,7 +248,7 @@ func createRollingLogger(config RollingConfig) (io.Writer, error) {
|
|||
}
|
||||
|
||||
rotatingFileInit.writer = &lumberjack.Logger{
|
||||
Filename: path.Join(config.Dirname, config.Filename),
|
||||
Filename: filepath.Join(config.Dirname, config.Filename),
|
||||
MaxBackups: config.maxBackups,
|
||||
MaxSize: config.maxSize,
|
||||
MaxAge: config.maxAge,
|
||||
|
|
|
@ -74,7 +74,7 @@ type EventLog struct {
|
|||
type LogEventType int8
|
||||
|
||||
const (
|
||||
// Cloudflared events are signficant to cloudflared operations like connection state changes.
|
||||
// Cloudflared events are significant to cloudflared operations like connection state changes.
|
||||
// Cloudflared is also the default event type for any events that haven't been separated into a proper event type.
|
||||
Cloudflared LogEventType = iota
|
||||
HTTP
|
||||
|
@ -129,7 +129,7 @@ func (e *LogEventType) UnmarshalJSON(data []byte) error {
|
|||
|
||||
// LogLevel corresponds to the zerolog logging levels
|
||||
// "panic", "fatal", and "trace" are exempt from this list as they are rarely used and, at least
|
||||
// the the first two are limited to failure conditions that lead to cloudflared shutting down.
|
||||
// the first two are limited to failure conditions that lead to cloudflared shutting down.
|
||||
type LogLevel int8
|
||||
|
||||
const (
|
||||
|
|
|
@ -29,7 +29,7 @@ var Runtime = "host"
|
|||
|
||||
func GetMetricsDefaultAddress(runtimeType string) string {
|
||||
// When issuing the diagnostic command we may have to reach a server that is
|
||||
// running in a virtual enviroment and in that case we must bind to 0.0.0.0
|
||||
// running in a virtual environment and in that case we must bind to 0.0.0.0
|
||||
// otherwise the server won't be reachable.
|
||||
switch runtimeType {
|
||||
case "virtual":
|
||||
|
|
|
@ -0,0 +1,150 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ../flow/limiter.go
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter
|
||||
//
|
||||
|
||||
// Package mocks is a generated GoMock package.
|
||||
package mocks
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockLimiter is a mock of Limiter interface.
|
||||
type MockLimiter struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLimiterMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockLimiterMockRecorder is the mock recorder for MockLimiter.
|
||||
type MockLimiterMockRecorder struct {
|
||||
mock *MockLimiter
|
||||
}
|
||||
|
||||
// NewMockLimiter creates a new mock instance.
|
||||
func NewMockLimiter(ctrl *gomock.Controller) *MockLimiter {
|
||||
mock := &MockLimiter{ctrl: ctrl}
|
||||
mock.recorder = &MockLimiterMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLimiter) EXPECT() *MockLimiterMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Acquire mocks base method.
|
||||
func (m *MockLimiter) Acquire(flowType string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Acquire", flowType)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Acquire indicates an expected call of Acquire.
|
||||
func (mr *MockLimiterMockRecorder) Acquire(flowType any) *MockLimiterAcquireCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockLimiter)(nil).Acquire), flowType)
|
||||
return &MockLimiterAcquireCall{Call: call}
|
||||
}
|
||||
|
||||
// MockLimiterAcquireCall wrap *gomock.Call
|
||||
type MockLimiterAcquireCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockLimiterAcquireCall) Return(arg0 error) *MockLimiterAcquireCall {
|
||||
c.Call = c.Call.Return(arg0)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockLimiterAcquireCall) Do(f func(string) error) *MockLimiterAcquireCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockLimiterAcquireCall) DoAndReturn(f func(string) error) *MockLimiterAcquireCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Release mocks base method.
|
||||
func (m *MockLimiter) Release() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Release")
|
||||
}
|
||||
|
||||
// Release indicates an expected call of Release.
|
||||
func (mr *MockLimiterMockRecorder) Release() *MockLimiterReleaseCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockLimiter)(nil).Release))
|
||||
return &MockLimiterReleaseCall{Call: call}
|
||||
}
|
||||
|
||||
// MockLimiterReleaseCall wrap *gomock.Call
|
||||
type MockLimiterReleaseCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockLimiterReleaseCall) Return() *MockLimiterReleaseCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockLimiterReleaseCall) Do(f func()) *MockLimiterReleaseCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockLimiterReleaseCall) DoAndReturn(f func()) *MockLimiterReleaseCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// SetLimit mocks base method.
|
||||
func (m *MockLimiter) SetLimit(arg0 uint64) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetLimit", arg0)
|
||||
}
|
||||
|
||||
// SetLimit indicates an expected call of SetLimit.
|
||||
func (mr *MockLimiterMockRecorder) SetLimit(arg0 any) *MockLimiterSetLimitCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockLimiter)(nil).SetLimit), arg0)
|
||||
return &MockLimiterSetLimitCall{Call: call}
|
||||
}
|
||||
|
||||
// MockLimiterSetLimitCall wrap *gomock.Call
|
||||
type MockLimiterSetLimitCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockLimiterSetLimitCall) Return() *MockLimiterSetLimitCall {
|
||||
c.Call = c.Call.Return()
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockLimiterSetLimitCall) Do(f func(uint64)) *MockLimiterSetLimitCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockLimiterSetLimitCall) DoAndReturn(f func(uint64)) *MockLimiterSetLimitCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
//go:build gomock || generate
|
||||
|
||||
package mocks
|
||||
|
||||
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter"
|
|
@ -4,14 +4,17 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/proxy"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
|
@ -33,6 +36,8 @@ type Orchestrator struct {
|
|||
// cloudflared Configuration
|
||||
config *Config
|
||||
tags []pogs.Tag
|
||||
// flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
|
||||
flowLimiter cfdflow.Limiter
|
||||
log *zerolog.Logger
|
||||
|
||||
// orchestrator must not handle any more updates after shutdownC is closed
|
||||
|
@ -54,6 +59,7 @@ func NewOrchestrator(ctx context.Context,
|
|||
internalRules: internalRules,
|
||||
config: config,
|
||||
tags: tags,
|
||||
flowLimiter: cfdflow.NewLimiter(config.WarpRouting.MaxActiveFlows),
|
||||
log: log,
|
||||
shutdownC: ctx.Done(),
|
||||
}
|
||||
|
@ -112,6 +118,30 @@ func (o *Orchestrator) UpdateConfig(version int32, config []byte) *pogs.UpdateCo
|
|||
}
|
||||
}
|
||||
|
||||
// overrideRemoteWarpRoutingWithLocalValues overrides the ingress.WarpRoutingConfig that comes from the remote with
|
||||
// the local values if there is any.
|
||||
func (o *Orchestrator) overrideRemoteWarpRoutingWithLocalValues(remoteWarpRouting *ingress.WarpRoutingConfig) error {
|
||||
return o.overrideMaxActiveFlows(o.config.ConfigurationFlags[flags.MaxActiveFlows], remoteWarpRouting)
|
||||
}
|
||||
|
||||
// overrideMaxActiveFlows checks the local configuration flags, and if a value is found for the flags.MaxActiveFlows
|
||||
// overrides the value that comes on the remote ingress.WarpRoutingConfig with the local value.
|
||||
func (o *Orchestrator) overrideMaxActiveFlows(maxActiveFlowsLocalConfig string, remoteWarpRouting *ingress.WarpRoutingConfig) error {
|
||||
// If max active flows isn't defined locally just use the remote value
|
||||
if maxActiveFlowsLocalConfig == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
maxActiveFlowsLocalOverride, err := strconv.ParseUint(maxActiveFlowsLocalConfig, 10, 64)
|
||||
if err != nil {
|
||||
return pkgerrors.Wrapf(err, "failed to parse %s", flags.MaxActiveFlows)
|
||||
}
|
||||
|
||||
// Override the value that comes from the remote with the local value
|
||||
remoteWarpRouting.MaxActiveFlows = maxActiveFlowsLocalOverride
|
||||
return nil
|
||||
}
|
||||
|
||||
// The caller is responsible to make sure there is no concurrent access
|
||||
func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting ingress.WarpRoutingConfig) error {
|
||||
select {
|
||||
|
@ -120,6 +150,11 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
|
|||
default:
|
||||
}
|
||||
|
||||
// Overrides the local values, onto the remote values of the warp routing configuration
|
||||
if err := o.overrideRemoteWarpRoutingWithLocalValues(&warpRouting); err != nil {
|
||||
return pkgerrors.Wrap(err, "failed to merge local overrides into warp routing configuration")
|
||||
}
|
||||
|
||||
// Assign the internal ingress rules to the parsed ingress
|
||||
ingressRules.InternalRules = o.internalRules
|
||||
|
||||
|
@ -134,9 +169,13 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
|
|||
// The downside is minimized because none of the ingress.OriginService implementation have that requirement
|
||||
proxyShutdownC := make(chan struct{})
|
||||
if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil {
|
||||
return errors.Wrap(err, "failed to start origin")
|
||||
return pkgerrors.Wrap(err, "failed to start origin")
|
||||
}
|
||||
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.config.WriteTimeout, o.log)
|
||||
|
||||
// Update the flow limit since the configuration might have changed
|
||||
o.flowLimiter.SetLimit(warpRouting.MaxActiveFlows)
|
||||
|
||||
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.flowLimiter, o.config.WriteTimeout, o.log)
|
||||
o.proxy.Store(proxy)
|
||||
o.config.Ingress = &ingressRules
|
||||
o.config.WarpRouting = warpRouting
|
||||
|
@ -208,6 +247,12 @@ func (o *Orchestrator) GetOriginProxy() (connection.OriginProxy, error) {
|
|||
return proxy, nil
|
||||
}
|
||||
|
||||
// GetFlowLimiter returns the flow limiter used across cloudflared, that can be hot reload when
|
||||
// the configuration changes.
|
||||
func (o *Orchestrator) GetFlowLimiter() cfdflow.Limiter {
|
||||
return o.flowLimiter
|
||||
}
|
||||
|
||||
func (o *Orchestrator) waitToCloseLastProxy() {
|
||||
<-o.shutdownC
|
||||
o.lock.Lock()
|
||||
|
|
|
@ -16,8 +16,11 @@ import (
|
|||
"github.com/google/uuid"
|
||||
gows "github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
|
||||
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
|
@ -106,25 +109,25 @@ func TestUpdateConfiguration(t *testing.T) {
|
|||
require.Len(t, configV2.Ingress.Rules, 3)
|
||||
// originRequest of this ingress rule overrides global default
|
||||
require.Equal(t, config.CustomDuration{Duration: time.Second * 10}, configV2.Ingress.Rules[0].Config.ConnectTimeout)
|
||||
require.Equal(t, true, configV2.Ingress.Rules[0].Config.NoTLSVerify)
|
||||
require.True(t, configV2.Ingress.Rules[0].Config.NoTLSVerify)
|
||||
// Inherited from global default
|
||||
require.Equal(t, true, configV2.Ingress.Rules[0].Config.NoHappyEyeballs)
|
||||
require.True(t, configV2.Ingress.Rules[0].Config.NoHappyEyeballs)
|
||||
// Validate ingress rule 1
|
||||
require.Equal(t, "jira.tunnel.org", configV2.Ingress.Rules[1].Hostname)
|
||||
require.True(t, configV2.Ingress.Rules[1].Matches("jira.tunnel.org", "/users"))
|
||||
require.Equal(t, "http://172.32.20.6:80", configV2.Ingress.Rules[1].Service.String())
|
||||
// originRequest of this ingress rule overrides global default
|
||||
require.Equal(t, config.CustomDuration{Duration: time.Second * 30}, configV2.Ingress.Rules[1].Config.ConnectTimeout)
|
||||
require.Equal(t, true, configV2.Ingress.Rules[1].Config.NoTLSVerify)
|
||||
require.True(t, configV2.Ingress.Rules[1].Config.NoTLSVerify)
|
||||
// Inherited from global default
|
||||
require.Equal(t, true, configV2.Ingress.Rules[1].Config.NoHappyEyeballs)
|
||||
require.True(t, configV2.Ingress.Rules[1].Config.NoHappyEyeballs)
|
||||
// Validate ingress rule 2, it's the catch-all rule
|
||||
require.True(t, configV2.Ingress.Rules[2].Matches("blogs.tunnel.io", "/2022/02/10"))
|
||||
// Inherited from global default
|
||||
require.Equal(t, config.CustomDuration{Duration: time.Second * 90}, configV2.Ingress.Rules[2].Config.ConnectTimeout)
|
||||
require.Equal(t, false, configV2.Ingress.Rules[2].Config.NoTLSVerify)
|
||||
require.Equal(t, true, configV2.Ingress.Rules[2].Config.NoHappyEyeballs)
|
||||
require.Equal(t, configV2.WarpRouting.ConnectTimeout.Duration, 10*time.Second)
|
||||
require.False(t, configV2.Ingress.Rules[2].Config.NoTLSVerify)
|
||||
require.True(t, configV2.Ingress.Rules[2].Config.NoHappyEyeballs)
|
||||
require.Equal(t, 10*time.Second, configV2.WarpRouting.ConnectTimeout.Duration)
|
||||
|
||||
originProxyV2, err := orchestrator.GetOriginProxy()
|
||||
require.NoError(t, err)
|
||||
|
@ -317,7 +320,7 @@ func TestConcurrentUpdateAndRead(t *testing.T) {
|
|||
go func(i int, originProxy connection.OriginProxy) {
|
||||
defer wg.Done()
|
||||
resp, err := proxyHTTP(originProxy, hostname)
|
||||
require.NoError(t, err, "proxyHTTP %d failed %v", i, err)
|
||||
assert.NoError(t, err, "proxyHTTP %d failed %v", i, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
var warpRoutingDisabled bool
|
||||
|
@ -326,16 +329,16 @@ func TestConcurrentUpdateAndRead(t *testing.T) {
|
|||
// v1 proxy, warp enabled
|
||||
case 200:
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, t.Name(), string(body))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, t.Name(), string(body))
|
||||
warpRoutingDisabled = false
|
||||
// v2 proxy, warp disabled
|
||||
case 204:
|
||||
require.Greater(t, i, concurrentRequests/4)
|
||||
assert.Greater(t, i, concurrentRequests/4)
|
||||
warpRoutingDisabled = true
|
||||
// v3 proxy, warp enabled
|
||||
case 418:
|
||||
require.Greater(t, i, concurrentRequests/2)
|
||||
assert.Greater(t, i, concurrentRequests/2)
|
||||
warpRoutingDisabled = false
|
||||
}
|
||||
|
||||
|
@ -358,11 +361,10 @@ func TestConcurrentUpdateAndRead(t *testing.T) {
|
|||
|
||||
err = proxyTCP(ctx, originProxy, tcpOrigin.Addr().String(), w, pr)
|
||||
if warpRoutingDisabled {
|
||||
require.Error(t, err, "expect proxyTCP %d to return error", i)
|
||||
assert.Error(t, err, "expect proxyTCP %d to return error", i)
|
||||
} else {
|
||||
require.NoError(t, err, "proxyTCP %d failed %v", i, err)
|
||||
assert.NoError(t, err, "proxyTCP %d failed %v", i, err)
|
||||
}
|
||||
|
||||
}(i, originProxy)
|
||||
|
||||
if i == concurrentRequests/4 {
|
||||
|
@ -388,6 +390,57 @@ func TestConcurrentUpdateAndRead(t *testing.T) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestOverrideWarpRoutingConfigWithLocalValues tests that if a value is defined in the Config.ConfigurationFlags,
|
||||
// it will override the value that comes from the remote result.
|
||||
func TestOverrideWarpRoutingConfigWithLocalValues(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
assertMaxActiveFlows := func(orchestrator *Orchestrator, expectedValue uint64) {
|
||||
configJson, err := orchestrator.GetConfigJSON()
|
||||
require.NoError(t, err)
|
||||
var result map[string]interface{}
|
||||
err = json.Unmarshal(configJson, &result)
|
||||
require.NoError(t, err)
|
||||
warpRouting := result["warp-routing"].(map[string]interface{})
|
||||
require.EqualValues(t, expectedValue, warpRouting["maxActiveFlows"])
|
||||
}
|
||||
|
||||
remoteValue := uint64(100)
|
||||
remoteIngress := ingress.Ingress{}
|
||||
remoteWarpConfig := ingress.WarpRoutingConfig{
|
||||
MaxActiveFlows: remoteValue,
|
||||
}
|
||||
remoteConfig := &Config{
|
||||
Ingress: &remoteIngress,
|
||||
WarpRouting: remoteWarpConfig,
|
||||
ConfigurationFlags: map[string]string{},
|
||||
}
|
||||
orchestrator, err := NewOrchestrator(ctx, remoteConfig, testTags, []ingress.Rule{}, &testLogger)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertMaxActiveFlows(orchestrator, remoteValue)
|
||||
|
||||
// Add a local override for the maxActiveFlows
|
||||
localValue := uint64(500)
|
||||
remoteConfig.ConfigurationFlags[flags.MaxActiveFlows] = fmt.Sprintf("%d", localValue)
|
||||
// Force a configuration refresh
|
||||
err = orchestrator.updateIngress(remoteIngress, remoteWarpConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the value being used is the local one
|
||||
assertMaxActiveFlows(orchestrator, localValue)
|
||||
|
||||
// Remove local override for the maxActiveFlows
|
||||
delete(remoteConfig.ConfigurationFlags, flags.MaxActiveFlows)
|
||||
// Force a configuration refresh
|
||||
err = orchestrator.updateIngress(remoteIngress, remoteWarpConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check the value being used is now the remote again
|
||||
assertMaxActiveFlows(orchestrator, remoteValue)
|
||||
}
|
||||
|
||||
func proxyHTTP(originProxy connection.OriginProxy, hostname string) (*http.Response, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", hostname), nil)
|
||||
if err != nil {
|
||||
|
@ -409,15 +462,16 @@ func proxyHTTP(originProxy connection.OriginProxy, hostname string) (*http.Respo
|
|||
return w.Result(), nil
|
||||
}
|
||||
|
||||
// nolint: testifylint // this is used inside go routines so it can't use `require.`
|
||||
func tcpEyeball(t *testing.T, reqWriter io.WriteCloser, body string, respReadWriter *respReadWriteFlusher) {
|
||||
writeN, err := reqWriter.Write([]byte(body))
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
readBuffer := make([]byte, writeN)
|
||||
n, err := respReadWriter.Read(readBuffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, body, string(readBuffer[:n]))
|
||||
require.Equal(t, writeN, n)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, body, string(readBuffer[:n]))
|
||||
assert.Equal(t, writeN, n)
|
||||
}
|
||||
|
||||
func proxyTCP(ctx context.Context, originProxy connection.OriginProxy, originAddr string, w http.ResponseWriter, reqBody io.ReadCloser) error {
|
||||
|
@ -458,14 +512,15 @@ func serveTCPOrigin(t *testing.T, tcpOrigin net.Listener, wg *sync.WaitGroup) {
|
|||
}
|
||||
}
|
||||
|
||||
// nolint: testifylint // this is used inside go routines so it can't use `require.`
|
||||
func echoTCP(t *testing.T, conn net.Conn) {
|
||||
readBuf := make([]byte, 1000)
|
||||
readN, err := conn.Read(readBuf)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
writeN, err := conn.Write(readBuf[:readN])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, readN, writeN)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, readN, writeN)
|
||||
}
|
||||
|
||||
type validateHostHandler struct {
|
||||
|
@ -479,16 +534,17 @@ func (vhh *validateHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
|
|||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(vhh.body))
|
||||
_, _ = w.Write([]byte(vhh.body))
|
||||
}
|
||||
|
||||
// nolint: testifylint // this is used inside go routines so it can't use `require.`
|
||||
func updateWithValidation(t *testing.T, orchestrator *Orchestrator, version int32, config []byte) {
|
||||
resp := orchestrator.UpdateConfig(version, config)
|
||||
require.NoError(t, resp.Err)
|
||||
require.Equal(t, version, resp.LastAppliedVersion)
|
||||
assert.NoError(t, resp.Err)
|
||||
assert.Equal(t, version, resp.LastAppliedVersion)
|
||||
}
|
||||
|
||||
// TestClosePreviousProxies makes sure proxies started in the pervious configuration version are shutdown
|
||||
// TestClosePreviousProxies makes sure proxies started in the previous configuration version are shutdown
|
||||
func TestClosePreviousProxies(t *testing.T) {
|
||||
var (
|
||||
hostname = "hello.tunnel1.org"
|
||||
|
@ -532,6 +588,7 @@ func TestClosePreviousProxies(t *testing.T) {
|
|||
|
||||
originProxyV1, err := orchestrator.GetOriginProxy()
|
||||
require.NoError(t, err)
|
||||
// nolint: bodyclose
|
||||
resp, err := proxyHTTP(originProxyV1, hostname)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
@ -540,12 +597,14 @@ func TestClosePreviousProxies(t *testing.T) {
|
|||
|
||||
originProxyV2, err := orchestrator.GetOriginProxy()
|
||||
require.NoError(t, err)
|
||||
// nolint: bodyclose
|
||||
resp, err = proxyHTTP(originProxyV2, hostname)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusTeapot, resp.StatusCode)
|
||||
|
||||
// The hello-world server in config v1 should have been stopped. We wait a bit since it's closed asynchronously.
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
// nolint: bodyclose
|
||||
resp, err = proxyHTTP(originProxyV1, hostname)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, resp)
|
||||
|
@ -557,6 +616,7 @@ func TestClosePreviousProxies(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.NotEqual(t, originProxyV1, originProxyV3)
|
||||
|
||||
// nolint: bodyclose
|
||||
resp, err = proxyHTTP(originProxyV3, hostname)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
@ -566,6 +626,7 @@ func TestClosePreviousProxies(t *testing.T) {
|
|||
// Wait for proxies to shutdown
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
// nolint: bodyclose
|
||||
resp, err = proxyHTTP(originProxyV3, hostname)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, resp)
|
||||
|
@ -622,7 +683,7 @@ func TestPersistentConnection(t *testing.T) {
|
|||
go func() {
|
||||
defer wg.Done()
|
||||
conn, err := tcpOrigin.Accept()
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Expect 3 TCP messages
|
||||
|
@ -630,26 +691,26 @@ func TestPersistentConnection(t *testing.T) {
|
|||
echoTCP(t, conn)
|
||||
}
|
||||
}()
|
||||
// Simulate cloudflared recieving a TCP connection
|
||||
// Simulate cloudflared receiving a TCP connection
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
require.NoError(t, proxyTCP(ctx, originProxy, tcpOrigin.Addr().String(), tcpRespReadWriter, tcpReqReader))
|
||||
assert.NoError(t, proxyTCP(ctx, originProxy, tcpOrigin.Addr().String(), tcpRespReadWriter, tcpReqReader))
|
||||
}()
|
||||
// Simulate cloudflared recieving a WS connection
|
||||
// Simulate cloudflared receiving a WS connection
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, hostname, wsReqReader)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
// ProxyHTTP will add Connection, Upgrade and Sec-Websocket-Version headers
|
||||
req.Header.Add("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
|
||||
log := zerolog.Nop()
|
||||
respWriter, err := connection.NewHTTP2RespWriter(req, wsRespReadWriter, connection.TypeWebsocket, &log)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = originProxy.ProxyHTTP(respWriter, tracing.NewTracedHTTPRequest(req, 0, &log), true)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Simulate eyeball WS and TCP connections
|
||||
|
|
|
@ -9,10 +9,14 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
|
||||
"github.com/cloudflare/cloudflared/carrier"
|
||||
"github.com/cloudflare/cloudflared/cfio"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
|
@ -32,8 +36,8 @@ const (
|
|||
type Proxy struct {
|
||||
ingressRules ingress.Ingress
|
||||
warpRouting *ingress.WarpRoutingService
|
||||
management *ingress.ManagementService
|
||||
tags []pogs.Tag
|
||||
flowLimiter cfdflow.Limiter
|
||||
log *zerolog.Logger
|
||||
}
|
||||
|
||||
|
@ -42,12 +46,14 @@ func NewOriginProxy(
|
|||
ingressRules ingress.Ingress,
|
||||
warpRouting ingress.WarpRoutingConfig,
|
||||
tags []pogs.Tag,
|
||||
flowLimiter cfdflow.Limiter,
|
||||
writeTimeout time.Duration,
|
||||
log *zerolog.Logger,
|
||||
) *Proxy {
|
||||
proxy := &Proxy{
|
||||
ingressRules: ingressRules,
|
||||
tags: tags,
|
||||
flowLimiter: flowLimiter,
|
||||
log: log,
|
||||
}
|
||||
|
||||
|
@ -64,7 +70,7 @@ func (p *Proxy) applyIngressMiddleware(rule *ingress.Rule, r *http.Request, w co
|
|||
}
|
||||
|
||||
if result.ShouldFilterRequest {
|
||||
w.WriteRespHeaders(result.StatusCode, nil)
|
||||
_ = w.WriteRespHeaders(result.StatusCode, nil)
|
||||
return fmt.Errorf("request filtered by middleware handler (%s) due to: %s", handler.Name(), result.Reason), true
|
||||
}
|
||||
}
|
||||
|
@ -152,10 +158,18 @@ func (p *Proxy) ProxyTCP(
|
|||
return err
|
||||
}
|
||||
|
||||
logger := newTCPLogger(p.log, req)
|
||||
|
||||
// Try to start a new flow
|
||||
if err := p.flowLimiter.Acquire(management.TCP.String()); err != nil {
|
||||
logger.Warn().Msg("Too many concurrent flows being handled, rejecting tcp proxy")
|
||||
return pkgerrors.Wrap(err, "failed to start tcp flow due to rate limiting")
|
||||
}
|
||||
defer p.flowLimiter.Release()
|
||||
|
||||
serveCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
logger := newTCPLogger(p.log, req)
|
||||
tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, &logger)
|
||||
logger.Debug().Msg("tcp proxy stream started")
|
||||
|
||||
|
|
|
@ -21,8 +21,13 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/cli/v2"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/cloudflare/cloudflared/mocks"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cfio"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
|
@ -71,11 +76,6 @@ func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
|
|||
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
|
||||
}
|
||||
|
||||
// respHeaders is a test function to read respHeaders
|
||||
func (w *mockHTTPRespWriter) headers() http.Header {
|
||||
return w.Header()
|
||||
}
|
||||
|
||||
func (m *mockHTTPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
panic("Hijack not implemented")
|
||||
}
|
||||
|
@ -113,7 +113,7 @@ func (w *mockWSRespWriter) Read(data []byte) (int, error) {
|
|||
return w.reader.Read(data)
|
||||
}
|
||||
|
||||
func (m *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
func (w *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
panic("Hijack not implemented")
|
||||
}
|
||||
|
||||
|
@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
|
|||
|
||||
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
|
||||
|
||||
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, time.Duration(0), &log)
|
||||
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
|
||||
t.Run("testProxyHTTP", testProxyHTTP(proxy))
|
||||
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
|
||||
t.Run("testProxySSE", testProxySSE(proxy))
|
||||
|
@ -246,7 +246,7 @@ func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
|||
_ = responseWriter.Close()
|
||||
|
||||
close(finished)
|
||||
errGroup.Wait()
|
||||
_ = errGroup.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -267,7 +267,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
|
|||
defer wg.Done()
|
||||
log := zerolog.Nop()
|
||||
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false)
|
||||
require.Equal(t, err.Error(), "context canceled")
|
||||
require.Equal(t, "context canceled", err.Error())
|
||||
|
||||
require.Equal(t, http.StatusOK, responseWriter.Code)
|
||||
}()
|
||||
|
@ -275,7 +275,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
|
|||
for i := 0; i < pushCount; i++ {
|
||||
line := responseWriter.ReadBytes()
|
||||
expect := fmt.Sprintf("%d\n\n", i)
|
||||
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
|
||||
require.Equal(t, []byte(expect), line, "Expect to read %v, got %v", expect, line)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
@ -290,7 +290,9 @@ func TestProxySSEAllData(t *testing.T) {
|
|||
responseWriter := newMockSSERespWriter()
|
||||
|
||||
// responseWriter uses an unbuffered channel, so we call in a different go-routine
|
||||
go cfio.Copy(responseWriter, eyeballReader)
|
||||
go func() {
|
||||
_, _ = cfio.Copy(responseWriter, eyeballReader)
|
||||
}()
|
||||
|
||||
result := string(<-responseWriter.writeNotification)
|
||||
require.Equal(t, "data\r\r", result)
|
||||
|
@ -366,7 +368,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
|
||||
|
||||
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, time.Duration(0), &log)
|
||||
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
|
||||
|
||||
for _, test := range tests {
|
||||
responseWriter := newMockHTTPRespWriter()
|
||||
|
@ -414,25 +416,20 @@ func TestProxyError(t *testing.T) {
|
|||
|
||||
log := zerolog.Nop()
|
||||
|
||||
proxy := NewOriginProxy(ing, noWarpRouting, testTags, time.Duration(0), &log)
|
||||
proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
|
||||
|
||||
responseWriter := newMockHTTPRespWriter()
|
||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false))
|
||||
require.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false))
|
||||
}
|
||||
|
||||
type replayer struct {
|
||||
sync.RWMutex
|
||||
writeDone chan struct{}
|
||||
rw *bytes.Buffer
|
||||
}
|
||||
|
||||
func newReplayer(buffer *bytes.Buffer) {
|
||||
|
||||
}
|
||||
|
||||
func (r *replayer) Read(p []byte) (int, error) {
|
||||
r.RLock()
|
||||
defer r.RUnlock()
|
||||
|
@ -471,7 +468,7 @@ func (r *replayer) Bytes() []byte {
|
|||
// eyeball sends tcp packets wrapped in websockets. (E.g: cloudflared access).
|
||||
func TestConnections(t *testing.T) {
|
||||
logger := logger.Create(nil)
|
||||
replayer := &replayer{rw: &bytes.Buffer{}}
|
||||
replayer := &replayer{rw: bytes.NewBuffer([]byte{})}
|
||||
type args struct {
|
||||
ingressServiceScheme string
|
||||
originService func(*testing.T, net.Listener)
|
||||
|
@ -486,6 +483,9 @@ func TestConnections(t *testing.T) {
|
|||
|
||||
// requestheaders to be sent in the call to proxy.Proxy
|
||||
requestHeaders http.Header
|
||||
|
||||
// flowLimiterResponse is the response of the cfdflow.Limiter#Acquire method call
|
||||
flowLimiterResponse error
|
||||
}
|
||||
|
||||
type want struct {
|
||||
|
@ -663,6 +663,25 @@ func TestConnections(t *testing.T) {
|
|||
err: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tcp-* proxy rate limited flow",
|
||||
args: args{
|
||||
ingressServiceScheme: "tcp://",
|
||||
originService: runEchoTCPService,
|
||||
eyeballResponseWriter: newTCPRespWriter(replayer),
|
||||
eyeballRequestBody: newTCPRequestBody([]byte("rate-limited")),
|
||||
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
|
||||
connectionType: connection.TypeTCP,
|
||||
requestHeaders: map[string][]string{
|
||||
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
|
||||
},
|
||||
flowLimiterResponse: cfdflow.ErrTooManyActiveFlows,
|
||||
},
|
||||
want: want{
|
||||
message: []byte{},
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
@ -674,8 +693,16 @@ func TestConnections(t *testing.T) {
|
|||
test.args.originService(t, ln)
|
||||
|
||||
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
|
||||
ingressRule.StartOrigins(logger, ctx.Done())
|
||||
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, time.Duration(0), logger)
|
||||
_ = ingressRule.StartOrigins(logger, ctx.Done())
|
||||
|
||||
// Mock flow limiter
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
flowLimiter := mocks.NewMockLimiter(ctrl)
|
||||
flowLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.flowLimiterResponse)
|
||||
flowLimiter.EXPECT().Release().AnyTimes()
|
||||
|
||||
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, flowLimiter, time.Duration(0), logger)
|
||||
proxy.warpRouting = test.args.warpRoutingService
|
||||
|
||||
dest := ln.Addr().String()
|
||||
|
@ -693,7 +720,7 @@ func TestConnections(t *testing.T) {
|
|||
respWriter = newTCPRespWriter(pipedReqBody.pipedConn)
|
||||
go func() {
|
||||
resp := pipedReqBody.roundtrip(test.args.ingressServiceScheme + ln.Addr().String())
|
||||
replayer.Write(resp)
|
||||
_, _ = replayer.Write(resp)
|
||||
}()
|
||||
}
|
||||
if test.args.connectionType == connection.TypeTCP {
|
||||
|
@ -705,9 +732,9 @@ func TestConnections(t *testing.T) {
|
|||
}
|
||||
|
||||
cancel()
|
||||
assert.Equal(t, test.want.err, err != nil)
|
||||
assert.Equal(t, test.want.message, replayer.Bytes())
|
||||
assert.Equal(t, test.want.headers, respWriter.Header())
|
||||
require.Equal(t, test.want.err, err != nil)
|
||||
require.Equal(t, test.want.message, replayer.Bytes())
|
||||
require.Equal(t, test.want.headers, respWriter.Header())
|
||||
replayer.rw.Reset()
|
||||
})
|
||||
}
|
||||
|
@ -720,7 +747,9 @@ type requestBody struct {
|
|||
|
||||
func newWSRequestBody(data []byte) *requestBody {
|
||||
pr, pw := io.Pipe()
|
||||
go wsutil.WriteClientBinary(pw, data)
|
||||
go func() {
|
||||
_ = wsutil.WriteClientBinary(pw, data)
|
||||
}()
|
||||
return &requestBody{
|
||||
pr: pr,
|
||||
pw: pw,
|
||||
|
@ -728,7 +757,9 @@ func newWSRequestBody(data []byte) *requestBody {
|
|||
}
|
||||
func newTCPRequestBody(data []byte) *requestBody {
|
||||
pr, pw := io.Pipe()
|
||||
go pw.Write(data)
|
||||
go func() {
|
||||
_, _ = pw.Write(data)
|
||||
}()
|
||||
return &requestBody{
|
||||
pr: pr,
|
||||
pw: pw,
|
||||
|
@ -740,8 +771,8 @@ func (r *requestBody) Read(p []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
func (r *requestBody) Close() error {
|
||||
r.pw.Close()
|
||||
r.pr.Close()
|
||||
_ = r.pw.Close()
|
||||
_ = r.pr.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -774,6 +805,7 @@ func (p *pipedRequestBody) roundtrip(addr string) []byte {
|
|||
panic(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
panic(fmt.Errorf("resp returned status code: %d", resp.StatusCode))
|
||||
|
@ -917,7 +949,9 @@ func runEchoTCPService(t *testing.T, l net.Listener) {
|
|||
go func() {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
|
@ -971,12 +1005,15 @@ func runEchoWSService(t *testing.T, l net.Listener) {
|
|||
}
|
||||
}
|
||||
|
||||
// nolint: gosec
|
||||
server := http.Server{
|
||||
Handler: http.HandlerFunc(ws),
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve(l)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
|
|
@ -116,7 +116,7 @@ func (s *UDPSessionRegistrationDatagram) MarshalBinary() (data []byte, err error
|
|||
data = make([]byte, sessionRegistrationIPv4DatagramHeaderLen+len(s.Payload))
|
||||
}
|
||||
data[0] = byte(UDPSessionRegistrationType)
|
||||
data[1] = byte(flags)
|
||||
data[1] = flags
|
||||
binary.BigEndian.PutUint16(data[2:4], s.Dest.Port())
|
||||
binary.BigEndian.PutUint16(data[4:6], uint16(s.IdleDurationHint.Seconds()))
|
||||
err = s.RequestID.MarshalBinaryTo(data[6:22])
|
||||
|
@ -284,6 +284,8 @@ const (
|
|||
ResponseDestinationUnreachable SessionRegistrationResp = 0x01
|
||||
// Session registration was unable to bind to a local UDP socket.
|
||||
ResponseUnableToBindSocket SessionRegistrationResp = 0x02
|
||||
// Session registration failed due to the number of flows being higher than the limit.
|
||||
ResponseTooManyActiveFlows SessionRegistrationResp = 0x03
|
||||
// Session registration failed with an unexpected error but provided a message.
|
||||
ResponseErrorWithMsg SessionRegistrationResp = 0xff
|
||||
)
|
||||
|
@ -311,6 +313,7 @@ func (s *UDPSessionRegistrationResponseDatagram) MarshalBinary() (data []byte, e
|
|||
if len(s.ErrorMsg) > maxResponseErrorMessageLen {
|
||||
return nil, wrapMarshalErr(ErrDatagramResponseMsgInvalidSize)
|
||||
}
|
||||
// nolint: gosec
|
||||
errMsgLen := uint16(len(s.ErrorMsg))
|
||||
|
||||
data = make([]byte, datagramSessionRegistrationResponseLen+errMsgLen)
|
||||
|
|
|
@ -7,6 +7,10 @@ import (
|
|||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -16,6 +20,8 @@ var (
|
|||
ErrSessionBoundToOtherConn = errors.New("flow is in use by another connection")
|
||||
// ErrSessionAlreadyRegistered is returned when a registration already exists for this connection.
|
||||
ErrSessionAlreadyRegistered = errors.New("flow is already registered for this connection")
|
||||
// ErrSessionRegistrationRateLimited is returned when a registration fails due to rate limiting on the number of active flows.
|
||||
ErrSessionRegistrationRateLimited = errors.New("flow registration rate limited")
|
||||
)
|
||||
|
||||
type SessionManager interface {
|
||||
|
@ -38,14 +44,16 @@ type sessionManager struct {
|
|||
sessions map[RequestID]Session
|
||||
mutex sync.RWMutex
|
||||
originDialer DialUDP
|
||||
limiter cfdflow.Limiter
|
||||
metrics Metrics
|
||||
log *zerolog.Logger
|
||||
}
|
||||
|
||||
func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP) SessionManager {
|
||||
func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP, limiter cfdflow.Limiter) SessionManager {
|
||||
return &sessionManager{
|
||||
sessions: make(map[RequestID]Session),
|
||||
originDialer: originDialer,
|
||||
limiter: limiter,
|
||||
metrics: metrics,
|
||||
log: log,
|
||||
}
|
||||
|
@ -61,6 +69,12 @@ func (s *sessionManager) RegisterSession(request *UDPSessionRegistrationDatagram
|
|||
}
|
||||
return nil, ErrSessionBoundToOtherConn
|
||||
}
|
||||
|
||||
// Try to start a new session
|
||||
if err := s.limiter.Acquire(management.UDP.String()); err != nil {
|
||||
return nil, ErrSessionRegistrationRateLimited
|
||||
}
|
||||
|
||||
// Attempt to bind the UDP socket for the new session
|
||||
origin, err := s.originDialer(request.Dest)
|
||||
if err != nil {
|
||||
|
@ -100,4 +114,5 @@ func (s *sessionManager) UnregisterSession(requestID RequestID) {
|
|||
_ = session.Close()
|
||||
}
|
||||
delete(s.sessions, requestID)
|
||||
s.limiter.Release()
|
||||
}
|
||||
|
|
|
@ -8,14 +8,19 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/cloudflare/cloudflared/mocks"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
||||
)
|
||||
|
||||
func TestRegisterSession(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort)
|
||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0))
|
||||
|
||||
request := v3.UDPSessionRegistrationDatagram{
|
||||
RequestID: testRequestID,
|
||||
|
@ -71,10 +76,32 @@ func TestRegisterSession(t *testing.T) {
|
|||
|
||||
func TestGetSession_Empty(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort)
|
||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0))
|
||||
|
||||
_, err := manager.GetSession(testRequestID)
|
||||
if !errors.Is(err, v3.ErrSessionNotFound) {
|
||||
t.Fatalf("get session find no session: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterSessionRateLimit(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
flowLimiterMock := mocks.NewMockLimiter(ctrl)
|
||||
|
||||
flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
|
||||
flowLimiterMock.EXPECT().Release().Times(0)
|
||||
|
||||
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, flowLimiterMock)
|
||||
|
||||
request := v3.UDPSessionRegistrationDatagram{
|
||||
RequestID: testRequestID,
|
||||
Dest: netip.MustParseAddrPort("127.0.0.1:5000"),
|
||||
Traced: false,
|
||||
IdleDurationHint: 5 * time.Second,
|
||||
Payload: nil,
|
||||
}
|
||||
_, err := manager.RegisterSession(&request, &noopEyeball{})
|
||||
require.ErrorIs(t, err, v3.ErrSessionRegistrationRateLimited)
|
||||
}
|
||||
|
|
|
@ -65,12 +65,11 @@ type datagramConn struct {
|
|||
icmpRouter ingress.ICMPRouter
|
||||
metrics Metrics
|
||||
logger *zerolog.Logger
|
||||
|
||||
datagrams chan []byte
|
||||
readErrors chan error
|
||||
|
||||
icmpEncoderPool sync.Pool // a pool of *packet.Encoder
|
||||
icmpDecoder *packet.ICMPDecoder
|
||||
icmpDecoderPool sync.Pool
|
||||
}
|
||||
|
||||
func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRouter ingress.ICMPRouter, index uint8, metrics Metrics, logger *zerolog.Logger) DatagramConn {
|
||||
|
@ -89,7 +88,11 @@ func NewDatagramConn(conn QuicConnection, sessionManager SessionManager, icmpRou
|
|||
return packet.NewEncoder()
|
||||
},
|
||||
},
|
||||
icmpDecoder: packet.NewICMPDecoder(),
|
||||
icmpDecoderPool: sync.Pool{
|
||||
New: func() any {
|
||||
return packet.NewICMPDecoder()
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -140,8 +143,6 @@ func (c *datagramConn) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.Raw
|
|||
return c.SendICMPPacket(c.icmpRouter.ConvertToTTLExceeded(icmp, rawPacket))
|
||||
}
|
||||
|
||||
var errReadTimeout error = errors.New("receive datagram timeout")
|
||||
|
||||
// pollDatagrams will read datagrams from the underlying connection until the provided context is done.
|
||||
func (c *datagramConn) pollDatagrams(ctx context.Context) {
|
||||
for ctx.Err() == nil {
|
||||
|
@ -253,8 +254,12 @@ func (c *datagramConn) handleSessionRegistrationDatagram(ctx context.Context, da
|
|||
// Session is already registered but to a different connection
|
||||
c.handleSessionMigration(datagram.RequestID, &log)
|
||||
return
|
||||
case ErrSessionRegistrationRateLimited:
|
||||
// There are too many concurrent sessions so we return an error to force a retry later
|
||||
c.handleSessionRegistrationRateLimited(datagram, &log)
|
||||
return
|
||||
default:
|
||||
log.Err(err).Msgf("flow registration failure")
|
||||
log.Err(err).Msg("flow registration failure")
|
||||
c.handleSessionRegistrationFailure(datagram.RequestID, &log)
|
||||
return
|
||||
}
|
||||
|
@ -275,7 +280,7 @@ func (c *datagramConn) handleSessionRegistrationDatagram(ctx context.Context, da
|
|||
// [Session.Serve] is blocking and will continue this go routine till the end of the session lifetime.
|
||||
start := time.Now()
|
||||
err = session.Serve(ctx)
|
||||
elapsedMS := time.Now().Sub(start).Milliseconds()
|
||||
elapsedMS := time.Since(start).Milliseconds()
|
||||
log = log.With().Int64(logDurationKey, elapsedMS).Logger()
|
||||
if err == nil {
|
||||
// We typically don't expect a session to close without some error response. [SessionIdleErr] is the typical
|
||||
|
@ -343,6 +348,16 @@ func (c *datagramConn) handleSessionRegistrationFailure(requestID RequestID, log
|
|||
}
|
||||
}
|
||||
|
||||
func (c *datagramConn) handleSessionRegistrationRateLimited(datagram *UDPSessionRegistrationDatagram, logger *zerolog.Logger) {
|
||||
c.logger.Warn().Msg("Too many concurrent sessions being handled, rejecting udp proxy")
|
||||
|
||||
rateLimitResponse := ResponseTooManyActiveFlows
|
||||
err := c.SendUDPSessionResponse(datagram.RequestID, rateLimitResponse)
|
||||
if err != nil {
|
||||
logger.Err(err).Msgf("unable to send flow registration error response (%d)", rateLimitResponse)
|
||||
}
|
||||
}
|
||||
|
||||
// Handles incoming datagrams that need to be sent to a registered session.
|
||||
func (c *datagramConn) handleSessionPayloadDatagram(datagram *UDPSessionPayloadDatagram, logger *zerolog.Logger) {
|
||||
s, err := c.sessionManager.GetSession(datagram.RequestID)
|
||||
|
@ -367,7 +382,16 @@ func (c *datagramConn) handleICMPPacket(datagram *ICMPDatagram) {
|
|||
|
||||
// Decode the provided ICMPDatagram as an ICMP packet
|
||||
rawPacket := packet.RawPacket{Data: datagram.Payload}
|
||||
icmp, err := c.icmpDecoder.Decode(rawPacket)
|
||||
cachedDecoder := c.icmpDecoderPool.Get()
|
||||
defer c.icmpDecoderPool.Put(cachedDecoder)
|
||||
decoder, ok := cachedDecoder.(*packet.ICMPDecoder)
|
||||
if !ok {
|
||||
c.logger.Error().Msg("Could not get ICMPDecoder from the pool. Dropping packet")
|
||||
return
|
||||
}
|
||||
|
||||
icmp, err := decoder.Decode(rawPacket)
|
||||
|
||||
if err != nil {
|
||||
c.logger.Err(err).Msgf("unable to marshal icmp packet")
|
||||
return
|
||||
|
|
|
@ -4,18 +4,23 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/icmp"
|
||||
"golang.org/x/net/ipv4"
|
||||
|
||||
cfdflow "github.com/cloudflare/cloudflared/flow"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/packet"
|
||||
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
||||
|
@ -83,7 +88,7 @@ func (m *mockEyeball) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.RawP
|
|||
|
||||
func TestDatagramConn_New(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
if conn == nil {
|
||||
t.Fatal("expected valid connection")
|
||||
}
|
||||
|
@ -92,10 +97,12 @@ func TestDatagramConn_New(t *testing.T) {
|
|||
func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
payload := []byte{0xef, 0xef}
|
||||
conn.SendUDPSessionDatagram(payload)
|
||||
err := conn.SendUDPSessionDatagram(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
p := <-quic.recv
|
||||
if !slices.Equal(p, payload) {
|
||||
t.Fatal("datagram sent does not match datagram received on quic side")
|
||||
|
@ -105,15 +112,16 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
|
|||
func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
|
||||
resp := <-quic.recv
|
||||
var response v3.UDPSessionRegistrationResponseDatagram
|
||||
err := response.UnmarshalBinary(resp)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = response.UnmarshalBinary(resp)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := v3.UDPSessionRegistrationResponseDatagram{
|
||||
RequestID: testRequestID,
|
||||
ResponseType: v3.ResponseDestinationUnreachable,
|
||||
|
@ -126,7 +134,7 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
|
|||
func TestDatagramConnServe_ApplicationClosed(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
@ -142,7 +150,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
|
|||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
quic.ctx = ctx
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
err := conn.Serve(context.Background())
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
|
@ -153,7 +161,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
|
|||
func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := &mockQuicConnReadError{err: net.ErrClosed}
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
err := conn.Serve(context.Background())
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
|
@ -161,6 +169,38 @@ func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
sessionManager := &mockSessionManager{
|
||||
expectedRegErr: v3.ErrSessionRegistrationRateLimited,
|
||||
}
|
||||
conn := v3.NewDatagramConn(quic, sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
|
||||
|
||||
// Setup the muxer
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
// Send new session registration
|
||||
datagram := newRegisterSessionDatagram(testRequestID)
|
||||
quic.send <- datagram
|
||||
|
||||
// Wait for session registration response with failure
|
||||
datagram = <-quic.recv
|
||||
var resp v3.UDPSessionRegistrationResponseDatagram
|
||||
err := resp.UnmarshalBinary(datagram)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
require.EqualValues(t, testRequestID, resp.RequestID)
|
||||
require.EqualValues(t, v3.ResponseTooManyActiveFlows, resp.ResponseType)
|
||||
}
|
||||
|
||||
func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
|
@ -304,6 +344,89 @@ func TestDatagramConnServe(t *testing.T) {
|
|||
assertContextClosed(t, ctx, done, cancel)
|
||||
}
|
||||
|
||||
// This test exists because decoding multiple packets in parallel with the same decoder
|
||||
// instances causes inteference resulting in multiple different raw packets being decoded
|
||||
// as the same decoded packet.
|
||||
func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
session := newMockSession()
|
||||
sessionManager := mockSessionManager{session: &session}
|
||||
router := newMockICMPRouter()
|
||||
conn := v3.NewDatagramConn(quic, &sessionManager, router, 0, &noopMetrics{}, &log)
|
||||
|
||||
// Setup the muxer
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
defer cancel(errors.New("other error"))
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
packetCount := 100
|
||||
packets := make([]*packet.ICMP, 100)
|
||||
ipTemplate := "10.0.0.%d"
|
||||
for i := 1; i <= packetCount; i++ {
|
||||
packets[i-1] = &packet.ICMP{
|
||||
IP: &packet.IP{
|
||||
Src: netip.MustParseAddr("192.168.1.1"),
|
||||
Dst: netip.MustParseAddr(fmt.Sprintf(ipTemplate, i)),
|
||||
Protocol: layers.IPProtocolICMPv4,
|
||||
TTL: 20,
|
||||
},
|
||||
Message: &icmp.Message{
|
||||
Type: ipv4.ICMPTypeEcho,
|
||||
Code: 0,
|
||||
Body: &icmp.Echo{
|
||||
ID: 25821,
|
||||
Seq: 58129,
|
||||
Data: []byte("test"),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
var receivedPackets []*packet.ICMP
|
||||
go func() {
|
||||
for ctx.Err() == nil {
|
||||
icmpPacket := <-router.recv
|
||||
receivedPackets = append(receivedPackets, icmpPacket)
|
||||
wg.Done()
|
||||
}
|
||||
}()
|
||||
|
||||
for _, p := range packets {
|
||||
// We increment here but only decrement when receiving the packet
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
datagram := newICMPDatagram(p)
|
||||
quic.send <- datagram
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// If there were duplicates then we won't have the same number of IPs
|
||||
packetSet := make(map[netip.Addr]*packet.ICMP, 0)
|
||||
for _, p := range receivedPackets {
|
||||
packetSet[p.Dst] = p
|
||||
}
|
||||
assert.Equal(t, len(packetSet), len(packets))
|
||||
|
||||
// Sort the slice by last byte of IP address (the one we increment for each destination)
|
||||
// and then check that we have one match for each packet sent
|
||||
sort.Slice(receivedPackets, func(i, j int) bool {
|
||||
return receivedPackets[i].Dst.As4()[3] < receivedPackets[j].Dst.As4()[3]
|
||||
})
|
||||
for i, p := range receivedPackets {
|
||||
assert.Equal(t, p.Dst, packets[i].Dst)
|
||||
}
|
||||
|
||||
// Cancel the muxer Serve context and make sure it closes with the expected error
|
||||
assertContextClosed(t, ctx, done, cancel)
|
||||
}
|
||||
|
||||
func TestDatagramConnServe_RegisterTwice(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
quic := newMockQuicConn()
|
||||
|
@ -588,7 +711,7 @@ func TestDatagramConnServe_ICMPDatagram_TTLExceeded(t *testing.T) {
|
|||
datagram := newICMPDatagram(expectedICMP)
|
||||
quic.send <- datagram
|
||||
|
||||
// Origin should not recieve a packet
|
||||
// Origin should not receive a packet
|
||||
select {
|
||||
case <-router.recv:
|
||||
t.Fatalf("TTL should be expired and no origin ICMP sent")
|
||||
|
@ -630,18 +753,6 @@ func newRegisterSessionDatagram(id v3.RequestID) []byte {
|
|||
return payload
|
||||
}
|
||||
|
||||
func newRegisterResponseSessionDatagram(id v3.RequestID, resp v3.SessionRegistrationResp) []byte {
|
||||
datagram := v3.UDPSessionRegistrationResponseDatagram{
|
||||
RequestID: id,
|
||||
ResponseType: resp,
|
||||
}
|
||||
payload, err := datagram.MarshalBinary()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func newSessionPayloadDatagram(id v3.RequestID, payload []byte) []byte {
|
||||
datagram := make([]byte, len(payload)+17)
|
||||
err := v3.MarshalPayloadHeaderTo(id, datagram[:])
|
||||
|
|
|
@ -89,7 +89,10 @@ func NewSession(
|
|||
log *zerolog.Logger,
|
||||
) Session {
|
||||
logger := log.With().Str(logFlowID, id.String()).Logger()
|
||||
closeChan := make(chan error, 1)
|
||||
// closeChan has two slots to allow for both writers (the closeFn and the Serve routine) to both be able to
|
||||
// write to the channel without blocking since there is only ever one value read from the closeChan by the
|
||||
// waitForCloseCondition.
|
||||
closeChan := make(chan error, 2)
|
||||
session := &session{
|
||||
id: id,
|
||||
closeAfterIdle: closeAfterIdle,
|
||||
|
|
|
@ -3,13 +3,14 @@ package v3_test
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/fortytw2/leaktest"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
v3 "github.com/cloudflare/cloudflared/quic/v3"
|
||||
|
@ -32,45 +33,64 @@ func TestSessionNew(t *testing.T) {
|
|||
|
||||
func testSessionWrite(t *testing.T, payload []byte) {
|
||||
log := zerolog.Nop()
|
||||
origin := newTestOrigin(makePayload(1280))
|
||||
session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||
origin, server := net.Pipe()
|
||||
defer origin.Close()
|
||||
defer server.Close()
|
||||
// Start origin server read
|
||||
serverRead := make(chan []byte, 1)
|
||||
go func() {
|
||||
read := make([]byte, 1500)
|
||||
server.Read(read[:])
|
||||
serverRead <- read
|
||||
}()
|
||||
// Create session and write to origin
|
||||
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||
n, err := session.Write(payload)
|
||||
defer session.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != len(payload) {
|
||||
t.Fatal("unable to write the whole payload")
|
||||
}
|
||||
if !slices.Equal(payload, origin.write[:len(payload)]) {
|
||||
|
||||
read := <-serverRead
|
||||
if !slices.Equal(payload, read[:len(payload)]) {
|
||||
t.Fatal("payload provided from origin and read value are not the same")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionWrite_Max(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
payload := makePayload(1280)
|
||||
testSessionWrite(t, payload)
|
||||
}
|
||||
|
||||
func TestSessionWrite_Min(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
payload := makePayload(0)
|
||||
testSessionWrite(t, payload)
|
||||
}
|
||||
|
||||
func TestSessionServe_OriginMax(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
payload := makePayload(1280)
|
||||
testSessionServe_Origin(t, payload)
|
||||
}
|
||||
|
||||
func TestSessionServe_OriginMin(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
payload := makePayload(0)
|
||||
testSessionServe_Origin(t, payload)
|
||||
}
|
||||
|
||||
func testSessionServe_Origin(t *testing.T, payload []byte) {
|
||||
log := zerolog.Nop()
|
||||
origin, server := net.Pipe()
|
||||
defer origin.Close()
|
||||
defer server.Close()
|
||||
eyeball := newMockEyeball()
|
||||
origin := newTestOrigin(payload)
|
||||
session := v3.NewSession(testRequestID, 3*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
|
||||
session := v3.NewSession(testRequestID, 3*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
|
||||
defer session.Close()
|
||||
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
|
@ -80,13 +100,19 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
|
|||
done <- session.Serve(ctx)
|
||||
}()
|
||||
|
||||
// Write from the origin server
|
||||
_, err := server.Write(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
select {
|
||||
case data := <-eyeball.recvData:
|
||||
// check received data matches provided from origin
|
||||
expectedData := makePayload(1500)
|
||||
v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
|
||||
copy(expectedData[17:], payload)
|
||||
if !slices.Equal(expectedData[:17+len(payload)], data) {
|
||||
if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) {
|
||||
t.Fatal("expected datagram did not equal expected")
|
||||
}
|
||||
cancel(expectedContextCanceled)
|
||||
|
@ -95,7 +121,7 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err := <-done
|
||||
err = <-done
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -105,11 +131,14 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
|
|||
}
|
||||
|
||||
func TestSessionServe_OriginTooLarge(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
eyeball := newMockEyeball()
|
||||
payload := makePayload(1281)
|
||||
origin := newTestOrigin(payload)
|
||||
session := v3.NewSession(testRequestID, 2*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
|
||||
origin, server := net.Pipe()
|
||||
defer origin.Close()
|
||||
defer server.Close()
|
||||
session := v3.NewSession(testRequestID, 2*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
|
||||
defer session.Close()
|
||||
|
||||
done := make(chan error)
|
||||
|
@ -117,6 +146,12 @@ func TestSessionServe_OriginTooLarge(t *testing.T) {
|
|||
done <- session.Serve(context.Background())
|
||||
}()
|
||||
|
||||
// Attempt to write a payload too large from the origin
|
||||
_, err := server.Write(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
select {
|
||||
case data := <-eyeball.recvData:
|
||||
// we never expect a read to make it here because the origin provided a payload that is too large
|
||||
|
@ -130,6 +165,7 @@ func TestSessionServe_OriginTooLarge(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionServe_Migrate(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
eyeball := newMockEyeball()
|
||||
pipe1, pipe2 := net.Pipe()
|
||||
|
@ -186,6 +222,7 @@ func TestSessionServe_Migrate(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
eyeball := newMockEyeball()
|
||||
pipe1, pipe2 := net.Pipe()
|
||||
|
@ -245,39 +282,48 @@ func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionClose_Multiple(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
origin := newTestOrigin(makePayload(128))
|
||||
session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||
origin, server := net.Pipe()
|
||||
defer origin.Close()
|
||||
defer server.Close()
|
||||
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||
err := session.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !origin.closed.Load() {
|
||||
t.Fatal("origin wasn't closed")
|
||||
b := [1500]byte{}
|
||||
_, err = server.Read(b[:])
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatalf("origin server connection should be closed: %s", err)
|
||||
}
|
||||
// Reset the closed status to make sure it isn't closed again
|
||||
origin.closed.Store(false)
|
||||
// subsequent closes shouldn't call close again or cause any errors
|
||||
err = session.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if origin.closed.Load() {
|
||||
t.Fatal("origin was incorrectly closed twice")
|
||||
_, err = server.Read(b[:])
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatalf("origin server connection should still be closed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionServe_IdleTimeout(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
origin := newTestIdleOrigin(10 * time.Second) // Make idle time longer than closeAfterIdle
|
||||
origin, server := net.Pipe()
|
||||
defer origin.Close()
|
||||
defer server.Close()
|
||||
closeAfterIdle := 2 * time.Second
|
||||
session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||
session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||
err := session.Serve(context.Background())
|
||||
if !errors.Is(err, v3.SessionIdleErr{}) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// session should be closed
|
||||
if !origin.closed {
|
||||
b := [1500]byte{}
|
||||
_, err = server.Read(b[:])
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatalf("session should be closed after Serve returns")
|
||||
}
|
||||
// closing a session again should not return an error
|
||||
|
@ -288,12 +334,14 @@ func TestSessionServe_IdleTimeout(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionServe_ParentContextCanceled(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
// Make idle time and idle timeout longer than closeAfterIdle
|
||||
origin := newTestIdleOrigin(10 * time.Second)
|
||||
origin, server := net.Pipe()
|
||||
defer origin.Close()
|
||||
defer server.Close()
|
||||
closeAfterIdle := 10 * time.Second
|
||||
|
||||
session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||
session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
err := session.Serve(ctx)
|
||||
|
@ -301,7 +349,9 @@ func TestSessionServe_ParentContextCanceled(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
// session should be closed
|
||||
if !origin.closed {
|
||||
b := [1500]byte{}
|
||||
_, err = server.Read(b[:])
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatalf("session should be closed after Serve returns")
|
||||
}
|
||||
// closing a session again should not return an error
|
||||
|
@ -312,6 +362,7 @@ func TestSessionServe_ParentContextCanceled(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSessionServe_ReadErrors(t *testing.T) {
|
||||
defer leaktest.Check(t)()
|
||||
log := zerolog.Nop()
|
||||
origin := newTestErrOrigin(net.ErrClosed, nil)
|
||||
session := v3.NewSession(testRequestID, 30*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
|
||||
|
@ -321,72 +372,6 @@ func TestSessionServe_ReadErrors(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type testOrigin struct {
|
||||
// bytes from Write
|
||||
write []byte
|
||||
// bytes provided to Read
|
||||
read []byte
|
||||
readOnce atomic.Bool
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func newTestOrigin(payload []byte) testOrigin {
|
||||
return testOrigin{
|
||||
read: payload,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *testOrigin) Read(p []byte) (n int, err error) {
|
||||
if o.closed.Load() {
|
||||
return -1, net.ErrClosed
|
||||
}
|
||||
if o.readOnce.Load() {
|
||||
// We only want to provide one read so all other reads will be blocked
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
o.readOnce.Store(true)
|
||||
return copy(p, o.read), nil
|
||||
}
|
||||
|
||||
func (o *testOrigin) Write(p []byte) (n int, err error) {
|
||||
if o.closed.Load() {
|
||||
return -1, net.ErrClosed
|
||||
}
|
||||
o.write = make([]byte, len(p))
|
||||
copy(o.write, p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (o *testOrigin) Close() error {
|
||||
o.closed.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
type testIdleOrigin struct {
|
||||
duration time.Duration
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newTestIdleOrigin(d time.Duration) testIdleOrigin {
|
||||
return testIdleOrigin{
|
||||
duration: d,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *testIdleOrigin) Read(p []byte) (n int, err error) {
|
||||
time.Sleep(o.duration)
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
func (o *testIdleOrigin) Write(p []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (o *testIdleOrigin) Close() error {
|
||||
o.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
type testErrOrigin struct {
|
||||
readErr error
|
||||
writeErr error
|
||||
|
|
|
@ -113,7 +113,7 @@ class PkgCreator:
|
|||
|
||||
def create_rpm_pkgs(self, artifacts_path, gpg_key_name):
|
||||
self._setup_rpm_pkg_directories(artifacts_path, gpg_key_name)
|
||||
p = Popen(["createrepo", "./rpm"], stdout=PIPE, stderr=PIPE)
|
||||
p = Popen(["createrepo_c", "./rpm"], stdout=PIPE, stderr=PIPE)
|
||||
out, err = p.communicate()
|
||||
if p.returncode != 0:
|
||||
print(f"create rpm_pkgs result => {out}, {err}")
|
||||
|
@ -346,7 +346,7 @@ def parse_args():
|
|||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--deb-based-releases", default=["bookworm", "bullseye", "buster", "jammy", "impish", "focal", "bionic",
|
||||
"--deb-based-releases", default=["any", "bookworm", "bullseye", "buster", "noble", "jammy", "impish", "focal", "bionic",
|
||||
"xenial", "trusty"],
|
||||
help="list of debian based releases that need to be packaged for"
|
||||
)
|
||||
|
|
|
@ -79,8 +79,8 @@ func (b *BackoffHandler) BackoffTimer() <-chan time.Time {
|
|||
} else {
|
||||
b.retries++
|
||||
}
|
||||
maxTimeToWait := time.Duration(b.GetBaseTime() * 1 << (b.retries))
|
||||
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds()))
|
||||
maxTimeToWait := b.GetBaseTime() * (1 << b.retries)
|
||||
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds())) // #nosec G404
|
||||
return b.Clock.After(timeToWait)
|
||||
}
|
||||
|
||||
|
@ -99,11 +99,11 @@ func (b *BackoffHandler) Backoff(ctx context.Context) bool {
|
|||
}
|
||||
}
|
||||
|
||||
// Sets a grace period within which the the backoff timer is maintained. After the grace
|
||||
// Sets a grace period within which the backoff timer is maintained. After the grace
|
||||
// period expires, the number of retries & backoff duration is reset.
|
||||
func (b *BackoffHandler) SetGracePeriod() time.Duration {
|
||||
maxTimeToWait := b.GetBaseTime() * 2 << (b.retries + 1)
|
||||
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds()))
|
||||
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds())) // #nosec G404
|
||||
b.resetDeadline = b.Clock.Now().Add(timeToWait)
|
||||
|
||||
return timeToWait
|
||||
|
@ -118,7 +118,7 @@ func (b BackoffHandler) GetBaseTime() time.Duration {
|
|||
|
||||
// Retries returns the number of retries consumed so far.
|
||||
func (b *BackoffHandler) Retries() int {
|
||||
return int(b.retries)
|
||||
return int(b.retries) // #nosec G115
|
||||
}
|
||||
|
||||
func (b *BackoffHandler) ReachedMaxRetries() bool {
|
||||
|
|
|
@ -7,30 +7,53 @@ import (
|
|||
"github.com/cloudflare/cloudflared/features"
|
||||
)
|
||||
|
||||
// When experimental post-quantum tunnels are enabled, and we're hitting an
|
||||
// issue creating the tunnel, we'll report the first error
|
||||
// to https://pqtunnels.cloudflareresearch.com.
|
||||
|
||||
const (
|
||||
PQKex = tls.CurveID(0x6399) // X25519Kyber768Draft00
|
||||
PQKexName = "X25519Kyber768Draft00"
|
||||
X25519Kyber768Draft00PQKex = tls.CurveID(0x6399) // X25519Kyber768Draft00
|
||||
X25519Kyber768Draft00PQKexName = "X25519Kyber768Draft00"
|
||||
P256Kyber768Draft00PQKex = tls.CurveID(0xfe32) // P256Kyber768Draft00
|
||||
P256Kyber768Draft00PQKexName = "P256Kyber768Draft00"
|
||||
X25519MLKEM768PQKex = tls.CurveID(0x11ec) // X25519MLKEM768
|
||||
X25519MLKEM768PQKexName = "X25519MLKEM768"
|
||||
)
|
||||
|
||||
func curvePreference(pqMode features.PostQuantumMode, currentCurve []tls.CurveID) ([]tls.CurveID, error) {
|
||||
var (
|
||||
nonFipsPostQuantumStrictPKex []tls.CurveID = []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex}
|
||||
nonFipsPostQuantumPreferPKex []tls.CurveID = []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex}
|
||||
fipsPostQuantumStrictPKex []tls.CurveID = []tls.CurveID{P256Kyber768Draft00PQKex}
|
||||
fipsPostQuantumPreferPKex []tls.CurveID = []tls.CurveID{P256Kyber768Draft00PQKex, tls.CurveP256}
|
||||
)
|
||||
|
||||
func removeDuplicates(curves []tls.CurveID) []tls.CurveID {
|
||||
bucket := make(map[tls.CurveID]bool)
|
||||
var result []tls.CurveID
|
||||
for _, curve := range curves {
|
||||
if _, ok := bucket[curve]; !ok {
|
||||
bucket[curve] = true
|
||||
result = append(result, curve)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func curvePreference(pqMode features.PostQuantumMode, fipsEnabled bool, currentCurve []tls.CurveID) ([]tls.CurveID, error) {
|
||||
switch pqMode {
|
||||
case features.PostQuantumStrict:
|
||||
// If the user passes the -post-quantum flag, we override
|
||||
// CurvePreferences to only support hybrid post-quantum key agreements.
|
||||
return []tls.CurveID{PQKex}, nil
|
||||
if fipsEnabled {
|
||||
return fipsPostQuantumStrictPKex, nil
|
||||
}
|
||||
return nonFipsPostQuantumStrictPKex, nil
|
||||
case features.PostQuantumPrefer:
|
||||
if len(currentCurve) == 0 {
|
||||
return []tls.CurveID{PQKex}, nil
|
||||
if fipsEnabled {
|
||||
// Ensure that all curves returned are FIPS compliant.
|
||||
// Moreover the first curves are post-quantum and then the
|
||||
// non post-quantum.
|
||||
return fipsPostQuantumPreferPKex, nil
|
||||
}
|
||||
|
||||
if currentCurve[0] != PQKex {
|
||||
return append([]tls.CurveID{PQKex}, currentCurve...), nil
|
||||
}
|
||||
return currentCurve, nil
|
||||
curves := append(nonFipsPostQuantumPreferPKex, currentCurve...)
|
||||
curves = removeDuplicates(curves)
|
||||
return curves, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Unexpected post quantum mode")
|
||||
}
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
package supervisor
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/features"
|
||||
)
|
||||
|
||||
func TestCurvePreferences(t *testing.T) {
|
||||
// This tests if the correct curves are returned
|
||||
// given a PostQuantumMode and a FIPS enabled bool
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
currentCurves []tls.CurveID
|
||||
expectedCurves []tls.CurveID
|
||||
pqMode features.PostQuantumMode
|
||||
fipsEnabled bool
|
||||
}{
|
||||
{
|
||||
name: "FIPS with Prefer PQ",
|
||||
pqMode: features.PostQuantumPrefer,
|
||||
fipsEnabled: true,
|
||||
currentCurves: []tls.CurveID{tls.CurveP384},
|
||||
expectedCurves: []tls.CurveID{P256Kyber768Draft00PQKex, tls.CurveP256},
|
||||
},
|
||||
{
|
||||
name: "FIPS with Strict PQ",
|
||||
pqMode: features.PostQuantumStrict,
|
||||
fipsEnabled: true,
|
||||
currentCurves: []tls.CurveID{tls.CurveP256, tls.CurveP384},
|
||||
expectedCurves: []tls.CurveID{P256Kyber768Draft00PQKex},
|
||||
},
|
||||
{
|
||||
name: "FIPS with Prefer PQ - no duplicates",
|
||||
pqMode: features.PostQuantumPrefer,
|
||||
fipsEnabled: true,
|
||||
currentCurves: []tls.CurveID{tls.CurveP256},
|
||||
expectedCurves: []tls.CurveID{P256Kyber768Draft00PQKex, tls.CurveP256},
|
||||
},
|
||||
{
|
||||
name: "Non FIPS with Prefer PQ",
|
||||
pqMode: features.PostQuantumPrefer,
|
||||
fipsEnabled: false,
|
||||
currentCurves: []tls.CurveID{tls.CurveP256},
|
||||
expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex, tls.CurveP256},
|
||||
},
|
||||
{
|
||||
name: "Non FIPS with Prefer PQ - no duplicates",
|
||||
pqMode: features.PostQuantumPrefer,
|
||||
fipsEnabled: false,
|
||||
currentCurves: []tls.CurveID{X25519Kyber768Draft00PQKex, tls.CurveP256},
|
||||
expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex, tls.CurveP256},
|
||||
},
|
||||
{
|
||||
name: "Non FIPS with Prefer PQ - correct preference order",
|
||||
pqMode: features.PostQuantumPrefer,
|
||||
fipsEnabled: false,
|
||||
currentCurves: []tls.CurveID{tls.CurveP256, X25519Kyber768Draft00PQKex},
|
||||
expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex, tls.CurveP256},
|
||||
},
|
||||
{
|
||||
name: "Non FIPS with Strict PQ",
|
||||
pqMode: features.PostQuantumStrict,
|
||||
fipsEnabled: false,
|
||||
currentCurves: []tls.CurveID{tls.CurveP256},
|
||||
expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tcase := range tests {
|
||||
t.Run(tcase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
curves, err := curvePreference(tcase.pqMode, tcase.fipsEnabled, tcase.currentCurves)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tcase.expectedCurves, curves)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -26,12 +26,6 @@ const (
|
|||
tunnelRetryDuration = time.Second * 10
|
||||
// Interval between registering new tunnels
|
||||
registrationInterval = time.Second
|
||||
|
||||
subsystemRefreshAuth = "refresh_auth"
|
||||
// Maximum exponent for 'Authenticate' exponential backoff
|
||||
refreshAuthMaxBackoff = 10
|
||||
// Waiting time before retrying a failed 'Authenticate' connection
|
||||
refreshAuthRetryDuration = time.Second * 10
|
||||
)
|
||||
|
||||
// Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and
|
||||
|
@ -84,7 +78,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
|
|||
edgeBindAddr := config.EdgeBindAddr
|
||||
|
||||
datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
|
||||
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort)
|
||||
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetFlowLimiter())
|
||||
|
||||
edgeTunnelServer := EdgeTunnelServer{
|
||||
config: config,
|
||||
|
@ -253,9 +247,7 @@ func (s *Supervisor) startFirstTunnel(
|
|||
ctx context.Context,
|
||||
connectedSignal *signal.Signal,
|
||||
) {
|
||||
var (
|
||||
err error
|
||||
)
|
||||
var err error
|
||||
const firstConnIndex = 0
|
||||
isStaticEdge := len(s.config.EdgeAddrs) > 0
|
||||
defer func() {
|
||||
|
@ -306,13 +298,12 @@ func (s *Supervisor) startTunnel(
|
|||
index int,
|
||||
connectedSignal *signal.Signal,
|
||||
) {
|
||||
var (
|
||||
err error
|
||||
)
|
||||
var err error
|
||||
defer func() {
|
||||
s.tunnelErrors <- tunnelError{index: index, err: err}
|
||||
}()
|
||||
|
||||
// nolint: gosec
|
||||
err = s.edgeTunnelServer.Serve(ctx, uint8(index), s.tunnelsProtocolFallback[index], connectedSignal)
|
||||
}
|
||||
|
||||
|
@ -334,7 +325,3 @@ func (s *Supervisor) waitForNextTunnel(index int) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Supervisor) unusedIPs() bool {
|
||||
return s.edgeIPs.AvailableAddrs() > s.config.HAConnections
|
||||
}
|
||||
|
|
|
@ -7,11 +7,11 @@ import (
|
|||
"net"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/getsentry/sentry-go"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
|
||||
"github.com/cloudflare/cloudflared/features"
|
||||
"github.com/cloudflare/cloudflared/fips"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
"github.com/cloudflare/cloudflared/orchestration"
|
||||
|
@ -460,6 +461,7 @@ func (e *EdgeTunnelServer) serveConnection(
|
|||
|
||||
switch protocol {
|
||||
case connection.QUIC:
|
||||
// nolint: gosec
|
||||
connOptions := e.config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
|
||||
return e.serveQUIC(ctx,
|
||||
addr.UDP.AddrPort(),
|
||||
|
@ -475,6 +477,7 @@ func (e *EdgeTunnelServer) serveConnection(
|
|||
return err, true
|
||||
}
|
||||
|
||||
// nolint: gosec
|
||||
connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
|
||||
if err := e.serveHTTP2(
|
||||
ctx,
|
||||
|
@ -554,15 +557,13 @@ func (e *EdgeTunnelServer) serveQUIC(
|
|||
tlsConfig := e.config.EdgeTLSConfigs[connection.QUIC]
|
||||
|
||||
pqMode := e.config.FeatureSelector.PostQuantumMode()
|
||||
if pqMode == features.PostQuantumStrict || pqMode == features.PostQuantumPrefer {
|
||||
connOptions.Client.Features = features.Dedup(append(connOptions.Client.Features, features.FeaturePostQuantum))
|
||||
}
|
||||
|
||||
curvePref, err := curvePreference(pqMode, tlsConfig.CurvePreferences)
|
||||
curvePref, err := curvePreference(pqMode, fips.IsFipsEnabled(), tlsConfig.CurvePreferences)
|
||||
if err != nil {
|
||||
return err, true
|
||||
}
|
||||
|
||||
connLogger.Logger().Info().Msgf("Using %v as curve preferences", curvePref)
|
||||
|
||||
tlsConfig.CurvePreferences = curvePref
|
||||
|
||||
// quic-go 0.44 increases the initial packet size to 1280 by default. That breaks anyone running tunnel through WARP
|
||||
|
@ -598,11 +599,13 @@ func (e *EdgeTunnelServer) serveQUIC(
|
|||
)
|
||||
if err != nil {
|
||||
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to dial a quic connection")
|
||||
|
||||
e.reportErrorToSentry(err)
|
||||
return err, true
|
||||
}
|
||||
|
||||
var datagramSessionManager connection.DatagramSessionHandler
|
||||
if slices.Contains(connOptions.Client.Features, features.FeatureDatagramV3) {
|
||||
if e.config.FeatureSelector.DatagramVersion() == features.DatagramV3 {
|
||||
datagramSessionManager = connection.NewDatagramV3Connection(
|
||||
ctx,
|
||||
conn,
|
||||
|
@ -620,6 +623,7 @@ func (e *EdgeTunnelServer) serveQUIC(
|
|||
connIndex,
|
||||
e.config.RPCTimeout,
|
||||
e.config.WriteStreamTimeout,
|
||||
e.orchestrator.GetFlowLimiter(),
|
||||
connLogger.Logger(),
|
||||
)
|
||||
}
|
||||
|
@ -666,6 +670,26 @@ func (e *EdgeTunnelServer) serveQUIC(
|
|||
return errGroup.Wait(), false
|
||||
}
|
||||
|
||||
// The reportErrorToSentry is an helper function that handles
|
||||
// verifies if an error should be reported to Sentry.
|
||||
func (e *EdgeTunnelServer) reportErrorToSentry(err error) {
|
||||
dialErr, ok := err.(*connection.EdgeQuicDialError)
|
||||
if ok {
|
||||
// The TransportError provides an Unwrap function however
|
||||
// the err MAY not always be set
|
||||
transportErr, ok := dialErr.Cause.(*quic.TransportError)
|
||||
if ok &&
|
||||
transportErr.ErrorCode.IsCryptoError() &&
|
||||
fips.IsFipsEnabled() &&
|
||||
e.config.FeatureSelector.PostQuantumMode() == features.PostQuantumStrict {
|
||||
// Only report to Sentry when using FIPS, PQ,
|
||||
// and the error is a Crypto error reported by
|
||||
// an EdgeQuicDialError
|
||||
sentry.CaptureException(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh <-chan struct{}) error {
|
||||
select {
|
||||
case reconnect := <-reconnectCh:
|
||||
|
|
|
@ -53,7 +53,7 @@ type signalHandler struct {
|
|||
}
|
||||
|
||||
type jwtPayload struct {
|
||||
Aud []string `json:"aud"`
|
||||
Aud []string `json:"-"`
|
||||
Email string `json:"email"`
|
||||
Exp int `json:"exp"`
|
||||
Iat int `json:"iat"`
|
||||
|
@ -68,6 +68,34 @@ type transferServiceResponse struct {
|
|||
OrgToken string `json:"org_token"`
|
||||
}
|
||||
|
||||
func (p *jwtPayload) UnmarshalJSON(data []byte) error {
|
||||
type Alias jwtPayload
|
||||
if err := json.Unmarshal(data, (*Alias)(p)); err != nil {
|
||||
return err
|
||||
}
|
||||
var audParser struct {
|
||||
Aud any `json:"aud"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &audParser); err != nil {
|
||||
return err
|
||||
}
|
||||
switch aud := audParser.Aud.(type) {
|
||||
case string:
|
||||
p.Aud = []string{aud}
|
||||
case []any:
|
||||
for _, a := range aud {
|
||||
s, ok := a.(string)
|
||||
if !ok {
|
||||
return errors.New("aud array contains non-string elements")
|
||||
}
|
||||
p.Aud = append(p.Aud, s)
|
||||
}
|
||||
default:
|
||||
return errors.New("aud field is not a string or an array of strings")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p jwtPayload) isExpired() bool {
|
||||
return int(time.Now().Unix()) > p.Exp
|
||||
}
|
||||
|
@ -182,7 +210,9 @@ func getToken(appURL *url.URL, appInfo *AppInfo, useHostOnly bool, log *zerolog.
|
|||
if err = fileLockAppToken.Acquire(); err != nil {
|
||||
return "", errors.Wrap(err, "failed to acquire app token lock")
|
||||
}
|
||||
defer fileLockAppToken.Release()
|
||||
defer func() {
|
||||
_ = fileLockAppToken.Release()
|
||||
}()
|
||||
|
||||
// check to see if another process has gotten a token while we waited for the lock
|
||||
if token, err := GetAppTokenIfExists(appInfo); token != "" && err == nil {
|
||||
|
@ -202,7 +232,9 @@ func getToken(appURL *url.URL, appInfo *AppInfo, useHostOnly bool, log *zerolog.
|
|||
if err = fileLockOrgToken.Acquire(); err != nil {
|
||||
return "", errors.Wrap(err, "failed to acquire org token lock")
|
||||
}
|
||||
defer fileLockOrgToken.Release()
|
||||
defer func() {
|
||||
_ = fileLockOrgToken.Release()
|
||||
}()
|
||||
// check if an org token has been created since the lock was acquired
|
||||
orgToken, err = GetOrgTokenIfExists(appInfo.AuthDomain)
|
||||
}
|
||||
|
@ -218,7 +250,6 @@ func getToken(appURL *url.URL, appInfo *AppInfo, useHostOnly bool, log *zerolog.
|
|||
}
|
||||
}
|
||||
return getTokensFromEdge(appURL, appInfo.AppAUD, appTokenPath, orgTokenPath, useHostOnly, log)
|
||||
|
||||
}
|
||||
|
||||
// getTokensFromEdge will attempt to use the transfer service to retrieve an app and org token, save them to disk,
|
||||
|
@ -250,7 +281,6 @@ func getTokensFromEdge(appURL *url.URL, appAUD, appTokenPath, orgTokenPath strin
|
|||
}
|
||||
|
||||
return resp.AppToken, nil
|
||||
|
||||
}
|
||||
|
||||
// GetAppInfo makes a request to the appURL and stops at the first redirect. The 302 location header will contain the
|
||||
|
@ -320,7 +350,6 @@ func handleRedirects(req *http.Request, via []*http.Request, orgToken string) er
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// stop after hitting authorized endpoint since it will contain the app token
|
||||
|
@ -408,7 +437,6 @@ func GetAppTokenIfExists(appInfo *AppInfo) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
return token.CompactSerialize()
|
||||
|
||||
}
|
||||
|
||||
// GetTokenIfExists will return the token from local storage if it exists and not expired
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue