Merge branch 'cloudflare:master' into master

This commit is contained in:
Jauder Ho 2022-01-20 10:24:45 -08:00 committed by GitHub
commit 672a128776
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
409 changed files with 35414 additions and 8919 deletions

View File

@ -46,8 +46,8 @@ git reset --hard origin/master
URL="https://packages.argotunnel.com/dl/cloudflared-$VERSION-darwin-amd64.tgz"
tee cloudflared.rb <<EOF
class Cloudflared < Formula
desc 'Argo Tunnel'
homepage 'https://developers.cloudflare.com/argo-tunnel/'
desc 'Cloudflare Tunnel'
homepage 'https://developers.cloudflare.com/cloudflare-one/connections/connect-apps'
url '$URL'
sha256 '$SHA256'
version '$VERSION'
@ -62,6 +62,6 @@ git add cloudflared.rb
git diff
git config user.name "cloudflare-warp-bot"
git config user.email "warp-bot@cloudflare.com"
git commit -m "Release Argo Tunnel $VERSION"
git commit -m "Release Cloudflare Tunnel $VERSION"
git push -v origin master

View File

@ -1,5 +1,53 @@
**Experimental**: This is a new format for release notes. The format and availability is subject to change.
## 2021.12.2
### Bug Fixes
- Fix logging when `quic` transport is used and UDP traffic is proxied.
- FIPS compliant cloudflared binaries will now be released as separate artifacts. Recall that these are only for linux
and amd64.
## 2021.12.1
### Bug Fixes
- Fixes Github issue #530 where cloudflared 2021.12.0 could not reach origins that were HTTPS and using certain encryption
methods forbidden by FIPS compliance (such as Let's Encrypt certificates). To address this fix we have temporarily reverted
FIPS compliance from amd64 linux binaries that was recently introduced (or fixed actually as it was never working before).
## 2021.12.0
### New Features
- Cloudflared binary released for amd64 linux is now FIPS compliant.
### Improvements
- Logging about connectivity to Cloudflare edge now only yields `ERR` level logging if there are no connections to
Cloudflare edge that are active. Otherwise it logs `WARN` level.
### Bug Fixes
- Fixes Github issue #501.
## 2021.11.0
### Improvements
- Fallback from `protocol:quic` to `protocol:http2` immediately if UDP connectivity isn't available. This could be because of a firewall or
egress rule.
## 2021.10.4
### Improvements
- Collect quic transport metrics on RTT, packets and bytes transferred.
### Bug Fixes
- Fix race condition that was writing to the connection after the http2 handler returns.
## 2021.9.2
### New features
- `cloudflared` can now run with `quic` as the underlying tunnel transport protocol. To try it, change or add "protocol: quic" to your config.yml file or
run cloudflared with the `--protocol quic` flag. e.g:
`cloudflared tunnel --protocol quic run <tunnel-name>`
### Bug Fixes
- Fixed some generic transport bugs in `quic` mode. It's advised to upgrade to at least this version (2021.9.2) when running `cloudflared`
with `quic` protocol.
- `cloudflared` docker images will now show version.
## 2021.8.4
### Improvements
- Temporary tunnels (those hosted on trycloudflare.com that do not require a Cloudflare login) now run as Named Tunnels

View File

@ -1,25 +1,41 @@
VERSION := $(shell git describe --tags --always --dirty="-dev" --match "[0-9][0-9][0-9][0-9].*.*")
VERSION := $(shell git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*")
MSI_VERSION := $(shell git tag -l --sort=v:refname | grep "w" | tail -1 | cut -c2-)
#MSI_VERSION expects the format of the tag to be: (wX.X.X). Starts with the w character to not break cfsetup.
#e.g. w3.0.1 or w4.2.10. It trims off the w character when creating the MSI.
ifeq ($(FIPS), true)
GO_BUILD_TAGS := $(GO_BUILD_TAGS) fips
endif
ifneq ($(GO_BUILD_TAGS),)
GO_BUILD_TAGS := -tags $(GO_BUILD_TAGS)
ifeq ($(ORIGINAL_NAME), true)
# Used for builds that want FIPS compilation but want the artifacts generated to still have the original name.
BINARY_NAME := cloudflared
else ifeq ($(FIPS), true)
# Used for FIPS compliant builds that do not match the case above.
BINARY_NAME := cloudflared-fips
else
# Used for all other (non-FIPS) builds.
BINARY_NAME := cloudflared
endif
ifeq ($(NIGHTLY), true)
DEB_PACKAGE_NAME := cloudflared-nightly
DEB_PACKAGE_NAME := $(BINARY_NAME)-nightly
NIGHTLY_FLAGS := --conflicts cloudflared --replaces cloudflared
else
DEB_PACKAGE_NAME := cloudflared
DEB_PACKAGE_NAME := $(BINARY_NAME)
endif
DATE := $(shell date -u '+%Y-%m-%d-%H%M UTC')
VERSION_FLAGS := -ldflags='-X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)"'
VERSION_FLAGS := -X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)"
LINK_FLAGS :=
ifeq ($(FIPS), true)
LINK_FLAGS := -linkmode=external -extldflags=-static $(LINK_FLAGS)
# Prevent linking with libc regardless of CGO enabled or not.
GO_BUILD_TAGS := $(GO_BUILD_TAGS) osusergo netgo fips
VERSION_FLAGS := $(VERSION_FLAGS) -X "main.BuildType=FIPS"
endif
LDFLAGS := -ldflags='$(VERSION_FLAGS) $(LINK_FLAGS)'
ifneq ($(GO_BUILD_TAGS),)
GO_BUILD_TAGS := -tags "$(GO_BUILD_TAGS)"
endif
IMPORT_PATH := github.com/cloudflare/cloudflared
PACKAGE_DIR := $(CURDIR)/packaging
@ -61,9 +77,9 @@ else
endif
ifeq ($(TARGET_OS), windows)
EXECUTABLE_PATH=./cloudflared.exe
EXECUTABLE_PATH=./$(BINARY_NAME).exe
else
EXECUTABLE_PATH=./cloudflared
EXECUTABLE_PATH=./$(BINARY_NAME)
endif
ifeq ($(FLAVOR), centos-7)
@ -80,17 +96,15 @@ clean:
go clean
.PHONY: cloudflared
cloudflared:
cloudflared:
ifeq ($(FIPS), true)
$(info Building cloudflared with go-fips)
-test -f fips/fips.go && mv fips/fips.go fips/fips.go.linux-amd64
mv fips/fips.go.linux-amd64 fips/fips.go
cp -f fips/fips.go.linux-amd64 cmd/cloudflared/fips.go
endif
GOOS=$(TARGET_OS) GOARCH=$(TARGET_ARCH) go build -v -mod=vendor $(GO_BUILD_TAGS) $(VERSION_FLAGS) $(IMPORT_PATH)/cmd/cloudflared
GOOS=$(TARGET_OS) GOARCH=$(TARGET_ARCH) go build -v -mod=vendor $(GO_BUILD_TAGS) $(LDFLAGS) $(IMPORT_PATH)/cmd/cloudflared
ifeq ($(FIPS), true)
mv fips/fips.go fips/fips.go.linux-amd64
rm -f cmd/cloudflared/fips.go
./check-fips.sh cloudflared
endif
.PHONY: container
@ -100,10 +114,10 @@ container:
.PHONY: test
test: vet
ifndef CI
go test -v -mod=vendor -race $(VERSION_FLAGS) ./...
go test -v -mod=vendor -race $(LDFLAGS) ./...
else
@mkdir -p .cover
go test -v -mod=vendor -race $(VERSION_FLAGS) -coverprofile=".cover/c.out" ./...
go test -v -mod=vendor -race $(LDFLAGS) -coverprofile=".cover/c.out" ./...
go tool cover -html ".cover/c.out" -o .cover/all.html
endif
@ -112,10 +126,10 @@ test-ssh-server:
docker-compose -f ssh_server_tests/docker-compose.yml up
define publish_package
chmod 664 cloudflared*.$(1); \
chmod 664 $(BINARY_NAME)*.$(1); \
for HOST in $(CF_PKG_HOSTS); do \
ssh-keyscan -t rsa $$HOST >> ~/.ssh/known_hosts; \
scp -p -4 cloudflared*.$(1) cfsync@$$HOST:/state/cf-pkg/staging/$(2)/$(TARGET_PUBLIC_REPO)/cloudflared/; \
ssh-keyscan -t ecdsa $$HOST >> ~/.ssh/known_hosts; \
scp -p -4 $(BINARY_NAME)*.$(1) cfsync@$$HOST:/state/cf-pkg/staging/$(2)/$(TARGET_PUBLIC_REPO)/$(BINARY_NAME)/; \
done
endef
@ -127,12 +141,14 @@ publish-deb: cloudflared-deb
publish-rpm: cloudflared-rpm
$(call publish_package,rpm,yum)
# When we build packages, the package name will be FIPS-aware.
# But we keep the binary installed by it to be named "cloudflared" regardless.
define build_package
mkdir -p $(PACKAGE_DIR)
cp cloudflared $(PACKAGE_DIR)/cloudflared
cat cloudflared_man_template | sed -e 's/\$${VERSION}/$(VERSION)/; s/\$${DATE}/$(DATE)/' > $(PACKAGE_DIR)/cloudflared.1
fakeroot fpm -C $(PACKAGE_DIR) -s dir -t $(1) \
--description 'Cloudflare Argo tunnel daemon' \
--description 'Cloudflare Tunnel daemon' \
--vendor 'Cloudflare' \
--license 'Cloudflare Service Agreement' \
--url 'https://github.com/cloudflare/cloudflared' \
@ -247,8 +263,8 @@ tunnelrpc-deps:
capnp compile -ogo tunnelrpc/tunnelrpc.capnp
.PHONY: quic-deps
quic-deps:
which capnp
quic-deps:
which capnp
which capnpc-go
capnp compile -ogo quic/schema/quic_metadata_protocol.capnp
@ -258,9 +274,9 @@ vet:
# go get github.com/sudarshan-reddy/go-sumtype (don't do this in build directory or this will cause vendor issues)
# Note: If you have github.com/BurntSushi/go-sumtype then you might have to use the repo above instead
# for now because it uses an older version of golang.org/x/tools.
which go-sumtype
which go-sumtype
go-sumtype $$(go list -mod=vendor ./...)
.PHONY: goimports
goimports:
for d in $$(go list -mod=readonly -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc) ; do goimports -format-only -local github.com/cloudflare/cloudflared -w $$d ; done
for d in $$(go list -mod=readonly -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc) ; do goimports -format-only -local github.com/cloudflare/cloudflared -w $$d ; done

View File

@ -1,14 +1,28 @@
# Argo Tunnel client
# Cloudflare Tunnel client
Contains the command-line client for Cloudflare Tunnel, a tunneling daemon that proxies traffic from the Cloudflare network to your origins.
This daemon sits between Cloudflare network and your origin (e.g. a webserver). Cloudflare attracts client requests and sends them to you
via this daemon, without requiring you to poke holes on your firewall --- your origin can remain as closed as possible.
Extensive documentation can be found in the [Cloudflare Tunnel section](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps) of the Cloudflare Docs.
All usages related with proxying to your origins are available under `cloudflared tunnel help`.
You can also use `cloudflared` to access Tunnel origins (that are protected with `cloudflared tunnel`) for TCP traffic
at Layer 4 (i.e., not HTTP/websocket), which is relevant for use cases such as SSH, RDP, etc.
Such usages are available under `cloudflared access help`.
You can instead use [WARP client](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/configuration/private-networks)
to access private origins behind Tunnels for Layer 4 traffic without requiring `cloudflared access` commands on the client side.
Contains the command-line client for Argo Tunnel, a tunneling daemon that proxies any local webserver through the Cloudflare network. Extensive documentation can be found in the [Argo Tunnel section](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps) of the Cloudflare Docs.
## Before you get started
Before you use Argo Tunnel, you'll need to complete a few steps in the Cloudflare dashboard. The website you add to Cloudflare will be used to route traffic to your Tunnel.
Before you use Cloudflare Tunnel, you'll need to complete a few steps in the Cloudflare dashboard: you need to add a
website to your Cloudflare account. Note that today it is possible to use Tunnel without a website (e.g. for private
routing), but for legacy reasons this requirement is still necessary:
1. [Add a website to Cloudflare](https://support.cloudflare.com/hc/en-us/articles/201720164-Creating-a-Cloudflare-account-and-adding-a-website)
2. [Change your domain nameservers to Cloudflare](https://support.cloudflare.com/hc/en-us/articles/205195708)
## Installing `cloudflared`
Downloads are available as standalone binaries, a Docker image, and Debian, RPM, and Homebrew packages. You can also find releases here on the `cloudflared` GitHub repository.
@ -18,18 +32,23 @@ Downloads are available as standalone binaries, a Docker image, and Debian, RPM,
* A Docker image of `cloudflared` is [available on DockerHub](https://hub.docker.com/r/cloudflare/cloudflared)
* You can install on Windows machines with the [steps here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation#windows)
User documentation for Argo Tunnel can be found at https://developers.cloudflare.com/cloudflare-one/connections/connect-apps
User documentation for Cloudflare Tunnel can be found at https://developers.cloudflare.com/cloudflare-one/connections/connect-apps
## Creating Tunnels and routing traffic
Once installed, you can authenticate `cloudflared` into your Cloudflare account and begin creating Tunnels that serve traffic for hostnames in your account.
Once installed, you can authenticate `cloudflared` into your Cloudflare account and begin creating Tunnels to serve traffic to your origins.
* Create a Tunnel with [these instructions](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/create-tunnel)
* Route traffic to that Tunnel with [DNS records in Cloudflare](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/dns) or with a [Cloudflare Load Balancer](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/lb)
* Route traffic to that Tunnel:
* Via public [DNS records in Cloudflare](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/dns)
* Or via a public hostname guided by a [Cloudflare Load Balancer](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/lb)
* Or from [WARP client private traffic](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/configuration/private-networks)
## TryCloudflare
Want to test Argo Tunnel before adding a website to Cloudflare? You can do so with TryCloudflare using the documentation [available here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/trycloudflare).
Want to test Cloudflare Tunnel before adding a website to Cloudflare? You can do so with TryCloudflare using the documentation [available here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/trycloudflare).
## Deprecated versions

View File

@ -1,3 +1,107 @@
2022.1.2
- 2022-01-13 TUN-5650: Fix pynacl version to 1.4.0 and pygithub version to 1.55 so release doesn't break unexpectedly
2022.1.1
- 2022-01-10 TUN-5631: Build everything with go 1.17.5
- 2022-01-06 TUN-5623: Configure quic max datagram frame size to 1350 bytes for none Windows platforms
2022.1.0
- 2022-01-03 TUN-5612: Add support for specifying TLS min/max version
- 2022-01-03 TUN-5612: Make tls min/max version public visible
- 2022-01-03 TUN-5551: Internally published debian artifacts are now named just cloudflared even though they are FIPS compliant
- 2022-01-04 TUN-5600: Close QUIC transports as soon as possible while respecting graceful shutdown
- 2022-01-05 TUN-5616: Never fallback transport if user chooses it on purpose
- 2022-01-05 TUN-5204: Unregister QUIC transports on disconnect
- 2022-01-04 TUN-5600: Add coverage to component tests for various transports
2021.12.4
- 2021-12-27 TUN-5482: Refactor tunnelstore client related packages for more coherent package
- 2021-12-27 TUN-5551: Change internally published debian package to be FIPS compliant
- 2021-12-27 TUN-5551: Show whether the binary was built for FIPS compliance
2021.12.3
- 2021-12-22 TUN-5584: Changes for release 2021.12.2
- 2021-12-22 TUN-5590: QUIC datagram max user payload is 1217 bytes
- 2021-12-22 TUN-5593: Read full packet from UDP connection, even if it exceeds MTU of the transport. When packet length is greater than the MTU of the transport, we will silently drop packets (for now).
- 2021-12-23 TUN-5597: Log session ID when session is terminated by edge
2021.12.2
- 2021-12-20 TUN-5571: Remove redundant session manager log, it's already logged in origin/tunnel.ServeQUIC
- 2021-12-20 TUN-5570: Only log RPC server events at error level to reduce noise
- 2021-12-14 TUN-5494: Send a RPC with terminate reason to edge if the session is closed locally
- 2021-11-09 TUN-5551: Reintroduce FIPS compliance for linux amd64 now as separate binaries
2021.12.1
- 2021-12-16 TUN-5549: Revert "TUN-5277: Ensure cloudflared binary is FIPS compliant on linux amd64"
2021.12.0
- 2021-12-13 TUN-5530: Get current time from ticker
- 2021-12-15 TUN-5544: Update CHANGES.md for next release
- 2021-12-07 TUN-5519: Adjust URL for virtual_networks endpoint to match what we will publish
- 2021-12-02 TUN-5488: Close session after it's idle for a period defined by registerUdpSession RPC
- 2021-12-09 TUN-5504: Fix upload of packages to public repo
- 2021-11-30 TUN-5481: Create abstraction for Origin UDP Connection
- 2021-11-30 TUN-5422: Define RPC to unregister session
- 2021-11-26 TUN-5361: Commands for managing virtual networks
- 2021-11-29 TUN-5362: Adjust route ip commands to be aware of virtual networks
- 2021-11-23 TUN-5301: Separate datagram multiplex and session management logic from quic connection logic
- 2021-11-10 TUN-5405: Update net package to v0.0.0-20211109214657-ef0fda0de508
- 2021-11-10 TUN-5408: Update quic package to v0.24.0
- 2021-11-12 Fix typos
- 2021-11-13 Fix for Issue #501: Unexpected User-agent insertion when tunneling http request
- 2021-11-16 TUN-5129: Remove `-dev` suffix when computing version and Git has uncommitted changes
- 2021-11-18 TUN-5441: Fix message about available protocols
- 2021-11-12 TUN-5300: Define RPC to register UDP sessions
- 2021-11-14 TUN-5299: Send/receive QUIC datagram from edge and proxy to origin as UDP
- 2021-11-04 TUN-5387: Updated CHANGES.md for 2021.11.0
- 2021-11-08 TUN-5368: Log connection issues with LogLevel that depends on tunnel state
- 2021-11-09 TUN-5397: Log cloudflared output when it fails to connect tunnel
- 2021-11-09 TUN-5277: Ensure cloudflared binary is FIPS compliant on linux amd64
- 2021-11-08 TUN-5393: Content-length is no longer a control header for non-h2mux transports
2021.11.0
- 2021-11-03 TUN-5285: Fallback to HTTP2 immediately if connection times out with no network activity
- 2021-09-29 Add flag to 'tunnel create' subcommand to specify a base64-encoded secret
2021.10.5
- 2021-10-25 Update change log for release 2021.10.4
- 2021-10-25 Revert "TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown"
2021.10.4
- 2021-10-21 TUN-5287: Fix misuse of wait group in TestQUICServer that caused the test to exit immediately
- 2021-10-21 TUN-5286: Upgrade crypto/ssh package to fix CVE-2020-29652
- 2021-10-18 TUN-5262: Allow to configure max fetch size for listing queries
- 2021-10-19 TUN-5262: Improvements to `max-fetch-size` that allow to deal with large number of tunnels in account
- 2021-10-15 TUN-5261: Collect QUIC metrics about RTT, packets and bytes transfered and log events at tracing level
- 2021-10-19 TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown
2021.10.3
- 2021-10-14 TUN-5255: Fix potential panic if Cloudflare API fails to respond to GetTunnel(id) during delete command
- 2021-10-14 TUN-5257: Fix more cfsetup targets that were broken by recent package changes
2021.10.2
- 2021-10-11 TUN-5138: Switch to QUIC on auto protocol based on threshold
- 2021-10-14 TUN-5250: Add missing packages for cfsetup to succeed in github release pkgs target
2021.10.1
- 2021-10-12 TUN-5246: Use protocol: quic for Quick tunnels if one is not already set
- 2021-10-13 TUN-5249: Revert "TUN-5138: Switch to QUIC on auto protocol based on threshold"
2021.10.0
- 2021-10-11 TUN-5138: Switch to QUIC on auto protocol based on threshold
- 2021-10-07 TUN-5195: Do not set empty body if not applicable
- 2021-10-08 UN-5213: Increase MaxStreams value for QUIC transport
- 2021-09-28 TUN-5169: Release 2021.9.2 CHANGES.md
- 2021-09-28 TUN-5164: Update README and clean up references to Argo Tunnel (using Cloudflare Tunnel instead)
2021.9.2
- 2021-09-21 TUN-5129: Use go 1.17 and copy .git folder to docker build to compute version
- 2021-09-21 TUN-5128: Enforce maximum grace period
- 2021-09-22 TUN-5141: Make sure websocket pinger returns before streaming returns
- 2021-09-24 TUN-5142: Add asynchronous servecontrolstream for QUIC
- 2021-09-24 TUN-5142: defer close rpcconn inside unregister instead of ServeControlStream
- 2021-09-27 TUN-5160: Set request.ContentLength when this value is in request header
2021.9.1
- 2021-09-21 TUN-5118: Quic connection now detects duplicate connections similar to http2
- 2021-09-15 Fix TryCloudflare link

25
build-packages-fips.sh Executable file
View File

@ -0,0 +1,25 @@
VERSION=$(git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*")
echo $VERSION
# This controls the directory the built artifacts go into
export ARTIFACT_DIR=built_artifacts/
mkdir -p $ARTIFACT_DIR
arch=("amd64")
export TARGET_ARCH=$arch
export TARGET_OS=linux
export FIPS=true
# For BoringCrypto to link, we need CGO enabled. Otherwise compilation fails.
export CGO_ENABLED=1
make cloudflared-deb
mv cloudflared-fips\_$VERSION\_$arch.deb $ARTIFACT_DIR/cloudflared-fips-linux-$arch.deb
# rpm packages invert the - and _ and use x86_64 instead of amd64.
RPMVERSION=$(echo $VERSION|sed -r 's/-/_/g')
RPMARCH="x86_64"
make cloudflared-rpm
mv cloudflared-fips-$RPMVERSION-1.$RPMARCH.rpm $ARTIFACT_DIR/cloudflared-fips-linux-$RPMARCH.rpm
# finally move the linux binary as well.
mv ./cloudflared $ARTIFACT_DIR/cloudflared-fips-linux-$arch

View File

@ -1,12 +1,15 @@
VERSION=$(git describe --tags --always --dirty="-dev" --match "[0-9][0-9][0-9][0-9].*.*")
VERSION=$(git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*")
echo $VERSION
# Avoid depending on C code since we don't need it.
export CGO_ENABLED=0
# This controls the directory the built artifacts go into
export ARTIFACT_DIR=built_artifacts/
mkdir -p $ARTIFACT_DIR
windowsArchs=("amd64" "386")
export TARGET_OS=windows
for arch in ${windowsArchs[@]}; do
for arch in ${windowsArchs[@]}; do
export TARGET_ARCH=$arch
make cloudflared-msi
mv ./cloudflared.exe $ARTIFACT_DIR/cloudflared-windows-$arch.exe
@ -14,15 +17,14 @@ for arch in ${windowsArchs[@]}; do
done
export FIPS=true
linuxArchs=("amd64" "386" "arm" "arm64")
linuxArchs=("386" "amd64" "arm" "arm64")
export TARGET_OS=linux
for arch in ${linuxArchs[@]}; do
for arch in ${linuxArchs[@]}; do
export TARGET_ARCH=$arch
make cloudflared-deb
mv cloudflared\_$VERSION\_$arch.deb $ARTIFACT_DIR/cloudflared-linux-$arch.deb
# rpm packages invert the - and _ and use x86_64 instead of amd64.
# rpm packages invert the - and _ and use x86_64 instead of amd64.
RPMVERSION=$(echo $VERSION|sed -r 's/-/_/g')
RPMARCH=$arch
if [ $arch == "amd64" ];then
@ -37,4 +39,3 @@ for arch in ${linuxArchs[@]}; do
# finally move the linux binary as well.
mv ./cloudflared $ARTIFACT_DIR/cloudflared-linux-$arch
done

View File

@ -54,7 +54,7 @@ func (c *StdinoutStream) Write(p []byte) (int, error) {
return os.Stdout.Write(p)
}
// Helper to allow defering the response close with a check that the resp is not nil
// Helper to allow deferring the response close with a check that the resp is not nil
func closeRespBody(resp *http.Response) {
if resp != nil {
_ = resp.Body.Close()

View File

@ -66,7 +66,7 @@ func DecodeOriginCert(blocks []byte) (*OriginCert, error) {
originCert.ServiceKey = ntt.ServiceKey
originCert.AccountID = ntt.AccountID
} else {
// Try the older format, where the zoneID and service key are seperated by
// Try the older format, where the zoneID and service key are separated by
// a new line character
token := string(block.Bytes)
s := strings.Split(token, "\n")

186
cfapi/base_client.go Normal file
View File

@ -0,0 +1,186 @@
package cfapi
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"golang.org/x/net/http2"
)
const (
defaultTimeout = 15 * time.Second
jsonContentType = "application/json"
)
var (
ErrUnauthorized = errors.New("unauthorized")
ErrBadRequest = errors.New("incorrect request parameters")
ErrNotFound = errors.New("not found")
ErrAPINoSuccess = errors.New("API call failed")
)
type RESTClient struct {
baseEndpoints *baseEndpoints
authToken string
userAgent string
client http.Client
log *zerolog.Logger
}
type baseEndpoints struct {
accountLevel url.URL
zoneLevel url.URL
accountRoutes url.URL
accountVnets url.URL
}
var _ Client = (*RESTClient)(nil)
func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, log *zerolog.Logger) (*RESTClient, error) {
if strings.HasSuffix(baseURL, "/") {
baseURL = baseURL[:len(baseURL)-1]
}
accountLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag))
if err != nil {
return nil, errors.Wrap(err, "failed to create account level endpoint")
}
accountRoutesEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/teamnet/routes", baseURL, accountTag))
if err != nil {
return nil, errors.Wrap(err, "failed to create route account-level endpoint")
}
accountVnetsEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/teamnet/virtual_networks", baseURL, accountTag))
if err != nil {
return nil, errors.Wrap(err, "failed to create virtual network account-level endpoint")
}
zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag))
if err != nil {
return nil, errors.Wrap(err, "failed to create account level endpoint")
}
httpTransport := http.Transport{
TLSHandshakeTimeout: defaultTimeout,
ResponseHeaderTimeout: defaultTimeout,
}
http2.ConfigureTransport(&httpTransport)
return &RESTClient{
baseEndpoints: &baseEndpoints{
accountLevel: *accountLevelEndpoint,
zoneLevel: *zoneLevelEndpoint,
accountRoutes: *accountRoutesEndpoint,
accountVnets: *accountVnetsEndpoint,
},
authToken: authToken,
userAgent: userAgent,
client: http.Client{
Transport: &httpTransport,
Timeout: defaultTimeout,
},
log: log,
}, nil
}
func (r *RESTClient) sendRequest(method string, url url.URL, body interface{}) (*http.Response, error) {
var bodyReader io.Reader
if body != nil {
if bodyBytes, err := json.Marshal(body); err != nil {
return nil, errors.Wrap(err, "failed to serialize json body")
} else {
bodyReader = bytes.NewBuffer(bodyBytes)
}
}
req, err := http.NewRequest(method, url.String(), bodyReader)
if err != nil {
return nil, errors.Wrapf(err, "can't create %s request", method)
}
req.Header.Set("User-Agent", r.userAgent)
if bodyReader != nil {
req.Header.Set("Content-Type", jsonContentType)
}
req.Header.Add("X-Auth-User-Service-Key", r.authToken)
req.Header.Add("Accept", "application/json;version=1")
return r.client.Do(req)
}
func parseResponse(reader io.Reader, data interface{}) error {
// Schema for Tunnelstore responses in the v1 API.
// Roughly, it's a wrapper around a particular result that adds failures/errors/etc
var result response
// First, parse the wrapper and check the API call succeeded
if err := json.NewDecoder(reader).Decode(&result); err != nil {
return errors.Wrap(err, "failed to decode response")
}
if err := result.checkErrors(); err != nil {
return err
}
if !result.Success {
return ErrAPINoSuccess
}
// At this point we know the API call succeeded, so, parse out the inner
// result into the datatype provided as a parameter.
if err := json.Unmarshal(result.Result, &data); err != nil {
return errors.Wrap(err, "the Cloudflare API response was an unexpected type")
}
return nil
}
type response struct {
Success bool `json:"success,omitempty"`
Errors []apiErr `json:"errors,omitempty"`
Messages []string `json:"messages,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
}
func (r *response) checkErrors() error {
if len(r.Errors) == 0 {
return nil
}
if len(r.Errors) == 1 {
return r.Errors[0]
}
var messages string
for _, e := range r.Errors {
messages += fmt.Sprintf("%s; ", e)
}
return fmt.Errorf("API errors: %s", messages)
}
type apiErr struct {
Code json.Number `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
func (e apiErr) Error() string {
return fmt.Sprintf("code: %v, reason: %s", e.Code, e.Message)
}
func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error {
if resp.Header.Get("Content-Type") == "application/json" {
var errorsResp response
if json.NewDecoder(resp.Body).Decode(&errorsResp) == nil {
if err := errorsResp.checkErrors(); err != nil {
return errors.Errorf("Failed to %s: %s", op, err)
}
}
}
switch resp.StatusCode {
case http.StatusOK:
return nil
case http.StatusBadRequest:
return ErrBadRequest
case http.StatusUnauthorized, http.StatusForbidden:
return ErrUnauthorized
case http.StatusNotFound:
return ErrNotFound
}
return errors.Errorf("API call to %s failed with status %d: %s", op,
resp.StatusCode, http.StatusText(resp.StatusCode))
}

39
cfapi/client.go Normal file
View File

@ -0,0 +1,39 @@
package cfapi
import (
"github.com/google/uuid"
)
type TunnelClient interface {
CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
DeleteTunnel(tunnelID uuid.UUID) error
ListTunnels(filter *TunnelFilter) ([]*Tunnel, error)
ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error)
CleanupConnections(tunnelID uuid.UUID, params *CleanupParams) error
}
type HostnameClient interface {
RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error)
}
type IPRouteClient interface {
ListRoutes(filter *IpRouteFilter) ([]*DetailedRoute, error)
AddRoute(newRoute NewRoute) (Route, error)
DeleteRoute(params DeleteRouteParams) error
GetByIP(params GetRouteByIpParams) (DetailedRoute, error)
}
type VnetClient interface {
CreateVirtualNetwork(newVnet NewVirtualNetwork) (VirtualNetwork, error)
ListVirtualNetworks(filter *VnetFilter) ([]*VirtualNetwork, error)
DeleteVirtualNetwork(id uuid.UUID) error
UpdateVirtualNetwork(id uuid.UUID, updates UpdateVirtualNetwork) error
}
type Client interface {
TunnelClient
HostnameClient
IPRouteClient
VnetClient
}

192
cfapi/hostname.go Normal file
View File

@ -0,0 +1,192 @@
package cfapi
import (
"encoding/json"
"fmt"
"io"
"net/http"
"path"
"github.com/google/uuid"
"github.com/pkg/errors"
)
type Change = string
const (
ChangeNew = "new"
ChangeUpdated = "updated"
ChangeUnchanged = "unchanged"
)
// HostnameRoute represents a record type that can route to a tunnel
type HostnameRoute interface {
json.Marshaler
RecordType() string
UnmarshalResult(body io.Reader) (HostnameRouteResult, error)
String() string
}
type HostnameRouteResult interface {
// SuccessSummary explains what will route to this tunnel when it's provisioned successfully
SuccessSummary() string
}
type DNSRoute struct {
userHostname string
overwriteExisting bool
}
type DNSRouteResult struct {
route *DNSRoute
CName Change `json:"cname"`
Name string `json:"name"`
}
func NewDNSRoute(userHostname string, overwriteExisting bool) HostnameRoute {
return &DNSRoute{
userHostname: userHostname,
overwriteExisting: overwriteExisting,
}
}
func (dr *DNSRoute) MarshalJSON() ([]byte, error) {
s := struct {
Type string `json:"type"`
UserHostname string `json:"user_hostname"`
OverwriteExisting bool `json:"overwrite_existing"`
}{
Type: dr.RecordType(),
UserHostname: dr.userHostname,
OverwriteExisting: dr.overwriteExisting,
}
return json.Marshal(&s)
}
func (dr *DNSRoute) UnmarshalResult(body io.Reader) (HostnameRouteResult, error) {
var result DNSRouteResult
err := parseResponse(body, &result)
result.route = dr
return &result, err
}
func (dr *DNSRoute) RecordType() string {
return "dns"
}
func (dr *DNSRoute) String() string {
return fmt.Sprintf("%s %s", dr.RecordType(), dr.userHostname)
}
func (res *DNSRouteResult) SuccessSummary() string {
var msgFmt string
switch res.CName {
case ChangeNew:
msgFmt = "Added CNAME %s which will route to this tunnel"
case ChangeUpdated: // this is not currently returned by tunnelsore
msgFmt = "%s updated to route to your tunnel"
case ChangeUnchanged:
msgFmt = "%s is already configured to route to your tunnel"
}
return fmt.Sprintf(msgFmt, res.hostname())
}
// hostname yields the resulting name for the DNS route; if that is not available from Cloudflare API, then the
// requested name is returned instead (should not be the common path, it is just a fall-back).
func (res *DNSRouteResult) hostname() string {
if res.Name != "" {
return res.Name
}
return res.route.userHostname
}
type LBRoute struct {
lbName string
lbPool string
}
type LBRouteResult struct {
route *LBRoute
LoadBalancer Change `json:"load_balancer"`
Pool Change `json:"pool"`
}
func NewLBRoute(lbName, lbPool string) HostnameRoute {
return &LBRoute{
lbName: lbName,
lbPool: lbPool,
}
}
func (lr *LBRoute) MarshalJSON() ([]byte, error) {
s := struct {
Type string `json:"type"`
LBName string `json:"lb_name"`
LBPool string `json:"lb_pool"`
}{
Type: lr.RecordType(),
LBName: lr.lbName,
LBPool: lr.lbPool,
}
return json.Marshal(&s)
}
func (lr *LBRoute) RecordType() string {
return "lb"
}
func (lb *LBRoute) String() string {
return fmt.Sprintf("%s %s %s", lb.RecordType(), lb.lbName, lb.lbPool)
}
func (lr *LBRoute) UnmarshalResult(body io.Reader) (HostnameRouteResult, error) {
var result LBRouteResult
err := parseResponse(body, &result)
result.route = lr
return &result, err
}
func (res *LBRouteResult) SuccessSummary() string {
var msg string
switch res.LoadBalancer + "," + res.Pool {
case "new,new":
msg = "Created load balancer %s and added a new pool %s with this tunnel as an origin"
case "new,updated":
msg = "Created load balancer %s with an existing pool %s which was updated to use this tunnel as an origin"
case "new,unchanged":
msg = "Created load balancer %s with an existing pool %s which already has this tunnel as an origin"
case "updated,new":
msg = "Added new pool %[2]s with this tunnel as an origin to load balancer %[1]s"
case "updated,updated":
msg = "Updated pool %[2]s to use this tunnel as an origin and added it to load balancer %[1]s"
case "updated,unchanged":
msg = "Added pool %[2]s, which already has this tunnel as an origin, to load balancer %[1]s"
case "unchanged,updated":
msg = "Added this tunnel as an origin in pool %[2]s which is already used by load balancer %[1]s"
case "unchanged,unchanged":
msg = "Load balancer %s already uses pool %s which has this tunnel as an origin"
case "unchanged,new":
// this state is not possible
fallthrough
default:
msg = "Something went wrong: failed to modify load balancer %s with pool %s; please check traffic manager configuration in the dashboard"
}
return fmt.Sprintf(msg, res.route.lbName, res.route.lbPool)
}
func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error) {
endpoint := r.baseEndpoints.zoneLevel
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID))
resp, err := r.sendRequest("PUT", endpoint, route)
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return route.UnmarshalResult(resp.Body)
}
return nil, r.statusCodeToError("add route", resp)
}

View File

