Compare commits

...

42 Commits

Author SHA1 Message Date
João "Pisco" Fernandes d8a066628b Release 2025.4.0 2025-04-01 20:23:54 +01:00
João "Pisco" Fernandes 553e77e061 chore: fix linter rules 2025-04-01 18:57:55 +01:00
Cyb3r Jak3 8f94f54ec7
feat: Adds a new command line for tunnel run for token file
Adds a new command line flag for `tunnel run` which allows a file to be
read for the token. I've left the token command line argument with
priority.
2025-04-01 18:23:22 +01:00
gofastasf 2827b2fe8f
fix: Use path and filepath operation appropriately
Using path package methods can cause errors on windows machines.

path methods are used for url operations and unix specific operation.

filepath methods are used for file system paths and its cross platform. 

Remove strings.HasSuffix and use filepath.Ext and path.Ext for file and
url extenstions respectively.
2025-04-01 17:59:43 +01:00
Rohan Mukherjee 6dc8ed710e
fix: expand home directory for credentials file
## Issue

The [documentation for creating a tunnel's configuration
file](https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/get-started/create-local-tunnel/#4-create-a-configuration-file)
does not specify that the `credentials-file` field in `config.yml` needs
to be an absolute path.

A user (E.G. me 🤦) might add a path like `~/.cloudflared/<uuid>.json`
and wonder why the `cloudflared tunnel run` command is throwing a
credentials file not found error. Although one might consider it
intuitive, it's not a fair assumption as a lot of CLI tools allow file
paths with `~` for specifying files.

P.S. The tunnel ID in the following snippet is not a real tunnel ID, I
just generated it.
```
url: http://localhost:8000
tunnel: 958a1ef6-ff8c-4455-825a-5aed91242135
credentials-file: ~/.cloudflared/958a1ef6-ff8c-4455-825a-5aed91242135.json
```

Furthermore, the error has a confusing message for the user as the file
at the logged path actually exists, it is just that `os.Stat` failed
because it could not expand the `~`.

## Solution

This commit fixes the above issue by running a `homedir.Expand` on the
`credentials-file` path in the `credentialFinder` function.
2025-04-01 17:54:57 +01:00
Shereef Marzouk e0b1ac0d05
chore: Update tunnel configuration link in the readme 2025-04-01 17:53:29 +01:00
Bernhard M. Wiedemann e7c5eb54af
Use RELEASE_NOTES date instead of build date
Use `RELEASE_NOTES` date instead of build date
to make builds reproducible.
See https://reproducible-builds.org/ for why this is good
and https://reproducible-builds.org/specs/source-date-epoch/
for the definition of this variable.
This date call only works with GNU date and BSD date.

Alternatively,
https://reproducible-builds.org/docs/source-date-epoch/#makefile could
be implemented.

This patch was done while working on reproducible builds for openSUSE,
sponsored by the NLnet NGI0 fund.
2025-04-01 17:52:50 +01:00
teslaedison cfec602fa7
chore: remove repetitive words 2025-04-01 17:51:57 +01:00
Micah Yeager 6fceb94998
feat: emit explicit errors for the `service` command on unsupported OSes
Per the contribution guidelines, this seemed to me like a small enough
change to not warrant an issue before creating this pull request. Let me
know if you'd like me to create one anyway.

## Background

While working with `cloudflared` on FreeBSD recently, I noticed that
there's an inconsistency with the available CLI commands on that OS
versus others — namely that the `service` command doesn't exist at all
for operating systems other than Linux, macOS, and Windows.

Contrast `cloudflared --help` output on macOS versus FreeBSD (truncated
to focus on the `COMMANDS` section):

- Current help output on macOS:

  ```text
  COMMANDS:
     update     Update the agent if a new version exists
     version    Print the version
     proxy-dns  Run a DNS over HTTPS proxy server.
     tail       Stream logs from a remote cloudflared
     service    Manages the cloudflared launch agent
     help, h    Shows a list of commands or help for one command
     Access:
       access, forward  access <subcommand>
     Tunnel:
tunnel Use Cloudflare Tunnel to expose private services to the Internet
or to Cloudflare connected private users.
  ```
- Current help output on FreeBSD:
  ```text
  COMMANDS:
     update     Update the agent if a new version exists
     version    Print the version
     proxy-dns  Run a DNS over HTTPS proxy server.
     tail       Stream logs from a remote cloudflared
     help, h    Shows a list of commands or help for one command
     Access:
       access, forward  access <subcommand>
     Tunnel:
tunnel Use Cloudflare Tunnel to expose private services to the Internet
or to Cloudflare connected private users.
  ```

This omission has caused confusion for users (including me), especially
since the provided command in the Cloudflare Zero Trust dashboard
returns a seemingly-unrelated error message:

```console
$ sudo cloudflared service install ...
You did not specify any valid additional argument to the cloudflared tunnel command.

If you are trying to run a Quick Tunnel then you need to explicitly pass the --url flag.
Eg. cloudflared tunnel --url localhost:8080/.

Please note that Quick Tunnels are meant to be ephemeral and should only be used for testing purposes.
For production usage, we recommend creating Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)
```

## Contribution

This pull request adds a "stub" `service` command (including the usual
subcommands available on other OSes) to explicitly declare it as
unsupported on the operating system.

New help output on FreeBSD (and other operating systems where service
management is unsupported):

```text
COMMANDS:
   update     Update the agent if a new version exists
   version    Print the version
   proxy-dns  Run a DNS over HTTPS proxy server.
   tail       Stream logs from a remote cloudflared
   service    Manages the cloudflared system service (not supported on this operating system)
   help, h    Shows a list of commands or help for one command
   Access:
     access, forward  access <subcommand>
   Tunnel:
     tunnel  Use Cloudflare Tunnel to expose private services to the Internet or to   Cloudflare connected private users.
```

New outputs when running the service management subcommands:

```console
$ sudo cloudflared service install ...
service installation is not supported on this operating system
```

```console
$ sudo cloudflared service uninstall ...
service uninstallation is not supported on this operating system
```

This keeps the available commands consistent until proper service
management support can be added for these otherwise-supported operating
systems.
2025-04-01 17:48:20 +01:00
Roman cf817f7036
Fix messages to point to one.dash.cloudflare.com 2025-04-01 17:47:23 +01:00
VFLC c8724a290a
Fix broken links in `cmd/cloudflared/*.go` related to running tunnel as a service
This PR updates 3 broken links to document [run tunnel as a
service](https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/configure-tunnels/local-management/as-a-service/).
2025-04-01 17:45:59 +01:00
João "Pisco" Fernandes e7586153be TUN-9101: Don't ignore errors on `cloudflared access ssh`
## Summary

This change ensures that errors resulting from the `cloudflared access ssh` call are no longer ignored. By returning the error from `carrier.StartClient` to the upstream, we ensure that these errors are properly logged on stdout, providing better visibility and debugging capabilities.

Relates to TUN-9101
2025-03-17 18:42:19 +00:00
Chung-Ting Huang 11777db304 TUN-9089: Pin go import to v0.30.0, v0.31.0 requires go 1.23
Closes TUN-9089
2025-03-06 12:05:24 +00:00
lneto 3f6b1f24d0 Release 2025.2.1 2025-02-26 16:44:32 +00:00
Luis Neto a4105e8708 TUN-9016: update base-debian to v12
## Summary

Fixes vulnerability ([CVE -2024-4741](https://github.com/advisories/GHSA-6vgq-8qjq-h578))

 Closes TUN-9016
2025-02-26 15:54:10 +00:00
Luis Neto 6496322bee TUN-9007: modify logic to resolve region when the tunnel token has an endpoint field
## Summary

Within the work of FEDRamp it is necessary to change the HA SD lookup to use as srv `fed-v2-origintunneld`

This work assumes that the tunnel token has an optional endpoint field which will be used to modify the behaviour of the HA SD lookup.

Finally, the presence of the endpoint will override region to _fed_ and fail if any value is passed for the flag region.

Closes TUN-9007
2025-02-25 19:03:41 +00:00
Luis Neto 906452a9c9 TUN-8960: Connect to FED API GW based on the OriginCert's endpoint
## Summary

Within the scope of the FEDRamp High RM, it is necessary to detect if an user should connect to a FEDRamp colo.

At first, it was considered to add the --fedramp as global flag however this could be a footgun for the user or even an hindrance, thus, the proposal is to save in the token (during login) if the user authenticated using the FEDRamp Dashboard. This solution makes it easier to the user as they will only be required to pass the flag in login and nothing else.

* Introduces the new field, endpoint, in OriginCert
* Refactors login to remove the private key and certificate which are no longer used
* Login will only store the Argo Tunnel Token
* Remove namedTunnelToken as it was only used to for serialization

Closes TUN-8960
2025-02-25 17:13:33 +00:00
Jingqi Huang d969fdec3e SDLC-3762: Remove backstage.io/source-location from catalog-info.yaml 2025-02-13 13:02:50 -08:00
João "Pisco" Fernandes 7336a1a4d6 TUN-8914: Create a flags module to group all cloudflared cli flags
## Summary

This commit refactors some of the flags of cloudflared to their own module, so that they can be used across the code without requiring to literal strings which are much more error prone.

 Closes TUN-8914
2025-02-06 03:30:27 -08:00
João "Pisco" Fernandes df5dafa6d7 Release 2025.2.0 2025-02-03 18:39:00 +00:00
Bas Westerbaan c19f919428 Bump x/crypto to 0.31.0 2025-02-03 16:08:02 +01:00
João "Pisco" Fernandes b187879e69 TUN-8914: Add a new configuration to locally override the max-active-flows
## Summary

This commit introduces a new command line flag, `--max-active-flows`, which allows overriding the remote configuration for the maximum number of active flows.

The flag can be used with the `run` command, like `cloudflared tunnel --no-autoupdate run --token <TUNNEL_TOKEN> --max-active-flows 50000`, or set via an environment variable `TUNNEL_MAX_ACTIVE_FLOWS`.

Note that locally-set values always take precedence over remote settings, even if the tunnel is remotely managed.

Closes TUN-8914
2025-02-03 03:42:50 -08:00
lneto 2feccd772c Release 2025.1.1 2025-01-30 14:48:47 +00:00
Luis Neto 90176a79b4 TUN-8894: report FIPS+PQ error to Sentry when dialling to the edge
## Summary

Since we will enable PQ + FIPS it is necessary to add observability so that we can understand if issues are happening.

 Closes TUN-8894
2025-01-30 06:26:53 -08:00
Luis Neto 9695829e5b TUN-8857: remove restriction for using FIPS and PQ
## Summary

When the FIPS compliance was achieved with HTTP/2 Transport the technology at the time wasn't available or certified to be used in tandem with Post-Quantum encryption. Nowadays, that is possible, thus, we can also remove this restriction from Cloudflared.

 Closes TUN-8857
2025-01-30 05:47:07 -08:00
Luis Neto 31a870b291 TUN-8855: Update PQ curve preferences
## Summary

Nowadays, Cloudflared only supports X25519Kyber768Draft00 (0x6399,25497) but older versions may use different preferences.

For FIPS compliance we are required to use P256Kyber768Draft00 (0xfe32,65074) which is supported in our internal fork of [Go-Boring-1.22.10](https://bitbucket.cfdata.org/projects/PLAT/repos/goboring/browse?at=refs/heads/go-boring/1.22.10 "Follow link").

In the near future, Go will support by default the X25519MLKEM768 (0x11ec,4588) given this we may drop the usage of our public fork of GO.

To summarise:

* Cloudflared FIPS: QUIC_CURVE_PREFERENCES=65074
* Cloudflared non-FIPS: QUIC_CURVE_PREFERENCES=4588

Closes TUN-8855
2025-01-30 05:02:47 -08:00
Luis Neto bfdb0c76dc TUN-8855: fix lint issues
## Summary

Fix lint issues necessary for a subsequent PR. This is only separate to allow a better code review of the actual changes.

Closes TUN-8855
2025-01-30 03:53:24 -08:00
Luis Neto 45f67c23fd TUN-8858: update go to 1.22.10 and include quic-go FIPS changes
## Summary

To have support for new curves and to achieve FIPS compliance Cloudflared must be released with [Go-Boring-1.22.10](https://bitbucket.cfdata.org/projects/PLAT/repos/goboring/browse?at=refs/heads/go-boring/1.22.10 "Follow link") along with the quic-go patches. 

 Closes TUN-8858
2025-01-30 03:11:54 -08:00
João "Pisco" Fernandes 0f1bfe99ce TUN-8904: Rename Connect Response Flow Rate Limited metadata
## Summary

This commit renames the public variable that identifies the metadata key and value for the ConnectResponse structure when the flow was rate limited.

 Closes TUN-8904
2025-01-22 07:23:46 -08:00
Eduardo Gomes 18eecaf151 AUTH-6633 Fix cloudflared access login + warp as auth
## Summary
cloudflared access login and cloudflared access curl fails when the Access application has warp_as_auth enabled.

This bug originates from a 4 year old inconsistency where tokens signed by the nginx-fl-access module include 'aud' as a string, while tokens signed by the access authentication worker include 'aud' as an array of strings.
When the new(ish) feature warp_as_auth is enabled for the app, the fl module signs the token as opposed to the worker like usually.


I'm going to bring this up to the Access team, and try to figure out a way to consolidate this discrepancy without breaking behaviour.

Meanwhile we have this [CUSTESC ](https://jira.cfdata.org/browse/CUSTESC-47987), so I'm making cloudflared more lenient by accepting both []string and string in the token 'aud' field.



Tested this by compiling and running cloudflared access curls to my domains


Closes AUTH-6633
2025-01-21 04:00:28 -08:00
João "Pisco" Fernandes 4eb0f8ce5f TUN-8861: Rename Session Limiter to Flow Limiter
## Summary
Session is the concept used for UDP flows. Therefore, to make
the session limiter ambiguous for both TCP and UDP, this commit
renames it to flow limiter.

Closes TUN-8861
2025-01-20 06:33:40 -08:00
João "Pisco" Fernandes 8c2eda16c1 TUN-8861: Add configuration for active sessions limiter
## Summary
This commit adds a new configuration in the warp routing
config to allow users to define the active sessions limit
value.
2025-01-20 11:39:42 +00:00
João "Pisco" Fernandes 8bfe111cab TUN-8861: Add session limiter to TCP session manager
## Summary
In order to make cloudflared behavior more predictable and
prevent an exhaustion of resources, we have decided to add
session limits that can be configured by the user. This commit
adds the session limiter to the HTTP/TCP handling path.
For now the limiter is set to run only in unlimited mode.
2025-01-20 10:53:53 +00:00
João "Pisco" Fernandes bf4954e96a TUN-8861: Add session limiter to UDP session manager
## Summary
In order to make cloudflared behavior more predictable and
prevent an exhaustion of resources, we have decided to add
session limits that can be configured by the user. This first
commit introduces the session limiter and adds it to the UDP
handling path. For now the limiter is set to run only in
unlimited mode.
2025-01-20 02:52:32 -08:00
Gonçalo Garcia 8918b6729e TUN-8871: Accept login flag to authenticate with Fedramp environment
## Summary
Some description...

Closes TUN-8871
2025-01-17 08:16:36 -08:00
João "Pisco" Fernandes 25c3f676f4 TUN-8900: Add import of Apple Developer Certificate Authority to macOS Pipeline
## Summary
During the renewal of the certificates used to sign the macOS binaries and package,
we faced an issue with the new certificates requiring a specific certification authority
that wasn't available in the keychain of the mac agents. Therefore, this commit adds
an import step that will ensure that the Certificate Authority, usually fetched from
https://www.apple.com/certificateauthority/ is imported into the keychain to validate
the Developer Certificates.

Closes TUN-8900
2025-01-17 07:10:16 -08:00
João "Pisco" Fernandes a1963aed80 TUN-8866: Add linter to cloudflared repository
## Summary
To improve our code, this commit adds a linter that will start
checking for issues from this commit onwards, also forcing
issues to be fixed on the file changed and not only on the changes
themselves. This should help improve our code quality overtime.

Closes TUN-8866
2025-01-16 07:02:54 -08:00
chungthuang ac34f94d42 TUN-8848: Don't treat connection shutdown as an error condition when RPC server is done 2025-01-09 10:07:12 -06:00
João "Pisco" Fernandes d8c7f1c1ec Release 2025.1.0 2025-01-07 11:33:38 +00:00
Devin Carr 3b522a27cf TUN-8807: Add support_datagram_v3 to remote feature rollout
Support rolling out the `support_datagram_v3` feature via remote feature rollout (DNS TXT record) with `dv3` key.

Consolidated some of the feature evaluation code into the features module to simplify the lookup of available features at runtime.

Reduced complexity for management logs feature lookup since it's a default feature.

Closes TUN-8807
2025-01-06 09:15:18 -08:00
João "Pisco" Fernandes 5cfe9bef79 TUN-8842: Add Ubuntu Noble and 'any' debian distributions to release script
## Summary
Ubuntu has released a new LTS version, and there are people starting to use it, this makes
our installation recommendation, that automatically detecs the release flavor, to fail for
Noble users. Therefore, this commit adds this new version to our release packages.
It also adds an `any` package so that we can update our documentation to use it since
we are using the same binaries across all debian flavors, so there is no reason to keep
adding more release flavors when we can just take advantage of the `any` release flavor
like other repositories do.
2025-01-06 12:09:13 +00:00
Luis Neto 2714d10d62 TUN-8829: add CONTAINER_BUILD to dockerfiles
Closes TUN-8829
2024-12-20 08:24:12 -08:00
276 changed files with 24324 additions and 7709 deletions

89
.golangci.yaml Normal file
View File

@ -0,0 +1,89 @@
linters:
enable:
# Some of the linters below are commented out. We should uncomment and start running them, but they return
# too many problems to fix in one commit. Something for later.
- asasalint # Check for pass []any as any in variadic func(...any).
- asciicheck # Checks that all code identifiers does not have non-ASCII symbols in the name.
- bidichk # Checks for dangerous unicode character sequences.
- bodyclose # Checks whether HTTP response body is closed successfully.
- decorder # Check declaration order and count of types, constants, variables and functions.
- dogsled # Checks assignments with too many blank identifiers (e.g. x, , , _, := f()).
- dupl # Tool for code clone detection.
- dupword # Checks for duplicate words in the source code.
- durationcheck # Check for two durations multiplied together.
- errcheck # Errcheck is a program for checking for unchecked errors in Go code. These unchecked errors can be critical bugs in some cases.
- errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error.
- exhaustive # Check exhaustiveness of enum switch statements.
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification.
- goimports # Check import statements are formatted according to the 'goimport' command. Reformat imports in autofix mode.
- gosec # Inspects source code for security problems.
- gosimple # Linter for Go source code that specializes in simplifying code.
- govet # Vet examines Go source code and reports suspicious constructs. It is roughly the same as 'go vet' and uses its passes.
- ineffassign # Detects when assignments to existing variables are not used.
- importas # Enforces consistent import aliases.
- misspell # Finds commonly misspelled English words.
- prealloc # Finds slice declarations that could potentially be pre-allocated.
- promlinter # Check Prometheus metrics naming via promlint.
- sloglint # Ensure consistent code style when using log/slog.
- sqlclosecheck # Checks that sql.Rows, sql.Stmt, sqlx.NamedStmt, pgx.Query are closed.
- staticcheck # It's a set of rules from staticcheck. It's not the same thing as the staticcheck binary.
- tenv # Tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17.
- testableexamples # Linter checks if examples are testable (have an expected output).
- testifylint # Checks usage of github.com/stretchr/testify.
- tparallel # Tparallel detects inappropriate usage of t.Parallel() method in your Go test codes.
- unconvert # Remove unnecessary type conversions.
- unused # Checks Go code for unused constants, variables, functions and types.
- wastedassign # Finds wasted assignment statements.
- whitespace # Whitespace is a linter that checks for unnecessary newlines at the start and end of functions, if, for, etc.
- zerologlint # Detects the wrong usage of zerolog that a user forgets to dispatch with Send or Msg.
# Other linters are disabled, list of all is here: https://golangci-lint.run/usage/linters/
run:
timeout: 5m
modules-download-mode: vendor
# output configuration options
output:
formats:
- format: 'colored-line-number'
print-issued-lines: true
print-linter-name: true
issues:
# Maximum issues count per one linter.
# Set to 0 to disable.
# Default: 50
max-issues-per-linter: 50
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 15
# Show only new issues: if there are unstaged changes or untracked files,
# only those changes are analyzed, else only changes in HEAD~ are analyzed.
# It's a super-useful option for integration of golangci-lint into existing large codebase.
# It's not practical to fix all existing issues at the moment of integration:
# much better don't allow issues in new code.
#
# Default: false
new: true
# Show only new issues created after git revision `REV`.
# Default: ""
new-from-rev: ac34f94d423273c8fa8fdbb5f2ac60e55f2c77d5
# Show issues in any part of update files (requires new-from-rev or new-from-patch).
# Default: false
whole-files: true
# Which dirs to exclude: issues from them won't be reported.
# Can use regexp here: `generated.*`, regexp is applied on full path,
# including the path prefix if one is set.
# Default dirs are skipped independently of this option's value (see exclude-dirs-use-default).
# "/" will be replaced by current OS file path separator to properly work on Windows.
# Default: []
exclude-dirs:
- vendor
linters-settings:
# Check exhaustiveness of enum switch statements.
exhaustive:
# Presence of "default" case in switch statements satisfies exhaustiveness,
# even if all enum members are not listed.
# Default: false
default-signifies-exhaustive: true

View File

@ -3,6 +3,6 @@
cd /tmp
git clone -q https://github.com/cloudflare/go
cd go/src
# https://github.com/cloudflare/go/tree/f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38 is version go1.22.5-devel-cf
git checkout -q f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38
# https://github.com/cloudflare/go/tree/af19da5605ca11f85776ef7af3384a02a315a52b is version go1.22.5-devel-cf
git checkout -q af19da5605ca11f85776ef7af3384a02a315a52b
./make.bash

126
.teamcity/mac/build.sh vendored
View File

@ -22,6 +22,7 @@ TARGET_DIRECTORY=".build"
BINARY_NAME="cloudflared"
VERSION=$(git describe --tags --always --dirty="-dev")
PRODUCT="cloudflared"
APPLE_CA_CERT="apple_dev_ca.cert"
CODE_SIGN_PRIV="code_sign.p12"
CODE_SIGN_CERT="code_sign.cer"
INSTALLER_PRIV="installer.p12"
@ -35,15 +36,56 @@ mkdir -p ../src/github.com/cloudflare/
cp -r . ../src/github.com/cloudflare/cloudflared
cd ../src/github.com/cloudflare/cloudflared
# Add code signing private key to the key chain
if [[ ! -z "$CFD_CODE_SIGN_KEY" ]]; then
if [[ ! -z "$CFD_CODE_SIGN_PASS" ]]; then
# write private key to disk and then import it keychain
echo -n -e ${CFD_CODE_SIGN_KEY} | base64 -D > ${CODE_SIGN_PRIV}
# Imports certificates to the Apple KeyChain
import_certificate() {
local CERTIFICATE_NAME=$1
local CERTIFICATE_ENV_VAR=$2
local CERTIFICATE_FILE_NAME=$3
echo "Importing $CERTIFICATE_NAME"
if [[ ! -z "$CERTIFICATE_ENV_VAR" ]]; then
# write certificate to disk and then import it keychain
echo -n -e ${CERTIFICATE_ENV_VAR} | base64 -D > ${CERTIFICATE_FILE_NAME}
# we set || true here and for every `security import invoke` because the "duplicate SecKeychainItemImport" error
# will cause set -e to exit 1. It is okay we do this because we deliberately handle this error in the lines below.
out=$(security import ${CODE_SIGN_PRIV} -A -P "${CFD_CODE_SIGN_PASS}" 2>&1) || true
exitcode=$?
local out=$(security import ${CERTIFICATE_FILE_NAME} -A 2>&1) || true
local exitcode=$?
# delete the certificate from disk
rm -rf ${CERTIFICATE_FILE_NAME}
if [ -n "$out" ]; then
if [ $exitcode -eq 0 ]; then
echo "$out"
else
if [ "$out" != "${SEC_DUP_MSG}" ]; then
echo "$out" >&2
exit $exitcode
else
echo "already imported code signing certificate"
fi
fi
fi
fi
}
# Imports private keys to the Apple KeyChain
import_private_keys() {
local PRIVATE_KEY_NAME=$1
local PRIVATE_KEY_ENV_VAR=$2
local PRIVATE_KEY_FILE_NAME=$3
local PRIVATE_KEY_PASS=$4
echo "Importing $PRIVATE_KEY_NAME"
if [[ ! -z "$PRIVATE_KEY_ENV_VAR" ]]; then
if [[ ! -z "$PRIVATE_KEY_PASS" ]]; then
# write private key to disk and then import it keychain
echo -n -e ${PRIVATE_KEY_ENV_VAR} | base64 -D > ${PRIVATE_KEY_FILE_NAME}
# we set || true here and for every `security import invoke` because the "duplicate SecKeychainItemImport" error
# will cause set -e to exit 1. It is okay we do this because we deliberately handle this error in the lines below.
local out=$(security import ${PRIVATE_KEY_FILE_NAME} -A -P "${PRIVATE_KEY_PASS}" 2>&1) || true
local exitcode=$?
rm -rf ${PRIVATE_KEY_FILE_NAME}
if [ -n "$out" ]; then
if [ $exitcode -eq 0 ]; then
echo "$out"
@ -54,72 +96,24 @@ if [[ ! -z "$CFD_CODE_SIGN_KEY" ]]; then
fi
fi
fi
rm ${CODE_SIGN_PRIV}
fi
fi
fi
}
# Add Apple Root Developer certificate to the key chain
import_certificate "Apple Developer CA" "${APPLE_DEV_CA_CERT}" "${APPLE_CA_CERT}"
# Add code signing private key to the key chain
import_private_keys "Developer ID Application" "${CFD_CODE_SIGN_KEY}" "${CODE_SIGN_PRIV}" "${CFD_CODE_SIGN_PASS}"
# Add code signing certificate to the key chain
if [[ ! -z "$CFD_CODE_SIGN_CERT" ]]; then
# write certificate to disk and then import it keychain
echo -n -e ${CFD_CODE_SIGN_CERT} | base64 -D > ${CODE_SIGN_CERT}
out1=$(security import ${CODE_SIGN_CERT} -A 2>&1) || true
exitcode1=$?
if [ -n "$out1" ]; then
if [ $exitcode1 -eq 0 ]; then
echo "$out1"
else
if [ "$out1" != "${SEC_DUP_MSG}" ]; then
echo "$out1" >&2
exit $exitcode1
else
echo "already imported code signing certificate"
fi
fi
fi
rm ${CODE_SIGN_CERT}
fi
import_certificate "Developer ID Application" "${CFD_CODE_SIGN_CERT}" "${CODE_SIGN_CERT}"
# Add package signing private key to the key chain
if [[ ! -z "$CFD_INSTALLER_KEY" ]]; then
if [[ ! -z "$CFD_INSTALLER_PASS" ]]; then
# write private key to disk and then import it into the keychain
echo -n -e ${CFD_INSTALLER_KEY} | base64 -D > ${INSTALLER_PRIV}
out2=$(security import ${INSTALLER_PRIV} -A -P "${CFD_INSTALLER_PASS}" 2>&1) || true
exitcode2=$?
if [ -n "$out2" ]; then
if [ $exitcode2 -eq 0 ]; then
echo "$out2"
else
if [ "$out2" != "${SEC_DUP_MSG}" ]; then
echo "$out2" >&2
exit $exitcode2
fi
fi
fi
rm ${INSTALLER_PRIV}
fi
fi
import_private_keys "Developer ID Installer" "${CFD_INSTALLER_KEY}" "${INSTALLER_PRIV}" "${CFD_INSTALLER_PASS}"
# Add package signing certificate to the key chain
if [[ ! -z "$CFD_INSTALLER_CERT" ]]; then
# write certificate to disk and then import it keychain
echo -n -e ${CFD_INSTALLER_CERT} | base64 -D > ${INSTALLER_CERT}
out3=$(security import ${INSTALLER_CERT} -A 2>&1) || true
exitcode3=$?
if [ -n "$out3" ]; then
if [ $exitcode3 -eq 0 ]; then
echo "$out3"
else
if [ "$out3" != "${SEC_DUP_MSG}" ]; then
echo "$out3" >&2
exit $exitcode3
else
echo "already imported installer certificate"
fi
fi
fi
rm ${INSTALLER_CERT}
fi
import_certificate "Developer ID Installer" "${CFD_INSTALLER_CERT}" "${INSTALLER_CERT}"
# get the code signing certificate name
if [[ ! -z "$CFD_CODE_SIGN_NAME" ]]; then

View File

@ -9,8 +9,8 @@ Set-Location "$Env:Temp"
git clone -q https://github.com/cloudflare/go
Write-Output "Building go..."
cd go/src
# https://github.com/cloudflare/go/tree/f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38 is version go1.22.5-devel-cf
git checkout -q f4334cdc0c3f22a3bfdd7e66f387e3ffc65a5c38
# https://github.com/cloudflare/go/tree/af19da5605ca11f85776ef7af3384a02a315a52b is version go1.22.5-devel-cf
git checkout -q af19da5605ca11f85776ef7af3384a02a315a52b
& ./make.bat
Write-Output "Installed"

View File

@ -1,3 +1,7 @@
## 2025.1.1
### New Features
- This release introduces the use of new Post Quantum curves and the ability to use Post Quantum curves when running tunnels with the QUIC protocol this applies to non-FIPS and FIPS builds.
## 2024.12.2
### New Features
- This release introduces the ability to collect troubleshooting information from one instance of cloudflared running on the local machine. The command can be executed as `cloudflared tunnel diag`.

View File

@ -1,11 +1,13 @@
# use a builder image for building cloudflare
ARG TARGET_GOOS
ARG TARGET_GOARCH
FROM golang:1.22.5 as builder
FROM golang:1.22.10 as builder
ENV GO111MODULE=on \
CGO_ENABLED=0 \
TARGET_GOOS=${TARGET_GOOS} \
TARGET_GOARCH=${TARGET_GOARCH} \
# the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual
# which changes how cloudflared binds the metrics server
CONTAINER_BUILD=1
@ -20,7 +22,7 @@ RUN .teamcity/install-cloudflare-go.sh
RUN PATH="/tmp/go/bin:$PATH" make cloudflared
# use a distroless base image with glibc
FROM gcr.io/distroless/base-debian11:nonroot
FROM gcr.io/distroless/base-debian12:nonroot
LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared"

View File

@ -1,7 +1,10 @@
# use a builder image for building cloudflare
FROM golang:1.22.5 as builder
FROM golang:1.22.10 as builder
ENV GO111MODULE=on \
CGO_ENABLED=0
CGO_ENABLED=0 \
# the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual
# which changes how cloudflared binds the metrics server
CONTAINER_BUILD=1
WORKDIR /go/src/github.com/cloudflare/cloudflared/
@ -14,7 +17,7 @@ RUN .teamcity/install-cloudflare-go.sh
RUN GOOS=linux GOARCH=amd64 PATH="/tmp/go/bin:$PATH" make cloudflared
# use a distroless base image with glibc
FROM gcr.io/distroless/base-debian11:nonroot
FROM gcr.io/distroless/base-debian12:nonroot
LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared"

View File

@ -1,7 +1,10 @@
# use a builder image for building cloudflare
FROM golang:1.22.5 as builder
FROM golang:1.22.10 as builder
ENV GO111MODULE=on \
CGO_ENABLED=0
CGO_ENABLED=0 \
# the CONTAINER_BUILD envvar is used set github.com/cloudflare/cloudflared/metrics.Runtime=virtual
# which changes how cloudflared binds the metrics server
CONTAINER_BUILD=1
WORKDIR /go/src/github.com/cloudflare/cloudflared/
@ -14,7 +17,7 @@ RUN .teamcity/install-cloudflare-go.sh
RUN GOOS=linux GOARCH=arm64 PATH="/tmp/go/bin:$PATH" make cloudflared
# use a distroless base image with glibc
FROM gcr.io/distroless/base-debian11:nonroot-arm64
FROM gcr.io/distroless/base-debian12:nonroot-arm64
LABEL org.opencontainers.image.source="https://github.com/cloudflare/cloudflared"

View File

@ -24,7 +24,7 @@ else
DEB_PACKAGE_NAME := $(BINARY_NAME)
endif
DATE := $(shell date -u '+%Y-%m-%d-%H%M UTC')
DATE := $(shell date -u -r RELEASE_NOTES '+%Y-%m-%d-%H%M UTC')
VERSION_FLAGS := -X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)"
ifdef PACKAGE_MANAGER
VERSION_FLAGS := $(VERSION_FLAGS) -X "github.com/cloudflare/cloudflared/cmd/cloudflared/updater.BuiltForPackageManager=$(PACKAGE_MANAGER)"
@ -133,11 +133,9 @@ clean:
cloudflared:
ifeq ($(FIPS), true)
$(info Building cloudflared with go-fips)
cp -f fips/fips.go.linux-amd64 cmd/cloudflared/fips.go
endif
GOOS=$(TARGET_OS) GOARCH=$(TARGET_ARCH) $(ARM_COMMAND) go build -mod=vendor $(GO_BUILD_TAGS) $(LDFLAGS) $(IMPORT_PATH)/cmd/cloudflared
ifeq ($(FIPS), true)
rm -f cmd/cloudflared/fips.go
./check-fips.sh cloudflared
endif
@ -255,4 +253,17 @@ vet:
.PHONY: fmt
fmt:
goimports -l -w -local github.com/cloudflare/cloudflared $$(go list -mod=vendor -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc/proto)
@goimports -l -w -local github.com/cloudflare/cloudflared $$(go list -mod=vendor -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc/proto)
@go fmt $$(go list -mod=vendor -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc/proto)
.PHONY: fmt-check
fmt-check:
@./fmt-check.sh
.PHONY: lint
lint:
@golangci-lint run
.PHONY: mocks
mocks:
go generate mocks/mockgen.go

View File

@ -40,7 +40,7 @@ User documentation for Cloudflare Tunnel can be found at https://developers.clou
Once installed, you can authenticate `cloudflared` into your Cloudflare account and begin creating Tunnels to serve traffic to your origins.
* Create a Tunnel with [these instructions](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/create-tunnel)
* Create a Tunnel with [these instructions](https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/get-started/)
* Route traffic to that Tunnel:
* Via public [DNS records in Cloudflare](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/dns)
* Or via a public hostname guided by a [Cloudflare Load Balancer](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/lb)
@ -56,3 +56,27 @@ Want to test Cloudflare Tunnel before adding a website to Cloudflare? You can do
Cloudflare currently supports versions of cloudflared that are **within one year** of the most recent release. Breaking changes unrelated to feature availability may be introduced that will impact versions released more than one year ago. You can read more about upgrading cloudflared in our [developer documentation](https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/downloads/#updating-cloudflared).
For example, as of January 2023 Cloudflare will support cloudflared version 2023.1.1 to cloudflared 2022.1.1.
## Development
### Requirements
- [GNU Make](https://www.gnu.org/software/make/)
- [capnp](https://capnproto.org/install.html)
- [cloudflare go toolchain](https://github.com/cloudflare/go)
- Optional tools:
- [capnpc-go](https://pkg.go.dev/zombiezen.com/go/capnproto2/capnpc-go)
- [goimports](https://pkg.go.dev/golang.org/x/tools/cmd/goimports)
- [golangci-lint](https://github.com/golangci/golangci-lint)
- [gomocks](https://pkg.go.dev/go.uber.org/mock)
### Build
To build cloudflared locally run `make cloudflared`
### Test
To locally run the tests run `make test`
### Linting
To format the code and keep a good code quality use `make fmt` and `make lint`
### Mocks
After changes on interfaces you might need to regenerate the mocks, so run `make mock`

View File

@ -1,3 +1,50 @@
2025.4.0
- 2025-04-02 Fix broken links in `cmd/cloudflared/*.go` related to running tunnel as a service
- 2025-04-02 chore: remove repetitive words
- 2025-04-01 Fix messages to point to one.dash.cloudflare.com
- 2025-04-01 feat: emit explicit errors for the `service` command on unsupported OSes
- 2025-04-01 Use RELEASE_NOTES date instead of build date
- 2025-04-01 chore: Update tunnel configuration link in the readme
- 2025-04-01 fix: expand home directory for credentials file
- 2025-04-01 fix: Use path and filepath operation appropriately
- 2025-04-01 feat: Adds a new command line for tunnel run for token file
- 2025-04-01 chore: fix linter rules
- 2025-03-17 TUN-9101: Don't ignore errors on `cloudflared access ssh`
- 2025-03-06 TUN-9089: Pin go import to v0.30.0, v0.31.0 requires go 1.23
2025.2.1
- 2025-02-26 TUN-9016: update base-debian to v12
- 2025-02-25 TUN-8960: Connect to FED API GW based on the OriginCert's endpoint
- 2025-02-25 TUN-9007: modify logic to resolve region when the tunnel token has an endpoint field
- 2025-02-13 SDLC-3762: Remove backstage.io/source-location from catalog-info.yaml
- 2025-02-06 TUN-8914: Create a flags module to group all cloudflared cli flags
2025.2.0
- 2025-02-03 TUN-8914: Add a new configuration to locally override the max-active-flows
- 2025-02-03 Bump x/crypto to 0.31.0
2025.1.1
- 2025-01-30 TUN-8858: update go to 1.22.10 and include quic-go FIPS changes
- 2025-01-30 TUN-8855: fix lint issues
- 2025-01-30 TUN-8855: Update PQ curve preferences
- 2025-01-30 TUN-8857: remove restriction for using FIPS and PQ
- 2025-01-30 TUN-8894: report FIPS+PQ error to Sentry when dialling to the edge
- 2025-01-22 TUN-8904: Rename Connect Response Flow Rate Limited metadata
- 2025-01-21 AUTH-6633 Fix cloudflared access login + warp as auth
- 2025-01-20 TUN-8861: Add session limiter to UDP session manager
- 2025-01-20 TUN-8861: Rename Session Limiter to Flow Limiter
- 2025-01-17 TUN-8900: Add import of Apple Developer Certificate Authority to macOS Pipeline
- 2025-01-17 TUN-8871: Accept login flag to authenticate with Fedramp environment
- 2025-01-16 TUN-8866: Add linter to cloudflared repository
- 2025-01-14 TUN-8861: Add session limiter to TCP session manager
- 2025-01-13 TUN-8861: Add configuration for active sessions limiter
- 2025-01-09 TUN-8848: Don't treat connection shutdown as an error condition when RPC server is done
2025.1.0
- 2025-01-06 TUN-8842: Add Ubuntu Noble and 'any' debian distributions to release script
- 2025-01-06 TUN-8807: Add support_datagram_v3 to remote feature rollout
- 2024-12-20 TUN-8829: add CONTAINER_BUILD to dockerfiles
2024.12.2
- 2024-12-19 TUN-8822: Prevent concurrent usage of ICMPDecoder
- 2024-12-18 TUN-8818: update changes document to reflect newly added diag subcommand

View File

@ -17,7 +17,7 @@ make cloudflared-deb
mv cloudflared-fips\_$VERSION\_$arch.deb $ARTIFACT_DIR/cloudflared-fips-linux-$arch.deb
# rpm packages invert the - and _ and use x86_64 instead of amd64.
RPMVERSION=$(echo $VERSION|sed -r 's/-/_/g')
RPMVERSION=$(echo $VERSION | sed -r 's/-/_/g')
RPMARCH="x86_64"
make cloudflared-rpm
mv cloudflared-fips-$RPMVERSION-1.$RPMARCH.rpm $ARTIFACT_DIR/cloudflared-fips-linux-$RPMARCH.rpm

View File

@ -4,7 +4,6 @@ metadata:
name: cloudflared
description: Client for Cloudflare Tunnels
annotations:
backstage.io/source-location: url:https://bitbucket.cfdata.org/projects/TUN/repos/cloudflared/browse
cloudflare.com/software-excellence-opt-in: "true"
cloudflare.com/jira-project-key: "TUN"
cloudflare.com/jira-project-component: "Cloudflare Tunnel"

View File

@ -1,4 +1,4 @@
pinned_go: &pinned_go go-boring=1.22.5-1
pinned_go: &pinned_go go-boring=1.22.10-1
build_dir: &build_dir /cfsetup_build
default-flavor: bookworm
@ -13,10 +13,14 @@ bullseye: &bullseye
- rubygem-fpm
- rpm
- libffi-dev
- golangci-lint
pre-cache: &build_pre_cache
- export GOCACHE=/cfsetup_build/.cache/go-build
- go install golang.org/x/tools/cmd/goimports@latest
- go install golang.org/x/tools/cmd/goimports@v0.30.0
post-cache:
# Linting
- make lint
- make fmt-check
# Build binary for component test
- GOOS=linux GOARCH=amd64 make cloudflared
build-linux-fips:
@ -156,7 +160,6 @@ bullseye: &bullseye
- export GOOS=linux
- export GOARCH=amd64
- export PATH="$HOME/go/bin:$PATH"
- ./fmt-check.sh
- make test | gotest-to-teamcity
test-fips:
build_dir: *build_dir
@ -167,7 +170,6 @@ bullseye: &bullseye
- export GOARCH=amd64
- export FIPS=true
- export PATH="$HOME/go/bin:$PATH"
- ./fmt-check.sh
- make test | gotest-to-teamcity
component-test:
build_dir: *build_dir

View File

@ -104,7 +104,7 @@ func ssh(c *cli.Context) error {
case 3:
options.OriginURL = fmt.Sprintf("https://%s:%s", parts[2], parts[1])
options.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
InsecureSkipVerify: true, // #nosec G402
ServerName: parts[0],
}
log.Warn().Msgf("Using insecure SSL connection because SNI overridden to %s", parts[0])
@ -141,6 +141,5 @@ func ssh(c *cli.Context) error {
logger := log.With().Str("host", url.Host).Logger()
s = stream.NewDebugStream(s, &logger, maxMessages)
}
carrier.StartClient(wsConn, s, options)
return nil
return carrier.StartClient(wsConn, s, options)
}

View File

@ -19,6 +19,7 @@ import (
"github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/sshgen"
"github.com/cloudflare/cloudflared/token"
@ -172,15 +173,15 @@ func Commands() []*cli.Command {
EnvVars: []string{"TUNNEL_SERVICE_TOKEN_SECRET"},
},
&cli.StringFlag{
Name: logger.LogFileFlag,
Name: cfdflags.LogFile,
Usage: "Save application log to this file for reporting issues.",
},
&cli.StringFlag{
Name: logger.LogSSHDirectoryFlag,
Name: cfdflags.LogDirectory,
Usage: "Save application log to this directory for reporting issues.",
},
&cli.StringFlag{
Name: logger.LogSSHLevelFlag,
Name: cfdflags.LogLevelSSH,
Aliases: []string{"loglevel"}, //added to match the tunnel side
Usage: "Application logging level {debug, info, warn, error, fatal}. ",
},
@ -342,7 +343,7 @@ func run(cmd string, args ...string) error {
return err
}
go func() {
io.Copy(os.Stderr, stderr)
_, _ = io.Copy(os.Stderr, stderr)
}()
stdout, err := c.StdoutPipe()
@ -350,7 +351,7 @@ func run(cmd string, args ...string) error {
return err
}
go func() {
io.Copy(os.Stdout, stdout)
_, _ = io.Copy(os.Stdout, stdout)
}()
return c.Run()
}
@ -531,7 +532,7 @@ func isFileThere(candidate string) bool {
}
// verifyTokenAtEdge checks for a token on disk, or generates a new one.
// Then makes a request to to the origin with the token to ensure it is valid.
// Then makes a request to the origin with the token to ensure it is valid.
// Returns nil if token is valid.
func verifyTokenAtEdge(appUrl *url.URL, appInfo *token.AppInfo, c *cli.Context, log *zerolog.Logger) error {
headers := parseRequestHeaders(c.StringSlice(sshHeaderFlag))

View File

@ -4,7 +4,7 @@ import (
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"github.com/cloudflare/cloudflared/logger"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
)
var (
@ -15,14 +15,14 @@ var (
func ConfigureLoggingFlags(shouldHide bool) []cli.Flag {
return []cli.Flag{
altsrc.NewStringFlag(&cli.StringFlag{
Name: logger.LogLevelFlag,
Name: cfdflags.LogLevel,
Value: "info",
Usage: "Application logging level {debug, info, warn, error, fatal}. " + debugLevelWarning,
EnvVars: []string{"TUNNEL_LOGLEVEL"},
Hidden: shouldHide,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: logger.LogTransportLevelFlag,
Name: cfdflags.TransportLogLevel,
Aliases: []string{"proto-loglevel"}, // This flag used to be called proto-loglevel
Value: "info",
Usage: "Transport logging level(previously called protocol logging level) {debug, info, warn, error, fatal}",
@ -30,19 +30,19 @@ func ConfigureLoggingFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: logger.LogFileFlag,
Name: cfdflags.LogFile,
Usage: "Save application log to this file for reporting issues.",
EnvVars: []string{"TUNNEL_LOGFILE"},
Hidden: shouldHide,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: logger.LogDirectoryFlag,
Name: cfdflags.LogDirectory,
Usage: "Save application log to this directory for reporting issues.",
EnvVars: []string{"TUNNEL_LOGDIRECTORY"},
Hidden: shouldHide,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "trace-output",
Name: cfdflags.TraceOutput,
Usage: "Name of trace output file, generated when cloudflared stops.",
EnvVars: []string{"TUNNEL_TRACE_OUTPUT"},
Hidden: shouldHide,

View File

@ -0,0 +1,155 @@
package flags
const (
// HaConnections specifies how many connections to make to the edge
HaConnections = "ha-connections"
// SshPort is the port on localhost the cloudflared ssh server will run on
SshPort = "local-ssh-port"
// SshIdleTimeout defines the duration a SSH session can remain idle before being closed
SshIdleTimeout = "ssh-idle-timeout"
// SshMaxTimeout defines the max duration a SSH session can remain open for
SshMaxTimeout = "ssh-max-timeout"
// SshLogUploaderBucketName is the bucket name to use for the SSH log uploader
SshLogUploaderBucketName = "bucket-name"
// SshLogUploaderRegionName is the AWS region name to use for the SSH log uploader
SshLogUploaderRegionName = "region-name"
// SshLogUploaderSecretID is the Secret id of SSH log uploader
SshLogUploaderSecretID = "secret-id"
// SshLogUploaderAccessKeyID is the Access key id of SSH log uploader
SshLogUploaderAccessKeyID = "access-key-id"
// SshLogUploaderSessionTokenID is the Session token of SSH log uploader
SshLogUploaderSessionTokenID = "session-token"
// SshLogUploaderS3URL is the S3 URL of SSH log uploader (e.g. don't use AWS s3 and use google storage bucket instead)
SshLogUploaderS3URL = "s3-url-host"
// HostKeyPath is the path of the dir to save SSH host keys too
HostKeyPath = "host-key-path"
// RpcTimeout is how long to wait for a Capnp RPC request to the edge
RpcTimeout = "rpc-timeout"
// WriteStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin).
WriteStreamTimeout = "write-stream-timeout"
// QuicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size.
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
// Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets.
QuicDisablePathMTUDiscovery = "quic-disable-pmtu-discovery"
// QuicConnLevelFlowControlLimit controls the max flow control limit allocated for a QUIC connection. This controls how much data is the
// receiver willing to buffer. Once the limit is reached, the sender will send a DATA_BLOCKED frame to indicate it has more data to write,
// but it's blocked by flow control
QuicConnLevelFlowControlLimit = "quic-connection-level-flow-control-limit"
// QuicStreamLevelFlowControlLimit is similar to quicConnLevelFlowControlLimit but for each QUIC stream. When the sender is blocked,
// it will send a STREAM_DATA_BLOCKED frame
QuicStreamLevelFlowControlLimit = "quic-stream-level-flow-control-limit"
// Ui is to enable launching cloudflared in interactive UI mode
Ui = "ui"
// ConnectorLabel is the command line flag to give a meaningful label to a specific connector
ConnectorLabel = "label"
// MaxActiveFlows is the command line flag to set the maximum number of flows that cloudflared can be processing at the same time
MaxActiveFlows = "max-active-flows"
// Tag is the command line flag to set custom tags used to identify this tunnel via added HTTP request headers to the origin
Tag = "tag"
// Protocol is the command line flag to set the protocol to use to connect to the Cloudflare Edge
Protocol = "protocol"
// PostQuantum is the command line flag to force the connection to Cloudflare Edge to use Post Quantum cryptography
PostQuantum = "post-quantum"
// Features is the command line flag to opt into various features that are still being developed or tested
Features = "features"
// EdgeIpVersion is the command line flag to set the Cloudflare Edge IP address version to connect with
EdgeIpVersion = "edge-ip-version"
// EdgeBindAddress is the command line flag to bind to IP address for outgoing connections to Cloudflare Edge
EdgeBindAddress = "edge-bind-address"
// Force is the command line flag to specify if you wish to force an action
Force = "force"
// Edge is the command line flag to set the address of the Cloudflare tunnel server. Only works in Cloudflare's internal testing environment
Edge = "edge"
// Region is the command line flag to set the Cloudflare Edge region to connect to
Region = "region"
// IsAutoUpdated is the command line flag to signal the new process that cloudflared has been autoupdated
IsAutoUpdated = "is-autoupdated"
// LBPool is the command line flag to set the name of the load balancing pool to add this origin to
LBPool = "lb-pool"
// Retries is the command line flag to set the maximum number of retries for connection/protocol errors
Retries = "retries"
// MaxEdgeAddrRetries is the command line flag to set the maximum number of times to retry on edge addrs before falling back to a lower protocol
MaxEdgeAddrRetries = "max-edge-addr-retries"
// GracePeriod is the command line flag to set the maximum amount of time that cloudflared waits to shut down if it is still serving requests
GracePeriod = "grace-period"
// ICMPV4Src is the command line flag to set the source address and the interface name to send/receive ICMPv4 messages
ICMPV4Src = "icmpv4-src"
// ICMPV6Src is the command line flag to set the source address and the interface name to send/receive ICMPv6 messages
ICMPV6Src = "icmpv6-src"
// ProxyDns is the command line flag to run DNS server over HTTPS
ProxyDns = "proxy-dns"
// Name is the command line to set the name of the tunnel
Name = "name"
// AutoUpdateFreq is the command line for setting the frequency that cloudflared checks for updates
AutoUpdateFreq = "autoupdate-freq"
// NoAutoUpdate is the command line flag to disable cloudflared from checking for updates
NoAutoUpdate = "no-autoupdate"
// LogLevel is the command line flag for the cloudflared logging level
LogLevel = "loglevel"
// LogLevelSSH is the command line flag for the cloudflared ssh logging level
LogLevelSSH = "log-level"
// TransportLogLevel is the command line flag for the transport logging level
TransportLogLevel = "transport-loglevel"
// LogFile is the command line flag to define the file where application logs will be stored
LogFile = "logfile"
// LogDirectory is the command line flag to define the directory where application logs will be stored.
LogDirectory = "log-directory"
// TraceOutput is the command line flag to set the name of trace output file
TraceOutput = "trace-output"
// OriginCert is the command line flag to define the path for the origin certificate used by cloudflared
OriginCert = "origincert"
// Metrics is the command line flag to define the address of the metrics server
Metrics = "metrics"
// MetricsUpdateFreq is the command line flag to define how frequently tunnel metrics are updated
MetricsUpdateFreq = "metrics-update-freq"
// ApiURL is the command line flag used to define the base URL of the API
ApiURL = "api-url"
)

View File

@ -3,11 +3,38 @@
package main
import (
"fmt"
"os"
cli "github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
)
func runApp(app *cli.App, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the cloudflared system service (not supported on this operating system)",
Subcommands: []*cli.Command{
{
Name: "install",
Usage: "Install cloudflared as a system service (not supported on this operating system)",
Action: cliutil.ConfiguredAction(installGenericService),
},
{
Name: "uninstall",
Usage: "Uninstall the cloudflared service (not supported on this operating system)",
Action: cliutil.ConfiguredAction(uninstallGenericService),
},
},
})
app.Run(os.Args)
}
func installGenericService(c *cli.Context) error {
return fmt.Errorf("service installation is not supported on this operating system")
}
func uninstallGenericService(c *cli.Context) error {
return fmt.Errorf("service uninstallation is not supported on this operating system")
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"os"
homedir "github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/urfave/cli/v2"
@ -17,7 +18,7 @@ const (
launchdIdentifier = "com.cloudflare.cloudflared"
)
func runApp(app *cli.App, graceShutdownC chan struct{}) {
func runApp(app *cli.App, _ chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the cloudflared launch agent",
@ -119,7 +120,7 @@ func installLaunchd(c *cli.Context) error {
log.Info().Msg("Installing cloudflared client as an user launch agent. " +
"Note that cloudflared client will only run when the user is logged in. " +
"If you want to run cloudflared client at boot, install with root permission. " +
"For more information, visit https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/run-as-service")
"For more information, visit https://developers.cloudflare.com/cloudflare-one/connections/connect-networks/configure-tunnels/local-management/as-a-service/macos/")
}
etPath, err := os.Executable()
if err != nil {
@ -207,3 +208,15 @@ func uninstallLaunchd(c *cli.Context) error {
}
return err
}
func userHomeDir() (string, error) {
// This returns the home dir of the executing user using OS-specific method
// for discovering the home dir. It's not recommended to call this function
// when the user has root permission as $HOME depends on what options the user
// use with sudo.
homeDir, err := homedir.Dir()
if err != nil {
return "", errors.Wrap(err, "Cannot determine home directory for the user")
}
return homeDir, nil
}

View File

@ -2,19 +2,17 @@ package main
import (
"fmt"
"math/rand"
"os"
"strings"
"time"
"github.com/getsentry/sentry-go"
homedir "github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/urfave/cli/v2"
"go.uber.org/automaxprocs/maxprocs"
"github.com/cloudflare/cloudflared/cmd/cloudflared/access"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/cmd/cloudflared/proxydns"
"github.com/cloudflare/cloudflared/cmd/cloudflared/tail"
"github.com/cloudflare/cloudflared/cmd/cloudflared/tunnel"
@ -52,10 +50,8 @@ var (
func main() {
// FIXME: TUN-8148: Disable QUIC_GO ECN due to bugs in proper detection if supported
os.Setenv("QUIC_GO_DISABLE_ECN", "1")
rand.Seed(time.Now().UnixNano())
metrics.RegisterBuildInfo(BuildType, BuildTime, Version)
maxprocs.Set()
_, _ = maxprocs.Set()
bInfo := cliutil.GetBuildInfo(BuildType, Version)
// Graceful shutdown channel used by the app. When closed, app must terminate gracefully.
@ -110,7 +106,7 @@ func commands(version func(c *cli.Context)) []*cli.Command {
Usage: "specify if you wish to update to the latest beta version",
},
&cli.BoolFlag{
Name: "force",
Name: cfdflags.Force,
Usage: "specify if you wish to force an upgrade to the latest version regardless of the current version",
Hidden: true,
},
@ -184,18 +180,6 @@ func action(graceShutdownC chan struct{}) cli.ActionFunc {
})
}
func userHomeDir() (string, error) {
// This returns the home dir of the executing user using OS-specific method
// for discovering the home dir. It's not recommended to call this function
// when the user has root permission as $HOME depends on what options the user
// use with sudo.
homeDir, err := homedir.Dir()
if err != nil {
return "", errors.Wrap(err, "Cannot determine home directory for the user")
}
return homeDir, nil
}
// In order to keep the amount of noise sent to Sentry low, typical network errors can be filtered out here by a substring match.
func captureError(err error) {
errorMessage := err.Error()

View File

@ -1,13 +1,13 @@
package main
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path"
"path/filepath"
"text/template"
homedir "github.com/mitchellh/go-homedir"
@ -44,7 +44,7 @@ func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
return err
}
if _, err = os.Stat(resolvedPath); err == nil {
return fmt.Errorf(serviceAlreadyExistsWarn(resolvedPath))
return errors.New(serviceAlreadyExistsWarn(resolvedPath))
}
var buffer bytes.Buffer
@ -57,7 +57,7 @@ func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
fileMode = st.FileMode
}
plistFolder := path.Dir(resolvedPath)
plistFolder := filepath.Dir(resolvedPath)
err = os.MkdirAll(plistFolder, 0o755)
if err != nil {
return fmt.Errorf("error creating %s: %v", plistFolder, err)
@ -118,49 +118,6 @@ func ensureConfigDirExists(configDir string) error {
return err
}
// openFile opens the file at path. If create is set and the file exists, returns nil, true, nil
func openFile(path string, create bool) (file *os.File, exists bool, err error) {
expandedPath, err := homedir.Expand(path)
if err != nil {
return nil, false, err
}
if create {
fileInfo, err := os.Stat(expandedPath)
if err == nil && fileInfo.Size() > 0 {
return nil, true, nil
}
file, err = os.OpenFile(expandedPath, os.O_RDWR|os.O_CREATE, 0600)
} else {
file, err = os.Open(expandedPath)
}
return file, false, err
}
func copyCredential(srcCredentialPath, destCredentialPath string) error {
destFile, exists, err := openFile(destCredentialPath, true)
if err != nil {
return err
} else if exists {
// credentials already exist, do nothing
return nil
}
defer destFile.Close()
srcFile, _, err := openFile(srcCredentialPath, false)
if err != nil {
return err
}
defer srcFile.Close()
// Copy certificate
_, err = io.Copy(destFile, srcFile)
if err != nil {
return fmt.Errorf("unable to copy %s to %s: %v", srcCredentialPath, destCredentialPath, err)
}
return nil
}
func copyFile(src, dest string) error {
srcFile, err := os.Open(src)
if err != nil {
@ -187,36 +144,3 @@ func copyFile(src, dest string) error {
ok = true
return nil
}
func copyConfig(srcConfigPath, destConfigPath string) error {
// Copy or create config
destFile, exists, err := openFile(destConfigPath, true)
if err != nil {
return fmt.Errorf("cannot open %s with error: %s", destConfigPath, err)
} else if exists {
// config already exists, do nothing
return nil
}
defer destFile.Close()
srcFile, _, err := openFile(srcConfigPath, false)
if err != nil {
fmt.Println("Your service needs a config file that at least specifies the hostname option.")
fmt.Println("Type in a hostname now, or leave it blank and create the config file later.")
fmt.Print("Hostname: ")
reader := bufio.NewReader(os.Stdin)
input, _ := reader.ReadString('\n')
if input == "" {
return err
}
fmt.Fprintf(destFile, "hostname: %s\n", input)
} else {
defer srcFile.Close()
_, err = io.Copy(destFile, srcFile)
if err != nil {
return fmt.Errorf("unable to copy %s to %s: %v", srcConfigPath, destConfigPath, err)
}
}
return nil
}

View File

@ -18,14 +18,12 @@ import (
"nhooyr.io/websocket"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/management"
)
var (
buildInfo *cliutil.BuildInfo
)
var buildInfo *cliutil.BuildInfo
func Init(bi *cliutil.BuildInfo) {
buildInfo = bi
@ -56,7 +54,7 @@ func managementTokenCommand(c *cli.Context) error {
if err != nil {
return err
}
var tokenResponse = struct {
tokenResponse := struct {
Token string `json:"token"`
}{Token: token}
@ -119,13 +117,13 @@ func buildTailCommand(subcommands []*cli.Command) *cli.Command {
Value: "",
},
&cli.StringFlag{
Name: logger.LogLevelFlag,
Name: cfdflags.LogLevel,
Value: "info",
Usage: "Application logging level {debug, info, warn, error, fatal}",
EnvVars: []string{"TUNNEL_LOGLEVEL"},
},
&cli.StringFlag{
Name: credentials.OriginCertFlag,
Name: cfdflags.OriginCert,
Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: credentials.FindDefaultOriginCertPath(),
@ -169,7 +167,7 @@ func handleValidationError(resp *http.Response, log *zerolog.Logger) {
// logger will be created to emit only against the os.Stderr as to not obstruct with normal output from
// management requests
func createLogger(c *cli.Context) *zerolog.Logger {
level, levelErr := zerolog.ParseLevel(c.String(logger.LogLevelFlag))
level, levelErr := zerolog.ParseLevel(c.String(cfdflags.LogLevel))
if levelErr != nil {
level = zerolog.InfoLevel
}
@ -183,9 +181,10 @@ func createLogger(c *cli.Context) *zerolog.Logger {
// parseFilters will attempt to parse provided filters to send to with the EventStartStreaming
func parseFilters(c *cli.Context) (*management.StreamingFilters, error) {
var level *management.LogLevel
var events []management.LogEventType
var sample float64
events := make([]management.LogEventType, 0)
argLevel := c.String("level")
argEvents := c.StringSlice("event")
argSample := c.Float64("sample")
@ -225,12 +224,12 @@ func parseFilters(c *cli.Context) (*management.StreamingFilters, error) {
// getManagementToken will make a call to the Cloudflare API to acquire a management token for the requested tunnel.
func getManagementToken(c *cli.Context, log *zerolog.Logger) (string, error) {
userCreds, err := credentials.Read(c.String(credentials.OriginCertFlag), log)
userCreds, err := credentials.Read(c.String(cfdflags.OriginCert), log)
if err != nil {
return "", err
}
client, err := userCreds.Client(c.String("api-url"), buildInfo.UserAgent(), log)
client, err := userCreds.Client(c.String(cfdflags.ApiURL), buildInfo.UserAgent(), log)
if err != nil {
return "", err
}
@ -331,6 +330,7 @@ func Run(c *cli.Context) error {
header["cf-trace-id"] = []string{trace}
}
ctx := c.Context
// nolint: bodyclose
conn, resp, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
HTTPHeader: header,
})

View File

@ -16,7 +16,7 @@ import (
"github.com/facebookgo/grace/gracenet"
"github.com/getsentry/sentry-go"
"github.com/google/uuid"
homedir "github.com/mitchellh/go-homedir"
"github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/urfave/cli/v2"
@ -24,6 +24,7 @@ import (
"github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/cmd/cloudflared/proxydns"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/config"
@ -31,7 +32,6 @@ import (
"github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/diagnostic"
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/features"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/management"
@ -48,61 +48,6 @@ import (
const (
sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
// ha-Connections specifies how many connections to make to the edge
haConnectionsFlag = "ha-connections"
// sshPortFlag is the port on localhost the cloudflared ssh server will run on
sshPortFlag = "local-ssh-port"
// sshIdleTimeoutFlag defines the duration a SSH session can remain idle before being closed
sshIdleTimeoutFlag = "ssh-idle-timeout"
// sshMaxTimeoutFlag defines the max duration a SSH session can remain open for
sshMaxTimeoutFlag = "ssh-max-timeout"
// bucketNameFlag is the bucket name to use for the SSH log uploader
bucketNameFlag = "bucket-name"
// regionNameFlag is the AWS region name to use for the SSH log uploader
regionNameFlag = "region-name"
// secretIDFlag is the Secret id of SSH log uploader
secretIDFlag = "secret-id"
// accessKeyIDFlag is the Access key id of SSH log uploader
accessKeyIDFlag = "access-key-id"
// sessionTokenIDFlag is the Session token of SSH log uploader
sessionTokenIDFlag = "session-token"
// s3URLFlag is the S3 URL of SSH log uploader (e.g. don't use AWS s3 and use google storage bucket instead)
s3URLFlag = "s3-url-host"
// hostKeyPath is the path of the dir to save SSH host keys too
hostKeyPath = "host-key-path"
// rpcTimeout is how long to wait for a Capnp RPC request to the edge
rpcTimeout = "rpc-timeout"
// writeStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin).
writeStreamTimeout = "write-stream-timeout"
// quicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size.
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
// Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets.
quicDisablePathMTUDiscovery = "quic-disable-pmtu-discovery"
// quicConnLevelFlowControlLimit controls the max flow control limit allocated for a QUIC connection. This controls how much data is the
// receiver willing to buffer. Once the limit is reached, the sender will send a DATA_BLOCKED frame to indicate it has more data to write,
// but it's blocked by flow control
quicConnLevelFlowControlLimit = "quic-connection-level-flow-control-limit"
// quicStreamLevelFlowControlLimit is similar to quicConnLevelFlowControlLimit but for each QUIC stream. When the sender is blocked,
// it will send a STREAM_DATA_BLOCKED frame
quicStreamLevelFlowControlLimit = "quic-stream-level-flow-control-limit"
// uiFlag is to enable launching cloudflared in interactive UI mode
uiFlag = "ui"
LogFieldCommand = "command"
LogFieldExpandedPath = "expandedPath"
LogFieldPIDPathname = "pidPathname"
@ -117,7 +62,6 @@ Eg. cloudflared tunnel --url localhost:8080/.
Please note that Quick Tunnels are meant to be ephemeral and should only be used for testing purposes.
For production usage, we recommend creating Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)
`
connectorLabelFlag = "label"
)
var (
@ -127,14 +71,14 @@ var (
routeFailMsg = fmt.Sprintf("failed to provision routing, please create it manually via Cloudflare dashboard or UI; "+
"most likely you already have a conflicting record there. You can also rerun this command with --%s to overwrite "+
"any existing DNS records for this hostname.", overwriteDNSFlag)
deprecatedClassicTunnelErr = fmt.Errorf("Classic tunnels have been deprecated, please use Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)")
errDeprecatedClassicTunnel = errors.New("Classic tunnels have been deprecated, please use Named Tunnels. (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/tunnel-guide/)")
// TODO: TUN-8756 the list below denotes the flags that do not possess any kind of sensitive information
// however this approach is not maintainble in the long-term.
nonSecretFlagsList = []string{
"config",
"autoupdate-freq",
"no-autoupdate",
"metrics",
cfdflags.AutoUpdateFreq,
cfdflags.NoAutoUpdate,
cfdflags.Metrics,
"pidfile",
"url",
"hello-world",
@ -167,54 +111,55 @@ var (
"bastion",
"proxy-address",
"proxy-port",
"loglevel",
"transport-loglevel",
"logfile",
"log-directory",
"trace-output",
"proxy-dns",
cfdflags.LogLevel,
cfdflags.TransportLogLevel,
cfdflags.LogFile,
cfdflags.LogDirectory,
cfdflags.TraceOutput,
cfdflags.ProxyDns,
"proxy-dns-port",
"proxy-dns-address",
"proxy-dns-upstream",
"proxy-dns-max-upstream-conns",
"proxy-dns-bootstrap",
"is-autoupdated",
"edge",
"region",
"edge-ip-version",
"edge-bind-address",
cfdflags.IsAutoUpdated,
cfdflags.Edge,
cfdflags.Region,
cfdflags.EdgeIpVersion,
cfdflags.EdgeBindAddress,
"cacert",
"hostname",
"id",
"lb-pool",
"api-url",
"metrics-update-freq",
"tag",
cfdflags.LBPool,
cfdflags.ApiURL,
cfdflags.MetricsUpdateFreq,
cfdflags.Tag,
"heartbeat-interval",
"heartbeat-count",
"max-edge-addr-retries",
"retries",
cfdflags.MaxEdgeAddrRetries,
cfdflags.Retries,
"ha-connections",
"rpc-timeout",
"write-stream-timeout",
"quic-disable-pmtu-discovery",
"quic-connection-level-flow-control-limit",
"quic-stream-level-flow-control-limit",
"label",
"grace-period",
cfdflags.ConnectorLabel,
cfdflags.GracePeriod,
"compression-quality",
"use-reconnect-token",
"dial-edge-timeout",
"stdin-control",
"name",
"ui",
cfdflags.Name,
cfdflags.Ui,
"quick-service",
"max-fetch-size",
"post-quantum",
cfdflags.PostQuantum,
"management-diagnostics",
"protocol",
cfdflags.Protocol,
"overwrite-dns",
"help",
cfdflags.MaxActiveFlows,
}
)
@ -263,7 +208,7 @@ then protect with Cloudflare Access).
B) Locally reachable TCP/UDP-based private services to Cloudflare connected private users in the same account, e.g.,
those enrolled to a Zero Trust WARP Client.
You can manage your Tunnels via dash.teams.cloudflare.com. This approach will only require you to run a single command
You can manage your Tunnels via one.dash.cloudflare.com. This approach will only require you to run a single command
later in each machine where you wish to run a Tunnel.
Alternatively, you can manage your Tunnels via the command line. Begin by obtaining a certificate to be able to do so:
@ -299,7 +244,7 @@ func TunnelCommand(c *cli.Context) error {
// --name required
// --url or --hello-world required
// --hostname optional
if name := c.String("name"); name != "" {
if name := c.String(cfdflags.Name); name != "" {
hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil {
return errors.Wrap(err, "Invalid hostname provided")
@ -316,7 +261,7 @@ func TunnelCommand(c *cli.Context) error {
// A unauthenticated named tunnel hosted on <random>.<quick-tunnels-service>.com
// We don't support running proxy-dns and a quick tunnel at the same time as the same process
shouldRunQuickTunnel := c.IsSet("url") || c.IsSet(ingress.HelloWorldFlag)
if !c.IsSet("proxy-dns") && c.String("quick-service") != "" && shouldRunQuickTunnel {
if !c.IsSet(cfdflags.ProxyDns) && c.String("quick-service") != "" && shouldRunQuickTunnel {
return RunQuickTunnel(sc)
}
@ -327,10 +272,10 @@ func TunnelCommand(c *cli.Context) error {
// Classic tunnel usage is no longer supported
if c.String("hostname") != "" {
return deprecatedClassicTunnelErr
return errDeprecatedClassicTunnel
}
if c.IsSet("proxy-dns") {
if c.IsSet(cfdflags.ProxyDns) {
if shouldRunQuickTunnel {
return fmt.Errorf("running a quick tunnel with `proxy-dns` is not supported")
}
@ -377,7 +322,7 @@ func runAdhocNamedTunnel(sc *subcommandContext, name, credentialsOutputPath stri
func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) {
if hostname := c.String("hostname"); hostname != "" {
if lbPool := c.String("lb-pool"); lbPool != "" {
if lbPool := c.String(cfdflags.LBPool); lbPool != "" {
return cfapi.NewLBRoute(hostname, lbPool), true
}
return cfapi.NewDNSRoute(hostname, c.Bool(overwriteDNSFlagName)), true
@ -407,7 +352,7 @@ func StartServer(
log.Info().Msg(config.ErrNoConfigFile.Error())
}
if c.IsSet("trace-output") {
if c.IsSet(cfdflags.TraceOutput) {
tmpTraceFile, err := os.CreateTemp("", "trace")
if err != nil {
log.Err(err).Msg("Failed to create new temporary file to save trace output")
@ -419,7 +364,7 @@ func StartServer(
if err := tmpTraceFile.Close(); err != nil {
traceLog.Err(err).Msg("Failed to close temporary trace output file")
}
traceOutputFilepath := c.String("trace-output")
traceOutputFilepath := c.String(cfdflags.TraceOutput)
if err := os.Rename(tmpTraceFile.Name(), traceOutputFilepath); err != nil {
traceLog.
Err(err).
@ -449,7 +394,7 @@ func StartServer(
go waitForSignal(graceShutdownC, log)
if c.IsSet("proxy-dns") {
if c.IsSet(cfdflags.ProxyDns) {
dnsReadySignal := make(chan struct{})
wg.Add(1)
go func() {
@ -471,7 +416,7 @@ func StartServer(
go func() {
defer wg.Done()
autoupdater := updater.NewAutoUpdater(
c.Bool("no-autoupdate"), c.Duration("autoupdate-freq"), &listeners, log,
c.Bool(cfdflags.NoAutoUpdate), c.Duration(cfdflags.AutoUpdateFreq), &listeners, log,
)
errC <- autoupdater.Run(ctx)
}()
@ -515,8 +460,6 @@ func StartServer(
tunnelConfig.ICMPRouterServer = nil
}
internalRules := []ingress.Rule{}
if features.Contains(features.FeatureManagementLogs) {
serviceIP := c.String("service-op-ip")
if edgeAddrs, err := edgediscovery.ResolveEdge(log, tunnelConfig.Region, tunnelConfig.EdgeIPVersion); err == nil {
if serviceAddr, err := edgeAddrs.GetAddrForRPC(); err == nil {
@ -529,12 +472,11 @@ func StartServer(
c.Bool("management-diagnostics"),
serviceIP,
clientID,
c.String(connectorLabelFlag),
c.String(cfdflags.ConnectorLabel),
logger.ManagementLogger.Log,
logger.ManagementLogger,
)
internalRules = []ingress.Rule{ingress.NewManagementRule(mgmt)}
}
internalRules := []ingress.Rule{ingress.NewManagementRule(mgmt)}
orchestrator, err := orchestration.NewOrchestrator(ctx, orchestratorConfig, tunnelConfig.Tags, internalRules, tunnelConfig.Log)
if err != nil {
return err
@ -582,7 +524,7 @@ func StartServer(
errC <- metrics.ServeMetrics(metricsListener, ctx, metricsConfig, log)
}()
reconnectCh := make(chan supervisor.ReconnectSignal, c.Int(haConnectionsFlag))
reconnectCh := make(chan supervisor.ReconnectSignal, c.Int(cfdflags.HaConnections))
if c.IsSet("stdin-control") {
log.Info().Msg("Enabling control through stdin")
go stdinControl(reconnectCh, log)
@ -619,8 +561,10 @@ func waitToShutdown(wg *sync.WaitGroup,
log.Debug().Msg("Graceful shutdown signalled")
if gracePeriod > 0 {
// wait for either grace period or service termination
ticker := time.NewTicker(gracePeriod)
defer ticker.Stop()
select {
case <-time.Tick(gracePeriod):
case <-ticker.C:
case <-errC:
}
}
@ -648,7 +592,7 @@ func waitToShutdown(wg *sync.WaitGroup,
func notifySystemd(waitForSignal *signal.Signal) {
<-waitForSignal.Wait()
daemon.SdNotify(false, "READY=1")
_, _ = daemon.SdNotify(false, "READY=1")
}
func writePidFile(waitForSignal *signal.Signal, pidPathname string, log *zerolog.Logger) {
@ -700,31 +644,31 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
flags = append(flags, []cli.Flag{
credentialsFileFlag,
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "is-autoupdated",
Name: cfdflags.IsAutoUpdated,
Usage: "Signal the new process that Cloudflare Tunnel connector has been autoupdated",
Value: false,
Hidden: true,
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "edge",
Name: cfdflags.Edge,
Usage: "Address of the Cloudflare tunnel server. Only works in Cloudflare's internal testing environment.",
EnvVars: []string{"TUNNEL_EDGE"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "region",
Name: cfdflags.Region,
Usage: "Cloudflare Edge region to connect to. Omit or set to empty to connect to the global region.",
EnvVars: []string{"TUNNEL_REGION"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "edge-ip-version",
Name: cfdflags.EdgeIpVersion,
Usage: "Cloudflare Edge IP address version to connect with. {4, 6, auto}",
EnvVars: []string{"TUNNEL_EDGE_IP_VERSION"},
Value: "4",
Hidden: false,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "edge-bind-address",
Name: cfdflags.EdgeBindAddress,
Usage: "Bind to IP address for outgoing connections to Cloudflare Edge.",
EnvVars: []string{"TUNNEL_EDGE_BIND_ADDRESS"},
Hidden: false,
@ -748,7 +692,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "lb-pool",
Name: cfdflags.LBPool,
Usage: "The name of a (new/existing) load balancing pool to add this origin to.",
EnvVars: []string{"TUNNEL_LB_POOL"},
Hidden: shouldHide,
@ -772,21 +716,21 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "api-url",
Name: cfdflags.ApiURL,
Usage: "Base URL for Cloudflare API v4",
EnvVars: []string{"TUNNEL_API_URL"},
Value: "https://api.cloudflare.com/client/v4",
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "metrics-update-freq",
Name: cfdflags.MetricsUpdateFreq,
Usage: "Frequency to update tunnel metrics",
Value: time.Second * 5,
EnvVars: []string{"TUNNEL_METRICS_UPDATE_FREQ"},
Hidden: shouldHide,
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "tag",
Name: cfdflags.Tag,
Usage: "Custom tags used to identify this tunnel via added HTTP request headers to the origin, in format `KEY=VALUE`. Multiple tags may be specified.",
EnvVars: []string{"TUNNEL_TAG"},
Hidden: true,
@ -805,64 +749,64 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: true,
}),
altsrc.NewIntFlag(&cli.IntFlag{
Name: "max-edge-addr-retries",
Name: cfdflags.MaxEdgeAddrRetries,
Usage: "Maximum number of times to retry on edge addrs before falling back to a lower protocol",
Value: 8,
Hidden: true,
}),
// Note TUN-3758 , we use Int because UInt is not supported with altsrc
altsrc.NewIntFlag(&cli.IntFlag{
Name: "retries",
Name: cfdflags.Retries,
Value: 5,
Usage: "Maximum number of retries for connection/protocol errors.",
EnvVars: []string{"TUNNEL_RETRIES"},
Hidden: shouldHide,
}),
altsrc.NewIntFlag(&cli.IntFlag{
Name: haConnectionsFlag,
Name: cfdflags.HaConnections,
Value: 4,
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: rpcTimeout,
Name: cfdflags.RpcTimeout,
Value: 5 * time.Second,
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: writeStreamTimeout,
Name: cfdflags.WriteStreamTimeout,
EnvVars: []string{"TUNNEL_STREAM_WRITE_TIMEOUT"},
Usage: "Use this option to add a stream write timeout for connections when writing towards the origin or edge. Default is 0 which disables the write timeout.",
Value: 0 * time.Second,
Hidden: true,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: quicDisablePathMTUDiscovery,
Name: cfdflags.QuicDisablePathMTUDiscovery,
EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"},
Usage: "Use this option to disable PTMU discovery for QUIC connections. This will result in lower packet sizes. Not however, that this may cause instability for UDP proxying.",
Value: false,
Hidden: true,
}),
altsrc.NewIntFlag(&cli.IntFlag{
Name: quicConnLevelFlowControlLimit,
Name: cfdflags.QuicConnLevelFlowControlLimit,
EnvVars: []string{"TUNNEL_QUIC_CONN_LEVEL_FLOW_CONTROL_LIMIT"},
Usage: "Use this option to change the connection-level flow control limit for QUIC transport.",
Value: 30 * (1 << 20), // 30 MB
Hidden: true,
}),
altsrc.NewIntFlag(&cli.IntFlag{
Name: quicStreamLevelFlowControlLimit,
Name: cfdflags.QuicStreamLevelFlowControlLimit,
EnvVars: []string{"TUNNEL_QUIC_STREAM_LEVEL_FLOW_CONTROL_LIMIT"},
Usage: "Use this option to change the connection-level flow control limit for QUIC transport.",
Value: 6 * (1 << 20), // 6 MB
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: connectorLabelFlag,
Name: cfdflags.ConnectorLabel,
Usage: "Use this option to give a meaningful label to a specific connector. When a tunnel starts up, a connector id unique to the tunnel is generated. This is a uuid. To make it easier to identify a connector, we will use the hostname of the machine the tunnel is running on along with the connector ID. This option exists if one wants to have more control over what their individual connectors are called.",
Value: "",
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "grace-period",
Name: cfdflags.GracePeriod,
Usage: "When cloudflared receives SIGINT/SIGTERM it will stop accepting new requests, wait for in-progress requests to terminate, then shutdown. Waiting for in-progress requests will timeout after this grace period, or when a second SIGTERM/SIGINT is received.",
Value: time.Second * 30,
EnvVars: []string{"TUNNEL_GRACE_PERIOD"},
@ -898,14 +842,14 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Value: false,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "name",
Name: cfdflags.Name,
Aliases: []string{"n"},
EnvVars: []string{"TUNNEL_NAME"},
Usage: "Stable name to identify the tunnel. Using this flag will create, route and run a tunnel. For production usage, execute each command separately",
Hidden: shouldHide,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: uiFlag,
Name: cfdflags.Ui,
Usage: "(depreciated) Launch tunnel UI. Tunnel logs are scrollable via 'j', 'k', or arrow keys.",
Value: false,
Hidden: true,
@ -923,11 +867,10 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Hidden: true,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "post-quantum",
Name: cfdflags.PostQuantum,
Usage: "When given creates an experimental post-quantum secure tunnel",
Aliases: []string{"pq"},
EnvVars: []string{"TUNNEL_POST_QUANTUM"},
Hidden: FipsEnabled,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "management-diagnostics",
@ -952,27 +895,27 @@ func configureCloudflaredFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide,
},
altsrc.NewStringFlag(&cli.StringFlag{
Name: credentials.OriginCertFlag,
Name: cfdflags.OriginCert,
Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: credentials.FindDefaultOriginCertPath(),
Hidden: shouldHide,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "autoupdate-freq",
Name: cfdflags.AutoUpdateFreq,
Usage: fmt.Sprintf("Autoupdate frequency. Default is %v.", updater.DefaultCheckUpdateFreq),
Value: updater.DefaultCheckUpdateFreq,
Hidden: shouldHide,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "no-autoupdate",
Name: cfdflags.NoAutoUpdate,
Usage: "Disable periodic check for updates, restarting the server with the new version.",
EnvVars: []string{"NO_AUTOUPDATE"},
Value: false,
Hidden: shouldHide,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "metrics",
Name: cfdflags.Metrics,
Value: metrics.GetMetricsDefaultAddress(metrics.Runtime),
Usage: fmt.Sprintf(
`Listen address for metrics reporting. If no address is passed cloudflared will try to bind to %v.
@ -1136,62 +1079,62 @@ func legacyTunnelFlag(msg string) string {
func sshFlags(shouldHide bool) []cli.Flag {
return []cli.Flag{
altsrc.NewStringFlag(&cli.StringFlag{
Name: sshPortFlag,
Name: cfdflags.SshPort,
Usage: "Localhost port that cloudflared SSH server will run on",
Value: "2222",
EnvVars: []string{"LOCAL_SSH_PORT"},
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: sshIdleTimeoutFlag,
Name: cfdflags.SshIdleTimeout,
Usage: "Connection timeout after no activity",
EnvVars: []string{"SSH_IDLE_TIMEOUT"},
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: sshMaxTimeoutFlag,
Name: cfdflags.SshMaxTimeout,
Usage: "Absolute connection timeout",
EnvVars: []string{"SSH_MAX_TIMEOUT"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: bucketNameFlag,
Name: cfdflags.SshLogUploaderBucketName,
Usage: "Bucket name of where to upload SSH logs",
EnvVars: []string{"BUCKET_ID"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: regionNameFlag,
Name: cfdflags.SshLogUploaderRegionName,
Usage: "Region name of where to upload SSH logs",
EnvVars: []string{"REGION_ID"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: secretIDFlag,
Name: cfdflags.SshLogUploaderSecretID,
Usage: "Secret ID of where to upload SSH logs",
EnvVars: []string{"SECRET_ID"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: accessKeyIDFlag,
Name: cfdflags.SshLogUploaderAccessKeyID,
Usage: "Access Key ID of where to upload SSH logs",
EnvVars: []string{"ACCESS_CLIENT_ID"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: sessionTokenIDFlag,
Name: cfdflags.SshLogUploaderSessionTokenID,
Usage: "Session Token to use in the configuration of SSH logs uploading",
EnvVars: []string{"SESSION_TOKEN_ID"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: s3URLFlag,
Name: cfdflags.SshLogUploaderS3URL,
Usage: "S3 url of where to upload SSH logs",
EnvVars: []string{"S3_URL"},
Hidden: true,
}),
altsrc.NewPathFlag(&cli.PathFlag{
Name: hostKeyPath,
Name: cfdflags.HostKeyPath,
Usage: "Absolute path of directory to save SSH host keys in",
EnvVars: []string{"HOST_KEY_PATH"},
Hidden: true,
@ -1231,7 +1174,7 @@ func sshFlags(shouldHide bool) []cli.Flag {
func configureProxyDNSFlags(shouldHide bool) []cli.Flag {
return []cli.Flag{
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "proxy-dns",
Name: cfdflags.ProxyDns,
Usage: "Run a DNS over HTTPS proxy server.",
EnvVars: []string{"TUNNEL_DNS"},
Hidden: shouldHide,
@ -1329,7 +1272,7 @@ func nonSecretCliFlags(log *zerolog.Logger, cli *cli.Context, flagInclusionList
}
switch flag {
case logger.LogDirectoryFlag, logger.LogFileFlag:
case cfdflags.LogDirectory, cfdflags.LogFile:
{
absolute, err := filepath.Abs(value)
if err != nil {

View File

@ -18,6 +18,7 @@ import (
"golang.org/x/term"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery"
@ -33,26 +34,27 @@ import (
const (
secretValue = "*****"
icmpFunnelTimeout = time.Second * 10
fedRampRegion = "fed" // const string denoting the region used to connect to FEDRamp servers
)
var (
developerPortal = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup"
serviceUrl = developerPortal + "/tunnel-guide/local/as-a-service/"
argumentsUrl = developerPortal + "/tunnel-guide/local/local-management/arguments/"
secretFlags = [2]*altsrc.StringFlag{credentialsContentsFlag, tunnelTokenFlag}
configFlags = []string{"autoupdate-freq", "no-autoupdate", "retries", "protocol", "loglevel", "transport-loglevel", "origincert", "metrics", "metrics-update-freq", "edge-ip-version", "edge-bind-address"}
)
func generateRandomClientID(log *zerolog.Logger) (string, error) {
u, err := uuid.NewRandom()
if err != nil {
log.Error().Msgf("couldn't create UUID for client ID %s", err)
return "", err
configFlags = []string{
flags.AutoUpdateFreq,
flags.NoAutoUpdate,
flags.Retries,
flags.Protocol,
flags.LogLevel,
flags.TransportLogLevel,
flags.OriginCert,
flags.Metrics,
flags.MetricsUpdateFreq,
flags.EdgeIpVersion,
flags.EdgeBindAddress,
flags.MaxActiveFlows,
}
return u.String(), nil
}
)
func logClientOptions(c *cli.Context, log *zerolog.Logger) {
flags := make(map[string]interface{})
@ -109,8 +111,8 @@ func isSecretEnvVar(key string) bool {
}
func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.TunnelProperties) bool {
return c.IsSet("proxy-dns") &&
!(c.IsSet("name") || // adhoc-named tunnel
return c.IsSet(flags.ProxyDns) &&
!(c.IsSet(flags.Name) || // adhoc-named tunnel
c.IsSet(ingress.HelloWorldFlag) || // quick or named tunnel
namedTunnel != nil) // named tunnel
}
@ -128,29 +130,21 @@ func prepareTunnelConfig(
return nil, nil, errors.Wrap(err, "can't generate connector UUID")
}
log.Info().Msgf("Generated Connector ID: %s", clientID)
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
tags, err := NewTagSliceFromCLI(c.StringSlice(flags.Tag))
if err != nil {
log.Err(err).Msg("Tag parse failure")
return nil, nil, errors.Wrap(err, "Tag parse failure")
}
tags = append(tags, pogs.Tag{Name: "ID", Value: clientID.String()})
transportProtocol := c.String("protocol")
transportProtocol := c.String(flags.Protocol)
isPostQuantumEnforced := c.Bool(flags.PostQuantum)
clientFeatures := features.Dedup(append(c.StringSlice("features"), features.DefaultFeatures...))
staticFeatures := features.StaticFeatures{}
if c.Bool("post-quantum") {
if FipsEnabled {
return nil, nil, fmt.Errorf("post-quantum not supported in FIPS mode")
}
pqMode := features.PostQuantumStrict
staticFeatures.PostQuantumMode = &pqMode
}
featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, staticFeatures, log)
featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, c.StringSlice("features"), c.Bool("post-quantum"), log)
if err != nil {
return nil, nil, errors.Wrap(err, "Failed to create feature selector")
}
clientFeatures := featureSelector.ClientFeatures()
pqMode := featureSelector.PostQuantumMode()
if pqMode == features.PostQuantumStrict {
// Error if the user tries to force a non-quic transport protocol
@ -158,12 +152,6 @@ func prepareTunnelConfig(
return nil, nil, fmt.Errorf("post-quantum is only supported with the quic transport")
}
transportProtocol = connection.QUIC.String()
clientFeatures = append(clientFeatures, features.FeaturePostQuantum)
log.Info().Msgf(
"Using hybrid post-quantum key agreement %s",
supervisor.PQKexName,
)
}
namedTunnel.Client = pogs.ClientInfo{
@ -178,7 +166,7 @@ func prepareTunnelConfig(
return nil, nil, err
}
protocolSelector, err := connection.NewProtocolSelector(transportProtocol, namedTunnel.Credentials.AccountTag, c.IsSet(TunnelTokenFlag), c.Bool("post-quantum"), edgediscovery.ProtocolPercentage, connection.ResolveTTL, log)
protocolSelector, err := connection.NewProtocolSelector(transportProtocol, namedTunnel.Credentials.AccountTag, c.IsSet(TunnelTokenFlag), isPostQuantumEnforced, edgediscovery.ProtocolPercentage, connection.ResolveTTL, log)
if err != nil {
return nil, nil, err
}
@ -204,11 +192,11 @@ func prepareTunnelConfig(
if err != nil {
return nil, nil, err
}
edgeIPVersion, err := parseConfigIPVersion(c.String("edge-ip-version"))
edgeIPVersion, err := parseConfigIPVersion(c.String(flags.EdgeIpVersion))
if err != nil {
return nil, nil, err
}
edgeBindAddr, err := parseConfigBindAddress(c.String("edge-bind-address"))
edgeBindAddr, err := parseConfigBindAddress(c.String(flags.EdgeBindAddress))
if err != nil {
return nil, nil, err
}
@ -221,36 +209,50 @@ func prepareTunnelConfig(
log.Warn().Str("edgeIPVersion", edgeIPVersion.String()).Err(err).Msg("Overriding edge-ip-version")
}
region := c.String(flags.Region)
endpoint := namedTunnel.Credentials.Endpoint
var resolvedRegion string
// set resolvedRegion to either the region passed as argument
// or to the endpoint in the credentials.
// Region and endpoint are interchangeable
if region != "" && endpoint != "" {
return nil, nil, fmt.Errorf("region provided with a token that has an endpoint")
} else if region != "" {
resolvedRegion = region
} else if endpoint != "" {
resolvedRegion = endpoint
}
tunnelConfig := &supervisor.TunnelConfig{
GracePeriod: gracePeriod,
ReplaceExisting: c.Bool("force"),
ReplaceExisting: c.Bool(flags.Force),
OSArch: info.OSArch(),
ClientID: clientID.String(),
EdgeAddrs: c.StringSlice("edge"),
Region: c.String("region"),
EdgeAddrs: c.StringSlice(flags.Edge),
Region: resolvedRegion,
EdgeIPVersion: edgeIPVersion,
EdgeBindAddr: edgeBindAddr,
HAConnections: c.Int(haConnectionsFlag),
IsAutoupdated: c.Bool("is-autoupdated"),
LBPool: c.String("lb-pool"),
HAConnections: c.Int(flags.HaConnections),
IsAutoupdated: c.Bool(flags.IsAutoUpdated),
LBPool: c.String(flags.LBPool),
Tags: tags,
Log: log,
LogTransport: logTransport,
Observer: observer,
ReportedVersion: info.Version(),
// Note TUN-3758 , we use Int because UInt is not supported with altsrc
Retries: uint(c.Int("retries")),
Retries: uint(c.Int(flags.Retries)), // nolint: gosec
RunFromTerminal: isRunningFromTerminal(),
NamedTunnel: namedTunnel,
ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs,
FeatureSelector: featureSelector,
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
RPCTimeout: c.Duration(rpcTimeout),
WriteStreamTimeout: c.Duration(writeStreamTimeout),
DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery),
QUICConnectionLevelFlowControlLimit: c.Uint64(quicConnLevelFlowControlLimit),
QUICStreamLevelFlowControlLimit: c.Uint64(quicStreamLevelFlowControlLimit),
MaxEdgeAddrRetries: uint8(c.Int(flags.MaxEdgeAddrRetries)), // nolint: gosec
RPCTimeout: c.Duration(flags.RpcTimeout),
WriteStreamTimeout: c.Duration(flags.WriteStreamTimeout),
DisableQUICPathMTUDiscovery: c.Bool(flags.QuicDisablePathMTUDiscovery),
QUICConnectionLevelFlowControlLimit: c.Uint64(flags.QuicConnLevelFlowControlLimit),
QUICStreamLevelFlowControlLimit: c.Uint64(flags.QuicStreamLevelFlowControlLimit),
}
icmpRouter, err := newICMPRouter(c, log)
if err != nil {
@ -262,7 +264,7 @@ func prepareTunnelConfig(
Ingress: &ingressRules,
WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting),
ConfigurationFlags: parseConfigFlags(c),
WriteTimeout: c.Duration(writeStreamTimeout),
WriteTimeout: tunnelConfig.WriteStreamTimeout,
}
return tunnelConfig, orchestratorConfig, nil
}
@ -280,9 +282,9 @@ func parseConfigFlags(c *cli.Context) map[string]string {
}
func gracePeriod(c *cli.Context) (time.Duration, error) {
period := c.Duration("grace-period")
period := c.Duration(flags.GracePeriod)
if period > connection.MaxGracePeriod {
return time.Duration(0), fmt.Errorf("grace-period must be equal or less than %v", connection.MaxGracePeriod)
return time.Duration(0), fmt.Errorf("%s must be equal or less than %v", flags.GracePeriod, connection.MaxGracePeriod)
}
return period, nil
}
@ -365,14 +367,14 @@ func newICMPRouter(c *cli.Context, logger *zerolog.Logger) (ingress.ICMPRouterSe
}
func determineICMPSources(c *cli.Context, logger *zerolog.Logger) (netip.Addr, netip.Addr, error) {
ipv4Src, err := determineICMPv4Src(c.String("icmpv4-src"), logger)
ipv4Src, err := determineICMPv4Src(c.String(flags.ICMPV4Src), logger)
if err != nil {
return netip.Addr{}, netip.Addr{}, errors.Wrap(err, "failed to determine IPv4 source address for ICMP proxy")
}
logger.Info().Msgf("ICMP proxy will use %s as source for IPv4", ipv4Src)
ipv6Src, zone, err := determineICMPv6Src(c.String("icmpv6-src"), logger, ipv4Src)
ipv6Src, zone, err := determineICMPv6Src(c.String(flags.ICMPV6Src), logger, ipv4Src)
if err != nil {
return netip.Addr{}, netip.Addr{}, errors.Wrap(err, "failed to determine IPv6 source address for ICMP proxy")
}

View File

@ -4,6 +4,7 @@ import (
"fmt"
"path/filepath"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/credentials"
@ -57,7 +58,7 @@ func newSearchByID(id uuid.UUID, c *cli.Context, log *zerolog.Logger, fs fileSys
}
func (s searchByID) Path() (string, error) {
originCertPath := s.c.String(credentials.OriginCertFlag)
originCertPath := s.c.String(cfdflags.OriginCert)
originCertLog := s.log.With().
Str("originCertPath", originCertPath).
Logger()

View File

@ -1,3 +0,0 @@
package tunnel
var FipsEnabled bool

View File

@ -20,7 +20,31 @@ import (
const (
baseLoginURL = "https://dash.cloudflare.com/argotunnel"
callbackStoreURL = "https://login.cloudflareaccess.org/"
callbackURL = "https://login.cloudflareaccess.org/"
// For now these are the same but will change in the future once we know which URLs to use (TUN-8872)
fedBaseLoginURL = "https://dash.cloudflare.com/argotunnel"
fedCallbackStoreURL = "https://login.cloudflareaccess.org/"
fedRAMPParamName = "fedramp"
loginURLParamName = "loginURL"
callbackURLParamName = "callbackURL"
)
var (
loginURL = &cli.StringFlag{
Name: loginURLParamName,
Value: baseLoginURL,
Usage: "The URL used to login (default is https://dash.cloudflare.com/argotunnel)",
}
callbackStore = &cli.StringFlag{
Name: callbackURLParamName,
Value: callbackURL,
Usage: "The URL used for the callback (default is https://login.cloudflareaccess.org/)",
}
fedramp = &cli.BoolFlag{
Name: fedRAMPParamName,
Aliases: []string{"f"},
Usage: "Login with FedRAMP High environment.",
}
)
func buildLoginSubcommand(hidden bool) *cli.Command {
@ -30,6 +54,11 @@ func buildLoginSubcommand(hidden bool) *cli.Command {
Usage: "Generate a configuration file with your login details",
ArgsUsage: " ",
Hidden: hidden,
Flags: []cli.Flag{
loginURL,
callbackStore,
fedramp,
},
}
}
@ -38,15 +67,25 @@ func login(c *cli.Context) error {
path, ok, err := checkForExistingCert()
if ok {
fmt.Fprintf(os.Stdout, "You have an existing certificate at %s which login would overwrite.\nIf this is intentional, please move or delete that file then run this command again.\n", path)
log.Error().Err(err).Msgf("You have an existing certificate at %s which login would overwrite.\nIf this is intentional, please move or delete that file then run this command again.\n", path)
return nil
} else if err != nil {
return err
}
loginURL, err := url.Parse(baseLoginURL)
var (
baseloginURL = c.String(loginURLParamName)
callbackStoreURL = c.String(callbackURLParamName)
)
isFEDRamp := c.Bool(fedRAMPParamName)
if isFEDRamp {
baseloginURL = fedBaseLoginURL
callbackStoreURL = fedCallbackStoreURL
}
loginURL, err := url.Parse(baseloginURL)
if err != nil {
// shouldn't happen, URL is hardcoded
return err
}
@ -61,7 +100,23 @@ func login(c *cli.Context) error {
log,
)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to write the certificate due to the following error:\n%v\n\nYour browser will download the certificate instead. You will have to manually\ncopy it to the following path:\n\n%s\n", err, path)
log.Error().Err(err).Msgf("Failed to write the certificate.\n\nYour browser will download the certificate instead. You will have to manually\ncopy it to the following path:\n\n%s\n", path)
return err
}
cert, err := credentials.DecodeOriginCert(resourceData)
if err != nil {
log.Error().Err(err).Msg("failed to decode origin certificate")
return err
}
if isFEDRamp {
cert.Endpoint = credentials.FedEndpoint
}
resourceData, err = cert.EncodeOriginCert()
if err != nil {
log.Error().Err(err).Msg("failed to encode origin certificate")
return err
}
@ -69,7 +124,7 @@ func login(c *cli.Context) error {
return errors.Wrap(err, fmt.Sprintf("error writing cert to %s", path))
}
fmt.Fprintf(os.Stdout, "You have successfully logged in.\nIf you wish to copy your credentials to a server, they have been saved to:\n%s\n", path)
log.Info().Msgf("You have successfully logged in.\nIf you wish to copy your credentials to a server, they have been saved to:\n%s\n", path)
return nil
}

View File

@ -11,6 +11,7 @@ import (
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/connection"
)
@ -82,13 +83,13 @@ func RunQuickTunnel(sc *subcommandContext) error {
sc.log.Info().Msg(line)
}
if !sc.c.IsSet("protocol") {
sc.c.Set("protocol", "quic")
if !sc.c.IsSet(flags.Protocol) {
_ = sc.c.Set(flags.Protocol, "quic")
}
// Override the number of connections used. Quick tunnels shouldn't be used for production usage,
// so, use a single connection instead.
sc.c.Set(haConnectionsFlag, "1")
_ = sc.c.Set(flags.HaConnections, "1")
return StartServer(
sc.c,
buildInfo,

View File

@ -9,22 +9,26 @@ import (
"strings"
"github.com/google/uuid"
"github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/cfapi"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/credentials"
"github.com/cloudflare/cloudflared/logger"
)
type errInvalidJSONCredential struct {
const fedRampBaseApiURL = "https://api.fed.cloudflare.com/client/v4"
type invalidJSONCredentialError struct {
err error
path string
}
func (e errInvalidJSONCredential) Error() string {
func (e invalidJSONCredentialError) Error() string {
return "Invalid JSON when parsing tunnel credentials file"
}
@ -51,8 +55,13 @@ func newSubcommandContext(c *cli.Context) (*subcommandContext, error) {
// Returns something that can find the given tunnel's credentials file.
func (sc *subcommandContext) credentialFinder(tunnelID uuid.UUID) CredFinder {
if path := sc.c.String(CredFileFlag); path != "" {
// Expand path if CredFileFlag contains `~`
absPath, err := homedir.Expand(path)
if err != nil {
return newStaticPath(path, sc.fs)
}
return newStaticPath(absPath, sc.fs)
}
return newSearchByID(tunnelID, sc.c, sc.log, sc.fs)
}
@ -64,7 +73,16 @@ func (sc *subcommandContext) client() (cfapi.Client, error) {
if err != nil {
return nil, err
}
sc.tunnelstoreClient, err = cred.Client(sc.c.String("api-url"), buildInfo.UserAgent(), sc.log)
var apiURL string
if cred.IsFEDEndpoint() {
sc.log.Info().Str("api-url", fedRampBaseApiURL).Msg("using fedramp base api")
apiURL = fedRampBaseApiURL
} else {
apiURL = sc.c.String(cfdflags.ApiURL)
}
sc.tunnelstoreClient, err = cred.Client(apiURL, buildInfo.UserAgent(), sc.log)
if err != nil {
return nil, err
}
@ -73,7 +91,7 @@ func (sc *subcommandContext) client() (cfapi.Client, error) {
func (sc *subcommandContext) credential() (*credentials.User, error) {
if sc.userCredential == nil {
uc, err := credentials.Read(sc.c.String(credentials.OriginCertFlag), sc.log)
uc, err := credentials.Read(sc.c.String(cfdflags.OriginCert), sc.log)
if err != nil {
return nil, err
}
@ -94,13 +112,13 @@ func (sc *subcommandContext) readTunnelCredentials(credFinder CredFinder) (conne
var credentials connection.Credentials
if err = json.Unmarshal(body, &credentials); err != nil {
if strings.HasSuffix(filePath, ".pem") {
if filepath.Ext(filePath) == ".pem" {
return connection.Credentials{}, fmt.Errorf("The tunnel credentials file should be .json but you gave a .pem. " +
"The tunnel credentials file was originally created by `cloudflared tunnel create`. " +
"You may have accidentally used the filepath to cert.pem, which is generated by `cloudflared tunnel " +
"login`.")
}
return connection.Credentials{}, errInvalidJSONCredential{path: filePath, err: err}
return connection.Credentials{}, invalidJSONCredentialError{path: filePath, err: err}
}
return credentials, nil
}
@ -122,7 +140,7 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec
if err != nil {
return nil, errors.Wrap(err, "Couldn't decode tunnel secret from base64")
}
tunnelSecret = []byte(decodedSecret)
tunnelSecret = decodedSecret
if len(tunnelSecret) < 32 {
return nil, errors.New("Decoded tunnel secret must be at least 32 bytes long")
}
@ -160,7 +178,7 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string, sec
errorLines = append(errorLines, fmt.Sprintf("Cloudflared tried to delete the tunnel for you, but encountered an error. You should use `cloudflared tunnel delete %v` to delete the tunnel yourself, because the tunnel can't be run without the tunnelfile.", tunnel.ID))
errorLines = append(errorLines, fmt.Sprintf("The delete tunnel error is: %v", deleteErr))
} else {
errorLines = append(errorLines, fmt.Sprintf("The tunnel was deleted, because the tunnel can't be run without the credentials file"))
errorLines = append(errorLines, "The tunnel was deleted, because the tunnel can't be run without the credentials file")
}
errorMsg := strings.Join(errorLines, "\n")
return nil, errors.New(errorMsg)
@ -189,7 +207,7 @@ func (sc *subcommandContext) list(filter *cfapi.TunnelFilter) ([]*cfapi.Tunnel,
}
func (sc *subcommandContext) delete(tunnelIDs []uuid.UUID) error {
forceFlagSet := sc.c.Bool("force")
forceFlagSet := sc.c.Bool(cfdflags.Force)
client, err := sc.client()
if err != nil {
@ -229,7 +247,7 @@ func (sc *subcommandContext) findCredentials(tunnelID uuid.UUID) (connection.Cre
var err error
if credentialsContents := sc.c.String(CredContentsFlag); credentialsContents != "" {
if err = json.Unmarshal([]byte(credentialsContents), &credentials); err != nil {
err = errInvalidJSONCredential{path: "TUNNEL_CRED_CONTENTS", err: err}
err = invalidJSONCredentialError{path: "TUNNEL_CRED_CONTENTS", err: err}
}
} else {
credFinder := sc.credentialFinder(tunnelID)
@ -245,7 +263,7 @@ func (sc *subcommandContext) findCredentials(tunnelID uuid.UUID) (connection.Cre
func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
credentials, err := sc.findCredentials(tunnelID)
if err != nil {
if e, ok := err.(errInvalidJSONCredential); ok {
if e, ok := err.(invalidJSONCredentialError); ok {
sc.log.Error().Msgf("The credentials file at %s contained invalid JSON. This is probably caused by passing the wrong filepath. Reminder: the credentials file is a .json file created via `cloudflared tunnel create`.", e.path)
sc.log.Error().Msgf("Invalid JSON when parsing credentials file: %s", e.err.Error())
}

View File

@ -16,19 +16,21 @@ import (
"time"
"github.com/google/uuid"
homedir "github.com/mitchellh/go-homedir"
"github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"golang.org/x/net/idna"
yaml "gopkg.in/yaml.v3"
"gopkg.in/yaml.v3"
"github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/diagnostic"
"github.com/cloudflare/cloudflared/fips"
"github.com/cloudflare/cloudflared/metrics"
)
@ -39,6 +41,7 @@ const (
CredFileFlag = "credentials-file"
CredContentsFlag = "credentials-contents"
TunnelTokenFlag = "token"
TunnelTokenFileFlag = "token-file"
overwriteDNSFlagName = "overwrite-dns"
noDiagLogsFlagName = "no-diag-logs"
noDiagMetricsFlagName = "no-diag-metrics"
@ -47,7 +50,6 @@ const (
noDiagNetworkFlagName = "no-diag-network"
diagContainerIDFlagName = "diag-container-id"
diagPodFlagName = "diag-pod-id"
metricsFlagName = "metrics"
LogFieldTunnelID = "tunnelID"
)
@ -59,7 +61,7 @@ var (
Usage: "Include deleted tunnels in the list",
}
listNameFlag = &cli.StringFlag{
Name: "name",
Name: flags.Name,
Aliases: []string{"n"},
Usage: "List tunnels with the given `NAME`",
}
@ -107,7 +109,7 @@ var (
EnvVars: []string{"TUNNEL_LIST_INVERT_SORT"},
}
featuresFlag = altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "features",
Name: flags.Features,
Aliases: []string{"F"},
Usage: "Opt into various features that are still being developed or tested.",
})
@ -125,18 +127,23 @@ var (
})
tunnelTokenFlag = altsrc.NewStringFlag(&cli.StringFlag{
Name: TunnelTokenFlag,
Usage: "The Tunnel token. When provided along with credentials, this will take precedence.",
Usage: "The Tunnel token. When provided along with credentials, this will take precedence. Also takes precedence over token-file",
EnvVars: []string{"TUNNEL_TOKEN"},
})
tunnelTokenFileFlag = altsrc.NewStringFlag(&cli.StringFlag{
Name: TunnelTokenFileFlag,
Usage: "Filepath at which to read the tunnel token. When provided along with credentials, this will take precedence.",
EnvVars: []string{"TUNNEL_TOKEN_FILE"},
})
forceDeleteFlag = &cli.BoolFlag{
Name: "force",
Name: flags.Force,
Aliases: []string{"f"},
Usage: "Deletes a tunnel even if tunnel is connected and it has dependencies associated to it. (eg. IP routes)." +
" It is not possible to delete tunnels that have connections or non-deleted dependencies, without this flag.",
EnvVars: []string{"TUNNEL_RUN_FORCE_OVERWRITE"},
}
selectProtocolFlag = altsrc.NewStringFlag(&cli.StringFlag{
Name: "protocol",
Name: flags.Protocol,
Value: connection.AutoSelectFlag,
Aliases: []string{"p"},
Usage: fmt.Sprintf("Protocol implementation to connect with Cloudflare's edge network. %s", connection.AvailableProtocolFlagMessage),
@ -144,11 +151,11 @@ var (
Hidden: true,
})
postQuantumFlag = altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "post-quantum",
Name: flags.PostQuantum,
Usage: "When given creates an experimental post-quantum secure tunnel",
Aliases: []string{"pq"},
EnvVars: []string{"TUNNEL_POST_QUANTUM"},
Hidden: FipsEnabled,
Hidden: fips.IsFipsEnabled(),
})
sortInfoByFlag = &cli.StringFlag{
Name: "sort-by",
@ -180,17 +187,17 @@ var (
EnvVars: []string{"TUNNEL_CREATE_SECRET"},
}
icmpv4SrcFlag = &cli.StringFlag{
Name: "icmpv4-src",
Name: flags.ICMPV4Src,
Usage: "Source address to send/receive ICMPv4 messages. If not provided cloudflared will dial a local address to determine the source IP or fallback to 0.0.0.0.",
EnvVars: []string{"TUNNEL_ICMPV4_SRC"},
}
icmpv6SrcFlag = &cli.StringFlag{
Name: "icmpv6-src",
Name: flags.ICMPV6Src,
Usage: "Source address and the interface name to send/receive ICMPv6 messages. If not provided cloudflared will dial a local address to determine the source IP or fallback to ::.",
EnvVars: []string{"TUNNEL_ICMPV6_SRC"},
}
metricsFlag = &cli.StringFlag{
Name: metricsFlagName,
Name: flags.Metrics,
Usage: "The metrics server address i.e.: 127.0.0.1:12345. If your instance is running in a Docker/Kubernetes environment you need to setup port forwarding for your application.",
Value: "",
}
@ -229,6 +236,11 @@ var (
Usage: "Network diagnostics won't be performed",
Value: false,
}
maxActiveFlowsFlag = &cli.Uint64Flag{
Name: flags.MaxActiveFlows,
Usage: "Overrides the remote configuration for max active private network flows (TCP/UDP) that this cloudflared instance supports",
EnvVars: []string{"TUNNEL_MAX_ACTIVE_FLOWS"},
}
)
func buildCreateCommand() *cli.Command {
@ -331,7 +343,7 @@ func listCommand(c *cli.Context) error {
if !c.Bool("show-deleted") {
filter.NoDeleted()
}
if name := c.String("name"); name != "" {
if name := c.String(flags.Name); name != "" {
filter.ByName(name)
}
if namePrefix := c.String("name-prefix"); namePrefix != "" {
@ -441,7 +453,7 @@ func fmtConnections(connections []cfapi.Connection, showRecentlyDisconnected boo
sort.Strings(sortedColos)
// Map each colo to its frequency, combine into output string.
var output []string
output := make([]string, 0, len(sortedColos))
for _, coloName := range sortedColos {
output = append(output, fmt.Sprintf("%dx%s", numConnsPerColo[coloName], coloName))
}
@ -461,16 +473,21 @@ func buildReadyCommand() *cli.Command {
}
func readyCommand(c *cli.Context) error {
metricsOpts := c.String("metrics")
if !c.IsSet("metrics") {
return fmt.Errorf("--metrics has to be provided")
metricsOpts := c.String(flags.Metrics)
if !c.IsSet(flags.Metrics) {
return errors.New("--metrics has to be provided")
}
requestURL := fmt.Sprintf("http://%s/ready", metricsOpts)
res, err := http.Get(requestURL)
req, err := http.NewRequest(http.MethodGet, requestURL, nil)
if err != nil {
return err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != 200 {
body, err := io.ReadAll(res.Body)
if err != nil {
@ -697,8 +714,10 @@ func buildRunCommand() *cli.Command {
selectProtocolFlag,
featuresFlag,
tunnelTokenFlag,
tunnelTokenFileFlag,
icmpv4SrcFlag,
icmpv6SrcFlag,
maxActiveFlowsFlag,
}
flags = append(flags, configureProxyFlags(false)...)
return &cli.Command{
@ -736,12 +755,22 @@ func runCommand(c *cli.Context) error {
"your origin will not be reachable. You should remove the `hostname` property to avoid this warning.")
}
tokenStr := c.String(TunnelTokenFlag)
// Check if tokenStr is blank before checking for tokenFile
if tokenStr == "" {
if tokenFile := c.String(TunnelTokenFileFlag); tokenFile != "" {
data, err := os.ReadFile(tokenFile)
if err != nil {
return cliutil.UsageError("Failed to read token file: " + err.Error())
}
tokenStr = strings.TrimSpace(string(data))
}
}
// Check if token is provided and if not use default tunnelID flag method
if tokenStr := c.String(TunnelTokenFlag); tokenStr != "" {
if tokenStr != "" {
if token, err := ParseToken(tokenStr); err == nil {
return sc.runWithCredentials(token.Credentials())
}
return cliutil.UsageError("Provided Tunnel token is not valid.")
} else {
tunnelRef := c.Args().First()
@ -1067,7 +1096,7 @@ func diagCommand(ctx *cli.Context) error {
log := sctx.log
options := diagnostic.Options{
KnownAddresses: metrics.GetMetricsKnownAddresses(metrics.Runtime),
Address: sctx.c.String(metricsFlagName),
Address: sctx.c.String(flags.Metrics),
ContainerID: sctx.c.String(diagContainerIDFlagName),
PodID: sctx.c.String(diagPodFlagName),
Toggles: diagnostic.Toggles{

View File

@ -22,7 +22,7 @@ var (
Usage: "The ID or name of the virtual network to which the route is associated to.",
}
routeAddError = errors.New("You must supply exactly one argument, the ID or CIDR of the route you want to delete")
errAddRoute = errors.New("You must supply exactly one argument, the ID or CIDR of the route you want to delete")
)
func buildRouteIPSubcommand() *cli.Command {
@ -32,7 +32,7 @@ func buildRouteIPSubcommand() *cli.Command {
UsageText: "cloudflared tunnel [--config FILEPATH] route COMMAND [arguments...]",
Description: `cloudflared can provision routes for any IP space in your corporate network. Users enrolled in
your Cloudflare for Teams organization can reach those IPs through the Cloudflare WARP
client. You can then configure L7/L4 filtering on https://dash.teams.cloudflare.com to
client. You can then configure L7/L4 filtering on https://one.dash.cloudflare.com to
determine who can reach certain routes.
By default IP routes all exist within a single virtual network. If you use the same IP
space(s) in different physical private networks, all meant to be reachable via IP routes,
@ -187,7 +187,7 @@ func deleteRouteCommand(c *cli.Context) error {
}
if c.NArg() != 1 {
return routeAddError
return errAddRoute
}
var routeId uuid.UUID
@ -195,7 +195,7 @@ func deleteRouteCommand(c *cli.Context) error {
if err != nil {
_, network, err := net.ParseCIDR(c.Args().First())
if err != nil || network == nil {
return routeAddError
return errAddRoute
}
var vnetId *uuid.UUID

View File

@ -15,13 +15,14 @@ import (
"golang.org/x/term"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/logger"
)
const (
DefaultCheckUpdateFreq = time.Hour * 24
noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/as-a-service/"
noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/configure-tunnels/local-management/as-a-service/"
noUpdateOnWindowsMessage = "cloudflared will not automatically update on Windows systems."
noUpdateManagedPackageMessage = "cloudflared will not automatically update if installed by a package manager."
isManagedInstallFile = ".installedFromPackageManager"
@ -38,6 +39,7 @@ var (
// BinaryUpdated implements ExitCoder interface, the app will exit with status code 11
// https://pkg.go.dev/github.com/urfave/cli/v2?tab=doc#ExitCoder
// nolint: errname
type statusSuccess struct {
newVersion string
}
@ -50,16 +52,16 @@ func (u *statusSuccess) ExitCode() int {
return 11
}
// UpdateErr implements ExitCoder interface, the app will exit with status code 10
type statusErr struct {
// statusError implements ExitCoder interface, the app will exit with status code 10
type statusError struct {
err error
}
func (e *statusErr) Error() string {
func (e *statusError) Error() string {
return fmt.Sprintf("failed to update cloudflared: %v", e.err)
}
func (e *statusErr) ExitCode() int {
func (e *statusError) ExitCode() int {
return 10
}
@ -79,7 +81,7 @@ type UpdateOutcome struct {
}
func (uo *UpdateOutcome) noUpdate() bool {
return uo.Error == nil && uo.Updated == false
return uo.Error == nil && !uo.Updated
}
func Init(info *cliutil.BuildInfo) {
@ -153,7 +155,7 @@ func Update(c *cli.Context) error {
log.Info().Msg("cloudflared is set to update from staging")
}
isForced := c.Bool("force")
isForced := c.Bool(cfdflags.Force)
if isForced {
log.Info().Msg("cloudflared is set to upgrade to the latest publish version regardless of the current version")
}
@ -166,7 +168,7 @@ func Update(c *cli.Context) error {
intendedVersion: c.String("version"),
})
if updateOutcome.Error != nil {
return &statusErr{updateOutcome.Error}
return &statusError{updateOutcome.Error}
}
if updateOutcome.noUpdate() {
@ -252,7 +254,7 @@ func (a *AutoUpdater) Run(ctx context.Context) error {
pid, err := a.listeners.StartProcess()
if err != nil {
a.log.Err(err).Msg("Unable to restart server automatically")
return &statusErr{err: err}
return &statusError{err: err}
}
// stop old process after autoupdate. Otherwise we create a new process
// after each update

View File

@ -10,9 +10,9 @@ import (
"net/url"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
"strings"
"text/template"
"time"
@ -134,7 +134,7 @@ func (v *WorkersVersion) Apply() error {
if err := os.Rename(newFilePath, v.targetPath); err != nil {
//attempt rollback
os.Rename(oldFilePath, v.targetPath)
_ = os.Rename(oldFilePath, v.targetPath)
return err
}
os.Remove(oldFilePath)
@ -181,7 +181,7 @@ func download(url, filepath string, isCompressed bool) error {
tr := tar.NewReader(gr)
// advance the reader pass the header, which will be the single binary file
tr.Next()
_, _ = tr.Next()
r = tr
}
@ -198,7 +198,7 @@ func download(url, filepath string, isCompressed bool) error {
// isCompressedFile is a really simple file extension check to see if this is a macos tar and gzipped
func isCompressedFile(urlstring string) bool {
if strings.HasSuffix(urlstring, ".tgz") {
if path.Ext(urlstring) == ".tgz" {
return true
}
@ -206,7 +206,7 @@ func isCompressedFile(urlstring string) bool {
if err != nil {
return false
}
return strings.HasSuffix(u.Path, ".tgz")
return path.Ext(u.Path) == ".tgz"
}
// writeBatchFile writes a batch file out to disk
@ -249,7 +249,6 @@ func runWindowsBatch(batchFile string) error {
if exitError, ok := err.(*exec.ExitError); ok {
return fmt.Errorf("Error during update : %s;", string(exitError.Stderr))
}
}
return err
}

View File

@ -26,7 +26,7 @@ import (
const (
windowsServiceName = "Cloudflared"
windowsServiceDescription = "Cloudflared agent"
windowsServiceUrl = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/as-a-service/windows/"
windowsServiceUrl = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/configure-tunnels/local-management/as-a-service/windows/"
recoverActionDelay = time.Second * 20
failureCountResetPeriod = time.Hour * 24

View File

@ -1,7 +1,6 @@
from util import LOGGER, nofips, start_cloudflared, wait_tunnel_ready
from util import LOGGER, start_cloudflared, wait_tunnel_ready
@nofips
class TestPostQuantum:
def _extra_config(self):
config = {
@ -12,6 +11,11 @@ class TestPostQuantum:
def test_post_quantum(self, tmp_path, component_tests_config):
config = component_tests_config(self._extra_config())
LOGGER.debug(config)
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], cfd_args=["run", "--post-quantum"], new_process=True):
wait_tunnel_ready(tunnel_url=config.get_url(),
require_min_connections=1)
with start_cloudflared(
tmp_path,
config,
cfd_pre_args=["tunnel", "--ha-connections", "1"],
cfd_args=["run", "--post-quantum"],
new_process=True,
):
wait_tunnel_ready(tunnel_url=config.get_url(), require_min_connections=1)

View File

@ -155,7 +155,7 @@ func FindOrCreateConfigPath() string {
// i.e. it fails if a user specifies both --url and --unix-socket
func ValidateUnixSocket(c *cli.Context) (string, error) {
if c.IsSet("unix-socket") && (c.IsSet("url") || c.NArg() > 0) {
return "", errors.New("--unix-socket must be used exclusivly.")
return "", errors.New("--unix-socket must be used exclusively.")
}
return c.String("unix-socket"), nil
}
@ -260,6 +260,7 @@ type Configuration struct {
type WarpRoutingConfig struct {
ConnectTimeout *CustomDuration `yaml:"connectTimeout" json:"connectTimeout,omitempty"`
MaxActiveFlows *uint64 `yaml:"maxActiveFlows" json:"maxActiveFlows,omitempty"`
TCPKeepAlive *CustomDuration `yaml:"tcpKeepAlive" json:"tcpKeepAlive,omitempty"`
}

View File

@ -60,6 +60,7 @@ type Credentials struct {
AccountTag string
TunnelSecret []byte
TunnelID uuid.UUID
Endpoint string
}
func (c *Credentials) Auth() pogs.TunnelAuth {
@ -74,13 +75,16 @@ type TunnelToken struct {
AccountTag string `json:"a"`
TunnelSecret []byte `json:"s"`
TunnelID uuid.UUID `json:"t"`
Endpoint string `json:"e,omitempty"`
}
func (t TunnelToken) Credentials() Credentials {
// nolint: gosimple
return Credentials{
AccountTag: t.AccountTag,
TunnelSecret: t.TunnelSecret,
TunnelID: t.TunnelID,
Endpoint: t.Endpoint,
}
}

View File

@ -2,14 +2,18 @@ package connection
import (
"context"
"crypto/rand"
"fmt"
"io"
"math/rand"
"math/big"
"net/http"
"time"
pkgerrors "github.com/pkg/errors"
"github.com/rs/zerolog"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/stream"
"github.com/cloudflare/cloudflared/tracing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -77,7 +81,7 @@ func (moc *mockOriginProxy) ProxyHTTP(
return wsFlakyEndpoint(w, req)
default:
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
return fmt.Errorf("unknown websocket endpoint %s", req.URL.Path)
}
}
switch req.URL.Path {
@ -95,7 +99,6 @@ func (moc *mockOriginProxy) ProxyHTTP(
originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
}
return nil
}
func (moc *mockOriginProxy) ProxyTCP(
@ -103,6 +106,10 @@ func (moc *mockOriginProxy) ProxyTCP(
rwa ReadWriteAcker,
r *TCPRequest,
) error {
if r.CfTraceID == "flow-rate-limited" {
return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "tcp flow rate limited")
}
return nil
}
@ -178,7 +185,8 @@ func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
rInt, _ := rand.Int(rand.Reader, big.NewInt(50))
closedAfter := time.Millisecond * time.Duration(rInt.Int64())
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
stream.Pipe(wsConn, originConn, &log)
cancel()

View File

@ -22,8 +22,9 @@ var (
var (
// pre-generate possible values for res
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared", false)
responseMetaHeaderCfdFlowRateLimited = mustInitRespMetaHeader("cloudflared", true)
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin", false)
)
// HTTPHeader is a custom header struct that expects only ever one value for the header.
@ -35,10 +36,11 @@ type HTTPHeader struct {
type responseMetaHeader struct {
Source string `json:"src"`
FlowRateLimited bool `json:"flow_rate_limited,omitempty"`
}
func mustInitRespMetaHeader(src string) string {
header, err := json.Marshal(responseMetaHeader{Source: src})
func mustInitRespMetaHeader(src string, flowRateLimited bool) string {
header, err := json.Marshal(responseMetaHeader{Source: src, FlowRateLimited: flowRateLimited})
if err != nil {
panic(fmt.Sprintf("Failed to serialize response meta header = %s, err: %v", src, err))
}
@ -112,7 +114,7 @@ func SerializeHeaders(h1Headers http.Header) string {
func DeserializeHeaders(serializedHeaders string) ([]HTTPHeader, error) {
const unableToDeserializeErr = "Unable to deserialize headers"
var deserialized []HTTPHeader
deserialized := make([]HTTPHeader, 0)
for _, serializedPair := range strings.Split(serializedHeaders, ";") {
if len(serializedPair) == 0 {
continue

View File

@ -16,6 +16,8 @@ import (
"github.com/rs/zerolog"
"golang.org/x/net/http2"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/tracing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
@ -156,7 +158,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.log.Error().Err(requestErr).Msg("failed to serve incoming request")
// WriteErrorResponse will return false if status was already written. we need to abort handler.
if !respWriter.WriteErrorResponse() {
if !respWriter.WriteErrorResponse(requestErr) {
c.log.Debug().Msg("Handler aborted due to failure to write error response after status already sent")
panic(http.ErrAbortHandler)
}
@ -209,8 +211,9 @@ func NewHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type, l
w: w,
log: log,
}
respWriter.WriteErrorResponse()
return nil, fmt.Errorf("%T doesn't implement http.Flusher", w)
err := fmt.Errorf("%T doesn't implement http.Flusher", w)
respWriter.WriteErrorResponse(err)
return nil, err
}
return &http2RespWriter{
@ -295,7 +298,7 @@ func (rp *http2RespWriter) WriteHeader(status int) {
rp.log.Warn().Msg("WriteHeader after hijack")
return
}
rp.WriteRespHeaders(status, rp.respHeaders)
_ = rp.WriteRespHeaders(status, rp.respHeaders)
}
func (rp *http2RespWriter) hijacked() bool {
@ -328,12 +331,16 @@ func (rp *http2RespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return conn, readWriter, nil
}
func (rp *http2RespWriter) WriteErrorResponse() bool {
func (rp *http2RespWriter) WriteErrorResponse(err error) bool {
if rp.statusWritten {
return false
}
if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
rp.setResponseMetaHeader(responseMetaHeaderCfdFlowRateLimited)
} else {
rp.setResponseMetaHeader(responseMetaHeaderCfd)
}
rp.w.WriteHeader(http.StatusBadGateway)
rp.statusWritten = true

View File

@ -20,6 +20,8 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"
"github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
@ -65,19 +67,18 @@ func TestHTTP2ConfigurationSet(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
endpoint := fmt.Sprintf("http://localhost:8080/ok")
reqBody := []byte(`{
"version": 2,
"config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}}
`)
reader := bytes.NewReader(reqBody)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, reader)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://localhost:8080/ok", reader)
require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate)
@ -85,11 +86,11 @@ func TestHTTP2ConfigurationSet(t *testing.T) {
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
bdy, err := io.ReadAll(resp.Body)
defer resp.Body.Close()
require.NoError(t, err)
assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy))
cancel()
wg.Wait()
}
func TestServeHTTP(t *testing.T) {
@ -134,7 +135,7 @@ func TestServeHTTP(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
@ -153,6 +154,7 @@ func TestServeHTTP(t *testing.T) {
require.NoError(t, err)
require.Equal(t, test.expectedBody, respBody)
}
_ = resp.Body.Close()
if test.isProxyError {
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
} else {
@ -281,10 +283,11 @@ func TestServeWS(t *testing.T) {
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
require.NoError(t, err)
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
require.Equal(t, data, respBody, "expect %s, got %s", string(data), string(respBody))
cancel()
resp := respWriter.Result()
defer resp.Body.Close()
// http2RespWriter should rewrite status 101 to 200
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
@ -304,7 +307,7 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
serverDone := make(chan struct{})
go func() {
defer close(serverDone)
cfdHTTP2Conn.Serve(ctx)
_ = cfdHTTP2Conn.Serve(ctx)
}()
edgeTransport := http2.Transport{}
@ -319,13 +322,16 @@ func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
readPipe, writePipe := io.Pipe()
reqCtx, reqCancel := context.WithCancel(ctx)
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
require.NoError(t, err)
assert.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
assert.NoError(t, err)
_ = resp.Body.Close()
// http2RespWriter should rewrite status 101 to 200
require.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, http.StatusOK, resp.StatusCode)
wg.Add(1)
go func() {
@ -378,7 +384,7 @@ func TestServeControlStream(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
@ -391,7 +397,8 @@ func TestServeControlStream(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
edgeHTTP2Conn.RoundTrip(req)
// nolint: bodyclose
_, _ = edgeHTTP2Conn.RoundTrip(req)
}()
<-rpcClientFactory.registered
@ -431,7 +438,7 @@ func TestFailRegistration(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
@ -442,9 +449,10 @@ func TestFailRegistration(t *testing.T) {
require.NoError(t, err)
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
assert.NotNil(t, http2Conn.controlStreamErr)
require.Error(t, http2Conn.controlStreamErr)
cancel()
wg.Wait()
}
@ -481,7 +489,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
@ -494,6 +502,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
// nolint: bodyclose
_, _ = edgeHTTP2Conn.RoundTrip(req)
}()
@ -524,6 +533,36 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
})
}
func TestServeTCP_RateLimited(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
http2Conn, edgeConn := newTestHTTP2Connection()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = http2Conn.Serve(ctx)
}()
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080", nil)
require.NoError(t, err)
req.Header.Set(InternalTCPProxySrcHeader, "tcp")
req.Header.Set(tracing.TracerContextName, "flow-rate-limited")
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
require.Equal(t, responseMetaHeaderCfdFlowRateLimited, resp.Header.Get(ResponseMetaHeader))
cancel()
wg.Wait()
}
func benchmarkServeHTTP(b *testing.B, test testRequest) {
http2Conn, edgeConn := newTestHTTP2Connection()
@ -532,7 +571,7 @@ func benchmarkServeHTTP(b *testing.B, test testRequest) {
wg.Add(1)
go func() {
defer wg.Done()
http2Conn.Serve(ctx)
_ = http2Conn.Serve(ctx)
}()
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)

View File

@ -14,7 +14,7 @@ import (
const (
AvailableProtocolFlagMessage = "Available protocols: 'auto' - automatically chooses the best protocol over time (the default; and also the recommended one); 'quic' - based on QUIC, relying on UDP egress to Cloudflare edge; 'http2' - using Go's HTTP2 library, relying on TCP egress to Cloudflare edge"
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge (unused, but kept for legacy reference).
edgeH2muxTLSServerName = "cftunnel.com"
_ = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com"
// edgeQUICServerName is the server name to establish quic connection with edge.
@ -24,11 +24,9 @@ const (
ResolveTTL = time.Hour
)
var (
// ProtocolList represents a list of supported protocols for communication with the edge
// in order of precedence for remote percentage fetcher.
ProtocolList = []Protocol{QUIC, HTTP2}
)
// ProtocolList represents a list of supported protocols for communication with the edge
// in order of precedence for remote percentage fetcher.
var ProtocolList = []Protocol{QUIC, HTTP2}
type Protocol int64
@ -58,7 +56,7 @@ func (p Protocol) String() string {
case QUIC:
return "quic"
default:
return fmt.Sprintf("unknown protocol")
return "unknown protocol"
}
}
@ -246,11 +244,11 @@ func NewProtocolSelector(
return newRemoteProtocolSelector(fetchedProtocol, ProtocolList, threshold, protocolFetcher, resolveTTL, log), nil
}
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
return nil, fmt.Errorf("unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
}
func switchThreshold(accountTag string) int32 {
h := fnv.New32a()
_, _ = h.Write([]byte(accountTag))
return int32(h.Sum32() % 100)
return int32(h.Sum32() % 100) // nolint: gosec
}

View File

@ -17,6 +17,8 @@ import (
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
cfdflow "github.com/cloudflare/cloudflared/flow"
cfdquic "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -101,14 +103,19 @@ func (q *quicConnection) Serve(ctx context.Context) error {
// amount of the grace period, allowing requests to finish before we cancel the context, which will
// make cloudflared exit.
if err := q.serveControlStream(ctx, controlStream); err == nil {
if q.gracePeriod > 0 {
// In Go1.23 this can be removed and replaced with time.Ticker
// see https://pkg.go.dev/time#Tick
ticker := time.NewTicker(q.gracePeriod)
defer ticker.Stop()
select {
case <-ctx.Done():
case <-time.Tick(q.gracePeriod):
case <-ticker.C:
}
}
}
cancel()
return err
})
errGroup.Go(func() error {
defer cancel()
@ -129,7 +136,7 @@ func (q *quicConnection) serveControlStream(ctx context.Context, controlStream q
// Close the connection with no errors specified.
func (q *quicConnection) Close() {
q.conn.CloseWithError(0, "")
_ = q.conn.CloseWithError(0, "")
}
func (q *quicConnection) acceptStream(ctx context.Context) error {
@ -182,7 +189,13 @@ func (q *quicConnection) handleDataStream(ctx context.Context, stream *rpcquic.R
return err
}
if writeRespErr := stream.WriteConnectResponseData(err); writeRespErr != nil {
var metadata []pogs.Metadata
// Check the type of error that was throw and add metadata that will help identify it on OTD.
if errors.Is(err, cfdflow.ErrTooManyActiveFlows) {
metadata = append(metadata, pogs.ErrorFlowConnectRateLimitedMetadata)
}
if writeRespErr := stream.WriteConnectResponseData(err, metadata...); writeRespErr != nil {
return writeRespErr
}
}
@ -278,7 +291,7 @@ func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header)
func (hrw *httpResponseAdapter) Write(p []byte) (int, error) {
// Make sure to send WriteHeader response if not called yet
if !hrw.connectResponseSent {
hrw.WriteRespHeaders(http.StatusOK, hrw.headers)
_ = hrw.WriteRespHeaders(http.StatusOK, hrw.headers)
}
return hrw.RequestServerStream.Write(p)
}
@ -291,7 +304,7 @@ func (hrw *httpResponseAdapter) Header() http.Header {
func (hrw *httpResponseAdapter) Flush() {}
func (hrw *httpResponseAdapter) WriteHeader(status int) {
hrw.WriteRespHeaders(status, hrw.headers)
_ = hrw.WriteRespHeaders(status, hrw.headers)
}
func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
@ -304,7 +317,7 @@ func (hrw *httpResponseAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
}
func (hrw *httpResponseAdapter) WriteErrorResponse(err error) {
hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
_ = hrw.WriteConnectResponseData(err, pogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
}
func (hrw *httpResponseAdapter) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error {

View File

@ -8,6 +8,7 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"math/big"
@ -21,13 +22,15 @@ import (
"github.com/gobwas/ws/wsutil"
"github.com/google/uuid"
"github.com/pkg/errors"
pkgerrors "github.com/pkg/errors"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/nettest"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/datagramsession"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/packet"
@ -53,7 +56,8 @@ var _ ReadWriteAcker = (*streamReadWriteAcker)(nil)
func TestQUICServer(t *testing.T) {
// This is simply a sample websocket frame message.
wsBuf := &bytes.Buffer{}
wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
err := wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
require.NoError(t, err)
var tests = []struct {
desc string
@ -158,17 +162,19 @@ func TestQUICServer(t *testing.T) {
serverDone := make(chan struct{})
go func() {
// nolint: testifylint
quicServer(
ctx, t, quicListener, test.dest, test.connectionType, test.metadata, test.message, test.expectedResponse,
)
close(serverDone)
}()
// nolint: gosec
tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(i))
connDone := make(chan struct{})
go func() {
tunnelConn.Serve(ctx)
_ = tunnelConn.Serve(ctx)
close(connDone)
}()
@ -254,14 +260,14 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, tr *tracing.T
case "/ok":
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
case "/slow_echo_body":
time.Sleep(5)
time.Sleep(5 * time.Nanosecond)
fallthrough
case "/echo_body":
resp := &http.Response{
StatusCode: http.StatusOK,
}
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
io.Copy(w, r.Body)
_, _ = io.Copy(w, r.Body)
case "/error":
return fmt.Errorf("Failed to proxy to origin")
default:
@ -493,16 +499,20 @@ func TestBuildHTTPRequest(t *testing.T) {
test := test // capture range variable
t.Run(test.name, func(t *testing.T) {
req, err := buildHTTPRequest(context.Background(), test.connectRequest, test.body, 0, &log)
assert.NoError(t, err)
require.NoError(t, err)
test.req = test.req.WithContext(req.Context())
assert.Equal(t, test.req, req.Request)
require.Equal(t, test.req, req.Request)
})
}
}
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
rwa.AckConnection("")
io.Copy(rwa, rwa)
if tcpRequest.Dest == "rate-limit-me" {
return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "failed tcp stream")
}
_ = rwa.AckConnection("")
_, _ = io.Copy(rwa, rwa)
return nil
}
@ -520,16 +530,19 @@ func TestServeUDPSession(t *testing.T) {
edgeQUICSessionChan := make(chan quic.Connection)
go func() {
earlyListener, err := quic.Listen(udpListener, testTLSServerConfig, testQUICConfig)
require.NoError(t, err)
assert.NoError(t, err)
edgeQUICSession, err := earlyListener.Accept(ctx)
require.NoError(t, err)
assert.NoError(t, err)
edgeQUICSessionChan <- edgeQUICSession
}()
// Random index to avoid reusing port
tunnelConn, datagramConn := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), 28)
go tunnelConn.Serve(ctx)
go func() {
_ = tunnelConn.Serve(ctx)
}()
edgeQUICSession := <-edgeQUICSessionChan
@ -545,14 +558,14 @@ func TestNopCloserReadWriterCloseBeforeEOF(t *testing.T) {
n, err := readerWriter.Read(buffer)
require.NoError(t, err)
require.Equal(t, n, 5)
require.Equal(t, 5, n)
// close
require.NoError(t, readerWriter.Close())
// read should get error
n, err = readerWriter.Read(buffer)
require.Equal(t, n, 0)
require.Equal(t, 0, n)
require.Equal(t, err, fmt.Errorf("closed by handler"))
}
@ -562,7 +575,7 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
n, err := readerWriter.Read(buffer)
require.NoError(t, err)
require.Equal(t, n, 9)
require.Equal(t, 9, n)
// force another read to read eof
_, err = readerWriter.Read(buffer)
@ -573,7 +586,7 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
// read should get EOF still
n, err = readerWriter.Read(buffer)
require.Equal(t, n, 0)
require.Equal(t, 0, n)
require.Equal(t, err, io.EOF)
}
@ -589,6 +602,59 @@ func TestCreateUDPConnReuseSourcePort(t *testing.T) {
}
}
// TestTCPProxy_FlowRateLimited tests if the pogs.ConnectResponse returns the expected error and metadata, when a
// new flow is rate limited.
func TestTCPProxy_FlowRateLimited(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
// Start a UDP Listener for QUIC.
udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
require.NoError(t, err)
udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr)
require.NoError(t, err)
defer udpListener.Close()
quicTransport := &quic.Transport{Conn: udpListener, ConnectionIDLength: 16}
quicListener, err := quicTransport.Listen(testTLSServerConfig, testQUICConfig)
require.NoError(t, err)
serverDone := make(chan struct{})
go func() {
defer close(serverDone)
session, err := quicListener.Accept(ctx)
assert.NoError(t, err)
quicStream, err := session.OpenStreamSync(context.Background())
assert.NoError(t, err)
stream := cfdquic.NewSafeStreamCloser(quicStream, defaultQUICTimeout, &log)
reqClientStream := rpcquic.RequestClientStream{ReadWriteCloser: stream}
err = reqClientStream.WriteConnectRequestData("rate-limit-me", pogs.ConnectionTypeTCP)
assert.NoError(t, err)
response, err := reqClientStream.ReadConnectResponseData()
assert.NoError(t, err)
// Got Rate Limited
assert.NotEmpty(t, response.Error)
assert.Contains(t, response.Metadata, pogs.ErrorFlowConnectRateLimitedMetadata)
}()
tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(0))
connDone := make(chan struct{})
go func() {
defer close(connDone)
_ = tunnelConn.Serve(ctx)
}()
<-serverDone
cancel()
<-connDone
}
func testCreateUDPConnReuseSourcePortForEdgeIP(t *testing.T, edgeIP netip.AddrPort) {
logger := zerolog.Nop()
conn, err := createUDPConnForConnIndex(0, nil, edgeIP, &logger)
@ -669,6 +735,7 @@ func serveSession(ctx context.Context, datagramConn *datagramV2Connection, edgeQ
unregisterReason: expectedReason,
calledUnregisterChan: unregisterFromEdgeChan,
}
// nolint: testifylint
go runRPCServer(ctx, edgeQUICSession, sessionRPCServer, nil, t)
<-unregisterFromEdgeChan
@ -729,6 +796,7 @@ func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionI
func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) (TunnelConnection, *datagramV2Connection) {
tlsClientConfig := &tls.Config{
// nolint: gosec
InsecureSkipVerify: true,
NextProtos: []string{"argotunnel"},
}
@ -747,6 +815,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
index,
&log,
)
require.NoError(t, err)
// Start a session manager for the connection
sessionDemuxChan := make(chan *packet.Session, 4)
@ -757,7 +826,9 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
datagramConn := &datagramV2Connection{
conn,
index,
sessionManager,
cfdflow.NewLimiter(0),
datagramMuxer,
packetRouter,
15 * time.Second,
@ -796,6 +867,7 @@ func (m *mockReaderNoopWriter) Close() error {
// GenerateTLSConfig sets up a bare-bones TLS config for a QUIC server
func GenerateTLSConfig() *tls.Config {
// nolint: gosec
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
@ -812,6 +884,7 @@ func GenerateTLSConfig() *tls.Config {
if err != nil {
panic(err)
}
// nolint: gosec
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{"argotunnel"},

View File

@ -7,12 +7,15 @@ import (
"time"
"github.com/google/uuid"
pkgerrors "github.com/pkg/errors"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/datagramsession"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/management"
@ -39,9 +42,13 @@ type DatagramSessionHandler interface {
type datagramV2Connection struct {
conn quic.Connection
index uint8
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
sessionManager datagramsession.Manager
// flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
flowLimiter cfdflow.Limiter
// datagramMuxer mux/demux datagrams from quic connection
datagramMuxer *cfdquic.DatagramMuxerV2
packetRouter *ingress.PacketRouter
@ -58,6 +65,7 @@ func NewDatagramV2Connection(ctx context.Context,
index uint8,
rpcTimeout time.Duration,
streamWriteTimeout time.Duration,
flowLimiter cfdflow.Limiter,
logger *zerolog.Logger,
) DatagramSessionHandler {
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
@ -66,13 +74,15 @@ func NewDatagramV2Connection(ctx context.Context,
packetRouter := ingress.NewPacketRouter(icmpRouter, datagramMuxer, index, logger)
return &datagramV2Connection{
conn,
sessionManager,
datagramMuxer,
packetRouter,
rpcTimeout,
streamWriteTimeout,
logger,
conn: conn,
index: index,
sessionManager: sessionManager,
flowLimiter: flowLimiter,
datagramMuxer: datagramMuxer,
packetRouter: packetRouter,
rpcTimeout: rpcTimeout,
streamWriteTimeout: streamWriteTimeout,
logger: logger,
}
}
@ -109,12 +119,23 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
attribute.String("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)),
))
log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger()
// Try to start a new session
if err := q.flowLimiter.Acquire(management.UDP.String()); err != nil {
log.Warn().Msgf("Too many concurrent sessions being handled, rejecting udp proxy to %s:%d", dstIP, dstPort)
err := pkgerrors.Wrap(err, "failed to start udp session due to rate limiting")
tracing.EndWithErrorStatus(registerSpan, err)
return nil, err
}
// Each session is a series of datagram from an eyeball to a dstIP:dstPort.
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
originProxy, err := ingress.DialUDP(dstIP, dstPort)
if err != nil {
log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
tracing.EndWithErrorStatus(registerSpan, err)
q.flowLimiter.Release()
return nil, err
}
registerSpan.SetAttributes(
@ -127,10 +148,14 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
originProxy.Close()
log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
tracing.EndWithErrorStatus(registerSpan, err)
q.flowLimiter.Release()
return nil, err
}
go q.serveUDPSession(session, closeAfterIdleHint)
go func() {
defer q.flowLimiter.Release() // we do the release here, instead of inside the `serveUDPSession` just to keep all acquire/release calls in the same method.
q.serveUDPSession(session, closeAfterIdleHint)
}()
log.Debug().
Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).
@ -170,7 +195,7 @@ func (q *datagramV2Connection) serveUDPSession(session *datagramsession.Session,
// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge
func (q *datagramV2Connection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) {
q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
_ = q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
quicStream, err := q.conn.OpenStream()
if err != nil {
// Log this at debug because this is not an error if session was closed due to lost connection

View File

@ -0,0 +1,96 @@
package connection
import (
"context"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/mocks"
)
type mockQuicConnection struct {
}
func (m *mockQuicConnection) AcceptStream(_ context.Context) (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) AcceptUniStream(_ context.Context) (quic.ReceiveStream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenStream() (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenStreamSync(_ context.Context) (quic.Stream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenUniStream() (quic.SendStream, error) {
return nil, nil
}
func (m *mockQuicConnection) OpenUniStreamSync(_ context.Context) (quic.SendStream, error) {
return nil, nil
}
func (m *mockQuicConnection) LocalAddr() net.Addr {
return nil
}
func (m *mockQuicConnection) RemoteAddr() net.Addr {
return nil
}
func (m *mockQuicConnection) CloseWithError(_ quic.ApplicationErrorCode, s string) error {
return nil
}
func (m *mockQuicConnection) Context() context.Context {
return nil
}
func (m *mockQuicConnection) ConnectionState() quic.ConnectionState {
panic("not meant to be called")
}
func (m *mockQuicConnection) SendDatagram(_ []byte) error {
return nil
}
func (m *mockQuicConnection) ReceiveDatagram(_ context.Context) ([]byte, error) {
return nil, nil
}
func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
log := zerolog.Nop()
conn := &mockQuicConnection{}
ctrl := gomock.NewController(t)
flowLimiterMock := mocks.NewMockLimiter(ctrl)
datagramConn := NewDatagramV2Connection(
context.Background(),
conn,
nil,
0,
0*time.Second,
0*time.Second,
flowLimiterMock,
&log,
)
flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
flowLimiterMock.EXPECT().Release().Times(0)
_, err := datagramConn.RegisterUdpSession(context.Background(), uuid.New(), net.IPv4(0, 0, 0, 0), 1000, 1*time.Second, "")
require.ErrorIs(t, err, cfdflow.ErrTooManyActiveFlows)
}

View File

@ -9,6 +9,7 @@ import (
const (
logFieldOriginCertPath = "originCertPath"
FedEndpoint = "fed"
)
type User struct {
@ -32,6 +33,10 @@ func (c User) CertPath() string {
return c.certPath
}
func (c User) IsFEDEndpoint() bool {
return c.cert.Endpoint == FedEndpoint
}
// Client uses the user credentials to create a Cloudflare API client
func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfapi.Client, error) {
if apiURL == "" {
@ -45,7 +50,6 @@ func (c *User) Client(apiURL string, userAgent string, log *zerolog.Logger) (cfa
userAgent,
log,
)
if err != nil {
return nil, err
}

View File

@ -3,7 +3,7 @@ package credentials
import (
"io/fs"
"os"
"path"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
@ -13,8 +13,8 @@ func TestCredentialsRead(t *testing.T) {
file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem")
require.NoError(t, err)
dir := t.TempDir()
certPath := path.Join(dir, originCertFile)
os.WriteFile(certPath, file, fs.ModePerm)
certPath := filepath.Join(dir, originCertFile)
_ = os.WriteFile(certPath, file, fs.ModePerm)
user, err := Read(certPath, &nopLog)
require.NoError(t, err)
require.Equal(t, certPath, user.CertPath())

View File

@ -1,11 +1,13 @@
package credentials
import (
"bytes"
"encoding/json"
"encoding/pem"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/mitchellh/go-homedir"
"github.com/rs/zerolog"
@ -15,19 +17,30 @@ import (
const (
DefaultCredentialFile = "cert.pem"
OriginCertFlag = "origincert"
)
type namedTunnelToken struct {
type OriginCert struct {
ZoneID string `json:"zoneID"`
AccountID string `json:"accountID"`
APIToken string `json:"apiToken"`
Endpoint string `json:"endpoint,omitempty"`
}
type OriginCert struct {
ZoneID string
APIToken string
AccountID string
func (oc *OriginCert) UnmarshalJSON(data []byte) error {
var aux struct {
ZoneID string `json:"zoneID"`
AccountID string `json:"accountID"`
APIToken string `json:"apiToken"`
Endpoint string `json:"endpoint,omitempty"`
}
if err := json.Unmarshal(data, &aux); err != nil {
return fmt.Errorf("error parsing OriginCert: %v", err)
}
oc.ZoneID = aux.ZoneID
oc.AccountID = aux.AccountID
oc.APIToken = aux.APIToken
oc.Endpoint = strings.ToLower(aux.Endpoint)
return nil
}
// FindDefaultOriginCertPath returns the first path that contains a cert.pem file. If none of the
@ -42,40 +55,56 @@ func FindDefaultOriginCertPath() string {
return ""
}
func DecodeOriginCert(blocks []byte) (*OriginCert, error) {
return decodeOriginCert(blocks)
}
func (cert *OriginCert) EncodeOriginCert() ([]byte, error) {
if cert == nil {
return nil, fmt.Errorf("originCert cannot be nil")
}
buffer, err := json.Marshal(cert)
if err != nil {
return nil, fmt.Errorf("originCert marshal failed: %v", err)
}
block := pem.Block{
Type: "ARGO TUNNEL TOKEN",
Headers: map[string]string{},
Bytes: buffer,
}
var out bytes.Buffer
err = pem.Encode(&out, &block)
if err != nil {
return nil, fmt.Errorf("pem encoding failed: %v", err)
}
return out.Bytes(), nil
}
func decodeOriginCert(blocks []byte) (*OriginCert, error) {
if len(blocks) == 0 {
return nil, fmt.Errorf("Cannot decode empty certificate")
return nil, fmt.Errorf("cannot decode empty certificate")
}
originCert := OriginCert{}
block, rest := pem.Decode(blocks)
for {
if block == nil {
break
}
for block != nil {
switch block.Type {
case "PRIVATE KEY", "CERTIFICATE":
// this is for legacy purposes.
break
case "ARGO TUNNEL TOKEN":
if originCert.ZoneID != "" || originCert.APIToken != "" {
return nil, fmt.Errorf("Found multiple tokens in the certificate")
return nil, fmt.Errorf("found multiple tokens in the certificate")
}
// The token is a string,
// Try the newer JSON format
ntt := namedTunnelToken{}
if err := json.Unmarshal(block.Bytes, &ntt); err == nil {
originCert.ZoneID = ntt.ZoneID
originCert.APIToken = ntt.APIToken
originCert.AccountID = ntt.AccountID
}
_ = json.Unmarshal(block.Bytes, &originCert)
default:
return nil, fmt.Errorf("Unknown block %s in the certificate", block.Type)
return nil, fmt.Errorf("unknown block %s in the certificate", block.Type)
}
block, rest = pem.Decode(rest)
}
if originCert.ZoneID == "" || originCert.APIToken == "" {
return nil, fmt.Errorf("Missing token in the certificate")
return nil, fmt.Errorf("missing token in the certificate")
}
return &originCert, nil

View File

@ -4,7 +4,7 @@ import (
"fmt"
"io/fs"
"os"
"path"
"path/filepath"
"testing"
"github.com/rs/zerolog"
@ -16,27 +16,25 @@ const (
originCertFile = "cert.pem"
)
var (
nopLog = zerolog.Nop().With().Logger()
)
var nopLog = zerolog.Nop().With().Logger()
func TestLoadOriginCert(t *testing.T) {
cert, err := decodeOriginCert([]byte{})
assert.Equal(t, fmt.Errorf("Cannot decode empty certificate"), err)
assert.Equal(t, fmt.Errorf("cannot decode empty certificate"), err)
assert.Nil(t, cert)
blocks, err := os.ReadFile("test-cert-unknown-block.pem")
assert.NoError(t, err)
require.NoError(t, err)
cert, err = decodeOriginCert(blocks)
assert.Equal(t, fmt.Errorf("Unknown block RSA PRIVATE KEY in the certificate"), err)
assert.Equal(t, fmt.Errorf("unknown block RSA PRIVATE KEY in the certificate"), err)
assert.Nil(t, cert)
}
func TestJSONArgoTunnelTokenEmpty(t *testing.T) {
blocks, err := os.ReadFile("test-cert-no-token.pem")
assert.NoError(t, err)
require.NoError(t, err)
cert, err := decodeOriginCert(blocks)
assert.Equal(t, fmt.Errorf("Missing token in the certificate"), err)
assert.Equal(t, fmt.Errorf("missing token in the certificate"), err)
assert.Nil(t, cert)
}
@ -52,51 +50,21 @@ func TestJSONArgoTunnelToken(t *testing.T) {
func CloudflareTunnelTokenTest(t *testing.T, path string) {
blocks, err := os.ReadFile(path)
assert.NoError(t, err)
require.NoError(t, err)
cert, err := decodeOriginCert(blocks)
assert.NoError(t, err)
require.NoError(t, err)
assert.NotNil(t, cert)
assert.Equal(t, "7b0a4d77dfb881c1a3b7d61ea9443e19", cert.ZoneID)
key := "test-service-key"
assert.Equal(t, key, cert.APIToken)
}
type mockFile struct {
path string
data []byte
err error
}
type mockFileSystem struct {
files map[string]mockFile
}
func newMockFileSystem(files ...mockFile) *mockFileSystem {
fs := mockFileSystem{map[string]mockFile{}}
for _, f := range files {
fs.files[f.path] = f
}
return &fs
}
func (fs *mockFileSystem) ReadFile(path string) ([]byte, error) {
if f, ok := fs.files[path]; ok {
return f.data, f.err
}
return nil, os.ErrNotExist
}
func (fs *mockFileSystem) ValidFilePath(path string) bool {
_, exists := fs.files[path]
return exists
}
func TestFindOriginCert_Valid(t *testing.T) {
file, err := os.ReadFile("test-cloudflare-tunnel-cert-json.pem")
require.NoError(t, err)
dir := t.TempDir()
certPath := path.Join(dir, originCertFile)
os.WriteFile(certPath, file, fs.ModePerm)
certPath := filepath.Join(dir, originCertFile)
_ = os.WriteFile(certPath, file, fs.ModePerm)
path, err := FindOriginCert(certPath, &nopLog)
require.NoError(t, err)
require.Equal(t, certPath, path)
@ -104,7 +72,32 @@ func TestFindOriginCert_Valid(t *testing.T) {
func TestFindOriginCert_Missing(t *testing.T) {
dir := t.TempDir()
certPath := path.Join(dir, originCertFile)
certPath := filepath.Join(dir, originCertFile)
_, err := FindOriginCert(certPath, &nopLog)
require.Error(t, err)
}
func TestEncodeDecodeOriginCert(t *testing.T) {
cert := OriginCert{
ZoneID: "zone",
AccountID: "account",
APIToken: "token",
Endpoint: "FED",
}
blocks, err := cert.EncodeOriginCert()
require.NoError(t, err)
decodedCert, err := DecodeOriginCert(blocks)
require.NoError(t, err)
assert.NotNil(t, cert)
assert.Equal(t, "zone", decodedCert.ZoneID)
assert.Equal(t, "account", decodedCert.AccountID)
assert.Equal(t, "token", decodedCert.APIToken)
assert.Equal(t, FedEndpoint, decodedCert.Endpoint)
}
func TestEncodeDecodeNilOriginCert(t *testing.T) {
var cert *OriginCert
blocks, err := cert.EncodeOriginCert()
assert.Equal(t, fmt.Errorf("originCert cannot be nil"), err)
require.Nil(t, blocks)
}

View File

@ -87,3 +87,4 @@ M2i4QoOFcSKIG+v4SuvgEJHgG8vGvxh2qlSxnMWuPV+7/1P5ATLqDj1PlKms+BNR
y7sc5AT9PclkL3Y9MNzOu0LXyBkGYcl8M0EQfLv9VPbWT+NXiMg/O2CHiT02pAAz
uQicoQq3yzeQh20wtrtaXzTNmA==
-----END RSA PRIVATE KEY-----

View File

@ -1,4 +1,4 @@
FROM golang:1.22.5 as builder
FROM golang:1.22.10 as builder
ENV GO111MODULE=on \
CGO_ENABLED=0
WORKDIR /go/src/github.com/cloudflare/cloudflared/

View File

@ -9,7 +9,7 @@ import (
"net/url"
"strconv"
"github.com/cloudflare/cloudflared/logger"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
)
type httpClient struct {
@ -86,12 +86,12 @@ func (client *httpClient) GetLogConfiguration(ctx context.Context) (*LogConfigur
return nil, fmt.Errorf("error convertin pid to int: %w", err)
}
logFile, exists := data[logger.LogFileFlag]
logFile, exists := data[cfdflags.LogFile]
if exists {
return &LogConfiguration{logFile, "", uid}, nil
}
logDirectory, exists := data[logger.LogDirectoryFlag]
logDirectory, exists := data[cfdflags.LogDirectory]
if exists {
return &LogConfiguration{"", logDirectory, uid}, nil
}

View File

@ -11,28 +11,40 @@ const (
FeatureDatagramV3 = "support_datagram_v3"
)
var (
DefaultFeatures = []string{
var defaultFeatures = []string{
FeatureAllowRemoteConfig,
FeatureSerializedHeaders,
FeatureDatagramV2,
FeatureQUICSupportEOF,
FeatureManagementLogs,
}
}
// Features set by user provided flags
type staticFeatures struct {
PostQuantumMode *PostQuantumMode
}
type PostQuantumMode uint8
const (
// Prefer post quantum, but fallback if connection cannot be established
PostQuantumPrefer PostQuantumMode = iota
// If the user passes the --post-quantum flag, we override
// CurvePreferences to only support hybrid post-quantum key agreements.
PostQuantumStrict
)
func Contains(feature string) bool {
for _, f := range DefaultFeatures {
if f == feature {
return true
}
}
return false
}
type DatagramVersion string
const (
// DatagramV2 is the currently supported datagram protocol for UDP and ICMP packets
DatagramV2 DatagramVersion = FeatureDatagramV2
// DatagramV3 is a new datagram protocol for UDP and ICMP packets. It is not backwards compatible with datagram v2.
DatagramV3 DatagramVersion = FeatureDatagramV3
)
// Remove any duplicates from the slice
func Dedup(slice []string) []string {
// Convert the slice into a set
set := make(map[string]bool, 0)
for _, str := range slice {

View File

@ -6,6 +6,7 @@ import (
"fmt"
"hash/fnv"
"net"
"slices"
"sync"
"time"
@ -18,61 +19,67 @@ const (
lookupTimeout = time.Second * 10
)
type PostQuantumMode uint8
const (
// Prefer post quantum, but fallback if connection cannot be established
PostQuantumPrefer PostQuantumMode = iota
// If the user passes the --post-quantum flag, we override
// CurvePreferences to only support hybrid post-quantum key agreements.
PostQuantumStrict
)
// If the TXT record adds other fields, the umarshal logic will ignore those keys
// If the TXT record is missing a key, the field will unmarshal to the default Go value
// pq was removed in TUN-7970
type featuresRecord struct{}
func NewFeatureSelector(ctx context.Context, accountTag string, staticFeatures StaticFeatures, logger *zerolog.Logger) (*FeatureSelector, error) {
return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), staticFeatures, defaultRefreshFreq)
type featuresRecord struct {
// support_datagram_v3
DatagramV3Percentage int32 `json:"dv3"`
// PostQuantumPercentage int32 `json:"pq"` // Removed in TUN-7970
}
// FeatureSelector determines if this account will try new features. It preiodically queries a DNS TXT record
// to see which features are turned on
func NewFeatureSelector(ctx context.Context, accountTag string, cliFeatures []string, pq bool, logger *zerolog.Logger) (*FeatureSelector, error) {
return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), cliFeatures, pq, defaultRefreshFreq)
}
// FeatureSelector determines if this account will try new features. It periodically queries a DNS TXT record
// to see which features are turned on.
type FeatureSelector struct {
accountHash int32
logger *zerolog.Logger
resolver resolver
staticFeatures StaticFeatures
staticFeatures staticFeatures
cliFeatures []string
// lock protects concurrent access to dynamic features
lock sync.RWMutex
features featuresRecord
}
// Features set by user provided flags
type StaticFeatures struct {
PostQuantumMode *PostQuantumMode
}
func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, staticFeatures StaticFeatures, refreshFreq time.Duration) (*FeatureSelector, error) {
func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, cliFeatures []string, pq bool, refreshFreq time.Duration) (*FeatureSelector, error) {
// Combine default features and user-provided features
var pqMode *PostQuantumMode
if pq {
mode := PostQuantumStrict
pqMode = &mode
cliFeatures = append(cliFeatures, FeaturePostQuantum)
}
staticFeatures := staticFeatures{
PostQuantumMode: pqMode,
}
selector := &FeatureSelector{
accountHash: switchThreshold(accountTag),
logger: logger,
resolver: resolver,
staticFeatures: staticFeatures,
cliFeatures: Dedup(cliFeatures),
}
if err := selector.refresh(ctx); err != nil {
logger.Err(err).Msg("Failed to fetch features, default to disable")
}
// Run refreshLoop next time we have a new feature to rollout
go selector.refreshLoop(ctx, refreshFreq)
return selector, nil
}
func (fs *FeatureSelector) accountEnabled(percentage int32) bool {
return percentage > fs.accountHash
}
func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode {
if fs.staticFeatures.PostQuantumMode != nil {
return *fs.staticFeatures.PostQuantumMode
@ -81,6 +88,33 @@ func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode {
return PostQuantumPrefer
}
func (fs *FeatureSelector) DatagramVersion() DatagramVersion {
fs.lock.RLock()
defer fs.lock.RUnlock()
// If user provides the feature via the cli, we take it as priority over remote feature evaluation
if slices.Contains(fs.cliFeatures, FeatureDatagramV3) {
return DatagramV3
}
// If the user specifies DatagramV2, we also take that over remote
if slices.Contains(fs.cliFeatures, FeatureDatagramV2) {
return DatagramV2
}
if fs.accountEnabled(fs.features.DatagramV3Percentage) {
return DatagramV3
}
return DatagramV2
}
// ClientFeatures will return the list of currently available features that cloudflared should provide to the edge.
//
// This list is dynamic and can change in-between returns.
func (fs *FeatureSelector) ClientFeatures() []string {
// Evaluate any remote features along with static feature list to construct the list of features
return Dedup(slices.Concat(defaultFeatures, fs.cliFeatures, []string{string(fs.DatagramVersion())}))
}
func (fs *FeatureSelector) refreshLoop(ctx context.Context, refreshFreq time.Duration) {
ticker := time.NewTicker(refreshFreq)
for {

View File

@ -14,15 +14,19 @@ import (
func TestUnmarshalFeaturesRecord(t *testing.T) {
tests := []struct {
record []byte
expectedPercentage int32
}{
{
record: []byte(`{"pq":0}`),
record: []byte(`{"dv3":0}`),
expectedPercentage: 0,
},
{
record: []byte(`{"pq":39}`),
record: []byte(`{"dv3":39}`),
expectedPercentage: 39,
},
{
record: []byte(`{"pq":100}`),
record: []byte(`{"dv3":100}`),
expectedPercentage: 100,
},
{
record: []byte(`{}`), // Unmarshal to default struct if key is not present
@ -36,37 +40,186 @@ func TestUnmarshalFeaturesRecord(t *testing.T) {
var features featuresRecord
err := json.Unmarshal(test.record, &features)
require.NoError(t, err)
require.Equal(t, featuresRecord{}, features)
require.Equal(t, test.expectedPercentage, features.DatagramV3Percentage, test)
}
}
func TestFeaturePrecedenceEvaluationPostQuantum(t *testing.T) {
logger := zerolog.Nop()
tests := []struct {
name string
cli bool
expectedFeatures []string
expectedVersion PostQuantumMode
}{
{
name: "default",
cli: false,
expectedFeatures: defaultFeatures,
expectedVersion: PostQuantumPrefer,
},
{
name: "user_specified",
cli: true,
expectedFeatures: Dedup(append(defaultFeatures, FeaturePostQuantum)),
expectedVersion: PostQuantumStrict,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
resolver := &staticResolver{record: featuresRecord{}}
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, []string{}, test.cli, time.Second)
require.NoError(t, err)
require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures())
require.Equal(t, test.expectedVersion, selector.PostQuantumMode())
})
}
}
func TestFeaturePrecedenceEvaluationDatagramVersion(t *testing.T) {
logger := zerolog.Nop()
tests := []struct {
name string
cli []string
remote featuresRecord
expectedFeatures []string
expectedVersion DatagramVersion
}{
{
name: "default",
cli: []string{},
remote: featuresRecord{},
expectedFeatures: defaultFeatures,
expectedVersion: DatagramV2,
},
{
name: "user_specified_v2",
cli: []string{FeatureDatagramV2},
remote: featuresRecord{},
expectedFeatures: defaultFeatures,
expectedVersion: DatagramV2,
},
{
name: "user_specified_v3",
cli: []string{FeatureDatagramV3},
remote: featuresRecord{},
expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)),
expectedVersion: FeatureDatagramV3,
},
{
name: "remote_specified_v3",
cli: []string{},
remote: featuresRecord{
DatagramV3Percentage: 100,
},
expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)),
expectedVersion: FeatureDatagramV3,
},
{
name: "remote_and_user_specified_v3",
cli: []string{FeatureDatagramV3},
remote: featuresRecord{
DatagramV3Percentage: 100,
},
expectedFeatures: Dedup(append(defaultFeatures, FeatureDatagramV3)),
expectedVersion: FeatureDatagramV3,
},
{
name: "remote_v3_and_user_specified_v2",
cli: []string{FeatureDatagramV2},
remote: featuresRecord{
DatagramV3Percentage: 100,
},
expectedFeatures: defaultFeatures,
expectedVersion: DatagramV2,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
resolver := &staticResolver{record: test.remote}
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false, time.Second)
require.NoError(t, err)
require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures())
require.Equal(t, test.expectedVersion, selector.DatagramVersion())
})
}
}
func TestRefreshFeaturesRecord(t *testing.T) {
// The hash of the accountTag is 82
accountTag := t.Name()
threshold := switchThreshold(accountTag)
percentages := []int32{0, 10, 81, 82, 83, 100, 101, 1000}
refreshFreq := time.Millisecond * 10
selector := newTestSelector(t, percentages, false, refreshFreq)
// Starting out should default to DatagramV2
require.Equal(t, DatagramV2, selector.DatagramVersion())
for _, percentage := range percentages {
if percentage > threshold {
require.Equal(t, DatagramV3, selector.DatagramVersion())
} else {
require.Equal(t, DatagramV2, selector.DatagramVersion())
}
time.Sleep(refreshFreq + time.Millisecond)
}
// Make sure error doesn't override the last fetched features
require.Equal(t, DatagramV3, selector.DatagramVersion())
}
func TestStaticFeatures(t *testing.T) {
pqMode := PostQuantumStrict
selector := newTestSelector(t, &pqMode, time.Millisecond*10)
percentages := []int32{0}
// PostQuantum Enabled from user flag
selector := newTestSelector(t, percentages, true, time.Millisecond*10)
require.Equal(t, PostQuantumStrict, selector.PostQuantumMode())
// No StaticFeatures configured
selector = newTestSelector(t, nil, time.Millisecond*10)
// PostQuantum Disabled (or not set)
selector = newTestSelector(t, percentages, false, time.Millisecond*10)
require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode())
}
func newTestSelector(t *testing.T, pqMode *PostQuantumMode, refreshFreq time.Duration) *FeatureSelector {
func newTestSelector(t *testing.T, percentages []int32, pq bool, refreshFreq time.Duration) *FeatureSelector {
accountTag := t.Name()
logger := zerolog.Nop()
resolver := &mockResolver{}
staticFeatures := StaticFeatures{
PostQuantumMode: pqMode,
resolver := &mockResolver{
percentages: percentages,
}
selector, err := newFeatureSelector(context.Background(), accountTag, &logger, resolver, staticFeatures, refreshFreq)
selector, err := newFeatureSelector(context.Background(), accountTag, &logger, resolver, []string{}, pq, refreshFreq)
require.NoError(t, err)
return selector
}
type mockResolver struct{}
type mockResolver struct {
nextIndex int
percentages []int32
}
func (mr *mockResolver) lookupRecord(ctx context.Context) ([]byte, error) {
return nil, fmt.Errorf("mockResolver hasn't implement lookupRecord")
if mr.nextIndex >= len(mr.percentages) {
return nil, fmt.Errorf("no more record to lookup")
}
record, err := json.Marshal(featuresRecord{
DatagramV3Percentage: mr.percentages[mr.nextIndex],
})
mr.nextIndex++
return record, err
}
type staticResolver struct {
record featuresRecord
}
func (r *staticResolver) lookupRecord(ctx context.Context) ([]byte, error) {
return json.Marshal(r.record)
}

11
fips/fips.go Normal file
View File

@ -0,0 +1,11 @@
//go:build fips
package fips
import (
_ "crypto/tls/fipsonly"
)
func IsFipsEnabled() bool {
return true
}

View File

@ -1,12 +0,0 @@
// +build fips
package main
import (
_ "crypto/tls/fipsonly"
"github.com/cloudflare/cloudflared/cmd/cloudflared/tunnel"
)
func init () {
tunnel.FipsEnabled = true
}

7
fips/nofips.go Normal file
View File

@ -0,0 +1,7 @@
//go:build !fips
package fips
func IsFipsEnabled() bool {
return false
}

77
flow/limiter.go Normal file
View File

@ -0,0 +1,77 @@
package flow
import (
"errors"
"sync"
)
const (
unlimitedActiveFlows = 0
)
var (
ErrTooManyActiveFlows = errors.New("too many active flows")
)
type Limiter interface {
// Acquire tries to acquire a free slot for a flow, if the value of flows is already above
// the maximum it returns ErrTooManyActiveFlows.
Acquire(flowType string) error
// Release releases a slot for a flow.
Release()
// SetLimit allows to hot swap the limit value of the limiter.
SetLimit(uint64)
}
type flowLimiter struct {
limiterLock sync.Mutex
activeFlowsCounter uint64
maxActiveFlows uint64
unlimited bool
}
func NewLimiter(maxActiveFlows uint64) Limiter {
flowLimiter := &flowLimiter{
maxActiveFlows: maxActiveFlows,
unlimited: isUnlimited(maxActiveFlows),
}
return flowLimiter
}
func (s *flowLimiter) Acquire(flowType string) error {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
if !s.unlimited && s.activeFlowsCounter >= s.maxActiveFlows {
flowRegistrationsDropped.WithLabelValues(flowType).Inc()
return ErrTooManyActiveFlows
}
s.activeFlowsCounter++
return nil
}
func (s *flowLimiter) Release() {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
if s.activeFlowsCounter <= 0 {
return
}
s.activeFlowsCounter--
}
func (s *flowLimiter) SetLimit(newMaxActiveFlows uint64) {
s.limiterLock.Lock()
defer s.limiterLock.Unlock()
s.maxActiveFlows = newMaxActiveFlows
s.unlimited = isUnlimited(newMaxActiveFlows)
}
// isUnlimited checks if the value received matches the configuration for the unlimited flow limiter.
func isUnlimited(value uint64) bool {
return value == unlimitedActiveFlows
}

119
flow/limiter_test.go Normal file
View File

@ -0,0 +1,119 @@
package flow_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/flow"
)
func TestFlowLimiter_Unlimited(t *testing.T) {
unlimitedLimiter := flow.NewLimiter(0)
for i := 0; i < 1000; i++ {
err := unlimitedLimiter.Acquire("test")
require.NoError(t, err)
}
}
func TestFlowLimiter_Limited(t *testing.T) {
maxFlows := uint64(5)
limiter := flow.NewLimiter(maxFlows)
for i := uint64(0); i < maxFlows; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
err := limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
}
func TestFlowLimiter_AcquireAndReleaseFlow(t *testing.T) {
maxFlows := uint64(5)
limiter := flow.NewLimiter(maxFlows)
// Acquire the maximum number of flows
for i := uint64(0); i < maxFlows; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
// Validate acquire 1 more flows fails
err := limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
// Release the maximum number of flows
for i := uint64(0); i < maxFlows; i++ {
limiter.Release()
}
// Validate acquire 1 more flows works
err = limiter.Acquire("shouldn't fail")
require.NoError(t, err)
// Release a 10x the number of max flows
for i := uint64(0); i < 10*maxFlows; i++ {
limiter.Release()
}
// Validate it still can only acquire a value = number max flows.
for i := uint64(0); i < maxFlows; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
err = limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
}
func TestFlowLimiter_SetLimit(t *testing.T) {
maxFlows := uint64(5)
limiter := flow.NewLimiter(maxFlows)
// Acquire the maximum number of flows
for i := uint64(0); i < maxFlows; i++ {
err := limiter.Acquire("test")
require.NoError(t, err)
}
// Validate acquire 1 more flows fails
err := limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
// Set the flow limiter to support one more request
limiter.SetLimit(maxFlows + 1)
// Validate acquire 1 more flows now works
err = limiter.Acquire("shouldn't fail")
require.NoError(t, err)
// Validate acquire 1 more flows doesn't work because we already reached the limit
err = limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
// Release all flows
for i := uint64(0); i < maxFlows+1; i++ {
limiter.Release()
}
// Validate 1 flow works again
err = limiter.Acquire("shouldn't fail")
require.NoError(t, err)
// Set the flow limit to 1
limiter.SetLimit(1)
// Validate acquire 1 more flows doesn't work
err = limiter.Acquire("should fail")
require.ErrorIs(t, err, flow.ErrTooManyActiveFlows)
// Set the flow limit to unlimited
limiter.SetLimit(0)
// Validate it can acquire a lot of flows because it is now unlimited.
for i := uint64(0); i < 10*maxFlows; i++ {
err := limiter.Acquire("shouldn't fail")
require.NoError(t, err)
}
}

23
flow/metrics.go Normal file
View File

@ -0,0 +1,23 @@
package flow
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
const (
namespace = "flow"
)
var (
labels = []string{"flow_type"}
flowRegistrationsDropped = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "registrations_rate_limited_total",
Help: "Count registrations dropped due to high number of concurrent flows being handled",
},
labels,
)
)

21
go.mod
View File

@ -35,11 +35,12 @@ require (
go.opentelemetry.io/otel/trace v1.26.0
go.opentelemetry.io/proto/otlp v1.2.0
go.uber.org/automaxprocs v1.4.0
golang.org/x/crypto v0.23.0
golang.org/x/net v0.25.0
golang.org/x/sync v0.7.0
golang.org/x/sys v0.20.0
golang.org/x/term v0.20.0
go.uber.org/mock v0.5.0
golang.org/x/crypto v0.31.0
golang.org/x/net v0.26.0
golang.org/x/sync v0.10.0
golang.org/x/sys v0.28.0
golang.org/x/term v0.27.0
google.golang.org/protobuf v1.34.1
gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1
@ -83,12 +84,11 @@ require (
github.com/prometheus/procfs v0.12.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
go.opentelemetry.io/otel/metric v1.26.0 // indirect
go.uber.org/mock v0.4.0 // indirect
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/mod v0.18.0 // indirect
golang.org/x/oauth2 v0.18.0 // indirect
golang.org/x/text v0.15.0 // indirect
golang.org/x/tools v0.21.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/tools v0.22.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240311132316-a219d84964c2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect
@ -102,3 +102,6 @@ replace github.com/urfave/cli/v2 => github.com/ipostelnik/cli/v2 v2.3.1-0.202103
replace github.com/prometheus/golang_client => github.com/prometheus/golang_client v1.12.1
replace gopkg.in/yaml.v3 => gopkg.in/yaml.v3 v3.0.1
// This fork is based on quic-go v0.45
replace github.com/quic-go/quic-go => github.com/chungthuang/quic-go v0.45.1-0.20250128102735-2687bd175910

40
go.sum
View File

@ -7,6 +7,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chungthuang/quic-go v0.45.1-0.20250128102735-2687bd175910 h1:/hTvBpxBDj/3NIzTodi1oEOyNBpirvgDSPKSV7VqAZU=
github.com/chungthuang/quic-go v0.45.1-0.20250128102735-2687bd175910/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
github.com/coredns/caddy v1.1.1 h1:2eYKZT7i6yxIfGP3qLJoJ7HAsDJqYB+X68g4NYjSrE0=
github.com/coredns/caddy v1.1.1/go.mod h1:A6ntJQlAWuQfFlsd9hvigKbo2WS0VUs2l1e2F+BawD4=
github.com/coredns/coredns v1.11.3 h1:8RjnpZc42db5th84/QJKH2i137ecJdzZK1HJwhetSPk=
@ -173,8 +175,6 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
github.com/quic-go/quic-go v0.45.0 h1:OHmkQGM37luZITyTSu6ff03HP/2IrwDX1ZFiNEhSFUE=
github.com/quic-go/quic-go v0.45.0/go.mod h1:1dLehS7TIR64+vxGR70GDcatWTOtMX2PUtnKsjbTurI=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
@ -217,33 +217,33 @@ go.opentelemetry.io/proto/otlp v1.2.0 h1:pVeZGk7nXDC9O2hncA6nHldxEjm6LByfA2aN8IO
go.opentelemetry.io/proto/otlp v1.2.0/go.mod h1:gGpR8txAl5M03pDhMC79G6SdqNV26naRm/KDsgaHD8A=
go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0=
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=
golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/oauth2 v0.18.0 h1:09qnuIAgzdx1XplqJvW6CQqMCtGZykZWcXzPMPUusvI=
golang.org/x/oauth2 v0.18.0/go.mod h1:Wf7knwG0MPoWIMMBgFlEaSUDaKskp0dCfrlJRJXbBi8=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -254,19 +254,19 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
@ -275,8 +275,8 @@ golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.21.0 h1:qc0xYgIbsSDt9EyWz05J5wfa7LOVW0YTLOXrqdLAWIw=
golang.org/x/tools v0.21.0/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA=
golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -22,6 +22,7 @@ var (
const (
defaultProxyAddress = "127.0.0.1"
defaultKeepAliveConnections = 100
defaultMaxActiveFlows = 0 // unlimited
SSHServerFlag = "ssh-server"
Socks5Flag = "socks5"
ProxyConnectTimeoutFlag = "proxy-connect-timeout"
@ -46,17 +47,22 @@ const (
type WarpRoutingConfig struct {
ConnectTimeout config.CustomDuration `yaml:"connectTimeout" json:"connectTimeout,omitempty"`
MaxActiveFlows uint64 `yaml:"maxActiveFlows" json:"MaxActiveFlows,omitempty"`
TCPKeepAlive config.CustomDuration `yaml:"tcpKeepAlive" json:"tcpKeepAlive,omitempty"`
}
func NewWarpRoutingConfig(raw *config.WarpRoutingConfig) WarpRoutingConfig {
cfg := WarpRoutingConfig{
ConnectTimeout: defaultWarpRoutingConnectTimeout,
MaxActiveFlows: defaultMaxActiveFlows,
TCPKeepAlive: defaultTCPKeepAlive,
}
if raw.ConnectTimeout != nil {
cfg.ConnectTimeout = *raw.ConnectTimeout
}
if raw.MaxActiveFlows != nil {
cfg.MaxActiveFlows = *raw.MaxActiveFlows
}
if raw.TCPKeepAlive != nil {
cfg.TCPKeepAlive = *raw.TCPKeepAlive
}
@ -68,6 +74,9 @@ func (c *WarpRoutingConfig) RawConfig() config.WarpRoutingConfig {
if c.ConnectTimeout.Duration != defaultWarpRoutingConnectTimeout.Duration {
raw.ConnectTimeout = &c.ConnectTimeout
}
if c.MaxActiveFlows != defaultMaxActiveFlows {
raw.MaxActiveFlows = &c.MaxActiveFlows
}
if c.TCPKeepAlive.Duration != defaultTCPKeepAlive.Duration {
raw.TCPKeepAlive = &c.TCPKeepAlive
}
@ -172,6 +181,7 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
}
if flag := ProxyPortFlag; c.IsSet(flag) {
// Note TUN-3758 , we use Int because UInt is not supported with altsrc
// nolint: gosec
proxyPort = uint(c.Int(flag))
}
if flag := Http2OriginFlag; c.IsSet(flag) {
@ -551,7 +561,7 @@ func convertToRawIPRules(ipRules []ipaccess.Rule) []config.IngressIPRule {
}
func defaultBoolToNil(b bool) *bool {
if b == false {
if !b {
return nil
}

View File

@ -4,7 +4,6 @@ import (
"fmt"
"io"
"os"
"path"
"path/filepath"
"sync"
"time"
@ -16,7 +15,7 @@ import (
"golang.org/x/term"
"gopkg.in/natefinch/lumberjack.v2"
"github.com/cloudflare/cloudflared/features"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/management"
)
@ -24,14 +23,6 @@ const (
EnableTerminalLog = false
DisableTerminalLog = true
LogLevelFlag = "loglevel"
LogFileFlag = "logfile"
LogDirectoryFlag = "log-directory"
LogTransportLevelFlag = "transport-loglevel"
LogSSHDirectoryFlag = "log-directory"
LogSSHLevelFlag = "log-level"
dirPermMode = 0744 // rwxr--r--
filePermMode = 0644 // rw-r--r--
@ -46,11 +37,7 @@ func init() {
zerolog.TimeFieldFormat = time.RFC3339
zerolog.TimestampFunc = utcNow
if features.Contains(features.FeatureManagementLogs) {
// Management logger needs to be initialized before any of the other loggers as to not capture
// it's own logging events.
ManagementLogger = management.NewLogger()
}
}
func utcNow() time.Time {
@ -124,10 +111,7 @@ func newZerolog(loggerConfig *Config) *zerolog.Logger {
writers = append(writers, rollingLogger)
}
var managementWriter zerolog.LevelWriter
if features.Contains(features.FeatureManagementLogs) {
managementWriter = ManagementLogger
}
managementWriter := ManagementLogger
level, levelErr := zerolog.ParseLevel(loggerConfig.MinLevel)
if levelErr != nil {
@ -145,15 +129,15 @@ func newZerolog(loggerConfig *Config) *zerolog.Logger {
}
func CreateTransportLoggerFromContext(c *cli.Context, disableTerminal bool) *zerolog.Logger {
return createFromContext(c, LogTransportLevelFlag, LogDirectoryFlag, disableTerminal)
return createFromContext(c, cfdflags.TransportLogLevel, cfdflags.LogDirectory, disableTerminal)
}
func CreateLoggerFromContext(c *cli.Context, disableTerminal bool) *zerolog.Logger {
return createFromContext(c, LogLevelFlag, LogDirectoryFlag, disableTerminal)
return createFromContext(c, cfdflags.LogLevel, cfdflags.LogDirectory, disableTerminal)
}
func CreateSSHLoggerFromContext(c *cli.Context, disableTerminal bool) *zerolog.Logger {
return createFromContext(c, LogSSHLevelFlag, LogSSHDirectoryFlag, disableTerminal)
return createFromContext(c, cfdflags.LogLevelSSH, cfdflags.LogDirectory, disableTerminal)
}
func createFromContext(
@ -163,7 +147,7 @@ func createFromContext(
disableTerminal bool,
) *zerolog.Logger {
logLevel := c.String(logLevelFlagName)
logFile := c.String(LogFileFlag)
logFile := c.String(cfdflags.LogFile)
logDirectory := c.String(logDirectoryFlagName)
loggerConfig := CreateConfig(
@ -175,7 +159,7 @@ func createFromContext(
log := newZerolog(loggerConfig)
if incompatibleFlagsSet := logFile != "" && logDirectory != ""; incompatibleFlagsSet {
log.Error().Msgf("Your config includes values for both %s (%s) and %s (%s), but they are incompatible. %s takes precedence.", LogFileFlag, logFile, logDirectoryFlagName, logDirectory, LogFileFlag)
log.Error().Msgf("Your config includes values for both %s (%s) and %s (%s), but they are incompatible. %s takes precedence.", cfdflags.LogFile, logFile, logDirectoryFlagName, logDirectory, cfdflags.LogFile)
}
return log
}
@ -214,7 +198,6 @@ var (
func createFileWriter(config FileConfig) (io.Writer, error) {
singleFileInit.once.Do(func() {
var logFile io.Writer
fullpath := config.Fullpath()
@ -265,7 +248,7 @@ func createRollingLogger(config RollingConfig) (io.Writer, error) {
}
rotatingFileInit.writer = &lumberjack.Logger{
Filename: path.Join(config.Dirname, config.Filename),
Filename: filepath.Join(config.Dirname, config.Filename),
MaxBackups: config.maxBackups,
MaxSize: config.maxSize,
MaxAge: config.maxAge,

View File

@ -74,7 +74,7 @@ type EventLog struct {
type LogEventType int8
const (
// Cloudflared events are signficant to cloudflared operations like connection state changes.
// Cloudflared events are significant to cloudflared operations like connection state changes.
// Cloudflared is also the default event type for any events that haven't been separated into a proper event type.
Cloudflared LogEventType = iota
HTTP
@ -129,7 +129,7 @@ func (e *LogEventType) UnmarshalJSON(data []byte) error {
// LogLevel corresponds to the zerolog logging levels
// "panic", "fatal", and "trace" are exempt from this list as they are rarely used and, at least
// the the first two are limited to failure conditions that lead to cloudflared shutting down.
// the first two are limited to failure conditions that lead to cloudflared shutting down.
type LogLevel int8
const (

150
mocks/mock_limiter.go Normal file
View File

@ -0,0 +1,150 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ../flow/limiter.go
//
// Generated by this command:
//
// mockgen -typed -build_flags=-tags=gomock -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter
//
// Package mocks is a generated GoMock package.
package mocks
import (
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockLimiter is a mock of Limiter interface.
type MockLimiter struct {
ctrl *gomock.Controller
recorder *MockLimiterMockRecorder
isgomock struct{}
}
// MockLimiterMockRecorder is the mock recorder for MockLimiter.
type MockLimiterMockRecorder struct {
mock *MockLimiter
}
// NewMockLimiter creates a new mock instance.
func NewMockLimiter(ctrl *gomock.Controller) *MockLimiter {
mock := &MockLimiter{ctrl: ctrl}
mock.recorder = &MockLimiterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockLimiter) EXPECT() *MockLimiterMockRecorder {
return m.recorder
}
// Acquire mocks base method.
func (m *MockLimiter) Acquire(flowType string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Acquire", flowType)
ret0, _ := ret[0].(error)
return ret0
}
// Acquire indicates an expected call of Acquire.
func (mr *MockLimiterMockRecorder) Acquire(flowType any) *MockLimiterAcquireCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockLimiter)(nil).Acquire), flowType)
return &MockLimiterAcquireCall{Call: call}
}
// MockLimiterAcquireCall wrap *gomock.Call
type MockLimiterAcquireCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockLimiterAcquireCall) Return(arg0 error) *MockLimiterAcquireCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockLimiterAcquireCall) Do(f func(string) error) *MockLimiterAcquireCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockLimiterAcquireCall) DoAndReturn(f func(string) error) *MockLimiterAcquireCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Release mocks base method.
func (m *MockLimiter) Release() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Release")
}
// Release indicates an expected call of Release.
func (mr *MockLimiterMockRecorder) Release() *MockLimiterReleaseCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockLimiter)(nil).Release))
return &MockLimiterReleaseCall{Call: call}
}
// MockLimiterReleaseCall wrap *gomock.Call
type MockLimiterReleaseCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockLimiterReleaseCall) Return() *MockLimiterReleaseCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockLimiterReleaseCall) Do(f func()) *MockLimiterReleaseCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockLimiterReleaseCall) DoAndReturn(f func()) *MockLimiterReleaseCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// SetLimit mocks base method.
func (m *MockLimiter) SetLimit(arg0 uint64) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetLimit", arg0)
}
// SetLimit indicates an expected call of SetLimit.
func (mr *MockLimiterMockRecorder) SetLimit(arg0 any) *MockLimiterSetLimitCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLimit", reflect.TypeOf((*MockLimiter)(nil).SetLimit), arg0)
return &MockLimiterSetLimitCall{Call: call}
}
// MockLimiterSetLimitCall wrap *gomock.Call
type MockLimiterSetLimitCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockLimiterSetLimitCall) Return() *MockLimiterSetLimitCall {
c.Call = c.Call.Return()
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockLimiterSetLimitCall) Do(f func(uint64)) *MockLimiterSetLimitCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockLimiterSetLimitCall) DoAndReturn(f func(uint64)) *MockLimiterSetLimitCall {
c.Call = c.Call.DoAndReturn(f)
return c
}

5
mocks/mockgen.go Normal file
View File

@ -0,0 +1,5 @@
//go:build gomock || generate
package mocks
//go:generate sh -c "go run go.uber.org/mock/mockgen -typed -build_flags=\"-tags=gomock\" -package mocks -destination mock_limiter.go -source=../flow/limiter.go Limiter"

View File

@ -4,14 +4,17 @@ import (
"context"
"encoding/json"
"fmt"
"strconv"
"sync"
"sync/atomic"
"github.com/pkg/errors"
pkgerrors "github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/proxy"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -33,6 +36,8 @@ type Orchestrator struct {
// cloudflared Configuration
config *Config
tags []pogs.Tag
// flowLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
flowLimiter cfdflow.Limiter
log *zerolog.Logger
// orchestrator must not handle any more updates after shutdownC is closed
@ -54,6 +59,7 @@ func NewOrchestrator(ctx context.Context,
internalRules: internalRules,
config: config,
tags: tags,
flowLimiter: cfdflow.NewLimiter(config.WarpRouting.MaxActiveFlows),
log: log,
shutdownC: ctx.Done(),
}
@ -112,6 +118,30 @@ func (o *Orchestrator) UpdateConfig(version int32, config []byte) *pogs.UpdateCo
}
}
// overrideRemoteWarpRoutingWithLocalValues overrides the ingress.WarpRoutingConfig that comes from the remote with
// the local values if there is any.
func (o *Orchestrator) overrideRemoteWarpRoutingWithLocalValues(remoteWarpRouting *ingress.WarpRoutingConfig) error {
return o.overrideMaxActiveFlows(o.config.ConfigurationFlags[flags.MaxActiveFlows], remoteWarpRouting)
}
// overrideMaxActiveFlows checks the local configuration flags, and if a value is found for the flags.MaxActiveFlows
// overrides the value that comes on the remote ingress.WarpRoutingConfig with the local value.
func (o *Orchestrator) overrideMaxActiveFlows(maxActiveFlowsLocalConfig string, remoteWarpRouting *ingress.WarpRoutingConfig) error {
// If max active flows isn't defined locally just use the remote value
if maxActiveFlowsLocalConfig == "" {
return nil
}
maxActiveFlowsLocalOverride, err := strconv.ParseUint(maxActiveFlowsLocalConfig, 10, 64)
if err != nil {
return pkgerrors.Wrapf(err, "failed to parse %s", flags.MaxActiveFlows)
}
// Override the value that comes from the remote with the local value
remoteWarpRouting.MaxActiveFlows = maxActiveFlowsLocalOverride
return nil
}
// The caller is responsible to make sure there is no concurrent access
func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting ingress.WarpRoutingConfig) error {
select {
@ -120,6 +150,11 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
default:
}
// Overrides the local values, onto the remote values of the warp routing configuration
if err := o.overrideRemoteWarpRoutingWithLocalValues(&warpRouting); err != nil {
return pkgerrors.Wrap(err, "failed to merge local overrides into warp routing configuration")
}
// Assign the internal ingress rules to the parsed ingress
ingressRules.InternalRules = o.internalRules
@ -134,9 +169,13 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
// The downside is minimized because none of the ingress.OriginService implementation have that requirement
proxyShutdownC := make(chan struct{})
if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil {
return errors.Wrap(err, "failed to start origin")
return pkgerrors.Wrap(err, "failed to start origin")
}
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.config.WriteTimeout, o.log)
// Update the flow limit since the configuration might have changed
o.flowLimiter.SetLimit(warpRouting.MaxActiveFlows)
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.flowLimiter, o.config.WriteTimeout, o.log)
o.proxy.Store(proxy)
o.config.Ingress = &ingressRules
o.config.WarpRouting = warpRouting
@ -208,6 +247,12 @@ func (o *Orchestrator) GetOriginProxy() (connection.OriginProxy, error) {
return proxy, nil
}
// GetFlowLimiter returns the flow limiter used across cloudflared, that can be hot reload when
// the configuration changes.
func (o *Orchestrator) GetFlowLimiter() cfdflow.Limiter {
return o.flowLimiter
}
func (o *Orchestrator) waitToCloseLastProxy() {
<-o.shutdownC
o.lock.Lock()

View File

@ -16,8 +16,11 @@ import (
"github.com/google/uuid"
gows "github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
@ -106,25 +109,25 @@ func TestUpdateConfiguration(t *testing.T) {
require.Len(t, configV2.Ingress.Rules, 3)
// originRequest of this ingress rule overrides global default
require.Equal(t, config.CustomDuration{Duration: time.Second * 10}, configV2.Ingress.Rules[0].Config.ConnectTimeout)
require.Equal(t, true, configV2.Ingress.Rules[0].Config.NoTLSVerify)
require.True(t, configV2.Ingress.Rules[0].Config.NoTLSVerify)
// Inherited from global default
require.Equal(t, true, configV2.Ingress.Rules[0].Config.NoHappyEyeballs)
require.True(t, configV2.Ingress.Rules[0].Config.NoHappyEyeballs)
// Validate ingress rule 1
require.Equal(t, "jira.tunnel.org", configV2.Ingress.Rules[1].Hostname)
require.True(t, configV2.Ingress.Rules[1].Matches("jira.tunnel.org", "/users"))
require.Equal(t, "http://172.32.20.6:80", configV2.Ingress.Rules[1].Service.String())
// originRequest of this ingress rule overrides global default
require.Equal(t, config.CustomDuration{Duration: time.Second * 30}, configV2.Ingress.Rules[1].Config.ConnectTimeout)
require.Equal(t, true, configV2.Ingress.Rules[1].Config.NoTLSVerify)
require.True(t, configV2.Ingress.Rules[1].Config.NoTLSVerify)
// Inherited from global default
require.Equal(t, true, configV2.Ingress.Rules[1].Config.NoHappyEyeballs)
require.True(t, configV2.Ingress.Rules[1].Config.NoHappyEyeballs)
// Validate ingress rule 2, it's the catch-all rule
require.True(t, configV2.Ingress.Rules[2].Matches("blogs.tunnel.io", "/2022/02/10"))
// Inherited from global default
require.Equal(t, config.CustomDuration{Duration: time.Second * 90}, configV2.Ingress.Rules[2].Config.ConnectTimeout)
require.Equal(t, false, configV2.Ingress.Rules[2].Config.NoTLSVerify)
require.Equal(t, true, configV2.Ingress.Rules[2].Config.NoHappyEyeballs)
require.Equal(t, configV2.WarpRouting.ConnectTimeout.Duration, 10*time.Second)
require.False(t, configV2.Ingress.Rules[2].Config.NoTLSVerify)
require.True(t, configV2.Ingress.Rules[2].Config.NoHappyEyeballs)
require.Equal(t, 10*time.Second, configV2.WarpRouting.ConnectTimeout.Duration)
originProxyV2, err := orchestrator.GetOriginProxy()
require.NoError(t, err)
@ -317,7 +320,7 @@ func TestConcurrentUpdateAndRead(t *testing.T) {
go func(i int, originProxy connection.OriginProxy) {
defer wg.Done()
resp, err := proxyHTTP(originProxy, hostname)
require.NoError(t, err, "proxyHTTP %d failed %v", i, err)
assert.NoError(t, err, "proxyHTTP %d failed %v", i, err)
defer resp.Body.Close()
var warpRoutingDisabled bool
@ -326,16 +329,16 @@ func TestConcurrentUpdateAndRead(t *testing.T) {
// v1 proxy, warp enabled
case 200:
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, t.Name(), string(body))
assert.NoError(t, err)
assert.Equal(t, t.Name(), string(body))
warpRoutingDisabled = false
// v2 proxy, warp disabled
case 204:
require.Greater(t, i, concurrentRequests/4)
assert.Greater(t, i, concurrentRequests/4)
warpRoutingDisabled = true
// v3 proxy, warp enabled
case 418:
require.Greater(t, i, concurrentRequests/2)
assert.Greater(t, i, concurrentRequests/2)
warpRoutingDisabled = false
}
@ -358,11 +361,10 @@ func TestConcurrentUpdateAndRead(t *testing.T) {
err = proxyTCP(ctx, originProxy, tcpOrigin.Addr().String(), w, pr)
if warpRoutingDisabled {
require.Error(t, err, "expect proxyTCP %d to return error", i)
assert.Error(t, err, "expect proxyTCP %d to return error", i)
} else {
require.NoError(t, err, "proxyTCP %d failed %v", i, err)
assert.NoError(t, err, "proxyTCP %d failed %v", i, err)
}
}(i, originProxy)
if i == concurrentRequests/4 {
@ -388,6 +390,57 @@ func TestConcurrentUpdateAndRead(t *testing.T) {
wg.Wait()
}
// TestOverrideWarpRoutingConfigWithLocalValues tests that if a value is defined in the Config.ConfigurationFlags,
// it will override the value that comes from the remote result.
func TestOverrideWarpRoutingConfigWithLocalValues(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
assertMaxActiveFlows := func(orchestrator *Orchestrator, expectedValue uint64) {
configJson, err := orchestrator.GetConfigJSON()
require.NoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(configJson, &result)
require.NoError(t, err)
warpRouting := result["warp-routing"].(map[string]interface{})
require.EqualValues(t, expectedValue, warpRouting["maxActiveFlows"])
}
remoteValue := uint64(100)
remoteIngress := ingress.Ingress{}
remoteWarpConfig := ingress.WarpRoutingConfig{
MaxActiveFlows: remoteValue,
}
remoteConfig := &Config{
Ingress: &remoteIngress,
WarpRouting: remoteWarpConfig,
ConfigurationFlags: map[string]string{},
}
orchestrator, err := NewOrchestrator(ctx, remoteConfig, testTags, []ingress.Rule{}, &testLogger)
require.NoError(t, err)
assertMaxActiveFlows(orchestrator, remoteValue)
// Add a local override for the maxActiveFlows
localValue := uint64(500)
remoteConfig.ConfigurationFlags[flags.MaxActiveFlows] = fmt.Sprintf("%d", localValue)
// Force a configuration refresh
err = orchestrator.updateIngress(remoteIngress, remoteWarpConfig)
require.NoError(t, err)
// Check the value being used is the local one
assertMaxActiveFlows(orchestrator, localValue)
// Remove local override for the maxActiveFlows
delete(remoteConfig.ConfigurationFlags, flags.MaxActiveFlows)
// Force a configuration refresh
err = orchestrator.updateIngress(remoteIngress, remoteWarpConfig)
require.NoError(t, err)
// Check the value being used is now the remote again
assertMaxActiveFlows(orchestrator, remoteValue)
}
func proxyHTTP(originProxy connection.OriginProxy, hostname string) (*http.Response, error) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", hostname), nil)
if err != nil {
@ -409,15 +462,16 @@ func proxyHTTP(originProxy connection.OriginProxy, hostname string) (*http.Respo
return w.Result(), nil
}
// nolint: testifylint // this is used inside go routines so it can't use `require.`
func tcpEyeball(t *testing.T, reqWriter io.WriteCloser, body string, respReadWriter *respReadWriteFlusher) {
writeN, err := reqWriter.Write([]byte(body))
require.NoError(t, err)
assert.NoError(t, err)
readBuffer := make([]byte, writeN)
n, err := respReadWriter.Read(readBuffer)
require.NoError(t, err)
require.Equal(t, body, string(readBuffer[:n]))
require.Equal(t, writeN, n)
assert.NoError(t, err)
assert.Equal(t, body, string(readBuffer[:n]))
assert.Equal(t, writeN, n)
}
func proxyTCP(ctx context.Context, originProxy connection.OriginProxy, originAddr string, w http.ResponseWriter, reqBody io.ReadCloser) error {
@ -458,14 +512,15 @@ func serveTCPOrigin(t *testing.T, tcpOrigin net.Listener, wg *sync.WaitGroup) {
}
}
// nolint: testifylint // this is used inside go routines so it can't use `require.`
func echoTCP(t *testing.T, conn net.Conn) {
readBuf := make([]byte, 1000)
readN, err := conn.Read(readBuf)
require.NoError(t, err)
assert.NoError(t, err)
writeN, err := conn.Write(readBuf[:readN])
require.NoError(t, err)
require.Equal(t, readN, writeN)
assert.NoError(t, err)
assert.Equal(t, readN, writeN)
}
type validateHostHandler struct {
@ -479,16 +534,17 @@ func (vhh *validateHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(vhh.body))
_, _ = w.Write([]byte(vhh.body))
}
// nolint: testifylint // this is used inside go routines so it can't use `require.`
func updateWithValidation(t *testing.T, orchestrator *Orchestrator, version int32, config []byte) {
resp := orchestrator.UpdateConfig(version, config)
require.NoError(t, resp.Err)
require.Equal(t, version, resp.LastAppliedVersion)
assert.NoError(t, resp.Err)
assert.Equal(t, version, resp.LastAppliedVersion)
}
// TestClosePreviousProxies makes sure proxies started in the pervious configuration version are shutdown
// TestClosePreviousProxies makes sure proxies started in the previous configuration version are shutdown
func TestClosePreviousProxies(t *testing.T) {
var (
hostname = "hello.tunnel1.org"
@ -532,6 +588,7 @@ func TestClosePreviousProxies(t *testing.T) {
originProxyV1, err := orchestrator.GetOriginProxy()
require.NoError(t, err)
// nolint: bodyclose
resp, err := proxyHTTP(originProxyV1, hostname)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
@ -540,12 +597,14 @@ func TestClosePreviousProxies(t *testing.T) {
originProxyV2, err := orchestrator.GetOriginProxy()
require.NoError(t, err)
// nolint: bodyclose
resp, err = proxyHTTP(originProxyV2, hostname)
require.NoError(t, err)
require.Equal(t, http.StatusTeapot, resp.StatusCode)
// The hello-world server in config v1 should have been stopped. We wait a bit since it's closed asynchronously.
time.Sleep(time.Millisecond * 10)
// nolint: bodyclose
resp, err = proxyHTTP(originProxyV1, hostname)
require.Error(t, err)
require.Nil(t, resp)
@ -557,6 +616,7 @@ func TestClosePreviousProxies(t *testing.T) {
require.NoError(t, err)
require.NotEqual(t, originProxyV1, originProxyV3)
// nolint: bodyclose
resp, err = proxyHTTP(originProxyV3, hostname)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
@ -566,6 +626,7 @@ func TestClosePreviousProxies(t *testing.T) {
// Wait for proxies to shutdown
time.Sleep(time.Millisecond * 10)
// nolint: bodyclose
resp, err = proxyHTTP(originProxyV3, hostname)
require.Error(t, err)
require.Nil(t, resp)
@ -622,7 +683,7 @@ func TestPersistentConnection(t *testing.T) {
go func() {
defer wg.Done()
conn, err := tcpOrigin.Accept()
require.NoError(t, err)
assert.NoError(t, err)
defer conn.Close()
// Expect 3 TCP messages
@ -630,26 +691,26 @@ func TestPersistentConnection(t *testing.T) {
echoTCP(t, conn)
}
}()
// Simulate cloudflared recieving a TCP connection
// Simulate cloudflared receiving a TCP connection
go func() {
defer wg.Done()
require.NoError(t, proxyTCP(ctx, originProxy, tcpOrigin.Addr().String(), tcpRespReadWriter, tcpReqReader))
assert.NoError(t, proxyTCP(ctx, originProxy, tcpOrigin.Addr().String(), tcpRespReadWriter, tcpReqReader))
}()
// Simulate cloudflared recieving a WS connection
// Simulate cloudflared receiving a WS connection
go func() {
defer wg.Done()
req, err := http.NewRequest(http.MethodGet, hostname, wsReqReader)
require.NoError(t, err)
assert.NoError(t, err)
// ProxyHTTP will add Connection, Upgrade and Sec-Websocket-Version headers
req.Header.Add("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
log := zerolog.Nop()
respWriter, err := connection.NewHTTP2RespWriter(req, wsRespReadWriter, connection.TypeWebsocket, &log)
require.NoError(t, err)
assert.NoError(t, err)
err = originProxy.ProxyHTTP(respWriter, tracing.NewTracedHTTPRequest(req, 0, &log), true)
require.NoError(t, err)
assert.NoError(t, err)
}()
// Simulate eyeball WS and TCP connections

View File

@ -9,10 +9,14 @@ import (
"time"
"github.com/pkg/errors"
pkgerrors "github.com/pkg/errors"
"github.com/rs/zerolog"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/management"
"github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/cfio"
"github.com/cloudflare/cloudflared/connection"
@ -32,8 +36,8 @@ const (
type Proxy struct {
ingressRules ingress.Ingress
warpRouting *ingress.WarpRoutingService
management *ingress.ManagementService
tags []pogs.Tag
flowLimiter cfdflow.Limiter
log *zerolog.Logger
}
@ -42,12 +46,14 @@ func NewOriginProxy(
ingressRules ingress.Ingress,
warpRouting ingress.WarpRoutingConfig,
tags []pogs.Tag,
flowLimiter cfdflow.Limiter,
writeTimeout time.Duration,
log *zerolog.Logger,
) *Proxy {
proxy := &Proxy{
ingressRules: ingressRules,
tags: tags,
flowLimiter: flowLimiter,
log: log,
}
@ -64,7 +70,7 @@ func (p *Proxy) applyIngressMiddleware(rule *ingress.Rule, r *http.Request, w co
}
if result.ShouldFilterRequest {
w.WriteRespHeaders(result.StatusCode, nil)
_ = w.WriteRespHeaders(result.StatusCode, nil)
return fmt.Errorf("request filtered by middleware handler (%s) due to: %s", handler.Name(), result.Reason), true
}
}
@ -152,10 +158,18 @@ func (p *Proxy) ProxyTCP(
return err
}
logger := newTCPLogger(p.log, req)
// Try to start a new flow
if err := p.flowLimiter.Acquire(management.TCP.String()); err != nil {
logger.Warn().Msg("Too many concurrent flows being handled, rejecting tcp proxy")
return pkgerrors.Wrap(err, "failed to start tcp flow due to rate limiting")
}
defer p.flowLimiter.Release()
serveCtx, cancel := context.WithCancel(ctx)
defer cancel()
logger := newTCPLogger(p.log, req)
tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, &logger)
logger.Debug().Msg("tcp proxy stream started")

View File

@ -21,8 +21,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"
"go.uber.org/mock/gomock"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/mocks"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/cfio"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
@ -71,11 +76,6 @@ func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
}
// respHeaders is a test function to read respHeaders
func (w *mockHTTPRespWriter) headers() http.Header {
return w.Header()
}
func (m *mockHTTPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
panic("Hijack not implemented")
}
@ -113,7 +113,7 @@ func (w *mockWSRespWriter) Read(data []byte) (int, error) {
return w.reader.Read(data)
}
func (m *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (w *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
panic("Hijack not implemented")
}
@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, time.Duration(0), &log)
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
t.Run("testProxyHTTP", testProxyHTTP(proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
t.Run("testProxySSE", testProxySSE(proxy))
@ -246,7 +246,7 @@ func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
_ = responseWriter.Close()
close(finished)
errGroup.Wait()
_ = errGroup.Wait()
}
}
@ -267,7 +267,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
defer wg.Done()
log := zerolog.Nop()
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false)
require.Equal(t, err.Error(), "context canceled")
require.Equal(t, "context canceled", err.Error())
require.Equal(t, http.StatusOK, responseWriter.Code)
}()
@ -275,7 +275,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
for i := 0; i < pushCount; i++ {
line := responseWriter.ReadBytes()
expect := fmt.Sprintf("%d\n\n", i)
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
require.Equal(t, []byte(expect), line, "Expect to read %v, got %v", expect, line)
}
cancel()
@ -290,7 +290,9 @@ func TestProxySSEAllData(t *testing.T) {
responseWriter := newMockSSERespWriter()
// responseWriter uses an unbuffered channel, so we call in a different go-routine
go cfio.Copy(responseWriter, eyeballReader)
go func() {
_, _ = cfio.Copy(responseWriter, eyeballReader)
}()
result := string(<-responseWriter.writeNotification)
require.Equal(t, "data\r\r", result)
@ -366,7 +368,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
ctx, cancel := context.WithCancel(context.Background())
require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, time.Duration(0), &log)
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
for _, test := range tests {
responseWriter := newMockHTTPRespWriter()
@ -414,25 +416,20 @@ func TestProxyError(t *testing.T) {
log := zerolog.Nop()
proxy := NewOriginProxy(ing, noWarpRouting, testTags, time.Duration(0), &log)
proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
assert.NoError(t, err)
require.NoError(t, err)
assert.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false))
require.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false))
}
type replayer struct {
sync.RWMutex
writeDone chan struct{}
rw *bytes.Buffer
}
func newReplayer(buffer *bytes.Buffer) {
}
func (r *replayer) Read(p []byte) (int, error) {
r.RLock()
defer r.RUnlock()
@ -471,7 +468,7 @@ func (r *replayer) Bytes() []byte {
// eyeball sends tcp packets wrapped in websockets. (E.g: cloudflared access).
func TestConnections(t *testing.T) {
logger := logger.Create(nil)
replayer := &replayer{rw: &bytes.Buffer{}}
replayer := &replayer{rw: bytes.NewBuffer([]byte{})}
type args struct {
ingressServiceScheme string
originService func(*testing.T, net.Listener)
@ -486,6 +483,9 @@ func TestConnections(t *testing.T) {
// requestheaders to be sent in the call to proxy.Proxy
requestHeaders http.Header
// flowLimiterResponse is the response of the cfdflow.Limiter#Acquire method call
flowLimiterResponse error
}
type want struct {
@ -663,6 +663,25 @@ func TestConnections(t *testing.T) {
err: true,
},
},
{
name: "tcp-* proxy rate limited flow",
args: args{
ingressServiceScheme: "tcp://",
originService: runEchoTCPService,
eyeballResponseWriter: newTCPRespWriter(replayer),
eyeballRequestBody: newTCPRequestBody([]byte("rate-limited")),
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
connectionType: connection.TypeTCP,
requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
},
flowLimiterResponse: cfdflow.ErrTooManyActiveFlows,
},
want: want{
message: []byte{},
err: true,
},
},
}
for _, test := range tests {
@ -674,8 +693,16 @@ func TestConnections(t *testing.T) {
test.args.originService(t, ln)
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
ingressRule.StartOrigins(logger, ctx.Done())
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, time.Duration(0), logger)
_ = ingressRule.StartOrigins(logger, ctx.Done())
// Mock flow limiter
ctrl := gomock.NewController(t)
defer ctrl.Finish()
flowLimiter := mocks.NewMockLimiter(ctrl)
flowLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.flowLimiterResponse)
flowLimiter.EXPECT().Release().AnyTimes()
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, flowLimiter, time.Duration(0), logger)
proxy.warpRouting = test.args.warpRoutingService
dest := ln.Addr().String()
@ -693,7 +720,7 @@ func TestConnections(t *testing.T) {
respWriter = newTCPRespWriter(pipedReqBody.pipedConn)
go func() {
resp := pipedReqBody.roundtrip(test.args.ingressServiceScheme + ln.Addr().String())
replayer.Write(resp)
_, _ = replayer.Write(resp)
}()
}
if test.args.connectionType == connection.TypeTCP {
@ -705,9 +732,9 @@ func TestConnections(t *testing.T) {
}
cancel()
assert.Equal(t, test.want.err, err != nil)
assert.Equal(t, test.want.message, replayer.Bytes())
assert.Equal(t, test.want.headers, respWriter.Header())
require.Equal(t, test.want.err, err != nil)
require.Equal(t, test.want.message, replayer.Bytes())
require.Equal(t, test.want.headers, respWriter.Header())
replayer.rw.Reset()
})
}
@ -720,7 +747,9 @@ type requestBody struct {
func newWSRequestBody(data []byte) *requestBody {
pr, pw := io.Pipe()
go wsutil.WriteClientBinary(pw, data)
go func() {
_ = wsutil.WriteClientBinary(pw, data)
}()
return &requestBody{
pr: pr,
pw: pw,
@ -728,7 +757,9 @@ func newWSRequestBody(data []byte) *requestBody {
}
func newTCPRequestBody(data []byte) *requestBody {
pr, pw := io.Pipe()
go pw.Write(data)
go func() {
_, _ = pw.Write(data)
}()
return &requestBody{
pr: pr,
pw: pw,
@ -740,8 +771,8 @@ func (r *requestBody) Read(p []byte) (n int, err error) {
}
func (r *requestBody) Close() error {
r.pw.Close()
r.pr.Close()
_ = r.pw.Close()
_ = r.pr.Close()
return nil
}
@ -774,6 +805,7 @@ func (p *pipedRequestBody) roundtrip(addr string) []byte {
panic(err)
}
defer conn.Close()
defer resp.Body.Close()
if resp.StatusCode != http.StatusSwitchingProtocols {
panic(fmt.Errorf("resp returned status code: %d", resp.StatusCode))
@ -917,7 +949,9 @@ func runEchoTCPService(t *testing.T, l net.Listener) {
go func() {
for {
conn, err := l.Accept()
require.NoError(t, err)
if err != nil {
panic(err)
}
defer conn.Close()
for {
@ -971,12 +1005,15 @@ func runEchoWSService(t *testing.T, l net.Listener) {
}
}
// nolint: gosec
server := http.Server{
Handler: http.HandlerFunc(ws),
}
go func() {
err := server.Serve(l)
require.NoError(t, err)
if err != nil {
panic(err)
}
}()
}

View File

@ -116,7 +116,7 @@ func (s *UDPSessionRegistrationDatagram) MarshalBinary() (data []byte, err error
data = make([]byte, sessionRegistrationIPv4DatagramHeaderLen+len(s.Payload))
}
data[0] = byte(UDPSessionRegistrationType)
data[1] = byte(flags)
data[1] = flags
binary.BigEndian.PutUint16(data[2:4], s.Dest.Port())
binary.BigEndian.PutUint16(data[4:6], uint16(s.IdleDurationHint.Seconds()))
err = s.RequestID.MarshalBinaryTo(data[6:22])
@ -284,6 +284,8 @@ const (
ResponseDestinationUnreachable SessionRegistrationResp = 0x01
// Session registration was unable to bind to a local UDP socket.
ResponseUnableToBindSocket SessionRegistrationResp = 0x02
// Session registration failed due to the number of flows being higher than the limit.
ResponseTooManyActiveFlows SessionRegistrationResp = 0x03
// Session registration failed with an unexpected error but provided a message.
ResponseErrorWithMsg SessionRegistrationResp = 0xff
)
@ -311,6 +313,7 @@ func (s *UDPSessionRegistrationResponseDatagram) MarshalBinary() (data []byte, e
if len(s.ErrorMsg) > maxResponseErrorMessageLen {
return nil, wrapMarshalErr(ErrDatagramResponseMsgInvalidSize)
}
// nolint: gosec
errMsgLen := uint16(len(s.ErrorMsg))
data = make([]byte, datagramSessionRegistrationResponseLen+errMsgLen)

View File

@ -7,6 +7,10 @@ import (
"sync"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/management"
cfdflow "github.com/cloudflare/cloudflared/flow"
)
var (
@ -16,6 +20,8 @@ var (
ErrSessionBoundToOtherConn = errors.New("flow is in use by another connection")
// ErrSessionAlreadyRegistered is returned when a registration already exists for this connection.
ErrSessionAlreadyRegistered = errors.New("flow is already registered for this connection")
// ErrSessionRegistrationRateLimited is returned when a registration fails due to rate limiting on the number of active flows.
ErrSessionRegistrationRateLimited = errors.New("flow registration rate limited")
)
type SessionManager interface {
@ -38,14 +44,16 @@ type sessionManager struct {
sessions map[RequestID]Session
mutex sync.RWMutex
originDialer DialUDP
limiter cfdflow.Limiter
metrics Metrics
log *zerolog.Logger
}
func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP) SessionManager {
func NewSessionManager(metrics Metrics, log *zerolog.Logger, originDialer DialUDP, limiter cfdflow.Limiter) SessionManager {
return &sessionManager{
sessions: make(map[RequestID]Session),
originDialer: originDialer,
limiter: limiter,
metrics: metrics,
log: log,
}
@ -61,6 +69,12 @@ func (s *sessionManager) RegisterSession(request *UDPSessionRegistrationDatagram
}
return nil, ErrSessionBoundToOtherConn
}
// Try to start a new session
if err := s.limiter.Acquire(management.UDP.String()); err != nil {
return nil, ErrSessionRegistrationRateLimited
}
// Attempt to bind the UDP socket for the new session
origin, err := s.originDialer(request.Dest)
if err != nil {
@ -100,4 +114,5 @@ func (s *sessionManager) UnregisterSession(requestID RequestID) {
_ = session.Close()
}
delete(s.sessions, requestID)
s.limiter.Release()
}

View File

@ -8,14 +8,19 @@ import (
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/cloudflare/cloudflared/mocks"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/ingress"
v3 "github.com/cloudflare/cloudflared/quic/v3"
)
func TestRegisterSession(t *testing.T) {
log := zerolog.Nop()
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort)
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0))
request := v3.UDPSessionRegistrationDatagram{
RequestID: testRequestID,
@ -71,10 +76,32 @@ func TestRegisterSession(t *testing.T) {
func TestGetSession_Empty(t *testing.T) {
log := zerolog.Nop()
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort)
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0))
_, err := manager.GetSession(testRequestID)
if !errors.Is(err, v3.ErrSessionNotFound) {
t.Fatalf("get session find no session: %v", err)
}
}
func TestRegisterSessionRateLimit(t *testing.T) {
log := zerolog.Nop()
ctrl := gomock.NewController(t)
flowLimiterMock := mocks.NewMockLimiter(ctrl)
flowLimiterMock.EXPECT().Acquire("udp").Return(cfdflow.ErrTooManyActiveFlows)
flowLimiterMock.EXPECT().Release().Times(0)
manager := v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, flowLimiterMock)
request := v3.UDPSessionRegistrationDatagram{
RequestID: testRequestID,
Dest: netip.MustParseAddrPort("127.0.0.1:5000"),
Traced: false,
IdleDurationHint: 5 * time.Second,
Payload: nil,
}
_, err := manager.RegisterSession(&request, &noopEyeball{})
require.ErrorIs(t, err, v3.ErrSessionRegistrationRateLimited)
}

View File

@ -143,8 +143,6 @@ func (c *datagramConn) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.Raw
return c.SendICMPPacket(c.icmpRouter.ConvertToTTLExceeded(icmp, rawPacket))
}
var errReadTimeout error = errors.New("receive datagram timeout")
// pollDatagrams will read datagrams from the underlying connection until the provided context is done.
func (c *datagramConn) pollDatagrams(ctx context.Context) {
for ctx.Err() == nil {
@ -256,8 +254,12 @@ func (c *datagramConn) handleSessionRegistrationDatagram(ctx context.Context, da
// Session is already registered but to a different connection
c.handleSessionMigration(datagram.RequestID, &log)
return
case ErrSessionRegistrationRateLimited:
// There are too many concurrent sessions so we return an error to force a retry later
c.handleSessionRegistrationRateLimited(datagram, &log)
return
default:
log.Err(err).Msgf("flow registration failure")
log.Err(err).Msg("flow registration failure")
c.handleSessionRegistrationFailure(datagram.RequestID, &log)
return
}
@ -278,7 +280,7 @@ func (c *datagramConn) handleSessionRegistrationDatagram(ctx context.Context, da
// [Session.Serve] is blocking and will continue this go routine till the end of the session lifetime.
start := time.Now()
err = session.Serve(ctx)
elapsedMS := time.Now().Sub(start).Milliseconds()
elapsedMS := time.Since(start).Milliseconds()
log = log.With().Int64(logDurationKey, elapsedMS).Logger()
if err == nil {
// We typically don't expect a session to close without some error response. [SessionIdleErr] is the typical
@ -346,6 +348,16 @@ func (c *datagramConn) handleSessionRegistrationFailure(requestID RequestID, log
}
}
func (c *datagramConn) handleSessionRegistrationRateLimited(datagram *UDPSessionRegistrationDatagram, logger *zerolog.Logger) {
c.logger.Warn().Msg("Too many concurrent sessions being handled, rejecting udp proxy")
rateLimitResponse := ResponseTooManyActiveFlows
err := c.SendUDPSessionResponse(datagram.RequestID, rateLimitResponse)
if err != nil {
logger.Err(err).Msgf("unable to send flow registration error response (%d)", rateLimitResponse)
}
}
// Handles incoming datagrams that need to be sent to a registered session.
func (c *datagramConn) handleSessionPayloadDatagram(datagram *UDPSessionPayloadDatagram, logger *zerolog.Logger) {
s, err := c.sessionManager.GetSession(datagram.RequestID)

View File

@ -13,13 +13,14 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/google/gopacket/layers"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/packet"
v3 "github.com/cloudflare/cloudflared/quic/v3"
@ -87,7 +88,7 @@ func (m *mockEyeball) SendICMPTTLExceed(icmp *packet.ICMP, rawPacket packet.RawP
func TestDatagramConn_New(t *testing.T) {
log := zerolog.Nop()
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
conn := v3.NewDatagramConn(newMockQuicConn(), v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
if conn == nil {
t.Fatal("expected valid connection")
}
@ -96,10 +97,12 @@ func TestDatagramConn_New(t *testing.T) {
func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
log := zerolog.Nop()
quic := newMockQuicConn()
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
payload := []byte{0xef, 0xef}
conn.SendUDPSessionDatagram(payload)
err := conn.SendUDPSessionDatagram(payload)
require.NoError(t, err)
p := <-quic.recv
if !slices.Equal(p, payload) {
t.Fatal("datagram sent does not match datagram received on quic side")
@ -109,15 +112,16 @@ func TestDatagramConn_SendUDPSessionDatagram(t *testing.T) {
func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
log := zerolog.Nop()
quic := newMockQuicConn()
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
err := conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
require.NoError(t, err)
conn.SendUDPSessionResponse(testRequestID, v3.ResponseDestinationUnreachable)
resp := <-quic.recv
var response v3.UDPSessionRegistrationResponseDatagram
err := response.UnmarshalBinary(resp)
if err != nil {
t.Fatal(err)
}
err = response.UnmarshalBinary(resp)
require.NoError(t, err)
expected := v3.UDPSessionRegistrationResponseDatagram{
RequestID: testRequestID,
ResponseType: v3.ResponseDestinationUnreachable,
@ -130,7 +134,7 @@ func TestDatagramConn_SendUDPSessionResponse(t *testing.T) {
func TestDatagramConnServe_ApplicationClosed(t *testing.T) {
log := zerolog.Nop()
quic := newMockQuicConn()
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
@ -146,7 +150,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
quic.ctx = ctx
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
err := conn.Serve(context.Background())
if !errors.Is(err, context.DeadlineExceeded) {
@ -157,7 +161,7 @@ func TestDatagramConnServe_ConnectionClosed(t *testing.T) {
func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
log := zerolog.Nop()
quic := &mockQuicConnReadError{err: net.ErrClosed}
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
conn := v3.NewDatagramConn(quic, v3.NewSessionManager(&noopMetrics{}, &log, ingress.DialUDPAddrPort, cfdflow.NewLimiter(0)), &noopICMPRouter{}, 0, &noopMetrics{}, &log)
err := conn.Serve(context.Background())
if !errors.Is(err, net.ErrClosed) {
@ -165,6 +169,38 @@ func TestDatagramConnServe_ReceiveDatagramError(t *testing.T) {
}
}
func TestDatagramConnServe_SessionRegistrationRateLimit(t *testing.T) {
log := zerolog.Nop()
quic := newMockQuicConn()
sessionManager := &mockSessionManager{
expectedRegErr: v3.ErrSessionRegistrationRateLimited,
}
conn := v3.NewDatagramConn(quic, sessionManager, &noopICMPRouter{}, 0, &noopMetrics{}, &log)
// Setup the muxer
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan error, 1)
go func() {
done <- conn.Serve(ctx)
}()
// Send new session registration
datagram := newRegisterSessionDatagram(testRequestID)
quic.send <- datagram
// Wait for session registration response with failure
datagram = <-quic.recv
var resp v3.UDPSessionRegistrationResponseDatagram
err := resp.UnmarshalBinary(datagram)
if err != nil {
t.Fatal(err)
}
require.EqualValues(t, testRequestID, resp.RequestID)
require.EqualValues(t, v3.ResponseTooManyActiveFlows, resp.ResponseType)
}
func TestDatagramConnServe_ErrorDatagramTypes(t *testing.T) {
for _, test := range []struct {
name string
@ -354,12 +390,10 @@ func TestDatagramConnServeDecodeMultipleICMPInParallel(t *testing.T) {
var receivedPackets []*packet.ICMP
go func() {
for ctx.Err() == nil {
select {
case icmpPacket := <-router.recv:
icmpPacket := <-router.recv
receivedPackets = append(receivedPackets, icmpPacket)
wg.Done()
}
}
}()
for _, p := range packets {
@ -677,7 +711,7 @@ func TestDatagramConnServe_ICMPDatagram_TTLExceeded(t *testing.T) {
datagram := newICMPDatagram(expectedICMP)
quic.send <- datagram
// Origin should not recieve a packet
// Origin should not receive a packet
select {
case <-router.recv:
t.Fatalf("TTL should be expired and no origin ICMP sent")
@ -719,18 +753,6 @@ func newRegisterSessionDatagram(id v3.RequestID) []byte {
return payload
}
func newRegisterResponseSessionDatagram(id v3.RequestID, resp v3.SessionRegistrationResp) []byte {
datagram := v3.UDPSessionRegistrationResponseDatagram{
RequestID: id,
ResponseType: resp,
}
payload, err := datagram.MarshalBinary()
if err != nil {
panic(err)
}
return payload
}
func newSessionPayloadDatagram(id v3.RequestID, payload []byte) []byte {
datagram := make([]byte, len(payload)+17)
err := v3.MarshalPayloadHeaderTo(id, datagram[:])

View File

@ -346,7 +346,7 @@ def parse_args():
)
parser.add_argument(
"--deb-based-releases", default=["bookworm", "bullseye", "buster", "jammy", "impish", "focal", "bionic",
"--deb-based-releases", default=["any", "bookworm", "bullseye", "buster", "noble", "jammy", "impish", "focal", "bionic",
"xenial", "trusty"],
help="list of debian based releases that need to be packaged for"
)

View File

@ -79,8 +79,8 @@ func (b *BackoffHandler) BackoffTimer() <-chan time.Time {
} else {
b.retries++
}
maxTimeToWait := time.Duration(b.GetBaseTime() * 1 << (b.retries))
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds()))
maxTimeToWait := b.GetBaseTime() * (1 << b.retries)
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds())) // #nosec G404
return b.Clock.After(timeToWait)
}
@ -99,11 +99,11 @@ func (b *BackoffHandler) Backoff(ctx context.Context) bool {
}
}
// Sets a grace period within which the the backoff timer is maintained. After the grace
// Sets a grace period within which the backoff timer is maintained. After the grace
// period expires, the number of retries & backoff duration is reset.
func (b *BackoffHandler) SetGracePeriod() time.Duration {
maxTimeToWait := b.GetBaseTime() * 2 << (b.retries + 1)
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds()))
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds())) // #nosec G404
b.resetDeadline = b.Clock.Now().Add(timeToWait)
return timeToWait
@ -118,7 +118,7 @@ func (b BackoffHandler) GetBaseTime() time.Duration {
// Retries returns the number of retries consumed so far.
func (b *BackoffHandler) Retries() int {
return int(b.retries)
return int(b.retries) // #nosec G115
}
func (b *BackoffHandler) ReachedMaxRetries() bool {

View File

@ -7,30 +7,53 @@ import (
"github.com/cloudflare/cloudflared/features"
)
// When experimental post-quantum tunnels are enabled, and we're hitting an
// issue creating the tunnel, we'll report the first error
// to https://pqtunnels.cloudflareresearch.com.
const (
PQKex = tls.CurveID(0x6399) // X25519Kyber768Draft00
PQKexName = "X25519Kyber768Draft00"
X25519Kyber768Draft00PQKex = tls.CurveID(0x6399) // X25519Kyber768Draft00
X25519Kyber768Draft00PQKexName = "X25519Kyber768Draft00"
P256Kyber768Draft00PQKex = tls.CurveID(0xfe32) // P256Kyber768Draft00
P256Kyber768Draft00PQKexName = "P256Kyber768Draft00"
X25519MLKEM768PQKex = tls.CurveID(0x11ec) // X25519MLKEM768
X25519MLKEM768PQKexName = "X25519MLKEM768"
)
func curvePreference(pqMode features.PostQuantumMode, currentCurve []tls.CurveID) ([]tls.CurveID, error) {
var (
nonFipsPostQuantumStrictPKex []tls.CurveID = []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex}
nonFipsPostQuantumPreferPKex []tls.CurveID = []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex}
fipsPostQuantumStrictPKex []tls.CurveID = []tls.CurveID{P256Kyber768Draft00PQKex}
fipsPostQuantumPreferPKex []tls.CurveID = []tls.CurveID{P256Kyber768Draft00PQKex, tls.CurveP256}
)
func removeDuplicates(curves []tls.CurveID) []tls.CurveID {
bucket := make(map[tls.CurveID]bool)
var result []tls.CurveID
for _, curve := range curves {
if _, ok := bucket[curve]; !ok {
bucket[curve] = true
result = append(result, curve)
}
}
return result
}
func curvePreference(pqMode features.PostQuantumMode, fipsEnabled bool, currentCurve []tls.CurveID) ([]tls.CurveID, error) {
switch pqMode {
case features.PostQuantumStrict:
// If the user passes the -post-quantum flag, we override
// CurvePreferences to only support hybrid post-quantum key agreements.
return []tls.CurveID{PQKex}, nil
if fipsEnabled {
return fipsPostQuantumStrictPKex, nil
}
return nonFipsPostQuantumStrictPKex, nil
case features.PostQuantumPrefer:
if len(currentCurve) == 0 {
return []tls.CurveID{PQKex}, nil
if fipsEnabled {
// Ensure that all curves returned are FIPS compliant.
// Moreover the first curves are post-quantum and then the
// non post-quantum.
return fipsPostQuantumPreferPKex, nil
}
if currentCurve[0] != PQKex {
return append([]tls.CurveID{PQKex}, currentCurve...), nil
}
return currentCurve, nil
curves := append(nonFipsPostQuantumPreferPKex, currentCurve...)
curves = removeDuplicates(curves)
return curves, nil
default:
return nil, fmt.Errorf("Unexpected post quantum mode")
}

View File

@ -0,0 +1,84 @@
package supervisor
import (
"crypto/tls"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/features"
)
func TestCurvePreferences(t *testing.T) {
// This tests if the correct curves are returned
// given a PostQuantumMode and a FIPS enabled bool
t.Parallel()
tests := []struct {
name string
currentCurves []tls.CurveID
expectedCurves []tls.CurveID
pqMode features.PostQuantumMode
fipsEnabled bool
}{
{
name: "FIPS with Prefer PQ",
pqMode: features.PostQuantumPrefer,
fipsEnabled: true,
currentCurves: []tls.CurveID{tls.CurveP384},
expectedCurves: []tls.CurveID{P256Kyber768Draft00PQKex, tls.CurveP256},
},
{
name: "FIPS with Strict PQ",
pqMode: features.PostQuantumStrict,
fipsEnabled: true,
currentCurves: []tls.CurveID{tls.CurveP256, tls.CurveP384},
expectedCurves: []tls.CurveID{P256Kyber768Draft00PQKex},
},
{
name: "FIPS with Prefer PQ - no duplicates",
pqMode: features.PostQuantumPrefer,
fipsEnabled: true,
currentCurves: []tls.CurveID{tls.CurveP256},
expectedCurves: []tls.CurveID{P256Kyber768Draft00PQKex, tls.CurveP256},
},
{
name: "Non FIPS with Prefer PQ",
pqMode: features.PostQuantumPrefer,
fipsEnabled: false,
currentCurves: []tls.CurveID{tls.CurveP256},
expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex, tls.CurveP256},
},
{
name: "Non FIPS with Prefer PQ - no duplicates",
pqMode: features.PostQuantumPrefer,
fipsEnabled: false,
currentCurves: []tls.CurveID{X25519Kyber768Draft00PQKex, tls.CurveP256},
expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex, tls.CurveP256},
},
{
name: "Non FIPS with Prefer PQ - correct preference order",
pqMode: features.PostQuantumPrefer,
fipsEnabled: false,
currentCurves: []tls.CurveID{tls.CurveP256, X25519Kyber768Draft00PQKex},
expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex, tls.CurveP256},
},
{
name: "Non FIPS with Strict PQ",
pqMode: features.PostQuantumStrict,
fipsEnabled: false,
currentCurves: []tls.CurveID{tls.CurveP256},
expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex},
},
}
for _, tcase := range tests {
t.Run(tcase.name, func(t *testing.T) {
t.Parallel()
curves, err := curvePreference(tcase.pqMode, tcase.fipsEnabled, tcase.currentCurves)
require.NoError(t, err)
assert.Equal(t, tcase.expectedCurves, curves)
})
}
}

View File

@ -26,12 +26,6 @@ const (
tunnelRetryDuration = time.Second * 10
// Interval between registering new tunnels
registrationInterval = time.Second
subsystemRefreshAuth = "refresh_auth"
// Maximum exponent for 'Authenticate' exponential backoff
refreshAuthMaxBackoff = 10
// Waiting time before retrying a failed 'Authenticate' connection
refreshAuthRetryDuration = time.Second * 10
)
// Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and
@ -84,7 +78,7 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
edgeBindAddr := config.EdgeBindAddr
datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort)
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetFlowLimiter())
edgeTunnelServer := EdgeTunnelServer{
config: config,
@ -253,9 +247,7 @@ func (s *Supervisor) startFirstTunnel(
ctx context.Context,
connectedSignal *signal.Signal,
) {
var (
err error
)
var err error
const firstConnIndex = 0
isStaticEdge := len(s.config.EdgeAddrs) > 0
defer func() {
@ -306,13 +298,12 @@ func (s *Supervisor) startTunnel(
index int,
connectedSignal *signal.Signal,
) {
var (
err error
)
var err error
defer func() {
s.tunnelErrors <- tunnelError{index: index, err: err}
}()
// nolint: gosec
err = s.edgeTunnelServer.Serve(ctx, uint8(index), s.tunnelsProtocolFallback[index], connectedSignal)
}
@ -334,7 +325,3 @@ func (s *Supervisor) waitForNextTunnel(index int) bool {
}
return false
}
func (s *Supervisor) unusedIPs() bool {
return s.edgeIPs.AvailableAddrs() > s.config.HAConnections
}

View File

@ -7,11 +7,11 @@ import (
"net"
"net/netip"
"runtime/debug"
"slices"
"strings"
"sync"
"time"
"github.com/getsentry/sentry-go"
"github.com/pkg/errors"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
@ -21,6 +21,7 @@ import (
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
"github.com/cloudflare/cloudflared/features"
"github.com/cloudflare/cloudflared/fips"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/management"
"github.com/cloudflare/cloudflared/orchestration"
@ -460,6 +461,7 @@ func (e *EdgeTunnelServer) serveConnection(
switch protocol {
case connection.QUIC:
// nolint: gosec
connOptions := e.config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
return e.serveQUIC(ctx,
addr.UDP.AddrPort(),
@ -475,6 +477,7 @@ func (e *EdgeTunnelServer) serveConnection(
return err, true
}
// nolint: gosec
connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
if err := e.serveHTTP2(
ctx,
@ -554,15 +557,13 @@ func (e *EdgeTunnelServer) serveQUIC(
tlsConfig := e.config.EdgeTLSConfigs[connection.QUIC]
pqMode := e.config.FeatureSelector.PostQuantumMode()
if pqMode == features.PostQuantumStrict || pqMode == features.PostQuantumPrefer {
connOptions.Client.Features = features.Dedup(append(connOptions.Client.Features, features.FeaturePostQuantum))
}
curvePref, err := curvePreference(pqMode, tlsConfig.CurvePreferences)
curvePref, err := curvePreference(pqMode, fips.IsFipsEnabled(), tlsConfig.CurvePreferences)
if err != nil {
return err, true
}
connLogger.Logger().Info().Msgf("Using %v as curve preferences", curvePref)
tlsConfig.CurvePreferences = curvePref
// quic-go 0.44 increases the initial packet size to 1280 by default. That breaks anyone running tunnel through WARP
@ -598,11 +599,13 @@ func (e *EdgeTunnelServer) serveQUIC(
)
if err != nil {
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to dial a quic connection")
e.reportErrorToSentry(err)
return err, true
}
var datagramSessionManager connection.DatagramSessionHandler
if slices.Contains(connOptions.Client.Features, features.FeatureDatagramV3) {
if e.config.FeatureSelector.DatagramVersion() == features.DatagramV3 {
datagramSessionManager = connection.NewDatagramV3Connection(
ctx,
conn,
@ -620,6 +623,7 @@ func (e *EdgeTunnelServer) serveQUIC(
connIndex,
e.config.RPCTimeout,
e.config.WriteStreamTimeout,
e.orchestrator.GetFlowLimiter(),
connLogger.Logger(),
)
}
@ -666,6 +670,26 @@ func (e *EdgeTunnelServer) serveQUIC(
return errGroup.Wait(), false
}
// The reportErrorToSentry is an helper function that handles
// verifies if an error should be reported to Sentry.
func (e *EdgeTunnelServer) reportErrorToSentry(err error) {
dialErr, ok := err.(*connection.EdgeQuicDialError)
if ok {
// The TransportError provides an Unwrap function however
// the err MAY not always be set
transportErr, ok := dialErr.Cause.(*quic.TransportError)
if ok &&
transportErr.ErrorCode.IsCryptoError() &&
fips.IsFipsEnabled() &&
e.config.FeatureSelector.PostQuantumMode() == features.PostQuantumStrict {
// Only report to Sentry when using FIPS, PQ,
// and the error is a Crypto error reported by
// an EdgeQuicDialError
sentry.CaptureException(err)
}
}
}
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh <-chan struct{}) error {
select {
case reconnect := <-reconnectCh:

View File

@ -53,7 +53,7 @@ type signalHandler struct {
}
type jwtPayload struct {
Aud []string `json:"aud"`
Aud []string `json:"-"`
Email string `json:"email"`
Exp int `json:"exp"`
Iat int `json:"iat"`
@ -68,6 +68,34 @@ type transferServiceResponse struct {
OrgToken string `json:"org_token"`
}
func (p *jwtPayload) UnmarshalJSON(data []byte) error {
type Alias jwtPayload
if err := json.Unmarshal(data, (*Alias)(p)); err != nil {
return err
}
var audParser struct {
Aud any `json:"aud"`
}
if err := json.Unmarshal(data, &audParser); err != nil {
return err
}
switch aud := audParser.Aud.(type) {
case string:
p.Aud = []string{aud}
case []any:
for _, a := range aud {
s, ok := a.(string)
if !ok {
return errors.New("aud array contains non-string elements")
}
p.Aud = append(p.Aud, s)
}
default:
return errors.New("aud field is not a string or an array of strings")
}
return nil
}
func (p jwtPayload) isExpired() bool {
return int(time.Now().Unix()) > p.Exp
}
@ -182,7 +210,9 @@ func getToken(appURL *url.URL, appInfo *AppInfo, useHostOnly bool, log *zerolog.
if err = fileLockAppToken.Acquire(); err != nil {
return "", errors.Wrap(err, "failed to acquire app token lock")
}
defer fileLockAppToken.Release()
defer func() {
_ = fileLockAppToken.Release()
}()
// check to see if another process has gotten a token while we waited for the lock
if token, err := GetAppTokenIfExists(appInfo); token != "" && err == nil {
@ -202,7 +232,9 @@ func getToken(appURL *url.URL, appInfo *AppInfo, useHostOnly bool, log *zerolog.
if err = fileLockOrgToken.Acquire(); err != nil {
return "", errors.Wrap(err, "failed to acquire org token lock")
}
defer fileLockOrgToken.Release()
defer func() {
_ = fileLockOrgToken.Release()
}()
// check if an org token has been created since the lock was acquired
orgToken, err = GetOrgTokenIfExists(appInfo.AuthDomain)
}
@ -218,7 +250,6 @@ func getToken(appURL *url.URL, appInfo *AppInfo, useHostOnly bool, log *zerolog.
}
}
return getTokensFromEdge(appURL, appInfo.AppAUD, appTokenPath, orgTokenPath, useHostOnly, log)
}
// getTokensFromEdge will attempt to use the transfer service to retrieve an app and org token, save them to disk,
@ -250,7 +281,6 @@ func getTokensFromEdge(appURL *url.URL, appAUD, appTokenPath, orgTokenPath strin
}
return resp.AppToken, nil
}
// GetAppInfo makes a request to the appURL and stops at the first redirect. The 302 location header will contain the
@ -320,7 +350,6 @@ func handleRedirects(req *http.Request, via []*http.Request, orgToken string) er
}
}
}
}
// stop after hitting authorized endpoint since it will contain the app token
@ -408,7 +437,6 @@ func GetAppTokenIfExists(appInfo *AppInfo) (string, error) {
return "", err
}
return token.CompactSerialize()
}
// GetTokenIfExists will return the token from local storage if it exists and not expired

View File

@ -1,6 +1,7 @@
package token
import (
"encoding/json"
"net/http"
"net/url"
"testing"
@ -11,7 +12,7 @@ func TestHandleRedirects_AttachOrgToken(t *testing.T) {
via := []*http.Request{}
orgToken := "orgTokenValue"
handleRedirects(req, via, orgToken)
_ = handleRedirects(req, via, orgToken)
// Check if the orgToken cookie is attached
cookies := req.Cookies()
@ -80,3 +81,55 @@ func TestHandleRedirects_StopAtAuthorizedEndpoint(t *testing.T) {
t.Errorf("Expected ErrUseLastResponse, got %v", err)
}
}
func TestJwtPayloadUnmarshal_AudAsString(t *testing.T) {
jwt := `{"aud":"7afbdaf987054f889b3bdd0d29ebfcd2"}`
var payload jwtPayload
if err := json.Unmarshal([]byte(jwt), &payload); err != nil {
t.Errorf("Expected no error, got %v", err)
}
if len(payload.Aud) != 1 || payload.Aud[0] != "7afbdaf987054f889b3bdd0d29ebfcd2" {
t.Errorf("Expected aud to be 7afbdaf987054f889b3bdd0d29ebfcd2, got %v", payload.Aud)
}
}
func TestJwtPayloadUnmarshal_AudAsSlice(t *testing.T) {
jwt := `{"aud":["7afbdaf987054f889b3bdd0d29ebfcd2", "f835c0016f894768976c01e076844efe"]}`
var payload jwtPayload
if err := json.Unmarshal([]byte(jwt), &payload); err != nil {
t.Errorf("Expected no error, got %v", err)
}
if len(payload.Aud) != 2 || payload.Aud[0] != "7afbdaf987054f889b3bdd0d29ebfcd2" || payload.Aud[1] != "f835c0016f894768976c01e076844efe" {
t.Errorf("Expected aud to be [7afbdaf987054f889b3bdd0d29ebfcd2, f835c0016f894768976c01e076844efe], got %v", payload.Aud)
}
}
func TestJwtPayloadUnmarshal_FailsWhenAudIsInt(t *testing.T) {
jwt := `{"aud":123}`
var payload jwtPayload
err := json.Unmarshal([]byte(jwt), &payload)
wantErr := "aud field is not a string or an array of strings"
if err.Error() != wantErr {
t.Errorf("Expected %v, got %v", wantErr, err)
}
}
func TestJwtPayloadUnmarshal_FailsWhenAudIsArrayOfInts(t *testing.T) {
jwt := `{"aud": [999, 123] }`
var payload jwtPayload
err := json.Unmarshal([]byte(jwt), &payload)
wantErr := "aud array contains non-string elements"
if err.Error() != wantErr {
t.Errorf("Expected %v, got %v", wantErr, err)
}
}
func TestJwtPayloadUnmarshal_FailsWhenAudIsOmitted(t *testing.T) {
jwt := `{}`
var payload jwtPayload
err := json.Unmarshal([]byte(jwt), &payload)
wantErr := "aud field is not a string or an array of strings"
if err.Error() != wantErr {
t.Errorf("Expected %v, got %v", wantErr, err)
}
}

View File

@ -70,7 +70,6 @@ func RunTransfer(transferURL *url.URL, appAUD, resourceName, key, value string,
}
return resourceData, nil
}
// BuildRequestURL creates a request suitable for a resource transfer.

View File

@ -18,6 +18,11 @@ const (
ConnectionTypeTCP
)
var (
// ErrorFlowConnectRateLimitedMetadata is the Metadata entry that allows to know if a request was rate limited on connect.
ErrorFlowConnectRateLimitedMetadata = Metadata{Key: "FlowConnectRateLimited", Val: "true"}
)
func (c ConnectionType) String() string {
switch c {
case ConnectionTypeHTTP:

View File

@ -38,6 +38,7 @@ func (rss *RequestServerStream) WriteConnectResponseData(respErr error, metadata
if respErr != nil {
connectResponse = &pogs.ConnectResponse{
Error: respErr.Error(),
Metadata: metadata,
}
} else {
connectResponse = &pogs.ConnectResponse{

View File

@ -98,12 +98,7 @@ func TestConnectResponseMeta(t *testing.T) {
reqClientStream := RequestClientStream{noopCloser{b}}
respMeta, err := reqClientStream.ReadConnectResponseData()
require.NoError(t, err)
if respMeta.Error == "" {
assert.Equal(t, test.metadata, respMeta.Metadata)
} else {
assert.Equal(t, 0, len(respMeta.Metadata))
}
require.Equal(t, test.metadata, respMeta.Metadata)
})
}
}
@ -153,21 +148,21 @@ func TestRegisterUdpSession(t *testing.T) {
}()
rpcClientStream, err := NewCloudflaredClient(context.Background(), clientStream, 5*time.Second)
assert.NoError(t, err)
require.NoError(t, err)
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
assert.NoError(t, err)
assert.NoError(t, reg.Err)
require.NoError(t, err)
require.NoError(t, reg.Err)
// Different sessionID, the RPC server should reject the registraion
// Different sessionID, the RPC server should reject the registration
reg, err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
assert.NoError(t, err)
assert.Error(t, reg.Err)
require.NoError(t, err)
require.Error(t, reg.Err)
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
require.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
// Different sessionID, the RPC server should reject the unregistraion
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
// Different sessionID, the RPC server should reject the unregistration
require.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
rpcClientStream.Close()
<-sessionRegisteredChan
@ -200,10 +195,10 @@ func TestManageConfiguration(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
rpcClientStream, err := NewCloudflaredClient(ctx, clientStream, 5*time.Second)
assert.NoError(t, err)
require.NoError(t, err)
result, err := rpcClientStream.UpdateConfiguration(ctx, version, config)
assert.NoError(t, err)
require.NoError(t, err)
require.Equal(t, version, result.LastAppliedVersion)
require.NoError(t, result.Err)

View File

@ -51,7 +51,7 @@ func (s *SessionManagerServer) Serve(ctx context.Context, stream io.ReadWriteClo
select {
case <-rpcConn.Done():
return rpcConn.Err()
return nil
case <-ctx.Done():
return ctx.Err()
}

View File

@ -8,7 +8,9 @@ import (
"fmt"
"io"
"net"
"os"
"reflect"
"strconv"
"sync"
"sync/atomic"
"time"
@ -288,6 +290,16 @@ var newConnection = func(
s.logger,
)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
// Allow server to define custom MaxUDPPayloadSize
maxUDPPayloadSize := protocol.MaxPacketBufferSize
if maxPacketSize := os.Getenv("TUNNEL_MAX_QUIC_PACKET_SIZE"); maxPacketSize != "" {
if customMaxPacketSize, err := strconv.ParseUint(maxPacketSize, 10, 64); err == nil {
maxUDPPayloadSize = int(customMaxPacketSize)
} else {
utils.DefaultLogger.Errorf("failed to parse TUNNEL_MAX_QUIC_PACKET_SIZE: %v", err)
}
}
params := &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@ -298,7 +310,7 @@ var newConnection = func(
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
AckDelayExponent: protocol.AckDelayExponent,
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
MaxUDPPayloadSize: protocol.ByteCount(maxUDPPayloadSize),
DisableActiveMigration: true,
StatelessResetToken: &statelessResetToken,
OriginalDestinationConnectionID: origDestConnID,

View File

@ -12,7 +12,9 @@ import (
// These cipher suite implementations are copied from the standard library crypto/tls package.
const aeadNonceLength = 12
const (
aeadNonceLength = 12
)
type cipherSuite struct {
ID uint16
@ -44,12 +46,13 @@ func aeadAESGCMTLS13(key, nonceMask []byte) *xorNonceAEAD {
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(aes)
aead, err := newAEAD(aes)
if err != nil {
panic(err)
}
ret := &xorNonceAEAD{aead: aead}
ret := &xorNonceAEAD{aead: aead, hasSeenNonceZero: false}
copy(ret.nonceMask[:], nonceMask)
return ret
}
@ -73,6 +76,7 @@ func aeadChaCha20Poly1305(key, nonceMask []byte) *xorNonceAEAD {
type xorNonceAEAD struct {
nonceMask [aeadNonceLength]byte
aead cipher.AEAD
hasSeenNonceZero bool // This value denotes if the aead field was used with a nonce = 0
}
func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number
@ -80,6 +84,10 @@ func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() }
func (f *xorNonceAEAD) explicitNonceLen() int { return 0 }
func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte {
return f.seal(nonce, out, plaintext, additionalData)
}
func (f *xorNonceAEAD) doSeal(nonce, out, plaintext, additionalData []byte) []byte {
for i, b := range nonce {
f.nonceMask[4+i] ^= b
}

View File

@ -0,0 +1,51 @@
//go:build boringcrypto
package handshake
import (
"crypto/cipher"
"crypto/tls"
"os"
"strings"
)
var goBoringDisabled bool = strings.TrimSpace(os.Getenv("QUIC_GO_DISABLE_BORING")) == "1"
func newAEAD(aes cipher.Block) (cipher.AEAD, error) {
if goBoringDisabled {
// In case Go Boring is disabled then
// fallback to normal cryptographic procedure.
return cipher.NewGCM(aes)
}
return tls.NewGCMTLS13(aes)
}
func allZeros(nonce []byte) bool {
for _, e := range nonce {
if e != 0 {
return false
}
}
return true
}
func (f *xorNonceAEAD) sealZeroNonce() {
f.doSeal([]byte{}, []byte{}, []byte{}, []byte{})
}
func (f *xorNonceAEAD) seal(nonce, out, plaintext, additionalData []byte) []byte {
if !goBoringDisabled {
if !f.hasSeenNonceZero {
// BoringSSL expects that the first nonce passed to the
// AEAD instance is zero.
// At this point the nonce argument is either zero or
// an artificial one will be passed to the AEAD through
// [sealZeroNonce]
f.hasSeenNonceZero = true
if !allZeros(nonce) {
f.sealZeroNonce()
}
}
}
return f.doSeal(nonce, out, plaintext, additionalData)
}

View File

@ -0,0 +1,13 @@
//go:build !boringcrypto
package handshake
import "crypto/cipher"
func newAEAD(aes cipher.Block) (cipher.AEAD, error) {
return cipher.NewGCM(aes)
}
func (f *xorNonceAEAD) seal(nonce, out, plaintext, additionalData []byte) []byte {
return f.doSeal(nonce, out, plaintext, additionalData)
}

506
vendor/go.uber.org/mock/gomock/call.go generated vendored Normal file
View File

@ -0,0 +1,506 @@
// Copyright 2010 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gomock
import (
"fmt"
"reflect"
"strconv"
"strings"
)
// Call represents an expected call to a mock.
type Call struct {
t TestHelper // for triggering test failures on invalid call setup
receiver any // the receiver of the method call
method string // the name of the method
methodType reflect.Type // the type of the method
args []Matcher // the args
origin string // file and line number of call setup
preReqs []*Call // prerequisite calls
// Expectations
minCalls, maxCalls int
numCalls int // actual number made
// actions are called when this Call is called. Each action gets the args and
// can set the return values by returning a non-nil slice. Actions run in the
// order they are created.
actions []func([]any) []any
}
// newCall creates a *Call. It requires the method type in order to support
// unexported methods.
func newCall(t TestHelper, receiver any, method string, methodType reflect.Type, args ...any) *Call {
t.Helper()
// TODO: check arity, types.
mArgs := make([]Matcher, len(args))
for i, arg := range args {
if m, ok := arg.(Matcher); ok {
mArgs[i] = m
} else if arg == nil {
// Handle nil specially so that passing a nil interface value
// will match the typed nils of concrete args.
mArgs[i] = Nil()
} else {
mArgs[i] = Eq(arg)
}
}
// callerInfo's skip should be updated if the number of calls between the user's test
// and this line changes, i.e. this code is wrapped in another anonymous function.
// 0 is us, 1 is RecordCallWithMethodType(), 2 is the generated recorder, and 3 is the user's test.
origin := callerInfo(3)
actions := []func([]any) []any{func([]any) []any {
// Synthesize the zero value for each of the return args' types.
rets := make([]any, methodType.NumOut())
for i := 0; i < methodType.NumOut(); i++ {
rets[i] = reflect.Zero(methodType.Out(i)).Interface()
}
return rets
}}
return &Call{
t: t, receiver: receiver, method: method, methodType: methodType,
args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions,
}
}
// AnyTimes allows the expectation to be called 0 or more times
func (c *Call) AnyTimes() *Call {
c.minCalls, c.maxCalls = 0, 1e8 // close enough to infinity
return c
}
// MinTimes requires the call to occur at least n times. If AnyTimes or MaxTimes have not been called or if MaxTimes
// was previously called with 1, MinTimes also sets the maximum number of calls to infinity.
func (c *Call) MinTimes(n int) *Call {
c.minCalls = n
if c.maxCalls == 1 {
c.maxCalls = 1e8
}
return c
}
// MaxTimes limits the number of calls to n times. If AnyTimes or MinTimes have not been called or if MinTimes was
// previously called with 1, MaxTimes also sets the minimum number of calls to 0.
func (c *Call) MaxTimes(n int) *Call {
c.maxCalls = n
if c.minCalls == 1 {
c.minCalls = 0
}
return c
}
// DoAndReturn declares the action to run when the call is matched.
// The return values from this function are returned by the mocked function.
// It takes an any argument to support n-arity functions.
// The anonymous function must match the function signature mocked method.
func (c *Call) DoAndReturn(f any) *Call {
// TODO: Check arity and types here, rather than dying badly elsewhere.
v := reflect.ValueOf(f)
c.addAction(func(args []any) []any {
c.t.Helper()
ft := v.Type()
if c.methodType.NumIn() != ft.NumIn() {
if ft.IsVariadic() {
c.t.Fatalf("wrong number of arguments in DoAndReturn func for %T.%v The function signature must match the mocked method, a variadic function cannot be used.",
c.receiver, c.method)
} else {
c.t.Fatalf("wrong number of arguments in DoAndReturn func for %T.%v: got %d, want %d [%s]",
c.receiver, c.method, ft.NumIn(), c.methodType.NumIn(), c.origin)
}
return nil
}
vArgs := make([]reflect.Value, len(args))
for i := 0; i < len(args); i++ {
if args[i] != nil {
vArgs[i] = reflect.ValueOf(args[i])
} else {
// Use the zero value for the arg.
vArgs[i] = reflect.Zero(ft.In(i))
}
}
vRets := v.Call(vArgs)
rets := make([]any, len(vRets))
for i, ret := range vRets {
rets[i] = ret.Interface()
}
return rets
})
return c
}
// Do declares the action to run when the call is matched. The function's
// return values are ignored to retain backward compatibility. To use the
// return values call DoAndReturn.
// It takes an any argument to support n-arity functions.
// The anonymous function must match the function signature mocked method.
func (c *Call) Do(f any) *Call {
// TODO: Check arity and types here, rather than dying badly elsewhere.
v := reflect.ValueOf(f)
c.addAction(func(args []any) []any {
c.t.Helper()
ft := v.Type()
if c.methodType.NumIn() != ft.NumIn() {
if ft.IsVariadic() {
c.t.Fatalf("wrong number of arguments in Do func for %T.%v The function signature must match the mocked method, a variadic function cannot be used.",
c.receiver, c.method)
} else {
c.t.Fatalf("wrong number of arguments in Do func for %T.%v: got %d, want %d [%s]",
c.receiver, c.method, ft.NumIn(), c.methodType.NumIn(), c.origin)
}
return nil
}
vArgs := make([]reflect.Value, len(args))
for i := 0; i < len(args); i++ {
if args[i] != nil {
vArgs[i] = reflect.ValueOf(args[i])
} else {
// Use the zero value for the arg.
vArgs[i] = reflect.Zero(ft.In(i))
}
}
v.Call(vArgs)
return nil
})
return c
}
// Return declares the values to be returned by the mocked function call.
func (c *Call) Return(rets ...any) *Call {
c.t.Helper()
mt := c.methodType
if len(rets) != mt.NumOut() {
c.t.Fatalf("wrong number of arguments to Return for %T.%v: got %d, want %d [%s]",
c.receiver, c.method, len(rets), mt.NumOut(), c.origin)
}
for i, ret := range rets {
if got, want := reflect.TypeOf(ret), mt.Out(i); got == want {
// Identical types; nothing to do.
} else if got == nil {
// Nil needs special handling.
switch want.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
// ok
default:
c.t.Fatalf("argument %d to Return for %T.%v is nil, but %v is not nillable [%s]",
i, c.receiver, c.method, want, c.origin)
}
} else if got.AssignableTo(want) {
// Assignable type relation. Make the assignment now so that the generated code
// can return the values with a type assertion.
v := reflect.New(want).Elem()
v.Set(reflect.ValueOf(ret))
rets[i] = v.Interface()
} else {
c.t.Fatalf("wrong type of argument %d to Return for %T.%v: %v is not assignable to %v [%s]",
i, c.receiver, c.method, got, want, c.origin)
}
}
c.addAction(func([]any) []any {
return rets
})
return c
}
// Times declares the exact number of times a function call is expected to be executed.
func (c *Call) Times(n int) *Call {
c.minCalls, c.maxCalls = n, n
return c
}
// SetArg declares an action that will set the nth argument's value,
// indirected through a pointer. Or, in the case of a slice and map, SetArg
// will copy value's elements/key-value pairs into the nth argument.
func (c *Call) SetArg(n int, value any) *Call {
c.t.Helper()
mt := c.methodType
// TODO: This will break on variadic methods.
// We will need to check those at invocation time.
if n < 0 || n >= mt.NumIn() {
c.t.Fatalf("SetArg(%d, ...) called for a method with %d args [%s]",
n, mt.NumIn(), c.origin)
}
// Permit setting argument through an interface.
// In the interface case, we don't (nay, can't) check the type here.
at := mt.In(n)
switch at.Kind() {
case reflect.Ptr:
dt := at.Elem()
if vt := reflect.TypeOf(value); !vt.AssignableTo(dt) {
c.t.Fatalf("SetArg(%d, ...) argument is a %v, not assignable to %v [%s]",
n, vt, dt, c.origin)
}
case reflect.Interface, reflect.Slice, reflect.Map:
// nothing to do
default:
c.t.Fatalf("SetArg(%d, ...) referring to argument of non-pointer non-interface non-slice non-map type %v [%s]",
n, at, c.origin)
}
c.addAction(func(args []any) []any {
v := reflect.ValueOf(value)
switch reflect.TypeOf(args[n]).Kind() {
case reflect.Slice:
setSlice(args[n], v)
case reflect.Map:
setMap(args[n], v)
default:
reflect.ValueOf(args[n]).Elem().Set(v)
}
return nil
})
return c
}
// isPreReq returns true if other is a direct or indirect prerequisite to c.
func (c *Call) isPreReq(other *Call) bool {
for _, preReq := range c.preReqs {
if other == preReq || preReq.isPreReq(other) {
return true
}
}
return false
}
// After declares that the call may only match after preReq has been exhausted.
func (c *Call) After(preReq *Call) *Call {
c.t.Helper()
if c == preReq {
c.t.Fatalf("A call isn't allowed to be its own prerequisite")
}
if preReq.isPreReq(c) {
c.t.Fatalf("Loop in call order: %v is a prerequisite to %v (possibly indirectly).", c, preReq)
}
c.preReqs = append(c.preReqs, preReq)
return c
}
// Returns true if the minimum number of calls have been made.
func (c *Call) satisfied() bool {
return c.numCalls >= c.minCalls
}
// Returns true if the maximum number of calls have been made.
func (c *Call) exhausted() bool {
return c.numCalls >= c.maxCalls
}
func (c *Call) String() string {
args := make([]string, len(c.args))
for i, arg := range c.args {
args[i] = arg.String()
}
arguments := strings.Join(args, ", ")
return fmt.Sprintf("%T.%v(%s) %s", c.receiver, c.method, arguments, c.origin)
}
// Tests if the given call matches the expected call.
// If yes, returns nil. If no, returns error with message explaining why it does not match.
func (c *Call) matches(args []any) error {
if !c.methodType.IsVariadic() {
if len(args) != len(c.args) {
return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: %d",
c.origin, len(args), len(c.args))
}
for i, m := range c.args {
if !m.Matches(args[i]) {
return fmt.Errorf(
"expected call at %s doesn't match the argument at index %d.\nGot: %v\nWant: %v",
c.origin, i, formatGottenArg(m, args[i]), m,
)
}
}
} else {
if len(c.args) < c.methodType.NumIn()-1 {
return fmt.Errorf("expected call at %s has the wrong number of matchers. Got: %d, want: %d",
c.origin, len(c.args), c.methodType.NumIn()-1)
}
if len(c.args) != c.methodType.NumIn() && len(args) != len(c.args) {
return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: %d",
c.origin, len(args), len(c.args))
}
if len(args) < len(c.args)-1 {
return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: greater than or equal to %d",
c.origin, len(args), len(c.args)-1)
}
for i, m := range c.args {
if i < c.methodType.NumIn()-1 {
// Non-variadic args
if !m.Matches(args[i]) {
return fmt.Errorf("expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v",
c.origin, strconv.Itoa(i), formatGottenArg(m, args[i]), m)
}
continue
}
// The last arg has a possibility of a variadic argument, so let it branch
// sample: Foo(a int, b int, c ...int)
if i < len(c.args) && i < len(args) {
if m.Matches(args[i]) {
// Got Foo(a, b, c) want Foo(matcherA, matcherB, gomock.Any())
// Got Foo(a, b, c) want Foo(matcherA, matcherB, someSliceMatcher)
// Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC)
// Got Foo(a, b) want Foo(matcherA, matcherB)
// Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD)
continue
}
}
// The number of actual args don't match the number of matchers,
// or the last matcher is a slice and the last arg is not.
// If this function still matches it is because the last matcher
// matches all the remaining arguments or the lack of any.
// Convert the remaining arguments, if any, into a slice of the
// expected type.
vArgsType := c.methodType.In(c.methodType.NumIn() - 1)
vArgs := reflect.MakeSlice(vArgsType, 0, len(args)-i)
for _, arg := range args[i:] {
vArgs = reflect.Append(vArgs, reflect.ValueOf(arg))
}
if m.Matches(vArgs.Interface()) {
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, gomock.Any())
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, someSliceMatcher)
// Got Foo(a, b) want Foo(matcherA, matcherB, gomock.Any())
// Got Foo(a, b) want Foo(matcherA, matcherB, someEmptySliceMatcher)
break
}
// Wrong number of matchers or not match. Fail.
// Got Foo(a, b) want Foo(matcherA, matcherB, matcherC, matcherD)
// Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC, matcherD)
// Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD, matcherE)
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, matcherC, matcherD)
// Got Foo(a, b, c) want Foo(matcherA, matcherB)
return fmt.Errorf("expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v",
c.origin, strconv.Itoa(i), formatGottenArg(m, args[i:]), c.args[i])
}
}
// Check that all prerequisite calls have been satisfied.
for _, preReqCall := range c.preReqs {
if !preReqCall.satisfied() {
return fmt.Errorf("expected call at %s doesn't have a prerequisite call satisfied:\n%v\nshould be called before:\n%v",
c.origin, preReqCall, c)
}
}
// Check that the call is not exhausted.
if c.exhausted() {
return fmt.Errorf("expected call at %s has already been called the max number of times", c.origin)
}
return nil
}
// dropPrereqs tells the expected Call to not re-check prerequisite calls any
// longer, and to return its current set.
func (c *Call) dropPrereqs() (preReqs []*Call) {
preReqs = c.preReqs
c.preReqs = nil
return
}
func (c *Call) call() []func([]any) []any {
c.numCalls++
return c.actions
}
// InOrder declares that the given calls should occur in order.
// It panics if the type of any of the arguments isn't *Call or a generated
// mock with an embedded *Call.
func InOrder(args ...any) {
calls := make([]*Call, 0, len(args))
for i := 0; i < len(args); i++ {
if call := getCall(args[i]); call != nil {
calls = append(calls, call)
continue
}
panic(fmt.Sprintf(
"invalid argument at position %d of type %T, InOrder expects *gomock.Call or generated mock types with an embedded *gomock.Call",
i,
args[i],
))
}
for i := 1; i < len(calls); i++ {
calls[i].After(calls[i-1])
}
}
// getCall checks if the parameter is a *Call or a generated struct
// that wraps a *Call and returns the *Call pointer - if neither, it returns nil.
func getCall(arg any) *Call {
if call, ok := arg.(*Call); ok {
return call
}
t := reflect.ValueOf(arg)
if t.Kind() != reflect.Ptr && t.Kind() != reflect.Interface {
return nil
}
t = t.Elem()
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if !f.CanInterface() {
continue
}
if call, ok := f.Interface().(*Call); ok {
return call
}
}
return nil
}
func setSlice(arg any, v reflect.Value) {
va := reflect.ValueOf(arg)
for i := 0; i < v.Len(); i++ {
va.Index(i).Set(v.Index(i))
}
}
func setMap(arg any, v reflect.Value) {
va := reflect.ValueOf(arg)
for _, e := range va.MapKeys() {
va.SetMapIndex(e, reflect.Value{})
}
for _, e := range v.MapKeys() {
va.SetMapIndex(e, v.MapIndex(e))
}
}
func (c *Call) addAction(action func([]any) []any) {
c.actions = append(c.actions, action)
}
func formatGottenArg(m Matcher, arg any) string {
got := fmt.Sprintf("%v (%T)", arg, arg)
if gs, ok := m.(GotFormatter); ok {
got = gs.Got(arg)
}
return got
}

164
vendor/go.uber.org/mock/gomock/callset.go generated vendored Normal file
View File

@ -0,0 +1,164 @@
// Copyright 2011 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gomock
import (
"bytes"
"errors"
"fmt"
"sync"
)
// callSet represents a set of expected calls, indexed by receiver and method
// name.
type callSet struct {
// Calls that are still expected.
expected map[callSetKey][]*Call
expectedMu *sync.Mutex
// Calls that have been exhausted.
exhausted map[callSetKey][]*Call
// when set to true, existing call expectations are overridden when new call expectations are made
allowOverride bool
}
// callSetKey is the key in the maps in callSet
type callSetKey struct {
receiver any
fname string
}
func newCallSet() *callSet {
return &callSet{
expected: make(map[callSetKey][]*Call),
expectedMu: &sync.Mutex{},
exhausted: make(map[callSetKey][]*Call),
}
}
func newOverridableCallSet() *callSet {
return &callSet{
expected: make(map[callSetKey][]*Call),
expectedMu: &sync.Mutex{},
exhausted: make(map[callSetKey][]*Call),
allowOverride: true,
}
}
// Add adds a new expected call.
func (cs callSet) Add(call *Call) {
key := callSetKey{call.receiver, call.method}
cs.expectedMu.Lock()
defer cs.expectedMu.Unlock()
m := cs.expected
if call.exhausted() {
m = cs.exhausted
}
if cs.allowOverride {
m[key] = make([]*Call, 0)
}
m[key] = append(m[key], call)
}
// Remove removes an expected call.
func (cs callSet) Remove(call *Call) {
key := callSetKey{call.receiver, call.method}
cs.expectedMu.Lock()
defer cs.expectedMu.Unlock()
calls := cs.expected[key]
for i, c := range calls {
if c == call {
// maintain order for remaining calls
cs.expected[key] = append(calls[:i], calls[i+1:]...)
cs.exhausted[key] = append(cs.exhausted[key], call)
break
}
}
}
// FindMatch searches for a matching call. Returns error with explanation message if no call matched.
func (cs callSet) FindMatch(receiver any, method string, args []any) (*Call, error) {
key := callSetKey{receiver, method}
cs.expectedMu.Lock()
defer cs.expectedMu.Unlock()
// Search through the expected calls.
expected := cs.expected[key]
var callsErrors bytes.Buffer
for _, call := range expected {
err := call.matches(args)
if err != nil {
_, _ = fmt.Fprintf(&callsErrors, "\n%v", err)
} else {
return call, nil
}
}
// If we haven't found a match then search through the exhausted calls so we
// get useful error messages.
exhausted := cs.exhausted[key]
for _, call := range exhausted {
if err := call.matches(args); err != nil {
_, _ = fmt.Fprintf(&callsErrors, "\n%v", err)
continue
}
_, _ = fmt.Fprintf(
&callsErrors, "all expected calls for method %q have been exhausted", method,
)
}
if len(expected)+len(exhausted) == 0 {
_, _ = fmt.Fprintf(&callsErrors, "there are no expected calls of the method %q for that receiver", method)
}
return nil, errors.New(callsErrors.String())
}
// Failures returns the calls that are not satisfied.
func (cs callSet) Failures() []*Call {
cs.expectedMu.Lock()
defer cs.expectedMu.Unlock()
failures := make([]*Call, 0, len(cs.expected))
for _, calls := range cs.expected {
for _, call := range calls {
if !call.satisfied() {
failures = append(failures, call)
}
}
}
return failures
}
// Satisfied returns true in case all expected calls in this callSet are satisfied.
func (cs callSet) Satisfied() bool {
cs.expectedMu.Lock()
defer cs.expectedMu.Unlock()
for _, calls := range cs.expected {
for _, call := range calls {
if !call.satisfied() {
return false
}
}
}
return true
}

326
vendor/go.uber.org/mock/gomock/controller.go generated vendored Normal file
View File

@ -0,0 +1,326 @@
// Copyright 2010 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gomock
import (
"context"
"fmt"
"reflect"
"runtime"
"sync"
)
// A TestReporter is something that can be used to report test failures. It
// is satisfied by the standard library's *testing.T.
type TestReporter interface {
Errorf(format string, args ...any)
Fatalf(format string, args ...any)
}
// TestHelper is a TestReporter that has the Helper method. It is satisfied
// by the standard library's *testing.T.
type TestHelper interface {
TestReporter
Helper()
}
// cleanuper is used to check if TestHelper also has the `Cleanup` method. A
// common pattern is to pass in a `*testing.T` to
// `NewController(t TestReporter)`. In Go 1.14+, `*testing.T` has a cleanup
// method. This can be utilized to call `Finish()` so the caller of this library
// does not have to.
type cleanuper interface {
Cleanup(func())
}
// A Controller represents the top-level control of a mock ecosystem. It
// defines the scope and lifetime of mock objects, as well as their
// expectations. It is safe to call Controller's methods from multiple
// goroutines. Each test should create a new Controller.
//
// func TestFoo(t *testing.T) {
// ctrl := gomock.NewController(t)
// // ..
// }
//
// func TestBar(t *testing.T) {
// t.Run("Sub-Test-1", st) {
// ctrl := gomock.NewController(st)
// // ..
// })
// t.Run("Sub-Test-2", st) {
// ctrl := gomock.NewController(st)
// // ..
// })
// })
type Controller struct {
// T should only be called within a generated mock. It is not intended to
// be used in user code and may be changed in future versions. T is the
// TestReporter passed in when creating the Controller via NewController.
// If the TestReporter does not implement a TestHelper it will be wrapped
// with a nopTestHelper.
T TestHelper
mu sync.Mutex
expectedCalls *callSet
finished bool
}
// NewController returns a new Controller. It is the preferred way to create a Controller.
//
// Passing [*testing.T] registers cleanup function to automatically call [Controller.Finish]
// when the test and all its subtests complete.
func NewController(t TestReporter, opts ...ControllerOption) *Controller {
h, ok := t.(TestHelper)
if !ok {
h = &nopTestHelper{t}
}
ctrl := &Controller{
T: h,
expectedCalls: newCallSet(),
}
for _, opt := range opts {
opt.apply(ctrl)
}
if c, ok := isCleanuper(ctrl.T); ok {
c.Cleanup(func() {
ctrl.T.Helper()
ctrl.finish(true, nil)
})
}
return ctrl
}
// ControllerOption configures how a Controller should behave.
type ControllerOption interface {
apply(*Controller)
}
type overridableExpectationsOption struct{}
// WithOverridableExpectations allows for overridable call expectations
// i.e., subsequent call expectations override existing call expectations
func WithOverridableExpectations() overridableExpectationsOption {
return overridableExpectationsOption{}
}
func (o overridableExpectationsOption) apply(ctrl *Controller) {
ctrl.expectedCalls = newOverridableCallSet()
}
type cancelReporter struct {
t TestHelper
cancel func()
}
func (r *cancelReporter) Errorf(format string, args ...any) {
r.t.Errorf(format, args...)
}
func (r *cancelReporter) Fatalf(format string, args ...any) {
defer r.cancel()
r.t.Fatalf(format, args...)
}
func (r *cancelReporter) Helper() {
r.t.Helper()
}
// WithContext returns a new Controller and a Context, which is cancelled on any
// fatal failure.
func WithContext(ctx context.Context, t TestReporter) (*Controller, context.Context) {
h, ok := t.(TestHelper)
if !ok {
h = &nopTestHelper{t: t}
}
ctx, cancel := context.WithCancel(ctx)
return NewController(&cancelReporter{t: h, cancel: cancel}), ctx
}
type nopTestHelper struct {
t TestReporter
}
func (h *nopTestHelper) Errorf(format string, args ...any) {
h.t.Errorf(format, args...)
}
func (h *nopTestHelper) Fatalf(format string, args ...any) {
h.t.Fatalf(format, args...)
}
func (h nopTestHelper) Helper() {}
// RecordCall is called by a mock. It should not be called by user code.
func (ctrl *Controller) RecordCall(receiver any, method string, args ...any) *Call {
ctrl.T.Helper()
recv := reflect.ValueOf(receiver)
for i := 0; i < recv.Type().NumMethod(); i++ {
if recv.Type().Method(i).Name == method {
return ctrl.RecordCallWithMethodType(receiver, method, recv.Method(i).Type(), args...)
}
}
ctrl.T.Fatalf("gomock: failed finding method %s on %T", method, receiver)
panic("unreachable")
}
// RecordCallWithMethodType is called by a mock. It should not be called by user code.
func (ctrl *Controller) RecordCallWithMethodType(receiver any, method string, methodType reflect.Type, args ...any) *Call {
ctrl.T.Helper()
call := newCall(ctrl.T, receiver, method, methodType, args...)
ctrl.mu.Lock()
defer ctrl.mu.Unlock()
ctrl.expectedCalls.Add(call)
return call
}
// Call is called by a mock. It should not be called by user code.
func (ctrl *Controller) Call(receiver any, method string, args ...any) []any {
ctrl.T.Helper()
// Nest this code so we can use defer to make sure the lock is released.
actions := func() []func([]any) []any {
ctrl.T.Helper()
ctrl.mu.Lock()
defer ctrl.mu.Unlock()
expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
if err != nil {
// callerInfo's skip should be updated if the number of calls between the user's test
// and this line changes, i.e. this code is wrapped in another anonymous function.
// 0 is us, 1 is controller.Call(), 2 is the generated mock, and 3 is the user's test.
origin := callerInfo(3)
stringArgs := make([]string, len(args))
for i, arg := range args {
stringArgs[i] = getString(arg)
}
ctrl.T.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, stringArgs, origin, err)
}
// Two things happen here:
// * the matching call no longer needs to check prerequisite calls,
// * and the prerequisite calls are no longer expected, so remove them.
preReqCalls := expected.dropPrereqs()
for _, preReqCall := range preReqCalls {
ctrl.expectedCalls.Remove(preReqCall)
}
actions := expected.call()
if expected.exhausted() {
ctrl.expectedCalls.Remove(expected)
}
return actions
}()
var rets []any
for _, action := range actions {
if r := action(args); r != nil {
rets = r
}
}
return rets
}
// Finish checks to see if all the methods that were expected to be called were called.
// It is not idempotent and therefore can only be invoked once.
//
// Note: If you pass a *testing.T into [NewController], you no longer
// need to call ctrl.Finish() in your test methods.
func (ctrl *Controller) Finish() {
// If we're currently panicking, probably because this is a deferred call.
// This must be recovered in the deferred function.
err := recover()
ctrl.finish(false, err)
}
// Satisfied returns whether all expected calls bound to this Controller have been satisfied.
// Calling Finish is then guaranteed to not fail due to missing calls.
func (ctrl *Controller) Satisfied() bool {
ctrl.mu.Lock()
defer ctrl.mu.Unlock()
return ctrl.expectedCalls.Satisfied()
}
func (ctrl *Controller) finish(cleanup bool, panicErr any) {
ctrl.T.Helper()
ctrl.mu.Lock()
defer ctrl.mu.Unlock()
if ctrl.finished {
if _, ok := isCleanuper(ctrl.T); !ok {
ctrl.T.Fatalf("Controller.Finish was called more than once. It has to be called exactly once.")
}
return
}
ctrl.finished = true
// Short-circuit, pass through the panic.
if panicErr != nil {
panic(panicErr)
}
// Check that all remaining expected calls are satisfied.
failures := ctrl.expectedCalls.Failures()
for _, call := range failures {
ctrl.T.Errorf("missing call(s) to %v", call)
}
if len(failures) != 0 {
if !cleanup {
ctrl.T.Fatalf("aborting test due to missing call(s)")
return
}
ctrl.T.Errorf("aborting test due to missing call(s)")
}
}
// callerInfo returns the file:line of the call site. skip is the number
// of stack frames to skip when reporting. 0 is callerInfo's call site.
func callerInfo(skip int) string {
if _, file, line, ok := runtime.Caller(skip + 1); ok {
return fmt.Sprintf("%s:%d", file, line)
}
return "unknown file"
}
// isCleanuper checks it if t's base TestReporter has a Cleanup method.
func isCleanuper(t TestReporter) (cleanuper, bool) {
tr := unwrapTestReporter(t)
c, ok := tr.(cleanuper)
return c, ok
}
// unwrapTestReporter unwraps TestReporter to the base implementation.
func unwrapTestReporter(t TestReporter) TestReporter {
tr := t
switch nt := t.(type) {
case *cancelReporter:
tr = nt.t
if h, check := tr.(*nopTestHelper); check {
tr = h.t
}
case *nopTestHelper:
tr = nt.t
default:
// not wrapped
}
return tr
}

60
vendor/go.uber.org/mock/gomock/doc.go generated vendored Normal file
View File

@ -0,0 +1,60 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package gomock is a mock framework for Go.
//
// Standard usage:
//
// (1) Define an interface that you wish to mock.
// type MyInterface interface {
// SomeMethod(x int64, y string)
// }
// (2) Use mockgen to generate a mock from the interface.
// (3) Use the mock in a test:
// func TestMyThing(t *testing.T) {
// mockCtrl := gomock.NewController(t)
// mockObj := something.NewMockMyInterface(mockCtrl)
// mockObj.EXPECT().SomeMethod(4, "blah")
// // pass mockObj to a real object and play with it.
// }
//
// By default, expected calls are not enforced to run in any particular order.
// Call order dependency can be enforced by use of InOrder and/or Call.After.
// Call.After can create more varied call order dependencies, but InOrder is
// often more convenient.
//
// The following examples create equivalent call order dependencies.
//
// Example of using Call.After to chain expected call order:
//
// firstCall := mockObj.EXPECT().SomeMethod(1, "first")
// secondCall := mockObj.EXPECT().SomeMethod(2, "second").After(firstCall)
// mockObj.EXPECT().SomeMethod(3, "third").After(secondCall)
//
// Example of using InOrder to declare expected call order:
//
// gomock.InOrder(
// mockObj.EXPECT().SomeMethod(1, "first"),
// mockObj.EXPECT().SomeMethod(2, "second"),
// mockObj.EXPECT().SomeMethod(3, "third"),
// )
//
// The standard TestReporter most users will pass to `NewController` is a
// `*testing.T` from the context of the test. Note that this will use the
// standard `t.Error` and `t.Fatal` methods to report what happened in the test.
// In some cases this can leave your testing package in a weird state if global
// state is used since `t.Fatal` is like calling panic in the middle of a
// function. In these cases it is recommended that you pass in your own
// `TestReporter`.
package gomock

Some files were not shown because too many files have changed in this diff Show More