@ -1,17 +1,9 @@
package tunnelstore
package cfapi
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"reflect"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
@ -105,125 +97,3 @@ func TestLBRouteResultSuccessSummary(t *testing.T) {
assert.Equal(t, tt.expected, actual, "case %d", i+1)
}
}
func Test_parseListTunnels(t *testing.T) {
type args struct {
body string
}
tests := []struct {
name string
args args
want []*Tunnel
wantErr bool
}{
{
name: "empty list",
args: args{body: `{"success": true, "result": []}`},
want: []*Tunnel{},
},
{
name: "success is false",
args: args{body: `{"success": false, "result": []}`},
wantErr: true,
},
{
name: "errors are present",
args: args{body: `{"errors": [{"code": 1003, "message":"An A, AAAA or CNAME record already exists with that host"}], "result": []}`},
wantErr: true,
},
{
name: "invalid response",
args: args{body: `abc`},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body := ioutil.NopCloser(bytes.NewReader([]byte(tt.args.body)))
got, err := parseListTunnels(body)
if (err != nil) != tt.wantErr {
t.Errorf("parseListTunnels() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseListTunnels() = %v, want %v", got, tt.want)
}
})
}
}
func Test_unmarshalTunnel(t *testing.T) {
type args struct {
reader io.Reader
}
tests := []struct {
name string
args args
want *Tunnel
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := unmarshalTunnel(tt.args.reader)
if (err != nil) != tt.wantErr {
t.Errorf("unmarshalTunnel() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("unmarshalTunnel() = %v, want %v", got, tt.want)
}
})
}
}
func TestUnmarshalTunnelOk(t *testing.T) {
jsonBody := `{"success": true, "result": {"id": "00000000-0000-0000-0000-000000000000","name":"test","created_at":"0001-01-01T00:00:00Z","connections":[]}}`
expected := Tunnel{
ID: uuid.Nil,
Name: "test",
CreatedAt: time.Time{},
Connections: []Connection{},
}
actual, err := unmarshalTunnel(bytes.NewReader([]byte(jsonBody)))
assert.NoError(t, err)
assert.Equal(t, &expected, actual)
}
func TestUnmarshalTunnelErr(t *testing.T) {
tests := []string{
`abc`,
`{"success": true, "result": abc}`,
`{"success": false, "result": {"id": "00000000-0000-0000-0000-000000000000","name":"test","created_at":"0001-01-01T00:00:00Z","connections":[]}}}`,
`{"errors": [{"code": 1003, "message":"An A, AAAA or CNAME record already exists with that host"}], "result": {"id": "00000000-0000-0000-0000-000000000000","name":"test","created_at":"0001-01-01T00:00:00Z","connections":[]}}}`,
}
for i, test := range tests {
_, err := unmarshalTunnel(bytes.NewReader([]byte(test)))
assert.Error(t, err, fmt.Sprintf("Test #%v failed", i))
}
}
func TestUnmarshalConnections(t *testing.T) {
jsonBody := `{"success":true,"messages":[],"errors":[],"result":[{"id":"d4041254-91e3-4deb-bd94-b46e11680b1e","features":["ha-origin"],"version":"2021.2.5","arch":"darwin_amd64","conns":[{"colo_name":"LIS","id":"ac2286e5-c708-4588-a6a0-ba6b51940019","is_pending_reconnect":false,"origin_ip":"148.38.28.2","opened_at":"0001-01-01T00:00:00Z"}],"run_at":"0001-01-01T00:00:00Z"}]}`
expected := ActiveClient{
ID: uuid.MustParse("d4041254-91e3-4deb-bd94-b46e11680b1e"),
Features: []string{"ha-origin"},
Version: "2021.2.5",
Arch: "darwin_amd64",
RunAt: time.Time{},
Connections: []Connection{{
ID: uuid.MustParse("ac2286e5-c708-4588-a6a0-ba6b51940019"),
ColoName: "LIS",
IsPendingReconnect: false,
OriginIP: net.ParseIP("148.38.28.2"),
OpenedAt: time.Time{},
}},
}
actual, err := parseConnectionsDetails(bytes.NewReader([]byte(jsonBody)))
assert.NoError(t, err)
assert.Equal(t, []*ActiveClient{&expected}, actual)
}

240
cfapi/ip_route.go Normal file
View File

@ -0,0 +1,240 @@
package cfapi
import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"path"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
)
// Route is a mapping from customer's IP space to a tunnel.
// Each route allows the customer to route eyeballs in their corporate network
// to certain private IP ranges. Each Route represents an IP range in their
// network, and says that eyeballs can reach that route using the corresponding
// tunnel.
type Route struct {
Network CIDR `json:"network"`
TunnelID uuid.UUID `json:"tunnel_id"`
// Optional field. When unset, it means the Route belongs to the default virtual network.
VNetID *uuid.UUID `json:"virtual_network_id,omitempty"`
Comment string `json:"comment"`
CreatedAt time.Time `json:"created_at"`
DeletedAt time.Time `json:"deleted_at"`
}
// CIDR is just a newtype wrapper around net.IPNet. It adds JSON unmarshalling.
type CIDR net.IPNet
func (c CIDR) String() string {
n := net.IPNet(c)
return n.String()
}
func (c CIDR) MarshalJSON() ([]byte, error) {
str := c.String()
json, err := json.Marshal(str)
if err != nil {
return nil, errors.Wrap(err, "error serializing CIDR into JSON")
}
return json, nil
}
// UnmarshalJSON parses a JSON string into net.IPNet
func (c *CIDR) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return errors.Wrap(err, "error parsing cidr string")
}
_, network, err := net.ParseCIDR(s)
if err != nil {
return errors.Wrap(err, "error parsing invalid network from backend")
}
if network == nil {
return fmt.Errorf("backend returned invalid network %s", s)
}
*c = CIDR(*network)
return nil
}
// NewRoute has all the parameters necessary to add a new route to the table.
type NewRoute struct {
Network net.IPNet
TunnelID uuid.UUID
Comment string
// Optional field. If unset, backend will assume the default vnet for the account.
VNetID *uuid.UUID
}
// MarshalJSON handles fields with non-JSON types (e.g. net.IPNet).
func (r NewRoute) MarshalJSON() ([]byte, error) {
return json.Marshal(&struct {
TunnelID uuid.UUID `json:"tunnel_id"`
Comment string `json:"comment"`
VNetID *uuid.UUID `json:"virtual_network_id,omitempty"`
}{
TunnelID: r.TunnelID,
Comment: r.Comment,
VNetID: r.VNetID,
})
}
// DetailedRoute is just a Route with some extra fields, e.g. TunnelName.
type DetailedRoute struct {
Network CIDR `json:"network"`
TunnelID uuid.UUID `json:"tunnel_id"`
// Optional field. When unset, it means the DetailedRoute belongs to the default virtual network.
VNetID *uuid.UUID `json:"virtual_network_id,omitempty"`
Comment string `json:"comment"`
CreatedAt time.Time `json:"created_at"`
DeletedAt time.Time `json:"deleted_at"`
TunnelName string `json:"tunnel_name"`
}
// IsZero checks if DetailedRoute is the zero value.
func (r *DetailedRoute) IsZero() bool {
return r.TunnelID == uuid.Nil
}
// TableString outputs a table row summarizing the route, to be used
// when showing the user their routing table.
func (r DetailedRoute) TableString() string {
deletedColumn := "-"
if !r.DeletedAt.IsZero() {
deletedColumn = r.DeletedAt.Format(time.RFC3339)
}
vnetColumn := "default"
if r.VNetID != nil {
vnetColumn = r.VNetID.String()
}
return fmt.Sprintf(
"%s\t%s\t%s\t%s\t%s\t%s\t%s\t",
r.Network.String(),
vnetColumn,
r.Comment,
r.TunnelID,
r.TunnelName,
r.CreatedAt.Format(time.RFC3339),
deletedColumn,
)
}
type DeleteRouteParams struct {
Network net.IPNet
// Optional field. If unset, backend will assume the default vnet for the account.
VNetID *uuid.UUID
}
type GetRouteByIpParams struct {
Ip net.IP
// Optional field. If unset, backend will assume the default vnet for the account.
VNetID *uuid.UUID
}
// ListRoutes calls the Tunnelstore GET endpoint for all routes under an account.
func (r *RESTClient) ListRoutes(filter *IpRouteFilter) ([]*DetailedRoute, error) {
endpoint := r.baseEndpoints.accountRoutes
endpoint.RawQuery = filter.Encode()
resp, err := r.sendRequest("GET", endpoint, nil)
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return parseListDetailedRoutes(resp.Body)
}
return nil, r.statusCodeToError("list routes", resp)
}
// AddRoute calls the Tunnelstore POST endpoint for a given route.
func (r *RESTClient) AddRoute(newRoute NewRoute) (Route, error) {
endpoint := r.baseEndpoints.accountRoutes
endpoint.Path = path.Join(endpoint.Path, "network", url.PathEscape(newRoute.Network.String()))
resp, err := r.sendRequest("POST", endpoint, newRoute)
if err != nil {
return Route{}, errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return parseRoute(resp.Body)
}
return Route{}, r.statusCodeToError("add route", resp)
}
// DeleteRoute calls the Tunnelstore DELETE endpoint for a given route.
func (r *RESTClient) DeleteRoute(params DeleteRouteParams) error {
endpoint := r.baseEndpoints.accountRoutes
endpoint.Path = path.Join(endpoint.Path, "network", url.PathEscape(params.Network.String()))
setVnetParam(&endpoint, params.VNetID)
resp, err := r.sendRequest("DELETE", endpoint, nil)
if err != nil {
return errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
_, err := parseRoute(resp.Body)
return err
}
return r.statusCodeToError("delete route", resp)
}
// GetByIP checks which route will proxy a given IP.
func (r *RESTClient) GetByIP(params GetRouteByIpParams) (DetailedRoute, error) {
endpoint := r.baseEndpoints.accountRoutes
endpoint.Path = path.Join(endpoint.Path, "ip", url.PathEscape(params.Ip.String()))
setVnetParam(&endpoint, params.VNetID)
resp, err := r.sendRequest("GET", endpoint, nil)
if err != nil {
return DetailedRoute{}, errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return parseDetailedRoute(resp.Body)
}
return DetailedRoute{}, r.statusCodeToError("get route by IP", resp)
}
func parseListDetailedRoutes(body io.ReadCloser) ([]*DetailedRoute, error) {
var routes []*DetailedRoute
err := parseResponse(body, &routes)
return routes, err
}
func parseRoute(body io.ReadCloser) (Route, error) {
var route Route
err := parseResponse(body, &route)
return route, err
}
func parseDetailedRoute(body io.ReadCloser) (DetailedRoute, error) {
var route DetailedRoute
err := parseResponse(body, &route)
return route, err
}
// setVnetParam overwrites the URL's query parameters with a query param to scope the HostnameRoute action to a certain
// virtual network (if one is provided).
func setVnetParam(endpoint *url.URL, vnetID *uuid.UUID) {
queryParams := url.Values{}
if vnetID != nil {
queryParams.Set("virtual_network_id", vnetID.String())
}
endpoint.RawQuery = queryParams.Encode()
}

165
cfapi/ip_route_filter.go Normal file
View File

@ -0,0 +1,165 @@
package cfapi
import (
"fmt"
"net"
"net/url"
"strconv"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/urfave/cli/v2"
)
var (
filterIpRouteDeleted = cli.BoolFlag{
Name: "filter-is-deleted",
Usage: "If false (default), only show non-deleted routes. If true, only show deleted routes.",
}
filterIpRouteTunnelID = cli.StringFlag{
Name: "filter-tunnel-id",
Usage: "Show only routes with the given tunnel ID.",
}
filterSubsetIpRoute = cli.StringFlag{
Name: "filter-network-is-subset-of",
Aliases: []string{"nsub"},
Usage: "Show only routes whose network is a subset of the given network.",
}
filterSupersetIpRoute = cli.StringFlag{
Name: "filter-network-is-superset-of",
Aliases: []string{"nsup"},
Usage: "Show only routes whose network is a superset of the given network.",
}
filterIpRouteComment = cli.StringFlag{
Name: "filter-comment-is",
Usage: "Show only routes with this comment.",
}
filterIpRouteByVnet = cli.StringFlag{
Name: "filter-virtual-network-id",
Usage: "Show only routes that are attached to the given virtual network ID.",
}
// Flags contains all filter flags.
IpRouteFilterFlags = []cli.Flag{
&filterIpRouteDeleted,
&filterIpRouteTunnelID,
&filterSubsetIpRoute,
&filterSupersetIpRoute,
&filterIpRouteComment,
&filterIpRouteByVnet,
}
)
// IpRouteFilter which routes get queried.
type IpRouteFilter struct {
queryParams url.Values
}
// NewIpRouteFilterFromCLI parses CLI flags to discover which filters should get applied.
func NewIpRouteFilterFromCLI(c *cli.Context) (*IpRouteFilter, error) {
f := &IpRouteFilter{
queryParams: url.Values{},
}
// Set deletion filter
if flag := filterIpRouteDeleted.Name; c.IsSet(flag) && c.Bool(flag) {
f.deleted()
} else {
f.notDeleted()
}
if subset, err := cidrFromFlag(c, filterSubsetIpRoute); err != nil {
return nil, err
} else if subset != nil {
f.networkIsSupersetOf(*subset)
}
if superset, err := cidrFromFlag(c, filterSupersetIpRoute); err != nil {
return nil, err
} else if superset != nil {
f.networkIsSupersetOf(*superset)
}
if comment := c.String(filterIpRouteComment.Name); comment != "" {
f.commentIs(comment)
}
if tunnelID := c.String(filterIpRouteTunnelID.Name); tunnelID != "" {
u, err := uuid.Parse(tunnelID)
if err != nil {
return nil, errors.Wrapf(err, "Couldn't parse UUID from %s", filterIpRouteTunnelID.Name)
}
f.tunnelID(u)
}
if vnetId := c.String(filterIpRouteByVnet.Name); vnetId != "" {
u, err := uuid.Parse(vnetId)
if err != nil {
return nil, errors.Wrapf(err, "Couldn't parse UUID from %s", filterIpRouteByVnet.Name)
}
f.vnetID(u)
}
if maxFetch := c.Int("max-fetch-size"); maxFetch > 0 {
f.MaxFetchSize(uint(maxFetch))
}
return f, nil
}
// Parses a CIDR from the flag. If the flag was unset, returns (nil, nil).
func cidrFromFlag(c *cli.Context, flag cli.StringFlag) (*net.IPNet, error) {
if !c.IsSet(flag.Name) {
return nil, nil
}
_, subset, err := net.ParseCIDR(c.String(flag.Name))
if err != nil {
return nil, err
} else if subset == nil {
return nil, fmt.Errorf("Invalid CIDR supplied for %s", flag.Name)
}
return subset, nil
}
func (f *IpRouteFilter) commentIs(comment string) {
f.queryParams.Set("comment", comment)
}
func (f *IpRouteFilter) notDeleted() {
f.queryParams.Set("is_deleted", "false")
}
func (f *IpRouteFilter) deleted() {
f.queryParams.Set("is_deleted", "true")
}
func (f *IpRouteFilter) networkIsSubsetOf(superset net.IPNet) {
f.queryParams.Set("network_subset", superset.String())
}
func (f *IpRouteFilter) networkIsSupersetOf(subset net.IPNet) {
f.queryParams.Set("network_superset", subset.String())
}
func (f *IpRouteFilter) existedAt(existedAt time.Time) {
f.queryParams.Set("existed_at", existedAt.Format(time.RFC3339))
}
func (f *IpRouteFilter) tunnelID(id uuid.UUID) {
f.queryParams.Set("tunnel_id", id.String())
}
func (f *IpRouteFilter) vnetID(id uuid.UUID) {
f.queryParams.Set("virtual_network_id", id.String())
}
func (f *IpRouteFilter) MaxFetchSize(max uint) {
f.queryParams.Set("per_page", strconv.Itoa(int(max)))
}
func (f IpRouteFilter) Encode() string {
return f.queryParams.Encode()
}

175
cfapi/ip_route_test.go Normal file
View File

@ -0,0 +1,175 @@
package cfapi
import (
"encoding/json"
"fmt"
"net"
"strings"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func TestUnmarshalRoute(t *testing.T) {
testCases := []struct {
Json string
HasVnet bool
}{
{
`{
"network":"10.1.2.40/29",
"tunnel_id":"fba6ffea-807f-4e7a-a740-4184ee1b82c8",
"comment":"test",
"created_at":"2020-12-22T02:00:15.587008Z",
"deleted_at":null
}`,
false,
},
{
`{
"network":"10.1.2.40/29",
"tunnel_id":"fba6ffea-807f-4e7a-a740-4184ee1b82c8",
"comment":"test",
"created_at":"2020-12-22T02:00:15.587008Z",
"deleted_at":null,
"virtual_network_id":"38c95083-8191-4110-8339-3f438d44fdb9"
}`,
true,
},
}
for _, testCase := range testCases {
data := testCase.Json
var r Route
err := json.Unmarshal([]byte(data), &r)
// Check everything worked
require.NoError(t, err)
require.Equal(t, uuid.MustParse("fba6ffea-807f-4e7a-a740-4184ee1b82c8"), r.TunnelID)
require.Equal(t, "test", r.Comment)
_, cidr, err := net.ParseCIDR("10.1.2.40/29")
require.NoError(t, err)
require.Equal(t, CIDR(*cidr), r.Network)
require.Equal(t, "test", r.Comment)
if testCase.HasVnet {
require.Equal(t, uuid.MustParse("38c95083-8191-4110-8339-3f438d44fdb9"), *r.VNetID)
} else {
require.Nil(t, r.VNetID)
}
}
}
func TestDetailedRouteJsonRoundtrip(t *testing.T) {
testCases := []struct {
Json string
HasVnet bool
}{
{
`{
"network":"10.1.2.40/29",
"tunnel_id":"fba6ffea-807f-4e7a-a740-4184ee1b82c8",
"comment":"test",
"created_at":"2020-12-22T02:00:15.587008Z",
"deleted_at":"2021-01-14T05:01:42.183002Z",
"tunnel_name":"Mr. Tun"
}`,
false,
},
{
`{
"network":"10.1.2.40/29",
"tunnel_id":"fba6ffea-807f-4e7a-a740-4184ee1b82c8",
"virtual_network_id":"38c95083-8191-4110-8339-3f438d44fdb9",
"comment":"test",
"created_at":"2020-12-22T02:00:15.587008Z",
"deleted_at":"2021-01-14T05:01:42.183002Z",
"tunnel_name":"Mr. Tun"
}`,
true,
},
}
for _, testCase := range testCases {
data := testCase.Json
var r DetailedRoute
err := json.Unmarshal([]byte(data), &r)
// Check everything worked
require.NoError(t, err)
require.Equal(t, uuid.MustParse("fba6ffea-807f-4e7a-a740-4184ee1b82c8"), r.TunnelID)
require.Equal(t, "test", r.Comment)
_, cidr, err := net.ParseCIDR("10.1.2.40/29")
require.NoError(t, err)
require.Equal(t, CIDR(*cidr), r.Network)
require.Equal(t, "test", r.Comment)
require.Equal(t, "Mr. Tun", r.TunnelName)
if testCase.HasVnet {
require.Equal(t, uuid.MustParse("38c95083-8191-4110-8339-3f438d44fdb9"), *r.VNetID)
} else {
require.Nil(t, r.VNetID)
}
bytes, err := json.Marshal(r)
require.NoError(t, err)
obtainedJson := string(bytes)
data = strings.Replace(data, "\t", "", -1)
data = strings.Replace(data, "\n", "", -1)
require.Equal(t, data, obtainedJson)
}
}
func TestMarshalNewRoute(t *testing.T) {
_, network, err := net.ParseCIDR("1.2.3.4/32")
require.NoError(t, err)
require.NotNil(t, network)
vnetId := uuid.New()
newRoutes := []NewRoute{
{
Network: *network,
TunnelID: uuid.New(),
Comment: "hi",
},
{
Network: *network,
TunnelID: uuid.New(),
Comment: "hi",
VNetID: &vnetId,
},
}
for _, newRoute := range newRoutes {
// Test where receiver is struct
serialized, err := json.Marshal(newRoute)
require.NoError(t, err)
require.True(t, strings.Contains(string(serialized), "tunnel_id"))
// Test where receiver is pointer to struct
serialized, err = json.Marshal(&newRoute)
require.NoError(t, err)
require.True(t, strings.Contains(string(serialized), "tunnel_id"))
if newRoute.VNetID == nil {
require.False(t, strings.Contains(string(serialized), "virtual_network_id"))
} else {
require.True(t, strings.Contains(string(serialized), "virtual_network_id"))
}
}
}
func TestRouteTableString(t *testing.T) {
_, network, err := net.ParseCIDR("1.2.3.4/32")
require.NoError(t, err)
require.NotNil(t, network)
r := DetailedRoute{
Network: CIDR(*network),
}
row := r.TableString()
fmt.Println(row)
require.True(t, strings.HasPrefix(row, "1.2.3.4/32"))
}

183
cfapi/tunnel.go Normal file
View File

@ -0,0 +1,183 @@
package cfapi
import (
"fmt"
"io"
"net"
"net/http"
"net/url"
"path"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
)
var ErrTunnelNameConflict = errors.New("tunnel with name already exists")
type Tunnel struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
CreatedAt time.Time `json:"created_at"`
DeletedAt time.Time `json:"deleted_at"`
Connections []Connection `json:"connections"`
}
type Connection struct {
ColoName string `json:"colo_name"`
ID uuid.UUID `json:"id"`
IsPendingReconnect bool `json:"is_pending_reconnect"`
OriginIP net.IP `json:"origin_ip"`
OpenedAt time.Time `json:"opened_at"`
}
type ActiveClient struct {
ID uuid.UUID `json:"id"`
Features []string `json:"features"`
Version string `json:"version"`
Arch string `json:"arch"`
RunAt time.Time `json:"run_at"`
Connections []Connection `json:"conns"`
}
type newTunnel struct {
Name string `json:"name"`
TunnelSecret []byte `json:"tunnel_secret"`
}
type CleanupParams struct {
queryParams url.Values
}
func NewCleanupParams() *CleanupParams {
return &CleanupParams{
queryParams: url.Values{},
}
}
func (cp *CleanupParams) ForClient(clientID uuid.UUID) {
cp.queryParams.Set("client_id", clientID.String())
}
func (cp CleanupParams) encode() string {
return cp.queryParams.Encode()
}
func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) {
if name == "" {
return nil, errors.New("tunnel name required")
}
if _, err := uuid.Parse(name); err == nil {
return nil, errors.New("you cannot use UUIDs as tunnel names")
}
body := &newTunnel{
Name: name,
TunnelSecret: tunnelSecret,
}
resp, err := r.sendRequest("POST", r.baseEndpoints.accountLevel, body)
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
switch resp.StatusCode {
case http.StatusOK:
return unmarshalTunnel(resp.Body)
case http.StatusConflict:
return nil, ErrTunnelNameConflict
}
return nil, r.statusCodeToError("create tunnel", resp)
}
func (r *RESTClient) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) {
endpoint := r.baseEndpoints.accountLevel
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))
resp, err := r.sendRequest("GET", endpoint, nil)
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return unmarshalTunnel(resp.Body)
}
return nil, r.statusCodeToError("get tunnel", resp)
}
func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error {
endpoint := r.baseEndpoints.accountLevel
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))
resp, err := r.sendRequest("DELETE", endpoint, nil)
if err != nil {
return errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
return r.statusCodeToError("delete tunnel", resp)
}
func (r *RESTClient) ListTunnels(filter *TunnelFilter) ([]*Tunnel, error) {
endpoint := r.baseEndpoints.accountLevel
endpoint.RawQuery = filter.encode()
resp, err := r.sendRequest("GET", endpoint, nil)
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return parseListTunnels(resp.Body)
}
return nil, r.statusCodeToError("list tunnels", resp)
}
func parseListTunnels(body io.ReadCloser) ([]*Tunnel, error) {
var tunnels []*Tunnel
err := parseResponse(body, &tunnels)
return tunnels, err
}
func (r *RESTClient) ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error) {
endpoint := r.baseEndpoints.accountLevel
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/connections", tunnelID))
resp, err := r.sendRequest("GET", endpoint, nil)
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return parseConnectionsDetails(resp.Body)
}
return nil, r.statusCodeToError("list connection details", resp)
}
func parseConnectionsDetails(reader io.Reader) ([]*ActiveClient, error) {
var clients []*ActiveClient
err := parseResponse(reader, &clients)
return clients, err
}
func (r *RESTClient) CleanupConnections(tunnelID uuid.UUID, params *CleanupParams) error {
endpoint := r.baseEndpoints.accountLevel
endpoint.RawQuery = params.encode()
endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/connections", tunnelID))
resp, err := r.sendRequest("DELETE", endpoint, nil)
if err != nil {
return errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
return r.statusCodeToError("cleanup connections", resp)
}
func unmarshalTunnel(reader io.Reader) (*Tunnel, error) {
var tunnel Tunnel
err := parseResponse(reader, &tunnel)
return &tunnel, err
}

55
cfapi/tunnel_filter.go Normal file
View File

@ -0,0 +1,55 @@
package cfapi
import (
"net/url"
"strconv"
"time"
"github.com/google/uuid"
)
const (
TimeLayout = time.RFC3339
)
type TunnelFilter struct {
queryParams url.Values
}
func NewTunnelFilter() *TunnelFilter {
return &TunnelFilter{
queryParams: url.Values{},
}
}
func (f *TunnelFilter) ByName(name string) {
f.queryParams.Set("name", name)
}
func (f *TunnelFilter) ByNamePrefix(namePrefix string) {
f.queryParams.Set("name_prefix", namePrefix)
}
func (f *TunnelFilter) ExcludeNameWithPrefix(excludePrefix string) {
f.queryParams.Set("exclude_prefix", excludePrefix)
}
func (f *TunnelFilter) NoDeleted() {
f.queryParams.Set("is_deleted", "false")
}
func (f *TunnelFilter) ByExistedAt(existedAt time.Time) {
f.queryParams.Set("existed_at", existedAt.Format(TimeLayout))
}
func (f *TunnelFilter) ByTunnelID(tunnelID uuid.UUID) {
f.queryParams.Set("uuid", tunnelID.String())
}
func (f *TunnelFilter) MaxFetchSize(max uint) {
f.queryParams.Set("per_page", strconv.Itoa(int(max)))
}
func (f TunnelFilter) encode() string {
return f.queryParams.Encode()
}

149
cfapi/tunnel_test.go Normal file
View File

@ -0,0 +1,149 @@
package cfapi
import (
"bytes"
"fmt"
"io/ioutil"
"net"
"reflect"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
var loc, _ = time.LoadLocation("UTC")
func Test_parseListTunnels(t *testing.T) {
type args struct {
body string
}
tests := []struct {
name string
args args
want []*Tunnel
wantErr bool
}{
{
name: "empty list",
args: args{body: `{"success": true, "result": []}`},
want: []*Tunnel{},
},
{
name: "success is false",
args: args{body: `{"success": false, "result": []}`},
wantErr: true,
},
{
name: "errors are present",
args: args{body: `{"errors": [{"code": 1003, "message":"An A, AAAA or CNAME record already exists with that host"}], "result": []}`},
wantErr: true,
},
{
name: "invalid response",
args: args{body: `abc`},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body := ioutil.NopCloser(bytes.NewReader([]byte(tt.args.body)))
got, err := parseListTunnels(body)
if (err != nil) != tt.wantErr {
t.Errorf("parseListTunnels() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseListTunnels() = %v, want %v", got, tt.want)
}
})
}
}
func Test_unmarshalTunnel(t *testing.T) {
type args struct {
body string
}
tests := []struct {
name string
args args
want *Tunnel
wantErr bool
}{
{
name: "empty list",
args: args{body: `{"success": true, "result": {"id":"b34cc7ce-925b-46ee-bc23-4cb5c18d8292","created_at":"2021-07-29T13:46:14.090955Z","deleted_at":"2021-07-29T14:07:27.559047Z","name":"qt-bIWWN7D662ogh61pCPfu5s2XgqFY1OyV","account_id":6946212,"account_tag":"5ab4e9dfbd435d24068829fda0077963","conns_active_at":null,"conns_inactive_at":"2021-07-29T13:47:22.548482Z","tun_type":"cfd_tunnel","metadata":{"qtid":"a6fJROgkXutNruBGaJjD"}}}`},
want: &Tunnel{
ID: uuid.MustParse("b34cc7ce-925b-46ee-bc23-4cb5c18d8292"),
Name: "qt-bIWWN7D662ogh61pCPfu5s2XgqFY1OyV",
CreatedAt: time.Date(2021, 07, 29, 13, 46, 14, 90955000, loc),
DeletedAt: time.Date(2021, 07, 29, 14, 7, 27, 559047000, loc),
Connections: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := unmarshalTunnel(strings.NewReader(tt.args.body))
if (err != nil) != tt.wantErr {
t.Errorf("unmarshalTunnel() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("unmarshalTunnel() = %v, want %v", got, tt.want)
}
})
}
}
func TestUnmarshalTunnelOk(t *testing.T) {
jsonBody := `{"success": true, "result": {"id": "00000000-0000-0000-0000-000000000000","name":"test","created_at":"0001-01-01T00:00:00Z","connections":[]}}`
expected := Tunnel{
ID: uuid.Nil,
Name: "test",
CreatedAt: time.Time{},
Connections: []Connection{},
}
actual, err := unmarshalTunnel(bytes.NewReader([]byte(jsonBody)))
assert.NoError(t, err)
assert.Equal(t, &expected, actual)
}
func TestUnmarshalTunnelErr(t *testing.T) {
tests := []string{
`abc`,
`{"success": true, "result": abc}`,
`{"success": false, "result": {"id": "00000000-0000-0000-0000-000000000000","name":"test","created_at":"0001-01-01T00:00:00Z","connections":[]}}}`,
`{"errors": [{"code": 1003, "message":"An A, AAAA or CNAME record already exists with that host"}], "result": {"id": "00000000-0000-0000-0000-000000000000","name":"test","created_at":"0001-01-01T00:00:00Z","connections":[]}}}`,
}
for i, test := range tests {
_, err := unmarshalTunnel(bytes.NewReader([]byte(test)))
assert.Error(t, err, fmt.Sprintf("Test #%v failed", i))
}
}
func TestUnmarshalConnections(t *testing.T) {
jsonBody := `{"success":true,"messages":[],"errors":[],"result":[{"id":"d4041254-91e3-4deb-bd94-b46e11680b1e","features":["ha-origin"],"version":"2021.2.5","arch":"darwin_amd64","conns":[{"colo_name":"LIS","id":"ac2286e5-c708-4588-a6a0-ba6b51940019","is_pending_reconnect":false,"origin_ip":"148.38.28.2","opened_at":"0001-01-01T00:00:00Z"}],"run_at":"0001-01-01T00:00:00Z"}]}`
expected := ActiveClient{
ID: uuid.MustParse("d4041254-91e3-4deb-bd94-b46e11680b1e"),
Features: []string{"ha-origin"},
Version: "2021.2.5",
Arch: "darwin_amd64",
RunAt: time.Time{},
Connections: []Connection{{
ID: uuid.MustParse("ac2286e5-c708-4588-a6a0-ba6b51940019"),
ColoName: "LIS",
IsPendingReconnect: false,
OriginIP: net.ParseIP("148.38.28.2"),
OpenedAt: time.Time{},
}},
}
actual, err := parseConnectionsDetails(bytes.NewReader([]byte(jsonBody)))
assert.NoError(t, err)
assert.Equal(t, []*ActiveClient{&expected}, actual)
}

127
cfapi/virtual_network.go Normal file
View File

@ -0,0 +1,127 @@
package cfapi
import (
"fmt"
"io"
"net/http"
"net/url"
"path"
"strconv"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
)
type NewVirtualNetwork struct {
Name string `json:"name"`
Comment string `json:"comment"`
IsDefault bool `json:"is_default"`
}
type VirtualNetwork struct {
ID uuid.UUID `json:"id"`
Comment string `json:"comment"`
Name string `json:"name"`
IsDefault bool `json:"is_default_network"`
CreatedAt time.Time `json:"created_at"`
DeletedAt time.Time `json:"deleted_at"`
}
type UpdateVirtualNetwork struct {
Name *string `json:"name,omitempty"`
Comment *string `json:"comment,omitempty"`
IsDefault *bool `json:"is_default_network,omitempty"`
}
func (virtualNetwork VirtualNetwork) TableString() string {
deletedColumn := "-"
if !virtualNetwork.DeletedAt.IsZero() {
deletedColumn = virtualNetwork.DeletedAt.Format(time.RFC3339)
}
return fmt.Sprintf(
"%s\t%s\t%s\t%s\t%s\t%s\t",
virtualNetwork.ID,
virtualNetwork.Name,
strconv.FormatBool(virtualNetwork.IsDefault),
virtualNetwork.Comment,
virtualNetwork.CreatedAt.Format(time.RFC3339),
deletedColumn,
)
}
func (r *RESTClient) CreateVirtualNetwork(newVnet NewVirtualNetwork) (VirtualNetwork, error) {
resp, err := r.sendRequest("POST", r.baseEndpoints.accountVnets, newVnet)
if err != nil {
return VirtualNetwork{}, errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return parseVnet(resp.Body)
}
return VirtualNetwork{}, r.statusCodeToError("add virtual network", resp)
}
func (r *RESTClient) ListVirtualNetworks(filter *VnetFilter) ([]*VirtualNetwork, error) {
endpoint := r.baseEndpoints.accountVnets
endpoint.RawQuery = filter.Encode()
resp, err := r.sendRequest("GET", endpoint, nil)
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return parseListVnets(resp.Body)
}
return nil, r.statusCodeToError("list virtual networks", resp)
}
func (r *RESTClient) DeleteVirtualNetwork(id uuid.UUID) error {
endpoint := r.baseEndpoints.accountVnets
endpoint.Path = path.Join(endpoint.Path, url.PathEscape(id.String()))
resp, err := r.sendRequest("DELETE", endpoint, nil)
if err != nil {
return errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
_, err := parseVnet(resp.Body)
return err
}
return r.statusCodeToError("delete virtual network", resp)
}
func (r *RESTClient) UpdateVirtualNetwork(id uuid.UUID, updates UpdateVirtualNetwork) error {
endpoint := r.baseEndpoints.accountVnets
endpoint.Path = path.Join(endpoint.Path, url.PathEscape(id.String()))
resp, err := r.sendRequest("PATCH", endpoint, updates)
if err != nil {
return errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
_, err := parseVnet(resp.Body)
return err
}
return r.statusCodeToError("update virtual network", resp)
}
func parseListVnets(body io.ReadCloser) ([]*VirtualNetwork, error) {
var vnets []*VirtualNetwork
err := parseResponse(body, &vnets)
return vnets, err
}
func parseVnet(body io.ReadCloser) (VirtualNetwork, error) {
var vnet VirtualNetwork
err := parseResponse(body, &vnet)
return vnet, err
}

View File

@ -0,0 +1,99 @@
package cfapi
import (
"net/url"
"strconv"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/urfave/cli/v2"
)
var (
filterVnetId = cli.StringFlag{
Name: "id",
Usage: "List virtual networks with the given `ID`",
}
filterVnetByName = cli.StringFlag{
Name: "name",
Usage: "List virtual networks with the given `NAME`",
}
filterDefaultVnet = cli.BoolFlag{
Name: "is-default",
Usage: "If true, lists the virtual network that is the default one. If false, lists all non-default virtual networks for the account. If absent, all are included in the results regardless of their default status.",
}
filterDeletedVnet = cli.BoolFlag{
Name: "show-deleted",
Usage: "If false (default), only show non-deleted virtual networks. If true, only show deleted virtual networks.",
}
VnetFilterFlags = []cli.Flag{
&filterVnetId,
&filterVnetByName,
&filterDefaultVnet,
&filterDeletedVnet,
}
)
// VnetFilter which virtual networks get queried.
type VnetFilter struct {
queryParams url.Values
}
func NewVnetFilter() *VnetFilter {
return &VnetFilter{
queryParams: url.Values{},
}
}
func (f *VnetFilter) ById(vnetId uuid.UUID) {
f.queryParams.Set("id", vnetId.String())
}
func (f *VnetFilter) ByName(name string) {
f.queryParams.Set("name", name)
}
func (f *VnetFilter) ByDefaultStatus(isDefault bool) {
f.queryParams.Set("is_default", strconv.FormatBool(isDefault))
}
func (f *VnetFilter) WithDeleted(isDeleted bool) {
f.queryParams.Set("is_deleted", strconv.FormatBool(isDeleted))
}
func (f *VnetFilter) MaxFetchSize(max uint) {
f.queryParams.Set("per_page", strconv.Itoa(int(max)))
}
func (f VnetFilter) Encode() string {
return f.queryParams.Encode()
}
// NewFromCLI parses CLI flags to discover which filters should get applied to list virtual networks.
func NewFromCLI(c *cli.Context) (*VnetFilter, error) {
f := NewVnetFilter()
if id := c.String("id"); id != "" {
vnetId, err := uuid.Parse(id)
if err != nil {
return nil, errors.Wrapf(err, "%s is not a valid virtual network ID", id)
}
f.ById(vnetId)
}
if name := c.String("name"); name != "" {
f.ByName(name)
}
if c.IsSet("is-default") {
f.ByDefaultStatus(c.Bool("is-default"))
}
f.WithDeleted(c.Bool("show-deleted"))
if maxFetch := c.Int("max-fetch-size"); maxFetch > 0 {
f.MaxFetchSize(uint(maxFetch))
}
return f, nil
}

View File

@ -0,0 +1,79 @@
package cfapi
import (
"encoding/json"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func TestVirtualNetworkJsonRoundtrip(t *testing.T) {
data := `{
"id":"74fce949-351b-4752-b261-81a56cfd3130",
"comment":"New York DC1",
"name":"us-east-1",
"is_default_network":true,
"created_at":"2021-11-26T14:40:02.600673Z",
"deleted_at":"2021-12-01T10:23:13.102645Z"
}`
var v VirtualNetwork
err := json.Unmarshal([]byte(data), &v)
require.NoError(t, err)
require.Equal(t, uuid.MustParse("74fce949-351b-4752-b261-81a56cfd3130"), v.ID)
require.Equal(t, "us-east-1", v.Name)
require.Equal(t, "New York DC1", v.Comment)
require.Equal(t, true, v.IsDefault)
bytes, err := json.Marshal(v)
require.NoError(t, err)
obtainedJson := string(bytes)
data = strings.Replace(data, "\t", "", -1)
data = strings.Replace(data, "\n", "", -1)
require.Equal(t, data, obtainedJson)
}
func TestMarshalNewVnet(t *testing.T) {
newVnet := NewVirtualNetwork{
Name: "eu-west-1",
Comment: "London office",
IsDefault: true,
}
serialized, err := json.Marshal(newVnet)
require.NoError(t, err)
require.True(t, strings.Contains(string(serialized), newVnet.Name))
}
func TestMarshalUpdateVnet(t *testing.T) {
newName := "bulgaria-1"
updates := UpdateVirtualNetwork{
Name: &newName,
}
// Test where receiver is struct
serialized, err := json.Marshal(updates)
require.NoError(t, err)
require.True(t, strings.Contains(string(serialized), newName))
}
func TestVnetTableString(t *testing.T) {
virtualNet := VirtualNetwork{
ID: uuid.New(),
Name: "us-east-1",
Comment: "New York DC1",
IsDefault: true,
CreatedAt: time.Now(),
DeletedAt: time.Time{},
}
row := virtualNet.TableString()
require.True(t, strings.HasPrefix(row, virtualNet.ID.String()))
require.True(t, strings.Contains(row, virtualNet.Name))
require.True(t, strings.Contains(row, virtualNet.Comment))
require.True(t, strings.Contains(row, "true"))
require.True(t, strings.HasSuffix(row, "-\t"))
}

View File

@ -1,20 +1,10 @@
pinned_go: &pinned_go go=1.17-1
pinned_go_fips: &pinned_go_fips go-boring=1.16.6-6
pinned_go: &pinned_go go=1.17.5-1
pinned_go_fips: &pinned_go_fips go-boring=1.17.5-1
build_dir: &build_dir /cfsetup_build
default-flavor: buster
stretch: &stretch
build:
build_dir: *build_dir
builddeps:
- *pinned_go_fips
- build-essential
post-cache:
- export GOOS=linux
- export GOARCH=amd64
- export FIPS=true
- make cloudflared
build-non-fips: # helpful to catch problems with non-fips (only used for releasing non-linux artifacts) before releases
build_dir: *build_dir
builddeps:
- *pinned_go
@ -23,27 +13,45 @@ stretch: &stretch
- export GOOS=linux
- export GOARCH=amd64
- make cloudflared
build-all-packages: #except osxpkg
build-fips:
build_dir: *build_dir
builddeps:
- *pinned_go_fips
- build-essential
- fakeroot
- rubygem-fpm
- rpm
- wget
# libmsi and libgcab are libraries the wixl binary depends on.
- libmsi-dev
- libgcab-dev
pre-cache:
# TODO: https://jira.cfops.it/browse/TUN-4792 Replace this wixl with the official one once msitools supports
# environment.
- wget https://github.com/sudarshan-reddy/msitools/releases/download/v0.101b/wixl -P /usr/local/bin
- chmod a+x /usr/local/bin/wixl
post-cache:
- export GOOS=linux
- export GOARCH=amd64
- export FIPS=true
- ./build-packages.sh
- make cloudflared
# except FIPS (handled in github-fips-release-pkgs) and macos (handled in github-release-macos-amd64)
github-release-pkgs:
build_dir: *build_dir
builddeps:
- *pinned_go
- build-essential
- fakeroot
- rubygem-fpm
- rpm
- wget
# libmsi and libgcab are libraries the wixl binary depends on.
- libmsi-dev
- libgcab-dev
- python3-dev
- libffi-dev
- python3-setuptools
- python3-pip
pre-cache: &github_release_pkgs_pre_cache
- wget https://github.com/sudarshan-reddy/msitools/releases/download/v0.101b/wixl -P /usr/local/bin
- chmod a+x /usr/local/bin/wixl
- pip3 install pynacl==1.4.0
- pip3 install pygithub==1.55
post-cache:
# build all packages (except macos and FIPS) and move them to /cfsetup/built_artifacts
- ./build-packages.sh
# release the packages built and moved to /cfsetup/built_artifacts
- make github-release-built-pkgs
# handle FIPS separately so that we built with gofips compiler
github-fips-release-pkgs:
build_dir: *build_dir
builddeps:
- *pinned_go_fips
@ -55,20 +63,29 @@ stretch: &stretch
# libmsi and libgcab are libraries the wixl binary depends on.
- libmsi-dev
- libgcab-dev
- python3-dev
- libffi-dev
- python3-setuptools
- python3-pip
pre-cache:
- wget https://github.com/sudarshan-reddy/msitools/releases/download/v0.101b/wixl -P /usr/local/bin
- chmod a+x /usr/local/bin/wixl
- pip3 install pygithub
pre-cache: *github_release_pkgs_pre_cache
post-cache:
# build all packages and move them to /cfsetup/built_artifacts
- ./build-packages.sh
# release the packages built and moved to /cfsetup/built_artifacts
# same logic as above, but for FIPS packages only
- ./build-packages-fips.sh
- make github-release-built-pkgs
build-deb:
build_dir: *build_dir
builddeps: &build_deb_deps
- *pinned_go
- build-essential
- fakeroot
- rubygem-fpm
post-cache:
- export GOOS=linux
- export GOARCH=amd64
- make cloudflared-deb
build-fips-internal-deb:
build_dir: *build_dir
builddeps: &build_fips_deb_deps
- *pinned_go_fips
- build-essential
- fakeroot
@ -77,15 +94,17 @@ stretch: &stretch
- export GOOS=linux
- export GOARCH=amd64
- export FIPS=true
- export ORIGINAL_NAME=true
- make cloudflared-deb
build-deb-nightly:
build-fips-internal-deb-nightly:
build_dir: *build_dir
builddeps: *build_deb_deps
builddeps: *build_fips_deb_deps
post-cache:
- export GOOS=linux
- export GOARCH=amd64
- export FIPS=true
- export NIGHTLY=true
- export FIPS=true
- export ORIGINAL_NAME=true
- make cloudflared-deb
build-deb-arm64:
build_dir: *build_dir
@ -97,7 +116,7 @@ stretch: &stretch
publish-deb:
build_dir: *build_dir
builddeps:
- *pinned_go_fips
- *pinned_go
- build-essential
- fakeroot
- rubygem-fpm
@ -105,27 +124,43 @@ stretch: &stretch
post-cache:
- export GOOS=linux
- export GOARCH=amd64
- export FIPS=true
- make publish-deb
github-release-macos-amd64:
build_dir: *build_dir
builddeps:
- *pinned_go
- build-essential
- python3-dev
- libffi-dev
- python3-setuptools
- python3-pip
pre-cache: &install_pygithub
- pip3 install pygithub
- pip3 install pynacl==1.4.0
- pip3 install pygithub==1.55
post-cache:
- make github-mac-upload
test:
build_dir: *build_dir
builddeps:
- *pinned_go
- build-essential
- gotest-to-teamcity
pre-cache: &test_pre_cache
- go get golang.org/x/tools/cmd/goimports
- go get github.com/sudarshan-reddy/go-sumtype@v0.0.0-20210827105221-82eca7e5abb1
post-cache:
- export GOOS=linux
- export GOARCH=amd64
- export PATH="$HOME/go/bin:$PATH"
- ./fmt-check.sh
- make test | gotest-to-teamcity
test-fips:
build_dir: *build_dir
builddeps:
- *pinned_go_fips
- build-essential
- gotest-to-teamcity
pre-cache:
- go get golang.org/x/tools/cmd/goimports
- go get github.com/sudarshan-reddy/go-sumtype@v0.0.0-20210827105221-82eca7e5abb1
pre-cache: *test_pre_cache
post-cache:
- export GOOS=linux
- export GOARCH=amd64
@ -150,7 +185,7 @@ stretch: &stretch
post-cache:
# Creates and routes a Named Tunnel for this build. Also constructs config file from env vars.
- python3 component-tests/setup.py --type create
- pytest component-tests
- pytest component-tests -o log_cli=true --log-cli-level=INFO
# The Named Tunnel is deleted and its route unprovisioned here.
- python3 component-tests/setup.py --type cleanup
update-homebrew:
@ -163,6 +198,9 @@ stretch: &stretch
build_dir: *build_dir
builddeps:
- *pinned_go
- build-essential
- python3-dev
- libffi-dev
- python3-setuptools
- python3-pip
pre-cache: *install_pygithub
@ -208,23 +246,10 @@ centos-7:
pre-cache:
- yum install -y fakeroot
- yum upgrade -y binutils-2.27-44.base.el7.x86_64
- wget https://golang.org/dl/go1.16.3.linux-amd64.tar.gz -P /tmp/
- tar -C /usr/local -xzf /tmp/go1.16.3.linux-amd64.tar.gz
- wget https://go.dev/dl/go1.17.5.linux-amd64.tar.gz -P /tmp/
- tar -C /usr/local -xzf /tmp/go1.17.5.linux-amd64.tar.gz
post-cache:
- export PATH=$PATH:/usr/local/go/bin
- export GOOS=linux
- export GOARCH=amd64
- make publish-rpm
build-rpm:
build_dir: *build_dir
builddeps: *el7_builddeps
pre-cache:
- yum install -y fakeroot
- yum upgrade -y binutils-2.27-44.base.el7.x86_64
- wget https://golang.org/dl/go1.16.3.linux-amd64.tar.gz -P /tmp/
- tar -C /usr/local -xzf /tmp/go1.16.3.linux-amd64.tar.gz
post-cache:
- export PATH=$PATH:/usr/local/go/bin
- export GOOS=linux
- export GOARCH=amd64
- make cloudflared-rpm
- make publish-rpm

15
check-fips.sh Executable file
View File

@ -0,0 +1,15 @@
# Pass the path to the executable to check for FIPS compliance
exe=$1
if [ "$(go tool nm "${exe}" | grep -c '_Cfunc__goboringcrypto_')" -eq 0 ]; then
# Asserts that executable is using FIPS-compliant boringcrypto
echo "${exe}: missing goboring symbols" >&2
exit 1
fi
if [ "$(go tool nm "${exe}" | grep -c 'crypto/internal/boring/sig.FIPSOnly')" -eq 0 ]; then
# Asserts that executable is using FIPS-only schemes
echo "${exe}: missing fipsonly symbols" >&2
exit 1
fi
echo "${exe} is FIPS-compliant"

View File

@ -240,7 +240,7 @@ func login(c *cli.Context) error {
return nil
}
// ensureURLScheme prepends a URL with https:// if it doesnt have a scheme. http:// URLs will not be converted.
// ensureURLScheme prepends a URL with https:// if it doesn't have a scheme. http:// URLs will not be converted.
func ensureURLScheme(url string) string {
url = strings.Replace(strings.ToLower(url), "http://", "https://", 1)
if !strings.HasPrefix(url, "https://") {

View File

@ -1,4 +1,4 @@
package buildinfo
package cliutil
import (
"fmt"
@ -11,23 +11,39 @@ type BuildInfo struct {
GoOS string `json:"go_os"`
GoVersion string `json:"go_version"`
GoArch string `json:"go_arch"`
BuildType string `json:"build_type"`
CloudflaredVersion string `json:"cloudflared_version"`
}
func GetBuildInfo(cloudflaredVersion string) *BuildInfo {
func GetBuildInfo(buildType, version string) *BuildInfo {
return &BuildInfo{
GoOS: runtime.GOOS,
GoVersion: runtime.Version(),
GoArch: runtime.GOARCH,
CloudflaredVersion: cloudflaredVersion,
BuildType: buildType,
CloudflaredVersion: version,
}
}
func (bi *BuildInfo) Log(log *zerolog.Logger) {
log.Info().Msgf("Version %s", bi.CloudflaredVersion)
if bi.BuildType != "" {
log.Info().Msgf("Built%s", bi.GetBuildTypeMsg())
}
log.Info().Msgf("GOOS: %s, GOVersion: %s, GoArch: %s", bi.GoOS, bi.GoVersion, bi.GoArch)
}
func (bi *BuildInfo) OSArch() string {
return fmt.Sprintf("%s_%s", bi.GoOS, bi.GoArch)
}
func (bi *BuildInfo) Version() string {
return bi.CloudflaredVersion
}
func (bi *BuildInfo) GetBuildTypeMsg() string {
if bi.BuildType == "" {
return ""
}
return fmt.Sprintf(" with %s", bi.BuildType)
}

View File

@ -11,7 +11,7 @@ func RemovedCommand(name string) *cli.Command {
Name: name,
Action: func(context *cli.Context) error {
return cli.Exit(
fmt.Sprintf("%s command is no longer supported by cloudflared. Consult Argo Tunnel documentation for possible alternative solutions.", name),
fmt.Sprintf("%s command is no longer supported by cloudflared. Consult Cloudflare Tunnel documentation for possible alternative solutions.", name),
-1,
)
},

View File

@ -1,3 +1,4 @@
//go:build !windows && !darwin && !linux
// +build !windows,!darwin,!linux
package main

View File

@ -1,3 +1,4 @@
//go:build linux
// +build linux
package main
@ -19,11 +20,11 @@ import (
func runApp(app *cli.App, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel system service",
Usage: "Manages the Cloudflare Tunnel system service",
Subcommands: []*cli.Command{
{
Name: "install",
Usage: "Install Argo Tunnel as a system service",
Usage: "Install Cloudflare Tunnel as a system service",
Action: cliutil.ConfiguredAction(installLinuxService),
Flags: []cli.Flag{
&cli.BoolFlag{
@ -34,7 +35,7 @@ func runApp(app *cli.App, graceShutdownC chan struct{}) {
},
{
Name: "uninstall",
Usage: "Uninstall the Argo Tunnel service",
Usage: "Uninstall the Cloudflare Tunnel service",
Action: cliutil.ConfiguredAction(uninstallLinuxService),
},
},
@ -55,7 +56,7 @@ var systemdTemplates = []ServiceTemplate{
{
Path: "/etc/systemd/system/cloudflared.service",
Content: `[Unit]
Description=Argo Tunnel
Description=Cloudflare Tunnel
After=network.target
[Service]
@ -72,7 +73,7 @@ WantedBy=multi-user.target
{
Path: "/etc/systemd/system/cloudflared-update.service",
Content: `[Unit]
Description=Update Argo Tunnel
Description=Update Cloudflare Tunnel
After=network.target
[Service]
@ -82,7 +83,7 @@ ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 11 ]; then s
{
Path: "/etc/systemd/system/cloudflared-update.timer",
Content: `[Unit]
Description=Update Argo Tunnel
Description=Update Cloudflare Tunnel
[Timer]
OnCalendar=daily
@ -99,7 +100,7 @@ var sysvTemplate = ServiceTemplate{
Content: `#!/bin/sh
# For RedHat and cousins:
# chkconfig: 2345 99 01
# description: Argo Tunnel agent
# description: Cloudflare Tunnel agent
# processname: {{.Path}}
### BEGIN INIT INFO
# Provides: {{.Path}}
@ -107,8 +108,8 @@ var sysvTemplate = ServiceTemplate{
# Required-Stop:
# Default-Start: 2 3 4 5
# Default-Stop: 0 1 6
# Short-Description: Argo Tunnel
# Description: Argo Tunnel agent
# Short-Description: Cloudflare Tunnel
# Description: Cloudflare Tunnel agent
### END INIT INFO
name=$(basename $(readlink -f $0))
cmd="{{.Path}} --config /etc/cloudflared/config.yml --pidfile /var/run/$name.pid --autoupdate-freq 24h0m0s{{ range .ExtraArgs }} {{ . }}{{ end }}"

View File

@ -1,3 +1,4 @@
//go:build darwin
// +build darwin
package main
@ -20,16 +21,16 @@ const (
func runApp(app *cli.App, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel launch agent",
Usage: "Manages the Cloudflare Tunnel launch agent",
Subcommands: []*cli.Command{
{
Name: "install",
Usage: "Install Argo Tunnel as an user launch agent",
Usage: "Install Cloudflare Tunnel as an user launch agent",
Action: cliutil.ConfiguredAction(installLaunchd),
},
{
Name: "uninstall",
Usage: "Uninstall the Argo Tunnel launch agent",
Usage: "Uninstall the Cloudflare Tunnel launch agent",
Action: cliutil.ConfiguredAction(uninstallLaunchd),
},
},
@ -110,13 +111,13 @@ func installLaunchd(c *cli.Context) error {
log := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog)
if isRootUser() {
log.Info().Msg("Installing Argo Tunnel client as a system launch daemon. " +
"Argo Tunnel client will run at boot")
log.Info().Msg("Installing Cloudflare Tunnel client as a system launch daemon. " +
"Cloudflare Tunnel client will run at boot")
} else {
log.Info().Msg("Installing Argo Tunnel client as an user launch agent. " +
"Note that Argo Tunnel client will only run when the user is logged in. " +
"If you want to run Argo Tunnel client at boot, install with root permission. " +
"For more information, visit https://developers.cloudflare.com/argo-tunnel/reference/service/")
log.Info().Msg("Installing Cloudflare Tunnel client as an user launch agent. " +
"Note that Cloudflare Tunnel client will only run when the user is logged in. " +
"If you want to run Cloudflare Tunnel client at boot, install with root permission. " +
"For more information, visit https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/run-as-service")
}
etPath, err := os.Executable()
if err != nil {
@ -159,9 +160,9 @@ func uninstallLaunchd(c *cli.Context) error {
log := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog)
if isRootUser() {
log.Info().Msg("Uninstalling Argo Tunnel as a system launch daemon")
log.Info().Msg("Uninstalling Cloudflare Tunnel as a system launch daemon")
} else {
log.Info().Msg("Uninstalling Argo Tunnel as an user launch agent")
log.Info().Msg("Uninstalling Cloudflare Tunnel as an user launch agent")
}
installPath, err := installPath()
if err != nil {

View File

@ -31,6 +31,7 @@ const (
var (
Version = "DEV"
BuildTime = "unknown"
BuildType = ""
// Mostly network errors that we don't want reported back to Sentry, this is done by substring match.
ignoredErrors = []string{
"connection reset by peer",
@ -46,9 +47,10 @@ var (
func main() {
rand.Seed(time.Now().UnixNano())
metrics.RegisterBuildInfo(BuildTime, Version)
metrics.RegisterBuildInfo(BuildType, BuildTime, Version)
raven.SetRelease(Version)
maxprocs.Set()
bInfo := cliutil.GetBuildInfo(BuildType, Version)
// Graceful shutdown channel used by the app. When closed, app must terminate gracefully.
// Windows service manager closes this channel when it receives stop command.
@ -67,21 +69,21 @@ func main() {
app.Copyright = fmt.Sprintf(
`(c) %d Cloudflare Inc.
Your installation of cloudflared software constitutes a symbol of your signature indicating that you accept
the terms of the Cloudflare License (https://developers.cloudflare.com/argo-tunnel/license/),
the terms of the Cloudflare License (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/license),
Terms (https://www.cloudflare.com/terms/) and Privacy Policy (https://www.cloudflare.com/privacypolicy/).`,
time.Now().Year(),
)
app.Version = fmt.Sprintf("%s (built %s)", Version, BuildTime)
app.Version = fmt.Sprintf("%s (built %s%s)", Version, BuildTime, bInfo.GetBuildTypeMsg())
app.Description = `cloudflared connects your machine or user identity to Cloudflare's global network.
You can use it to authenticate a session to reach an API behind Access, route web traffic to this machine,
and configure access control.
See https://developers.cloudflare.com/argo-tunnel/ for more in-depth documentation.`
See https://developers.cloudflare.com/cloudflare-one/connections/connect-apps for more in-depth documentation.`
app.Flags = flags()
app.Action = action(graceShutdownC)
app.Commands = commands(cli.ShowVersion)
tunnel.Init(Version, graceShutdownC) // we need this to support the tunnel sub command...
tunnel.Init(bInfo, graceShutdownC) // we need this to support the tunnel sub command...
access.Init(graceShutdownC)
updater.Init(Version)
runApp(app, graceShutdownC)

View File

@ -21,7 +21,7 @@ import (
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/proxydns"
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
@ -35,7 +35,6 @@ import (
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunneldns"
"github.com/cloudflare/cloudflared/tunnelstore"
)
const (
@ -86,7 +85,7 @@ const (
var (
graceShutdownC chan struct{}
version string
buildInfo *cliutil.BuildInfo
routeFailMsg = fmt.Sprintf("failed to provision routing, please create it manually via Cloudflare dashboard or UI; "+
"most likely you already have a conflicting record there. You can also rerun this command with --%s to overwrite "+
@ -102,6 +101,8 @@ func Commands() []*cli.Command {
buildLoginSubcommand(false),
buildCreateCommand(),
buildRouteCommand(),
// TODO TUN-5477 this should not be hidden
buildVirtualNetworkSubcommand(true),
buildRunCommand(),
buildListCommand(),
buildInfoCommand(),
@ -126,14 +127,14 @@ func buildTunnelCommand(subcommands []*cli.Command) *cli.Command {
Name: "tunnel",
Action: cliutil.ConfiguredAction(TunnelCommand),
Category: "Tunnel",
Usage: "Make a locally-running web service accessible over the internet using Argo Tunnel.",
Usage: "Make a locally-running web service accessible over the internet using Cloudflare Tunnel.",
ArgsUsage: " ",
Description: `Argo Tunnel asks you to specify a hostname on a Cloudflare-powered
Description: `Cloudflare Tunnel asks you to specify a hostname on a Cloudflare-powered
domain you control and a local address. Traffic from that hostname is routed
(optionally via a Cloudflare Load Balancer) to this machine and appears on the
specified port where it can be served.
This feature requires your Cloudflare account be subscribed to the Argo Smart Routing feature.
This feature requires your Cloudflare account be subscribed to the Cloudflare Smart Routing feature.
To use, begin by calling login to download a certificate:
@ -173,15 +174,16 @@ func TunnelCommand(c *cli.Context) error {
return runClassicTunnel(sc)
}
func Init(ver string, gracefulShutdown chan struct{}) {
version, graceShutdownC = ver, gracefulShutdown
func Init(info *cliutil.BuildInfo, gracefulShutdown chan struct{}) {
buildInfo, graceShutdownC = info, gracefulShutdown
}
// runAdhocNamedTunnel create, route and run a named tunnel in one command
func runAdhocNamedTunnel(sc *subcommandContext, name, credentialsOutputPath string) error {
tunnel, ok, err := sc.tunnelActive(name)
if err != nil || !ok {
tunnel, err = sc.create(name, credentialsOutputPath)
// pass empty string as secret to generate one
tunnel, err = sc.create(name, credentialsOutputPath, "")
if err != nil {
return errors.Wrap(err, "failed to create tunnel")
}
@ -206,22 +208,22 @@ func runAdhocNamedTunnel(sc *subcommandContext, name, credentialsOutputPath stri
// runClassicTunnel creates a "classic" non-named tunnel
func runClassicTunnel(sc *subcommandContext) error {
return StartServer(sc.c, version, nil, sc.log, sc.isUIEnabled)
return StartServer(sc.c, buildInfo, nil, sc.log, sc.isUIEnabled)
}
func routeFromFlag(c *cli.Context) (route tunnelstore.Route, ok bool) {
func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) {
if hostname := c.String("hostname"); hostname != "" {
if lbPool := c.String("lb-pool"); lbPool != "" {
return tunnelstore.NewLBRoute(hostname, lbPool), true
return cfapi.NewLBRoute(hostname, lbPool), true
}
return tunnelstore.NewDNSRoute(hostname, c.Bool(overwriteDNSFlagName)), true
return cfapi.NewDNSRoute(hostname, c.Bool(overwriteDNSFlagName)), true
}
return nil, false
}
func StartServer(
c *cli.Context,
version string,
info *cliutil.BuildInfo,
namedTunnel *connection.NamedTunnelConfig,
log *zerolog.Logger,
isUIEnabled bool,
@ -268,8 +270,7 @@ func StartServer(
defer trace.Stop()
}
buildInfo := buildinfo.GetBuildInfo(version)
buildInfo.Log(log)
info.Log(log)
logClientOptions(c, log)
// this context drives the server, when it's cancelled tunnel and all other components (origins, dns, etc...) should stop
@ -333,7 +334,7 @@ func StartServer(
observer.SendURL(quickTunnelURL)
}
tunnelConfig, ingressRules, err := prepareTunnelConfig(c, buildInfo, version, log, logTransport, observer, namedTunnel)
tunnelConfig, ingressRules, err := prepareTunnelConfig(c, info, log, logTransport, observer, namedTunnel)
if err != nil {
log.Err(err).Msg("Couldn't start tunnel")
return err
@ -374,7 +375,7 @@ func StartServer(
if isUIEnabled {
tunnelUI := ui.NewUIModel(
version,
info.Version(),
hostname,
metricsListener.Addr().String(),
&ingressRules,
@ -488,7 +489,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
credentialsFileFlag,
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "is-autoupdated",
Usage: "Signal the new process that Argo Tunnel connector has been autoupdated",
Usage: "Signal the new process that Cloudflare Tunnel connector has been autoupdated",
Value: false,
Hidden: true,
}),
@ -646,6 +647,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Value: "https://api.trycloudflare.com",
Hidden: true,
}),
altsrc.NewIntFlag(&cli.IntFlag{
Name: "max-fetch-size",
Usage: `The maximum number of results that cloudflared can fetch from Cloudflare API for any listing operations needed`,
EnvVars: []string{"TUNNEL_MAX_FETCH_SIZE"},
Hidden: true,
}),
selectProtocolFlag,
overwriteDNSFlag,
}...)

View File

@ -16,7 +16,8 @@ import (
"github.com/urfave/cli/v2"
"golang.org/x/crypto/ssh/terminal"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery"
@ -148,8 +149,7 @@ func getOriginCert(originCertPath string, log *zerolog.Logger) ([]byte, error) {
func prepareTunnelConfig(
c *cli.Context,
buildInfo *buildinfo.BuildInfo,
version string,
info *cliutil.BuildInfo,
log, logTransport *zerolog.Logger,
observer *connection.Observer,
namedTunnel *connection.NamedTunnelConfig,
@ -193,8 +193,8 @@ func prepareTunnelConfig(
namedTunnel.Client = tunnelpogs.ClientInfo{
ClientID: clientUUID[:],
Features: dedup(features),
Version: version,
Arch: buildInfo.OSArch(),
Version: info.Version(),
Arch: info.OSArch(),
}
ingressRules, err = ingress.ParseIngress(cfg)
if err != nil && err != ingress.ErrNoIngressRules {
@ -238,7 +238,7 @@ func prepareTunnelConfig(
log.Info().Msgf("Warp-routing is enabled")
}
protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, log)
protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, origin.ResolveTTL, log)
if err != nil {
return nil, ingress.Ingress{}, err
}
@ -281,7 +281,7 @@ func prepareTunnelConfig(
return &origin.TunnelConfig{
ConnectionConfig: connectionConfig,
OSArch: buildInfo.OSArch(),
OSArch: info.OSArch(),
ClientID: clientID,
EdgeAddrs: c.StringSlice("edge"),
Region: c.String("region"),
@ -293,7 +293,7 @@ func prepareTunnelConfig(
Log: log,
LogTransport: logTransport,
Observer: observer,
ReportedVersion: version,
ReportedVersion: info.Version(),
// Note TUN-3758 , we use Int because UInt is not supported with altsrc
Retries: uint(c.Int("retries")),
RunFromTerminal: isRunningFromTerminal(),

View File

@ -1,4 +1,6 @@
//go:build ignore
// +build ignore
// TODO: Remove the above build tag and include this test when we start compiling with Golang 1.10.0+
package tunnel

View File

@ -5,12 +5,12 @@ import (
"github.com/google/uuid"
"github.com/cloudflare/cloudflared/tunnelstore"
"github.com/cloudflare/cloudflared/cfapi"
)
type Info struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
CreatedAt time.Time `json:"createdAt"`
Connectors []*tunnelstore.ActiveClient `json:"conns"`
ID uuid.UUID `json:"id"`
Name string `json:"name"`
CreatedAt time.Time `json:"createdAt"`
Connectors []*cfapi.ActiveClient `json:"conns"`
}

View File

@ -70,9 +70,13 @@ func RunQuickTunnel(sc *subcommandContext) error {
sc.log.Info().Msg(line)
}
if !sc.c.IsSet("protocol") {
sc.c.Set("protocol", "quic")
}
return StartServer(
sc.c,
version,
buildInfo,
&connection.NamedTunnelConfig{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname},
sc.log,
sc.isUIEnabled,

View File

@ -1,3 +1,4 @@
//go:build !windows
// +build !windows
package tunnel

View File

@ -1,6 +1,7 @@
package tunnel
import (
"encoding/base64"
"encoding/json"
"fmt"
"os"
@ -13,9 +14,9 @@ import (
"github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/certutil"
"github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelstore"
)
type errInvalidJSONCredential struct {
@ -36,7 +37,7 @@ type subcommandContext struct {
fs fileSystem
// These fields should be accessed using their respective Getter
tunnelstoreClient tunnelstore.Client
tunnelstoreClient cfapi.Client
userCredential *userCredential
}
@ -67,7 +68,7 @@ type userCredential struct {
certPath string
}
func (sc *subcommandContext) client() (tunnelstore.Client, error) {
func (sc *subcommandContext) client() (cfapi.Client, error) {
if sc.tunnelstoreClient != nil {
return sc.tunnelstoreClient, nil
}
@ -75,8 +76,8 @@ func (sc *subcommandContext) client() (tunnelstore.Client, error) {
if err != nil {
return nil, err
}
userAgent := fmt.Sprintf("cloudflared/%s", version)
client, err := tunnelstore.NewRESTClient(
userAgent := fmt.Sprintf("cloudflared/%s", buildInfo.Version())
client, err := cfapi.NewRESTClient(
sc.c.String("api-url"),
credential.cert.AccountID,
credential.cert.ZoneID,
@ -148,15 +149,27 @@ func (sc *subcommandContext) readTunnelCredentials(credFinder CredFinder) (conne
return credentials, nil
}
func (sc *subcommandContext) create(name string, credentialsFilePath string) (*tunnelstore.Tunnel, error) {
func (sc *subcommandContext) create(name string, credentialsFilePath string, secret string) (*cfapi.Tunnel, error) {
client, err := sc.client()
if err != nil {
return nil, errors.Wrap(err, "couldn't create client to talk to Argo Tunnel backend")
return nil, errors.Wrap(err, "couldn't create client to talk to Cloudflare Tunnel backend")
}
tunnelSecret, err := generateTunnelSecret()
if err != nil {
return nil, errors.Wrap(err, "couldn't generate the secret for your new tunnel")
var tunnelSecret []byte
if secret == "" {
tunnelSecret, err = generateTunnelSecret()
if err != nil {
return nil, errors.Wrap(err, "couldn't generate the secret for your new tunnel")
}
} else {
decodedSecret, err := base64.StdEncoding.DecodeString(secret)
if err != nil {
return nil, errors.Wrap(err, "Couldn't decode tunnel secret from base64")
}
tunnelSecret = []byte(decodedSecret)
if len(tunnelSecret) < 32 {
return nil, errors.New("Decoded tunnel secret must be at least 32 bytes long")
}
}
tunnel, err := client.CreateTunnel(name, tunnelSecret)
@ -211,7 +224,7 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string) (*t
return tunnel, nil
}
func (sc *subcommandContext) list(filter *tunnelstore.Filter) ([]*tunnelstore.Tunnel, error) {
func (sc *subcommandContext) list(filter *cfapi.TunnelFilter) ([]*cfapi.Tunnel, error) {
client, err := sc.client()
if err != nil {
return nil, err
@ -230,7 +243,7 @@ func (sc *subcommandContext) delete(tunnelIDs []uuid.UUID) error {
for _, id := range tunnelIDs {
tunnel, err := client.GetTunnel(id)
if err != nil {
return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", tunnel.ID)
return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", id)
}
// Check if tunnel DeletedAt field has already been set
@ -238,7 +251,7 @@ func (sc *subcommandContext) delete(tunnelIDs []uuid.UUID) error {
return fmt.Errorf("Tunnel %s has already been deleted", tunnel.ID)
}
if forceFlagSet {
if err := client.CleanupConnections(tunnel.ID, tunnelstore.NewCleanupParams()); err != nil {
if err := client.CleanupConnections(tunnel.ID, cfapi.NewCleanupParams()); err != nil {
return errors.Wrapf(err, "Error cleaning up connections for tunnel %s", tunnel.ID)
}
}
@ -290,7 +303,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
return StartServer(
sc.c,
version,
buildInfo,
&connection.NamedTunnelConfig{Credentials: credentials},
sc.log,
sc.isUIEnabled,
@ -298,7 +311,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error {
}
func (sc *subcommandContext) cleanupConnections(tunnelIDs []uuid.UUID) error {
params := tunnelstore.NewCleanupParams()
params := cfapi.NewCleanupParams()
extraLog := ""
if connector := sc.c.String("connector-id"); connector != "" {
connectorID, err := uuid.Parse(connector)
@ -322,7 +335,7 @@ func (sc *subcommandContext) cleanupConnections(tunnelIDs []uuid.UUID) error {
return nil
}
func (sc *subcommandContext) route(tunnelID uuid.UUID, r tunnelstore.Route) (tunnelstore.RouteResult, error) {
func (sc *subcommandContext) route(tunnelID uuid.UUID, r cfapi.HostnameRoute) (cfapi.HostnameRouteResult, error) {
client, err := sc.client()
if err != nil {
return nil, err
@ -332,8 +345,8 @@ func (sc *subcommandContext) route(tunnelID uuid.UUID, r tunnelstore.Route) (tun
}
// Query Tunnelstore to find the active tunnel with the given name.
func (sc *subcommandContext) tunnelActive(name string) (*tunnelstore.Tunnel, bool, error) {
filter := tunnelstore.NewFilter()
func (sc *subcommandContext) tunnelActive(name string) (*cfapi.Tunnel, bool, error) {
filter := cfapi.NewTunnelFilter()
filter.NoDeleted()
filter.ByName(name)
tunnels, err := sc.list(filter)
@ -373,52 +386,42 @@ func (sc *subcommandContext) findID(input string) (uuid.UUID, error) {
}
// findIDs is just like mapping `findID` over a slice, but it only uses
// one Tunnelstore API call.
// one Tunnelstore API call per non-UUID input provided.
func (sc *subcommandContext) findIDs(inputs []string) ([]uuid.UUID, error) {
uuids, names := splitUuids(inputs)
// Shortcut without Tunnelstore call if we find that all inputs are already UUIDs.
uuids, err := convertNamesToUuids(inputs, make(map[string]uuid.UUID))
if err == nil {
return uuids, nil
for _, name := range names {
filter := cfapi.NewTunnelFilter()
filter.NoDeleted()
filter.ByName(name)
tunnels, err := sc.list(filter)
if err != nil {
return nil, err
}
if len(tunnels) != 1 {
return nil, fmt.Errorf("there should only be 1 non-deleted Tunnel named %s", name)
}
uuids = append(uuids, tunnels[0].ID)
}
// First, look up all tunnels the user has
filter := tunnelstore.NewFilter()
filter.NoDeleted()
tunnels, err := sc.list(filter)
if err != nil {
return nil, err
}
// Do the pure list-processing in its own function, so that it can be
// unit tested easily.
return findIDs(tunnels, inputs)
return uuids, nil
}
func findIDs(tunnels []*tunnelstore.Tunnel, inputs []string) ([]uuid.UUID, error) {
// Put them into a dictionary for faster lookups
nameToID := make(map[string]uuid.UUID, len(tunnels))
for _, tunnel := range tunnels {
nameToID[tunnel.Name] = tunnel.ID
}
func splitUuids(inputs []string) ([]uuid.UUID, []string) {
uuids := make([]uuid.UUID, 0)
names := make([]string, 0)
return convertNamesToUuids(inputs, nameToID)
}
func convertNamesToUuids(inputs []string, nameToID map[string]uuid.UUID) ([]uuid.UUID, error) {
tunnelIDs := make([]uuid.UUID, len(inputs))
var badInputs []string
for i, input := range inputs {
if id, err := uuid.Parse(input); err == nil {
tunnelIDs[i] = id
} else if id, ok := nameToID[input]; ok {
tunnelIDs[i] = id
for _, input := range inputs {
id, err := uuid.Parse(input)
if err != nil {
names = append(names, input)
} else {
badInputs = append(badInputs, input)
uuids = append(uuids, id)
}
}
if len(badInputs) > 0 {
msg := "Please specify either the ID or name of a tunnel. The following inputs were neither: %s"
return nil, fmt.Errorf(msg, strings.Join(badInputs, ", "))
}
return tunnelIDs, nil
return uuids, names
}

View File

@ -1,16 +1,14 @@
package tunnel
import (
"net"
"github.com/pkg/errors"
"github.com/cloudflare/cloudflared/teamnet"
"github.com/cloudflare/cloudflared/cfapi"
)
const noClientMsg = "error while creating backend client"
func (sc *subcommandContext) listRoutes(filter *teamnet.Filter) ([]*teamnet.DetailedRoute, error) {
func (sc *subcommandContext) listRoutes(filter *cfapi.IpRouteFilter) ([]*cfapi.DetailedRoute, error) {
client, err := sc.client()
if err != nil {
return nil, errors.Wrap(err, noClientMsg)
@ -18,26 +16,26 @@ func (sc *subcommandContext) listRoutes(filter *teamnet.Filter) ([]*teamnet.Deta
return client.ListRoutes(filter)
}
func (sc *subcommandContext) addRoute(newRoute teamnet.NewRoute) (teamnet.Route, error) {
func (sc *subcommandContext) addRoute(newRoute cfapi.NewRoute) (cfapi.Route, error) {
client, err := sc.client()
if err != nil {
return teamnet.Route{}, errors.Wrap(err, noClientMsg)
return cfapi.Route{}, errors.Wrap(err, noClientMsg)
}
return client.AddRoute(newRoute)
}
func (sc *subcommandContext) deleteRoute(network net.IPNet) error {
func (sc *subcommandContext) deleteRoute(params cfapi.DeleteRouteParams) error {
client, err := sc.client()
if err != nil {
return errors.Wrap(err, noClientMsg)
}
return client.DeleteRoute(network)
return client.DeleteRoute(params)
}
func (sc *subcommandContext) getRouteByIP(ip net.IP) (teamnet.DetailedRoute, error) {
func (sc *subcommandContext) getRouteByIP(params cfapi.GetRouteByIpParams) (cfapi.DetailedRoute, error) {
client, err := sc.client()
if err != nil {
return teamnet.DetailedRoute{}, errors.Wrap(err, noClientMsg)
return cfapi.DetailedRoute{}, errors.Wrap(err, noClientMsg)
}
return client.GetByIP(ip)
return client.GetByIP(params)
}

View File

@ -13,83 +13,10 @@ import (
"github.com/rs/zerolog"
"github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/tunnelstore"
)
func Test_findIDs(t *testing.T) {
type args struct {
tunnels []*tunnelstore.Tunnel
inputs []string
}
tests := []struct {
name string
args args
want []uuid.UUID
wantErr bool
}{
{
name: "input not found",
args: args{
inputs: []string{"asdf"},
},
wantErr: true,
},
{
name: "only UUID",
args: args{
inputs: []string{"a8398a0b-876d-48ed-b609-3fcfd67a4950"},
},
want: []uuid.UUID{uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950")},
},
{
name: "only name",
args: args{
tunnels: []*tunnelstore.Tunnel{
{
ID: uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"),
Name: "tunnel1",
},
},
inputs: []string{"tunnel1"},
},
want: []uuid.UUID{uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950")},
},
{
name: "both UUID and name",
args: args{
tunnels: []*tunnelstore.Tunnel{
{
ID: uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"),
Name: "tunnel1",
},
{
ID: uuid.MustParse("bf028b68-744f-466e-97f8-c46161d80aa5"),
Name: "tunnel2",
},
},
inputs: []string{"tunnel1", "bf028b68-744f-466e-97f8-c46161d80aa5"},
},
want: []uuid.UUID{
uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"),
uuid.MustParse("bf028b68-744f-466e-97f8-c46161d80aa5"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := findIDs(tt.args.tunnels, tt.args.inputs)
if (err != nil) != tt.wantErr {
t.Errorf("findIDs() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("findIDs() = %v, want %v", got, tt.want)
}
})
}
}
type mockFileSystem struct {
rf func(string) ([]byte, error)
vfp func(string) bool
@ -109,7 +36,7 @@ func Test_subcommandContext_findCredentials(t *testing.T) {
log *zerolog.Logger
isUIEnabled bool
fs fileSystem
tunnelstoreClient tunnelstore.Client
tunnelstoreClient cfapi.Client
userCredential *userCredential
}
type args struct {
@ -260,13 +187,13 @@ func Test_subcommandContext_findCredentials(t *testing.T) {
}
type deleteMockTunnelStore struct {
tunnelstore.Client
cfapi.Client
mockTunnels map[uuid.UUID]mockTunnelBehaviour
deletedTunnelIDs []uuid.UUID
}
type mockTunnelBehaviour struct {
tunnel tunnelstore.Tunnel
tunnel cfapi.Tunnel
deleteErr error
cleanupErr error
}
@ -282,7 +209,7 @@ func newDeleteMockTunnelStore(tunnels ...mockTunnelBehaviour) *deleteMockTunnelS
}
}
func (d *deleteMockTunnelStore) GetTunnel(tunnelID uuid.UUID) (*tunnelstore.Tunnel, error) {
func (d *deleteMockTunnelStore) GetTunnel(tunnelID uuid.UUID) (*cfapi.Tunnel, error) {
tunnel, ok := d.mockTunnels[tunnelID]
if !ok {
return nil, fmt.Errorf("Couldn't find tunnel: %v", tunnelID)
@ -306,7 +233,7 @@ func (d *deleteMockTunnelStore) DeleteTunnel(tunnelID uuid.UUID) error {
return nil
}
func (d *deleteMockTunnelStore) CleanupConnections(tunnelID uuid.UUID, _ *tunnelstore.CleanupParams) error {
func (d *deleteMockTunnelStore) CleanupConnections(tunnelID uuid.UUID, _ *cfapi.CleanupParams) error {
tunnel, ok := d.mockTunnels[tunnelID]
if !ok {
return fmt.Errorf("Couldn't find tunnel: %v", tunnelID)
@ -357,10 +284,10 @@ func Test_subcommandContext_Delete(t *testing.T) {
}(),
tunnelstoreClient: newDeleteMockTunnelStore(
mockTunnelBehaviour{
tunnel: tunnelstore.Tunnel{ID: tunnelID1},
tunnel: cfapi.Tunnel{ID: tunnelID1},
},
mockTunnelBehaviour{
tunnel: tunnelstore.Tunnel{ID: tunnelID2},
tunnel: cfapi.Tunnel{ID: tunnelID2},
},
),
},

View File

@ -0,0 +1,40 @@
package tunnel
import (
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/cloudflare/cloudflared/cfapi"
)
func (sc *subcommandContext) addVirtualNetwork(newVnet cfapi.NewVirtualNetwork) (cfapi.VirtualNetwork, error) {
client, err := sc.client()
if err != nil {
return cfapi.VirtualNetwork{}, errors.Wrap(err, noClientMsg)
}
return client.CreateVirtualNetwork(newVnet)
}
func (sc *subcommandContext) listVirtualNetworks(filter *cfapi.VnetFilter) ([]*cfapi.VirtualNetwork, error) {
client, err := sc.client()
if err != nil {
return nil, errors.Wrap(err, noClientMsg)
}
return client.ListVirtualNetworks(filter)
}
func (sc *subcommandContext) deleteVirtualNetwork(vnetId uuid.UUID) error {
client, err := sc.client()
if err != nil {
return errors.Wrap(err, noClientMsg)
}
return client.DeleteVirtualNetwork(vnetId)
}
func (sc *subcommandContext) updateVirtualNetwork(vnetId uuid.UUID, updates cfapi.UpdateVirtualNetwork) error {
client, err := sc.client()
if err != nil {
return errors.Wrap(err, noClientMsg)
}
return client.UpdateVirtualNetwork(vnetId, updates)
}

View File

@ -21,11 +21,11 @@ import (
"golang.org/x/net/idna"
yaml "gopkg.in/yaml.v2"
"github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/tunnelstore"
)
const (
@ -64,8 +64,8 @@ var (
Name: "when",
Aliases: []string{"w"},
Usage: "List tunnels that are active at the given `TIME` in RFC3339 format",
Layout: tunnelstore.TimeLayout,
DefaultText: fmt.Sprintf("current time, %s", time.Now().Format(tunnelstore.TimeLayout)),
Layout: cfapi.TimeLayout,
DefaultText: fmt.Sprintf("current time, %s", time.Now().Format(cfapi.TimeLayout)),
}
listIDFlag = &cli.StringFlag{
Name: "id",
@ -156,6 +156,12 @@ var (
Usage: `Overwrites existing DNS records with this hostname`,
EnvVars: []string{"TUNNEL_FORCE_PROVISIONING_DNS"},
}
createSecretFlag = &cli.StringFlag{
Name: "secret",
Aliases: []string{"s"},
Usage: "Base64 encoded secret to set for the tunnel. The decoded secret must be at least 32 bytes long. If not specified, a random 32-byte secret will be generated.",
EnvVars: []string{"TUNNEL_CREATE_SECRET"},
}
)
func buildCreateCommand() *cli.Command {
@ -170,7 +176,7 @@ func buildCreateCommand() *cli.Command {
For example, to create a tunnel named 'my-tunnel' run:
$ cloudflared tunnel create my-tunnel`,
Flags: []cli.Flag{outputFormatFlag, credentialsFileFlagCLIOnly},
Flags: []cli.Flag{outputFormatFlag, credentialsFileFlagCLIOnly, createSecretFlag},
CustomHelpTemplate: commandHelpTemplate(),
}
}
@ -196,7 +202,7 @@ func createCommand(c *cli.Context) error {
warningChecker := updater.StartWarningCheck(c)
defer warningChecker.LogWarningIfAny(sc.log)
_, err = sc.create(name, c.String(CredFileFlag))
_, err = sc.create(name, c.String(CredFileFlag), c.String(createSecretFlag.Name))
return errors.Wrap(err, "failed to create tunnel")
}
@ -254,7 +260,7 @@ func listCommand(c *cli.Context) error {
warningChecker := updater.StartWarningCheck(c)
defer warningChecker.LogWarningIfAny(sc.log)
filter := tunnelstore.NewFilter()
filter := cfapi.NewTunnelFilter()
if !c.Bool("show-deleted") {
filter.NoDeleted()
}
@ -277,6 +283,9 @@ func listCommand(c *cli.Context) error {
}
filter.ByTunnelID(tunnelID)
}
if maxFetch := c.Int("max-fetch-size"); maxFetch > 0 {
filter.MaxFetchSize(uint(maxFetch))
}
tunnels, err := sc.list(filter)
if err != nil {
@ -320,13 +329,13 @@ func listCommand(c *cli.Context) error {
if len(tunnels) > 0 {
formatAndPrintTunnelList(tunnels, c.Bool("show-recently-disconnected"))
} else {
fmt.Println("You have no tunnels, use 'cloudflared tunnel create' to define a new tunnel")
fmt.Println("No tunnels were found for the given filter flags. You can use 'cloudflared tunnel create' to create a tunnel.")
}
return nil
}
func formatAndPrintTunnelList(tunnels []*tunnelstore.Tunnel, showRecentlyDisconnected bool) {
func formatAndPrintTunnelList(tunnels []*cfapi.Tunnel, showRecentlyDisconnected bool) {
writer := tabWriter()
defer writer.Flush()
@ -348,7 +357,7 @@ func formatAndPrintTunnelList(tunnels []*tunnelstore.Tunnel, showRecentlyDisconn
}
}
func fmtConnections(connections []tunnelstore.Connection, showRecentlyDisconnected bool) string {
func fmtConnections(connections []cfapi.Connection, showRecentlyDisconnected bool) string {
// Count connections per colo
numConnsPerColo := make(map[string]uint, len(connections))
@ -468,8 +477,8 @@ func tunnelInfo(c *cli.Context) error {
return nil
}
func getTunnel(sc *subcommandContext, tunnelID uuid.UUID) (*tunnelstore.Tunnel, error) {
filter := tunnelstore.NewFilter()
func getTunnel(sc *subcommandContext, tunnelID uuid.UUID) (*cfapi.Tunnel, error) {
filter := cfapi.NewTunnelFilter()
filter.ByTunnelID(tunnelID)
tunnels, err := sc.list(filter)
if err != nil {
@ -702,7 +711,7 @@ Further information about managing Cloudflare WARP traffic to your tunnel is ava
{
Name: "dns",
Action: cliutil.ConfiguredAction(routeDnsCommand),
Usage: "Route a hostname by creating a DNS CNAME record to a tunnel",
Usage: "HostnameRoute a hostname by creating a DNS CNAME record to a tunnel",
UsageText: "cloudflared tunnel route dns [TUNNEL] [HOSTNAME]",
Description: `Creates a DNS CNAME record hostname that points to the tunnel.`,
Flags: []cli.Flag{overwriteDNSFlag},
@ -719,7 +728,7 @@ Further information about managing Cloudflare WARP traffic to your tunnel is ava
}
}
func dnsRouteFromArg(c *cli.Context, overwriteExisting bool) (tunnelstore.Route, error) {
func dnsRouteFromArg(c *cli.Context, overwriteExisting bool) (cfapi.HostnameRoute, error) {
const (
userHostnameIndex = 1
expectedNArgs = 2
@ -733,10 +742,10 @@ func dnsRouteFromArg(c *cli.Context, overwriteExisting bool) (tunnelstore.Route,
} else if !validateHostname(userHostname, true) {
return nil, errors.Errorf("%s is not a valid hostname", userHostname)
}
return tunnelstore.NewDNSRoute(userHostname, overwriteExisting), nil
return cfapi.NewDNSRoute(userHostname, overwriteExisting), nil
}
func lbRouteFromArg(c *cli.Context) (tunnelstore.Route, error) {
func lbRouteFromArg(c *cli.Context) (cfapi.HostnameRoute, error) {
const (
lbNameIndex = 1
lbPoolIndex = 2
@ -759,7 +768,7 @@ func lbRouteFromArg(c *cli.Context) (tunnelstore.Route, error) {
return nil, errors.Errorf("%s is not a valid pool name", lbPool)
}
return tunnelstore.NewLBRoute(lbName, lbPool), nil
return cfapi.NewLBRoute(lbName, lbPool), nil
}
var nameRegex = regexp.MustCompile("^[_a-zA-Z0-9][-_.a-zA-Z0-9]*$")
@ -806,7 +815,7 @@ func routeCommand(c *cli.Context, routeType string) error {
if err != nil {
return err
}
var route tunnelstore.Route
var route cfapi.HostnameRoute
switch routeType {
case "dns":
route, err = dnsRouteFromArg(c, c.Bool(overwriteDNSFlagName))

View File

@ -8,12 +8,12 @@ import (
homedir "github.com/mitchellh/go-homedir"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/tunnelstore"
"github.com/cloudflare/cloudflared/cfapi"
)
func Test_fmtConnections(t *testing.T) {
type args struct {
connections []tunnelstore.Connection
connections []cfapi.Connection
}
tests := []struct {
name string
@ -23,14 +23,14 @@ func Test_fmtConnections(t *testing.T) {
{
name: "empty",
args: args{
connections: []tunnelstore.Connection{},
connections: []cfapi.Connection{},
},
want: "",
},
{
name: "trivial",
args: args{
connections: []tunnelstore.Connection{
connections: []cfapi.Connection{
{
ColoName: "DFW",
ID: uuid.MustParse("ea550130-57fd-4463-aab1-752822231ddd"),
@ -42,7 +42,7 @@ func Test_fmtConnections(t *testing.T) {
{
name: "with a pending reconnect",
args: args{
connections: []tunnelstore.Connection{
connections: []cfapi.Connection{
{
ColoName: "DFW",
ID: uuid.MustParse("ea550130-57fd-4463-aab1-752822231ddd"),
@ -55,7 +55,7 @@ func Test_fmtConnections(t *testing.T) {
{
name: "many colos",
args: args{
connections: []tunnelstore.Connection{
connections: []cfapi.Connection{
{
ColoName: "YRV",
ID: uuid.MustParse("ea550130-57fd-4463-aab1-752822231ddd"),

View File

@ -6,35 +6,54 @@ import (
"os"
"text/tabwriter"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/teamnet"
)
"github.com/urfave/cli/v2"
var (
vnetFlag = &cli.StringFlag{
Name: "virtual-network",
Aliases: []string{"vn"},
Usage: "The ID or name of the virtual network to which the route is associated to.",
}
)
func buildRouteIPSubcommand() *cli.Command {
return &cli.Command{
Name: "ip",
Usage: "Configure and query Cloudflare WARP routing to services or private networks available through this tunnel.",
Usage: "Configure and query Cloudflare WARP routing to private IP networks made available through Cloudflare Tunnels.",
UsageText: "cloudflared tunnel [--config FILEPATH] route COMMAND [arguments...]",
Description: `cloudflared can provision private routes from any IP space to origins in your corporate network.
Users enrolled in your Cloudflare for Teams organization can reach those routes through the
Cloudflare WARP client. You can also build rules to determine who can reach certain routes.`,
Description: `cloudflared can provision routes for any IP space in your corporate network. Users enrolled in
your Cloudflare for Teams organization can reach those IPs through the Cloudflare WARP
client. You can then configure L7/L4 filtering on https://dash.teams.cloudflare.com to
determine who can reach certain routes.
By default IP routes all exist within a single virtual network. If you use the same IP
space(s) in different physical private networks, all meant to be reachable via IP routes,
then you have to manage the ambiguous IP routes by associating them to virtual networks.
See "cloudflared tunnel network --help" for more information.`,
Subcommands: []*cli.Command{
{
Name: "add",
Action: cliutil.ConfiguredAction(addRouteCommand),
Usage: "Add any new network to the routing table reachable via the tunnel",
UsageText: "cloudflared tunnel [--config FILEPATH] route ip add [CIDR] [TUNNEL] [COMMENT?]",
Description: `Adds any network route space (represented as a CIDR) to your routing table.
That network space becomes reachable for requests egressing from a user's machine
Usage: "Add a new network to the routing table reachable via a Tunnel",
UsageText: "cloudflared tunnel [--config FILEPATH] route ip add [flags] [CIDR] [TUNNEL] [COMMENT?]",
Description: `Adds a network IP route space (represented as a CIDR) to your routing table.
That network IP space becomes reachable for requests egressing from a user's machine
as long as it is using Cloudflare WARP client and is enrolled in the same account
that is running the tunnel chosen here. Further, those requests will be proxied to
the specified tunnel, and reach an IP in the given CIDR, as long as that IP is
reachable from the tunnel.`,
that is running the Tunnel chosen here. Further, those requests will be proxied to
the specified Tunnel, and reach an IP in the given CIDR, as long as that IP is
reachable from cloudflared.
If the CIDR exists in more than one private network, to be connected with Cloudflare
Tunnels, then you have to manage those IP routes with virtual networks (see
"cloudflared tunnel network --help)". In those cases, you then have to tell
which virtual network's routing table you want to add the route to with:
"cloudflared tunnel route ip add --virtual-network [ID/name] [CIDR] [TUNNEL]".`,
Flags: []cli.Flag{vnetFlag},
},
{
Name: "show",
@ -49,17 +68,22 @@ reachable from the tunnel.`,
Name: "delete",
Action: cliutil.ConfiguredAction(deleteRouteCommand),
Usage: "Delete a row from your organization's private routing table",
UsageText: "cloudflared tunnel [--config FILEPATH] route ip delete [CIDR]",
Description: `Deletes the row for a given CIDR from your routing table. That portion
of your network will no longer be reachable by the WARP clients.`,
UsageText: "cloudflared tunnel [--config FILEPATH] route ip delete [flags] [CIDR]",
Description: `Deletes the row for a given CIDR from your routing table. That portion of your network
will no longer be reachable by the WARP clients. Note that if you use virtual
networks, then you have to tell which virtual network whose routing table you
have a row deleted from.`,
Flags: []cli.Flag{vnetFlag},
},
{
Name: "get",
Action: cliutil.ConfiguredAction(getRouteByIPCommand),
Usage: "Check which row of the routing table matches a given IP.",
UsageText: "cloudflared tunnel [--config FILEPATH] route ip get [IP]",
Description: `Checks which row of the routing table will be used to proxy a given IP.
This helps check and validate your config.`,
UsageText: "cloudflared tunnel [--config FILEPATH] route ip get [flags] [IP]",
Description: `Checks which row of the routing table will be used to proxy a given IP. This helps check
and validate your config. Note that if you use virtual networks, then you have
to tell which virtual network whose routing table you want to use.`,
Flags: []cli.Flag{vnetFlag},
},
},
}
@ -67,7 +91,7 @@ of your network will no longer be reachable by the WARP clients.`,
func showRoutesFlags() []cli.Flag {
flags := make([]cli.Flag, 0)
flags = append(flags, teamnet.FilterFlags...)
flags = append(flags, cfapi.IpRouteFilterFlags...)
flags = append(flags, outputFormatFlag)
return flags
}
@ -78,7 +102,7 @@ func showRoutesCommand(c *cli.Context) error {
return err
}
filter, err := teamnet.NewFromCLI(c)
filter, err := cfapi.NewIpRouteFilterFromCLI(c)
if err != nil {
return errors.Wrap(err, "invalid config for routing filters")
}
@ -98,7 +122,7 @@ func showRoutesCommand(c *cli.Context) error {
if len(routes) > 0 {
formatAndPrintRouteList(routes)
} else {
fmt.Println("You have no routes, use 'cloudflared tunnel route ip add' to add a route")
fmt.Println("No routes were found for the given filter flags. You can use 'cloudflared tunnel route ip add' to add a route.")
}
return nil
@ -112,7 +136,9 @@ func addRouteCommand(c *cli.Context) error {
if c.NArg() < 2 {
return errors.New("You must supply at least 2 arguments, first the network you wish to route (in CIDR form e.g. 1.2.3.4/32) and then the tunnel ID to proxy with")
}
args := c.Args()
_, network, err := net.ParseCIDR(args.Get(0))
if err != nil {
return errors.Wrap(err, "Invalid network CIDR")
@ -120,19 +146,32 @@ func addRouteCommand(c *cli.Context) error {
if network == nil {
return errors.New("Invalid network CIDR")
}
tunnelRef := args.Get(1)
tunnelID, err := sc.findID(tunnelRef)
if err != nil {
return errors.Wrap(err, "Invalid tunnel")
}
comment := ""
if c.NArg() >= 3 {
comment = args.Get(2)
}
_, err = sc.addRoute(teamnet.NewRoute{
var vnetId *uuid.UUID
if c.IsSet(vnetFlag.Name) {
id, err := getVnetId(sc, c.String(vnetFlag.Name))
if err != nil {
return err
}
vnetId = &id
}
_, err = sc.addRoute(cfapi.NewRoute{
Comment: comment,
Network: *network,
TunnelID: tunnelID,
VNetID: vnetId,
})
if err != nil {
return errors.Wrap(err, "API error")
@ -146,9 +185,11 @@ func deleteRouteCommand(c *cli.Context) error {
if err != nil {
return err
}
if c.NArg() != 1 {
return errors.New("You must supply exactly one argument, the network whose route you want to delete (in CIDR form e.g. 1.2.3.4/32)")
}
_, network, err := net.ParseCIDR(c.Args().First())
if err != nil {
return errors.Wrap(err, "Invalid network CIDR")
@ -156,7 +197,20 @@ func deleteRouteCommand(c *cli.Context) error {
if network == nil {
return errors.New("Invalid network CIDR")
}
if err := sc.deleteRoute(*network); err != nil {
params := cfapi.DeleteRouteParams{
Network: *network,
}
if c.IsSet(vnetFlag.Name) {
vnetId, err := getVnetId(sc, c.String(vnetFlag.Name))
if err != nil {
return err
}
params.VNetID = &vnetId
}
if err := sc.deleteRoute(params); err != nil {
return errors.Wrap(err, "API error")
}
fmt.Printf("Successfully deleted route for %s\n", network)
@ -177,19 +231,32 @@ func getRouteByIPCommand(c *cli.Context) error {
if ip == nil {
return fmt.Errorf("Invalid IP %s", ipInput)
}
route, err := sc.getRouteByIP(ip)
params := cfapi.GetRouteByIpParams{
Ip: ip,
}
if c.IsSet(vnetFlag.Name) {
vnetId, err := getVnetId(sc, c.String(vnetFlag.Name))
if err != nil {
return err
}
params.VNetID = &vnetId
}
route, err := sc.getRouteByIP(params)
if err != nil {
return errors.Wrap(err, "API error")
}
if route.IsZero() {
fmt.Printf("No route matches the IP %s\n", ip)
} else {
formatAndPrintRouteList([]*teamnet.DetailedRoute{&route})
formatAndPrintRouteList([]*cfapi.DetailedRoute{&route})
}
return nil
}
func formatAndPrintRouteList(routes []*teamnet.DetailedRoute) {
func formatAndPrintRouteList(routes []*cfapi.DetailedRoute) {
const (
minWidth = 0
tabWidth = 8
@ -202,7 +269,7 @@ func formatAndPrintRouteList(routes []*teamnet.DetailedRoute) {
defer writer.Flush()
// Print column headers with tabbed columns
_, _ = fmt.Fprintln(writer, "NETWORK\tCOMMENT\tTUNNEL ID\tTUNNEL NAME\tCREATED\tDELETED\t")
_, _ = fmt.Fprintln(writer, "NETWORK\tVIRTUAL NET ID\tCOMMENT\tTUNNEL ID\tTUNNEL NAME\tCREATED\tDELETED\t")
// Loop through routes, create formatted string for each, and print using tabwriter
for _, route := range routes {

View File

@ -0,0 +1,285 @@
package tunnel
import (
"fmt"
"os"
"text/tabwriter"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/cfapi"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
)
var (
makeDefaultFlag = &cli.BoolFlag{
Name: "default",
Aliases: []string{"d"},
Usage: "The virtual network becomes the default one for the account. This means that all operations that " +
"omit a virtual network will now implicitly be using this virtual network (i.e., the default one) such " +
"as new IP routes that are created. When this flag is not set, the virtual network will not become the " +
"default one in the account.",
}
newNameFlag = &cli.StringFlag{
Name: "name",
Aliases: []string{"n"},
Usage: "The new name for the virtual network.",
}
newCommentFlag = &cli.StringFlag{
Name: "comment",
Aliases: []string{"c"},
Usage: "A new comment describing the purpose of the virtual network.",
}
)
func buildVirtualNetworkSubcommand(hidden bool) *cli.Command {
return &cli.Command{
Name: "network",
Usage: "Configure and query virtual networks to manage private IP routes with overlapping IPs.",
UsageText: "cloudflared tunnel [--config FILEPATH] network COMMAND [arguments...]",
Description: `cloudflared allows to manage IP routes that expose origins in your private network space via their IP directly
to clients outside (e.g. using WARP client) --- those are configurable via "cloudflared tunnel route ip" commands.
By default, all those IP routes live in the same virtual network. Managing virtual networks (e.g. by creating a
new one) becomes relevant when you have different private networks that have overlapping IPs. E.g.: if you have
a private network A running Tunnel 1, and private network B running Tunnel 2, it is possible that both Tunnels
expose the same IP space (say 10.0.0.0/8); to handle that, you have to add each IP Route (one that points to
Tunnel 1 and another that points to Tunnel 2) in different Virtual Networks. That way, if your clients are on
Virtual Network X, they will see Tunnel 1 (via Route A) and not see Tunnel 2 (since its Route B is associated
to another Virtual Network Y).`,
Hidden: hidden,
Subcommands: []*cli.Command{
{
Name: "add",
Action: cliutil.ConfiguredAction(addVirtualNetworkCommand),
Usage: "Add a new virtual network to which IP routes can be attached",
UsageText: "cloudflared tunnel [--config FILEPATH] network add [flags] NAME [\"comment\"]",
Description: `Adds a new virtual network. You can then attach IP routes to this virtual network with "cloudflared tunnel route ip"
commands. By doing so, such route(s) become segregated from route(s) in another virtual networks. Note that all
routes exist within some virtual network. If you do not specify any, then the system pre-creates a default virtual
network to which all routes belong. That is fine if you do not have overlapping IPs within different physical
private networks in your infrastructure exposed via Cloudflare Tunnel. Note: if a virtual network is added as
the new default, then the previous existing default virtual network will be automatically modified to no longer
be the current default.`,
Flags: []cli.Flag{makeDefaultFlag},
Hidden: hidden,
},
{
Name: "list",
Action: cliutil.ConfiguredAction(listVirtualNetworksCommand),
Usage: "Lists the virtual networks",
UsageText: "cloudflared tunnel [--config FILEPATH] network list [flags]",
Description: "Lists the virtual networks based on the given filter flags.",
Flags: listVirtualNetworksFlags(),
Hidden: hidden,
},
{
Name: "delete",
Action: cliutil.ConfiguredAction(deleteVirtualNetworkCommand),
Usage: "Delete a virtual network",
UsageText: "cloudflared tunnel [--config FILEPATH] network delete VIRTUAL_NETWORK",
Description: `Deletes the virtual network (given its ID or name). This is only possible if that virtual network is unused.
A virtual network may be used by IP routes or by WARP devices.`,
Hidden: hidden,
},
{
Name: "update",
Action: cliutil.ConfiguredAction(updateVirtualNetworkCommand),
Usage: "Update a virtual network",
UsageText: "cloudflared tunnel [--config FILEPATH] network update [flags] VIRTUAL_NETWORK",
Description: `Updates the virtual network (given its ID or name). If this virtual network is updated to become the new
default, then the previously existing default virtual network will also be modified to no longer be the default.
You cannot update a virtual network to not be the default anymore directly. Instead, you should create a new
default or update an existing one to become the default.`,
Flags: []cli.Flag{newNameFlag, newCommentFlag, makeDefaultFlag},
Hidden: hidden,
},
},
}
}
func listVirtualNetworksFlags() []cli.Flag {
flags := make([]cli.Flag, 0)
flags = append(flags, cfapi.VnetFilterFlags...)
flags = append(flags, outputFormatFlag)
return flags
}
func addVirtualNetworkCommand(c *cli.Context) error {
sc, err := newSubcommandContext(c)
if err != nil {
return err
}
if c.NArg() < 1 {
return errors.New("You must supply at least 1 argument, the name of the virtual network you wish to add.")
}
warningChecker := updater.StartWarningCheck(c)
defer warningChecker.LogWarningIfAny(sc.log)
args := c.Args()
name := args.Get(0)
comment := ""
if c.NArg() >= 2 {
comment = args.Get(1)
}
newVnet := cfapi.NewVirtualNetwork{
Name: name,
Comment: comment,
IsDefault: c.Bool(makeDefaultFlag.Name),
}
createdVnet, err := sc.addVirtualNetwork(newVnet)
if err != nil {
return errors.Wrap(err, "Could not add virtual network")
}
extraMsg := ""
if createdVnet.IsDefault {
extraMsg = " (as the new default for this account) "
}
fmt.Printf(
"Successfully added virtual 'network' %s with ID: %s%s\n"+
"You can now add IP routes attached to this virtual network. See `cloudflared tunnel route ip add -help`\n",
name, createdVnet.ID, extraMsg,
)
return nil
}
func listVirtualNetworksCommand(c *cli.Context) error {
sc, err := newSubcommandContext(c)
if err != nil {
return err
}
warningChecker := updater.StartWarningCheck(c)
defer warningChecker.LogWarningIfAny(sc.log)
filter, err := cfapi.NewFromCLI(c)
if err != nil {
return errors.Wrap(err, "invalid flags for filtering virtual networks")
}
vnets, err := sc.listVirtualNetworks(filter)
if err != nil {
return err
}
if outputFormat := c.String(outputFormatFlag.Name); outputFormat != "" {
return renderOutput(outputFormat, vnets)
}
if len(vnets) > 0 {
formatAndPrintVnetsList(vnets)
} else {
fmt.Println("No virtual networks were found for the given filter flags. You can use 'cloudflared tunnel network add' to add a virtual network.")
}
return nil
}
func deleteVirtualNetworkCommand(c *cli.Context) error {
sc, err := newSubcommandContext(c)
if err != nil {
return err
}
if c.NArg() != 1 {
return errors.New("You must supply exactly one argument, either the ID or name of the virtual network to delete")
}
input := c.Args().Get(0)
vnetId, err := getVnetId(sc, input)
if err != nil {
return err
}
if err := sc.deleteVirtualNetwork(vnetId); err != nil {
return errors.Wrap(err, "API error")
}
fmt.Printf("Successfully deleted virtual network '%s'\n", input)
return nil
}
func updateVirtualNetworkCommand(c *cli.Context) error {
sc, err := newSubcommandContext(c)
if err != nil {
return err
}
if c.NArg() != 1 {
return errors.New(" You must supply exactly one argument, either the ID or (current) name of the virtual network to update")
}
input := c.Args().Get(0)
vnetId, err := getVnetId(sc, input)
if err != nil {
return err
}
updates := cfapi.UpdateVirtualNetwork{}
if c.IsSet(newNameFlag.Name) {
newName := c.String(newNameFlag.Name)
updates.Name = &newName
}
if c.IsSet(newCommentFlag.Name) {
newComment := c.String(newCommentFlag.Name)
updates.Comment = &newComment
}
if c.IsSet(makeDefaultFlag.Name) {
isDefault := c.Bool(makeDefaultFlag.Name)
updates.IsDefault = &isDefault
}
if err := sc.updateVirtualNetwork(vnetId, updates); err != nil {
return errors.Wrap(err, "API error")
}
fmt.Printf("Successfully updated virtual network '%s'\n", input)
return nil
}
func getVnetId(sc *subcommandContext, input string) (uuid.UUID, error) {
val, err := uuid.Parse(input)
if err == nil {
return val, nil
}
filter := cfapi.NewVnetFilter()
filter.WithDeleted(false)
filter.ByName(input)
vnets, err := sc.listVirtualNetworks(filter)
if err != nil {
return uuid.Nil, err
}
if len(vnets) != 1 {
return uuid.Nil, fmt.Errorf("there should only be 1 non-deleted virtual network named %s", input)
}
return vnets[0].ID, nil
}
func formatAndPrintVnetsList(vnets []*cfapi.VirtualNetwork) {
const (
minWidth = 0
tabWidth = 8
padding = 1
padChar = ' '
flags = 0
)
writer := tabwriter.NewWriter(os.Stdout, minWidth, tabWidth, padding, padChar, flags)
defer writer.Flush()
_, _ = fmt.Fprintln(writer, "ID\tNAME\tIS DEFAULT\tCOMMENT\tCREATED\tDELETED\t")
for _, virtualNetwork := range vnets {
formattedStr := virtualNetwork.TableString()
_, _ = fmt.Fprintln(writer, formattedStr)
}
}

View File

@ -19,7 +19,7 @@ import (
const (
DefaultCheckUpdateFreq = time.Hour * 24
noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/"
noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/run-as-service"
noUpdateOnWindowsMessage = "cloudflared will not automatically update on Windows systems."
noUpdateManagedPackageMessage = "cloudflared will not automatically update if installed by a package manager."
isManagedInstallFile = ".installedFromPackageManager"

View File

@ -1,3 +1,4 @@
//go:build !windows
// +build !windows
package updater

View File

@ -47,7 +47,7 @@ type batchData struct {
}
// WorkersVersion implements the Version interface.
// It contains everything needed to preform a version upgrade
// It contains everything needed to perform a version upgrade
type WorkersVersion struct {
downloadURL string
checksum string

View File

@ -1,3 +1,4 @@
//go:build windows
// +build windows
package main
@ -25,8 +26,8 @@ import (
const (
windowsServiceName = "Cloudflared"
windowsServiceDescription = "Argo Tunnel agent"
windowsServiceUrl = "https://developers.cloudflare.com/argo-tunnel/reference/service/"
windowsServiceDescription = "Cloudflare Tunnel agent"
windowsServiceUrl = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/run-as-service#windows"
recoverActionDelay = time.Second * 20
failureCountResetPeriod = time.Hour * 24
@ -45,16 +46,16 @@ const (
func runApp(app *cli.App, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel Windows service",
Usage: "Manages the Cloudflare Tunnel Windows service",
Subcommands: []*cli.Command{
{
Name: "install",
Usage: "Install Argo Tunnel as a Windows service",
Usage: "Install Cloudflare Tunnel as a Windows service",
Action: cliutil.ConfiguredAction(installWindowsService),
},
{
Name: "uninstall",
Usage: "Uninstall the Argo Tunnel service",
Usage: "Uninstall the Cloudflare Tunnel service",
Action: cliutil.ConfiguredAction(uninstallWindowsService),
},
},
@ -176,7 +177,7 @@ func (s *windowsService) Execute(serviceArgs []string, r <-chan svc.ChangeReques
func installWindowsService(c *cli.Context) error {
zeroLogger := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog)
zeroLogger.Info().Msg("Installing Argo Tunnel Windows service")
zeroLogger.Info().Msg("Installing Cloudflare Tunnel Windows service")
exepath, err := os.Executable()
if err != nil {
return errors.Wrap(err, "Cannot find path name that start the process")
@ -198,7 +199,7 @@ func installWindowsService(c *cli.Context) error {
return errors.Wrap(err, "Cannot install service")
}
defer s.Close()
log.Info().Msg("Argo Tunnel agent service is installed")
log.Info().Msg("Cloudflare Tunnel agent service is installed")
err = eventlog.InstallAsEventCreate(windowsServiceName, eventlog.Error|eventlog.Warning|eventlog.Info)
if err != nil {
s.Delete()
@ -217,7 +218,7 @@ func uninstallWindowsService(c *cli.Context) error {
With().
Str(LogFieldWindowsServiceName, windowsServiceName).Logger()
log.Info().Msg("Uninstalling Argo Tunnel Windows Service")
log.Info().Msg("Uninstalling Cloudflare Tunnel Windows Service")
m, err := mgr.Connect()
if err != nil {
return errors.Wrap(err, "Cannot establish a connection to the service control manager")
@ -232,7 +233,7 @@ func uninstallWindowsService(c *cli.Context) error {
if err != nil {
return errors.Wrap(err, "Cannot delete service")
}
log.Info().Msg("Argo Tunnel agent service is uninstalled")
log.Info().Msg("Cloudflare Tunnel agent service is uninstalled")
err = eventlog.Remove(windowsServiceName)
if err != nil {
return errors.Wrap(err, "Cannot remove event logger")

View File

@ -3,3 +3,7 @@ MAX_RETRIES = 5
BACKOFF_SECS = 7
PROXY_DNS_PORT = 9053
def protocols():
return ["h2mux", "http2", "quic"]

View File

@ -75,7 +75,7 @@ class TestLogging:
}
config = component_tests_config(extra_config)
with start_cloudflared(tmp_path, config, new_process=True, capture_output=False):
wait_tunnel_ready(tunnel_url=config.get_url())
wait_tunnel_ready(tunnel_url=config.get_url(), cfd_logs=str(log_file))
assert_log_in_file(log_file)
assert_json_log(log_file)
@ -88,5 +88,5 @@ class TestLogging:
}
config = component_tests_config(extra_config)
with start_cloudflared(tmp_path, config, new_process=True, capture_output=False):
wait_tunnel_ready(tunnel_url=config.get_url())
wait_tunnel_ready(tunnel_url=config.get_url(), cfd_logs=str(log_dir))
assert_log_to_dir(config, log_dir)

View File

@ -7,6 +7,7 @@ import pytest
from flaky import flaky
from conftest import CfdModes
from constants import protocols
from util import start_cloudflared, wait_tunnel_ready, check_tunnel_not_connected
@ -18,9 +19,16 @@ class TestReconnect:
"stdin-control": True,
}
def _extra_config(self, protocol):
return {
"stdin-control": True,
"protocol": protocol,
}
@pytest.mark.skipif(platform.system() == "Windows", reason=f"Currently buggy on Windows TUN-4584")
def test_named_reconnect(self, tmp_path, component_tests_config):
config = component_tests_config(self.extra_config)
@pytest.mark.parametrize("protocol", protocols())
def test_named_reconnect(self, tmp_path, component_tests_config, protocol):
config = component_tests_config(self._extra_config(protocol))
with start_cloudflared(tmp_path, config, new_process=True, allow_input=True, capture_output=False) as cloudflared:
# Repeat the test multiple times because some issues only occur after multiple reconnects
self.assert_reconnect(config, cloudflared, 5)

View File

@ -8,6 +8,7 @@ import time
import pytest
import requests
from constants import protocols
from util import start_cloudflared, wait_tunnel_ready, check_tunnel_not_connected
@ -17,17 +18,21 @@ def supported_signals():
return [signal.SIGTERM, signal.SIGINT]
class TestTermination():
class TestTermination:
grace_period = 5
timeout = 10
extra_config = {
"grace-period": f"{grace_period}s",
}
sse_endpoint = "/sse?freq=1s"
def _extra_config(self, protocol):
return {
"grace-period": f"{self.grace_period}s",
"protocol": protocol,
}
@pytest.mark.parametrize("signal", supported_signals())
def test_graceful_shutdown(self, tmp_path, component_tests_config, signal):
config = component_tests_config(self.extra_config)
@pytest.mark.parametrize("protocol", protocols())
def test_graceful_shutdown(self, tmp_path, component_tests_config, signal, protocol):
config = component_tests_config(self._extra_config(protocol))
with start_cloudflared(
tmp_path, config, new_process=True, capture_output=False) as cloudflared:
wait_tunnel_ready(tunnel_url=config.get_url())
@ -47,8 +52,9 @@ class TestTermination():
# test cloudflared terminates before grace period expires when all eyeball
# connections are drained
@pytest.mark.parametrize("signal", supported_signals())
def test_shutdown_once_no_connection(self, tmp_path, component_tests_config, signal):
config = component_tests_config(self.extra_config)
@pytest.mark.parametrize("protocol", protocols())
def test_shutdown_once_no_connection(self, tmp_path, component_tests_config, signal, protocol):
config = component_tests_config(self._extra_config(protocol))
with start_cloudflared(
tmp_path, config, new_process=True, capture_output=False) as cloudflared:
wait_tunnel_ready(tunnel_url=config.get_url())
@ -66,8 +72,9 @@ class TestTermination():
self.wait_eyeball_thread(in_flight_req, self.grace_period)
@pytest.mark.parametrize("signal", supported_signals())
def test_no_connection_shutdown(self, tmp_path, component_tests_config, signal):
config = component_tests_config(self.extra_config)
@pytest.mark.parametrize("protocol", protocols())
def test_no_connection_shutdown(self, tmp_path, component_tests_config, signal, protocol):
config = component_tests_config(self._extra_config(protocol))
with start_cloudflared(
tmp_path, config, new_process=True, capture_output=False) as cloudflared:
wait_tunnel_ready(tunnel_url=config.get_url())

View File

@ -1,12 +1,13 @@
from contextlib import contextmanager
import logging
import requests
from retrying import retry
import os
import subprocess
import yaml
from contextlib import contextmanager
from time import sleep
import requests
import yaml
from retrying import retry
from constants import METRICS_PORT, MAX_RETRIES, BACKOFF_SECS
LOGGER = logging.getLogger(__name__)
@ -19,7 +20,8 @@ def write_config(directory, config):
return config_path
def start_cloudflared(directory, config, cfd_args=["run"], cfd_pre_args=["tunnel"], new_process=False, allow_input=False, capture_output=True, root=False):
def start_cloudflared(directory, config, cfd_args=["run"], cfd_pre_args=["tunnel"], new_process=False,
allow_input=False, capture_output=True, root=False):
config_path = write_config(directory, config.full_config)
cmd = cloudflared_cmd(config, config_path, cfd_args, cfd_pre_args, root)
if new_process:
@ -53,18 +55,42 @@ def run_cloudflared_background(cmd, allow_input, capture_output):
LOGGER.info(f"cloudflared log: {cfd.stderr.read()}")
def wait_tunnel_ready(tunnel_url=None, require_min_connections=1, cfd_logs=None):
try:
inner_wait_tunnel_ready(tunnel_url, require_min_connections)
except Exception as e:
if cfd_logs is not None:
_log_cloudflared_logs(cfd_logs)
raise e
@retry(stop_max_attempt_number=MAX_RETRIES, wait_fixed=BACKOFF_SECS * 1000)
def wait_tunnel_ready(tunnel_url=None, require_min_connections=1):
def inner_wait_tunnel_ready(tunnel_url=None, require_min_connections=1):
metrics_url = f'http://localhost:{METRICS_PORT}/ready'
with requests.Session() as s:
resp = send_request(s, metrics_url, True)
assert resp.json()[
"readyConnections"] >= require_min_connections, f"Ready endpoint returned {resp.json()} but we expect at least {require_min_connections} connections"
assert resp.json()["readyConnections"] >= require_min_connections, \
f"Ready endpoint returned {resp.json()} but we expect at least {require_min_connections} connections"
if tunnel_url is not None:
send_request(s, tunnel_url, True)
def _log_cloudflared_logs(cfd_logs):
log_file = cfd_logs
if os.path.isdir(cfd_logs):
files = os.listdir(cfd_logs)
if len(files) == 0:
return
log_file = os.path.join(cfd_logs, files[0])
with open(log_file, "r") as f:
LOGGER.warning("Cloudflared Tunnel was not ready:")
for line in f.readlines():
LOGGER.warning(line)
@retry(stop_max_attempt_number=MAX_RETRIES * BACKOFF_SECS, wait_fixed=1000)
def check_tunnel_not_connected():
url = f'http://localhost:{METRICS_PORT}/ready'

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"math"
"net/http"
"strconv"
"strings"
@ -19,6 +20,7 @@ const (
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
LogFieldConnIndex = "connIndex"
MaxGracePeriod = time.Minute * 3
MaxConcurrentStreams = math.MaxUint32
)
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))

View File

@ -29,7 +29,9 @@ type controlStream struct {
// ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown.
type ControlStreamHandler interface {
ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, shouldWaitForUnregister bool) error
// ServeControlStream handles the control plane of the transport in the current goroutine calling this
ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions) error
// IsStopped tells whether the method above has finished
IsStopped() bool
}
@ -61,7 +63,6 @@ func (c *controlStream) ServeControlStream(
ctx context.Context,
rw io.ReadWriteCloser,
connOptions *tunnelpogs.ConnectionOptions,
shouldWaitForUnregister bool,
) error {
rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log)
@ -71,12 +72,7 @@ func (c *controlStream) ServeControlStream(
}
c.connectedFuse.Connected()
if shouldWaitForUnregister {
c.waitForUnregister(ctx, rpcClient)
} else {
go c.waitForUnregister(ctx, rpcClient)
}
c.waitForUnregister(ctx, rpcClient)
return nil
}

128
connection/h2mux_header.go Normal file
View File

@ -0,0 +1,128 @@
package connection
import (
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/pkg/errors"
"github.com/cloudflare/cloudflared/h2mux"
)
// H2RequestHeadersToH1Request converts the HTTP/2 headers coming from origintunneld
// to an HTTP/1 Request object destined for the local origin web service.
// This operation includes conversion of the pseudo-headers into their closest
// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
for _, header := range h2 {
name := strings.ToLower(header.Name)
if !IsH2muxControlRequestHeader(name) {
continue
}
switch name {
case ":method":
h1.Method = header.Value
case ":scheme":
// noop - use the preexisting scheme from h1.URL
case ":authority":
// Otherwise the host header will be based on the origin URL
h1.Host = header.Value
case ":path":
// We don't want to be an "opinionated" proxy, so ideally we would use :path as-is.
// However, this HTTP/1 Request object belongs to the Go standard library,
// whose URL package makes some opinionated decisions about the encoding of
// URL characters: see the docs of https://godoc.org/net/url#URL,
// in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath,
// which is always used when computing url.URL.String(), whether we'd like it or not.
//
// Well, not *always*. We could circumvent this by using url.URL.Opaque. But
// that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque
// is treated differently when HTTP_PROXY is set!
// See https://github.com/golang/go/issues/5684#issuecomment-66080888
//
// This means we are subject to the behavior of net/url's function `shouldEscape`
// (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101
if header.Value == "*" {
h1.URL.Path = "*"
continue
}
// Due to the behavior of validation.ValidateUrl, h1.URL may
// already have a partial value, with or without a trailing slash.
base := h1.URL.String()
base = strings.TrimRight(base, "/")
// But we know :path begins with '/', because we handled '*' above - see RFC7540
requestURL, err := url.Parse(base + header.Value)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value))
}
h1.URL = requestURL
case "content-length":
contentLength, err := strconv.ParseInt(header.Value, 10, 64)
if err != nil {
return fmt.Errorf("unparseable content length")
}
h1.ContentLength = contentLength
case RequestUserHeaders:
// Do not forward the serialized headers to the origin -- deserialize them, and ditch the serialized version
// Find and parse user headers serialized into a single one
userHeaders, err := DeserializeHeaders(header.Value)
if err != nil {
return errors.Wrap(err, "Unable to parse user headers")
}
for _, userHeader := range userHeaders {
h1.Header.Add(userHeader.Name, userHeader.Value)
}
default:
// All other control headers shall just be proxied transparently
h1.Header.Add(header.Name, header.Value)
}
}
return nil
}
func H1ResponseToH2ResponseHeaders(status int, h1 http.Header) (h2 []h2mux.Header) {
h2 = []h2mux.Header{
{Name: ":status", Value: strconv.Itoa(status)},
}
userHeaders := make(http.Header, len(h1))
for header, values := range h1 {
h2name := strings.ToLower(header)
if h2name == "content-length" {
// This header has meaning in HTTP/2 and will be used by the edge,
// so it should be sent as an HTTP/2 response header.
// Since these are http2 headers, they're required to be lowercase
h2 = append(h2, h2mux.Header{Name: "content-length", Value: values[0]})
} else if !IsH2muxControlResponseHeader(h2name) || IsWebsocketClientHeader(h2name) {
// User headers, on the other hand, must all be serialized so that
// HTTP/2 header validation won't be applied to HTTP/1 header values
userHeaders[header] = values
}
}
// Perform user header serialization and set them in the single header
h2 = append(h2, h2mux.Header{Name: ResponseUserHeaders, Value: SerializeHeaders(userHeaders)})
return h2
}
// IsH2muxControlRequestHeader is called in the direction of eyeball -> origin.
func IsH2muxControlRequestHeader(headerName string) bool {
return headerName == "content-length" ||
headerName == "connection" || headerName == "upgrade" || // Websocket request headers
strings.HasPrefix(headerName, ":") ||
strings.HasPrefix(headerName, "cf-")
}
// IsH2muxControlResponseHeader is called in the direction of eyeball <- origin.
func IsH2muxControlResponseHeader(headerName string) bool {
return headerName == "content-length" ||
strings.HasPrefix(headerName, ":") ||
strings.HasPrefix(headerName, "cf-int-") ||
strings.HasPrefix(headerName, "cf-cloudflared-")
}

View File

@ -0,0 +1,642 @@
package connection
import (
"fmt"
"math/rand"
"net/http"
"net/url"
"reflect"
"regexp"
"strings"
"testing"
"testing/quick"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/h2mux"
)
type ByName []h2mux.Header
func (a ByName) Len() int { return len(a) }
func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByName) Less(i, j int) bool {
if a[i].Name == a[j].Name {
return a[i].Value < a[j].Value
}
return a[i].Name < a[j].Name
}
func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
assert.NoError(t, err)
mockHeaders := http.Header{
"Mock header 1": {"Mock value 1"},
"Mock header 2": {"Mock value 2"},
}
headersConversionErr := H2RequestHeadersToH1Request(createSerializedHeaders(RequestUserHeaders, mockHeaders), request)
assert.True(t, reflect.DeepEqual(mockHeaders, request.Header))
assert.NoError(t, headersConversionErr)
}
func createSerializedHeaders(headersField string, headers http.Header) []h2mux.Header {
return []h2mux.Header{{
Name: headersField,
Value: SerializeHeaders(headers),
}}
}
func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
assert.NoError(t, err)
emptyHeaders := make(http.Header)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{{
Name: RequestUserHeaders,
Value: SerializeHeaders(emptyHeaders),
}},
request,
)
assert.True(t, reflect.DeepEqual(emptyHeaders, request.Header))
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
assert.NoError(t, err)
mockRequestHeaders := []h2mux.Header{
{Name: ":path", Value: "//bad_path/"},
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
}
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
assert.Equal(t, http.Header{
"Mock header": []string{"Mock value"},
}, request.Header)
assert.Equal(t, "http://example.com//bad_path/", request.URL.String())
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
assert.NoError(t, err)
mockRequestHeaders := []h2mux.Header{
{Name: ":path", Value: "/?query=mock%20value"},
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
}
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
assert.Equal(t, http.Header{
"Mock header": []string{"Mock value"},
}, request.Header)
assert.Equal(t, "http://example.com/?query=mock%20value", request.URL.String())
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
assert.NoError(t, err)
mockRequestHeaders := []h2mux.Header{
{Name: ":path", Value: "/mock%20path"},
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
}
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
assert.Equal(t, http.Header{
"Mock header": []string{"Mock value"},
}, request.Header)
assert.Equal(t, "http://example.com/mock%20path", request.URL.String())
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) {
type testCase struct {
path string
want string
}
testCases := []testCase{
{
path: "",
want: "",
},
{
path: "/",
want: "/",
},
{
path: "//",
want: "//",
},
{
path: "/test",
want: "/test",
},
{
path: "//test",
want: "//test",
},
{
// https://github.com/cloudflare/cloudflared/issues/81
path: "//test/",
want: "//test/",
},
{
path: "/%2Ftest",
want: "/%2Ftest",
},
{
path: "//%20test",
want: "//%20test",
},
{
// https://github.com/cloudflare/cloudflared/issues/124
path: "/test?get=somthing%20a",
want: "/test?get=somthing%20a",
},
{
path: "/%20",
want: "/%20",
},
{
// stdlib's EscapedPath() will always percent-encode ' '
path: "/ ",
want: "/%20",
},
{
path: "/ a ",
want: "/%20a%20",
},
{
path: "/a%20b",
want: "/a%20b",
},
{
path: "/foo/bar;param?query#frag",
want: "/foo/bar;param?query#frag",
},
{
// stdlib's EscapedPath() will always percent-encode non-ASCII chars
path: "/a␠b",
want: "/a%E2%90%A0b",
},
{
path: "/a-umlaut-ä",
want: "/a-umlaut-%C3%A4",
},
{
path: "/a-umlaut-%C3%A4",
want: "/a-umlaut-%C3%A4",
},
{
path: "/a-umlaut-%c3%a4",
want: "/a-umlaut-%c3%a4",
},
{
// here the second '#' is treated as part of the fragment
path: "/a#b#c",
want: "/a#b%23c",
},
{
path: "/a#b␠c",
want: "/a#b%E2%90%A0c",
},
{
path: "/a#b%20c",
want: "/a#b%20c",
},
{
path: "/a#b c",
want: "/a#b%20c",
},
{
// stdlib's EscapedPath() will always percent-encode '\'
path: "/\\",
want: "/%5C",
},
{
path: "/a\\",
want: "/a%5C",
},
{
path: "/a,b.c.",
want: "/a,b.c.",
},
{
path: "/.",
want: "/.",
},
{
// stdlib's EscapedPath() will always percent-encode '`'
path: "/a`",
want: "/a%60",
},
{
path: "/a[0]",
want: "/a[0]",
},
{
path: "/?a[0]=5 &b[]=",
want: "/?a[0]=5 &b[]=",
},
{
path: "/?a=%22b%20%22",
want: "/?a=%22b%20%22",
},
}
for index, testCase := range testCases {
requestURL := "https://example.com"
request, err := http.NewRequest(http.MethodGet, requestURL, nil)
assert.NoError(t, err)
mockRequestHeaders := []h2mux.Header{
{Name: ":path", Value: testCase.path},
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
}
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
assert.NoError(t, headersConversionErr)
assert.Equal(t,
http.Header{
"Mock header": []string{"Mock value"},
},
request.Header)
assert.Equal(t,
"https://example.com"+testCase.want,
request.URL.String(),
"Failed URL index: %v %#v", index, testCase)
}
}
func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) {
config := &quick.Config{
Values: func(args []reflect.Value, rand *rand.Rand) {
args[0] = reflect.ValueOf(randomHTTP2Path(t, rand))
},
}
type testOrigin struct {
url string
expectedScheme string
expectedBasePath string
}
testOrigins := []testOrigin{
{
url: "http://origin.hostname.example.com:8080",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080",
},
{
url: "http://origin.hostname.example.com:8080/",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080",
},
{
url: "http://origin.hostname.example.com:8080/api",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080/api",
},
{
url: "http://origin.hostname.example.com:8080/api/",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080/api",
},
{
url: "https://origin.hostname.example.com:8080/api",
expectedScheme: "https",
expectedBasePath: "https://origin.hostname.example.com:8080/api",
},
}
// use multiple schemes to demonstrate that the URL is based on the
// origin's scheme, not the :scheme header
for _, testScheme := range []string{"http", "https"} {
for _, testOrigin := range testOrigins {
assertion := func(testPath string) bool {
const expectedMethod = "POST"
const expectedHostname = "request.hostname.example.com"
h2 := []h2mux.Header{
{Name: ":method", Value: expectedMethod},
{Name: ":scheme", Value: testScheme},
{Name: ":authority", Value: expectedHostname},
{Name: ":path", Value: testPath},
{Name: RequestUserHeaders, Value: ""},
}
h1, err := http.NewRequest("GET", testOrigin.url, nil)
require.NoError(t, err)
err = H2RequestHeadersToH1Request(h2, h1)
return assert.NoError(t, err) &&
assert.Equal(t, expectedMethod, h1.Method) &&
assert.Equal(t, expectedHostname, h1.Host) &&
assert.Equal(t, testOrigin.expectedScheme, h1.URL.Scheme) &&
assert.Equal(t, testOrigin.expectedBasePath+testPath, h1.URL.String())
}
err := quick.Check(assertion, config)
assert.NoError(t, err)
}
}
}
func randomASCIIPrintableChar(rand *rand.Rand) int {
// smallest printable ASCII char is 32, largest is 126
const startPrintable = 32
const endPrintable = 127
return startPrintable + rand.Intn(endPrintable-startPrintable)
}
// randomASCIIText generates an ASCII string, some of whose characters may be
// percent-encoded. Its "logical length" (ignoring percent-encoding) is
// between 1 and `maxLength`.
func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string {
length := minLength + rand.Intn(maxLength)
var result strings.Builder
for i := 0; i < length; i++ {
c := randomASCIIPrintableChar(rand)
// 1/4 chance of using percent encoding when not necessary
if c == '%' || rand.Intn(4) == 0 {
result.WriteString(fmt.Sprintf("%%%02X", c))
} else {
result.WriteByte(byte(c))
}
}
return result.String()
}
// Calls `randomASCIIText` and ensures the result is a valid URL path,
// i.e. one that can pass unchanged through url.URL.String()
func randomHTTP1Path(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
text := randomASCIIText(rand, minLength, maxLength)
re, err := regexp.Compile("[^/;,]*")
require.NoError(t, err)
return "/" + re.ReplaceAllStringFunc(text, url.PathEscape)
}
// Calls `randomASCIIText` and ensures the result is a valid URL query,
// i.e. one that can pass unchanged through url.URL.String()
func randomHTTP1Query(rand *rand.Rand, minLength int, maxLength int) string {
text := randomASCIIText(rand, minLength, maxLength)
return "?" + strings.ReplaceAll(text, "#", "%23")
}
// Calls `randomASCIIText` and ensures the result is a valid URL fragment,
// i.e. one that can pass unchanged through url.URL.String()
func randomHTTP1Fragment(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
text := randomASCIIText(rand, minLength, maxLength)
u, err := url.Parse("#" + text)
require.NoError(t, err)
return u.String()
}
// Assemble a random :path pseudoheader that is legal by Go stdlib standards
// (i.e. all characters will satisfy "net/url".shouldEscape for their respective locations)
func randomHTTP2Path(t *testing.T, rand *rand.Rand) string {
result := randomHTTP1Path(t, rand, 1, 64)
if rand.Intn(2) == 1 {
result += randomHTTP1Query(rand, 1, 32)
}
if rand.Intn(2) == 1 {
result += randomHTTP1Fragment(t, rand, 1, 16)
}
return result
}
func stdlibHeaderToH2muxHeader(headers http.Header) (h2muxHeaders []h2mux.Header) {
for name, values := range headers {
for _, value := range values {
h2muxHeaders = append(h2muxHeaders, h2mux.Header{Name: name, Value: value})
}
}
return h2muxHeaders
}
func TestParseRequestHeaders(t *testing.T) {
mockUserHeadersToSerialize := http.Header{
"Mock-Header-One": {"1", "1.5"},
"Mock-Header-Two": {"2"},
"Mock-Header-Three": {"3"},
}
mockHeaders := []h2mux.Header{
{Name: "One", Value: "1"}, // will be dropped
{Name: "Cf-Two", Value: "cf-value-1"},
{Name: "Cf-Two", Value: "cf-value-2"},
{Name: RequestUserHeaders, Value: SerializeHeaders(mockUserHeadersToSerialize)},
}
expectedHeaders := []h2mux.Header{
{Name: "Cf-Two", Value: "cf-value-1"},
{Name: "Cf-Two", Value: "cf-value-2"},
{Name: "Mock-Header-One", Value: "1"},
{Name: "Mock-Header-One", Value: "1.5"},
{Name: "Mock-Header-Two", Value: "2"},
{Name: "Mock-Header-Three", Value: "3"},
}
h1 := &http.Request{
Header: make(http.Header),
}
err := H2RequestHeadersToH1Request(mockHeaders, h1)
assert.NoError(t, err)
assert.ElementsMatch(t, expectedHeaders, stdlibHeaderToH2muxHeader(h1.Header))
}
func TestIsH2muxControlRequestHeader(t *testing.T) {
controlRequestHeaders := []string{
// Anything that begins with cf-
"cf-sample-header",
// Any http2 pseudoheader
":sample-pseudo-header",
// content-length is a special case, it has to be there
// for some requests to work (per the HTTP2 spec)
"content-length",
// Websocket request headers
"connection",
"upgrade",
}
for _, header := range controlRequestHeaders {
assert.True(t, IsH2muxControlRequestHeader(header))
}
}
func TestIsH2muxControlResponseHeader(t *testing.T) {
controlResponseHeaders := []string{
// Anything that begins with cf-int- or cf-cloudflared-
"cf-int-sample-header",
"cf-cloudflared-sample-header",
// Any http2 pseudoheader
":sample-pseudo-header",
// content-length is a special case, it has to be there
// for some requests to work (per the HTTP2 spec)
"content-length",
}
for _, header := range controlResponseHeaders {
assert.True(t, IsH2muxControlResponseHeader(header))
}
}
func TestIsNotH2muxControlRequestHeader(t *testing.T) {
notControlRequestHeaders := []string{
"mock-header",
"another-sample-header",
}
for _, header := range notControlRequestHeaders {
assert.False(t, IsH2muxControlRequestHeader(header))
}
}
func TestIsNotH2muxControlResponseHeader(t *testing.T) {
notControlResponseHeaders := []string{
"mock-header",
"another-sample-header",
"upgrade",
"connection",
"cf-whatever", // On the response path, we only want to filter cf-int- and cf-cloudflared-
}
for _, header := range notControlResponseHeaders {
assert.False(t, IsH2muxControlResponseHeader(header))
}
}
func TestH1ResponseToH2ResponseHeaders(t *testing.T) {
mockHeaders := http.Header{
"User-header-one": {""},
"User-header-two": {"1", "2"},
"cf-header": {"cf-value"},
"cf-int-header": {"cf-int-value"},
"cf-cloudflared-header": {"cf-cloudflared-value"},
"Content-Length": {"123"},
}
mockResponse := http.Response{
StatusCode: 200,
Header: mockHeaders,
}
headers := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header)
serializedHeadersIndex := -1
for i, header := range headers {
if header.Name == ResponseUserHeaders {
serializedHeadersIndex = i
break
}
}
assert.NotEqual(t, -1, serializedHeadersIndex)
actualControlHeaders := append(
headers[:serializedHeadersIndex],
headers[serializedHeadersIndex+1:]...,
)
expectedControlHeaders := []h2mux.Header{
{Name: ":status", Value: "200"},
{Name: "content-length", Value: "123"},
}
assert.ElementsMatch(t, expectedControlHeaders, actualControlHeaders)
actualUserHeaders, err := DeserializeHeaders(headers[serializedHeadersIndex].Value)
expectedUserHeaders := []h2mux.Header{
{Name: "User-header-one", Value: ""},
{Name: "User-header-two", Value: "1"},
{Name: "User-header-two", Value: "2"},
{Name: "cf-header", Value: "cf-value"},
}
assert.NoError(t, err)
assert.ElementsMatch(t, expectedUserHeaders, actualUserHeaders)
}
// The purpose of this test is to check that our code and the http.Header
// implementation don't throw validation errors about header size
func TestHeaderSize(t *testing.T) {
largeValue := randSeq(5 * 1024 * 1024) // 5Mb
largeHeaders := http.Header{
"User-header": {largeValue},
}
mockResponse := http.Response{
StatusCode: 200,
Header: largeHeaders,
}
serializedHeaders := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header)
request, err := http.NewRequest(http.MethodGet, "https://example.com/", nil)
assert.NoError(t, err)
for _, header := range serializedHeaders {
request.Header.Set(header.Name, header.Value)
}
for _, header := range serializedHeaders {
if header.Name != ResponseUserHeaders {
continue
}
deserializedHeaders, err := DeserializeHeaders(header.Value)
assert.NoError(t, err)
assert.Equal(t, largeValue, deserializedHeaders[0].Value)
}
}
func randSeq(n int) string {
randomizer := rand.New(rand.NewSource(17))
var letters = []rune(":;,+/=abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
b := make([]rune, n)
for i := range b {
b[i] = letters[randomizer.Intn(len(letters))]
}
return string(b)
}
func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) {
ser := "eC1mb3J3YXJkZWQtcHJvdG8:aHR0cHM;dXBncmFkZS1pbnNlY3VyZS1yZXF1ZXN0cw:MQ;YWNjZXB0LWxhbmd1YWdl:ZW4tVVMsZW47cT0wLjkscnU7cT0wLjg;YWNjZXB0LWVuY29kaW5n:Z3ppcA;eC1mb3J3YXJkZWQtZm9y:MTczLjI0NS42MC42;dXNlci1hZ2VudA:TW96aWxsYS81LjAgKE1hY2ludG9zaDsgSW50ZWwgTWFjIE9TIFggMTBfMTRfNikgQXBwbGVXZWJLaXQvNTM3LjM2IChLSFRNTCwgbGlrZSBHZWNrbykgQ2hyb21lLzg0LjAuNDE0Ny44OSBTYWZhcmkvNTM3LjM2;c2VjLWZldGNoLW1vZGU:bmF2aWdhdGU;Y2RuLWxvb3A:Y2xvdWRmbGFyZQ;c2VjLWZldGNoLWRlc3Q:ZG9jdW1lbnQ;c2VjLWZldGNoLXVzZXI:PzE;c2VjLWZldGNoLXNpdGU:bm9uZQ;Y29va2ll:X19jZmR1aWQ9ZGNkOWZjOGNjNWMxMzE0NTMyYTFkMjhlZDEyOWRhOTYwMTU2OTk1MTYzNDsgX19jZl9ibT1mYzY2MzMzYzAzZmM0MWFiZTZmOWEyYzI2ZDUwOTA0YzIxYzZhMTQ2LTE1OTU2MjIzNDEtMTgwMC1BZTVzS2pIU2NiWGVFM05mMUhrTlNQMG1tMHBLc2pQWkloVnM1Z2g1SkNHQkFhS1UxVDB2b003alBGN3FjMHVSR2NjZGcrWHdhL1EzbTJhQzdDVU4xZ2M9;YWNjZXB0:dGV4dC9odG1sLGFwcGxpY2F0aW9uL3hodG1sK3htbCxhcHBsaWNhdGlvbi94bWw7cT0wLjksaW1hZ2Uvd2VicCxpbWFnZS9hcG5nLCovKjtxPTAuOCxhcHBsaWNhdGlvbi9zaWduZWQtZXhjaGFuZ2U7dj1iMztxPTAuOQ"
h2, _ := DeserializeHeaders(ser)
h1 := make(http.Header)
for _, header := range h2 {
h1.Add(header.Name, header.Value)
}
h1.Add("Content-Length", "200")
h1.Add("Cf-Something", "Else")
h1.Add("Upgrade", "websocket")
h1resp := &http.Response{
StatusCode: 200,
Header: h1,
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = H1ResponseToH2ResponseHeaders(h1resp.StatusCode, h1resp.Header)
}
}

View File

@ -4,8 +4,6 @@ import (
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/pkg/errors"
@ -44,93 +42,9 @@ func mustInitRespMetaHeader(src string) string {
var headerEncoding = base64.RawStdEncoding
// note: all h2mux headers should be lower-case (http/2 style)
const ()
// H2RequestHeadersToH1Request converts the HTTP/2 headers coming from origintunneld
// to an HTTP/1 Request object destined for the local origin web service.
// This operation includes conversion of the pseudo-headers into their closest
// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
for _, header := range h2 {
name := strings.ToLower(header.Name)
if !IsControlRequestHeader(name) {
continue
}
switch name {
case ":method":
h1.Method = header.Value
case ":scheme":
// noop - use the preexisting scheme from h1.URL
case ":authority":
// Otherwise the host header will be based on the origin URL
h1.Host = header.Value
case ":path":
// We don't want to be an "opinionated" proxy, so ideally we would use :path as-is.
// However, this HTTP/1 Request object belongs to the Go standard library,
// whose URL package makes some opinionated decisions about the encoding of
// URL characters: see the docs of https://godoc.org/net/url#URL,
// in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath,
// which is always used when computing url.URL.String(), whether we'd like it or not.
//
// Well, not *always*. We could circumvent this by using url.URL.Opaque. But
// that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque
// is treated differently when HTTP_PROXY is set!
// See https://github.com/golang/go/issues/5684#issuecomment-66080888
//
// This means we are subject to the behavior of net/url's function `shouldEscape`
// (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101
if header.Value == "*" {
h1.URL.Path = "*"
continue
}
// Due to the behavior of validation.ValidateUrl, h1.URL may
// already have a partial value, with or without a trailing slash.
base := h1.URL.String()
base = strings.TrimRight(base, "/")
// But we know :path begins with '/', because we handled '*' above - see RFC7540
requestURL, err := url.Parse(base + header.Value)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value))
}
h1.URL = requestURL
case "content-length":
contentLength, err := strconv.ParseInt(header.Value, 10, 64)
if err != nil {
return fmt.Errorf("unparseable content length")
}
h1.ContentLength = contentLength
case RequestUserHeaders:
// Do not forward the serialized headers to the origin -- deserialize them, and ditch the serialized version
// Find and parse user headers serialized into a single one
userHeaders, err := DeserializeHeaders(header.Value)
if err != nil {
return errors.Wrap(err, "Unable to parse user headers")
}
for _, userHeader := range userHeaders {
h1.Header.Add(userHeader.Name, userHeader.Value)
}
default:
// All other control headers shall just be proxied transparently
h1.Header.Add(header.Name, header.Value)
}
}
return nil
}
func IsControlRequestHeader(headerName string) bool {
return headerName == "content-length" ||
headerName == "connection" || headerName == "upgrade" || // Websocket request headers
strings.HasPrefix(headerName, ":") ||
strings.HasPrefix(headerName, "cf-")
}
// IsControlResponseHeader is called in the direction of eyeball <- origin.
func IsControlResponseHeader(headerName string) bool {
return headerName == "content-length" ||
strings.HasPrefix(headerName, ":") ||
return strings.HasPrefix(headerName, ":") ||
strings.HasPrefix(headerName, "cf-int-") ||
strings.HasPrefix(headerName, "cf-cloudflared-")
}
@ -142,31 +56,6 @@ func IsWebsocketClientHeader(headerName string) bool {
headerName == "upgrade"
}
func H1ResponseToH2ResponseHeaders(status int, h1 http.Header) (h2 []h2mux.Header) {
h2 = []h2mux.Header{
{Name: ":status", Value: strconv.Itoa(status)},
}
userHeaders := make(http.Header, len(h1))
for header, values := range h1 {
h2name := strings.ToLower(header)
if h2name == "content-length" {
// This header has meaning in HTTP/2 and will be used by the edge,
// so it should be sent as an HTTP/2 response header.
// Since these are http2 headers, they're required to be lowercase
h2 = append(h2, h2mux.Header{Name: "content-length", Value: values[0]})
} else if !IsControlResponseHeader(h2name) || IsWebsocketClientHeader(h2name) {
// User headers, on the other hand, must all be serialized so that
// HTTP/2 header validation won't be applied to HTTP/1 header values
userHeaders[header] = values
}
}
// Perform user header serialization and set them in the single header
h2 = append(h2, h2mux.Header{Name: ResponseUserHeaders, Value: SerializeHeaders(userHeaders)})
return h2
}
// Serialize HTTP1.x headers by base64-encoding each header name and value,
// and then joining them in the format of [key:value;]
func SerializeHeaders(h1Headers http.Header) string {

View File

@ -2,441 +2,14 @@ package connection
import (
"fmt"
"math/rand"
"net/http"
"net/url"
"reflect"
"regexp"
"sort"
"strings"
"testing"
"testing/quick"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/h2mux"
)
type ByName []h2mux.Header
func (a ByName) Len() int { return len(a) }
func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByName) Less(i, j int) bool {
if a[i].Name == a[j].Name {
return a[i].Value < a[j].Value
}
return a[i].Name < a[j].Name
}
func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
assert.NoError(t, err)
mockHeaders := http.Header{
"Mock header 1": {"Mock value 1"},
"Mock header 2": {"Mock value 2"},
}
headersConversionErr := H2RequestHeadersToH1Request(createSerializedHeaders(RequestUserHeaders, mockHeaders), request)
assert.True(t, reflect.DeepEqual(mockHeaders, request.Header))
assert.NoError(t, headersConversionErr)
}
func createSerializedHeaders(headersField string, headers http.Header) []h2mux.Header {
return []h2mux.Header{{
Name: headersField,
Value: SerializeHeaders(headers),
}}
}
func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
assert.NoError(t, err)
emptyHeaders := make(http.Header)
headersConversionErr := H2RequestHeadersToH1Request(
[]h2mux.Header{{
Name: RequestUserHeaders,
Value: SerializeHeaders(emptyHeaders),
}},
request,
)
assert.True(t, reflect.DeepEqual(emptyHeaders, request.Header))
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
assert.NoError(t, err)
mockRequestHeaders := []h2mux.Header{
{Name: ":path", Value: "//bad_path/"},
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
}
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
assert.Equal(t, http.Header{
"Mock header": []string{"Mock value"},
}, request.Header)
assert.Equal(t, "http://example.com//bad_path/", request.URL.String())
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
assert.NoError(t, err)
mockRequestHeaders := []h2mux.Header{
{Name: ":path", Value: "/?query=mock%20value"},
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
}
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
assert.Equal(t, http.Header{
"Mock header": []string{"Mock value"},
}, request.Header)
assert.Equal(t, "http://example.com/?query=mock%20value", request.URL.String())
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
assert.NoError(t, err)
mockRequestHeaders := []h2mux.Header{
{Name: ":path", Value: "/mock%20path"},
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
}
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
assert.Equal(t, http.Header{
"Mock header": []string{"Mock value"},
}, request.Header)
assert.Equal(t, "http://example.com/mock%20path", request.URL.String())
assert.NoError(t, headersConversionErr)
}
func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) {
type testCase struct {
path string
want string
}
testCases := []testCase{
{
path: "",
want: "",
},
{
path: "/",
want: "/",
},
{
path: "//",
want: "//",
},
{
path: "/test",
want: "/test",
},
{
path: "//test",
want: "//test",
},
{
// https://github.com/cloudflare/cloudflared/issues/81
path: "//test/",
want: "//test/",
},
{
path: "/%2Ftest",
want: "/%2Ftest",
},
{
path: "//%20test",
want: "//%20test",
},
{
// https://github.com/cloudflare/cloudflared/issues/124
path: "/test?get=somthing%20a",
want: "/test?get=somthing%20a",
},
{
path: "/%20",
want: "/%20",
},
{
// stdlib's EscapedPath() will always percent-encode ' '
path: "/ ",
want: "/%20",
},
{
path: "/ a ",
want: "/%20a%20",
},
{
path: "/a%20b",
want: "/a%20b",
},
{
path: "/foo/bar;param?query#frag",
want: "/foo/bar;param?query#frag",
},
{
// stdlib's EscapedPath() will always percent-encode non-ASCII chars
path: "/a␠b",
want: "/a%E2%90%A0b",
},
{
path: "/a-umlaut-ä",
want: "/a-umlaut-%C3%A4",
},
{
path: "/a-umlaut-%C3%A4",
want: "/a-umlaut-%C3%A4",
},
{
path: "/a-umlaut-%c3%a4",
want: "/a-umlaut-%c3%a4",
},
{
// here the second '#' is treated as part of the fragment
path: "/a#b#c",
want: "/a#b%23c",
},
{
path: "/a#b␠c",
want: "/a#b%E2%90%A0c",
},
{
path: "/a#b%20c",
want: "/a#b%20c",
},
{
path: "/a#b c",
want: "/a#b%20c",
},
{
// stdlib's EscapedPath() will always percent-encode '\'
path: "/\\",
want: "/%5C",
},
{
path: "/a\\",
want: "/a%5C",
},
{
path: "/a,b.c.",
want: "/a,b.c.",
},
{
path: "/.",
want: "/.",
},
{
// stdlib's EscapedPath() will always percent-encode '`'
path: "/a`",
want: "/a%60",
},
{
path: "/a[0]",
want: "/a[0]",
},
{
path: "/?a[0]=5 &b[]=",
want: "/?a[0]=5 &b[]=",
},
{
path: "/?a=%22b%20%22",
want: "/?a=%22b%20%22",
},
}
for index, testCase := range testCases {
requestURL := "https://example.com"
request, err := http.NewRequest(http.MethodGet, requestURL, nil)
assert.NoError(t, err)
mockRequestHeaders := []h2mux.Header{
{Name: ":path", Value: testCase.path},
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
}
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
assert.NoError(t, headersConversionErr)
assert.Equal(t,
http.Header{
"Mock header": []string{"Mock value"},
},
request.Header)
assert.Equal(t,
"https://example.com"+testCase.want,
request.URL.String(),
"Failed URL index: %v %#v", index, testCase)
}
}
func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) {
config := &quick.Config{
Values: func(args []reflect.Value, rand *rand.Rand) {
args[0] = reflect.ValueOf(randomHTTP2Path(t, rand))
},
}
type testOrigin struct {
url string
expectedScheme string
expectedBasePath string
}
testOrigins := []testOrigin{
{
url: "http://origin.hostname.example.com:8080",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080",
},
{
url: "http://origin.hostname.example.com:8080/",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080",
},
{
url: "http://origin.hostname.example.com:8080/api",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080/api",
},
{
url: "http://origin.hostname.example.com:8080/api/",
expectedScheme: "http",
expectedBasePath: "http://origin.hostname.example.com:8080/api",
},
{
url: "https://origin.hostname.example.com:8080/api",
expectedScheme: "https",
expectedBasePath: "https://origin.hostname.example.com:8080/api",
},
}
// use multiple schemes to demonstrate that the URL is based on the
// origin's scheme, not the :scheme header
for _, testScheme := range []string{"http", "https"} {
for _, testOrigin := range testOrigins {
assertion := func(testPath string) bool {
const expectedMethod = "POST"
const expectedHostname = "request.hostname.example.com"
h2 := []h2mux.Header{
{Name: ":method", Value: expectedMethod},
{Name: ":scheme", Value: testScheme},
{Name: ":authority", Value: expectedHostname},
{Name: ":path", Value: testPath},
{Name: RequestUserHeaders, Value: ""},
}
h1, err := http.NewRequest("GET", testOrigin.url, nil)
require.NoError(t, err)
err = H2RequestHeadersToH1Request(h2, h1)
return assert.NoError(t, err) &&
assert.Equal(t, expectedMethod, h1.Method) &&
assert.Equal(t, expectedHostname, h1.Host) &&
assert.Equal(t, testOrigin.expectedScheme, h1.URL.Scheme) &&
assert.Equal(t, testOrigin.expectedBasePath+testPath, h1.URL.String())
}
err := quick.Check(assertion, config)
assert.NoError(t, err)
}
}
}
func randomASCIIPrintableChar(rand *rand.Rand) int {
// smallest printable ASCII char is 32, largest is 126
const startPrintable = 32
const endPrintable = 127
return startPrintable + rand.Intn(endPrintable-startPrintable)
}
// randomASCIIText generates an ASCII string, some of whose characters may be
// percent-encoded. Its "logical length" (ignoring percent-encoding) is
// between 1 and `maxLength`.
func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string {
length := minLength + rand.Intn(maxLength)
var result strings.Builder
for i := 0; i < length; i++ {
c := randomASCIIPrintableChar(rand)
// 1/4 chance of using percent encoding when not necessary
if c == '%' || rand.Intn(4) == 0 {
result.WriteString(fmt.Sprintf("%%%02X", c))
} else {
result.WriteByte(byte(c))
}
}
return result.String()
}
// Calls `randomASCIIText` and ensures the result is a valid URL path,
// i.e. one that can pass unchanged through url.URL.String()
func randomHTTP1Path(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
text := randomASCIIText(rand, minLength, maxLength)
re, err := regexp.Compile("[^/;,]*")
require.NoError(t, err)
return "/" + re.ReplaceAllStringFunc(text, url.PathEscape)
}
// Calls `randomASCIIText` and ensures the result is a valid URL query,
// i.e. one that can pass unchanged through url.URL.String()
func randomHTTP1Query(rand *rand.Rand, minLength int, maxLength int) string {
text := randomASCIIText(rand, minLength, maxLength)
return "?" + strings.ReplaceAll(text, "#", "%23")
}
// Calls `randomASCIIText` and ensures the result is a valid URL fragment,
// i.e. one that can pass unchanged through url.URL.String()
func randomHTTP1Fragment(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
text := randomASCIIText(rand, minLength, maxLength)
u, err := url.Parse("#" + text)
require.NoError(t, err)
return u.String()
}
// Assemble a random :path pseudoheader that is legal by Go stdlib standards
// (i.e. all characters will satisfy "net/url".shouldEscape for their respective locations)
func randomHTTP2Path(t *testing.T, rand *rand.Rand) string {
result := randomHTTP1Path(t, rand, 1, 64)
if rand.Intn(2) == 1 {
result += randomHTTP1Query(rand, 1, 32)
}
if rand.Intn(2) == 1 {
result += randomHTTP1Fragment(t, rand, 1, 16)
}
return result
}
func stdlibHeaderToH2muxHeader(headers http.Header) (h2muxHeaders []h2mux.Header) {
for name, values := range headers {
for _, value := range values {
h2muxHeaders = append(h2muxHeaders, h2mux.Header{Name: name, Value: value})
}
}
return h2muxHeaders
}
func TestSerializeHeaders(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
assert.NoError(t, err)
@ -511,70 +84,13 @@ func TestDeserializeMalformed(t *testing.T) {
}
}
func TestParseRequestHeaders(t *testing.T) {
mockUserHeadersToSerialize := http.Header{
"Mock-Header-One": {"1", "1.5"},
"Mock-Header-Two": {"2"},
"Mock-Header-Three": {"3"},
}
mockHeaders := []h2mux.Header{
{Name: "One", Value: "1"}, // will be dropped
{Name: "Cf-Two", Value: "cf-value-1"},
{Name: "Cf-Two", Value: "cf-value-2"},
{Name: RequestUserHeaders, Value: SerializeHeaders(mockUserHeadersToSerialize)},
}
expectedHeaders := []h2mux.Header{
{Name: "Cf-Two", Value: "cf-value-1"},
{Name: "Cf-Two", Value: "cf-value-2"},
{Name: "Mock-Header-One", Value: "1"},
{Name: "Mock-Header-One", Value: "1.5"},
{Name: "Mock-Header-Two", Value: "2"},
{Name: "Mock-Header-Three", Value: "3"},
}
h1 := &http.Request{
Header: make(http.Header),
}
err := H2RequestHeadersToH1Request(mockHeaders, h1)
assert.NoError(t, err)
assert.ElementsMatch(t, expectedHeaders, stdlibHeaderToH2muxHeader(h1.Header))
}
func TestIsControlRequestHeader(t *testing.T) {
controlRequestHeaders := []string{
// Anything that begins with cf-
"cf-sample-header",
// Any http2 pseudoheader
":sample-pseudo-header",
// content-length is a special case, it has to be there
// for some requests to work (per the HTTP2 spec)
"content-length",
// Websocket request headers
"connection",
"upgrade",
}
for _, header := range controlRequestHeaders {
assert.True(t, IsControlRequestHeader(header))
}
}
func TestIsControlResponseHeader(t *testing.T) {
controlResponseHeaders := []string{
// Anything that begins with cf-int- or cf-cloudflared-
"cf-int-sample-header",
"cf-cloudflared-sample-header",
// Any http2 pseudoheader
":sample-pseudo-header",
// content-length is a special case, it has to be there
// for some requests to work (per the HTTP2 spec)
"content-length",
}
for _, header := range controlResponseHeaders {
@ -582,17 +98,6 @@ func TestIsControlResponseHeader(t *testing.T) {
}
}
func TestIsNotControlRequestHeader(t *testing.T) {
notControlRequestHeaders := []string{
"mock-header",
"another-sample-header",
}
for _, header := range notControlRequestHeaders {
assert.False(t, IsControlRequestHeader(header))
}
}
func TestIsNotControlResponseHeader(t *testing.T) {
notControlResponseHeaders := []string{
"mock-header",
@ -606,112 +111,3 @@ func TestIsNotControlResponseHeader(t *testing.T) {
assert.False(t, IsControlResponseHeader(header))
}
}
func TestH1ResponseToH2ResponseHeaders(t *testing.T) {
mockHeaders := http.Header{
"User-header-one": {""},
"User-header-two": {"1", "2"},
"cf-header": {"cf-value"},
"cf-int-header": {"cf-int-value"},
"cf-cloudflared-header": {"cf-cloudflared-value"},
"Content-Length": {"123"},
}
mockResponse := http.Response{
StatusCode: 200,
Header: mockHeaders,
}
headers := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header)
serializedHeadersIndex := -1
for i, header := range headers {
if header.Name == ResponseUserHeaders {
serializedHeadersIndex = i
break
}
}
assert.NotEqual(t, -1, serializedHeadersIndex)
actualControlHeaders := append(
headers[:serializedHeadersIndex],
headers[serializedHeadersIndex+1:]...,
)
expectedControlHeaders := []h2mux.Header{
{Name: ":status", Value: "200"},
{Name: "content-length", Value: "123"},
}
assert.ElementsMatch(t, expectedControlHeaders, actualControlHeaders)
actualUserHeaders, err := DeserializeHeaders(headers[serializedHeadersIndex].Value)
expectedUserHeaders := []h2mux.Header{
{Name: "User-header-one", Value: ""},
{Name: "User-header-two", Value: "1"},
{Name: "User-header-two", Value: "2"},
{Name: "cf-header", Value: "cf-value"},
}
assert.NoError(t, err)
assert.ElementsMatch(t, expectedUserHeaders, actualUserHeaders)
}
// The purpose of this test is to check that our code and the http.Header
// implementation don't throw validation errors about header size
func TestHeaderSize(t *testing.T) {
largeValue := randSeq(5 * 1024 * 1024) // 5Mb
largeHeaders := http.Header{
"User-header": {largeValue},
}
mockResponse := http.Response{
StatusCode: 200,
Header: largeHeaders,
}
serializedHeaders := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header)
request, err := http.NewRequest(http.MethodGet, "https://example.com/", nil)
assert.NoError(t, err)
for _, header := range serializedHeaders {
request.Header.Set(header.Name, header.Value)
}
for _, header := range serializedHeaders {
if header.Name != ResponseUserHeaders {
continue
}
deserializedHeaders, err := DeserializeHeaders(header.Value)
assert.NoError(t, err)
assert.Equal(t, largeValue, deserializedHeaders[0].Value)
}
}
func randSeq(n int) string {
randomizer := rand.New(rand.NewSource(17))
var letters = []rune(":;,+/=abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
b := make([]rune, n)
for i := range b {
b[i] = letters[randomizer.Intn(len(letters))]
}
return string(b)
}
func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) {
ser := "eC1mb3J3YXJkZWQtcHJvdG8:aHR0cHM;dXBncmFkZS1pbnNlY3VyZS1yZXF1ZXN0cw:MQ;YWNjZXB0LWxhbmd1YWdl:ZW4tVVMsZW47cT0wLjkscnU7cT0wLjg;YWNjZXB0LWVuY29kaW5n:Z3ppcA;eC1mb3J3YXJkZWQtZm9y:MTczLjI0NS42MC42;dXNlci1hZ2VudA:TW96aWxsYS81LjAgKE1hY2ludG9zaDsgSW50ZWwgTWFjIE9TIFggMTBfMTRfNikgQXBwbGVXZWJLaXQvNTM3LjM2IChLSFRNTCwgbGlrZSBHZWNrbykgQ2hyb21lLzg0LjAuNDE0Ny44OSBTYWZhcmkvNTM3LjM2;c2VjLWZldGNoLW1vZGU:bmF2aWdhdGU;Y2RuLWxvb3A:Y2xvdWRmbGFyZQ;c2VjLWZldGNoLWRlc3Q:ZG9jdW1lbnQ;c2VjLWZldGNoLXVzZXI:PzE;c2VjLWZldGNoLXNpdGU:bm9uZQ;Y29va2ll:X19jZmR1aWQ9ZGNkOWZjOGNjNWMxMzE0NTMyYTFkMjhlZDEyOWRhOTYwMTU2OTk1MTYzNDsgX19jZl9ibT1mYzY2MzMzYzAzZmM0MWFiZTZmOWEyYzI2ZDUwOTA0YzIxYzZhMTQ2LTE1OTU2MjIzNDEtMTgwMC1BZTVzS2pIU2NiWGVFM05mMUhrTlNQMG1tMHBLc2pQWkloVnM1Z2g1SkNHQkFhS1UxVDB2b003alBGN3FjMHVSR2NjZGcrWHdhL1EzbTJhQzdDVU4xZ2M9;YWNjZXB0:dGV4dC9odG1sLGFwcGxpY2F0aW9uL3hodG1sK3htbCxhcHBsaWNhdGlvbi94bWw7cT0wLjksaW1hZ2Uvd2VicCxpbWFnZS9hcG5nLCovKjtxPTAuOCxhcHBsaWNhdGlvbi9zaWduZWQtZXhjaGFuZ2U7dj1iMztxPTAuOQ"
h2, _ := DeserializeHeaders(ser)
h1 := make(http.Header)
for _, header := range h2 {
h1.Add(header.Name, header.Value)
}
h1.Add("Content-Length", "200")
h1.Add("Cf-Something", "Else")
h1.Add("Upgrade", "websocket")
h1resp := &http.Response{
StatusCode: 200,
Header: h1,
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = H1ResponseToH2ResponseHeaders(h1resp.StatusCode, h1resp.Header)
}
}

View File

@ -4,7 +4,6 @@ import (
"context"
"fmt"
"io"
"math"
"net"
"net/http"
"runtime/debug"
@ -60,7 +59,7 @@ func NewHTTP2Connection(
return &HTTP2Connection{
conn: conn,
server: &http2.Server{
MaxConcurrentStreams: math.MaxUint32,
MaxConcurrentStreams: MaxConcurrentStreams,
},
config: config,
connOptions: connOptions,
@ -109,7 +108,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch connType {
case TypeControlStream:
if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, true); err != nil {
if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions); err != nil {
c.controlStreamErr = err
c.log.Error().Err(err)
respWriter.WriteErrorResponse()
@ -126,7 +125,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case TypeTCP:
host, err := getRequestHost(r)
if err != nil {
err := fmt.Errorf(`cloudflared recieved a warp-routing request with an empty host value: %w`, err)
err := fmt.Errorf(`cloudflared received a warp-routing request with an empty host value: %w`, err)
c.log.Error().Err(err)
respWriter.WriteErrorResponse()
}
@ -185,12 +184,14 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro
for name, values := range header {
// Since these are http2 headers, they're required to be lowercase
h2name := strings.ToLower(name)
if h2name == "content-length" {
// This header has meaning in HTTP/2 and will be used by the edge,
// so it should be sent as an HTTP/2 response header.
// so it should be sent *also* as an HTTP/2 response header.
dest[name] = values
// Since these are http2 headers, they're required to be lowercase
} else if !IsControlResponseHeader(h2name) || IsWebsocketClientHeader(h2name) {
}
if !IsControlResponseHeader(h2name) || IsWebsocketClientHeader(h2name) {
// User headers, on the other hand, must all be serialized so that
// HTTP/2 header validation won't be applied to HTTP/1 header values
userHeaders[name] = values

View File

@ -7,24 +7,24 @@ import (
"time"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/edgediscovery"
)
const (
AvailableProtocolFlagMessage = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux"
AvailableProtocolFlagMessage = "Available protocols: 'auto' - automatically chooses the best protocol over time (the default; and also the recommended one); 'quic' - based on QUIC, relying on UDP egress to Cloudflare edge; 'http2' - using Go's HTTP2 library, relying on TCP egress to Cloudflare edge; 'h2mux' - Cloudflare's implementation of HTTP/2, deprecated"
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
edgeH2muxTLSServerName = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com"
// edgeQUICServerName is the server name to establish quic connection with edge.
edgeQUICServerName = "quic.cftunnel.com"
// threshold to switch back to h2mux when the user intentionally pick --protocol http2
explicitHTTP2FallbackThreshold = -1
autoSelectFlag = "auto"
autoSelectFlag = "auto"
)
var (
// ProtocolList represents a list of supported protocols for communication with the edge.
ProtocolList = []Protocol{H2mux, HTTP2, QUIC}
ProtocolList = []Protocol{H2mux, HTTP2, HTTP2Warp, QUIC, QUICWarp}
)
type Protocol int64
@ -36,6 +36,12 @@ const (
HTTP2
// QUIC is used only with named tunnels.
QUIC
// HTTP2Warp is used only with named tunnels. It's useful for warp-routing where we don't want to fallback to
// H2mux on HTTP2 failure to connect.
HTTP2Warp
//QUICWarp is used only with named tunnels. It's useful for warp-routing where we want to fallback to HTTP2 but
// don't want HTTP2 to fallback to H2mux
QUICWarp
)
// Fallback returns the fallback protocol and whether the protocol has a fallback
@ -45,8 +51,12 @@ func (p Protocol) fallback() (Protocol, bool) {
return 0, false
case HTTP2:
return H2mux, true
case HTTP2Warp:
return 0, false
case QUIC:
return HTTP2, true
case QUICWarp:
return HTTP2Warp, true
default:
return 0, false
}
@ -56,9 +66,9 @@ func (p Protocol) String() string {
switch p {
case H2mux:
return "h2mux"
case HTTP2:
case HTTP2, HTTP2Warp:
return "http2"
case QUIC:
case QUIC, QUICWarp:
return "quic"
default:
return fmt.Sprintf("unknown protocol")
@ -71,11 +81,11 @@ func (p Protocol) TLSSettings() *TLSSettings {
return &TLSSettings{
ServerName: edgeH2muxTLSServerName,
}
case HTTP2:
case HTTP2, HTTP2Warp:
return &TLSSettings{
ServerName: edgeH2TLSServerName,
}
case QUIC:
case QUIC, QUICWarp:
return &TLSSettings{
ServerName: edgeQUICServerName,
NextProtos: []string{"argotunnel"},
@ -108,29 +118,36 @@ func (s *staticProtocolSelector) Fallback() (Protocol, bool) {
}
type autoProtocolSelector struct {
lock sync.RWMutex
current Protocol
switchThrehold int32
fetchFunc PercentageFetcher
refreshAfter time.Time
ttl time.Duration
log *zerolog.Logger
lock sync.RWMutex
current Protocol
// protocolPool is desired protocols in the order of priority they should be picked in.
protocolPool []Protocol
switchThreshold int32
fetchFunc PercentageFetcher
refreshAfter time.Time
ttl time.Duration
log *zerolog.Logger
}
func newAutoProtocolSelector(
current Protocol,
switchThrehold int32,
protocolPool []Protocol,
switchThreshold int32,
fetchFunc PercentageFetcher,
ttl time.Duration,
log *zerolog.Logger,
) *autoProtocolSelector {
return &autoProtocolSelector{
current: current,
switchThrehold: switchThrehold,
fetchFunc: fetchFunc,
refreshAfter: time.Now().Add(ttl),
ttl: ttl,
log: log,
current: current,
protocolPool: protocolPool,
switchThreshold: switchThreshold,
fetchFunc: fetchFunc,
refreshAfter: time.Now().Add(ttl),
ttl: ttl,
log: log,
}
}
@ -141,28 +158,39 @@ func (s *autoProtocolSelector) Current() Protocol {
return s.current
}
percentage, err := s.fetchFunc()
protocol, err := getProtocol(s.protocolPool, s.fetchFunc, s.switchThreshold)
if err != nil {
s.log.Err(err).Msg("Failed to refresh protocol")
return s.current
}
s.current = protocol
if s.switchThrehold < percentage {
s.current = HTTP2
} else {
s.current = H2mux
}
s.refreshAfter = time.Now().Add(s.ttl)
return s.current
}
func getProtocol(protocolPool []Protocol, fetchFunc PercentageFetcher, switchThreshold int32) (Protocol, error) {
protocolPercentages, err := fetchFunc()
if err != nil {
return 0, err
}
for _, protocol := range protocolPool {
protocolPercentage := protocolPercentages.GetPercentage(protocol.String())
if protocolPercentage > switchThreshold {
return protocol, nil
}
}
return protocolPool[len(protocolPool)-1], nil
}
func (s *autoProtocolSelector) Fallback() (Protocol, bool) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.current.fallback()
}
type PercentageFetcher func() (int32, error)
type PercentageFetcher func() (edgediscovery.ProtocolPercents, error)
func NewProtocolSelector(
protocolFlag string,
@ -179,54 +207,76 @@ func NewProtocolSelector(
}, nil
}
// warp routing cannot be served over h2mux connections
if warpRoutingEnabled {
if protocolFlag == H2mux.String() {
log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.")
}
if protocolFlag == QUIC.String() {
return &staticProtocolSelector{
current: QUIC,
}, nil
}
return &staticProtocolSelector{
current: HTTP2,
}, nil
}
if protocolFlag == H2mux.String() {
return &staticProtocolSelector{
current: H2mux,
}, nil
}
if protocolFlag == QUIC.String() {
return newAutoProtocolSelector(QUIC, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
}
http2Percentage, err := fetchFunc()
threshold := switchThreshold(namedTunnel.Credentials.AccountTag)
fetchedProtocol, err := getProtocol([]Protocol{QUIC, HTTP2}, fetchFunc, threshold)
if err != nil {
log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.")
return &staticProtocolSelector{
current: HTTP2,
}, nil
}
if protocolFlag == HTTP2.String() {
if http2Percentage < 0 {
return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
if warpRoutingEnabled {
if protocolFlag == H2mux.String() || fetchedProtocol == H2mux {
log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.")
protocolFlag = HTTP2.String()
fetchedProtocol = HTTP2Warp
}
return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
return selectWarpRoutingProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol)
}
if protocolFlag != autoSelectFlag {
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
return selectNamedTunnelProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol)
}
func selectNamedTunnelProtocols(
protocolFlag string,
fetchFunc PercentageFetcher,
ttl time.Duration,
log *zerolog.Logger,
threshold int32,
protocol Protocol,
) (ProtocolSelector, error) {
// If the user picks a protocol, then we stick to it no matter what.
switch protocolFlag {
case H2mux.String():
return &staticProtocolSelector{current: H2mux}, nil
case QUIC.String():
return &staticProtocolSelector{current: QUIC}, nil
case HTTP2.String():
return &staticProtocolSelector{current: HTTP2}, nil
}
threshold := switchThreshold(namedTunnel.Credentials.AccountTag)
if threshold < http2Percentage {
return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, log), nil
// If the user does not pick (hopefully the majority) then we use the one derived from the TXT DNS record and
// fallback on failures.
if protocolFlag == autoSelectFlag {
return newAutoProtocolSelector(protocol, []Protocol{QUIC, HTTP2, H2mux}, threshold, fetchFunc, ttl, log), nil
}
return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, log), nil
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
}
func selectWarpRoutingProtocols(
protocolFlag string,
fetchFunc PercentageFetcher,
ttl time.Duration,
log *zerolog.Logger,
threshold int32,
protocol Protocol,
) (ProtocolSelector, error) {
// If the user picks a protocol, then we stick to it no matter what.
switch protocolFlag {
case QUIC.String():
return &staticProtocolSelector{current: QUICWarp}, nil
case HTTP2.String():
return &staticProtocolSelector{current: HTTP2Warp}, nil
}
// If the user does not pick (hopefully the majority) then we use the one derived from the TXT DNS record and
// fallback on failures.
if protocolFlag == autoSelectFlag {
return newAutoProtocolSelector(protocol, []Protocol{QUICWarp, HTTP2Warp}, threshold, fetchFunc, ttl, log), nil
}
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
}
func switchThreshold(accountTag string) int32 {

View File

@ -6,6 +6,8 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/edgediscovery"
)
const (
@ -21,29 +23,23 @@ var (
}
)
func mockFetcher(percentage int32) PercentageFetcher {
return func() (int32, error) {
return percentage, nil
}
}
func mockFetcherWithError() PercentageFetcher {
return func() (int32, error) {
return 0, fmt.Errorf("failed to fetch precentage")
func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) PercentageFetcher {
return func() (edgediscovery.ProtocolPercents, error) {
if getError {
return nil, fmt.Errorf("failed to fetch percentage")
}
return protocolPercent, nil
}
}
type dynamicMockFetcher struct {
percentage int32
err error
protocolPercents edgediscovery.ProtocolPercents
err error
}
func (dmf *dynamicMockFetcher) fetch() PercentageFetcher {
return func() (int32, error) {
if dmf.err != nil {
return 0, dmf.err
}
return dmf.percentage, nil
return func() (edgediscovery.ProtocolPercents, error) {
return dmf.protocolPercents, dmf.err
}
}
@ -69,36 +65,59 @@ func TestNewProtocolSelector(t *testing.T) {
name: "named tunnel over h2mux",
protocol: "h2mux",
expectedProtocol: H2mux,
fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil },
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel over http2",
protocol: "http2",
expectedProtocol: HTTP2,
hasFallback: true,
expectedFallback: H2mux,
fetchFunc: mockFetcher(0),
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel http2 disabled",
name: "named tunnel http2 disabled still gets http2 because it is manually picked",
protocol: "http2",
expectedProtocol: HTTP2,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel quic disabled still gets quic because it is manually picked",
protocol: "quic",
expectedProtocol: QUIC,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel quic and http2 disabled",
protocol: "auto",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(-1),
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel quic disabled",
protocol: "auto",
expectedProtocol: HTTP2,
// Hasfallback true is because if http2 fails, then we further fallback to h2mux.
hasFallback: true,
expectedFallback: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel auto all http2 disabled",
protocol: "auto",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(-1),
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel auto to h2mux",
protocol: "auto",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(0),
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}),
namedTunnelConfig: testNamedTunnelConfig,
},
{
@ -107,36 +126,71 @@ func TestNewProtocolSelector(t *testing.T) {
expectedProtocol: HTTP2,
hasFallback: true,
expectedFallback: H2mux,
fetchFunc: mockFetcher(100),
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel auto to quic",
protocol: "auto",
expectedProtocol: QUIC,
hasFallback: true,
expectedFallback: HTTP2,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing requesting h2mux",
protocol: "h2mux",
expectedProtocol: HTTP2,
expectedProtocol: HTTP2Warp,
hasFallback: false,
expectedFallback: H2mux,
fetchFunc: mockFetcher(100),
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1",
protocol: "h2mux",
expectedProtocol: HTTP2Warp,
hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing http2",
protocol: "http2",
expectedProtocol: HTTP2,
expectedProtocol: HTTP2Warp,
hasFallback: false,
expectedFallback: H2mux,
fetchFunc: mockFetcher(100),
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing quic",
protocol: "auto",
expectedProtocol: QUICWarp,
hasFallback: true,
expectedFallback: HTTP2Warp,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing auto",
protocol: "auto",
expectedProtocol: HTTP2,
expectedProtocol: HTTP2Warp,
hasFallback: false,
expectedFallback: H2mux,
fetchFunc: mockFetcher(100),
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing auto- quic",
protocol: "auto",
expectedProtocol: QUICWarp,
hasFallback: true,
expectedFallback: HTTP2Warp,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
@ -149,14 +203,14 @@ func TestNewProtocolSelector(t *testing.T) {
{
name: "named tunnel unknown protocol",
protocol: "unknown",
fetchFunc: mockFetcher(100),
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
namedTunnelConfig: testNamedTunnelConfig,
wantErr: true,
},
{
name: "named tunnel fetch error",
protocol: "unknown",
fetchFunc: mockFetcherWithError(),
protocol: "auto",
fetchFunc: mockFetcher(true),
namedTunnelConfig: testNamedTunnelConfig,
expectedProtocol: HTTP2,
wantErr: false,
@ -164,18 +218,20 @@ func TestNewProtocolSelector(t *testing.T) {
}
for _, test := range tests {
selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log)
if test.wantErr {
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
} else {
assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
fallback, ok := selector.Fallback()
assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
if test.hasFallback {
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
t.Run(test.name, func(t *testing.T) {
selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log)
if test.wantErr {
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
} else {
assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
fallback, ok := selector.Fallback()
assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
if test.hasFallback {
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
}
}
}
})
}
}
@ -185,64 +241,66 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 100
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 0
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 100
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, HTTP2, selector.Current())
fetcher.err = fmt.Errorf("failed to fetch")
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = -1
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
fetcher.err = nil
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 0
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
assert.Equal(t, QUIC, selector.Current())
}
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
fetcher := dynamicMockFetcher{}
// Since the user chooses http2 on purpose, we always stick to it.
selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log)
assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 100
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 0
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
assert.Equal(t, HTTP2, selector.Current())
fetcher.err = fmt.Errorf("failed to fetch")
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = -1
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
fetcher.err = nil
assert.Equal(t, H2mux, selector.Current())
fetcher.percentage = 0
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = 100
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
assert.Equal(t, HTTP2, selector.Current())
fetcher.percentage = -1
assert.Equal(t, H2mux, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
assert.Equal(t, HTTP2, selector.Current())
}
func TestProtocolSelectorRefreshTTL(t *testing.T) {
fetcher := dynamicMockFetcher{percentage: 100}
fetcher := dynamicMockFetcher{}
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log)
assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current())
assert.Equal(t, QUIC, selector.Current())
fetcher.percentage = 0
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 0}}
assert.Equal(t, QUIC, selector.Current())
}

View File

@ -9,11 +9,16 @@ import (
"net/http"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/lucas-clemente/quic-go"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/datagramsession"
"github.com/cloudflare/cloudflared/ingress"
quicpogs "github.com/cloudflare/cloudflared/quic"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
@ -29,60 +34,101 @@ const (
// QUICConnection represents the type that facilitates Proxying via QUIC streams.
type QUICConnection struct {
session quic.Session
logger *zerolog.Logger
httpProxy OriginProxy
gracefulShutdownC <-chan struct{}
stoppedGracefully bool
session quic.Session
logger *zerolog.Logger
httpProxy OriginProxy
sessionManager datagramsession.Manager
controlStreamHandler ControlStreamHandler
connOptions *tunnelpogs.ConnectionOptions
}
// NewQUICConnection returns a new instance of QUICConnection.
func NewQUICConnection(
ctx context.Context,
quicConfig *quic.Config,
edgeAddr net.Addr,
tlsConfig *tls.Config,
httpProxy OriginProxy,
connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler ControlStreamHandler,
observer *Observer,
logger *zerolog.Logger,
) (*QUICConnection, error) {
session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig)
if err != nil {
return nil, errors.Wrap(err, "failed to dial to edge")
return nil, fmt.Errorf("failed to dial to edge: %w", err)
}
registrationStream, err := session.OpenStream()
datagramMuxer, err := quicpogs.NewDatagramMuxer(session)
if err != nil {
return nil, errors.Wrap(err, "failed to open a registration stream")
}
err = controlStreamHandler.ServeControlStream(ctx, registrationStream, connOptions, false)
if err != nil {
// Not wrapping error here to be consistent with the http2 message.
return nil, err
}
sessionManager := datagramsession.NewManager(datagramMuxer, logger)
return &QUICConnection{
session: session,
httpProxy: httpProxy,
logger: observer.log,
session: session,
httpProxy: httpProxy,
logger: logger,
sessionManager: sessionManager,
controlStreamHandler: controlStreamHandler,
connOptions: connOptions,
}, nil
}
// Serve starts a QUIC session that begins accepting streams.
func (q *QUICConnection) Serve(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// origintunneld assumes the first stream is used for the control plane
controlStream, err := q.session.OpenStream()
if err != nil {
return fmt.Errorf("failed to open a registration control stream: %w", err)
}
// If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits
// as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this
// connection).
// If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the
// other goroutine as fast as possible.
ctx, cancel := context.WithCancel(ctx)
errGroup, ctx := errgroup.WithContext(ctx)
// In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control
// stream is already fully registered before the other goroutines can proceed.
errGroup.Go(func() error {
defer cancel()
return q.serveControlStream(ctx, controlStream)
})
errGroup.Go(func() error {
defer cancel()
return q.acceptStream(ctx)
})
errGroup.Go(func() error {
defer cancel()
return q.sessionManager.Serve(ctx)
})
return errGroup.Wait()
}
func (q *QUICConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error {
// This blocks until the control plane is done.
err := q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions)
if err != nil {
// Not wrapping error here to be consistent with the http2 message.
return err
}
return nil
}
func (q *QUICConnection) acceptStream(ctx context.Context) error {
defer q.Close()
for {
stream, err := q.session.AcceptStream(ctx)
if err != nil {
// context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional.
if errors.Is(err, context.Canceled) {
if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() {
return nil
}
return errors.Wrap(err, "failed to accept QUIC stream")
return fmt.Errorf("failed to accept QUIC stream: %w", err)
}
go func() {
defer stream.Close()
@ -99,7 +145,30 @@ func (q *QUICConnection) Close() {
}
func (q *QUICConnection) handleStream(stream quic.Stream) error {
connectRequest, err := quicpogs.ReadConnectRequestData(stream)
signature, err := quicpogs.DetermineProtocol(stream)
if err != nil {
return err
}
switch signature {
case quicpogs.DataStreamProtocolSignature:
reqServerStream, err := quicpogs.NewRequestServerStream(stream, signature)
if err != nil {
return nil
}
return q.handleDataStream(reqServerStream)
case quicpogs.RPCStreamProtocolSignature:
rpcStream, err := quicpogs.NewRPCServerStream(stream, signature)
if err != nil {
return err
}
return q.handleRPCStream(rpcStream)
default:
return fmt.Errorf("unknown protocol %v", signature)
}
}
func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) error {
connectRequest, err := stream.ReadConnectRequestData()
if err != nil {
return err
}
@ -114,32 +183,99 @@ func (q *QUICConnection) handleStream(stream quic.Stream) error {
w := newHTTPResponseAdapter(stream)
return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket)
case quicpogs.ConnectionTypeTCP:
rwa := &streamReadWriteAcker{
ReadWriter: stream,
}
rwa := &streamReadWriteAcker{stream}
return q.httpProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest})
}
return nil
}
func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) error {
return rpcStream.Serve(q, q.logger)
}
// RegisterUdpSession is the RPC method invoked by edge to register and run a session
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration) error {
// Each session is a series of datagram from an eyeball to a dstIP:dstPort.
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
originProxy, err := ingress.DialUDP(dstIP, dstPort)
if err != nil {
q.logger.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
return err
}
session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy)
if err != nil {
q.logger.Err(err).Str("sessionID", sessionID.String()).Msgf("Failed to register udp session")
return err
}
go q.serveUDPSession(session, closeAfterIdleHint)
q.logger.Debug().Msgf("Registered session %v, %v, %v", sessionID, dstIP, dstPort)
return nil
}
func (q *QUICConnection) serveUDPSession(session *datagramsession.Session, closeAfterIdleHint time.Duration) {
ctx := q.session.Context()
closedByRemote, err := session.Serve(ctx, closeAfterIdleHint)
// If session is terminated by remote, then we know it has been unregistered from session manager and edge
if !closedByRemote {
if err != nil {
q.closeUDPSession(ctx, session.ID, err.Error())
} else {
q.closeUDPSession(ctx, session.ID, "terminated without error")
}
}
q.logger.Debug().Err(err).Str("sessionID", session.ID.String()).Msg("Session terminated")
}
// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge
func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) {
q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
stream, err := q.session.OpenStream()
if err != nil {
// Log this at debug because this is not an error if session was closed due to lost connection
// with edge
q.logger.Debug().Err(err).Str("sessionID", sessionID.String()).
Msgf("Failed to open quic stream to unregister udp session with edge")
return
}
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.logger)
if err != nil {
// Log this at debug because this is not an error if session was closed due to lost connection
// with edge
q.logger.Err(err).Str("sessionID", sessionID.String()).
Msgf("Failed to open rpc stream to unregister udp session with edge")
return
}
if err := rpcClientStream.UnregisterUdpSession(ctx, sessionID, message); err != nil {
q.logger.Err(err).Str("sessionID", sessionID.String()).
Msgf("Failed to unregister udp session with edge")
}
}
// UnregisterUdpSession is the RPC method invoked by edge to unregister and terminate a sesssion
func (q *QUICConnection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
return q.sessionManager.UnregisterSession(ctx, sessionID, message, true)
}
// streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to
// the client.
type streamReadWriteAcker struct {
io.ReadWriter
*quicpogs.RequestServerStream
}
// AckConnection acks response back to the proxy.
func (s *streamReadWriteAcker) AckConnection() error {
return quicpogs.WriteConnectResponseData(s, nil)
return s.WriteConnectResponseData(nil)
}
// httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC.
type httpResponseAdapter struct {
io.Writer
*quicpogs.RequestServerStream
}
func newHTTPResponseAdapter(w io.Writer) httpResponseAdapter {
return httpResponseAdapter{w}
func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter {
return httpResponseAdapter{s}
}
func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
@ -151,18 +287,19 @@ func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header)
metadata = append(metadata, quicpogs.Metadata{Key: httpHeaderKey, Val: v})
}
}
return quicpogs.WriteConnectResponseData(hrw, nil, metadata...)
return hrw.WriteConnectResponseData(nil, metadata...)
}
func (hrw httpResponseAdapter) WriteErrorResponse(err error) {
quicpogs.WriteConnectResponseData(hrw, err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)})
}
func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.Reader) (*http.Request, error) {
func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*http.Request, error) {
metadata := connectRequest.MetadataMap()
dest := connectRequest.Dest
method := metadata[HTTPMethodKey]
host := metadata[HTTPHostKey]
isWebsocket := connectRequest.Type == quicpogs.ConnectionTypeWebsocket
req, err := http.NewRequest(method, dest, body)
if err != nil {
@ -175,11 +312,42 @@ func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.Reader) (
// metadata.Key is off the format httpHeaderKey:<HTTPHeader>
httpHeaderKey := strings.Split(metadata.Key, ":")
if len(httpHeaderKey) != 2 {
return nil, fmt.Errorf("Header Key: %s malformed", metadata.Key)
return nil, fmt.Errorf("header Key: %s malformed", metadata.Key)
}
req.Header.Add(httpHeaderKey[1], metadata.Val)
}
}
// Go's http.Client automatically sends chunked request body if this value is not set on the
// *http.Request struct regardless of header:
// https://go.googlesource.com/go/+/go1.8rc2/src/net/http/transfer.go#154.
if err := setContentLength(req); err != nil {
return nil, fmt.Errorf("Error setting content-length: %w", err)
}
// Go's client defaults to chunked encoding after a 200ms delay if the following cases are true:
// * the request body blocks
// * the content length is not set (or set to -1)
// * the method doesn't usually have a body (GET, HEAD, DELETE, ...)
// * there is no transfer-encoding=chunked already set.
// So, if transfer cannot be chunked and content length is 0, we dont set a request body.
if !isWebsocket && !isTransferEncodingChunked(req) && req.ContentLength == 0 {
req.Body = nil
}
stripWebsocketUpgradeHeader(req)
return req, err
}
func setContentLength(req *http.Request) error {
var err error
if contentLengthStr := req.Header.Get("Content-Length"); contentLengthStr != "" {
req.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
}
return err
}
func isTransferEncodingChunked(req *http.Request) bool {
transferEncodingVal := req.Header.Get("Transfer-Encoding")
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding suggests that this can be a comma
// separated value as well.
return strings.Contains(strings.ToLower(transferEncodingVal), "chunked")
}

View File

@ -13,32 +13,36 @@ import (
"math/big"
"net"
"net/http"
"net/url"
"os"
"sync"
"testing"
"time"
"github.com/gobwas/ws/wsutil"
"github.com/google/uuid"
"github.com/lucas-clemente/quic-go"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/datagramsession"
quicpogs "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
var (
testTLSServerConfig = generateTLSConfig()
testQUICConfig = &quic.Config{
KeepAlive: true,
EnableDatagrams: true,
}
)
// TestQUICServer tests if a quic server accepts and responds to a quic client with the acceptance protocol.
// It also serves as a demonstration for communication with the QUIC connection started by a cloudflared.
func TestQUICServer(t *testing.T) {
quicConfig := &quic.Config{
KeepAlive: true,
}
// Setup test.
log := zerolog.New(os.Stdout)
// Start a UDP Listener for QUIC.
udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
require.NoError(t, err)
@ -46,18 +50,6 @@ func TestQUICServer(t *testing.T) {
require.NoError(t, err)
defer udpListener.Close()
// Create a simple tls config.
tlsConfig := generateTLSConfig()
// Create a client config
tlsClientConfig := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"argotunnel"},
}
// Start a mock httpProxy
originProxy := &mockOriginProxyWithRequest{}
// This is simply a sample websocket frame message.
wsBuf := &bytes.Buffer{}
wsutil.WriteClientText(wsBuf, []byte("Hello"))
@ -75,15 +67,15 @@ func TestQUICServer(t *testing.T) {
dest: "/ok",
connectionType: quicpogs.ConnectionTypeHTTP,
metadata: []quicpogs.Metadata{
quicpogs.Metadata{
{
Key: "HttpHeader:Cf-Ray",
Val: "123123123",
},
quicpogs.Metadata{
{
Key: "HttpHost",
Val: "cf.host",
},
quicpogs.Metadata{
{
Key: "HttpMethod",
Val: "GET",
},
@ -95,18 +87,22 @@ func TestQUICServer(t *testing.T) {
dest: "/echo_body",
connectionType: quicpogs.ConnectionTypeHTTP,
metadata: []quicpogs.Metadata{
quicpogs.Metadata{
{
Key: "HttpHeader:Cf-Ray",
Val: "123123123",
},
quicpogs.Metadata{
{
Key: "HttpHost",
Val: "cf.host",
},
quicpogs.Metadata{
{
Key: "HttpMethod",
Val: "POST",
},
{
Key: "HttpHeader:Content-Length",
Val: "24",
},
},
message: []byte("This is the message body"),
expectedResponse: []byte("This is the message body"),
@ -116,19 +112,19 @@ func TestQUICServer(t *testing.T) {
dest: "/ok",
connectionType: quicpogs.ConnectionTypeWebsocket,
metadata: []quicpogs.Metadata{
quicpogs.Metadata{
{
Key: "HttpHeader:Cf-Cloudflared-Proxy-Connection-Upgrade",
Val: "Websocket",
},
quicpogs.Metadata{
{
Key: "HttpHeader:Another-Header",
Val: "Misc",
},
quicpogs.Metadata{
{
Key: "HttpHost",
Val: "cf.host",
},
quicpogs.Metadata{
{
Key: "HttpMethod",
Val: "get",
},
@ -149,29 +145,17 @@ func TestQUICServer(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
wg.Add(1)
defer wg.Done()
quicServer(
t, udpListener, tlsConfig, quicConfig,
t, udpListener, testTLSServerConfig, testQUICConfig,
test.dest, test.connectionType, test.metadata, test.message, test.expectedResponse,
)
}()
controlStream := fakeControlStream{}
qC, err := NewQUICConnection(
ctx,
quicConfig,
udpListener.LocalAddr(),
tlsClientConfig,
originProxy,
&pogs.ConnectionOptions{},
controlStream,
NewObserver(&log, &log, false),
)
require.NoError(t, err)
go qC.Serve(ctx)
qc := testQUICConnection(udpListener.LocalAddr(), t)
go qc.Serve(ctx)
wg.Wait()
cancel()
@ -183,7 +167,8 @@ type fakeControlStream struct {
ControlStreamHandler
}
func (fakeControlStream) ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, shouldWaitForUnregister bool) error {
func (fakeControlStream) ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions) error {
<-ctx.Done()
return nil
}
func (fakeControlStream) IsStopped() bool {
@ -213,10 +198,11 @@ func quicServer(
stream, err := session.OpenStreamSync(context.Background())
require.NoError(t, err)
err = quicpogs.WriteConnectRequestData(stream, dest, connectionType, metadata...)
reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream}
err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...)
require.NoError(t, err)
_, err = quicpogs.ReadConnectResponseData(stream)
_, err = reqClientStream.ReadConnectResponseData()
require.NoError(t, err)
if message != nil {
@ -292,8 +278,390 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
return nil
}
func TestBuildHTTPRequest(t *testing.T) {
var tests = []struct {
name string
connectRequest *quicpogs.ConnectRequest
body io.ReadCloser
req *http.Request
}{
{
name: "check if http.Request is built correctly with content length",
connectRequest: &quicpogs.ConnectRequest{
Dest: "http://test.com",
Metadata: []quicpogs.Metadata{
{
Key: "HttpHeader:Cf-Cloudflared-Proxy-Connection-Upgrade",
Val: "Websocket",
},
{
Key: "HttpHeader:Content-Length",
Val: "514",
},
{
Key: "HttpHeader:Another-Header",
Val: "Misc",
},
{
Key: "HttpHost",
Val: "cf.host",
},
{
Key: "HttpMethod",
Val: "get",
},
},
},
req: &http.Request{
Method: "get",
URL: &url.URL{
Scheme: "http",
Host: "test.com",
},
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{
"Another-Header": []string{"Misc"},
"Content-Length": []string{"514"},
},
ContentLength: 514,
Host: "cf.host",
Body: io.NopCloser(&bytes.Buffer{}),
},
body: io.NopCloser(&bytes.Buffer{}),
},
{
name: "if content length isn't part of request headers, then it's not set",
connectRequest: &quicpogs.ConnectRequest{
Dest: "http://test.com",
Metadata: []quicpogs.Metadata{
{
Key: "HttpHeader:Cf-Cloudflared-Proxy-Connection-Upgrade",
Val: "Websocket",
},
{
Key: "HttpHeader:Another-Header",
Val: "Misc",
},
{
Key: "HttpHost",
Val: "cf.host",
},
{
Key: "HttpMethod",
Val: "get",
},
},
},
req: &http.Request{
Method: "get",
URL: &url.URL{
Scheme: "http",
Host: "test.com",
},
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{
"Another-Header": []string{"Misc"},
},
ContentLength: 0,
Host: "cf.host",
Body: nil,
},
body: io.NopCloser(&bytes.Buffer{}),
},
{
name: "if content length is 0, but transfer-encoding is chunked, body is not nil",
connectRequest: &quicpogs.ConnectRequest{
Dest: "http://test.com",
Metadata: []quicpogs.Metadata{
{
Key: "HttpHeader:Another-Header",
Val: "Misc",
},
{
Key: "HttpHeader:Transfer-Encoding",
Val: "chunked",
},
{
Key: "HttpHost",
Val: "cf.host",
},
{
Key: "HttpMethod",
Val: "get",
},
},
},
req: &http.Request{
Method: "get",
URL: &url.URL{
Scheme: "http",
Host: "test.com",
},
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{
"Another-Header": []string{"Misc"},
"Transfer-Encoding": []string{"chunked"},
},
ContentLength: 0,
Host: "cf.host",
Body: io.NopCloser(&bytes.Buffer{}),
},
body: io.NopCloser(&bytes.Buffer{}),
},
{
name: "if content length is 0, but transfer-encoding is gzip,chunked, body is not nil",
connectRequest: &quicpogs.ConnectRequest{
Dest: "http://test.com",
Metadata: []quicpogs.Metadata{
{
Key: "HttpHeader:Another-Header",
Val: "Misc",
},
{
Key: "HttpHeader:Transfer-Encoding",
Val: "gzip,chunked",
},
{
Key: "HttpHost",
Val: "cf.host",
},
{
Key: "HttpMethod",
Val: "get",
},
},
},
req: &http.Request{
Method: "get",
URL: &url.URL{
Scheme: "http",
Host: "test.com",
},
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{
"Another-Header": []string{"Misc"},
"Transfer-Encoding": []string{"gzip,chunked"},
},
ContentLength: 0,
Host: "cf.host",
Body: io.NopCloser(&bytes.Buffer{}),
},
body: io.NopCloser(&bytes.Buffer{}),
},
{
name: "if content length is 0, and connect request is a websocket, body is not nil",
connectRequest: &quicpogs.ConnectRequest{
Type: quicpogs.ConnectionTypeWebsocket,
Dest: "http://test.com",
Metadata: []quicpogs.Metadata{
{
Key: "HttpHeader:Another-Header",
Val: "Misc",
},
{
Key: "HttpHost",
Val: "cf.host",
},
{
Key: "HttpMethod",
Val: "get",
},
},
},
req: &http.Request{
Method: "get",
URL: &url.URL{
Scheme: "http",
Host: "test.com",
},
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{
"Another-Header": []string{"Misc"},
},
ContentLength: 0,
Host: "cf.host",
Body: io.NopCloser(&bytes.Buffer{}),
},
body: io.NopCloser(&bytes.Buffer{}),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req, err := buildHTTPRequest(test.connectRequest, test.body)
assert.NoError(t, err)
test.req = test.req.WithContext(req.Context())
assert.Equal(t, test.req, req)
})
}
}
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
rwa.AckConnection()
io.Copy(rwa, rwa)
return nil
}
func TestServeUDPSession(t *testing.T) {
// Start a UDP Listener for QUIC.
udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0")
require.NoError(t, err)
udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr)
require.NoError(t, err)
defer udpListener.Close()
ctx, cancel := context.WithCancel(context.Background())
// Establish QUIC connection with edge
edgeQUICSessionChan := make(chan quic.Session)
go func() {
earlyListener, err := quic.Listen(udpListener, testTLSServerConfig, testQUICConfig)
require.NoError(t, err)
edgeQUICSession, err := earlyListener.Accept(ctx)
require.NoError(t, err)
edgeQUICSessionChan <- edgeQUICSession
}()
qc := testQUICConnection(udpListener.LocalAddr(), t)
go qc.Serve(ctx)
edgeQUICSession := <-edgeQUICSessionChan
serveSession(ctx, qc, edgeQUICSession, closedByOrigin, io.EOF.Error(), t)
serveSession(ctx, qc, edgeQUICSession, closedByTimeout, datagramsession.SessionIdleErr(time.Millisecond*50).Error(), t)
serveSession(ctx, qc, edgeQUICSession, closedByRemote, "eyeball closed connection", t)
cancel()
}
func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.Session, closeType closeReason, expectedReason string, t *testing.T) {
var (
payload = []byte(t.Name())
)
sessionID := uuid.New()
cfdConn, originConn := net.Pipe()
// Registers and run a new session
session, err := qc.sessionManager.RegisterSession(ctx, sessionID, cfdConn)
require.NoError(t, err)
sessionDone := make(chan struct{})
go func() {
qc.serveUDPSession(session, time.Millisecond*50)
close(sessionDone)
}()
// Send a message to the quic session on edge side, it should be deumx to this datagram session
muxedPayload := append(payload, sessionID[:]...)
err = edgeQUICSession.SendMessage(muxedPayload)
require.NoError(t, err)
readBuffer := make([]byte, len(payload)+1)
n, err := originConn.Read(readBuffer)
require.NoError(t, err)
require.Equal(t, len(payload), n)
require.True(t, bytes.Equal(payload, readBuffer[:n]))
// Close connection to terminate session
switch closeType {
case closedByOrigin:
originConn.Close()
case closedByRemote:
err = qc.UnregisterUdpSession(ctx, sessionID, expectedReason)
require.NoError(t, err)
case closedByTimeout:
}
if closeType != closedByRemote {
// Session was not closed by remote, so closeUDPSession should be invoked to unregister from remote
unregisterFromEdgeChan := make(chan struct{})
rpcServer := &mockSessionRPCServer{
sessionID: sessionID,
unregisterReason: expectedReason,
calledUnregisterChan: unregisterFromEdgeChan,
}
go runMockSessionRPCServer(ctx, edgeQUICSession, rpcServer, t)
<-unregisterFromEdgeChan
}
<-sessionDone
}
type closeReason uint8
const (
closedByOrigin closeReason = iota
closedByRemote
closedByTimeout
)
func runMockSessionRPCServer(ctx context.Context, session quic.Session, rpcServer *mockSessionRPCServer, t *testing.T) {
stream, err := session.AcceptStream(ctx)
require.NoError(t, err)
if stream.StreamID() == 0 {
// Skip the first stream, it's the control stream of the QUIC connection
stream, err = session.AcceptStream(ctx)
require.NoError(t, err)
}
protocol, err := quicpogs.DetermineProtocol(stream)
assert.NoError(t, err)
rpcServerStream, err := quicpogs.NewRPCServerStream(stream, protocol)
assert.NoError(t, err)
log := zerolog.New(os.Stdout)
err = rpcServerStream.Serve(rpcServer, &log)
assert.NoError(t, err)
}
type mockSessionRPCServer struct {
sessionID uuid.UUID
unregisterReason string
calledUnregisterChan chan struct{}
}
func (s mockSessionRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error {
return fmt.Errorf("mockSessionRPCServer doesn't implement RegisterUdpSession")
}
func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, reason string) error {
if s.sessionID != sessionID {
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
}
if s.unregisterReason != reason {
return fmt.Errorf("expect unregister reason %s, got %s", s.unregisterReason, reason)
}
close(s.calledUnregisterChan)
fmt.Println("unregister from edge")
return nil
}
func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection {
tlsClientConfig := &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"argotunnel"},
}
// Start a mock httpProxy
originProxy := &mockOriginProxyWithRequest{}
log := zerolog.New(os.Stdout)
qc, err := NewQUICConnection(
testQUICConfig,
udpListenerAddr,
tlsClientConfig,
originProxy,
&tunnelpogs.ConnectionOptions{},
fakeControlStream{},
&log,
)
require.NoError(t, err)
return qc
}

50
datagramsession/event.go Normal file
View File

@ -0,0 +1,50 @@
package datagramsession
import (
"fmt"
"io"
"github.com/google/uuid"
)
// registerSessionEvent is an event to start tracking a new session
type registerSessionEvent struct {
sessionID uuid.UUID
originProxy io.ReadWriteCloser
resultChan chan *Session
}
func newRegisterSessionEvent(sessionID uuid.UUID, originProxy io.ReadWriteCloser) *registerSessionEvent {
return &registerSessionEvent{
sessionID: sessionID,
originProxy: originProxy,
resultChan: make(chan *Session, 1),
}
}
// unregisterSessionEvent is an event to stop tracking and terminate the session.
type unregisterSessionEvent struct {
sessionID uuid.UUID
err *errClosedSession
}
// ClosedSessionError represent a condition that closes the session other than I/O
// I/O error is not included, because the side that closes the session is ambiguous.
type errClosedSession struct {
message string
byRemote bool
}
func (sc *errClosedSession) Error() string {
if sc.byRemote {
return fmt.Sprintf("session closed by remote due to %s", sc.message)
} else {
return fmt.Sprintf("session closed by local due to %s", sc.message)
}
}
// newDatagram is an event when transport receives new datagram
type newDatagram struct {
sessionID uuid.UUID
payload []byte
}

144
datagramsession/manager.go Normal file
View File

@ -0,0 +1,144 @@
package datagramsession
import (
"context"
"io"
"github.com/google/uuid"
"github.com/lucas-clemente/quic-go"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
)
const (
requestChanCapacity = 16
)
// Manager defines the APIs to manage sessions from the same transport.
type Manager interface {
// Serve starts the event loop
Serve(ctx context.Context) error
// RegisterSession starts tracking a session. Caller is responsible for starting the session
RegisterSession(ctx context.Context, sessionID uuid.UUID, dstConn io.ReadWriteCloser) (*Session, error)
// UnregisterSession stops tracking the session and terminates it
UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error
}
type manager struct {
registrationChan chan *registerSessionEvent
unregistrationChan chan *unregisterSessionEvent
datagramChan chan *newDatagram
transport transport
sessions map[uuid.UUID]*Session
log *zerolog.Logger
}
func NewManager(transport transport, log *zerolog.Logger) Manager {
return &manager{
registrationChan: make(chan *registerSessionEvent),
unregistrationChan: make(chan *unregisterSessionEvent),
// datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events
datagramChan: make(chan *newDatagram, requestChanCapacity),
transport: transport,
sessions: make(map[uuid.UUID]*Session),
log: log,
}
}
func (m *manager) Serve(ctx context.Context) error {
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
for {
sessionID, payload, err := m.transport.ReceiveFrom()
if err != nil {
if aerr, ok := err.(*quic.ApplicationError); ok && uint64(aerr.ErrorCode) == uint64(quic.NoError) {
return nil
} else {
return err
}
}
datagram := &newDatagram{
sessionID: sessionID,
payload: payload,
}
select {
case <-ctx.Done():
return ctx.Err()
// Only the event loop routine can update/lookup the sessions map to avoid concurrent access
// Send the datagram to the event loop. It will find the session to send to
case m.datagramChan <- datagram:
}
}
})
errGroup.Go(func() error {
for {
select {
case <-ctx.Done():
return nil
case datagram := <-m.datagramChan:
m.sendToSession(datagram)
case registration := <-m.registrationChan:
m.registerSession(ctx, registration)
// TODO: TUN-5422: Unregister inactive session upon timeout
case unregistration := <-m.unregistrationChan:
m.unregisterSession(unregistration)
}
}
})
return errGroup.Wait()
}
func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) {
event := newRegisterSessionEvent(sessionID, originProxy)
select {
case <-ctx.Done():
return nil, ctx.Err()
case m.registrationChan <- event:
session := <-event.resultChan
return session, nil
}
}
func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) {
session := newSession(registration.sessionID, m.transport, registration.originProxy, m.log)
m.sessions[registration.sessionID] = session
registration.resultChan <- session
}
func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error {
event := &unregisterSessionEvent{
sessionID: sessionID,
err: &errClosedSession{
message: message,
byRemote: byRemote,
},
}
select {
case <-ctx.Done():
return ctx.Err()
case m.unregistrationChan <- event:
return nil
}
}
func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) {
session, ok := m.sessions[unregistration.sessionID]
if ok {
delete(m.sessions, unregistration.sessionID)
session.close(unregistration.err)
}
}
func (m *manager) sendToSession(datagram *newDatagram) {
session, ok := m.sessions[datagram.sessionID]
if !ok {
m.log.Error().Str("sessionID", datagram.sessionID.String()).Msg("session not found")
return
}
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
// need to run in another go routine
_, err := session.transportToDst(datagram.payload)
if err != nil {
m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session")
}
}

View File

@ -0,0 +1,222 @@
package datagramsession
import (
"bytes"
"context"
"fmt"
"io"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
func TestManagerServe(t *testing.T) {
const (
sessions = 20
msgs = 50
remoteUnregisterMsg = "eyeball closed connection"
)
log := zerolog.Nop()
transport := &mockQUICTransport{
reqChan: newDatagramChannel(1),
respChan: newDatagramChannel(1),
}
mg := NewManager(transport, &log)
eyeballTracker := make(map[uuid.UUID]*datagramChannel)
for i := 0; i < sessions; i++ {
sessionID := uuid.New()
eyeballTracker[sessionID] = newDatagramChannel(1)
}
ctx, cancel := context.WithCancel(context.Background())
serveDone := make(chan struct{})
go func(ctx context.Context) {
mg.Serve(ctx)
close(serveDone)
}(ctx)
go func(ctx context.Context) {
for {
sessionID, payload, err := transport.respChan.Receive(ctx)
if err != nil {
require.Equal(t, context.Canceled, err)
return
}
respChan := eyeballTracker[sessionID]
require.NoError(t, respChan.Send(ctx, sessionID, payload))
}
}(ctx)
errGroup, ctx := errgroup.WithContext(ctx)
for sID, receiver := range eyeballTracker {
// Assign loop variables to local variables
sessionID := sID
eyeballRespReceiver := receiver
errGroup.Go(func() error {
payload := testPayload(sessionID)
expectResp := testResponse(payload)
cfdConn, originConn := net.Pipe()
origin := mockOrigin{
expectMsgCount: msgs,
expectedMsg: payload,
expectedResp: expectResp,
conn: originConn,
}
eyeball := mockEyeball{
expectMsgCount: msgs,
expectedMsg: expectResp,
expectSessionID: sessionID,
respReceiver: eyeballRespReceiver,
}
reqErrGroup, reqCtx := errgroup.WithContext(ctx)
reqErrGroup.Go(func() error {
return origin.serve()
})
reqErrGroup.Go(func() error {
return eyeball.serve(reqCtx)
})
session, err := mg.RegisterSession(ctx, sessionID, cfdConn)
require.NoError(t, err)
sessionDone := make(chan struct{})
go func() {
closedByRemote, err := session.Serve(ctx, time.Minute*2)
closeSession := &errClosedSession{
message: remoteUnregisterMsg,
byRemote: true,
}
require.Equal(t, closeSession, err)
require.True(t, closedByRemote)
close(sessionDone)
}()
for i := 0; i < msgs; i++ {
require.NoError(t, transport.newRequest(ctx, sessionID, testPayload(sessionID)))
}
// Make sure eyeball and origin have received all messages before unregistering the session
require.NoError(t, reqErrGroup.Wait())
require.NoError(t, mg.UnregisterSession(ctx, sessionID, remoteUnregisterMsg, true))
<-sessionDone
return nil
})
}
require.NoError(t, errGroup.Wait())
cancel()
transport.close()
<-serveDone
}
type mockOrigin struct {
expectMsgCount int
expectedMsg []byte
expectedResp []byte
conn io.ReadWriteCloser
}
func (mo *mockOrigin) serve() error {
expectedMsgLen := len(mo.expectedMsg)
readBuffer := make([]byte, expectedMsgLen+1)
for i := 0; i < mo.expectMsgCount; i++ {
n, err := mo.conn.Read(readBuffer)
if err != nil {
return err
}
if n != expectedMsgLen {
return fmt.Errorf("Expect to read %d bytes, read %d", expectedMsgLen, n)
}
if !bytes.Equal(readBuffer[:n], mo.expectedMsg) {
return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n])
}
_, err = mo.conn.Write(mo.expectedResp)
if err != nil {
return err
}
}
return nil
}
func testPayload(sessionID uuid.UUID) []byte {
return []byte(fmt.Sprintf("Message from %s", sessionID))
}
func testResponse(msg []byte) []byte {
return []byte(fmt.Sprintf("Response to %v", msg))
}
type mockEyeball struct {
expectMsgCount int
expectedMsg []byte
expectSessionID uuid.UUID
respReceiver *datagramChannel
}
func (me *mockEyeball) serve(ctx context.Context) error {
for i := 0; i < me.expectMsgCount; i++ {
sessionID, msg, err := me.respReceiver.Receive(ctx)
if err != nil {
return err
}
if sessionID != me.expectSessionID {
return fmt.Errorf("Expect session %s, got %s", me.expectSessionID, sessionID)
}
if !bytes.Equal(msg, me.expectedMsg) {
return fmt.Errorf("Expect %v, read %v", me.expectedMsg, msg)
}
}
return nil
}
// datagramChannel is a channel for Datagram with wrapper to send/receive with context
type datagramChannel struct {
datagramChan chan *newDatagram
closedChan chan struct{}
}
func newDatagramChannel(capacity uint) *datagramChannel {
return &datagramChannel{
datagramChan: make(chan *newDatagram, capacity),
closedChan: make(chan struct{}),
}
}
func (rc *datagramChannel) Send(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-rc.closedChan:
return fmt.Errorf("datagram channel closed")
case rc.datagramChan <- &newDatagram{sessionID: sessionID, payload: payload}:
return nil
}
}
func (rc *datagramChannel) Receive(ctx context.Context) (uuid.UUID, []byte, error) {
select {
case <-ctx.Done():
return uuid.Nil, nil, ctx.Err()
case <-rc.closedChan:
return uuid.Nil, nil, fmt.Errorf("datagram channel closed")
case msg := <-rc.datagramChan:
return msg.sessionID, msg.payload, nil
}
}
func (rc *datagramChannel) Close() {
// No need to close msgChan, it will be garbage collect once there is no reference to it
close(rc.closedChan)
}

140
datagramsession/session.go Normal file
View File

@ -0,0 +1,140 @@
package datagramsession
import (
"context"
"fmt"
"io"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog"
)
const (
defaultCloseIdleAfter = time.Second * 210
)
func SessionIdleErr(timeout time.Duration) error {
return fmt.Errorf("session idle for %v", timeout)
}
// Session is a bidirectional pipe of datagrams between transport and dstConn
// Currently the only implementation of transport is quic DatagramMuxer
// Destination can be a connection with origin or with eyeball
// When the destination is origin:
// - Datagrams from edge are read by Manager from the transport. Manager finds the corresponding Session and calls the
// write method of the Session to send to origin
// - Datagrams from origin are read from conn and SentTo transport. Transport will return them to eyeball
// When the destination is eyeball:
// - Datagrams from eyeball are read from conn and SentTo transport. Transport will send them to cloudflared
// - Datagrams from cloudflared are read by Manager from the transport. Manager finds the corresponding Session and calls the
// write method of the Session to send to eyeball
type Session struct {
ID uuid.UUID
transport transport
dstConn io.ReadWriteCloser
// activeAtChan is used to communicate the last read/write time
activeAtChan chan time.Time
closeChan chan error
log *zerolog.Logger
}
func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser, log *zerolog.Logger) *Session {
return &Session{
ID: id,
transport: transport,
dstConn: dstConn,
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
// drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 2),
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
closeChan: make(chan error, 2),
log: log,
}
}
func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (closedByRemote bool, err error) {
go func() {
// QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
// This makes it safe to share readBuffer between iterations
const maxPacketSize = 1500
readBuffer := make([]byte, maxPacketSize)
for {
if err := s.dstToTransport(readBuffer); err != nil {
s.closeChan <- err
return
}
}
}()
err = s.waitForCloseCondition(ctx, closeAfterIdle)
if closeSession, ok := err.(*errClosedSession); ok {
closedByRemote = closeSession.byRemote
}
return closedByRemote, err
}
func (s *Session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error {
// Closing dstConn cancels read so dstToTransport routine in Serve() can return
defer s.dstConn.Close()
if closeAfterIdle == 0 {
// provide deafult is caller doesn't specify one
closeAfterIdle = defaultCloseIdleAfter
}
checkIdleFreq := closeAfterIdle / 8
checkIdleTicker := time.NewTicker(checkIdleFreq)
defer checkIdleTicker.Stop()
activeAt := time.Now()
for {
select {
case <-ctx.Done():
return ctx.Err()
case reason := <-s.closeChan:
return reason
// TODO: TUN-5423 evaluate if using atomic is more efficient
case now := <-checkIdleTicker.C:
// The session is considered inactive if current time is after (last active time + allowed idle time)
if now.After(activeAt.Add(closeAfterIdle)) {
return SessionIdleErr(closeAfterIdle)
}
case activeAt = <-s.activeAtChan: // Update last active time
}
}
}
func (s *Session) dstToTransport(buffer []byte) error {
n, err := s.dstConn.Read(buffer)
s.markActive()
if n > 0 {
if n <= int(s.transport.MTU()) {
err = s.transport.SendTo(s.ID, buffer[:n])
} else {
// drop packet for now, eventually reply with ICMP for PMTUD
s.log.Debug().
Str("session", s.ID.String()).
Int("len", n).
Int("mtu", s.transport.MTU()).
Msg("dropped packet exceeding MTU")
}
}
return err
}
func (s *Session) transportToDst(payload []byte) (int, error) {
s.markActive()
return s.dstConn.Write(payload)
}
// Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there
// are many concurrent read/write. It is fine to lose some precision
func (s *Session) markActive() {
select {
case s.activeAtChan <- time.Now():
default:
}
}
func (s *Session) close(err *errClosedSession) {
s.closeChan <- err
}

View File

@ -0,0 +1,197 @@
package datagramsession
import (
"bytes"
"context"
"fmt"
"io"
"net"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
// TestCloseSession makes sure a session will stop after context is done
func TestSessionCtxDone(t *testing.T) {
testSessionReturns(t, closeByContext, time.Minute*2)
}
// TestCloseSession makes sure a session will stop after close method is called
func TestCloseSession(t *testing.T) {
testSessionReturns(t, closeByCallingClose, time.Minute*2)
}
// TestCloseIdle makess sure a session will stop after there is no read/write for a period defined by closeAfterIdle
func TestCloseIdle(t *testing.T) {
testSessionReturns(t, closeByTimeout, time.Millisecond*100)
}
func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.Duration) {
var (
localCloseReason = &errClosedSession{
message: "connection closed by origin",
byRemote: false,
}
)
sessionID := uuid.New()
cfdConn, originConn := net.Pipe()
payload := testPayload(sessionID)
transport := &mockQUICTransport{
reqChan: newDatagramChannel(1),
respChan: newDatagramChannel(1),
}
log := zerolog.Nop()
session := newSession(sessionID, transport, cfdConn, &log)
ctx, cancel := context.WithCancel(context.Background())
sessionDone := make(chan struct{})
go func() {
closedByRemote, err := session.Serve(ctx, closeAfterIdle)
switch closeBy {
case closeByContext:
require.Equal(t, context.Canceled, err)
require.False(t, closedByRemote)
case closeByCallingClose:
require.Equal(t, localCloseReason, err)
require.Equal(t, localCloseReason.byRemote, closedByRemote)
case closeByTimeout:
require.Equal(t, SessionIdleErr(closeAfterIdle), err)
require.False(t, closedByRemote)
}
close(sessionDone)
}()
go func() {
n, err := session.transportToDst(payload)
require.NoError(t, err)
require.Equal(t, len(payload), n)
}()
readBuffer := make([]byte, len(payload)+1)
n, err := originConn.Read(readBuffer)
require.NoError(t, err)
require.Equal(t, len(payload), n)
lastRead := time.Now()
switch closeBy {
case closeByContext:
cancel()
case closeByCallingClose:
session.close(localCloseReason)
}
<-sessionDone
if closeBy == closeByTimeout {
require.True(t, time.Now().After(lastRead.Add(closeAfterIdle)))
}
// call cancelled again otherwise the linter will warn about possible context leak
cancel()
}
type closeMethod int
const (
closeByContext closeMethod = iota
closeByCallingClose
closeByTimeout
)
func TestWriteToDstSessionPreventClosed(t *testing.T) {
testActiveSessionNotClosed(t, false, true)
}
func TestReadFromDstSessionPreventClosed(t *testing.T) {
testActiveSessionNotClosed(t, true, false)
}
func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) {
const closeAfterIdle = time.Millisecond * 100
const activeTime = time.Millisecond * 500
sessionID := uuid.New()
cfdConn, originConn := net.Pipe()
payload := testPayload(sessionID)
transport := &mockQUICTransport{
reqChan: newDatagramChannel(100),
respChan: newDatagramChannel(100),
}
log := zerolog.Nop()
session := newSession(sessionID, transport, cfdConn, &log)
startTime := time.Now()
activeUntil := startTime.Add(activeTime)
ctx, cancel := context.WithCancel(context.Background())
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
session.Serve(ctx, closeAfterIdle)
if time.Now().Before(startTime.Add(activeTime)) {
return fmt.Errorf("session closed while it's still active")
}
return nil
})
if readFromDst {
errGroup.Go(func() error {
for {
if time.Now().After(activeUntil) {
return nil
}
if _, err := originConn.Write(payload); err != nil {
return err
}
time.Sleep(closeAfterIdle / 2)
}
})
}
if writeToDst {
errGroup.Go(func() error {
readBuffer := make([]byte, len(payload))
for {
n, err := originConn.Read(readBuffer)
if err != nil {
if err == io.EOF || err == io.ErrClosedPipe {
return nil
}
return err
}
if !bytes.Equal(payload, readBuffer[:n]) {
return fmt.Errorf("payload %v is not equal to %v", readBuffer[:n], payload)
}
}
})
errGroup.Go(func() error {
for {
if time.Now().After(activeUntil) {
return nil
}
if _, err := session.transportToDst(payload); err != nil {
return err
}
time.Sleep(closeAfterIdle / 2)
}
})
}
require.NoError(t, errGroup.Wait())
cancel()
}
func TestMarkActiveNotBlocking(t *testing.T) {
const concurrentCalls = 50
session := newSession(uuid.New(), nil, nil, nil)
var wg sync.WaitGroup
wg.Add(concurrentCalls)
for i := 0; i < concurrentCalls; i++ {
go func() {
session.markActive()
wg.Done()
}()
}
wg.Wait()
}

View File

@ -0,0 +1,13 @@
package datagramsession
import "github.com/google/uuid"
// Transport is a connection between cloudflared and edge that can multiplex datagrams from multiple sessions
type transport interface {
// SendTo writes payload for a session to the transport
SendTo(sessionID uuid.UUID, payload []byte) error
// ReceiveFrom reads the next datagram from the transport
ReceiveFrom() (uuid.UUID, []byte, error)
// Max transmission unit to receive from the transport
MTU() int
}

View File

@ -0,0 +1,36 @@
package datagramsession
import (
"context"
"github.com/google/uuid"
)
type mockQUICTransport struct {
reqChan *datagramChannel
respChan *datagramChannel
}
func (mt *mockQUICTransport) SendTo(sessionID uuid.UUID, payload []byte) error {
buf := make([]byte, len(payload))
// The QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
copy(buf, payload)
return mt.respChan.Send(context.Background(), sessionID, buf)
}
func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) {
return mt.reqChan.Receive(context.Background())
}
func (mt *mockQUICTransport) MTU() int {
return 1280
}
func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
return mt.reqChan.Send(ctx, sessionID, payload)
}
func (mt *mockQUICTransport) close() {
mt.reqChan.Close()
mt.respChan.Close()
}

View File

@ -1,4 +1,4 @@
FROM golang:1.16.4 as builder
FROM golang:1.17.5 as builder
ENV GO111MODULE=on \
CGO_ENABLED=0
WORKDIR /go/src/github.com/cloudflare/cloudflared/

View File

@ -1,45 +1,50 @@
package edgediscovery
import (
"encoding/json"
"fmt"
"net"
"strconv"
"strings"
)
const (
protocolRecord = "protocol.argotunnel.com"
protocolRecord = "protocol-v2.argotunnel.com"
)
var (
errNoProtocolRecord = fmt.Errorf("No TXT record found for %s to determine connection protocol", protocolRecord)
)
func HTTP2Percentage() (int32, error) {
// ProtocolPercent represents a single Protocol Percentage combination.
type ProtocolPercent struct {
Protocol string `json:"protocol"`
Percentage int32 `json:"percentage"`
}
// ProtocolPercents represents the preferred distribution ratio of protocols when protocol isn't specified.
type ProtocolPercents []ProtocolPercent
// GetPercentage returns the threshold percentage of a single protocol.
func (p ProtocolPercents) GetPercentage(protocol string) int32 {
for _, protocolPercent := range p {
if strings.ToLower(protocolPercent.Protocol) == strings.ToLower(protocol) {
return protocolPercent.Percentage
}
}
return 0
}
// ProtocolPercentage returns the ratio of protocols and a specification ratio for their selection.
func ProtocolPercentage() (ProtocolPercents, error) {
records, err := net.LookupTXT(protocolRecord)
if err != nil {
return 0, err
return nil, err
}
if len(records) == 0 {
return 0, errNoProtocolRecord
return nil, errNoProtocolRecord
}
return parseHTTP2Precentage(records[0])
}
// The record looks like http2=percentage
func parseHTTP2Precentage(record string) (int32, error) {
const key = "http2"
slices := strings.Split(record, "=")
if len(slices) != 2 {
return 0, fmt.Errorf("Malformed TXT record %s, expect http2=percentage", record)
}
if slices[0] != key {
return 0, fmt.Errorf("Incorrect key %s, expect %s", slices[0], key)
}
percentage, err := strconv.ParseInt(slices[1], 10, 32)
if err != nil {
return 0, err
}
return int32(percentage), nil
var protocolsWithPercent ProtocolPercents
err = json.Unmarshal([]byte(records[0]), &protocolsWithPercent)
return protocolsWithPercent, err
}

View File

@ -6,75 +6,7 @@ import (
"github.com/stretchr/testify/assert"
)
func TestHTTP2Percentage(t *testing.T) {
_, err := HTTP2Percentage()
func TestProtocolPercentage(t *testing.T) {
_, err := ProtocolPercentage()
assert.NoError(t, err)
}
func TestParseHTTP2Precentage(t *testing.T) {
tests := []struct {
record string
percentage int32
wantErr bool
}{
{
record: "http2=-1",
percentage: -1,
wantErr: false,
},
{
record: "http2=0",
percentage: 0,
wantErr: false,
},
{
record: "http2=50",
percentage: 50,
wantErr: false,
},
{
record: "http2=100",
percentage: 100,
wantErr: false,
},
{
record: "http2=1000",
percentage: 1000,
wantErr: false,
},
{
record: "http2=10.5",
wantErr: true,
},
{
record: "http2=10 h2mux=90",
wantErr: true,
},
{
record: "http2=ten",
wantErr: true,
},
{
record: "h2mux=100",
wantErr: true,
},
{
record: "http2",
wantErr: true,
},
{
record: "http2=",
wantErr: true,
},
}
for _, test := range tests {
p, err := parseHTTP2Precentage(test.record)
if test.wantErr {
assert.Error(t, err)
} else {
assert.Equal(t, test.percentage, p)
}
}
}

View File

@ -72,7 +72,7 @@ def get_or_create_release(repo, version, dry_run=False):
except UnknownObjectException:
logging.info("Release %s not found", version)
# We dont want to create a new release tag if one doesnt already exist
# We don't want to create a new release tag if one doesn't already exist
assert_tag_exists(repo, version)
if dry_run:
@ -198,7 +198,7 @@ def upload_asset(release, filepath, filename, release_version, kv_account_id, na
pass # the macOS release copy fails with being the same file (already in the artifacts directory)
def main():
""" Attempts to upload Asset to Github Release. Creates Release if it doesnt exist """
""" Attempts to upload Asset to Github Release. Creates Release if it doesn't exist """
try:
args = parse_args()
client = Github(args.api_key)

12
go.mod
View File

@ -28,7 +28,7 @@ require (
github.com/json-iterator/go v1.1.10
github.com/kr/text v0.2.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/lucas-clemente/quic-go v0.23.0
github.com/lucas-clemente/quic-go v0.24.0
github.com/mattn/go-colorable v0.1.8
github.com/miekg/dns v1.1.31
github.com/mitchellh/go-homedir v1.1.0
@ -45,11 +45,11 @@ require (
github.com/stretchr/testify v1.6.0
github.com/urfave/cli/v2 v2.2.0
go.uber.org/automaxprocs v1.4.0
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d
golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20210510120138-977fb7262007
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1
golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf
google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d // indirect
google.golang.org/grpc v1.32.0 // indirect
@ -70,6 +70,7 @@ require (
github.com/cheekybits/genny v1.0.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect
github.com/francoispqt/gojay v1.2.13 // indirect
github.com/gdamore/encoding v1.0.0 // indirect
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect
github.com/golang/protobuf v1.5.2 // indirect
@ -77,6 +78,7 @@ require (
github.com/lucasb-eyer/go-colorful v1.0.3 // indirect
github.com/marten-seemann/qtls-go1-16 v0.1.4 // indirect
github.com/marten-seemann/qtls-go1-17 v0.1.0 // indirect
github.com/marten-seemann/qtls-go1-18 v0.1.0-beta.1 // indirect
github.com/mattn/go-isatty v0.0.12 // indirect
github.com/mattn/go-runewidth v0.0.8 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
@ -97,3 +99,5 @@ require (
)
replace github.com/urfave/cli/v2 => github.com/ipostelnik/cli/v2 v2.3.1-0.20210324024421-b6ea8234fe3d
replace github.com/lucas-clemente/quic-go => github.com/chungthuang/quic-go v0.24.1-0.20220110095058-981dc498cb62

19
go.sum
View File

@ -125,6 +125,12 @@ github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE=
github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ=
github.com/chungthuang/quic-go v0.24.1-0.20220106111256-154e7d8a89a9 h1:sHrAhwM2NHkb/5z7+cxDFMCvG3WnSAPbjqSbujLB3nU=
github.com/chungthuang/quic-go v0.24.1-0.20220106111256-154e7d8a89a9/go.mod h1:YtzP8bxRVCBlO77yRanE264+fY/T2U9ZlW1AaHOsMOg=
github.com/chungthuang/quic-go v0.24.1-0.20220106164320-fc99d36b9daa h1:QSi2gWSBtNtCH2/8Y6zFs4H5bnrHQQxFCzl7zJsPp28=
github.com/chungthuang/quic-go v0.24.1-0.20220106164320-fc99d36b9daa/go.mod h1:YtzP8bxRVCBlO77yRanE264+fY/T2U9ZlW1AaHOsMOg=
github.com/chungthuang/quic-go v0.24.1-0.20220110095058-981dc498cb62 h1:PLTB4iA6sOgAItzQY642tYdcGKfG/7i2gu93JQGgUcM=
github.com/chungthuang/quic-go v0.24.1-0.20220110095058-981dc498cb62/go.mod h1:YtzP8bxRVCBlO77yRanE264+fY/T2U9ZlW1AaHOsMOg=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
@ -197,6 +203,7 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk=
github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY=
github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4=
github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20=
@ -421,6 +428,8 @@ github.com/liquidweb/liquidweb-go v1.6.0/go.mod h1:UDcVnAMDkZxpw4Y7NOHkqoeiGacVL
github.com/lucas-clemente/quic-go v0.13.1/go.mod h1:Vn3/Fb0/77b02SGhQk36KzOUmXgVpFfizUfW5WMaqyU=
github.com/lucas-clemente/quic-go v0.23.0 h1:5vFnKtZ6nHDFsc/F3uuiF4T3y/AXaQdxjUqiVw26GZE=
github.com/lucas-clemente/quic-go v0.23.0/go.mod h1:paZuzjXCE5mj6sikVLMvqXk8lJV2AsqtJ6bDhjEfxx0=
github.com/lucas-clemente/quic-go v0.24.0 h1:ToR7SIIEdrgOhgVTHvPgdVRJfgVy+N0wQAagH7L4d5g=
github.com/lucas-clemente/quic-go v0.24.0/go.mod h1:paZuzjXCE5mj6sikVLMvqXk8lJV2AsqtJ6bDhjEfxx0=
github.com/lucasb-eyer/go-colorful v1.0.2/go.mod h1:0MS4r+7BZKSJ5mw4/S5MPN+qHFF1fYclkSPilDOKW0s=
github.com/lucasb-eyer/go-colorful v1.0.3 h1:QIbQXiugsb+q10B+MI+7DI1oQLdmnep86tWFlaaUAac=
github.com/lucasb-eyer/go-colorful v1.0.3/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
@ -437,6 +446,8 @@ github.com/marten-seemann/qtls-go1-16 v0.1.4 h1:xbHbOGGhrenVtII6Co8akhLEdrawwB2i
github.com/marten-seemann/qtls-go1-16 v0.1.4/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk=
github.com/marten-seemann/qtls-go1-17 v0.1.0 h1:P9ggrs5xtwiqXv/FHNwntmuLMNq3KaSIG93AtAZ48xk=
github.com/marten-seemann/qtls-go1-17 v0.1.0/go.mod h1:fz4HIxByo+LlWcreM4CZOYNuz3taBQ8rN2X6FqvaWo8=
github.com/marten-seemann/qtls-go1-18 v0.1.0-beta.1 h1:EnzzN9fPUkUck/1CuY1FlzBaIYMoiBsdwTNmNGkwUUM=
github.com/marten-seemann/qtls-go1-18 v0.1.0-beta.1/go.mod h1:PUhIQk19LoFt2174H4+an8TYvWOGjb/hHwphBeaDHwI=
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8=
github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
@ -732,6 +743,8 @@ golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPh
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM=
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -815,6 +828,10 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0=
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/net v0.0.0-20211109214657-ef0fda0de508 h1:v3NKo+t/Kc3EASxaKZ82lwK6mCf4ZeObQBduYFZHo7c=
golang.org/x/net v0.0.0-20211109214657-ef0fda0de508/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d h1:1n1fc535VhN8SYtD4cDUyNlfpAF2ROMM9+11equK3hs=
golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
@ -899,6 +916,8 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007 h1:gG67DSER+11cZvqIMb8S8bt0vZtiN6xWYARwirrOSfE=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf h1:MZ2shdL+ZM/XzY3ZGOnh4Nlpnxz5GSOhOmtHo3iPU6M=
golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=

View File

@ -93,7 +93,7 @@ func (m *activeStreamMap) Set(newStream *MuxedStream) bool {
return true
}
// Delete stops tracking the stream. It should be called only after it is closed and resetted.
// Delete stops tracking the stream. It should be called only after it is closed and reset.
func (m *activeStreamMap) Delete(streamID uint32) {
m.Lock()
defer m.Unlock()

View File

@ -1,3 +1,4 @@
//go:build cgo
// +build cgo
package h2mux

View File

@ -1,3 +1,4 @@
//go:build !cgo
// +build !cgo
package h2mux

View File

@ -12,7 +12,7 @@ import (
/* This is an implementation of https://github.com/vkrasnov/h2-compression-dictionaries
but modified for tunnels in a few key ways:
Since tunnels is a server-to-server service, some aspects of the spec would cause
unnessasary head-of-line blocking on the CPU and on the network, hence this implementation
unnecessary head-of-line blocking on the CPU and on the network, hence this implementation
allows for parallel compression on the "client", and buffering on the "server" to solve
this problem. */
@ -67,7 +67,7 @@ var compressionPresets = map[CompressionSetting]CompressionPreset{
}
func compressionSettingVal(version, fmt, sz, nd uint8) uint32 {
// Currently the compression settings are inlcude:
// Currently the compression settings are include:
// * version: only 1 is supported
// * fmt: only 2 for brotli is supported
// * sz: log2 of the maximal allowed dictionary size
@ -438,7 +438,7 @@ func assignDictToStream(s *MuxedStream, p []byte) bool {
h2d.dictLock.Lock()
if w.comp != nil {
// Check again with lock, in therory the inteface allows for unordered writes
// Check again with lock, in therory the interface allows for unordered writes
h2d.dictLock.Unlock()
return false
}
@ -468,7 +468,7 @@ func assignDictToStream(s *MuxedStream, p []byte) bool {
}
} else {
// Use the overflow dictionary as last resort
// If slots are availabe generate new dictioanries for path and content-type
// If slots are available generate new dictionaries for path and content-type
useID, _ = h2d.getGenericDictID()
pathID, pathFound = h2d.getNextDictID()
if pathFound {

View File

@ -174,7 +174,7 @@ func Handshake(
pingTimestamp := NewPingTimestamp()
connActive := NewSignal()
idleDuration := config.HeartbeatInterval
// Sanity check to enusre idelDuration is sane
// Sanity check to ensure idelDuration is sane
if idleDuration == 0 || idleDuration < defaultTimeout {
idleDuration = defaultTimeout
config.Log.Info().Msgf("muxer: Minimum idle time has been adjusted to %d", defaultTimeout)
@ -274,7 +274,7 @@ func (m *Muxer) readPeerSettings(magic uint32) error {
m.compressionQuality = compressionPresets[CompressionNone]
return nil
}
// Values used for compression are the mimimum between the two peers
// Values used for compression are the minimum between the two peers
if sz < m.compressionQuality.dictSize {
m.compressionQuality.dictSize = sz
}

View File

@ -130,7 +130,7 @@ func TestMuxMetricsUpdater(t *testing.T) {
m.updateReceiveWindow(uint32(j))
m.updateSendWindow(uint32(j))
// should always be disgarded since the send time is before readerSend
// should always be discarded since the send time is before readerSend
rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart.Add(-time.Duration(j*dataPoints) * time.Millisecond)}
m.updateRTT(rm)

View File

@ -37,7 +37,7 @@ type OriginUpTime struct {
UpTime string `json:"uptime"`
}
const defaultServerName = "the Argo Tunnel test server"
const defaultServerName = "the Cloudflare Tunnel test server"
const indexTemplate = `
<!DOCTYPE html>
<html lang="en">
@ -45,10 +45,10 @@ const indexTemplate = `
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=Edge">
<title>
Argo Tunnel Connection
Cloudflare Tunnel Connection
</title>
<meta name="author" content="">
<meta name="description" content="Argo Tunnel Connection">
<meta name="description" content="Cloudflare Tunnel Connection">
<meta name="viewport" content="width=device-width, initial-scale=1">
<style>
html{line-height:1.15;-ms-text-size-adjust:100%;-webkit-text-size-adjust:100%}body{margin:0}section{display:block}h1{font-size:2em;margin:.67em 0}a{background-color:transparent;-webkit-text-decoration-skip:objects}/* 1 */::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}/* 1 */a,body,dd,div,dl,dt,h1,h4,html,p,section{box-sizing:border-box}.bt{border-top-style:solid;border-top-width:1px}.bl{border-left-style:solid;border-left-width:1px}.b--orange{border-color:#f38020}.br1{border-radius:.125rem}.bw2{border-width:.25rem}.dib{display:inline-block}.sans-serif{font-family:open sans,-apple-system,BlinkMacSystemFont,avenir next,avenir,helvetica neue,helvetica,ubuntu,roboto,noto,segoe ui,arial,sans-serif}.overflow-x-auto{overflow-x:auto}.code{font-family:Consolas,monaco,monospace}.b{font-weight:700}.fw3{font-weight:300}.fw4{font-weight:400}.fw5{font-weight:500}.fw6{font-weight:600}.lh-copy{line-height:1.5}.link{text-decoration:none}.link,.link:active,.link:focus,.link:hover,.link:link,.link:visited{transition:color .15s ease-in}.link:focus{outline:1px dotted currentColor}.mw-100{max-width:100%}.mw4{max-width:8rem}.mw7{max-width:48rem}.bg-light-gray{background-color:#f7f7f7}.link-hover:hover{background-color:#1f679e}.white{color:#fff}.bg-white{background-color:#fff}.bg-blue{background-color:#408bc9}.pb2{padding-bottom:.5rem}.pb6{padding-bottom:8rem}.pt3{padding-top:1rem}.pt5{padding-top:4rem}.pv2{padding-top:.5rem;padding-bottom:.5rem}.ph3{padding-left:1rem;padding-right:1rem}.ph4{padding-left:2rem;padding-right:2rem}.ml0{margin-left:0}.mb1{margin-bottom:.25rem}.mb2{margin-bottom:.5rem}.mb3{margin-bottom:1rem}.mt5{margin-top:4rem}.ttu{text-transform:uppercase}.f4{font-size:1.25rem}.f5{font-size:1rem}.f6{font-size:.875rem}.f7{font-size:.75rem}.measure{max-width:30em}.center{margin-left:auto}.center{margin-right:auto}@media screen and (min-width:30em){.f2-ns{font-size:2.25rem}}@media screen and (min-width:30em) and (max-width:60em){.f5-m{font-size:1rem}}@media screen and (min-width:60em){.f4-l{font-size:1.25rem}}
@ -66,7 +66,7 @@ const indexTemplate = `
</svg>
<h1 class="f4 f2-ns mt5 fw5">Congrats! You created a tunnel!</h1>
<p class="f6 f5-m f4-l measure lh-copy fw3">
Argo Tunnel exposes locally running applications to the internet by
Cloudflare Tunnel exposes locally running applications to the internet by
running an encrypted, virtual tunnel from your laptop or server to
Cloudflare's edge network.
</p>
@ -74,7 +74,7 @@ const indexTemplate = `
<a
class="fw6 link white bg-blue ph4 pv2 br1 dib f5 link-hover"
style="border-bottom: 1px solid #1f679e"
href="https://developers.cloudflare.com/argo-tunnel/">
href="https://developers.cloudflare.com/cloudflare-one/connections/connect-apps">
Get started here
</a>
<section>

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
@ -19,7 +20,6 @@ import (
"golang.org/x/net/proxy"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/websocket"
@ -97,7 +97,7 @@ func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
}
// TestSocksStreamWSOverTCPConnection simulates proxying in socks mode.
// Eyeball side runs cloudflared accesss tcp with --url flag to start a websocket forwarder which
// Eyeball side runs cloudflared access tcp with --url flag to start a websocket forwarder which
// wraps SOCKS5 traffic in websocket
// Origin side runs a tcpOverWSConnection with socks.StreamHandler
func TestSocksStreamWSOverTCPConnection(t *testing.T) {
@ -192,8 +192,10 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
eyeballConn, err := connection.NewHTTP2RespWriter(r, w, connection.TypeWebsocket)
assert.NoError(t, err)
eyeballConn := &readWriter{
w: w,
r: r.Body,
}
cfdConn, originConn := net.Pipe()
tcpOverWSConn := tcpOverWSConnection{
@ -319,3 +321,16 @@ func echoTCPOrigin(t *testing.T, conn net.Conn) {
_, err = conn.Write(testResponse)
assert.NoError(t, err)
}
type readWriter struct {
w io.Writer
r io.Reader
}
func (r *readWriter) Read(p []byte) (n int, err error) {
return r.r.Read(p)
}
func (r *readWriter) Write(p []byte) (n int, err error) {
return r.w.Write(p)
}

View File

@ -40,6 +40,8 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
if o.hostHeader != "" {
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
// Pass the original Host header as X-Forwarded-Host.
req.Header.Set("X-Forwarded-Host", req.Host)
req.Host = o.hostHeader
}
return o.transport.RoundTrip(req)

View File

@ -3,6 +3,7 @@ package ingress
import (
"context"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
@ -118,7 +119,9 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
w.WriteHeader(http.StatusSwitchingProtocols)
return
}
w.Write([]byte("ok"))
// return the X-Forwarded-Host header for assertions
// as the httptest Server URL isn't available here yet
w.Write([]byte(r.Header.Get("X-Forwarded-Host")))
}
origin := httptest.NewServer(http.HandlerFunc(handler))
defer origin.Close()
@ -141,6 +144,10 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
respBody, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, respBody, []byte(originURL.Host))
}
func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {

View File

@ -0,0 +1,27 @@
package ingress
import (
"fmt"
"io"
"net"
)
type UDPProxy struct {
io.ReadWriteCloser
}
func DialUDP(dstIP net.IP, dstPort uint16) (*UDPProxy, error) {
dstAddr := &net.UDPAddr{
IP: dstIP,
Port: int(dstPort),
}
// We use nil as local addr to force runtime to find the best suitable local address IP given the destination
// address as context.
udpConn, err := net.DialUDP("udp", nil, dstAddr)
if err != nil {
return nil, fmt.Errorf("unable to create UDP proxy to origin (%v:%v): %w", dstIP, dstPort, err)
}
return &UDPProxy{udpConn}, nil
}

View File

@ -1,6 +1,6 @@
---
basename: "cloudflared"
comment: "Cloudflare Argo Tunnel"
comment: "Cloudflare Tunnel"
copyright: "Cloudflare, Inc"
arch: "x86"
abi: "64"

View File

@ -83,15 +83,15 @@ func ServeMetrics(
return err
}
func RegisterBuildInfo(buildTime string, version string) {
func RegisterBuildInfo(buildType, buildTime, version string) {
buildInfo := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
// Don't namespace build_info, since we want it to be consistent across all Cloudflare services
Name: "build_info",
Help: "Build and version information",
},
[]string{"goversion", "revision", "version"},
[]string{"goversion", "type", "revision", "version"},
)
prometheus.MustRegister(buildInfo)
buildInfo.WithLabelValues(runtime.Version(), buildTime, version).Set(1)
buildInfo.WithLabelValues(runtime.Version(), buildType, buildTime, version).Set(1)
}

View File

@ -4,46 +4,32 @@ import (
"encoding/json"
"fmt"
"net/http"
"sync"
conn "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/tunnelstate"
"github.com/rs/zerolog"
)
// ReadyServer serves HTTP 200 if the tunnel can serve traffic. Intended for k8s readiness checks.
type ReadyServer struct {
sync.RWMutex
isConnected map[int]bool
log *zerolog.Logger
tracker *tunnelstate.ConnTracker
}
// NewReadyServer initializes a ReadyServer and starts listening for dis/connection events.
func NewReadyServer(log *zerolog.Logger) *ReadyServer {
return &ReadyServer{
isConnected: make(map[int]bool, 0),
log: log,
tracker: tunnelstate.NewConnTracker(log),
}
}
func (rs *ReadyServer) OnTunnelEvent(c conn.Event) {
switch c.EventType {
case conn.Connected:
rs.Lock()
rs.isConnected[int(c.Index)] = true
rs.Unlock()
case conn.Disconnected, conn.Reconnecting, conn.RegisteringTunnel, conn.Unregistering:
rs.Lock()
rs.isConnected[int(c.Index)] = false
rs.Unlock()
default:
rs.log.Error().Msgf("Unknown connection event case %v", c)
}
rs.tracker.OnTunnelEvent(c)
}
type body struct {
Status int `json:"status"`
ReadyConnections int `json:"readyConnections"`
Status int `json:"status"`
ReadyConnections uint `json:"readyConnections"`
}
// ServeHTTP responds with HTTP 200 if the tunnel is connected to the edge.
@ -63,15 +49,11 @@ func (rs *ReadyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// This is the bulk of the logic for ServeHTTP, broken into its own pure function
// to make unit testing easy.
func (rs *ReadyServer) makeResponse() (statusCode, readyConnections int) {
statusCode = http.StatusServiceUnavailable
rs.RLock()
defer rs.RUnlock()
for _, connected := range rs.isConnected {
if connected {
statusCode = http.StatusOK
readyConnections++
}
func (rs *ReadyServer) makeResponse() (statusCode int, readyConnections uint) {
readyConnections = rs.tracker.CountActiveConns()
if readyConnections > 0 {
return http.StatusOK, readyConnections
} else {
return http.StatusServiceUnavailable, readyConnections
}
return statusCode, readyConnections
}

View File

@ -7,6 +7,8 @@ import (
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/tunnelstate"
"github.com/cloudflare/cloudflared/connection"
)
@ -18,7 +20,7 @@ func TestReadyServer_makeResponse(t *testing.T) {
name string
fields fields
wantOK bool
wantReadyConnections int
wantReadyConnections uint
}{
{
name: "One connection online => HTTP 200",
@ -49,7 +51,7 @@ func TestReadyServer_makeResponse(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := &ReadyServer{
isConnected: tt.fields.isConnected,
tracker: tunnelstate.MockedConnTracker(tt.fields.isConnected),
}
gotStatusCode, gotReadyConnections := rs.makeResponse()
if tt.wantOK && gotStatusCode != http.StatusOK {

View File

@ -0,0 +1,42 @@
package origin
import (
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/tunnelstate"
)
type ConnAwareLogger struct {
tracker *tunnelstate.ConnTracker
logger *zerolog.Logger
}
func NewConnAwareLogger(logger *zerolog.Logger, observer *connection.Observer) *ConnAwareLogger {
connAwareLogger := &ConnAwareLogger{
tracker: tunnelstate.NewConnTracker(logger),
logger: logger,
}
observer.RegisterSink(connAwareLogger.tracker)
return connAwareLogger
}
func (c *ConnAwareLogger) ReplaceLogger(logger *zerolog.Logger) *ConnAwareLogger {
return &ConnAwareLogger{
tracker: c.tracker,
logger: logger,
}
}
func (c *ConnAwareLogger) ConnAwareLogger() *zerolog.Event {
if c.tracker.CountActiveConns() == 0 {
return c.logger.Error()
}
return c.logger.Warn()
}
func (c *ConnAwareLogger) Logger() *zerolog.Logger {
return c.logger
}

View File

@ -177,6 +177,11 @@ func (p *Proxy) proxyHTTPRequest(
roundTripReq.Header.Set("Connection", "keep-alive")
}
// Set the User-Agent as an empty string if not provided to avoid inserting golang default UA
if roundTripReq.Header.Get("User-Agent") == "" {
roundTripReq.Header.Set("User-Agent", "")
}
resp, err := httpService.RoundTrip(roundTripReq)
if err != nil {
return errors.Wrap(err, "Unable to reach the origin service. The service may be down or it may not be responding to traffic from cloudflared")

View File

@ -1,3 +1,4 @@
//go:build !windows
// +build !windows
package origin

View File

@ -767,7 +767,7 @@ func newWSRespWriter(w io.Writer) *wsRespWriter {
func (w *wsRespWriter) Write(p []byte) (int, error) {
returnedMsg, err := wsutil.ReadServerBinary(bytes.NewBuffer(p))
if err != nil {
// The data was not returned by a websocket connecton.
// The data was not returned by a websocket connection.
if err != io.ErrUnexpectedEOF {
return w.w.Write(p)
}

View File

@ -91,7 +91,7 @@ func (cm *reconnectCredentialManager) ConnDigest(connID uint8) ([]byte, error) {
defer cm.mu.RUnlock()
digest, ok := cm.connDigest[connID]
if !ok {
return nil, fmt.Errorf("no conneciton digest for connection %v", connID)
return nil, fmt.Errorf("no connection digest for connection %v", connID)
}
return digest, nil
}

View File

@ -46,7 +46,7 @@ type Supervisor struct {
nextConnectedIndex int
nextConnectedSignal chan struct{}
log *zerolog.Logger
log *ConnAwareLogger
logTransport *zerolog.Logger
reconnectCredentialManager *reconnectCredentialManager
@ -91,7 +91,7 @@ func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, grace
edgeIPs: edgeIPs,
tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{},
log: config.Log,
log: NewConnAwareLogger(config.Log, config.Observer),
logTransport: config.LogTransport,
reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
useReconnectToken: useReconnectToken,
@ -123,7 +123,7 @@ func (s *Supervisor) Run(
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
refreshAuthBackoffTimer = timer
} else {
s.log.Err(err).
s.log.Logger().Err(err).
Dur("refreshAuthRetryDuration", refreshAuthRetryDuration).
Msgf("supervisor: initial refreshAuth failed, retrying in %v", refreshAuthRetryDuration)
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
@ -145,7 +145,7 @@ func (s *Supervisor) Run(
case tunnelError := <-s.tunnelErrors:
tunnelsActive--
if tunnelError.err != nil && !shuttingDown {
s.log.Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
s.log.ConnAwareLogger().Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
s.waitForNextTunnel(tunnelError.index)
@ -170,7 +170,7 @@ func (s *Supervisor) Run(
case <-refreshAuthBackoffTimer:
newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
if err != nil {
s.log.Err(err).Msg("supervisor: Authentication failed")
s.log.Logger().Err(err).Msg("supervisor: Authentication failed")
// Permanent failure. Leave the `select` without setting the
// channel to be non-null, so we'll never hit this case of the `select` again.
continue
@ -195,7 +195,7 @@ func (s *Supervisor) initialize(
) error {
availableAddrs := s.edgeIPs.AvailableAddrs()
if s.config.HAConnections > availableAddrs {
s.log.Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
s.log.Logger().Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
s.config.HAConnections = availableAddrs
}
@ -244,6 +244,7 @@ func (s *Supervisor) startFirstTunnel(
s.reconnectCredentialManager,
s.config,
addr,
s.log,
firstConnIndex,
connectedSignal,
s.cloudflaredUUID,
@ -277,6 +278,7 @@ func (s *Supervisor) startFirstTunnel(
s.reconnectCredentialManager,
s.config,
addr,
s.log,
firstConnIndex,
connectedSignal,
s.cloudflaredUUID,
@ -310,6 +312,7 @@ func (s *Supervisor) startTunnel(
s.reconnectCredentialManager,
s.config,
addr,
s.log,
uint8(index),
connectedSignal,
s.cloudflaredUUID,
@ -373,7 +376,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
if err != nil {
return nil, err
}
rpcClient := connection.NewTunnelServerClient(ctx, stream, s.log)
rpcClient := connection.NewTunnelServerClient(ctx, stream, s.log.Logger())
defer rpcClient.Close()
const arbitraryConnectionID = uint8(0)

View File

@ -20,6 +20,7 @@ import (
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
"github.com/cloudflare/cloudflared/h2mux"
quicpogs "github.com/cloudflare/cloudflared/quic"
"github.com/cloudflare/cloudflared/retry"
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/tunnelrpc"
@ -34,13 +35,6 @@ const (
quicMaxIdleTimeout = 15 * time.Second
)
type rpcName string
const (
reconnect rpcName = "reconnect"
authenticate rpcName = " authenticate"
)
type TunnelConfig struct {
ConnectionConfig *connection.Config
OSArch string
@ -130,6 +124,7 @@ func ServeTunnelLoop(
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
addr *allregions.EdgeAddr,
connAwareLogger *ConnAwareLogger,
connIndex uint8,
connectedSignal *signal.Signal,
cloudflaredUUID uuid.UUID,
@ -139,7 +134,8 @@ func ServeTunnelLoop(
haConnections.Inc()
defer haConnections.Dec()
connLog := config.Log.With().Uint8(connection.LogFieldConnIndex, connIndex).Logger()
logger := config.Log.With().Uint8(connection.LogFieldConnIndex, connIndex).Logger()
connLog := connAwareLogger.ReplaceLogger(&logger)
protocolFallback := &protocolFallback{
retry.BackoffHandler{MaxRetries: config.Retries},
@ -159,7 +155,7 @@ func ServeTunnelLoop(
for {
err, recoverable := ServeTunnel(
ctx,
&connLog,
connLog,
credentialManager,
config,
addr,
@ -181,7 +177,7 @@ func ServeTunnelLoop(
if !ok {
return err
}
connLog.Info().Msgf("Retrying connection in up to %s seconds", duration)
connLog.Logger().Info().Msgf("Retrying connection in up to %s seconds", duration)
select {
case <-ctx.Done():
@ -189,7 +185,13 @@ func ServeTunnelLoop(
case <-gracefulShutdownC:
return nil
case <-protocolFallback.BackoffTimer():
if !selectNextProtocol(&connLog, protocolFallback, config.ProtocolSelector) {
var idleTimeoutError *quic.IdleTimeoutError
if !selectNextProtocol(
connLog.Logger(),
protocolFallback,
config.ProtocolSelector,
errors.As(err, &idleTimeoutError),
) {
return err
}
}
@ -221,8 +223,9 @@ func selectNextProtocol(
connLog *zerolog.Logger,
protocolBackoff *protocolFallback,
selector connection.ProtocolSelector,
isNetworkActivityTimeout bool,
) bool {
if protocolBackoff.ReachedMaxRetries() {
if protocolBackoff.ReachedMaxRetries() || isNetworkActivityTimeout {
fallback, hasFallback := selector.Fallback()
if !hasFallback {
return false
@ -247,7 +250,7 @@ func selectNextProtocol(
// on error returns a flag indicating if error can be retried
func ServeTunnel(
ctx context.Context,
connLog *zerolog.Logger,
connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
addr *allregions.EdgeAddr,
@ -291,29 +294,29 @@ func ServeTunnel(
if err != nil {
switch err := err.(type) {
case connection.DupConnRegisterTunnelError:
connLog.Err(err).Msg("Unable to establish connection.")
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection.")
// don't retry this connection anymore, let supervisor pick a new address
return err, false
case connection.ServerRegisterTunnelError:
connLog.Err(err).Msg("Register tunnel error from server side")
connLog.ConnAwareLogger().Err(err).Msg("Register tunnel error from server side")
// Don't send registration error return from server to Sentry. They are
// logged on server side
if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 {
connLog.Error().Msg(activeIncidentsMsg(incidents))
connLog.ConnAwareLogger().Msg(activeIncidentsMsg(incidents))
}
return err.Cause, !err.Permanent
case ReconnectSignal:
connLog.Info().
connLog.Logger().Info().
Uint8(connection.LogFieldConnIndex, connIndex).
Msgf("Restarting connection due to reconnect signal in %s", err.Delay)
err.DelayBeforeReconnect()
return err, true
default:
if err == context.Canceled {
connLog.Debug().Err(err).Msgf("Serve tunnel error")
connLog.Logger().Debug().Err(err).Msgf("Serve tunnel error")
return err, false
}
connLog.Err(err).Msgf("Serve tunnel error")
connLog.ConnAwareLogger().Err(err).Msgf("Serve tunnel error")
_, permanent := err.(unrecoverableError)
return err, !permanent
}
@ -323,7 +326,7 @@ func ServeTunnel(
func serveTunnel(
ctx context.Context,
connLog *zerolog.Logger,
connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
addr *allregions.EdgeAddr,
@ -351,21 +354,22 @@ func serveTunnel(
)
switch protocol {
case connection.QUIC:
case connection.QUIC, connection.QUICWarp:
connOptions := config.ConnectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
return ServeQUIC(ctx,
addr.UDP,
config,
connLog,
connOptions,
controlStream,
connectedFuse,
connIndex,
reconnectCh,
gracefulShutdownC)
case connection.HTTP2:
case connection.HTTP2, connection.HTTP2Warp:
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP)
if err != nil {
connLog.Err(err).Msg("Unable to establish connection with Cloudflare edge")
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
return err, true
}
@ -387,7 +391,7 @@ func serveTunnel(
default:
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr.TCP)
if err != nil {
connLog.Err(err).Msg("Unable to establish connection with Cloudflare edge")
connLog.ConnAwareLogger().Err(err).Msg("Unable to establish connection with Cloudflare edge")
return err, true
}
@ -419,7 +423,7 @@ func (r unrecoverableError) Error() string {
func ServeH2mux(
ctx context.Context,
connLog *zerolog.Logger,
connLog *ConnAwareLogger,
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
edgeConn net.Conn,
@ -429,7 +433,7 @@ func ServeH2mux(
reconnectCh chan ReconnectSignal,
gracefulShutdownC <-chan struct{},
) error {
connLog.Debug().Msgf("Connecting via h2mux")
connLog.Logger().Debug().Msgf("Connecting via h2mux")
// Returns error from parsing the origin URL or handshake errors
handler, err, recoverable := connection.NewH2muxConnection(
config.ConnectionConfig,
@ -466,7 +470,7 @@ func ServeH2mux(
func ServeHTTP2(
ctx context.Context,
connLog *zerolog.Logger,
connLog *ConnAwareLogger,
config *TunnelConfig,
tlsServerConn net.Conn,
connOptions *tunnelpogs.ConnectionOptions,
@ -475,7 +479,7 @@ func ServeHTTP2(
gracefulShutdownC <-chan struct{},
reconnectCh chan ReconnectSignal,
) error {
connLog.Debug().Msgf("Connecting via http2")
connLog.Logger().Debug().Msgf("Connecting via http2")
h2conn := connection.NewHTTP2Connection(
tlsServerConn,
config.ConnectionConfig,
@ -507,69 +511,57 @@ func ServeQUIC(
ctx context.Context,
edgeAddr *net.UDPAddr,
config *TunnelConfig,
connLogger *ConnAwareLogger,
connOptions *tunnelpogs.ConnectionOptions,
controlStreamHandler connection.ControlStreamHandler,
connectedFuse connection.ConnectedFuse,
connIndex uint8,
reconnectCh chan ReconnectSignal,
gracefulShutdownC <-chan struct{},
) (err error, recoverable bool) {
tlsConfig := config.EdgeTLSConfigs[connection.QUIC]
quicConfig := &quic.Config{
HandshakeIdleTimeout: quicHandshakeIdleTimeout,
MaxIdleTimeout: quicMaxIdleTimeout,
KeepAlive: true,
HandshakeIdleTimeout: quicHandshakeIdleTimeout,
MaxIdleTimeout: quicMaxIdleTimeout,
MaxIncomingStreams: connection.MaxConcurrentStreams,
MaxIncomingUniStreams: connection.MaxConcurrentStreams,
KeepAlive: true,
EnableDatagrams: true,
MaxDatagramFrameSize: quicpogs.MaxDatagramFrameSize,
Tracer: quicpogs.NewClientTracer(connLogger.Logger(), connIndex),
}
for {
select {
case <-ctx.Done():
return
default:
quicConn, err := connection.NewQUICConnection(
ctx,
quicConfig,
edgeAddr,
tlsConfig,
config.ConnectionConfig.OriginProxy,
connOptions,
controlStreamHandler,
config.Observer)
if err != nil {
config.Log.Error().Msgf("Failed to create new quic connection, err: %v", err)
return err, true
}
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
err := quicConn.Serve(ctx)
if err != nil {
config.Log.Error().Msgf("Failed to serve quic connection, err: %v", err)
}
return fmt.Errorf("Connection with edge closed")
})
quicConn, err := connection.NewQUICConnection(
quicConfig,
edgeAddr,
tlsConfig,
config.ConnectionConfig.OriginProxy,
connOptions,
controlStreamHandler,
connLogger.Logger())
if err != nil {
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection")
return err, true
}
errGroup.Go(func() error {
return listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)
})
err = errGroup.Wait()
if err == nil {
return nil, false
}
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
err := quicConn.Serve(serveCtx)
if err != nil {
connLogger.ConnAwareLogger().Err(err).Msg("Failed to serve quic connection")
}
}
}
return err
})
type quicLogger struct {
*zerolog.Logger
}
errGroup.Go(func() error {
err := listenReconnect(serveCtx, reconnectCh, gracefulShutdownC)
if err != nil {
// forcefully break the connection (this is only used for testing)
quicConn.Close()
}
return err
})
func (ql *quicLogger) Write(p []byte) (n int, err error) {
ql.Debug().Msgf("quic log: %v", string(p))
return len(p), nil
}
func (ql *quicLogger) Close() error {
return nil
return errGroup.Wait(), false
}
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh <-chan struct{}) error {

Some files were not shown because too many files have changed in this diff Show More