diff --git a/.teamcity/update-homebrew.sh b/.teamcity/update-homebrew.sh index 250e00e3..6450f0ef 100755 --- a/.teamcity/update-homebrew.sh +++ b/.teamcity/update-homebrew.sh @@ -46,8 +46,8 @@ git reset --hard origin/master URL="https://packages.argotunnel.com/dl/cloudflared-$VERSION-darwin-amd64.tgz" tee cloudflared.rb <` + +### Bug Fixes +- Fixed some generic transport bugs in `quic` mode. It's advised to upgrade to at least this version (2021.9.2) when running `cloudflared` +with `quic` protocol. +- `cloudflared` docker images will now show version. + + ## 2021.8.4 ### Improvements - Temporary tunnels (those hosted on trycloudflare.com that do not require a Cloudflare login) now run as Named Tunnels diff --git a/Makefile b/Makefile index a7918dca..11ecbc74 100644 --- a/Makefile +++ b/Makefile @@ -1,25 +1,41 @@ -VERSION := $(shell git describe --tags --always --dirty="-dev" --match "[0-9][0-9][0-9][0-9].*.*") +VERSION := $(shell git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*") MSI_VERSION := $(shell git tag -l --sort=v:refname | grep "w" | tail -1 | cut -c2-) #MSI_VERSION expects the format of the tag to be: (wX.X.X). Starts with the w character to not break cfsetup. #e.g. w3.0.1 or w4.2.10. It trims off the w character when creating the MSI. -ifeq ($(FIPS), true) - GO_BUILD_TAGS := $(GO_BUILD_TAGS) fips -endif - -ifneq ($(GO_BUILD_TAGS),) - GO_BUILD_TAGS := -tags $(GO_BUILD_TAGS) +ifeq ($(ORIGINAL_NAME), true) + # Used for builds that want FIPS compilation but want the artifacts generated to still have the original name. + BINARY_NAME := cloudflared +else ifeq ($(FIPS), true) + # Used for FIPS compliant builds that do not match the case above. + BINARY_NAME := cloudflared-fips +else + # Used for all other (non-FIPS) builds. + BINARY_NAME := cloudflared endif ifeq ($(NIGHTLY), true) - DEB_PACKAGE_NAME := cloudflared-nightly + DEB_PACKAGE_NAME := $(BINARY_NAME)-nightly NIGHTLY_FLAGS := --conflicts cloudflared --replaces cloudflared else - DEB_PACKAGE_NAME := cloudflared + DEB_PACKAGE_NAME := $(BINARY_NAME) endif DATE := $(shell date -u '+%Y-%m-%d-%H%M UTC') -VERSION_FLAGS := -ldflags='-X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)"' +VERSION_FLAGS := -X "main.Version=$(VERSION)" -X "main.BuildTime=$(DATE)" + +LINK_FLAGS := +ifeq ($(FIPS), true) + LINK_FLAGS := -linkmode=external -extldflags=-static $(LINK_FLAGS) + # Prevent linking with libc regardless of CGO enabled or not. + GO_BUILD_TAGS := $(GO_BUILD_TAGS) osusergo netgo fips + VERSION_FLAGS := $(VERSION_FLAGS) -X "main.BuildType=FIPS" +endif + +LDFLAGS := -ldflags='$(VERSION_FLAGS) $(LINK_FLAGS)' +ifneq ($(GO_BUILD_TAGS),) + GO_BUILD_TAGS := -tags "$(GO_BUILD_TAGS)" +endif IMPORT_PATH := github.com/cloudflare/cloudflared PACKAGE_DIR := $(CURDIR)/packaging @@ -61,9 +77,9 @@ else endif ifeq ($(TARGET_OS), windows) - EXECUTABLE_PATH=./cloudflared.exe + EXECUTABLE_PATH=./$(BINARY_NAME).exe else - EXECUTABLE_PATH=./cloudflared + EXECUTABLE_PATH=./$(BINARY_NAME) endif ifeq ($(FLAVOR), centos-7) @@ -80,17 +96,15 @@ clean: go clean .PHONY: cloudflared -cloudflared: +cloudflared: ifeq ($(FIPS), true) $(info Building cloudflared with go-fips) - -test -f fips/fips.go && mv fips/fips.go fips/fips.go.linux-amd64 - mv fips/fips.go.linux-amd64 fips/fips.go + cp -f fips/fips.go.linux-amd64 cmd/cloudflared/fips.go endif - - GOOS=$(TARGET_OS) GOARCH=$(TARGET_ARCH) go build -v -mod=vendor $(GO_BUILD_TAGS) $(VERSION_FLAGS) $(IMPORT_PATH)/cmd/cloudflared - + GOOS=$(TARGET_OS) GOARCH=$(TARGET_ARCH) go build -v -mod=vendor $(GO_BUILD_TAGS) $(LDFLAGS) $(IMPORT_PATH)/cmd/cloudflared ifeq ($(FIPS), true) - mv fips/fips.go fips/fips.go.linux-amd64 + rm -f cmd/cloudflared/fips.go + ./check-fips.sh cloudflared endif .PHONY: container @@ -100,10 +114,10 @@ container: .PHONY: test test: vet ifndef CI - go test -v -mod=vendor -race $(VERSION_FLAGS) ./... + go test -v -mod=vendor -race $(LDFLAGS) ./... else @mkdir -p .cover - go test -v -mod=vendor -race $(VERSION_FLAGS) -coverprofile=".cover/c.out" ./... + go test -v -mod=vendor -race $(LDFLAGS) -coverprofile=".cover/c.out" ./... go tool cover -html ".cover/c.out" -o .cover/all.html endif @@ -112,10 +126,10 @@ test-ssh-server: docker-compose -f ssh_server_tests/docker-compose.yml up define publish_package - chmod 664 cloudflared*.$(1); \ + chmod 664 $(BINARY_NAME)*.$(1); \ for HOST in $(CF_PKG_HOSTS); do \ - ssh-keyscan -t rsa $$HOST >> ~/.ssh/known_hosts; \ - scp -p -4 cloudflared*.$(1) cfsync@$$HOST:/state/cf-pkg/staging/$(2)/$(TARGET_PUBLIC_REPO)/cloudflared/; \ + ssh-keyscan -t ecdsa $$HOST >> ~/.ssh/known_hosts; \ + scp -p -4 $(BINARY_NAME)*.$(1) cfsync@$$HOST:/state/cf-pkg/staging/$(2)/$(TARGET_PUBLIC_REPO)/$(BINARY_NAME)/; \ done endef @@ -127,12 +141,14 @@ publish-deb: cloudflared-deb publish-rpm: cloudflared-rpm $(call publish_package,rpm,yum) +# When we build packages, the package name will be FIPS-aware. +# But we keep the binary installed by it to be named "cloudflared" regardless. define build_package mkdir -p $(PACKAGE_DIR) cp cloudflared $(PACKAGE_DIR)/cloudflared cat cloudflared_man_template | sed -e 's/\$${VERSION}/$(VERSION)/; s/\$${DATE}/$(DATE)/' > $(PACKAGE_DIR)/cloudflared.1 fakeroot fpm -C $(PACKAGE_DIR) -s dir -t $(1) \ - --description 'Cloudflare Argo tunnel daemon' \ + --description 'Cloudflare Tunnel daemon' \ --vendor 'Cloudflare' \ --license 'Cloudflare Service Agreement' \ --url 'https://github.com/cloudflare/cloudflared' \ @@ -247,8 +263,8 @@ tunnelrpc-deps: capnp compile -ogo tunnelrpc/tunnelrpc.capnp .PHONY: quic-deps -quic-deps: - which capnp +quic-deps: + which capnp which capnpc-go capnp compile -ogo quic/schema/quic_metadata_protocol.capnp @@ -258,9 +274,9 @@ vet: # go get github.com/sudarshan-reddy/go-sumtype (don't do this in build directory or this will cause vendor issues) # Note: If you have github.com/BurntSushi/go-sumtype then you might have to use the repo above instead # for now because it uses an older version of golang.org/x/tools. - which go-sumtype + which go-sumtype go-sumtype $$(go list -mod=vendor ./...) .PHONY: goimports goimports: - for d in $$(go list -mod=readonly -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc) ; do goimports -format-only -local github.com/cloudflare/cloudflared -w $$d ; done + for d in $$(go list -mod=readonly -f '{{.Dir}}' -a ./... | fgrep -v tunnelrpc) ; do goimports -format-only -local github.com/cloudflare/cloudflared -w $$d ; done \ No newline at end of file diff --git a/README.md b/README.md index 17b45e93..134e7cc9 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,28 @@ -# Argo Tunnel client +# Cloudflare Tunnel client + +Contains the command-line client for Cloudflare Tunnel, a tunneling daemon that proxies traffic from the Cloudflare network to your origins. +This daemon sits between Cloudflare network and your origin (e.g. a webserver). Cloudflare attracts client requests and sends them to you +via this daemon, without requiring you to poke holes on your firewall --- your origin can remain as closed as possible. +Extensive documentation can be found in the [Cloudflare Tunnel section](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps) of the Cloudflare Docs. +All usages related with proxying to your origins are available under `cloudflared tunnel help`. + +You can also use `cloudflared` to access Tunnel origins (that are protected with `cloudflared tunnel`) for TCP traffic +at Layer 4 (i.e., not HTTP/websocket), which is relevant for use cases such as SSH, RDP, etc. +Such usages are available under `cloudflared access help`. + +You can instead use [WARP client](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/configuration/private-networks) +to access private origins behind Tunnels for Layer 4 traffic without requiring `cloudflared access` commands on the client side. -Contains the command-line client for Argo Tunnel, a tunneling daemon that proxies any local webserver through the Cloudflare network. Extensive documentation can be found in the [Argo Tunnel section](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps) of the Cloudflare Docs. ## Before you get started -Before you use Argo Tunnel, you'll need to complete a few steps in the Cloudflare dashboard. The website you add to Cloudflare will be used to route traffic to your Tunnel. - +Before you use Cloudflare Tunnel, you'll need to complete a few steps in the Cloudflare dashboard: you need to add a +website to your Cloudflare account. Note that today it is possible to use Tunnel without a website (e.g. for private +routing), but for legacy reasons this requirement is still necessary: 1. [Add a website to Cloudflare](https://support.cloudflare.com/hc/en-us/articles/201720164-Creating-a-Cloudflare-account-and-adding-a-website) 2. [Change your domain nameservers to Cloudflare](https://support.cloudflare.com/hc/en-us/articles/205195708) + ## Installing `cloudflared` Downloads are available as standalone binaries, a Docker image, and Debian, RPM, and Homebrew packages. You can also find releases here on the `cloudflared` GitHub repository. @@ -18,18 +32,23 @@ Downloads are available as standalone binaries, a Docker image, and Debian, RPM, * A Docker image of `cloudflared` is [available on DockerHub](https://hub.docker.com/r/cloudflare/cloudflared) * You can install on Windows machines with the [steps here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation#windows) -User documentation for Argo Tunnel can be found at https://developers.cloudflare.com/cloudflare-one/connections/connect-apps +User documentation for Cloudflare Tunnel can be found at https://developers.cloudflare.com/cloudflare-one/connections/connect-apps + ## Creating Tunnels and routing traffic -Once installed, you can authenticate `cloudflared` into your Cloudflare account and begin creating Tunnels that serve traffic for hostnames in your account. +Once installed, you can authenticate `cloudflared` into your Cloudflare account and begin creating Tunnels to serve traffic to your origins. * Create a Tunnel with [these instructions](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/create-tunnel) -* Route traffic to that Tunnel with [DNS records in Cloudflare](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/dns) or with a [Cloudflare Load Balancer](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/lb) +* Route traffic to that Tunnel: + * Via public [DNS records in Cloudflare](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/dns) + * Or via a public hostname guided by a [Cloudflare Load Balancer](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/routing-to-tunnel/lb) + * Or from [WARP client private traffic](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/configuration/private-networks) + ## TryCloudflare -Want to test Argo Tunnel before adding a website to Cloudflare? You can do so with TryCloudflare using the documentation [available here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/trycloudflare). +Want to test Cloudflare Tunnel before adding a website to Cloudflare? You can do so with TryCloudflare using the documentation [available here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/trycloudflare). ## Deprecated versions diff --git a/RELEASE_NOTES b/RELEASE_NOTES index 397ce952..d1387417 100644 --- a/RELEASE_NOTES +++ b/RELEASE_NOTES @@ -1,3 +1,107 @@ +2022.1.2 +- 2022-01-13 TUN-5650: Fix pynacl version to 1.4.0 and pygithub version to 1.55 so release doesn't break unexpectedly + +2022.1.1 +- 2022-01-10 TUN-5631: Build everything with go 1.17.5 +- 2022-01-06 TUN-5623: Configure quic max datagram frame size to 1350 bytes for none Windows platforms + +2022.1.0 +- 2022-01-03 TUN-5612: Add support for specifying TLS min/max version +- 2022-01-03 TUN-5612: Make tls min/max version public visible +- 2022-01-03 TUN-5551: Internally published debian artifacts are now named just cloudflared even though they are FIPS compliant +- 2022-01-04 TUN-5600: Close QUIC transports as soon as possible while respecting graceful shutdown +- 2022-01-05 TUN-5616: Never fallback transport if user chooses it on purpose +- 2022-01-05 TUN-5204: Unregister QUIC transports on disconnect +- 2022-01-04 TUN-5600: Add coverage to component tests for various transports + +2021.12.4 +- 2021-12-27 TUN-5482: Refactor tunnelstore client related packages for more coherent package +- 2021-12-27 TUN-5551: Change internally published debian package to be FIPS compliant +- 2021-12-27 TUN-5551: Show whether the binary was built for FIPS compliance + +2021.12.3 +- 2021-12-22 TUN-5584: Changes for release 2021.12.2 +- 2021-12-22 TUN-5590: QUIC datagram max user payload is 1217 bytes +- 2021-12-22 TUN-5593: Read full packet from UDP connection, even if it exceeds MTU of the transport. When packet length is greater than the MTU of the transport, we will silently drop packets (for now). +- 2021-12-23 TUN-5597: Log session ID when session is terminated by edge + +2021.12.2 +- 2021-12-20 TUN-5571: Remove redundant session manager log, it's already logged in origin/tunnel.ServeQUIC +- 2021-12-20 TUN-5570: Only log RPC server events at error level to reduce noise +- 2021-12-14 TUN-5494: Send a RPC with terminate reason to edge if the session is closed locally +- 2021-11-09 TUN-5551: Reintroduce FIPS compliance for linux amd64 now as separate binaries + +2021.12.1 +- 2021-12-16 TUN-5549: Revert "TUN-5277: Ensure cloudflared binary is FIPS compliant on linux amd64" + +2021.12.0 +- 2021-12-13 TUN-5530: Get current time from ticker +- 2021-12-15 TUN-5544: Update CHANGES.md for next release +- 2021-12-07 TUN-5519: Adjust URL for virtual_networks endpoint to match what we will publish +- 2021-12-02 TUN-5488: Close session after it's idle for a period defined by registerUdpSession RPC +- 2021-12-09 TUN-5504: Fix upload of packages to public repo +- 2021-11-30 TUN-5481: Create abstraction for Origin UDP Connection +- 2021-11-30 TUN-5422: Define RPC to unregister session +- 2021-11-26 TUN-5361: Commands for managing virtual networks +- 2021-11-29 TUN-5362: Adjust route ip commands to be aware of virtual networks +- 2021-11-23 TUN-5301: Separate datagram multiplex and session management logic from quic connection logic +- 2021-11-10 TUN-5405: Update net package to v0.0.0-20211109214657-ef0fda0de508 +- 2021-11-10 TUN-5408: Update quic package to v0.24.0 +- 2021-11-12 Fix typos +- 2021-11-13 Fix for Issue #501: Unexpected User-agent insertion when tunneling http request +- 2021-11-16 TUN-5129: Remove `-dev` suffix when computing version and Git has uncommitted changes +- 2021-11-18 TUN-5441: Fix message about available protocols +- 2021-11-12 TUN-5300: Define RPC to register UDP sessions +- 2021-11-14 TUN-5299: Send/receive QUIC datagram from edge and proxy to origin as UDP +- 2021-11-04 TUN-5387: Updated CHANGES.md for 2021.11.0 +- 2021-11-08 TUN-5368: Log connection issues with LogLevel that depends on tunnel state +- 2021-11-09 TUN-5397: Log cloudflared output when it fails to connect tunnel +- 2021-11-09 TUN-5277: Ensure cloudflared binary is FIPS compliant on linux amd64 +- 2021-11-08 TUN-5393: Content-length is no longer a control header for non-h2mux transports + +2021.11.0 +- 2021-11-03 TUN-5285: Fallback to HTTP2 immediately if connection times out with no network activity +- 2021-09-29 Add flag to 'tunnel create' subcommand to specify a base64-encoded secret + +2021.10.5 +- 2021-10-25 Update change log for release 2021.10.4 +- 2021-10-25 Revert "TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown" + +2021.10.4 +- 2021-10-21 TUN-5287: Fix misuse of wait group in TestQUICServer that caused the test to exit immediately +- 2021-10-21 TUN-5286: Upgrade crypto/ssh package to fix CVE-2020-29652 +- 2021-10-18 TUN-5262: Allow to configure max fetch size for listing queries +- 2021-10-19 TUN-5262: Improvements to `max-fetch-size` that allow to deal with large number of tunnels in account +- 2021-10-15 TUN-5261: Collect QUIC metrics about RTT, packets and bytes transfered and log events at tracing level +- 2021-10-19 TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown + +2021.10.3 +- 2021-10-14 TUN-5255: Fix potential panic if Cloudflare API fails to respond to GetTunnel(id) during delete command +- 2021-10-14 TUN-5257: Fix more cfsetup targets that were broken by recent package changes + +2021.10.2 +- 2021-10-11 TUN-5138: Switch to QUIC on auto protocol based on threshold +- 2021-10-14 TUN-5250: Add missing packages for cfsetup to succeed in github release pkgs target + +2021.10.1 +- 2021-10-12 TUN-5246: Use protocol: quic for Quick tunnels if one is not already set +- 2021-10-13 TUN-5249: Revert "TUN-5138: Switch to QUIC on auto protocol based on threshold" + +2021.10.0 +- 2021-10-11 TUN-5138: Switch to QUIC on auto protocol based on threshold +- 2021-10-07 TUN-5195: Do not set empty body if not applicable +- 2021-10-08 UN-5213: Increase MaxStreams value for QUIC transport +- 2021-09-28 TUN-5169: Release 2021.9.2 CHANGES.md +- 2021-09-28 TUN-5164: Update README and clean up references to Argo Tunnel (using Cloudflare Tunnel instead) + +2021.9.2 +- 2021-09-21 TUN-5129: Use go 1.17 and copy .git folder to docker build to compute version +- 2021-09-21 TUN-5128: Enforce maximum grace period +- 2021-09-22 TUN-5141: Make sure websocket pinger returns before streaming returns +- 2021-09-24 TUN-5142: Add asynchronous servecontrolstream for QUIC +- 2021-09-24 TUN-5142: defer close rpcconn inside unregister instead of ServeControlStream +- 2021-09-27 TUN-5160: Set request.ContentLength when this value is in request header + 2021.9.1 - 2021-09-21 TUN-5118: Quic connection now detects duplicate connections similar to http2 - 2021-09-15 Fix TryCloudflare link diff --git a/build-packages-fips.sh b/build-packages-fips.sh new file mode 100755 index 00000000..6daec235 --- /dev/null +++ b/build-packages-fips.sh @@ -0,0 +1,25 @@ +VERSION=$(git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*") +echo $VERSION + +# This controls the directory the built artifacts go into +export ARTIFACT_DIR=built_artifacts/ +mkdir -p $ARTIFACT_DIR + +arch=("amd64") +export TARGET_ARCH=$arch +export TARGET_OS=linux +export FIPS=true +# For BoringCrypto to link, we need CGO enabled. Otherwise compilation fails. +export CGO_ENABLED=1 + +make cloudflared-deb +mv cloudflared-fips\_$VERSION\_$arch.deb $ARTIFACT_DIR/cloudflared-fips-linux-$arch.deb + +# rpm packages invert the - and _ and use x86_64 instead of amd64. +RPMVERSION=$(echo $VERSION|sed -r 's/-/_/g') +RPMARCH="x86_64" +make cloudflared-rpm +mv cloudflared-fips-$RPMVERSION-1.$RPMARCH.rpm $ARTIFACT_DIR/cloudflared-fips-linux-$RPMARCH.rpm + +# finally move the linux binary as well. +mv ./cloudflared $ARTIFACT_DIR/cloudflared-fips-linux-$arch \ No newline at end of file diff --git a/build-packages.sh b/build-packages.sh index 7b4ece82..2c01bf9b 100755 --- a/build-packages.sh +++ b/build-packages.sh @@ -1,12 +1,15 @@ -VERSION=$(git describe --tags --always --dirty="-dev" --match "[0-9][0-9][0-9][0-9].*.*") +VERSION=$(git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*") echo $VERSION + +# Avoid depending on C code since we don't need it. export CGO_ENABLED=0 + # This controls the directory the built artifacts go into export ARTIFACT_DIR=built_artifacts/ mkdir -p $ARTIFACT_DIR windowsArchs=("amd64" "386") export TARGET_OS=windows -for arch in ${windowsArchs[@]}; do +for arch in ${windowsArchs[@]}; do export TARGET_ARCH=$arch make cloudflared-msi mv ./cloudflared.exe $ARTIFACT_DIR/cloudflared-windows-$arch.exe @@ -14,15 +17,14 @@ for arch in ${windowsArchs[@]}; do done -export FIPS=true -linuxArchs=("amd64" "386" "arm" "arm64") +linuxArchs=("386" "amd64" "arm" "arm64") export TARGET_OS=linux -for arch in ${linuxArchs[@]}; do +for arch in ${linuxArchs[@]}; do export TARGET_ARCH=$arch make cloudflared-deb mv cloudflared\_$VERSION\_$arch.deb $ARTIFACT_DIR/cloudflared-linux-$arch.deb - # rpm packages invert the - and _ and use x86_64 instead of amd64. + # rpm packages invert the - and _ and use x86_64 instead of amd64. RPMVERSION=$(echo $VERSION|sed -r 's/-/_/g') RPMARCH=$arch if [ $arch == "amd64" ];then @@ -37,4 +39,3 @@ for arch in ${linuxArchs[@]}; do # finally move the linux binary as well. mv ./cloudflared $ARTIFACT_DIR/cloudflared-linux-$arch done - diff --git a/carrier/carrier.go b/carrier/carrier.go index 1e423b77..152430d1 100644 --- a/carrier/carrier.go +++ b/carrier/carrier.go @@ -54,7 +54,7 @@ func (c *StdinoutStream) Write(p []byte) (int, error) { return os.Stdout.Write(p) } -// Helper to allow defering the response close with a check that the resp is not nil +// Helper to allow deferring the response close with a check that the resp is not nil func closeRespBody(resp *http.Response) { if resp != nil { _ = resp.Body.Close() diff --git a/certutil/certutil.go b/certutil/certutil.go index e3b48c62..0e90ed4b 100644 --- a/certutil/certutil.go +++ b/certutil/certutil.go @@ -66,7 +66,7 @@ func DecodeOriginCert(blocks []byte) (*OriginCert, error) { originCert.ServiceKey = ntt.ServiceKey originCert.AccountID = ntt.AccountID } else { - // Try the older format, where the zoneID and service key are seperated by + // Try the older format, where the zoneID and service key are separated by // a new line character token := string(block.Bytes) s := strings.Split(token, "\n") diff --git a/cfapi/base_client.go b/cfapi/base_client.go new file mode 100644 index 00000000..42b99316 --- /dev/null +++ b/cfapi/base_client.go @@ -0,0 +1,186 @@ +package cfapi + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/rs/zerolog" + "golang.org/x/net/http2" +) + +const ( + defaultTimeout = 15 * time.Second + jsonContentType = "application/json" +) + +var ( + ErrUnauthorized = errors.New("unauthorized") + ErrBadRequest = errors.New("incorrect request parameters") + ErrNotFound = errors.New("not found") + ErrAPINoSuccess = errors.New("API call failed") +) + +type RESTClient struct { + baseEndpoints *baseEndpoints + authToken string + userAgent string + client http.Client + log *zerolog.Logger +} + +type baseEndpoints struct { + accountLevel url.URL + zoneLevel url.URL + accountRoutes url.URL + accountVnets url.URL +} + +var _ Client = (*RESTClient)(nil) + +func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, log *zerolog.Logger) (*RESTClient, error) { + if strings.HasSuffix(baseURL, "/") { + baseURL = baseURL[:len(baseURL)-1] + } + accountLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag)) + if err != nil { + return nil, errors.Wrap(err, "failed to create account level endpoint") + } + accountRoutesEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/teamnet/routes", baseURL, accountTag)) + if err != nil { + return nil, errors.Wrap(err, "failed to create route account-level endpoint") + } + accountVnetsEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/teamnet/virtual_networks", baseURL, accountTag)) + if err != nil { + return nil, errors.Wrap(err, "failed to create virtual network account-level endpoint") + } + zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag)) + if err != nil { + return nil, errors.Wrap(err, "failed to create account level endpoint") + } + httpTransport := http.Transport{ + TLSHandshakeTimeout: defaultTimeout, + ResponseHeaderTimeout: defaultTimeout, + } + http2.ConfigureTransport(&httpTransport) + return &RESTClient{ + baseEndpoints: &baseEndpoints{ + accountLevel: *accountLevelEndpoint, + zoneLevel: *zoneLevelEndpoint, + accountRoutes: *accountRoutesEndpoint, + accountVnets: *accountVnetsEndpoint, + }, + authToken: authToken, + userAgent: userAgent, + client: http.Client{ + Transport: &httpTransport, + Timeout: defaultTimeout, + }, + log: log, + }, nil +} + +func (r *RESTClient) sendRequest(method string, url url.URL, body interface{}) (*http.Response, error) { + var bodyReader io.Reader + if body != nil { + if bodyBytes, err := json.Marshal(body); err != nil { + return nil, errors.Wrap(err, "failed to serialize json body") + } else { + bodyReader = bytes.NewBuffer(bodyBytes) + } + } + + req, err := http.NewRequest(method, url.String(), bodyReader) + if err != nil { + return nil, errors.Wrapf(err, "can't create %s request", method) + } + req.Header.Set("User-Agent", r.userAgent) + if bodyReader != nil { + req.Header.Set("Content-Type", jsonContentType) + } + req.Header.Add("X-Auth-User-Service-Key", r.authToken) + req.Header.Add("Accept", "application/json;version=1") + return r.client.Do(req) +} + +func parseResponse(reader io.Reader, data interface{}) error { + // Schema for Tunnelstore responses in the v1 API. + // Roughly, it's a wrapper around a particular result that adds failures/errors/etc + var result response + // First, parse the wrapper and check the API call succeeded + if err := json.NewDecoder(reader).Decode(&result); err != nil { + return errors.Wrap(err, "failed to decode response") + } + if err := result.checkErrors(); err != nil { + return err + } + if !result.Success { + return ErrAPINoSuccess + } + // At this point we know the API call succeeded, so, parse out the inner + // result into the datatype provided as a parameter. + if err := json.Unmarshal(result.Result, &data); err != nil { + return errors.Wrap(err, "the Cloudflare API response was an unexpected type") + } + return nil +} + +type response struct { + Success bool `json:"success,omitempty"` + Errors []apiErr `json:"errors,omitempty"` + Messages []string `json:"messages,omitempty"` + Result json.RawMessage `json:"result,omitempty"` +} + +func (r *response) checkErrors() error { + if len(r.Errors) == 0 { + return nil + } + if len(r.Errors) == 1 { + return r.Errors[0] + } + var messages string + for _, e := range r.Errors { + messages += fmt.Sprintf("%s; ", e) + } + return fmt.Errorf("API errors: %s", messages) +} + +type apiErr struct { + Code json.Number `json:"code,omitempty"` + Message string `json:"message,omitempty"` +} + +func (e apiErr) Error() string { + return fmt.Sprintf("code: %v, reason: %s", e.Code, e.Message) +} + +func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error { + if resp.Header.Get("Content-Type") == "application/json" { + var errorsResp response + if json.NewDecoder(resp.Body).Decode(&errorsResp) == nil { + if err := errorsResp.checkErrors(); err != nil { + return errors.Errorf("Failed to %s: %s", op, err) + } + } + } + + switch resp.StatusCode { + case http.StatusOK: + return nil + case http.StatusBadRequest: + return ErrBadRequest + case http.StatusUnauthorized, http.StatusForbidden: + return ErrUnauthorized + case http.StatusNotFound: + return ErrNotFound + } + return errors.Errorf("API call to %s failed with status %d: %s", op, + resp.StatusCode, http.StatusText(resp.StatusCode)) +} diff --git a/cfapi/client.go b/cfapi/client.go new file mode 100644 index 00000000..b4f17927 --- /dev/null +++ b/cfapi/client.go @@ -0,0 +1,39 @@ +package cfapi + +import ( + "github.com/google/uuid" +) + +type TunnelClient interface { + CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) + GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) + DeleteTunnel(tunnelID uuid.UUID) error + ListTunnels(filter *TunnelFilter) ([]*Tunnel, error) + ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error) + CleanupConnections(tunnelID uuid.UUID, params *CleanupParams) error +} + +type HostnameClient interface { + RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error) +} + +type IPRouteClient interface { + ListRoutes(filter *IpRouteFilter) ([]*DetailedRoute, error) + AddRoute(newRoute NewRoute) (Route, error) + DeleteRoute(params DeleteRouteParams) error + GetByIP(params GetRouteByIpParams) (DetailedRoute, error) +} + +type VnetClient interface { + CreateVirtualNetwork(newVnet NewVirtualNetwork) (VirtualNetwork, error) + ListVirtualNetworks(filter *VnetFilter) ([]*VirtualNetwork, error) + DeleteVirtualNetwork(id uuid.UUID) error + UpdateVirtualNetwork(id uuid.UUID, updates UpdateVirtualNetwork) error +} + +type Client interface { + TunnelClient + HostnameClient + IPRouteClient + VnetClient +} diff --git a/cfapi/hostname.go b/cfapi/hostname.go new file mode 100644 index 00000000..b8ca8bd4 --- /dev/null +++ b/cfapi/hostname.go @@ -0,0 +1,192 @@ +package cfapi + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "path" + + "github.com/google/uuid" + "github.com/pkg/errors" +) + +type Change = string + +const ( + ChangeNew = "new" + ChangeUpdated = "updated" + ChangeUnchanged = "unchanged" +) + +// HostnameRoute represents a record type that can route to a tunnel +type HostnameRoute interface { + json.Marshaler + RecordType() string + UnmarshalResult(body io.Reader) (HostnameRouteResult, error) + String() string +} + +type HostnameRouteResult interface { + // SuccessSummary explains what will route to this tunnel when it's provisioned successfully + SuccessSummary() string +} + +type DNSRoute struct { + userHostname string + overwriteExisting bool +} + +type DNSRouteResult struct { + route *DNSRoute + CName Change `json:"cname"` + Name string `json:"name"` +} + +func NewDNSRoute(userHostname string, overwriteExisting bool) HostnameRoute { + return &DNSRoute{ + userHostname: userHostname, + overwriteExisting: overwriteExisting, + } +} + +func (dr *DNSRoute) MarshalJSON() ([]byte, error) { + s := struct { + Type string `json:"type"` + UserHostname string `json:"user_hostname"` + OverwriteExisting bool `json:"overwrite_existing"` + }{ + Type: dr.RecordType(), + UserHostname: dr.userHostname, + OverwriteExisting: dr.overwriteExisting, + } + return json.Marshal(&s) +} + +func (dr *DNSRoute) UnmarshalResult(body io.Reader) (HostnameRouteResult, error) { + var result DNSRouteResult + err := parseResponse(body, &result) + result.route = dr + return &result, err +} + +func (dr *DNSRoute) RecordType() string { + return "dns" +} + +func (dr *DNSRoute) String() string { + return fmt.Sprintf("%s %s", dr.RecordType(), dr.userHostname) +} + +func (res *DNSRouteResult) SuccessSummary() string { + var msgFmt string + switch res.CName { + case ChangeNew: + msgFmt = "Added CNAME %s which will route to this tunnel" + case ChangeUpdated: // this is not currently returned by tunnelsore + msgFmt = "%s updated to route to your tunnel" + case ChangeUnchanged: + msgFmt = "%s is already configured to route to your tunnel" + } + return fmt.Sprintf(msgFmt, res.hostname()) +} + +// hostname yields the resulting name for the DNS route; if that is not available from Cloudflare API, then the +// requested name is returned instead (should not be the common path, it is just a fall-back). +func (res *DNSRouteResult) hostname() string { + if res.Name != "" { + return res.Name + } + return res.route.userHostname +} + +type LBRoute struct { + lbName string + lbPool string +} + +type LBRouteResult struct { + route *LBRoute + LoadBalancer Change `json:"load_balancer"` + Pool Change `json:"pool"` +} + +func NewLBRoute(lbName, lbPool string) HostnameRoute { + return &LBRoute{ + lbName: lbName, + lbPool: lbPool, + } +} + +func (lr *LBRoute) MarshalJSON() ([]byte, error) { + s := struct { + Type string `json:"type"` + LBName string `json:"lb_name"` + LBPool string `json:"lb_pool"` + }{ + Type: lr.RecordType(), + LBName: lr.lbName, + LBPool: lr.lbPool, + } + return json.Marshal(&s) +} + +func (lr *LBRoute) RecordType() string { + return "lb" +} + +func (lb *LBRoute) String() string { + return fmt.Sprintf("%s %s %s", lb.RecordType(), lb.lbName, lb.lbPool) +} + +func (lr *LBRoute) UnmarshalResult(body io.Reader) (HostnameRouteResult, error) { + var result LBRouteResult + err := parseResponse(body, &result) + result.route = lr + return &result, err +} + +func (res *LBRouteResult) SuccessSummary() string { + var msg string + switch res.LoadBalancer + "," + res.Pool { + case "new,new": + msg = "Created load balancer %s and added a new pool %s with this tunnel as an origin" + case "new,updated": + msg = "Created load balancer %s with an existing pool %s which was updated to use this tunnel as an origin" + case "new,unchanged": + msg = "Created load balancer %s with an existing pool %s which already has this tunnel as an origin" + case "updated,new": + msg = "Added new pool %[2]s with this tunnel as an origin to load balancer %[1]s" + case "updated,updated": + msg = "Updated pool %[2]s to use this tunnel as an origin and added it to load balancer %[1]s" + case "updated,unchanged": + msg = "Added pool %[2]s, which already has this tunnel as an origin, to load balancer %[1]s" + case "unchanged,updated": + msg = "Added this tunnel as an origin in pool %[2]s which is already used by load balancer %[1]s" + case "unchanged,unchanged": + msg = "Load balancer %s already uses pool %s which has this tunnel as an origin" + case "unchanged,new": + // this state is not possible + fallthrough + default: + msg = "Something went wrong: failed to modify load balancer %s with pool %s; please check traffic manager configuration in the dashboard" + } + + return fmt.Sprintf(msg, res.route.lbName, res.route.lbPool) +} + +func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error) { + endpoint := r.baseEndpoints.zoneLevel + endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID)) + resp, err := r.sendRequest("PUT", endpoint, route) + if err != nil { + return nil, errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return route.UnmarshalResult(resp.Body) + } + + return nil, r.statusCodeToError("add route", resp) +} diff --git a/tunnelstore/client_test.go b/cfapi/hostname_test.go similarity index 51% rename from tunnelstore/client_test.go rename to cfapi/hostname_test.go index a3e0ec97..5100465a 100644 --- a/tunnelstore/client_test.go +++ b/cfapi/hostname_test.go @@ -1,17 +1,9 @@ -package tunnelstore +package cfapi import ( - "bytes" - "fmt" - "io" - "io/ioutil" - "net" - "reflect" "strings" "testing" - "time" - "github.com/google/uuid" "github.com/stretchr/testify/assert" ) @@ -105,125 +97,3 @@ func TestLBRouteResultSuccessSummary(t *testing.T) { assert.Equal(t, tt.expected, actual, "case %d", i+1) } } - -func Test_parseListTunnels(t *testing.T) { - type args struct { - body string - } - tests := []struct { - name string - args args - want []*Tunnel - wantErr bool - }{ - { - name: "empty list", - args: args{body: `{"success": true, "result": []}`}, - want: []*Tunnel{}, - }, - { - name: "success is false", - args: args{body: `{"success": false, "result": []}`}, - wantErr: true, - }, - { - name: "errors are present", - args: args{body: `{"errors": [{"code": 1003, "message":"An A, AAAA or CNAME record already exists with that host"}], "result": []}`}, - wantErr: true, - }, - { - name: "invalid response", - args: args{body: `abc`}, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - body := ioutil.NopCloser(bytes.NewReader([]byte(tt.args.body))) - got, err := parseListTunnels(body) - if (err != nil) != tt.wantErr { - t.Errorf("parseListTunnels() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("parseListTunnels() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_unmarshalTunnel(t *testing.T) { - type args struct { - reader io.Reader - } - tests := []struct { - name string - args args - want *Tunnel - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := unmarshalTunnel(tt.args.reader) - if (err != nil) != tt.wantErr { - t.Errorf("unmarshalTunnel() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("unmarshalTunnel() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestUnmarshalTunnelOk(t *testing.T) { - - jsonBody := `{"success": true, "result": {"id": "00000000-0000-0000-0000-000000000000","name":"test","created_at":"0001-01-01T00:00:00Z","connections":[]}}` - expected := Tunnel{ - ID: uuid.Nil, - Name: "test", - CreatedAt: time.Time{}, - Connections: []Connection{}, - } - actual, err := unmarshalTunnel(bytes.NewReader([]byte(jsonBody))) - assert.NoError(t, err) - assert.Equal(t, &expected, actual) -} - -func TestUnmarshalTunnelErr(t *testing.T) { - - tests := []string{ - `abc`, - `{"success": true, "result": abc}`, - `{"success": false, "result": {"id": "00000000-0000-0000-0000-000000000000","name":"test","created_at":"0001-01-01T00:00:00Z","connections":[]}}}`, - `{"errors": [{"code": 1003, "message":"An A, AAAA or CNAME record already exists with that host"}], "result": {"id": "00000000-0000-0000-0000-000000000000","name":"test","created_at":"0001-01-01T00:00:00Z","connections":[]}}}`, - } - - for i, test := range tests { - _, err := unmarshalTunnel(bytes.NewReader([]byte(test))) - assert.Error(t, err, fmt.Sprintf("Test #%v failed", i)) - } -} - -func TestUnmarshalConnections(t *testing.T) { - jsonBody := `{"success":true,"messages":[],"errors":[],"result":[{"id":"d4041254-91e3-4deb-bd94-b46e11680b1e","features":["ha-origin"],"version":"2021.2.5","arch":"darwin_amd64","conns":[{"colo_name":"LIS","id":"ac2286e5-c708-4588-a6a0-ba6b51940019","is_pending_reconnect":false,"origin_ip":"148.38.28.2","opened_at":"0001-01-01T00:00:00Z"}],"run_at":"0001-01-01T00:00:00Z"}]}` - expected := ActiveClient{ - ID: uuid.MustParse("d4041254-91e3-4deb-bd94-b46e11680b1e"), - Features: []string{"ha-origin"}, - Version: "2021.2.5", - Arch: "darwin_amd64", - RunAt: time.Time{}, - Connections: []Connection{{ - ID: uuid.MustParse("ac2286e5-c708-4588-a6a0-ba6b51940019"), - ColoName: "LIS", - IsPendingReconnect: false, - OriginIP: net.ParseIP("148.38.28.2"), - OpenedAt: time.Time{}, - }}, - } - actual, err := parseConnectionsDetails(bytes.NewReader([]byte(jsonBody))) - assert.NoError(t, err) - assert.Equal(t, []*ActiveClient{&expected}, actual) -} diff --git a/cfapi/ip_route.go b/cfapi/ip_route.go new file mode 100644 index 00000000..6749aa89 --- /dev/null +++ b/cfapi/ip_route.go @@ -0,0 +1,240 @@ +package cfapi + +import ( + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "path" + "time" + + "github.com/google/uuid" + "github.com/pkg/errors" +) + +// Route is a mapping from customer's IP space to a tunnel. +// Each route allows the customer to route eyeballs in their corporate network +// to certain private IP ranges. Each Route represents an IP range in their +// network, and says that eyeballs can reach that route using the corresponding +// tunnel. +type Route struct { + Network CIDR `json:"network"` + TunnelID uuid.UUID `json:"tunnel_id"` + // Optional field. When unset, it means the Route belongs to the default virtual network. + VNetID *uuid.UUID `json:"virtual_network_id,omitempty"` + Comment string `json:"comment"` + CreatedAt time.Time `json:"created_at"` + DeletedAt time.Time `json:"deleted_at"` +} + +// CIDR is just a newtype wrapper around net.IPNet. It adds JSON unmarshalling. +type CIDR net.IPNet + +func (c CIDR) String() string { + n := net.IPNet(c) + return n.String() +} + +func (c CIDR) MarshalJSON() ([]byte, error) { + str := c.String() + json, err := json.Marshal(str) + if err != nil { + return nil, errors.Wrap(err, "error serializing CIDR into JSON") + } + return json, nil +} + +// UnmarshalJSON parses a JSON string into net.IPNet +func (c *CIDR) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return errors.Wrap(err, "error parsing cidr string") + } + _, network, err := net.ParseCIDR(s) + if err != nil { + return errors.Wrap(err, "error parsing invalid network from backend") + } + if network == nil { + return fmt.Errorf("backend returned invalid network %s", s) + } + *c = CIDR(*network) + return nil +} + +// NewRoute has all the parameters necessary to add a new route to the table. +type NewRoute struct { + Network net.IPNet + TunnelID uuid.UUID + Comment string + // Optional field. If unset, backend will assume the default vnet for the account. + VNetID *uuid.UUID +} + +// MarshalJSON handles fields with non-JSON types (e.g. net.IPNet). +func (r NewRoute) MarshalJSON() ([]byte, error) { + return json.Marshal(&struct { + TunnelID uuid.UUID `json:"tunnel_id"` + Comment string `json:"comment"` + VNetID *uuid.UUID `json:"virtual_network_id,omitempty"` + }{ + TunnelID: r.TunnelID, + Comment: r.Comment, + VNetID: r.VNetID, + }) +} + +// DetailedRoute is just a Route with some extra fields, e.g. TunnelName. +type DetailedRoute struct { + Network CIDR `json:"network"` + TunnelID uuid.UUID `json:"tunnel_id"` + // Optional field. When unset, it means the DetailedRoute belongs to the default virtual network. + VNetID *uuid.UUID `json:"virtual_network_id,omitempty"` + Comment string `json:"comment"` + CreatedAt time.Time `json:"created_at"` + DeletedAt time.Time `json:"deleted_at"` + TunnelName string `json:"tunnel_name"` +} + +// IsZero checks if DetailedRoute is the zero value. +func (r *DetailedRoute) IsZero() bool { + return r.TunnelID == uuid.Nil +} + +// TableString outputs a table row summarizing the route, to be used +// when showing the user their routing table. +func (r DetailedRoute) TableString() string { + deletedColumn := "-" + if !r.DeletedAt.IsZero() { + deletedColumn = r.DeletedAt.Format(time.RFC3339) + } + vnetColumn := "default" + if r.VNetID != nil { + vnetColumn = r.VNetID.String() + } + + return fmt.Sprintf( + "%s\t%s\t%s\t%s\t%s\t%s\t%s\t", + r.Network.String(), + vnetColumn, + r.Comment, + r.TunnelID, + r.TunnelName, + r.CreatedAt.Format(time.RFC3339), + deletedColumn, + ) +} + +type DeleteRouteParams struct { + Network net.IPNet + // Optional field. If unset, backend will assume the default vnet for the account. + VNetID *uuid.UUID +} + +type GetRouteByIpParams struct { + Ip net.IP + // Optional field. If unset, backend will assume the default vnet for the account. + VNetID *uuid.UUID +} + +// ListRoutes calls the Tunnelstore GET endpoint for all routes under an account. +func (r *RESTClient) ListRoutes(filter *IpRouteFilter) ([]*DetailedRoute, error) { + endpoint := r.baseEndpoints.accountRoutes + endpoint.RawQuery = filter.Encode() + resp, err := r.sendRequest("GET", endpoint, nil) + if err != nil { + return nil, errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return parseListDetailedRoutes(resp.Body) + } + + return nil, r.statusCodeToError("list routes", resp) +} + +// AddRoute calls the Tunnelstore POST endpoint for a given route. +func (r *RESTClient) AddRoute(newRoute NewRoute) (Route, error) { + endpoint := r.baseEndpoints.accountRoutes + endpoint.Path = path.Join(endpoint.Path, "network", url.PathEscape(newRoute.Network.String())) + resp, err := r.sendRequest("POST", endpoint, newRoute) + if err != nil { + return Route{}, errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return parseRoute(resp.Body) + } + + return Route{}, r.statusCodeToError("add route", resp) +} + +// DeleteRoute calls the Tunnelstore DELETE endpoint for a given route. +func (r *RESTClient) DeleteRoute(params DeleteRouteParams) error { + endpoint := r.baseEndpoints.accountRoutes + endpoint.Path = path.Join(endpoint.Path, "network", url.PathEscape(params.Network.String())) + setVnetParam(&endpoint, params.VNetID) + + resp, err := r.sendRequest("DELETE", endpoint, nil) + if err != nil { + return errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + _, err := parseRoute(resp.Body) + return err + } + + return r.statusCodeToError("delete route", resp) +} + +// GetByIP checks which route will proxy a given IP. +func (r *RESTClient) GetByIP(params GetRouteByIpParams) (DetailedRoute, error) { + endpoint := r.baseEndpoints.accountRoutes + endpoint.Path = path.Join(endpoint.Path, "ip", url.PathEscape(params.Ip.String())) + setVnetParam(&endpoint, params.VNetID) + + resp, err := r.sendRequest("GET", endpoint, nil) + if err != nil { + return DetailedRoute{}, errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return parseDetailedRoute(resp.Body) + } + + return DetailedRoute{}, r.statusCodeToError("get route by IP", resp) +} + +func parseListDetailedRoutes(body io.ReadCloser) ([]*DetailedRoute, error) { + var routes []*DetailedRoute + err := parseResponse(body, &routes) + return routes, err +} + +func parseRoute(body io.ReadCloser) (Route, error) { + var route Route + err := parseResponse(body, &route) + return route, err +} + +func parseDetailedRoute(body io.ReadCloser) (DetailedRoute, error) { + var route DetailedRoute + err := parseResponse(body, &route) + return route, err +} + +// setVnetParam overwrites the URL's query parameters with a query param to scope the HostnameRoute action to a certain +// virtual network (if one is provided). +func setVnetParam(endpoint *url.URL, vnetID *uuid.UUID) { + queryParams := url.Values{} + if vnetID != nil { + queryParams.Set("virtual_network_id", vnetID.String()) + } + endpoint.RawQuery = queryParams.Encode() +} diff --git a/cfapi/ip_route_filter.go b/cfapi/ip_route_filter.go new file mode 100644 index 00000000..1f0301ac --- /dev/null +++ b/cfapi/ip_route_filter.go @@ -0,0 +1,165 @@ +package cfapi + +import ( + "fmt" + "net" + "net/url" + "strconv" + "time" + + "github.com/google/uuid" + "github.com/pkg/errors" + "github.com/urfave/cli/v2" +) + +var ( + filterIpRouteDeleted = cli.BoolFlag{ + Name: "filter-is-deleted", + Usage: "If false (default), only show non-deleted routes. If true, only show deleted routes.", + } + filterIpRouteTunnelID = cli.StringFlag{ + Name: "filter-tunnel-id", + Usage: "Show only routes with the given tunnel ID.", + } + filterSubsetIpRoute = cli.StringFlag{ + Name: "filter-network-is-subset-of", + Aliases: []string{"nsub"}, + Usage: "Show only routes whose network is a subset of the given network.", + } + filterSupersetIpRoute = cli.StringFlag{ + Name: "filter-network-is-superset-of", + Aliases: []string{"nsup"}, + Usage: "Show only routes whose network is a superset of the given network.", + } + filterIpRouteComment = cli.StringFlag{ + Name: "filter-comment-is", + Usage: "Show only routes with this comment.", + } + filterIpRouteByVnet = cli.StringFlag{ + Name: "filter-virtual-network-id", + Usage: "Show only routes that are attached to the given virtual network ID.", + } + + // Flags contains all filter flags. + IpRouteFilterFlags = []cli.Flag{ + &filterIpRouteDeleted, + &filterIpRouteTunnelID, + &filterSubsetIpRoute, + &filterSupersetIpRoute, + &filterIpRouteComment, + &filterIpRouteByVnet, + } +) + +// IpRouteFilter which routes get queried. +type IpRouteFilter struct { + queryParams url.Values +} + +// NewIpRouteFilterFromCLI parses CLI flags to discover which filters should get applied. +func NewIpRouteFilterFromCLI(c *cli.Context) (*IpRouteFilter, error) { + f := &IpRouteFilter{ + queryParams: url.Values{}, + } + + // Set deletion filter + if flag := filterIpRouteDeleted.Name; c.IsSet(flag) && c.Bool(flag) { + f.deleted() + } else { + f.notDeleted() + } + + if subset, err := cidrFromFlag(c, filterSubsetIpRoute); err != nil { + return nil, err + } else if subset != nil { + f.networkIsSupersetOf(*subset) + } + + if superset, err := cidrFromFlag(c, filterSupersetIpRoute); err != nil { + return nil, err + } else if superset != nil { + f.networkIsSupersetOf(*superset) + } + + if comment := c.String(filterIpRouteComment.Name); comment != "" { + f.commentIs(comment) + } + + if tunnelID := c.String(filterIpRouteTunnelID.Name); tunnelID != "" { + u, err := uuid.Parse(tunnelID) + if err != nil { + return nil, errors.Wrapf(err, "Couldn't parse UUID from %s", filterIpRouteTunnelID.Name) + } + f.tunnelID(u) + } + + if vnetId := c.String(filterIpRouteByVnet.Name); vnetId != "" { + u, err := uuid.Parse(vnetId) + if err != nil { + return nil, errors.Wrapf(err, "Couldn't parse UUID from %s", filterIpRouteByVnet.Name) + } + f.vnetID(u) + } + + if maxFetch := c.Int("max-fetch-size"); maxFetch > 0 { + f.MaxFetchSize(uint(maxFetch)) + } + + return f, nil +} + +// Parses a CIDR from the flag. If the flag was unset, returns (nil, nil). +func cidrFromFlag(c *cli.Context, flag cli.StringFlag) (*net.IPNet, error) { + if !c.IsSet(flag.Name) { + return nil, nil + } + + _, subset, err := net.ParseCIDR(c.String(flag.Name)) + if err != nil { + return nil, err + } else if subset == nil { + return nil, fmt.Errorf("Invalid CIDR supplied for %s", flag.Name) + } + + return subset, nil +} + +func (f *IpRouteFilter) commentIs(comment string) { + f.queryParams.Set("comment", comment) +} + +func (f *IpRouteFilter) notDeleted() { + f.queryParams.Set("is_deleted", "false") +} + +func (f *IpRouteFilter) deleted() { + f.queryParams.Set("is_deleted", "true") +} + +func (f *IpRouteFilter) networkIsSubsetOf(superset net.IPNet) { + f.queryParams.Set("network_subset", superset.String()) +} + +func (f *IpRouteFilter) networkIsSupersetOf(subset net.IPNet) { + f.queryParams.Set("network_superset", subset.String()) +} + +func (f *IpRouteFilter) existedAt(existedAt time.Time) { + f.queryParams.Set("existed_at", existedAt.Format(time.RFC3339)) +} + +func (f *IpRouteFilter) tunnelID(id uuid.UUID) { + f.queryParams.Set("tunnel_id", id.String()) +} + +func (f *IpRouteFilter) vnetID(id uuid.UUID) { + f.queryParams.Set("virtual_network_id", id.String()) +} + +func (f *IpRouteFilter) MaxFetchSize(max uint) { + f.queryParams.Set("per_page", strconv.Itoa(int(max))) +} + +func (f IpRouteFilter) Encode() string { + return f.queryParams.Encode() +} diff --git a/cfapi/ip_route_test.go b/cfapi/ip_route_test.go new file mode 100644 index 00000000..93fa45ea --- /dev/null +++ b/cfapi/ip_route_test.go @@ -0,0 +1,175 @@ +package cfapi + +import ( + "encoding/json" + "fmt" + "net" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestUnmarshalRoute(t *testing.T) { + testCases := []struct { + Json string + HasVnet bool + }{ + { + `{ + "network":"10.1.2.40/29", + "tunnel_id":"fba6ffea-807f-4e7a-a740-4184ee1b82c8", + "comment":"test", + "created_at":"2020-12-22T02:00:15.587008Z", + "deleted_at":null + }`, + false, + }, + { + `{ + "network":"10.1.2.40/29", + "tunnel_id":"fba6ffea-807f-4e7a-a740-4184ee1b82c8", + "comment":"test", + "created_at":"2020-12-22T02:00:15.587008Z", + "deleted_at":null, + "virtual_network_id":"38c95083-8191-4110-8339-3f438d44fdb9" + }`, + true, + }, + } + + for _, testCase := range testCases { + data := testCase.Json + + var r Route + err := json.Unmarshal([]byte(data), &r) + + // Check everything worked + require.NoError(t, err) + require.Equal(t, uuid.MustParse("fba6ffea-807f-4e7a-a740-4184ee1b82c8"), r.TunnelID) + require.Equal(t, "test", r.Comment) + _, cidr, err := net.ParseCIDR("10.1.2.40/29") + require.NoError(t, err) + require.Equal(t, CIDR(*cidr), r.Network) + require.Equal(t, "test", r.Comment) + + if testCase.HasVnet { + require.Equal(t, uuid.MustParse("38c95083-8191-4110-8339-3f438d44fdb9"), *r.VNetID) + } else { + require.Nil(t, r.VNetID) + } + } +} + +func TestDetailedRouteJsonRoundtrip(t *testing.T) { + testCases := []struct { + Json string + HasVnet bool + }{ + { + `{ + "network":"10.1.2.40/29", + "tunnel_id":"fba6ffea-807f-4e7a-a740-4184ee1b82c8", + "comment":"test", + "created_at":"2020-12-22T02:00:15.587008Z", + "deleted_at":"2021-01-14T05:01:42.183002Z", + "tunnel_name":"Mr. Tun" + }`, + false, + }, + { + `{ + "network":"10.1.2.40/29", + "tunnel_id":"fba6ffea-807f-4e7a-a740-4184ee1b82c8", + "virtual_network_id":"38c95083-8191-4110-8339-3f438d44fdb9", + "comment":"test", + "created_at":"2020-12-22T02:00:15.587008Z", + "deleted_at":"2021-01-14T05:01:42.183002Z", + "tunnel_name":"Mr. Tun" + }`, + true, + }, + } + + for _, testCase := range testCases { + data := testCase.Json + + var r DetailedRoute + err := json.Unmarshal([]byte(data), &r) + + // Check everything worked + require.NoError(t, err) + require.Equal(t, uuid.MustParse("fba6ffea-807f-4e7a-a740-4184ee1b82c8"), r.TunnelID) + require.Equal(t, "test", r.Comment) + _, cidr, err := net.ParseCIDR("10.1.2.40/29") + require.NoError(t, err) + require.Equal(t, CIDR(*cidr), r.Network) + require.Equal(t, "test", r.Comment) + require.Equal(t, "Mr. Tun", r.TunnelName) + + if testCase.HasVnet { + require.Equal(t, uuid.MustParse("38c95083-8191-4110-8339-3f438d44fdb9"), *r.VNetID) + } else { + require.Nil(t, r.VNetID) + } + + bytes, err := json.Marshal(r) + require.NoError(t, err) + obtainedJson := string(bytes) + data = strings.Replace(data, "\t", "", -1) + data = strings.Replace(data, "\n", "", -1) + require.Equal(t, data, obtainedJson) + } +} + +func TestMarshalNewRoute(t *testing.T) { + _, network, err := net.ParseCIDR("1.2.3.4/32") + require.NoError(t, err) + require.NotNil(t, network) + vnetId := uuid.New() + + newRoutes := []NewRoute{ + { + Network: *network, + TunnelID: uuid.New(), + Comment: "hi", + }, + { + Network: *network, + TunnelID: uuid.New(), + Comment: "hi", + VNetID: &vnetId, + }, + } + + for _, newRoute := range newRoutes { + // Test where receiver is struct + serialized, err := json.Marshal(newRoute) + require.NoError(t, err) + require.True(t, strings.Contains(string(serialized), "tunnel_id")) + + // Test where receiver is pointer to struct + serialized, err = json.Marshal(&newRoute) + require.NoError(t, err) + require.True(t, strings.Contains(string(serialized), "tunnel_id")) + + if newRoute.VNetID == nil { + require.False(t, strings.Contains(string(serialized), "virtual_network_id")) + } else { + require.True(t, strings.Contains(string(serialized), "virtual_network_id")) + } + } +} + +func TestRouteTableString(t *testing.T) { + _, network, err := net.ParseCIDR("1.2.3.4/32") + require.NoError(t, err) + require.NotNil(t, network) + r := DetailedRoute{ + Network: CIDR(*network), + } + row := r.TableString() + fmt.Println(row) + require.True(t, strings.HasPrefix(row, "1.2.3.4/32")) +} diff --git a/cfapi/tunnel.go b/cfapi/tunnel.go new file mode 100644 index 00000000..5d4ea298 --- /dev/null +++ b/cfapi/tunnel.go @@ -0,0 +1,183 @@ +package cfapi + +import ( + "fmt" + "io" + "net" + "net/http" + "net/url" + "path" + "time" + + "github.com/google/uuid" + "github.com/pkg/errors" +) + +var ErrTunnelNameConflict = errors.New("tunnel with name already exists") + +type Tunnel struct { + ID uuid.UUID `json:"id"` + Name string `json:"name"` + CreatedAt time.Time `json:"created_at"` + DeletedAt time.Time `json:"deleted_at"` + Connections []Connection `json:"connections"` +} + +type Connection struct { + ColoName string `json:"colo_name"` + ID uuid.UUID `json:"id"` + IsPendingReconnect bool `json:"is_pending_reconnect"` + OriginIP net.IP `json:"origin_ip"` + OpenedAt time.Time `json:"opened_at"` +} + +type ActiveClient struct { + ID uuid.UUID `json:"id"` + Features []string `json:"features"` + Version string `json:"version"` + Arch string `json:"arch"` + RunAt time.Time `json:"run_at"` + Connections []Connection `json:"conns"` +} + +type newTunnel struct { + Name string `json:"name"` + TunnelSecret []byte `json:"tunnel_secret"` +} + +type CleanupParams struct { + queryParams url.Values +} + +func NewCleanupParams() *CleanupParams { + return &CleanupParams{ + queryParams: url.Values{}, + } +} + +func (cp *CleanupParams) ForClient(clientID uuid.UUID) { + cp.queryParams.Set("client_id", clientID.String()) +} + +func (cp CleanupParams) encode() string { + return cp.queryParams.Encode() +} + +func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) { + if name == "" { + return nil, errors.New("tunnel name required") + } + if _, err := uuid.Parse(name); err == nil { + return nil, errors.New("you cannot use UUIDs as tunnel names") + } + body := &newTunnel{ + Name: name, + TunnelSecret: tunnelSecret, + } + + resp, err := r.sendRequest("POST", r.baseEndpoints.accountLevel, body) + if err != nil { + return nil, errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + return unmarshalTunnel(resp.Body) + case http.StatusConflict: + return nil, ErrTunnelNameConflict + } + + return nil, r.statusCodeToError("create tunnel", resp) +} + +func (r *RESTClient) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) { + endpoint := r.baseEndpoints.accountLevel + endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID)) + resp, err := r.sendRequest("GET", endpoint, nil) + if err != nil { + return nil, errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return unmarshalTunnel(resp.Body) + } + + return nil, r.statusCodeToError("get tunnel", resp) +} + +func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error { + endpoint := r.baseEndpoints.accountLevel + endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID)) + resp, err := r.sendRequest("DELETE", endpoint, nil) + if err != nil { + return errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + return r.statusCodeToError("delete tunnel", resp) +} + +func (r *RESTClient) ListTunnels(filter *TunnelFilter) ([]*Tunnel, error) { + endpoint := r.baseEndpoints.accountLevel + endpoint.RawQuery = filter.encode() + resp, err := r.sendRequest("GET", endpoint, nil) + if err != nil { + return nil, errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return parseListTunnels(resp.Body) + } + + return nil, r.statusCodeToError("list tunnels", resp) +} + +func parseListTunnels(body io.ReadCloser) ([]*Tunnel, error) { + var tunnels []*Tunnel + err := parseResponse(body, &tunnels) + return tunnels, err +} + +func (r *RESTClient) ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error) { + endpoint := r.baseEndpoints.accountLevel + endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/connections", tunnelID)) + resp, err := r.sendRequest("GET", endpoint, nil) + if err != nil { + return nil, errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return parseConnectionsDetails(resp.Body) + } + + return nil, r.statusCodeToError("list connection details", resp) +} + +func parseConnectionsDetails(reader io.Reader) ([]*ActiveClient, error) { + var clients []*ActiveClient + err := parseResponse(reader, &clients) + return clients, err +} + +func (r *RESTClient) CleanupConnections(tunnelID uuid.UUID, params *CleanupParams) error { + endpoint := r.baseEndpoints.accountLevel + endpoint.RawQuery = params.encode() + endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/connections", tunnelID)) + resp, err := r.sendRequest("DELETE", endpoint, nil) + if err != nil { + return errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + return r.statusCodeToError("cleanup connections", resp) +} + +func unmarshalTunnel(reader io.Reader) (*Tunnel, error) { + var tunnel Tunnel + err := parseResponse(reader, &tunnel) + return &tunnel, err +} diff --git a/cfapi/tunnel_filter.go b/cfapi/tunnel_filter.go new file mode 100644 index 00000000..df8932bc --- /dev/null +++ b/cfapi/tunnel_filter.go @@ -0,0 +1,55 @@ +package cfapi + +import ( + "net/url" + "strconv" + "time" + + "github.com/google/uuid" +) + +const ( + TimeLayout = time.RFC3339 +) + +type TunnelFilter struct { + queryParams url.Values +} + +func NewTunnelFilter() *TunnelFilter { + return &TunnelFilter{ + queryParams: url.Values{}, + } +} + +func (f *TunnelFilter) ByName(name string) { + f.queryParams.Set("name", name) +} + +func (f *TunnelFilter) ByNamePrefix(namePrefix string) { + f.queryParams.Set("name_prefix", namePrefix) +} + +func (f *TunnelFilter) ExcludeNameWithPrefix(excludePrefix string) { + f.queryParams.Set("exclude_prefix", excludePrefix) +} + +func (f *TunnelFilter) NoDeleted() { + f.queryParams.Set("is_deleted", "false") +} + +func (f *TunnelFilter) ByExistedAt(existedAt time.Time) { + f.queryParams.Set("existed_at", existedAt.Format(TimeLayout)) +} + +func (f *TunnelFilter) ByTunnelID(tunnelID uuid.UUID) { + f.queryParams.Set("uuid", tunnelID.String()) +} + +func (f *TunnelFilter) MaxFetchSize(max uint) { + f.queryParams.Set("per_page", strconv.Itoa(int(max))) +} + +func (f TunnelFilter) encode() string { + return f.queryParams.Encode() +} diff --git a/cfapi/tunnel_test.go b/cfapi/tunnel_test.go new file mode 100644 index 00000000..fbb5fba8 --- /dev/null +++ b/cfapi/tunnel_test.go @@ -0,0 +1,149 @@ +package cfapi + +import ( + "bytes" + "fmt" + "io/ioutil" + "net" + "reflect" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +var loc, _ = time.LoadLocation("UTC") + +func Test_parseListTunnels(t *testing.T) { + type args struct { + body string + } + tests := []struct { + name string + args args + want []*Tunnel + wantErr bool + }{ + { + name: "empty list", + args: args{body: `{"success": true, "result": []}`}, + want: []*Tunnel{}, + }, + { + name: "success is false", + args: args{body: `{"success": false, "result": []}`}, + wantErr: true, + }, + { + name: "errors are present", + args: args{body: `{"errors": [{"code": 1003, "message":"An A, AAAA or CNAME record already exists with that host"}], "result": []}`}, + wantErr: true, + }, + { + name: "invalid response", + args: args{body: `abc`}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := ioutil.NopCloser(bytes.NewReader([]byte(tt.args.body))) + got, err := parseListTunnels(body) + if (err != nil) != tt.wantErr { + t.Errorf("parseListTunnels() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseListTunnels() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_unmarshalTunnel(t *testing.T) { + type args struct { + body string + } + tests := []struct { + name string + args args + want *Tunnel + wantErr bool + }{ + { + name: "empty list", + args: args{body: `{"success": true, "result": {"id":"b34cc7ce-925b-46ee-bc23-4cb5c18d8292","created_at":"2021-07-29T13:46:14.090955Z","deleted_at":"2021-07-29T14:07:27.559047Z","name":"qt-bIWWN7D662ogh61pCPfu5s2XgqFY1OyV","account_id":6946212,"account_tag":"5ab4e9dfbd435d24068829fda0077963","conns_active_at":null,"conns_inactive_at":"2021-07-29T13:47:22.548482Z","tun_type":"cfd_tunnel","metadata":{"qtid":"a6fJROgkXutNruBGaJjD"}}}`}, + want: &Tunnel{ + ID: uuid.MustParse("b34cc7ce-925b-46ee-bc23-4cb5c18d8292"), + Name: "qt-bIWWN7D662ogh61pCPfu5s2XgqFY1OyV", + CreatedAt: time.Date(2021, 07, 29, 13, 46, 14, 90955000, loc), + DeletedAt: time.Date(2021, 07, 29, 14, 7, 27, 559047000, loc), + Connections: nil, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := unmarshalTunnel(strings.NewReader(tt.args.body)) + if (err != nil) != tt.wantErr { + t.Errorf("unmarshalTunnel() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("unmarshalTunnel() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnmarshalTunnelOk(t *testing.T) { + + jsonBody := `{"success": true, "result": {"id": "00000000-0000-0000-0000-000000000000","name":"test","created_at":"0001-01-01T00:00:00Z","connections":[]}}` + expected := Tunnel{ + ID: uuid.Nil, + Name: "test", + CreatedAt: time.Time{}, + Connections: []Connection{}, + } + actual, err := unmarshalTunnel(bytes.NewReader([]byte(jsonBody))) + assert.NoError(t, err) + assert.Equal(t, &expected, actual) +} + +func TestUnmarshalTunnelErr(t *testing.T) { + + tests := []string{ + `abc`, + `{"success": true, "result": abc}`, + `{"success": false, "result": {"id": "00000000-0000-0000-0000-000000000000","name":"test","created_at":"0001-01-01T00:00:00Z","connections":[]}}}`, + `{"errors": [{"code": 1003, "message":"An A, AAAA or CNAME record already exists with that host"}], "result": {"id": "00000000-0000-0000-0000-000000000000","name":"test","created_at":"0001-01-01T00:00:00Z","connections":[]}}}`, + } + + for i, test := range tests { + _, err := unmarshalTunnel(bytes.NewReader([]byte(test))) + assert.Error(t, err, fmt.Sprintf("Test #%v failed", i)) + } +} + +func TestUnmarshalConnections(t *testing.T) { + jsonBody := `{"success":true,"messages":[],"errors":[],"result":[{"id":"d4041254-91e3-4deb-bd94-b46e11680b1e","features":["ha-origin"],"version":"2021.2.5","arch":"darwin_amd64","conns":[{"colo_name":"LIS","id":"ac2286e5-c708-4588-a6a0-ba6b51940019","is_pending_reconnect":false,"origin_ip":"148.38.28.2","opened_at":"0001-01-01T00:00:00Z"}],"run_at":"0001-01-01T00:00:00Z"}]}` + expected := ActiveClient{ + ID: uuid.MustParse("d4041254-91e3-4deb-bd94-b46e11680b1e"), + Features: []string{"ha-origin"}, + Version: "2021.2.5", + Arch: "darwin_amd64", + RunAt: time.Time{}, + Connections: []Connection{{ + ID: uuid.MustParse("ac2286e5-c708-4588-a6a0-ba6b51940019"), + ColoName: "LIS", + IsPendingReconnect: false, + OriginIP: net.ParseIP("148.38.28.2"), + OpenedAt: time.Time{}, + }}, + } + actual, err := parseConnectionsDetails(bytes.NewReader([]byte(jsonBody))) + assert.NoError(t, err) + assert.Equal(t, []*ActiveClient{&expected}, actual) +} diff --git a/cfapi/virtual_network.go b/cfapi/virtual_network.go new file mode 100644 index 00000000..e346678a --- /dev/null +++ b/cfapi/virtual_network.go @@ -0,0 +1,127 @@ +package cfapi + +import ( + "fmt" + "io" + "net/http" + "net/url" + "path" + "strconv" + "time" + + "github.com/google/uuid" + "github.com/pkg/errors" +) + +type NewVirtualNetwork struct { + Name string `json:"name"` + Comment string `json:"comment"` + IsDefault bool `json:"is_default"` +} + +type VirtualNetwork struct { + ID uuid.UUID `json:"id"` + Comment string `json:"comment"` + Name string `json:"name"` + IsDefault bool `json:"is_default_network"` + CreatedAt time.Time `json:"created_at"` + DeletedAt time.Time `json:"deleted_at"` +} + +type UpdateVirtualNetwork struct { + Name *string `json:"name,omitempty"` + Comment *string `json:"comment,omitempty"` + IsDefault *bool `json:"is_default_network,omitempty"` +} + +func (virtualNetwork VirtualNetwork) TableString() string { + deletedColumn := "-" + if !virtualNetwork.DeletedAt.IsZero() { + deletedColumn = virtualNetwork.DeletedAt.Format(time.RFC3339) + } + return fmt.Sprintf( + "%s\t%s\t%s\t%s\t%s\t%s\t", + virtualNetwork.ID, + virtualNetwork.Name, + strconv.FormatBool(virtualNetwork.IsDefault), + virtualNetwork.Comment, + virtualNetwork.CreatedAt.Format(time.RFC3339), + deletedColumn, + ) +} + +func (r *RESTClient) CreateVirtualNetwork(newVnet NewVirtualNetwork) (VirtualNetwork, error) { + resp, err := r.sendRequest("POST", r.baseEndpoints.accountVnets, newVnet) + if err != nil { + return VirtualNetwork{}, errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return parseVnet(resp.Body) + } + + return VirtualNetwork{}, r.statusCodeToError("add virtual network", resp) +} + +func (r *RESTClient) ListVirtualNetworks(filter *VnetFilter) ([]*VirtualNetwork, error) { + endpoint := r.baseEndpoints.accountVnets + endpoint.RawQuery = filter.Encode() + resp, err := r.sendRequest("GET", endpoint, nil) + if err != nil { + return nil, errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return parseListVnets(resp.Body) + } + + return nil, r.statusCodeToError("list virtual networks", resp) +} + +func (r *RESTClient) DeleteVirtualNetwork(id uuid.UUID) error { + endpoint := r.baseEndpoints.accountVnets + endpoint.Path = path.Join(endpoint.Path, url.PathEscape(id.String())) + resp, err := r.sendRequest("DELETE", endpoint, nil) + if err != nil { + return errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + _, err := parseVnet(resp.Body) + return err + } + + return r.statusCodeToError("delete virtual network", resp) +} + +func (r *RESTClient) UpdateVirtualNetwork(id uuid.UUID, updates UpdateVirtualNetwork) error { + endpoint := r.baseEndpoints.accountVnets + endpoint.Path = path.Join(endpoint.Path, url.PathEscape(id.String())) + resp, err := r.sendRequest("PATCH", endpoint, updates) + if err != nil { + return errors.Wrap(err, "REST request failed") + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + _, err := parseVnet(resp.Body) + return err + } + + return r.statusCodeToError("update virtual network", resp) +} + +func parseListVnets(body io.ReadCloser) ([]*VirtualNetwork, error) { + var vnets []*VirtualNetwork + err := parseResponse(body, &vnets) + return vnets, err +} + +func parseVnet(body io.ReadCloser) (VirtualNetwork, error) { + var vnet VirtualNetwork + err := parseResponse(body, &vnet) + return vnet, err +} diff --git a/cfapi/virtual_network_filter.go b/cfapi/virtual_network_filter.go new file mode 100644 index 00000000..ddba442d --- /dev/null +++ b/cfapi/virtual_network_filter.go @@ -0,0 +1,99 @@ +package cfapi + +import ( + "net/url" + "strconv" + + "github.com/google/uuid" + "github.com/pkg/errors" + "github.com/urfave/cli/v2" +) + +var ( + filterVnetId = cli.StringFlag{ + Name: "id", + Usage: "List virtual networks with the given `ID`", + } + filterVnetByName = cli.StringFlag{ + Name: "name", + Usage: "List virtual networks with the given `NAME`", + } + filterDefaultVnet = cli.BoolFlag{ + Name: "is-default", + Usage: "If true, lists the virtual network that is the default one. If false, lists all non-default virtual networks for the account. If absent, all are included in the results regardless of their default status.", + } + filterDeletedVnet = cli.BoolFlag{ + Name: "show-deleted", + Usage: "If false (default), only show non-deleted virtual networks. If true, only show deleted virtual networks.", + } + VnetFilterFlags = []cli.Flag{ + &filterVnetId, + &filterVnetByName, + &filterDefaultVnet, + &filterDeletedVnet, + } +) + +// VnetFilter which virtual networks get queried. +type VnetFilter struct { + queryParams url.Values +} + +func NewVnetFilter() *VnetFilter { + return &VnetFilter{ + queryParams: url.Values{}, + } +} + +func (f *VnetFilter) ById(vnetId uuid.UUID) { + f.queryParams.Set("id", vnetId.String()) +} + +func (f *VnetFilter) ByName(name string) { + f.queryParams.Set("name", name) +} + +func (f *VnetFilter) ByDefaultStatus(isDefault bool) { + f.queryParams.Set("is_default", strconv.FormatBool(isDefault)) +} + +func (f *VnetFilter) WithDeleted(isDeleted bool) { + f.queryParams.Set("is_deleted", strconv.FormatBool(isDeleted)) +} + +func (f *VnetFilter) MaxFetchSize(max uint) { + f.queryParams.Set("per_page", strconv.Itoa(int(max))) +} + +func (f VnetFilter) Encode() string { + return f.queryParams.Encode() +} + +// NewFromCLI parses CLI flags to discover which filters should get applied to list virtual networks. +func NewFromCLI(c *cli.Context) (*VnetFilter, error) { + f := NewVnetFilter() + + if id := c.String("id"); id != "" { + vnetId, err := uuid.Parse(id) + if err != nil { + return nil, errors.Wrapf(err, "%s is not a valid virtual network ID", id) + } + f.ById(vnetId) + } + + if name := c.String("name"); name != "" { + f.ByName(name) + } + + if c.IsSet("is-default") { + f.ByDefaultStatus(c.Bool("is-default")) + } + + f.WithDeleted(c.Bool("show-deleted")) + + if maxFetch := c.Int("max-fetch-size"); maxFetch > 0 { + f.MaxFetchSize(uint(maxFetch)) + } + + return f, nil +} diff --git a/cfapi/virtual_network_test.go b/cfapi/virtual_network_test.go new file mode 100644 index 00000000..528922c2 --- /dev/null +++ b/cfapi/virtual_network_test.go @@ -0,0 +1,79 @@ +package cfapi + +import ( + "encoding/json" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestVirtualNetworkJsonRoundtrip(t *testing.T) { + data := `{ + "id":"74fce949-351b-4752-b261-81a56cfd3130", + "comment":"New York DC1", + "name":"us-east-1", + "is_default_network":true, + "created_at":"2021-11-26T14:40:02.600673Z", + "deleted_at":"2021-12-01T10:23:13.102645Z" + }` + var v VirtualNetwork + err := json.Unmarshal([]byte(data), &v) + + require.NoError(t, err) + require.Equal(t, uuid.MustParse("74fce949-351b-4752-b261-81a56cfd3130"), v.ID) + require.Equal(t, "us-east-1", v.Name) + require.Equal(t, "New York DC1", v.Comment) + require.Equal(t, true, v.IsDefault) + + bytes, err := json.Marshal(v) + require.NoError(t, err) + obtainedJson := string(bytes) + data = strings.Replace(data, "\t", "", -1) + data = strings.Replace(data, "\n", "", -1) + require.Equal(t, data, obtainedJson) +} + +func TestMarshalNewVnet(t *testing.T) { + newVnet := NewVirtualNetwork{ + Name: "eu-west-1", + Comment: "London office", + IsDefault: true, + } + + serialized, err := json.Marshal(newVnet) + require.NoError(t, err) + require.True(t, strings.Contains(string(serialized), newVnet.Name)) +} + +func TestMarshalUpdateVnet(t *testing.T) { + newName := "bulgaria-1" + updates := UpdateVirtualNetwork{ + Name: &newName, + } + + // Test where receiver is struct + serialized, err := json.Marshal(updates) + require.NoError(t, err) + require.True(t, strings.Contains(string(serialized), newName)) +} + +func TestVnetTableString(t *testing.T) { + virtualNet := VirtualNetwork{ + ID: uuid.New(), + Name: "us-east-1", + Comment: "New York DC1", + IsDefault: true, + CreatedAt: time.Now(), + DeletedAt: time.Time{}, + } + + row := virtualNet.TableString() + require.True(t, strings.HasPrefix(row, virtualNet.ID.String())) + require.True(t, strings.Contains(row, virtualNet.Name)) + require.True(t, strings.Contains(row, virtualNet.Comment)) + require.True(t, strings.Contains(row, "true")) + require.True(t, strings.HasSuffix(row, "-\t")) +} diff --git a/cfsetup.yaml b/cfsetup.yaml index 1ebb1f64..39aa09a5 100644 --- a/cfsetup.yaml +++ b/cfsetup.yaml @@ -1,20 +1,10 @@ -pinned_go: &pinned_go go=1.17-1 -pinned_go_fips: &pinned_go_fips go-boring=1.16.6-6 +pinned_go: &pinned_go go=1.17.5-1 +pinned_go_fips: &pinned_go_fips go-boring=1.17.5-1 build_dir: &build_dir /cfsetup_build default-flavor: buster stretch: &stretch build: - build_dir: *build_dir - builddeps: - - *pinned_go_fips - - build-essential - post-cache: - - export GOOS=linux - - export GOARCH=amd64 - - export FIPS=true - - make cloudflared - build-non-fips: # helpful to catch problems with non-fips (only used for releasing non-linux artifacts) before releases build_dir: *build_dir builddeps: - *pinned_go @@ -23,27 +13,45 @@ stretch: &stretch - export GOOS=linux - export GOARCH=amd64 - make cloudflared - build-all-packages: #except osxpkg + build-fips: build_dir: *build_dir builddeps: - *pinned_go_fips - build-essential - - fakeroot - - rubygem-fpm - - rpm - - wget - # libmsi and libgcab are libraries the wixl binary depends on. - - libmsi-dev - - libgcab-dev - pre-cache: - # TODO: https://jira.cfops.it/browse/TUN-4792 Replace this wixl with the official one once msitools supports - # environment. - - wget https://github.com/sudarshan-reddy/msitools/releases/download/v0.101b/wixl -P /usr/local/bin - - chmod a+x /usr/local/bin/wixl post-cache: + - export GOOS=linux + - export GOARCH=amd64 - export FIPS=true - - ./build-packages.sh + - make cloudflared + # except FIPS (handled in github-fips-release-pkgs) and macos (handled in github-release-macos-amd64) github-release-pkgs: + build_dir: *build_dir + builddeps: + - *pinned_go + - build-essential + - fakeroot + - rubygem-fpm + - rpm + - wget + # libmsi and libgcab are libraries the wixl binary depends on. + - libmsi-dev + - libgcab-dev + - python3-dev + - libffi-dev + - python3-setuptools + - python3-pip + pre-cache: &github_release_pkgs_pre_cache + - wget https://github.com/sudarshan-reddy/msitools/releases/download/v0.101b/wixl -P /usr/local/bin + - chmod a+x /usr/local/bin/wixl + - pip3 install pynacl==1.4.0 + - pip3 install pygithub==1.55 + post-cache: + # build all packages (except macos and FIPS) and move them to /cfsetup/built_artifacts + - ./build-packages.sh + # release the packages built and moved to /cfsetup/built_artifacts + - make github-release-built-pkgs + # handle FIPS separately so that we built with gofips compiler + github-fips-release-pkgs: build_dir: *build_dir builddeps: - *pinned_go_fips @@ -55,20 +63,29 @@ stretch: &stretch # libmsi and libgcab are libraries the wixl binary depends on. - libmsi-dev - libgcab-dev + - python3-dev + - libffi-dev - python3-setuptools - python3-pip - pre-cache: - - wget https://github.com/sudarshan-reddy/msitools/releases/download/v0.101b/wixl -P /usr/local/bin - - chmod a+x /usr/local/bin/wixl - - pip3 install pygithub + pre-cache: *github_release_pkgs_pre_cache post-cache: - # build all packages and move them to /cfsetup/built_artifacts - - ./build-packages.sh - # release the packages built and moved to /cfsetup/built_artifacts + # same logic as above, but for FIPS packages only + - ./build-packages-fips.sh - make github-release-built-pkgs build-deb: build_dir: *build_dir builddeps: &build_deb_deps + - *pinned_go + - build-essential + - fakeroot + - rubygem-fpm + post-cache: + - export GOOS=linux + - export GOARCH=amd64 + - make cloudflared-deb + build-fips-internal-deb: + build_dir: *build_dir + builddeps: &build_fips_deb_deps - *pinned_go_fips - build-essential - fakeroot @@ -77,15 +94,17 @@ stretch: &stretch - export GOOS=linux - export GOARCH=amd64 - export FIPS=true + - export ORIGINAL_NAME=true - make cloudflared-deb - build-deb-nightly: + build-fips-internal-deb-nightly: build_dir: *build_dir - builddeps: *build_deb_deps + builddeps: *build_fips_deb_deps post-cache: - export GOOS=linux - export GOARCH=amd64 - - export FIPS=true - export NIGHTLY=true + - export FIPS=true + - export ORIGINAL_NAME=true - make cloudflared-deb build-deb-arm64: build_dir: *build_dir @@ -97,7 +116,7 @@ stretch: &stretch publish-deb: build_dir: *build_dir builddeps: - - *pinned_go_fips + - *pinned_go - build-essential - fakeroot - rubygem-fpm @@ -105,27 +124,43 @@ stretch: &stretch post-cache: - export GOOS=linux - export GOARCH=amd64 - - export FIPS=true - make publish-deb github-release-macos-amd64: build_dir: *build_dir builddeps: - *pinned_go + - build-essential + - python3-dev + - libffi-dev - python3-setuptools - python3-pip pre-cache: &install_pygithub - - pip3 install pygithub + - pip3 install pynacl==1.4.0 + - pip3 install pygithub==1.55 post-cache: - make github-mac-upload test: + build_dir: *build_dir + builddeps: + - *pinned_go + - build-essential + - gotest-to-teamcity + pre-cache: &test_pre_cache + - go get golang.org/x/tools/cmd/goimports + - go get github.com/sudarshan-reddy/go-sumtype@v0.0.0-20210827105221-82eca7e5abb1 + post-cache: + - export GOOS=linux + - export GOARCH=amd64 + - export PATH="$HOME/go/bin:$PATH" + - ./fmt-check.sh + - make test | gotest-to-teamcity + test-fips: build_dir: *build_dir builddeps: - *pinned_go_fips - build-essential - gotest-to-teamcity - pre-cache: - - go get golang.org/x/tools/cmd/goimports - - go get github.com/sudarshan-reddy/go-sumtype@v0.0.0-20210827105221-82eca7e5abb1 + pre-cache: *test_pre_cache post-cache: - export GOOS=linux - export GOARCH=amd64 @@ -150,7 +185,7 @@ stretch: &stretch post-cache: # Creates and routes a Named Tunnel for this build. Also constructs config file from env vars. - python3 component-tests/setup.py --type create - - pytest component-tests + - pytest component-tests -o log_cli=true --log-cli-level=INFO # The Named Tunnel is deleted and its route unprovisioned here. - python3 component-tests/setup.py --type cleanup update-homebrew: @@ -163,6 +198,9 @@ stretch: &stretch build_dir: *build_dir builddeps: - *pinned_go + - build-essential + - python3-dev + - libffi-dev - python3-setuptools - python3-pip pre-cache: *install_pygithub @@ -208,23 +246,10 @@ centos-7: pre-cache: - yum install -y fakeroot - yum upgrade -y binutils-2.27-44.base.el7.x86_64 - - wget https://golang.org/dl/go1.16.3.linux-amd64.tar.gz -P /tmp/ - - tar -C /usr/local -xzf /tmp/go1.16.3.linux-amd64.tar.gz + - wget https://go.dev/dl/go1.17.5.linux-amd64.tar.gz -P /tmp/ + - tar -C /usr/local -xzf /tmp/go1.17.5.linux-amd64.tar.gz post-cache: - export PATH=$PATH:/usr/local/go/bin - export GOOS=linux - export GOARCH=amd64 - - make publish-rpm - build-rpm: - build_dir: *build_dir - builddeps: *el7_builddeps - pre-cache: - - yum install -y fakeroot - - yum upgrade -y binutils-2.27-44.base.el7.x86_64 - - wget https://golang.org/dl/go1.16.3.linux-amd64.tar.gz -P /tmp/ - - tar -C /usr/local -xzf /tmp/go1.16.3.linux-amd64.tar.gz - post-cache: - - export PATH=$PATH:/usr/local/go/bin - - export GOOS=linux - - export GOARCH=amd64 - - make cloudflared-rpm + - make publish-rpm \ No newline at end of file diff --git a/check-fips.sh b/check-fips.sh new file mode 100755 index 00000000..98c05af1 --- /dev/null +++ b/check-fips.sh @@ -0,0 +1,15 @@ +# Pass the path to the executable to check for FIPS compliance +exe=$1 + +if [ "$(go tool nm "${exe}" | grep -c '_Cfunc__goboringcrypto_')" -eq 0 ]; then + # Asserts that executable is using FIPS-compliant boringcrypto + echo "${exe}: missing goboring symbols" >&2 + exit 1 +fi +if [ "$(go tool nm "${exe}" | grep -c 'crypto/internal/boring/sig.FIPSOnly')" -eq 0 ]; then + # Asserts that executable is using FIPS-only schemes + echo "${exe}: missing fipsonly symbols" >&2 + exit 1 +fi + +echo "${exe} is FIPS-compliant" diff --git a/cmd/cloudflared/access/cmd.go b/cmd/cloudflared/access/cmd.go index 25502666..50ccb663 100644 --- a/cmd/cloudflared/access/cmd.go +++ b/cmd/cloudflared/access/cmd.go @@ -240,7 +240,7 @@ func login(c *cli.Context) error { return nil } -// ensureURLScheme prepends a URL with https:// if it doesnt have a scheme. http:// URLs will not be converted. +// ensureURLScheme prepends a URL with https:// if it doesn't have a scheme. http:// URLs will not be converted. func ensureURLScheme(url string) string { url = strings.Replace(strings.ToLower(url), "http://", "https://", 1) if !strings.HasPrefix(url, "https://") { diff --git a/cmd/cloudflared/buildinfo/build_info.go b/cmd/cloudflared/cliutil/build_info.go similarity index 59% rename from cmd/cloudflared/buildinfo/build_info.go rename to cmd/cloudflared/cliutil/build_info.go index 069d3fd8..4d73701e 100644 --- a/cmd/cloudflared/buildinfo/build_info.go +++ b/cmd/cloudflared/cliutil/build_info.go @@ -1,4 +1,4 @@ -package buildinfo +package cliutil import ( "fmt" @@ -11,23 +11,39 @@ type BuildInfo struct { GoOS string `json:"go_os"` GoVersion string `json:"go_version"` GoArch string `json:"go_arch"` + BuildType string `json:"build_type"` CloudflaredVersion string `json:"cloudflared_version"` } -func GetBuildInfo(cloudflaredVersion string) *BuildInfo { +func GetBuildInfo(buildType, version string) *BuildInfo { return &BuildInfo{ GoOS: runtime.GOOS, GoVersion: runtime.Version(), GoArch: runtime.GOARCH, - CloudflaredVersion: cloudflaredVersion, + BuildType: buildType, + CloudflaredVersion: version, } } func (bi *BuildInfo) Log(log *zerolog.Logger) { log.Info().Msgf("Version %s", bi.CloudflaredVersion) + if bi.BuildType != "" { + log.Info().Msgf("Built%s", bi.GetBuildTypeMsg()) + } log.Info().Msgf("GOOS: %s, GOVersion: %s, GoArch: %s", bi.GoOS, bi.GoVersion, bi.GoArch) } func (bi *BuildInfo) OSArch() string { return fmt.Sprintf("%s_%s", bi.GoOS, bi.GoArch) } + +func (bi *BuildInfo) Version() string { + return bi.CloudflaredVersion +} + +func (bi *BuildInfo) GetBuildTypeMsg() string { + if bi.BuildType == "" { + return "" + } + return fmt.Sprintf(" with %s", bi.BuildType) +} diff --git a/cmd/cloudflared/cliutil/deprecated.go b/cmd/cloudflared/cliutil/deprecated.go index c151378e..c49d47c8 100644 --- a/cmd/cloudflared/cliutil/deprecated.go +++ b/cmd/cloudflared/cliutil/deprecated.go @@ -11,7 +11,7 @@ func RemovedCommand(name string) *cli.Command { Name: name, Action: func(context *cli.Context) error { return cli.Exit( - fmt.Sprintf("%s command is no longer supported by cloudflared. Consult Argo Tunnel documentation for possible alternative solutions.", name), + fmt.Sprintf("%s command is no longer supported by cloudflared. Consult Cloudflare Tunnel documentation for possible alternative solutions.", name), -1, ) }, diff --git a/cmd/cloudflared/generic_service.go b/cmd/cloudflared/generic_service.go index 1b237593..25872bbf 100644 --- a/cmd/cloudflared/generic_service.go +++ b/cmd/cloudflared/generic_service.go @@ -1,3 +1,4 @@ +//go:build !windows && !darwin && !linux // +build !windows,!darwin,!linux package main diff --git a/cmd/cloudflared/linux_service.go b/cmd/cloudflared/linux_service.go index d70c42ba..85d195fa 100644 --- a/cmd/cloudflared/linux_service.go +++ b/cmd/cloudflared/linux_service.go @@ -1,3 +1,4 @@ +//go:build linux // +build linux package main @@ -19,11 +20,11 @@ import ( func runApp(app *cli.App, graceShutdownC chan struct{}) { app.Commands = append(app.Commands, &cli.Command{ Name: "service", - Usage: "Manages the Argo Tunnel system service", + Usage: "Manages the Cloudflare Tunnel system service", Subcommands: []*cli.Command{ { Name: "install", - Usage: "Install Argo Tunnel as a system service", + Usage: "Install Cloudflare Tunnel as a system service", Action: cliutil.ConfiguredAction(installLinuxService), Flags: []cli.Flag{ &cli.BoolFlag{ @@ -34,7 +35,7 @@ func runApp(app *cli.App, graceShutdownC chan struct{}) { }, { Name: "uninstall", - Usage: "Uninstall the Argo Tunnel service", + Usage: "Uninstall the Cloudflare Tunnel service", Action: cliutil.ConfiguredAction(uninstallLinuxService), }, }, @@ -55,7 +56,7 @@ var systemdTemplates = []ServiceTemplate{ { Path: "/etc/systemd/system/cloudflared.service", Content: `[Unit] -Description=Argo Tunnel +Description=Cloudflare Tunnel After=network.target [Service] @@ -72,7 +73,7 @@ WantedBy=multi-user.target { Path: "/etc/systemd/system/cloudflared-update.service", Content: `[Unit] -Description=Update Argo Tunnel +Description=Update Cloudflare Tunnel After=network.target [Service] @@ -82,7 +83,7 @@ ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 11 ]; then s { Path: "/etc/systemd/system/cloudflared-update.timer", Content: `[Unit] -Description=Update Argo Tunnel +Description=Update Cloudflare Tunnel [Timer] OnCalendar=daily @@ -99,7 +100,7 @@ var sysvTemplate = ServiceTemplate{ Content: `#!/bin/sh # For RedHat and cousins: # chkconfig: 2345 99 01 -# description: Argo Tunnel agent +# description: Cloudflare Tunnel agent # processname: {{.Path}} ### BEGIN INIT INFO # Provides: {{.Path}} @@ -107,8 +108,8 @@ var sysvTemplate = ServiceTemplate{ # Required-Stop: # Default-Start: 2 3 4 5 # Default-Stop: 0 1 6 -# Short-Description: Argo Tunnel -# Description: Argo Tunnel agent +# Short-Description: Cloudflare Tunnel +# Description: Cloudflare Tunnel agent ### END INIT INFO name=$(basename $(readlink -f $0)) cmd="{{.Path}} --config /etc/cloudflared/config.yml --pidfile /var/run/$name.pid --autoupdate-freq 24h0m0s{{ range .ExtraArgs }} {{ . }}{{ end }}" diff --git a/cmd/cloudflared/macos_service.go b/cmd/cloudflared/macos_service.go index f28396c6..e987df87 100644 --- a/cmd/cloudflared/macos_service.go +++ b/cmd/cloudflared/macos_service.go @@ -1,3 +1,4 @@ +//go:build darwin // +build darwin package main @@ -20,16 +21,16 @@ const ( func runApp(app *cli.App, graceShutdownC chan struct{}) { app.Commands = append(app.Commands, &cli.Command{ Name: "service", - Usage: "Manages the Argo Tunnel launch agent", + Usage: "Manages the Cloudflare Tunnel launch agent", Subcommands: []*cli.Command{ { Name: "install", - Usage: "Install Argo Tunnel as an user launch agent", + Usage: "Install Cloudflare Tunnel as an user launch agent", Action: cliutil.ConfiguredAction(installLaunchd), }, { Name: "uninstall", - Usage: "Uninstall the Argo Tunnel launch agent", + Usage: "Uninstall the Cloudflare Tunnel launch agent", Action: cliutil.ConfiguredAction(uninstallLaunchd), }, }, @@ -110,13 +111,13 @@ func installLaunchd(c *cli.Context) error { log := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog) if isRootUser() { - log.Info().Msg("Installing Argo Tunnel client as a system launch daemon. " + - "Argo Tunnel client will run at boot") + log.Info().Msg("Installing Cloudflare Tunnel client as a system launch daemon. " + + "Cloudflare Tunnel client will run at boot") } else { - log.Info().Msg("Installing Argo Tunnel client as an user launch agent. " + - "Note that Argo Tunnel client will only run when the user is logged in. " + - "If you want to run Argo Tunnel client at boot, install with root permission. " + - "For more information, visit https://developers.cloudflare.com/argo-tunnel/reference/service/") + log.Info().Msg("Installing Cloudflare Tunnel client as an user launch agent. " + + "Note that Cloudflare Tunnel client will only run when the user is logged in. " + + "If you want to run Cloudflare Tunnel client at boot, install with root permission. " + + "For more information, visit https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/run-as-service") } etPath, err := os.Executable() if err != nil { @@ -159,9 +160,9 @@ func uninstallLaunchd(c *cli.Context) error { log := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog) if isRootUser() { - log.Info().Msg("Uninstalling Argo Tunnel as a system launch daemon") + log.Info().Msg("Uninstalling Cloudflare Tunnel as a system launch daemon") } else { - log.Info().Msg("Uninstalling Argo Tunnel as an user launch agent") + log.Info().Msg("Uninstalling Cloudflare Tunnel as an user launch agent") } installPath, err := installPath() if err != nil { diff --git a/cmd/cloudflared/main.go b/cmd/cloudflared/main.go index d5b26955..8732ad34 100644 --- a/cmd/cloudflared/main.go +++ b/cmd/cloudflared/main.go @@ -31,6 +31,7 @@ const ( var ( Version = "DEV" BuildTime = "unknown" + BuildType = "" // Mostly network errors that we don't want reported back to Sentry, this is done by substring match. ignoredErrors = []string{ "connection reset by peer", @@ -46,9 +47,10 @@ var ( func main() { rand.Seed(time.Now().UnixNano()) - metrics.RegisterBuildInfo(BuildTime, Version) + metrics.RegisterBuildInfo(BuildType, BuildTime, Version) raven.SetRelease(Version) maxprocs.Set() + bInfo := cliutil.GetBuildInfo(BuildType, Version) // Graceful shutdown channel used by the app. When closed, app must terminate gracefully. // Windows service manager closes this channel when it receives stop command. @@ -67,21 +69,21 @@ func main() { app.Copyright = fmt.Sprintf( `(c) %d Cloudflare Inc. Your installation of cloudflared software constitutes a symbol of your signature indicating that you accept - the terms of the Cloudflare License (https://developers.cloudflare.com/argo-tunnel/license/), + the terms of the Cloudflare License (https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/license), Terms (https://www.cloudflare.com/terms/) and Privacy Policy (https://www.cloudflare.com/privacypolicy/).`, time.Now().Year(), ) - app.Version = fmt.Sprintf("%s (built %s)", Version, BuildTime) + app.Version = fmt.Sprintf("%s (built %s%s)", Version, BuildTime, bInfo.GetBuildTypeMsg()) app.Description = `cloudflared connects your machine or user identity to Cloudflare's global network. You can use it to authenticate a session to reach an API behind Access, route web traffic to this machine, and configure access control. - See https://developers.cloudflare.com/argo-tunnel/ for more in-depth documentation.` + See https://developers.cloudflare.com/cloudflare-one/connections/connect-apps for more in-depth documentation.` app.Flags = flags() app.Action = action(graceShutdownC) app.Commands = commands(cli.ShowVersion) - tunnel.Init(Version, graceShutdownC) // we need this to support the tunnel sub command... + tunnel.Init(bInfo, graceShutdownC) // we need this to support the tunnel sub command... access.Init(graceShutdownC) updater.Init(Version) runApp(app, graceShutdownC) diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 5afd801f..1f9df38b 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -21,7 +21,7 @@ import ( "github.com/urfave/cli/v2" "github.com/urfave/cli/v2/altsrc" - "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" + "github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/proxydns" "github.com/cloudflare/cloudflared/cmd/cloudflared/ui" @@ -35,7 +35,6 @@ import ( "github.com/cloudflare/cloudflared/signal" "github.com/cloudflare/cloudflared/tlsconfig" "github.com/cloudflare/cloudflared/tunneldns" - "github.com/cloudflare/cloudflared/tunnelstore" ) const ( @@ -86,7 +85,7 @@ const ( var ( graceShutdownC chan struct{} - version string + buildInfo *cliutil.BuildInfo routeFailMsg = fmt.Sprintf("failed to provision routing, please create it manually via Cloudflare dashboard or UI; "+ "most likely you already have a conflicting record there. You can also rerun this command with --%s to overwrite "+ @@ -102,6 +101,8 @@ func Commands() []*cli.Command { buildLoginSubcommand(false), buildCreateCommand(), buildRouteCommand(), + // TODO TUN-5477 this should not be hidden + buildVirtualNetworkSubcommand(true), buildRunCommand(), buildListCommand(), buildInfoCommand(), @@ -126,14 +127,14 @@ func buildTunnelCommand(subcommands []*cli.Command) *cli.Command { Name: "tunnel", Action: cliutil.ConfiguredAction(TunnelCommand), Category: "Tunnel", - Usage: "Make a locally-running web service accessible over the internet using Argo Tunnel.", + Usage: "Make a locally-running web service accessible over the internet using Cloudflare Tunnel.", ArgsUsage: " ", - Description: `Argo Tunnel asks you to specify a hostname on a Cloudflare-powered + Description: `Cloudflare Tunnel asks you to specify a hostname on a Cloudflare-powered domain you control and a local address. Traffic from that hostname is routed (optionally via a Cloudflare Load Balancer) to this machine and appears on the specified port where it can be served. - This feature requires your Cloudflare account be subscribed to the Argo Smart Routing feature. + This feature requires your Cloudflare account be subscribed to the Cloudflare Smart Routing feature. To use, begin by calling login to download a certificate: @@ -173,15 +174,16 @@ func TunnelCommand(c *cli.Context) error { return runClassicTunnel(sc) } -func Init(ver string, gracefulShutdown chan struct{}) { - version, graceShutdownC = ver, gracefulShutdown +func Init(info *cliutil.BuildInfo, gracefulShutdown chan struct{}) { + buildInfo, graceShutdownC = info, gracefulShutdown } // runAdhocNamedTunnel create, route and run a named tunnel in one command func runAdhocNamedTunnel(sc *subcommandContext, name, credentialsOutputPath string) error { tunnel, ok, err := sc.tunnelActive(name) if err != nil || !ok { - tunnel, err = sc.create(name, credentialsOutputPath) + // pass empty string as secret to generate one + tunnel, err = sc.create(name, credentialsOutputPath, "") if err != nil { return errors.Wrap(err, "failed to create tunnel") } @@ -206,22 +208,22 @@ func runAdhocNamedTunnel(sc *subcommandContext, name, credentialsOutputPath stri // runClassicTunnel creates a "classic" non-named tunnel func runClassicTunnel(sc *subcommandContext) error { - return StartServer(sc.c, version, nil, sc.log, sc.isUIEnabled) + return StartServer(sc.c, buildInfo, nil, sc.log, sc.isUIEnabled) } -func routeFromFlag(c *cli.Context) (route tunnelstore.Route, ok bool) { +func routeFromFlag(c *cli.Context) (route cfapi.HostnameRoute, ok bool) { if hostname := c.String("hostname"); hostname != "" { if lbPool := c.String("lb-pool"); lbPool != "" { - return tunnelstore.NewLBRoute(hostname, lbPool), true + return cfapi.NewLBRoute(hostname, lbPool), true } - return tunnelstore.NewDNSRoute(hostname, c.Bool(overwriteDNSFlagName)), true + return cfapi.NewDNSRoute(hostname, c.Bool(overwriteDNSFlagName)), true } return nil, false } func StartServer( c *cli.Context, - version string, + info *cliutil.BuildInfo, namedTunnel *connection.NamedTunnelConfig, log *zerolog.Logger, isUIEnabled bool, @@ -268,8 +270,7 @@ func StartServer( defer trace.Stop() } - buildInfo := buildinfo.GetBuildInfo(version) - buildInfo.Log(log) + info.Log(log) logClientOptions(c, log) // this context drives the server, when it's cancelled tunnel and all other components (origins, dns, etc...) should stop @@ -333,7 +334,7 @@ func StartServer( observer.SendURL(quickTunnelURL) } - tunnelConfig, ingressRules, err := prepareTunnelConfig(c, buildInfo, version, log, logTransport, observer, namedTunnel) + tunnelConfig, ingressRules, err := prepareTunnelConfig(c, info, log, logTransport, observer, namedTunnel) if err != nil { log.Err(err).Msg("Couldn't start tunnel") return err @@ -374,7 +375,7 @@ func StartServer( if isUIEnabled { tunnelUI := ui.NewUIModel( - version, + info.Version(), hostname, metricsListener.Addr().String(), &ingressRules, @@ -488,7 +489,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag { credentialsFileFlag, altsrc.NewBoolFlag(&cli.BoolFlag{ Name: "is-autoupdated", - Usage: "Signal the new process that Argo Tunnel connector has been autoupdated", + Usage: "Signal the new process that Cloudflare Tunnel connector has been autoupdated", Value: false, Hidden: true, }), @@ -646,6 +647,12 @@ func tunnelFlags(shouldHide bool) []cli.Flag { Value: "https://api.trycloudflare.com", Hidden: true, }), + altsrc.NewIntFlag(&cli.IntFlag{ + Name: "max-fetch-size", + Usage: `The maximum number of results that cloudflared can fetch from Cloudflare API for any listing operations needed`, + EnvVars: []string{"TUNNEL_MAX_FETCH_SIZE"}, + Hidden: true, + }), selectProtocolFlag, overwriteDNSFlag, }...) diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 901d192c..0d29cc7d 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -16,7 +16,8 @@ import ( "github.com/urfave/cli/v2" "golang.org/x/crypto/ssh/terminal" - "github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo" + "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" + "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/edgediscovery" @@ -148,8 +149,7 @@ func getOriginCert(originCertPath string, log *zerolog.Logger) ([]byte, error) { func prepareTunnelConfig( c *cli.Context, - buildInfo *buildinfo.BuildInfo, - version string, + info *cliutil.BuildInfo, log, logTransport *zerolog.Logger, observer *connection.Observer, namedTunnel *connection.NamedTunnelConfig, @@ -193,8 +193,8 @@ func prepareTunnelConfig( namedTunnel.Client = tunnelpogs.ClientInfo{ ClientID: clientUUID[:], Features: dedup(features), - Version: version, - Arch: buildInfo.OSArch(), + Version: info.Version(), + Arch: info.OSArch(), } ingressRules, err = ingress.ParseIngress(cfg) if err != nil && err != ingress.ErrNoIngressRules { @@ -238,7 +238,7 @@ func prepareTunnelConfig( log.Info().Msgf("Warp-routing is enabled") } - protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.HTTP2Percentage, origin.ResolveTTL, log) + protocolSelector, err := connection.NewProtocolSelector(c.String("protocol"), warpRoutingEnabled, namedTunnel, edgediscovery.ProtocolPercentage, origin.ResolveTTL, log) if err != nil { return nil, ingress.Ingress{}, err } @@ -281,7 +281,7 @@ func prepareTunnelConfig( return &origin.TunnelConfig{ ConnectionConfig: connectionConfig, - OSArch: buildInfo.OSArch(), + OSArch: info.OSArch(), ClientID: clientID, EdgeAddrs: c.StringSlice("edge"), Region: c.String("region"), @@ -293,7 +293,7 @@ func prepareTunnelConfig( Log: log, LogTransport: logTransport, Observer: observer, - ReportedVersion: version, + ReportedVersion: info.Version(), // Note TUN-3758 , we use Int because UInt is not supported with altsrc Retries: uint(c.Int("retries")), RunFromTerminal: isRunningFromTerminal(), diff --git a/cmd/cloudflared/tunnel/configuration_test.go b/cmd/cloudflared/tunnel/configuration_test.go index c18c6c09..b4edf636 100644 --- a/cmd/cloudflared/tunnel/configuration_test.go +++ b/cmd/cloudflared/tunnel/configuration_test.go @@ -1,4 +1,6 @@ +//go:build ignore // +build ignore + // TODO: Remove the above build tag and include this test when we start compiling with Golang 1.10.0+ package tunnel diff --git a/cmd/cloudflared/tunnel/info.go b/cmd/cloudflared/tunnel/info.go index d3610f46..7c20556a 100644 --- a/cmd/cloudflared/tunnel/info.go +++ b/cmd/cloudflared/tunnel/info.go @@ -5,12 +5,12 @@ import ( "github.com/google/uuid" - "github.com/cloudflare/cloudflared/tunnelstore" + "github.com/cloudflare/cloudflared/cfapi" ) type Info struct { - ID uuid.UUID `json:"id"` - Name string `json:"name"` - CreatedAt time.Time `json:"createdAt"` - Connectors []*tunnelstore.ActiveClient `json:"conns"` + ID uuid.UUID `json:"id"` + Name string `json:"name"` + CreatedAt time.Time `json:"createdAt"` + Connectors []*cfapi.ActiveClient `json:"conns"` } diff --git a/cmd/cloudflared/tunnel/quick_tunnel.go b/cmd/cloudflared/tunnel/quick_tunnel.go index f3441d84..08b5ff78 100644 --- a/cmd/cloudflared/tunnel/quick_tunnel.go +++ b/cmd/cloudflared/tunnel/quick_tunnel.go @@ -70,9 +70,13 @@ func RunQuickTunnel(sc *subcommandContext) error { sc.log.Info().Msg(line) } + if !sc.c.IsSet("protocol") { + sc.c.Set("protocol", "quic") + } + return StartServer( sc.c, - version, + buildInfo, &connection.NamedTunnelConfig{Credentials: credentials, QuickTunnelUrl: data.Result.Hostname}, sc.log, sc.isUIEnabled, diff --git a/cmd/cloudflared/tunnel/signal_test.go b/cmd/cloudflared/tunnel/signal_test.go index 43921c6f..294ed713 100644 --- a/cmd/cloudflared/tunnel/signal_test.go +++ b/cmd/cloudflared/tunnel/signal_test.go @@ -1,3 +1,4 @@ +//go:build !windows // +build !windows package tunnel diff --git a/cmd/cloudflared/tunnel/subcommand_context.go b/cmd/cloudflared/tunnel/subcommand_context.go index 4450cf3c..cb5b15be 100644 --- a/cmd/cloudflared/tunnel/subcommand_context.go +++ b/cmd/cloudflared/tunnel/subcommand_context.go @@ -1,6 +1,7 @@ package tunnel import ( + "encoding/base64" "encoding/json" "fmt" "os" @@ -13,9 +14,9 @@ import ( "github.com/urfave/cli/v2" "github.com/cloudflare/cloudflared/certutil" + "github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/logger" - "github.com/cloudflare/cloudflared/tunnelstore" ) type errInvalidJSONCredential struct { @@ -36,7 +37,7 @@ type subcommandContext struct { fs fileSystem // These fields should be accessed using their respective Getter - tunnelstoreClient tunnelstore.Client + tunnelstoreClient cfapi.Client userCredential *userCredential } @@ -67,7 +68,7 @@ type userCredential struct { certPath string } -func (sc *subcommandContext) client() (tunnelstore.Client, error) { +func (sc *subcommandContext) client() (cfapi.Client, error) { if sc.tunnelstoreClient != nil { return sc.tunnelstoreClient, nil } @@ -75,8 +76,8 @@ func (sc *subcommandContext) client() (tunnelstore.Client, error) { if err != nil { return nil, err } - userAgent := fmt.Sprintf("cloudflared/%s", version) - client, err := tunnelstore.NewRESTClient( + userAgent := fmt.Sprintf("cloudflared/%s", buildInfo.Version()) + client, err := cfapi.NewRESTClient( sc.c.String("api-url"), credential.cert.AccountID, credential.cert.ZoneID, @@ -148,15 +149,27 @@ func (sc *subcommandContext) readTunnelCredentials(credFinder CredFinder) (conne return credentials, nil } -func (sc *subcommandContext) create(name string, credentialsFilePath string) (*tunnelstore.Tunnel, error) { +func (sc *subcommandContext) create(name string, credentialsFilePath string, secret string) (*cfapi.Tunnel, error) { client, err := sc.client() if err != nil { - return nil, errors.Wrap(err, "couldn't create client to talk to Argo Tunnel backend") + return nil, errors.Wrap(err, "couldn't create client to talk to Cloudflare Tunnel backend") } - tunnelSecret, err := generateTunnelSecret() - if err != nil { - return nil, errors.Wrap(err, "couldn't generate the secret for your new tunnel") + var tunnelSecret []byte + if secret == "" { + tunnelSecret, err = generateTunnelSecret() + if err != nil { + return nil, errors.Wrap(err, "couldn't generate the secret for your new tunnel") + } + } else { + decodedSecret, err := base64.StdEncoding.DecodeString(secret) + if err != nil { + return nil, errors.Wrap(err, "Couldn't decode tunnel secret from base64") + } + tunnelSecret = []byte(decodedSecret) + if len(tunnelSecret) < 32 { + return nil, errors.New("Decoded tunnel secret must be at least 32 bytes long") + } } tunnel, err := client.CreateTunnel(name, tunnelSecret) @@ -211,7 +224,7 @@ func (sc *subcommandContext) create(name string, credentialsFilePath string) (*t return tunnel, nil } -func (sc *subcommandContext) list(filter *tunnelstore.Filter) ([]*tunnelstore.Tunnel, error) { +func (sc *subcommandContext) list(filter *cfapi.TunnelFilter) ([]*cfapi.Tunnel, error) { client, err := sc.client() if err != nil { return nil, err @@ -230,7 +243,7 @@ func (sc *subcommandContext) delete(tunnelIDs []uuid.UUID) error { for _, id := range tunnelIDs { tunnel, err := client.GetTunnel(id) if err != nil { - return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", tunnel.ID) + return errors.Wrapf(err, "Can't get tunnel information. Please check tunnel id: %s", id) } // Check if tunnel DeletedAt field has already been set @@ -238,7 +251,7 @@ func (sc *subcommandContext) delete(tunnelIDs []uuid.UUID) error { return fmt.Errorf("Tunnel %s has already been deleted", tunnel.ID) } if forceFlagSet { - if err := client.CleanupConnections(tunnel.ID, tunnelstore.NewCleanupParams()); err != nil { + if err := client.CleanupConnections(tunnel.ID, cfapi.NewCleanupParams()); err != nil { return errors.Wrapf(err, "Error cleaning up connections for tunnel %s", tunnel.ID) } } @@ -290,7 +303,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error { return StartServer( sc.c, - version, + buildInfo, &connection.NamedTunnelConfig{Credentials: credentials}, sc.log, sc.isUIEnabled, @@ -298,7 +311,7 @@ func (sc *subcommandContext) run(tunnelID uuid.UUID) error { } func (sc *subcommandContext) cleanupConnections(tunnelIDs []uuid.UUID) error { - params := tunnelstore.NewCleanupParams() + params := cfapi.NewCleanupParams() extraLog := "" if connector := sc.c.String("connector-id"); connector != "" { connectorID, err := uuid.Parse(connector) @@ -322,7 +335,7 @@ func (sc *subcommandContext) cleanupConnections(tunnelIDs []uuid.UUID) error { return nil } -func (sc *subcommandContext) route(tunnelID uuid.UUID, r tunnelstore.Route) (tunnelstore.RouteResult, error) { +func (sc *subcommandContext) route(tunnelID uuid.UUID, r cfapi.HostnameRoute) (cfapi.HostnameRouteResult, error) { client, err := sc.client() if err != nil { return nil, err @@ -332,8 +345,8 @@ func (sc *subcommandContext) route(tunnelID uuid.UUID, r tunnelstore.Route) (tun } // Query Tunnelstore to find the active tunnel with the given name. -func (sc *subcommandContext) tunnelActive(name string) (*tunnelstore.Tunnel, bool, error) { - filter := tunnelstore.NewFilter() +func (sc *subcommandContext) tunnelActive(name string) (*cfapi.Tunnel, bool, error) { + filter := cfapi.NewTunnelFilter() filter.NoDeleted() filter.ByName(name) tunnels, err := sc.list(filter) @@ -373,52 +386,42 @@ func (sc *subcommandContext) findID(input string) (uuid.UUID, error) { } // findIDs is just like mapping `findID` over a slice, but it only uses -// one Tunnelstore API call. +// one Tunnelstore API call per non-UUID input provided. func (sc *subcommandContext) findIDs(inputs []string) ([]uuid.UUID, error) { + uuids, names := splitUuids(inputs) - // Shortcut without Tunnelstore call if we find that all inputs are already UUIDs. - uuids, err := convertNamesToUuids(inputs, make(map[string]uuid.UUID)) - if err == nil { - return uuids, nil + for _, name := range names { + filter := cfapi.NewTunnelFilter() + filter.NoDeleted() + filter.ByName(name) + + tunnels, err := sc.list(filter) + if err != nil { + return nil, err + } + + if len(tunnels) != 1 { + return nil, fmt.Errorf("there should only be 1 non-deleted Tunnel named %s", name) + } + + uuids = append(uuids, tunnels[0].ID) } - // First, look up all tunnels the user has - filter := tunnelstore.NewFilter() - filter.NoDeleted() - tunnels, err := sc.list(filter) - if err != nil { - return nil, err - } - // Do the pure list-processing in its own function, so that it can be - // unit tested easily. - return findIDs(tunnels, inputs) + return uuids, nil } -func findIDs(tunnels []*tunnelstore.Tunnel, inputs []string) ([]uuid.UUID, error) { - // Put them into a dictionary for faster lookups - nameToID := make(map[string]uuid.UUID, len(tunnels)) - for _, tunnel := range tunnels { - nameToID[tunnel.Name] = tunnel.ID - } +func splitUuids(inputs []string) ([]uuid.UUID, []string) { + uuids := make([]uuid.UUID, 0) + names := make([]string, 0) - return convertNamesToUuids(inputs, nameToID) -} - -func convertNamesToUuids(inputs []string, nameToID map[string]uuid.UUID) ([]uuid.UUID, error) { - tunnelIDs := make([]uuid.UUID, len(inputs)) - var badInputs []string - for i, input := range inputs { - if id, err := uuid.Parse(input); err == nil { - tunnelIDs[i] = id - } else if id, ok := nameToID[input]; ok { - tunnelIDs[i] = id + for _, input := range inputs { + id, err := uuid.Parse(input) + if err != nil { + names = append(names, input) } else { - badInputs = append(badInputs, input) + uuids = append(uuids, id) } } - if len(badInputs) > 0 { - msg := "Please specify either the ID or name of a tunnel. The following inputs were neither: %s" - return nil, fmt.Errorf(msg, strings.Join(badInputs, ", ")) - } - return tunnelIDs, nil + + return uuids, names } diff --git a/cmd/cloudflared/tunnel/subcommand_context_teamnet.go b/cmd/cloudflared/tunnel/subcommand_context_teamnet.go index bde2ded2..4605172c 100644 --- a/cmd/cloudflared/tunnel/subcommand_context_teamnet.go +++ b/cmd/cloudflared/tunnel/subcommand_context_teamnet.go @@ -1,16 +1,14 @@ package tunnel import ( - "net" - "github.com/pkg/errors" - "github.com/cloudflare/cloudflared/teamnet" + "github.com/cloudflare/cloudflared/cfapi" ) const noClientMsg = "error while creating backend client" -func (sc *subcommandContext) listRoutes(filter *teamnet.Filter) ([]*teamnet.DetailedRoute, error) { +func (sc *subcommandContext) listRoutes(filter *cfapi.IpRouteFilter) ([]*cfapi.DetailedRoute, error) { client, err := sc.client() if err != nil { return nil, errors.Wrap(err, noClientMsg) @@ -18,26 +16,26 @@ func (sc *subcommandContext) listRoutes(filter *teamnet.Filter) ([]*teamnet.Deta return client.ListRoutes(filter) } -func (sc *subcommandContext) addRoute(newRoute teamnet.NewRoute) (teamnet.Route, error) { +func (sc *subcommandContext) addRoute(newRoute cfapi.NewRoute) (cfapi.Route, error) { client, err := sc.client() if err != nil { - return teamnet.Route{}, errors.Wrap(err, noClientMsg) + return cfapi.Route{}, errors.Wrap(err, noClientMsg) } return client.AddRoute(newRoute) } -func (sc *subcommandContext) deleteRoute(network net.IPNet) error { +func (sc *subcommandContext) deleteRoute(params cfapi.DeleteRouteParams) error { client, err := sc.client() if err != nil { return errors.Wrap(err, noClientMsg) } - return client.DeleteRoute(network) + return client.DeleteRoute(params) } -func (sc *subcommandContext) getRouteByIP(ip net.IP) (teamnet.DetailedRoute, error) { +func (sc *subcommandContext) getRouteByIP(params cfapi.GetRouteByIpParams) (cfapi.DetailedRoute, error) { client, err := sc.client() if err != nil { - return teamnet.DetailedRoute{}, errors.Wrap(err, noClientMsg) + return cfapi.DetailedRoute{}, errors.Wrap(err, noClientMsg) } - return client.GetByIP(ip) + return client.GetByIP(params) } diff --git a/cmd/cloudflared/tunnel/subcommand_context_test.go b/cmd/cloudflared/tunnel/subcommand_context_test.go index 2155b3dc..61a1e68b 100644 --- a/cmd/cloudflared/tunnel/subcommand_context_test.go +++ b/cmd/cloudflared/tunnel/subcommand_context_test.go @@ -13,83 +13,10 @@ import ( "github.com/rs/zerolog" "github.com/urfave/cli/v2" + "github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/connection" - "github.com/cloudflare/cloudflared/tunnelstore" ) -func Test_findIDs(t *testing.T) { - type args struct { - tunnels []*tunnelstore.Tunnel - inputs []string - } - tests := []struct { - name string - args args - want []uuid.UUID - wantErr bool - }{ - { - name: "input not found", - args: args{ - inputs: []string{"asdf"}, - }, - wantErr: true, - }, - { - name: "only UUID", - args: args{ - inputs: []string{"a8398a0b-876d-48ed-b609-3fcfd67a4950"}, - }, - want: []uuid.UUID{uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950")}, - }, - { - name: "only name", - args: args{ - tunnels: []*tunnelstore.Tunnel{ - { - ID: uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"), - Name: "tunnel1", - }, - }, - inputs: []string{"tunnel1"}, - }, - want: []uuid.UUID{uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950")}, - }, - { - name: "both UUID and name", - args: args{ - tunnels: []*tunnelstore.Tunnel{ - { - ID: uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"), - Name: "tunnel1", - }, - { - ID: uuid.MustParse("bf028b68-744f-466e-97f8-c46161d80aa5"), - Name: "tunnel2", - }, - }, - inputs: []string{"tunnel1", "bf028b68-744f-466e-97f8-c46161d80aa5"}, - }, - want: []uuid.UUID{ - uuid.MustParse("a8398a0b-876d-48ed-b609-3fcfd67a4950"), - uuid.MustParse("bf028b68-744f-466e-97f8-c46161d80aa5"), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := findIDs(tt.args.tunnels, tt.args.inputs) - if (err != nil) != tt.wantErr { - t.Errorf("findIDs() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("findIDs() = %v, want %v", got, tt.want) - } - }) - } -} - type mockFileSystem struct { rf func(string) ([]byte, error) vfp func(string) bool @@ -109,7 +36,7 @@ func Test_subcommandContext_findCredentials(t *testing.T) { log *zerolog.Logger isUIEnabled bool fs fileSystem - tunnelstoreClient tunnelstore.Client + tunnelstoreClient cfapi.Client userCredential *userCredential } type args struct { @@ -260,13 +187,13 @@ func Test_subcommandContext_findCredentials(t *testing.T) { } type deleteMockTunnelStore struct { - tunnelstore.Client + cfapi.Client mockTunnels map[uuid.UUID]mockTunnelBehaviour deletedTunnelIDs []uuid.UUID } type mockTunnelBehaviour struct { - tunnel tunnelstore.Tunnel + tunnel cfapi.Tunnel deleteErr error cleanupErr error } @@ -282,7 +209,7 @@ func newDeleteMockTunnelStore(tunnels ...mockTunnelBehaviour) *deleteMockTunnelS } } -func (d *deleteMockTunnelStore) GetTunnel(tunnelID uuid.UUID) (*tunnelstore.Tunnel, error) { +func (d *deleteMockTunnelStore) GetTunnel(tunnelID uuid.UUID) (*cfapi.Tunnel, error) { tunnel, ok := d.mockTunnels[tunnelID] if !ok { return nil, fmt.Errorf("Couldn't find tunnel: %v", tunnelID) @@ -306,7 +233,7 @@ func (d *deleteMockTunnelStore) DeleteTunnel(tunnelID uuid.UUID) error { return nil } -func (d *deleteMockTunnelStore) CleanupConnections(tunnelID uuid.UUID, _ *tunnelstore.CleanupParams) error { +func (d *deleteMockTunnelStore) CleanupConnections(tunnelID uuid.UUID, _ *cfapi.CleanupParams) error { tunnel, ok := d.mockTunnels[tunnelID] if !ok { return fmt.Errorf("Couldn't find tunnel: %v", tunnelID) @@ -357,10 +284,10 @@ func Test_subcommandContext_Delete(t *testing.T) { }(), tunnelstoreClient: newDeleteMockTunnelStore( mockTunnelBehaviour{ - tunnel: tunnelstore.Tunnel{ID: tunnelID1}, + tunnel: cfapi.Tunnel{ID: tunnelID1}, }, mockTunnelBehaviour{ - tunnel: tunnelstore.Tunnel{ID: tunnelID2}, + tunnel: cfapi.Tunnel{ID: tunnelID2}, }, ), }, diff --git a/cmd/cloudflared/tunnel/subcommand_context_vnets.go b/cmd/cloudflared/tunnel/subcommand_context_vnets.go new file mode 100644 index 00000000..14e055fe --- /dev/null +++ b/cmd/cloudflared/tunnel/subcommand_context_vnets.go @@ -0,0 +1,40 @@ +package tunnel + +import ( + "github.com/google/uuid" + "github.com/pkg/errors" + + "github.com/cloudflare/cloudflared/cfapi" +) + +func (sc *subcommandContext) addVirtualNetwork(newVnet cfapi.NewVirtualNetwork) (cfapi.VirtualNetwork, error) { + client, err := sc.client() + if err != nil { + return cfapi.VirtualNetwork{}, errors.Wrap(err, noClientMsg) + } + return client.CreateVirtualNetwork(newVnet) +} + +func (sc *subcommandContext) listVirtualNetworks(filter *cfapi.VnetFilter) ([]*cfapi.VirtualNetwork, error) { + client, err := sc.client() + if err != nil { + return nil, errors.Wrap(err, noClientMsg) + } + return client.ListVirtualNetworks(filter) +} + +func (sc *subcommandContext) deleteVirtualNetwork(vnetId uuid.UUID) error { + client, err := sc.client() + if err != nil { + return errors.Wrap(err, noClientMsg) + } + return client.DeleteVirtualNetwork(vnetId) +} + +func (sc *subcommandContext) updateVirtualNetwork(vnetId uuid.UUID, updates cfapi.UpdateVirtualNetwork) error { + client, err := sc.client() + if err != nil { + return errors.Wrap(err, noClientMsg) + } + return client.UpdateVirtualNetwork(vnetId, updates) +} diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 6cfe3af2..22dca2a4 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -21,11 +21,11 @@ import ( "golang.org/x/net/idna" yaml "gopkg.in/yaml.v2" + "github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/updater" "github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/connection" - "github.com/cloudflare/cloudflared/tunnelstore" ) const ( @@ -64,8 +64,8 @@ var ( Name: "when", Aliases: []string{"w"}, Usage: "List tunnels that are active at the given `TIME` in RFC3339 format", - Layout: tunnelstore.TimeLayout, - DefaultText: fmt.Sprintf("current time, %s", time.Now().Format(tunnelstore.TimeLayout)), + Layout: cfapi.TimeLayout, + DefaultText: fmt.Sprintf("current time, %s", time.Now().Format(cfapi.TimeLayout)), } listIDFlag = &cli.StringFlag{ Name: "id", @@ -156,6 +156,12 @@ var ( Usage: `Overwrites existing DNS records with this hostname`, EnvVars: []string{"TUNNEL_FORCE_PROVISIONING_DNS"}, } + createSecretFlag = &cli.StringFlag{ + Name: "secret", + Aliases: []string{"s"}, + Usage: "Base64 encoded secret to set for the tunnel. The decoded secret must be at least 32 bytes long. If not specified, a random 32-byte secret will be generated.", + EnvVars: []string{"TUNNEL_CREATE_SECRET"}, + } ) func buildCreateCommand() *cli.Command { @@ -170,7 +176,7 @@ func buildCreateCommand() *cli.Command { For example, to create a tunnel named 'my-tunnel' run: $ cloudflared tunnel create my-tunnel`, - Flags: []cli.Flag{outputFormatFlag, credentialsFileFlagCLIOnly}, + Flags: []cli.Flag{outputFormatFlag, credentialsFileFlagCLIOnly, createSecretFlag}, CustomHelpTemplate: commandHelpTemplate(), } } @@ -196,7 +202,7 @@ func createCommand(c *cli.Context) error { warningChecker := updater.StartWarningCheck(c) defer warningChecker.LogWarningIfAny(sc.log) - _, err = sc.create(name, c.String(CredFileFlag)) + _, err = sc.create(name, c.String(CredFileFlag), c.String(createSecretFlag.Name)) return errors.Wrap(err, "failed to create tunnel") } @@ -254,7 +260,7 @@ func listCommand(c *cli.Context) error { warningChecker := updater.StartWarningCheck(c) defer warningChecker.LogWarningIfAny(sc.log) - filter := tunnelstore.NewFilter() + filter := cfapi.NewTunnelFilter() if !c.Bool("show-deleted") { filter.NoDeleted() } @@ -277,6 +283,9 @@ func listCommand(c *cli.Context) error { } filter.ByTunnelID(tunnelID) } + if maxFetch := c.Int("max-fetch-size"); maxFetch > 0 { + filter.MaxFetchSize(uint(maxFetch)) + } tunnels, err := sc.list(filter) if err != nil { @@ -320,13 +329,13 @@ func listCommand(c *cli.Context) error { if len(tunnels) > 0 { formatAndPrintTunnelList(tunnels, c.Bool("show-recently-disconnected")) } else { - fmt.Println("You have no tunnels, use 'cloudflared tunnel create' to define a new tunnel") + fmt.Println("No tunnels were found for the given filter flags. You can use 'cloudflared tunnel create' to create a tunnel.") } return nil } -func formatAndPrintTunnelList(tunnels []*tunnelstore.Tunnel, showRecentlyDisconnected bool) { +func formatAndPrintTunnelList(tunnels []*cfapi.Tunnel, showRecentlyDisconnected bool) { writer := tabWriter() defer writer.Flush() @@ -348,7 +357,7 @@ func formatAndPrintTunnelList(tunnels []*tunnelstore.Tunnel, showRecentlyDisconn } } -func fmtConnections(connections []tunnelstore.Connection, showRecentlyDisconnected bool) string { +func fmtConnections(connections []cfapi.Connection, showRecentlyDisconnected bool) string { // Count connections per colo numConnsPerColo := make(map[string]uint, len(connections)) @@ -468,8 +477,8 @@ func tunnelInfo(c *cli.Context) error { return nil } -func getTunnel(sc *subcommandContext, tunnelID uuid.UUID) (*tunnelstore.Tunnel, error) { - filter := tunnelstore.NewFilter() +func getTunnel(sc *subcommandContext, tunnelID uuid.UUID) (*cfapi.Tunnel, error) { + filter := cfapi.NewTunnelFilter() filter.ByTunnelID(tunnelID) tunnels, err := sc.list(filter) if err != nil { @@ -702,7 +711,7 @@ Further information about managing Cloudflare WARP traffic to your tunnel is ava { Name: "dns", Action: cliutil.ConfiguredAction(routeDnsCommand), - Usage: "Route a hostname by creating a DNS CNAME record to a tunnel", + Usage: "HostnameRoute a hostname by creating a DNS CNAME record to a tunnel", UsageText: "cloudflared tunnel route dns [TUNNEL] [HOSTNAME]", Description: `Creates a DNS CNAME record hostname that points to the tunnel.`, Flags: []cli.Flag{overwriteDNSFlag}, @@ -719,7 +728,7 @@ Further information about managing Cloudflare WARP traffic to your tunnel is ava } } -func dnsRouteFromArg(c *cli.Context, overwriteExisting bool) (tunnelstore.Route, error) { +func dnsRouteFromArg(c *cli.Context, overwriteExisting bool) (cfapi.HostnameRoute, error) { const ( userHostnameIndex = 1 expectedNArgs = 2 @@ -733,10 +742,10 @@ func dnsRouteFromArg(c *cli.Context, overwriteExisting bool) (tunnelstore.Route, } else if !validateHostname(userHostname, true) { return nil, errors.Errorf("%s is not a valid hostname", userHostname) } - return tunnelstore.NewDNSRoute(userHostname, overwriteExisting), nil + return cfapi.NewDNSRoute(userHostname, overwriteExisting), nil } -func lbRouteFromArg(c *cli.Context) (tunnelstore.Route, error) { +func lbRouteFromArg(c *cli.Context) (cfapi.HostnameRoute, error) { const ( lbNameIndex = 1 lbPoolIndex = 2 @@ -759,7 +768,7 @@ func lbRouteFromArg(c *cli.Context) (tunnelstore.Route, error) { return nil, errors.Errorf("%s is not a valid pool name", lbPool) } - return tunnelstore.NewLBRoute(lbName, lbPool), nil + return cfapi.NewLBRoute(lbName, lbPool), nil } var nameRegex = regexp.MustCompile("^[_a-zA-Z0-9][-_.a-zA-Z0-9]*$") @@ -806,7 +815,7 @@ func routeCommand(c *cli.Context, routeType string) error { if err != nil { return err } - var route tunnelstore.Route + var route cfapi.HostnameRoute switch routeType { case "dns": route, err = dnsRouteFromArg(c, c.Bool(overwriteDNSFlagName)) diff --git a/cmd/cloudflared/tunnel/subcommands_test.go b/cmd/cloudflared/tunnel/subcommands_test.go index 67aad6ea..4ebbd922 100644 --- a/cmd/cloudflared/tunnel/subcommands_test.go +++ b/cmd/cloudflared/tunnel/subcommands_test.go @@ -8,12 +8,12 @@ import ( homedir "github.com/mitchellh/go-homedir" "github.com/stretchr/testify/assert" - "github.com/cloudflare/cloudflared/tunnelstore" + "github.com/cloudflare/cloudflared/cfapi" ) func Test_fmtConnections(t *testing.T) { type args struct { - connections []tunnelstore.Connection + connections []cfapi.Connection } tests := []struct { name string @@ -23,14 +23,14 @@ func Test_fmtConnections(t *testing.T) { { name: "empty", args: args{ - connections: []tunnelstore.Connection{}, + connections: []cfapi.Connection{}, }, want: "", }, { name: "trivial", args: args{ - connections: []tunnelstore.Connection{ + connections: []cfapi.Connection{ { ColoName: "DFW", ID: uuid.MustParse("ea550130-57fd-4463-aab1-752822231ddd"), @@ -42,7 +42,7 @@ func Test_fmtConnections(t *testing.T) { { name: "with a pending reconnect", args: args{ - connections: []tunnelstore.Connection{ + connections: []cfapi.Connection{ { ColoName: "DFW", ID: uuid.MustParse("ea550130-57fd-4463-aab1-752822231ddd"), @@ -55,7 +55,7 @@ func Test_fmtConnections(t *testing.T) { { name: "many colos", args: args{ - connections: []tunnelstore.Connection{ + connections: []cfapi.Connection{ { ColoName: "YRV", ID: uuid.MustParse("ea550130-57fd-4463-aab1-752822231ddd"), diff --git a/cmd/cloudflared/tunnel/teamnet_subcommands.go b/cmd/cloudflared/tunnel/teamnet_subcommands.go index d87c9892..fbcec009 100644 --- a/cmd/cloudflared/tunnel/teamnet_subcommands.go +++ b/cmd/cloudflared/tunnel/teamnet_subcommands.go @@ -6,35 +6,54 @@ import ( "os" "text/tabwriter" + "github.com/google/uuid" "github.com/pkg/errors" + "github.com/urfave/cli/v2" + "github.com/cloudflare/cloudflared/cfapi" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" "github.com/cloudflare/cloudflared/cmd/cloudflared/updater" - "github.com/cloudflare/cloudflared/teamnet" +) - "github.com/urfave/cli/v2" +var ( + vnetFlag = &cli.StringFlag{ + Name: "virtual-network", + Aliases: []string{"vn"}, + Usage: "The ID or name of the virtual network to which the route is associated to.", + } ) func buildRouteIPSubcommand() *cli.Command { return &cli.Command{ Name: "ip", - Usage: "Configure and query Cloudflare WARP routing to services or private networks available through this tunnel.", + Usage: "Configure and query Cloudflare WARP routing to private IP networks made available through Cloudflare Tunnels.", UsageText: "cloudflared tunnel [--config FILEPATH] route COMMAND [arguments...]", - Description: `cloudflared can provision private routes from any IP space to origins in your corporate network. -Users enrolled in your Cloudflare for Teams organization can reach those routes through the -Cloudflare WARP client. You can also build rules to determine who can reach certain routes.`, + Description: `cloudflared can provision routes for any IP space in your corporate network. Users enrolled in +your Cloudflare for Teams organization can reach those IPs through the Cloudflare WARP +client. You can then configure L7/L4 filtering on https://dash.teams.cloudflare.com to +determine who can reach certain routes. +By default IP routes all exist within a single virtual network. If you use the same IP +space(s) in different physical private networks, all meant to be reachable via IP routes, +then you have to manage the ambiguous IP routes by associating them to virtual networks. +See "cloudflared tunnel network --help" for more information.`, Subcommands: []*cli.Command{ { Name: "add", Action: cliutil.ConfiguredAction(addRouteCommand), - Usage: "Add any new network to the routing table reachable via the tunnel", - UsageText: "cloudflared tunnel [--config FILEPATH] route ip add [CIDR] [TUNNEL] [COMMENT?]", - Description: `Adds any network route space (represented as a CIDR) to your routing table. -That network space becomes reachable for requests egressing from a user's machine + Usage: "Add a new network to the routing table reachable via a Tunnel", + UsageText: "cloudflared tunnel [--config FILEPATH] route ip add [flags] [CIDR] [TUNNEL] [COMMENT?]", + Description: `Adds a network IP route space (represented as a CIDR) to your routing table. +That network IP space becomes reachable for requests egressing from a user's machine as long as it is using Cloudflare WARP client and is enrolled in the same account -that is running the tunnel chosen here. Further, those requests will be proxied to -the specified tunnel, and reach an IP in the given CIDR, as long as that IP is -reachable from the tunnel.`, +that is running the Tunnel chosen here. Further, those requests will be proxied to +the specified Tunnel, and reach an IP in the given CIDR, as long as that IP is +reachable from cloudflared. +If the CIDR exists in more than one private network, to be connected with Cloudflare +Tunnels, then you have to manage those IP routes with virtual networks (see +"cloudflared tunnel network --help)". In those cases, you then have to tell +which virtual network's routing table you want to add the route to with: +"cloudflared tunnel route ip add --virtual-network [ID/name] [CIDR] [TUNNEL]".`, + Flags: []cli.Flag{vnetFlag}, }, { Name: "show", @@ -49,17 +68,22 @@ reachable from the tunnel.`, Name: "delete", Action: cliutil.ConfiguredAction(deleteRouteCommand), Usage: "Delete a row from your organization's private routing table", - UsageText: "cloudflared tunnel [--config FILEPATH] route ip delete [CIDR]", - Description: `Deletes the row for a given CIDR from your routing table. That portion -of your network will no longer be reachable by the WARP clients.`, + UsageText: "cloudflared tunnel [--config FILEPATH] route ip delete [flags] [CIDR]", + Description: `Deletes the row for a given CIDR from your routing table. That portion of your network +will no longer be reachable by the WARP clients. Note that if you use virtual +networks, then you have to tell which virtual network whose routing table you +have a row deleted from.`, + Flags: []cli.Flag{vnetFlag}, }, { Name: "get", Action: cliutil.ConfiguredAction(getRouteByIPCommand), Usage: "Check which row of the routing table matches a given IP.", - UsageText: "cloudflared tunnel [--config FILEPATH] route ip get [IP]", - Description: `Checks which row of the routing table will be used to proxy a given IP. - This helps check and validate your config.`, + UsageText: "cloudflared tunnel [--config FILEPATH] route ip get [flags] [IP]", + Description: `Checks which row of the routing table will be used to proxy a given IP. This helps check +and validate your config. Note that if you use virtual networks, then you have +to tell which virtual network whose routing table you want to use.`, + Flags: []cli.Flag{vnetFlag}, }, }, } @@ -67,7 +91,7 @@ of your network will no longer be reachable by the WARP clients.`, func showRoutesFlags() []cli.Flag { flags := make([]cli.Flag, 0) - flags = append(flags, teamnet.FilterFlags...) + flags = append(flags, cfapi.IpRouteFilterFlags...) flags = append(flags, outputFormatFlag) return flags } @@ -78,7 +102,7 @@ func showRoutesCommand(c *cli.Context) error { return err } - filter, err := teamnet.NewFromCLI(c) + filter, err := cfapi.NewIpRouteFilterFromCLI(c) if err != nil { return errors.Wrap(err, "invalid config for routing filters") } @@ -98,7 +122,7 @@ func showRoutesCommand(c *cli.Context) error { if len(routes) > 0 { formatAndPrintRouteList(routes) } else { - fmt.Println("You have no routes, use 'cloudflared tunnel route ip add' to add a route") + fmt.Println("No routes were found for the given filter flags. You can use 'cloudflared tunnel route ip add' to add a route.") } return nil @@ -112,7 +136,9 @@ func addRouteCommand(c *cli.Context) error { if c.NArg() < 2 { return errors.New("You must supply at least 2 arguments, first the network you wish to route (in CIDR form e.g. 1.2.3.4/32) and then the tunnel ID to proxy with") } + args := c.Args() + _, network, err := net.ParseCIDR(args.Get(0)) if err != nil { return errors.Wrap(err, "Invalid network CIDR") @@ -120,19 +146,32 @@ func addRouteCommand(c *cli.Context) error { if network == nil { return errors.New("Invalid network CIDR") } + tunnelRef := args.Get(1) tunnelID, err := sc.findID(tunnelRef) if err != nil { return errors.Wrap(err, "Invalid tunnel") } + comment := "" if c.NArg() >= 3 { comment = args.Get(2) } - _, err = sc.addRoute(teamnet.NewRoute{ + + var vnetId *uuid.UUID + if c.IsSet(vnetFlag.Name) { + id, err := getVnetId(sc, c.String(vnetFlag.Name)) + if err != nil { + return err + } + vnetId = &id + } + + _, err = sc.addRoute(cfapi.NewRoute{ Comment: comment, Network: *network, TunnelID: tunnelID, + VNetID: vnetId, }) if err != nil { return errors.Wrap(err, "API error") @@ -146,9 +185,11 @@ func deleteRouteCommand(c *cli.Context) error { if err != nil { return err } + if c.NArg() != 1 { return errors.New("You must supply exactly one argument, the network whose route you want to delete (in CIDR form e.g. 1.2.3.4/32)") } + _, network, err := net.ParseCIDR(c.Args().First()) if err != nil { return errors.Wrap(err, "Invalid network CIDR") @@ -156,7 +197,20 @@ func deleteRouteCommand(c *cli.Context) error { if network == nil { return errors.New("Invalid network CIDR") } - if err := sc.deleteRoute(*network); err != nil { + + params := cfapi.DeleteRouteParams{ + Network: *network, + } + + if c.IsSet(vnetFlag.Name) { + vnetId, err := getVnetId(sc, c.String(vnetFlag.Name)) + if err != nil { + return err + } + params.VNetID = &vnetId + } + + if err := sc.deleteRoute(params); err != nil { return errors.Wrap(err, "API error") } fmt.Printf("Successfully deleted route for %s\n", network) @@ -177,19 +231,32 @@ func getRouteByIPCommand(c *cli.Context) error { if ip == nil { return fmt.Errorf("Invalid IP %s", ipInput) } - route, err := sc.getRouteByIP(ip) + + params := cfapi.GetRouteByIpParams{ + Ip: ip, + } + + if c.IsSet(vnetFlag.Name) { + vnetId, err := getVnetId(sc, c.String(vnetFlag.Name)) + if err != nil { + return err + } + params.VNetID = &vnetId + } + + route, err := sc.getRouteByIP(params) if err != nil { return errors.Wrap(err, "API error") } if route.IsZero() { fmt.Printf("No route matches the IP %s\n", ip) } else { - formatAndPrintRouteList([]*teamnet.DetailedRoute{&route}) + formatAndPrintRouteList([]*cfapi.DetailedRoute{&route}) } return nil } -func formatAndPrintRouteList(routes []*teamnet.DetailedRoute) { +func formatAndPrintRouteList(routes []*cfapi.DetailedRoute) { const ( minWidth = 0 tabWidth = 8 @@ -202,7 +269,7 @@ func formatAndPrintRouteList(routes []*teamnet.DetailedRoute) { defer writer.Flush() // Print column headers with tabbed columns - _, _ = fmt.Fprintln(writer, "NETWORK\tCOMMENT\tTUNNEL ID\tTUNNEL NAME\tCREATED\tDELETED\t") + _, _ = fmt.Fprintln(writer, "NETWORK\tVIRTUAL NET ID\tCOMMENT\tTUNNEL ID\tTUNNEL NAME\tCREATED\tDELETED\t") // Loop through routes, create formatted string for each, and print using tabwriter for _, route := range routes { diff --git a/cmd/cloudflared/tunnel/vnets_subcommands.go b/cmd/cloudflared/tunnel/vnets_subcommands.go new file mode 100644 index 00000000..e55327ce --- /dev/null +++ b/cmd/cloudflared/tunnel/vnets_subcommands.go @@ -0,0 +1,285 @@ +package tunnel + +import ( + "fmt" + "os" + "text/tabwriter" + + "github.com/google/uuid" + "github.com/pkg/errors" + "github.com/urfave/cli/v2" + + "github.com/cloudflare/cloudflared/cfapi" + "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" + "github.com/cloudflare/cloudflared/cmd/cloudflared/updater" +) + +var ( + makeDefaultFlag = &cli.BoolFlag{ + Name: "default", + Aliases: []string{"d"}, + Usage: "The virtual network becomes the default one for the account. This means that all operations that " + + "omit a virtual network will now implicitly be using this virtual network (i.e., the default one) such " + + "as new IP routes that are created. When this flag is not set, the virtual network will not become the " + + "default one in the account.", + } + newNameFlag = &cli.StringFlag{ + Name: "name", + Aliases: []string{"n"}, + Usage: "The new name for the virtual network.", + } + newCommentFlag = &cli.StringFlag{ + Name: "comment", + Aliases: []string{"c"}, + Usage: "A new comment describing the purpose of the virtual network.", + } +) + +func buildVirtualNetworkSubcommand(hidden bool) *cli.Command { + return &cli.Command{ + Name: "network", + Usage: "Configure and query virtual networks to manage private IP routes with overlapping IPs.", + UsageText: "cloudflared tunnel [--config FILEPATH] network COMMAND [arguments...]", + Description: `cloudflared allows to manage IP routes that expose origins in your private network space via their IP directly +to clients outside (e.g. using WARP client) --- those are configurable via "cloudflared tunnel route ip" commands. +By default, all those IP routes live in the same virtual network. Managing virtual networks (e.g. by creating a +new one) becomes relevant when you have different private networks that have overlapping IPs. E.g.: if you have +a private network A running Tunnel 1, and private network B running Tunnel 2, it is possible that both Tunnels +expose the same IP space (say 10.0.0.0/8); to handle that, you have to add each IP Route (one that points to +Tunnel 1 and another that points to Tunnel 2) in different Virtual Networks. That way, if your clients are on +Virtual Network X, they will see Tunnel 1 (via Route A) and not see Tunnel 2 (since its Route B is associated +to another Virtual Network Y).`, + Hidden: hidden, + Subcommands: []*cli.Command{ + { + Name: "add", + Action: cliutil.ConfiguredAction(addVirtualNetworkCommand), + Usage: "Add a new virtual network to which IP routes can be attached", + UsageText: "cloudflared tunnel [--config FILEPATH] network add [flags] NAME [\"comment\"]", + Description: `Adds a new virtual network. You can then attach IP routes to this virtual network with "cloudflared tunnel route ip" +commands. By doing so, such route(s) become segregated from route(s) in another virtual networks. Note that all +routes exist within some virtual network. If you do not specify any, then the system pre-creates a default virtual +network to which all routes belong. That is fine if you do not have overlapping IPs within different physical +private networks in your infrastructure exposed via Cloudflare Tunnel. Note: if a virtual network is added as +the new default, then the previous existing default virtual network will be automatically modified to no longer +be the current default.`, + Flags: []cli.Flag{makeDefaultFlag}, + Hidden: hidden, + }, + { + Name: "list", + Action: cliutil.ConfiguredAction(listVirtualNetworksCommand), + Usage: "Lists the virtual networks", + UsageText: "cloudflared tunnel [--config FILEPATH] network list [flags]", + Description: "Lists the virtual networks based on the given filter flags.", + Flags: listVirtualNetworksFlags(), + Hidden: hidden, + }, + { + Name: "delete", + Action: cliutil.ConfiguredAction(deleteVirtualNetworkCommand), + Usage: "Delete a virtual network", + UsageText: "cloudflared tunnel [--config FILEPATH] network delete VIRTUAL_NETWORK", + Description: `Deletes the virtual network (given its ID or name). This is only possible if that virtual network is unused. +A virtual network may be used by IP routes or by WARP devices.`, + Hidden: hidden, + }, + { + Name: "update", + Action: cliutil.ConfiguredAction(updateVirtualNetworkCommand), + Usage: "Update a virtual network", + UsageText: "cloudflared tunnel [--config FILEPATH] network update [flags] VIRTUAL_NETWORK", + Description: `Updates the virtual network (given its ID or name). If this virtual network is updated to become the new +default, then the previously existing default virtual network will also be modified to no longer be the default. +You cannot update a virtual network to not be the default anymore directly. Instead, you should create a new +default or update an existing one to become the default.`, + Flags: []cli.Flag{newNameFlag, newCommentFlag, makeDefaultFlag}, + Hidden: hidden, + }, + }, + } +} + +func listVirtualNetworksFlags() []cli.Flag { + flags := make([]cli.Flag, 0) + flags = append(flags, cfapi.VnetFilterFlags...) + flags = append(flags, outputFormatFlag) + return flags +} + +func addVirtualNetworkCommand(c *cli.Context) error { + sc, err := newSubcommandContext(c) + if err != nil { + return err + } + if c.NArg() < 1 { + return errors.New("You must supply at least 1 argument, the name of the virtual network you wish to add.") + } + + warningChecker := updater.StartWarningCheck(c) + defer warningChecker.LogWarningIfAny(sc.log) + + args := c.Args() + + name := args.Get(0) + + comment := "" + if c.NArg() >= 2 { + comment = args.Get(1) + } + + newVnet := cfapi.NewVirtualNetwork{ + Name: name, + Comment: comment, + IsDefault: c.Bool(makeDefaultFlag.Name), + } + createdVnet, err := sc.addVirtualNetwork(newVnet) + + if err != nil { + return errors.Wrap(err, "Could not add virtual network") + } + + extraMsg := "" + if createdVnet.IsDefault { + extraMsg = " (as the new default for this account) " + } + fmt.Printf( + "Successfully added virtual 'network' %s with ID: %s%s\n"+ + "You can now add IP routes attached to this virtual network. See `cloudflared tunnel route ip add -help`\n", + name, createdVnet.ID, extraMsg, + ) + return nil +} + +func listVirtualNetworksCommand(c *cli.Context) error { + sc, err := newSubcommandContext(c) + if err != nil { + return err + } + + warningChecker := updater.StartWarningCheck(c) + defer warningChecker.LogWarningIfAny(sc.log) + + filter, err := cfapi.NewFromCLI(c) + if err != nil { + return errors.Wrap(err, "invalid flags for filtering virtual networks") + } + + vnets, err := sc.listVirtualNetworks(filter) + if err != nil { + return err + } + + if outputFormat := c.String(outputFormatFlag.Name); outputFormat != "" { + return renderOutput(outputFormat, vnets) + } + + if len(vnets) > 0 { + formatAndPrintVnetsList(vnets) + } else { + fmt.Println("No virtual networks were found for the given filter flags. You can use 'cloudflared tunnel network add' to add a virtual network.") + } + + return nil +} + +func deleteVirtualNetworkCommand(c *cli.Context) error { + sc, err := newSubcommandContext(c) + if err != nil { + return err + } + if c.NArg() != 1 { + return errors.New("You must supply exactly one argument, either the ID or name of the virtual network to delete") + } + + input := c.Args().Get(0) + vnetId, err := getVnetId(sc, input) + if err != nil { + return err + } + + if err := sc.deleteVirtualNetwork(vnetId); err != nil { + return errors.Wrap(err, "API error") + } + fmt.Printf("Successfully deleted virtual network '%s'\n", input) + return nil +} + +func updateVirtualNetworkCommand(c *cli.Context) error { + sc, err := newSubcommandContext(c) + if err != nil { + return err + } + if c.NArg() != 1 { + return errors.New(" You must supply exactly one argument, either the ID or (current) name of the virtual network to update") + } + + input := c.Args().Get(0) + vnetId, err := getVnetId(sc, input) + if err != nil { + return err + } + + updates := cfapi.UpdateVirtualNetwork{} + + if c.IsSet(newNameFlag.Name) { + newName := c.String(newNameFlag.Name) + updates.Name = &newName + } + if c.IsSet(newCommentFlag.Name) { + newComment := c.String(newCommentFlag.Name) + updates.Comment = &newComment + } + if c.IsSet(makeDefaultFlag.Name) { + isDefault := c.Bool(makeDefaultFlag.Name) + updates.IsDefault = &isDefault + } + + if err := sc.updateVirtualNetwork(vnetId, updates); err != nil { + return errors.Wrap(err, "API error") + } + fmt.Printf("Successfully updated virtual network '%s'\n", input) + return nil +} + +func getVnetId(sc *subcommandContext, input string) (uuid.UUID, error) { + val, err := uuid.Parse(input) + if err == nil { + return val, nil + } + + filter := cfapi.NewVnetFilter() + filter.WithDeleted(false) + filter.ByName(input) + + vnets, err := sc.listVirtualNetworks(filter) + if err != nil { + return uuid.Nil, err + } + + if len(vnets) != 1 { + return uuid.Nil, fmt.Errorf("there should only be 1 non-deleted virtual network named %s", input) + } + + return vnets[0].ID, nil +} + +func formatAndPrintVnetsList(vnets []*cfapi.VirtualNetwork) { + const ( + minWidth = 0 + tabWidth = 8 + padding = 1 + padChar = ' ' + flags = 0 + ) + + writer := tabwriter.NewWriter(os.Stdout, minWidth, tabWidth, padding, padChar, flags) + defer writer.Flush() + + _, _ = fmt.Fprintln(writer, "ID\tNAME\tIS DEFAULT\tCOMMENT\tCREATED\tDELETED\t") + + for _, virtualNetwork := range vnets { + formattedStr := virtualNetwork.TableString() + _, _ = fmt.Fprintln(writer, formattedStr) + } +} diff --git a/cmd/cloudflared/updater/update.go b/cmd/cloudflared/updater/update.go index b93d84e5..c385cf04 100644 --- a/cmd/cloudflared/updater/update.go +++ b/cmd/cloudflared/updater/update.go @@ -19,7 +19,7 @@ import ( const ( DefaultCheckUpdateFreq = time.Hour * 24 - noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/" + noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/run-as-service" noUpdateOnWindowsMessage = "cloudflared will not automatically update on Windows systems." noUpdateManagedPackageMessage = "cloudflared will not automatically update if installed by a package manager." isManagedInstallFile = ".installedFromPackageManager" diff --git a/cmd/cloudflared/updater/workers_service_test.go b/cmd/cloudflared/updater/workers_service_test.go index 94cf8a7e..72a51c9c 100644 --- a/cmd/cloudflared/updater/workers_service_test.go +++ b/cmd/cloudflared/updater/workers_service_test.go @@ -1,3 +1,4 @@ +//go:build !windows // +build !windows package updater diff --git a/cmd/cloudflared/updater/workers_update.go b/cmd/cloudflared/updater/workers_update.go index 9bc96b1a..443f896d 100644 --- a/cmd/cloudflared/updater/workers_update.go +++ b/cmd/cloudflared/updater/workers_update.go @@ -47,7 +47,7 @@ type batchData struct { } // WorkersVersion implements the Version interface. -// It contains everything needed to preform a version upgrade +// It contains everything needed to perform a version upgrade type WorkersVersion struct { downloadURL string checksum string diff --git a/cmd/cloudflared/windows_service.go b/cmd/cloudflared/windows_service.go index cc0b0a3c..2eba4780 100644 --- a/cmd/cloudflared/windows_service.go +++ b/cmd/cloudflared/windows_service.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package main @@ -25,8 +26,8 @@ import ( const ( windowsServiceName = "Cloudflared" - windowsServiceDescription = "Argo Tunnel agent" - windowsServiceUrl = "https://developers.cloudflare.com/argo-tunnel/reference/service/" + windowsServiceDescription = "Cloudflare Tunnel agent" + windowsServiceUrl = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/run-as-service#windows" recoverActionDelay = time.Second * 20 failureCountResetPeriod = time.Hour * 24 @@ -45,16 +46,16 @@ const ( func runApp(app *cli.App, graceShutdownC chan struct{}) { app.Commands = append(app.Commands, &cli.Command{ Name: "service", - Usage: "Manages the Argo Tunnel Windows service", + Usage: "Manages the Cloudflare Tunnel Windows service", Subcommands: []*cli.Command{ { Name: "install", - Usage: "Install Argo Tunnel as a Windows service", + Usage: "Install Cloudflare Tunnel as a Windows service", Action: cliutil.ConfiguredAction(installWindowsService), }, { Name: "uninstall", - Usage: "Uninstall the Argo Tunnel service", + Usage: "Uninstall the Cloudflare Tunnel service", Action: cliutil.ConfiguredAction(uninstallWindowsService), }, }, @@ -176,7 +177,7 @@ func (s *windowsService) Execute(serviceArgs []string, r <-chan svc.ChangeReques func installWindowsService(c *cli.Context) error { zeroLogger := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog) - zeroLogger.Info().Msg("Installing Argo Tunnel Windows service") + zeroLogger.Info().Msg("Installing Cloudflare Tunnel Windows service") exepath, err := os.Executable() if err != nil { return errors.Wrap(err, "Cannot find path name that start the process") @@ -198,7 +199,7 @@ func installWindowsService(c *cli.Context) error { return errors.Wrap(err, "Cannot install service") } defer s.Close() - log.Info().Msg("Argo Tunnel agent service is installed") + log.Info().Msg("Cloudflare Tunnel agent service is installed") err = eventlog.InstallAsEventCreate(windowsServiceName, eventlog.Error|eventlog.Warning|eventlog.Info) if err != nil { s.Delete() @@ -217,7 +218,7 @@ func uninstallWindowsService(c *cli.Context) error { With(). Str(LogFieldWindowsServiceName, windowsServiceName).Logger() - log.Info().Msg("Uninstalling Argo Tunnel Windows Service") + log.Info().Msg("Uninstalling Cloudflare Tunnel Windows Service") m, err := mgr.Connect() if err != nil { return errors.Wrap(err, "Cannot establish a connection to the service control manager") @@ -232,7 +233,7 @@ func uninstallWindowsService(c *cli.Context) error { if err != nil { return errors.Wrap(err, "Cannot delete service") } - log.Info().Msg("Argo Tunnel agent service is uninstalled") + log.Info().Msg("Cloudflare Tunnel agent service is uninstalled") err = eventlog.Remove(windowsServiceName) if err != nil { return errors.Wrap(err, "Cannot remove event logger") diff --git a/component-tests/constants.py b/component-tests/constants.py index 0212017c..bb4b9d22 100644 --- a/component-tests/constants.py +++ b/component-tests/constants.py @@ -3,3 +3,7 @@ MAX_RETRIES = 5 BACKOFF_SECS = 7 PROXY_DNS_PORT = 9053 + + +def protocols(): + return ["h2mux", "http2", "quic"] diff --git a/component-tests/test_logging.py b/component-tests/test_logging.py index f83c42e6..51282ded 100644 --- a/component-tests/test_logging.py +++ b/component-tests/test_logging.py @@ -75,7 +75,7 @@ class TestLogging: } config = component_tests_config(extra_config) with start_cloudflared(tmp_path, config, new_process=True, capture_output=False): - wait_tunnel_ready(tunnel_url=config.get_url()) + wait_tunnel_ready(tunnel_url=config.get_url(), cfd_logs=str(log_file)) assert_log_in_file(log_file) assert_json_log(log_file) @@ -88,5 +88,5 @@ class TestLogging: } config = component_tests_config(extra_config) with start_cloudflared(tmp_path, config, new_process=True, capture_output=False): - wait_tunnel_ready(tunnel_url=config.get_url()) + wait_tunnel_ready(tunnel_url=config.get_url(), cfd_logs=str(log_dir)) assert_log_to_dir(config, log_dir) diff --git a/component-tests/test_reconnect.py b/component-tests/test_reconnect.py index 02ad7ba2..0b601171 100644 --- a/component-tests/test_reconnect.py +++ b/component-tests/test_reconnect.py @@ -7,6 +7,7 @@ import pytest from flaky import flaky from conftest import CfdModes +from constants import protocols from util import start_cloudflared, wait_tunnel_ready, check_tunnel_not_connected @@ -18,9 +19,16 @@ class TestReconnect: "stdin-control": True, } + def _extra_config(self, protocol): + return { + "stdin-control": True, + "protocol": protocol, + } + @pytest.mark.skipif(platform.system() == "Windows", reason=f"Currently buggy on Windows TUN-4584") - def test_named_reconnect(self, tmp_path, component_tests_config): - config = component_tests_config(self.extra_config) + @pytest.mark.parametrize("protocol", protocols()) + def test_named_reconnect(self, tmp_path, component_tests_config, protocol): + config = component_tests_config(self._extra_config(protocol)) with start_cloudflared(tmp_path, config, new_process=True, allow_input=True, capture_output=False) as cloudflared: # Repeat the test multiple times because some issues only occur after multiple reconnects self.assert_reconnect(config, cloudflared, 5) diff --git a/component-tests/test_termination.py b/component-tests/test_termination.py index fbca69d0..ef12edc8 100644 --- a/component-tests/test_termination.py +++ b/component-tests/test_termination.py @@ -8,6 +8,7 @@ import time import pytest import requests +from constants import protocols from util import start_cloudflared, wait_tunnel_ready, check_tunnel_not_connected @@ -17,17 +18,21 @@ def supported_signals(): return [signal.SIGTERM, signal.SIGINT] -class TestTermination(): +class TestTermination: grace_period = 5 timeout = 10 - extra_config = { - "grace-period": f"{grace_period}s", - } sse_endpoint = "/sse?freq=1s" + def _extra_config(self, protocol): + return { + "grace-period": f"{self.grace_period}s", + "protocol": protocol, + } + @pytest.mark.parametrize("signal", supported_signals()) - def test_graceful_shutdown(self, tmp_path, component_tests_config, signal): - config = component_tests_config(self.extra_config) + @pytest.mark.parametrize("protocol", protocols()) + def test_graceful_shutdown(self, tmp_path, component_tests_config, signal, protocol): + config = component_tests_config(self._extra_config(protocol)) with start_cloudflared( tmp_path, config, new_process=True, capture_output=False) as cloudflared: wait_tunnel_ready(tunnel_url=config.get_url()) @@ -47,8 +52,9 @@ class TestTermination(): # test cloudflared terminates before grace period expires when all eyeball # connections are drained @pytest.mark.parametrize("signal", supported_signals()) - def test_shutdown_once_no_connection(self, tmp_path, component_tests_config, signal): - config = component_tests_config(self.extra_config) + @pytest.mark.parametrize("protocol", protocols()) + def test_shutdown_once_no_connection(self, tmp_path, component_tests_config, signal, protocol): + config = component_tests_config(self._extra_config(protocol)) with start_cloudflared( tmp_path, config, new_process=True, capture_output=False) as cloudflared: wait_tunnel_ready(tunnel_url=config.get_url()) @@ -66,8 +72,9 @@ class TestTermination(): self.wait_eyeball_thread(in_flight_req, self.grace_period) @pytest.mark.parametrize("signal", supported_signals()) - def test_no_connection_shutdown(self, tmp_path, component_tests_config, signal): - config = component_tests_config(self.extra_config) + @pytest.mark.parametrize("protocol", protocols()) + def test_no_connection_shutdown(self, tmp_path, component_tests_config, signal, protocol): + config = component_tests_config(self._extra_config(protocol)) with start_cloudflared( tmp_path, config, new_process=True, capture_output=False) as cloudflared: wait_tunnel_ready(tunnel_url=config.get_url()) diff --git a/component-tests/util.py b/component-tests/util.py index a8731d9d..3c42d2a7 100644 --- a/component-tests/util.py +++ b/component-tests/util.py @@ -1,12 +1,13 @@ -from contextlib import contextmanager import logging -import requests -from retrying import retry +import os import subprocess -import yaml - +from contextlib import contextmanager from time import sleep +import requests +import yaml +from retrying import retry + from constants import METRICS_PORT, MAX_RETRIES, BACKOFF_SECS LOGGER = logging.getLogger(__name__) @@ -19,7 +20,8 @@ def write_config(directory, config): return config_path -def start_cloudflared(directory, config, cfd_args=["run"], cfd_pre_args=["tunnel"], new_process=False, allow_input=False, capture_output=True, root=False): +def start_cloudflared(directory, config, cfd_args=["run"], cfd_pre_args=["tunnel"], new_process=False, + allow_input=False, capture_output=True, root=False): config_path = write_config(directory, config.full_config) cmd = cloudflared_cmd(config, config_path, cfd_args, cfd_pre_args, root) if new_process: @@ -53,18 +55,42 @@ def run_cloudflared_background(cmd, allow_input, capture_output): LOGGER.info(f"cloudflared log: {cfd.stderr.read()}") +def wait_tunnel_ready(tunnel_url=None, require_min_connections=1, cfd_logs=None): + try: + inner_wait_tunnel_ready(tunnel_url, require_min_connections) + except Exception as e: + if cfd_logs is not None: + _log_cloudflared_logs(cfd_logs) + raise e + + @retry(stop_max_attempt_number=MAX_RETRIES, wait_fixed=BACKOFF_SECS * 1000) -def wait_tunnel_ready(tunnel_url=None, require_min_connections=1): +def inner_wait_tunnel_ready(tunnel_url=None, require_min_connections=1): metrics_url = f'http://localhost:{METRICS_PORT}/ready' with requests.Session() as s: resp = send_request(s, metrics_url, True) - assert resp.json()[ - "readyConnections"] >= require_min_connections, f"Ready endpoint returned {resp.json()} but we expect at least {require_min_connections} connections" + + assert resp.json()["readyConnections"] >= require_min_connections, \ + f"Ready endpoint returned {resp.json()} but we expect at least {require_min_connections} connections" + if tunnel_url is not None: send_request(s, tunnel_url, True) +def _log_cloudflared_logs(cfd_logs): + log_file = cfd_logs + if os.path.isdir(cfd_logs): + files = os.listdir(cfd_logs) + if len(files) == 0: + return + log_file = os.path.join(cfd_logs, files[0]) + with open(log_file, "r") as f: + LOGGER.warning("Cloudflared Tunnel was not ready:") + for line in f.readlines(): + LOGGER.warning(line) + + @retry(stop_max_attempt_number=MAX_RETRIES * BACKOFF_SECS, wait_fixed=1000) def check_tunnel_not_connected(): url = f'http://localhost:{METRICS_PORT}/ready' diff --git a/connection/connection.go b/connection/connection.go index f061672c..2a57229f 100644 --- a/connection/connection.go +++ b/connection/connection.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "math" "net/http" "strconv" "strings" @@ -19,6 +20,7 @@ const ( lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;" LogFieldConnIndex = "connIndex" MaxGracePeriod = time.Minute * 3 + MaxConcurrentStreams = math.MaxUint32 ) var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols)) diff --git a/connection/control.go b/connection/control.go index f586f842..c0c6a1d7 100644 --- a/connection/control.go +++ b/connection/control.go @@ -29,7 +29,9 @@ type controlStream struct { // ControlStreamHandler registers connections with origintunneld and initiates graceful shutdown. type ControlStreamHandler interface { - ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, shouldWaitForUnregister bool) error + // ServeControlStream handles the control plane of the transport in the current goroutine calling this + ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions) error + // IsStopped tells whether the method above has finished IsStopped() bool } @@ -61,7 +63,6 @@ func (c *controlStream) ServeControlStream( ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, - shouldWaitForUnregister bool, ) error { rpcClient := c.newRPCClientFunc(ctx, rw, c.observer.log) @@ -71,12 +72,7 @@ func (c *controlStream) ServeControlStream( } c.connectedFuse.Connected() - if shouldWaitForUnregister { - c.waitForUnregister(ctx, rpcClient) - } else { - go c.waitForUnregister(ctx, rpcClient) - } - + c.waitForUnregister(ctx, rpcClient) return nil } diff --git a/connection/h2mux_header.go b/connection/h2mux_header.go new file mode 100644 index 00000000..3987f0db --- /dev/null +++ b/connection/h2mux_header.go @@ -0,0 +1,128 @@ +package connection + +import ( + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/pkg/errors" + + "github.com/cloudflare/cloudflared/h2mux" +) + +// H2RequestHeadersToH1Request converts the HTTP/2 headers coming from origintunneld +// to an HTTP/1 Request object destined for the local origin web service. +// This operation includes conversion of the pseudo-headers into their closest +// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3 +func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error { + for _, header := range h2 { + name := strings.ToLower(header.Name) + if !IsH2muxControlRequestHeader(name) { + continue + } + + switch name { + case ":method": + h1.Method = header.Value + case ":scheme": + // noop - use the preexisting scheme from h1.URL + case ":authority": + // Otherwise the host header will be based on the origin URL + h1.Host = header.Value + case ":path": + // We don't want to be an "opinionated" proxy, so ideally we would use :path as-is. + // However, this HTTP/1 Request object belongs to the Go standard library, + // whose URL package makes some opinionated decisions about the encoding of + // URL characters: see the docs of https://godoc.org/net/url#URL, + // in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath, + // which is always used when computing url.URL.String(), whether we'd like it or not. + // + // Well, not *always*. We could circumvent this by using url.URL.Opaque. But + // that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque + // is treated differently when HTTP_PROXY is set! + // See https://github.com/golang/go/issues/5684#issuecomment-66080888 + // + // This means we are subject to the behavior of net/url's function `shouldEscape` + // (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101 + + if header.Value == "*" { + h1.URL.Path = "*" + continue + } + // Due to the behavior of validation.ValidateUrl, h1.URL may + // already have a partial value, with or without a trailing slash. + base := h1.URL.String() + base = strings.TrimRight(base, "/") + // But we know :path begins with '/', because we handled '*' above - see RFC7540 + requestURL, err := url.Parse(base + header.Value) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value)) + } + h1.URL = requestURL + case "content-length": + contentLength, err := strconv.ParseInt(header.Value, 10, 64) + if err != nil { + return fmt.Errorf("unparseable content length") + } + h1.ContentLength = contentLength + case RequestUserHeaders: + // Do not forward the serialized headers to the origin -- deserialize them, and ditch the serialized version + // Find and parse user headers serialized into a single one + userHeaders, err := DeserializeHeaders(header.Value) + if err != nil { + return errors.Wrap(err, "Unable to parse user headers") + } + for _, userHeader := range userHeaders { + h1.Header.Add(userHeader.Name, userHeader.Value) + } + default: + // All other control headers shall just be proxied transparently + h1.Header.Add(header.Name, header.Value) + } + } + + return nil +} + +func H1ResponseToH2ResponseHeaders(status int, h1 http.Header) (h2 []h2mux.Header) { + h2 = []h2mux.Header{ + {Name: ":status", Value: strconv.Itoa(status)}, + } + userHeaders := make(http.Header, len(h1)) + for header, values := range h1 { + h2name := strings.ToLower(header) + if h2name == "content-length" { + // This header has meaning in HTTP/2 and will be used by the edge, + // so it should be sent as an HTTP/2 response header. + + // Since these are http2 headers, they're required to be lowercase + h2 = append(h2, h2mux.Header{Name: "content-length", Value: values[0]}) + } else if !IsH2muxControlResponseHeader(h2name) || IsWebsocketClientHeader(h2name) { + // User headers, on the other hand, must all be serialized so that + // HTTP/2 header validation won't be applied to HTTP/1 header values + userHeaders[header] = values + } + } + + // Perform user header serialization and set them in the single header + h2 = append(h2, h2mux.Header{Name: ResponseUserHeaders, Value: SerializeHeaders(userHeaders)}) + return h2 +} + +// IsH2muxControlRequestHeader is called in the direction of eyeball -> origin. +func IsH2muxControlRequestHeader(headerName string) bool { + return headerName == "content-length" || + headerName == "connection" || headerName == "upgrade" || // Websocket request headers + strings.HasPrefix(headerName, ":") || + strings.HasPrefix(headerName, "cf-") +} + +// IsH2muxControlResponseHeader is called in the direction of eyeball <- origin. +func IsH2muxControlResponseHeader(headerName string) bool { + return headerName == "content-length" || + strings.HasPrefix(headerName, ":") || + strings.HasPrefix(headerName, "cf-int-") || + strings.HasPrefix(headerName, "cf-cloudflared-") +} diff --git a/connection/h2mux_header_test.go b/connection/h2mux_header_test.go new file mode 100644 index 00000000..a78e02f4 --- /dev/null +++ b/connection/h2mux_header_test.go @@ -0,0 +1,642 @@ +package connection + +import ( + "fmt" + "math/rand" + "net/http" + "net/url" + "reflect" + "regexp" + "strings" + "testing" + "testing/quick" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudflare/cloudflared/h2mux" +) + +type ByName []h2mux.Header + +func (a ByName) Len() int { return len(a) } +func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a ByName) Less(i, j int) bool { + if a[i].Name == a[j].Name { + return a[i].Value < a[j].Value + } + + return a[i].Name < a[j].Name +} + +func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + assert.NoError(t, err) + + mockHeaders := http.Header{ + "Mock header 1": {"Mock value 1"}, + "Mock header 2": {"Mock value 2"}, + } + + headersConversionErr := H2RequestHeadersToH1Request(createSerializedHeaders(RequestUserHeaders, mockHeaders), request) + + assert.True(t, reflect.DeepEqual(mockHeaders, request.Header)) + assert.NoError(t, headersConversionErr) +} + +func createSerializedHeaders(headersField string, headers http.Header) []h2mux.Header { + return []h2mux.Header{{ + Name: headersField, + Value: SerializeHeaders(headers), + }} +} + +func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + assert.NoError(t, err) + + emptyHeaders := make(http.Header) + headersConversionErr := H2RequestHeadersToH1Request( + []h2mux.Header{{ + Name: RequestUserHeaders, + Value: SerializeHeaders(emptyHeaders), + }}, + request, + ) + + assert.True(t, reflect.DeepEqual(emptyHeaders, request.Header)) + assert.NoError(t, headersConversionErr) +} + +func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + assert.NoError(t, err) + + mockRequestHeaders := []h2mux.Header{ + {Name: ":path", Value: "//bad_path/"}, + {Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, + } + + headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) + + assert.Equal(t, http.Header{ + "Mock header": []string{"Mock value"}, + }, request.Header) + + assert.Equal(t, "http://example.com//bad_path/", request.URL.String()) + + assert.NoError(t, headersConversionErr) +} + +func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) + assert.NoError(t, err) + + mockRequestHeaders := []h2mux.Header{ + {Name: ":path", Value: "/?query=mock%20value"}, + {Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, + } + + headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) + + assert.Equal(t, http.Header{ + "Mock header": []string{"Mock value"}, + }, request.Header) + + assert.Equal(t, "http://example.com/?query=mock%20value", request.URL.String()) + + assert.NoError(t, headersConversionErr) +} + +func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) { + request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) + assert.NoError(t, err) + + mockRequestHeaders := []h2mux.Header{ + {Name: ":path", Value: "/mock%20path"}, + {Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, + } + + headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) + + assert.Equal(t, http.Header{ + "Mock header": []string{"Mock value"}, + }, request.Header) + + assert.Equal(t, "http://example.com/mock%20path", request.URL.String()) + + assert.NoError(t, headersConversionErr) +} + +func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) { + type testCase struct { + path string + want string + } + testCases := []testCase{ + { + path: "", + want: "", + }, + { + path: "/", + want: "/", + }, + { + path: "//", + want: "//", + }, + { + path: "/test", + want: "/test", + }, + { + path: "//test", + want: "//test", + }, + { + // https://github.com/cloudflare/cloudflared/issues/81 + path: "//test/", + want: "//test/", + }, + { + path: "/%2Ftest", + want: "/%2Ftest", + }, + { + path: "//%20test", + want: "//%20test", + }, + { + // https://github.com/cloudflare/cloudflared/issues/124 + path: "/test?get=somthing%20a", + want: "/test?get=somthing%20a", + }, + { + path: "/%20", + want: "/%20", + }, + { + // stdlib's EscapedPath() will always percent-encode ' ' + path: "/ ", + want: "/%20", + }, + { + path: "/ a ", + want: "/%20a%20", + }, + { + path: "/a%20b", + want: "/a%20b", + }, + { + path: "/foo/bar;param?query#frag", + want: "/foo/bar;param?query#frag", + }, + { + // stdlib's EscapedPath() will always percent-encode non-ASCII chars + path: "/a␠b", + want: "/a%E2%90%A0b", + }, + { + path: "/a-umlaut-ä", + want: "/a-umlaut-%C3%A4", + }, + { + path: "/a-umlaut-%C3%A4", + want: "/a-umlaut-%C3%A4", + }, + { + path: "/a-umlaut-%c3%a4", + want: "/a-umlaut-%c3%a4", + }, + { + // here the second '#' is treated as part of the fragment + path: "/a#b#c", + want: "/a#b%23c", + }, + { + path: "/a#b␠c", + want: "/a#b%E2%90%A0c", + }, + { + path: "/a#b%20c", + want: "/a#b%20c", + }, + { + path: "/a#b c", + want: "/a#b%20c", + }, + { + // stdlib's EscapedPath() will always percent-encode '\' + path: "/\\", + want: "/%5C", + }, + { + path: "/a\\", + want: "/a%5C", + }, + { + path: "/a,b.c.", + want: "/a,b.c.", + }, + { + path: "/.", + want: "/.", + }, + { + // stdlib's EscapedPath() will always percent-encode '`' + path: "/a`", + want: "/a%60", + }, + { + path: "/a[0]", + want: "/a[0]", + }, + { + path: "/?a[0]=5 &b[]=", + want: "/?a[0]=5 &b[]=", + }, + { + path: "/?a=%22b%20%22", + want: "/?a=%22b%20%22", + }, + } + + for index, testCase := range testCases { + requestURL := "https://example.com" + + request, err := http.NewRequest(http.MethodGet, requestURL, nil) + assert.NoError(t, err) + + mockRequestHeaders := []h2mux.Header{ + {Name: ":path", Value: testCase.path}, + {Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, + } + + headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) + assert.NoError(t, headersConversionErr) + + assert.Equal(t, + http.Header{ + "Mock header": []string{"Mock value"}, + }, + request.Header) + + assert.Equal(t, + "https://example.com"+testCase.want, + request.URL.String(), + "Failed URL index: %v %#v", index, testCase) + } +} + +func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) { + config := &quick.Config{ + Values: func(args []reflect.Value, rand *rand.Rand) { + args[0] = reflect.ValueOf(randomHTTP2Path(t, rand)) + }, + } + + type testOrigin struct { + url string + + expectedScheme string + expectedBasePath string + } + testOrigins := []testOrigin{ + { + url: "http://origin.hostname.example.com:8080", + expectedScheme: "http", + expectedBasePath: "http://origin.hostname.example.com:8080", + }, + { + url: "http://origin.hostname.example.com:8080/", + expectedScheme: "http", + expectedBasePath: "http://origin.hostname.example.com:8080", + }, + { + url: "http://origin.hostname.example.com:8080/api", + expectedScheme: "http", + expectedBasePath: "http://origin.hostname.example.com:8080/api", + }, + { + url: "http://origin.hostname.example.com:8080/api/", + expectedScheme: "http", + expectedBasePath: "http://origin.hostname.example.com:8080/api", + }, + { + url: "https://origin.hostname.example.com:8080/api", + expectedScheme: "https", + expectedBasePath: "https://origin.hostname.example.com:8080/api", + }, + } + + // use multiple schemes to demonstrate that the URL is based on the + // origin's scheme, not the :scheme header + for _, testScheme := range []string{"http", "https"} { + for _, testOrigin := range testOrigins { + assertion := func(testPath string) bool { + const expectedMethod = "POST" + const expectedHostname = "request.hostname.example.com" + + h2 := []h2mux.Header{ + {Name: ":method", Value: expectedMethod}, + {Name: ":scheme", Value: testScheme}, + {Name: ":authority", Value: expectedHostname}, + {Name: ":path", Value: testPath}, + {Name: RequestUserHeaders, Value: ""}, + } + h1, err := http.NewRequest("GET", testOrigin.url, nil) + require.NoError(t, err) + + err = H2RequestHeadersToH1Request(h2, h1) + return assert.NoError(t, err) && + assert.Equal(t, expectedMethod, h1.Method) && + assert.Equal(t, expectedHostname, h1.Host) && + assert.Equal(t, testOrigin.expectedScheme, h1.URL.Scheme) && + assert.Equal(t, testOrigin.expectedBasePath+testPath, h1.URL.String()) + } + err := quick.Check(assertion, config) + assert.NoError(t, err) + } + } +} + +func randomASCIIPrintableChar(rand *rand.Rand) int { + // smallest printable ASCII char is 32, largest is 126 + const startPrintable = 32 + const endPrintable = 127 + return startPrintable + rand.Intn(endPrintable-startPrintable) +} + +// randomASCIIText generates an ASCII string, some of whose characters may be +// percent-encoded. Its "logical length" (ignoring percent-encoding) is +// between 1 and `maxLength`. +func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string { + length := minLength + rand.Intn(maxLength) + var result strings.Builder + for i := 0; i < length; i++ { + c := randomASCIIPrintableChar(rand) + + // 1/4 chance of using percent encoding when not necessary + if c == '%' || rand.Intn(4) == 0 { + result.WriteString(fmt.Sprintf("%%%02X", c)) + } else { + result.WriteByte(byte(c)) + } + } + return result.String() +} + +// Calls `randomASCIIText` and ensures the result is a valid URL path, +// i.e. one that can pass unchanged through url.URL.String() +func randomHTTP1Path(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string { + text := randomASCIIText(rand, minLength, maxLength) + re, err := regexp.Compile("[^/;,]*") + require.NoError(t, err) + return "/" + re.ReplaceAllStringFunc(text, url.PathEscape) +} + +// Calls `randomASCIIText` and ensures the result is a valid URL query, +// i.e. one that can pass unchanged through url.URL.String() +func randomHTTP1Query(rand *rand.Rand, minLength int, maxLength int) string { + text := randomASCIIText(rand, minLength, maxLength) + return "?" + strings.ReplaceAll(text, "#", "%23") +} + +// Calls `randomASCIIText` and ensures the result is a valid URL fragment, +// i.e. one that can pass unchanged through url.URL.String() +func randomHTTP1Fragment(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string { + text := randomASCIIText(rand, minLength, maxLength) + u, err := url.Parse("#" + text) + require.NoError(t, err) + return u.String() +} + +// Assemble a random :path pseudoheader that is legal by Go stdlib standards +// (i.e. all characters will satisfy "net/url".shouldEscape for their respective locations) +func randomHTTP2Path(t *testing.T, rand *rand.Rand) string { + result := randomHTTP1Path(t, rand, 1, 64) + if rand.Intn(2) == 1 { + result += randomHTTP1Query(rand, 1, 32) + } + if rand.Intn(2) == 1 { + result += randomHTTP1Fragment(t, rand, 1, 16) + } + return result +} + +func stdlibHeaderToH2muxHeader(headers http.Header) (h2muxHeaders []h2mux.Header) { + for name, values := range headers { + for _, value := range values { + h2muxHeaders = append(h2muxHeaders, h2mux.Header{Name: name, Value: value}) + } + } + + return h2muxHeaders +} + +func TestParseRequestHeaders(t *testing.T) { + mockUserHeadersToSerialize := http.Header{ + "Mock-Header-One": {"1", "1.5"}, + "Mock-Header-Two": {"2"}, + "Mock-Header-Three": {"3"}, + } + + mockHeaders := []h2mux.Header{ + {Name: "One", Value: "1"}, // will be dropped + {Name: "Cf-Two", Value: "cf-value-1"}, + {Name: "Cf-Two", Value: "cf-value-2"}, + {Name: RequestUserHeaders, Value: SerializeHeaders(mockUserHeadersToSerialize)}, + } + + expectedHeaders := []h2mux.Header{ + {Name: "Cf-Two", Value: "cf-value-1"}, + {Name: "Cf-Two", Value: "cf-value-2"}, + {Name: "Mock-Header-One", Value: "1"}, + {Name: "Mock-Header-One", Value: "1.5"}, + {Name: "Mock-Header-Two", Value: "2"}, + {Name: "Mock-Header-Three", Value: "3"}, + } + h1 := &http.Request{ + Header: make(http.Header), + } + err := H2RequestHeadersToH1Request(mockHeaders, h1) + assert.NoError(t, err) + assert.ElementsMatch(t, expectedHeaders, stdlibHeaderToH2muxHeader(h1.Header)) +} + +func TestIsH2muxControlRequestHeader(t *testing.T) { + controlRequestHeaders := []string{ + // Anything that begins with cf- + "cf-sample-header", + + // Any http2 pseudoheader + ":sample-pseudo-header", + + // content-length is a special case, it has to be there + // for some requests to work (per the HTTP2 spec) + "content-length", + + // Websocket request headers + "connection", + "upgrade", + } + + for _, header := range controlRequestHeaders { + assert.True(t, IsH2muxControlRequestHeader(header)) + } +} + +func TestIsH2muxControlResponseHeader(t *testing.T) { + controlResponseHeaders := []string{ + // Anything that begins with cf-int- or cf-cloudflared- + "cf-int-sample-header", + "cf-cloudflared-sample-header", + + // Any http2 pseudoheader + ":sample-pseudo-header", + + // content-length is a special case, it has to be there + // for some requests to work (per the HTTP2 spec) + "content-length", + } + + for _, header := range controlResponseHeaders { + assert.True(t, IsH2muxControlResponseHeader(header)) + } +} + +func TestIsNotH2muxControlRequestHeader(t *testing.T) { + notControlRequestHeaders := []string{ + "mock-header", + "another-sample-header", + } + + for _, header := range notControlRequestHeaders { + assert.False(t, IsH2muxControlRequestHeader(header)) + } +} + +func TestIsNotH2muxControlResponseHeader(t *testing.T) { + notControlResponseHeaders := []string{ + "mock-header", + "another-sample-header", + "upgrade", + "connection", + "cf-whatever", // On the response path, we only want to filter cf-int- and cf-cloudflared- + } + + for _, header := range notControlResponseHeaders { + assert.False(t, IsH2muxControlResponseHeader(header)) + } +} + +func TestH1ResponseToH2ResponseHeaders(t *testing.T) { + mockHeaders := http.Header{ + "User-header-one": {""}, + "User-header-two": {"1", "2"}, + "cf-header": {"cf-value"}, + "cf-int-header": {"cf-int-value"}, + "cf-cloudflared-header": {"cf-cloudflared-value"}, + "Content-Length": {"123"}, + } + mockResponse := http.Response{ + StatusCode: 200, + Header: mockHeaders, + } + + headers := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header) + + serializedHeadersIndex := -1 + for i, header := range headers { + if header.Name == ResponseUserHeaders { + serializedHeadersIndex = i + break + } + } + assert.NotEqual(t, -1, serializedHeadersIndex) + actualControlHeaders := append( + headers[:serializedHeadersIndex], + headers[serializedHeadersIndex+1:]..., + ) + expectedControlHeaders := []h2mux.Header{ + {Name: ":status", Value: "200"}, + {Name: "content-length", Value: "123"}, + } + + assert.ElementsMatch(t, expectedControlHeaders, actualControlHeaders) + + actualUserHeaders, err := DeserializeHeaders(headers[serializedHeadersIndex].Value) + expectedUserHeaders := []h2mux.Header{ + {Name: "User-header-one", Value: ""}, + {Name: "User-header-two", Value: "1"}, + {Name: "User-header-two", Value: "2"}, + {Name: "cf-header", Value: "cf-value"}, + } + assert.NoError(t, err) + assert.ElementsMatch(t, expectedUserHeaders, actualUserHeaders) +} + +// The purpose of this test is to check that our code and the http.Header +// implementation don't throw validation errors about header size +func TestHeaderSize(t *testing.T) { + largeValue := randSeq(5 * 1024 * 1024) // 5Mb + largeHeaders := http.Header{ + "User-header": {largeValue}, + } + mockResponse := http.Response{ + StatusCode: 200, + Header: largeHeaders, + } + + serializedHeaders := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header) + request, err := http.NewRequest(http.MethodGet, "https://example.com/", nil) + assert.NoError(t, err) + for _, header := range serializedHeaders { + request.Header.Set(header.Name, header.Value) + } + + for _, header := range serializedHeaders { + if header.Name != ResponseUserHeaders { + continue + } + + deserializedHeaders, err := DeserializeHeaders(header.Value) + assert.NoError(t, err) + assert.Equal(t, largeValue, deserializedHeaders[0].Value) + } +} + +func randSeq(n int) string { + randomizer := rand.New(rand.NewSource(17)) + var letters = []rune(":;,+/=abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + b := make([]rune, n) + for i := range b { + b[i] = letters[randomizer.Intn(len(letters))] + } + return string(b) +} + +func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) { + ser := "eC1mb3J3YXJkZWQtcHJvdG8:aHR0cHM;dXBncmFkZS1pbnNlY3VyZS1yZXF1ZXN0cw:MQ;YWNjZXB0LWxhbmd1YWdl:ZW4tVVMsZW47cT0wLjkscnU7cT0wLjg;YWNjZXB0LWVuY29kaW5n:Z3ppcA;eC1mb3J3YXJkZWQtZm9y:MTczLjI0NS42MC42;dXNlci1hZ2VudA:TW96aWxsYS81LjAgKE1hY2ludG9zaDsgSW50ZWwgTWFjIE9TIFggMTBfMTRfNikgQXBwbGVXZWJLaXQvNTM3LjM2IChLSFRNTCwgbGlrZSBHZWNrbykgQ2hyb21lLzg0LjAuNDE0Ny44OSBTYWZhcmkvNTM3LjM2;c2VjLWZldGNoLW1vZGU:bmF2aWdhdGU;Y2RuLWxvb3A:Y2xvdWRmbGFyZQ;c2VjLWZldGNoLWRlc3Q:ZG9jdW1lbnQ;c2VjLWZldGNoLXVzZXI:PzE;c2VjLWZldGNoLXNpdGU:bm9uZQ;Y29va2ll:X19jZmR1aWQ9ZGNkOWZjOGNjNWMxMzE0NTMyYTFkMjhlZDEyOWRhOTYwMTU2OTk1MTYzNDsgX19jZl9ibT1mYzY2MzMzYzAzZmM0MWFiZTZmOWEyYzI2ZDUwOTA0YzIxYzZhMTQ2LTE1OTU2MjIzNDEtMTgwMC1BZTVzS2pIU2NiWGVFM05mMUhrTlNQMG1tMHBLc2pQWkloVnM1Z2g1SkNHQkFhS1UxVDB2b003alBGN3FjMHVSR2NjZGcrWHdhL1EzbTJhQzdDVU4xZ2M9;YWNjZXB0:dGV4dC9odG1sLGFwcGxpY2F0aW9uL3hodG1sK3htbCxhcHBsaWNhdGlvbi94bWw7cT0wLjksaW1hZ2Uvd2VicCxpbWFnZS9hcG5nLCovKjtxPTAuOCxhcHBsaWNhdGlvbi9zaWduZWQtZXhjaGFuZ2U7dj1iMztxPTAuOQ" + h2, _ := DeserializeHeaders(ser) + h1 := make(http.Header) + for _, header := range h2 { + h1.Add(header.Name, header.Value) + } + h1.Add("Content-Length", "200") + h1.Add("Cf-Something", "Else") + h1.Add("Upgrade", "websocket") + + h1resp := &http.Response{ + StatusCode: 200, + Header: h1, + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = H1ResponseToH2ResponseHeaders(h1resp.StatusCode, h1resp.Header) + } +} diff --git a/connection/header.go b/connection/header.go index 61242f6b..d1544263 100644 --- a/connection/header.go +++ b/connection/header.go @@ -4,8 +4,6 @@ import ( "encoding/base64" "fmt" "net/http" - "net/url" - "strconv" "strings" "github.com/pkg/errors" @@ -44,93 +42,9 @@ func mustInitRespMetaHeader(src string) string { var headerEncoding = base64.RawStdEncoding -// note: all h2mux headers should be lower-case (http/2 style) -const () - -// H2RequestHeadersToH1Request converts the HTTP/2 headers coming from origintunneld -// to an HTTP/1 Request object destined for the local origin web service. -// This operation includes conversion of the pseudo-headers into their closest -// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3 -func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error { - for _, header := range h2 { - name := strings.ToLower(header.Name) - if !IsControlRequestHeader(name) { - continue - } - - switch name { - case ":method": - h1.Method = header.Value - case ":scheme": - // noop - use the preexisting scheme from h1.URL - case ":authority": - // Otherwise the host header will be based on the origin URL - h1.Host = header.Value - case ":path": - // We don't want to be an "opinionated" proxy, so ideally we would use :path as-is. - // However, this HTTP/1 Request object belongs to the Go standard library, - // whose URL package makes some opinionated decisions about the encoding of - // URL characters: see the docs of https://godoc.org/net/url#URL, - // in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath, - // which is always used when computing url.URL.String(), whether we'd like it or not. - // - // Well, not *always*. We could circumvent this by using url.URL.Opaque. But - // that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque - // is treated differently when HTTP_PROXY is set! - // See https://github.com/golang/go/issues/5684#issuecomment-66080888 - // - // This means we are subject to the behavior of net/url's function `shouldEscape` - // (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101 - - if header.Value == "*" { - h1.URL.Path = "*" - continue - } - // Due to the behavior of validation.ValidateUrl, h1.URL may - // already have a partial value, with or without a trailing slash. - base := h1.URL.String() - base = strings.TrimRight(base, "/") - // But we know :path begins with '/', because we handled '*' above - see RFC7540 - requestURL, err := url.Parse(base + header.Value) - if err != nil { - return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value)) - } - h1.URL = requestURL - case "content-length": - contentLength, err := strconv.ParseInt(header.Value, 10, 64) - if err != nil { - return fmt.Errorf("unparseable content length") - } - h1.ContentLength = contentLength - case RequestUserHeaders: - // Do not forward the serialized headers to the origin -- deserialize them, and ditch the serialized version - // Find and parse user headers serialized into a single one - userHeaders, err := DeserializeHeaders(header.Value) - if err != nil { - return errors.Wrap(err, "Unable to parse user headers") - } - for _, userHeader := range userHeaders { - h1.Header.Add(userHeader.Name, userHeader.Value) - } - default: - // All other control headers shall just be proxied transparently - h1.Header.Add(header.Name, header.Value) - } - } - - return nil -} - -func IsControlRequestHeader(headerName string) bool { - return headerName == "content-length" || - headerName == "connection" || headerName == "upgrade" || // Websocket request headers - strings.HasPrefix(headerName, ":") || - strings.HasPrefix(headerName, "cf-") -} - +// IsControlResponseHeader is called in the direction of eyeball <- origin. func IsControlResponseHeader(headerName string) bool { - return headerName == "content-length" || - strings.HasPrefix(headerName, ":") || + return strings.HasPrefix(headerName, ":") || strings.HasPrefix(headerName, "cf-int-") || strings.HasPrefix(headerName, "cf-cloudflared-") } @@ -142,31 +56,6 @@ func IsWebsocketClientHeader(headerName string) bool { headerName == "upgrade" } -func H1ResponseToH2ResponseHeaders(status int, h1 http.Header) (h2 []h2mux.Header) { - h2 = []h2mux.Header{ - {Name: ":status", Value: strconv.Itoa(status)}, - } - userHeaders := make(http.Header, len(h1)) - for header, values := range h1 { - h2name := strings.ToLower(header) - if h2name == "content-length" { - // This header has meaning in HTTP/2 and will be used by the edge, - // so it should be sent as an HTTP/2 response header. - - // Since these are http2 headers, they're required to be lowercase - h2 = append(h2, h2mux.Header{Name: "content-length", Value: values[0]}) - } else if !IsControlResponseHeader(h2name) || IsWebsocketClientHeader(h2name) { - // User headers, on the other hand, must all be serialized so that - // HTTP/2 header validation won't be applied to HTTP/1 header values - userHeaders[header] = values - } - } - - // Perform user header serialization and set them in the single header - h2 = append(h2, h2mux.Header{Name: ResponseUserHeaders, Value: SerializeHeaders(userHeaders)}) - return h2 -} - // Serialize HTTP1.x headers by base64-encoding each header name and value, // and then joining them in the format of [key:value;] func SerializeHeaders(h1Headers http.Header) string { diff --git a/connection/header_test.go b/connection/header_test.go index ad6881b1..88add316 100644 --- a/connection/header_test.go +++ b/connection/header_test.go @@ -2,441 +2,14 @@ package connection import ( "fmt" - "math/rand" "net/http" - "net/url" "reflect" - "regexp" "sort" - "strings" "testing" - "testing/quick" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/cloudflare/cloudflared/h2mux" ) -type ByName []h2mux.Header - -func (a ByName) Len() int { return len(a) } -func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a ByName) Less(i, j int) bool { - if a[i].Name == a[j].Name { - return a[i].Value < a[j].Value - } - - return a[i].Name < a[j].Name -} - -func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) { - request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) - assert.NoError(t, err) - - mockHeaders := http.Header{ - "Mock header 1": {"Mock value 1"}, - "Mock header 2": {"Mock value 2"}, - } - - headersConversionErr := H2RequestHeadersToH1Request(createSerializedHeaders(RequestUserHeaders, mockHeaders), request) - - assert.True(t, reflect.DeepEqual(mockHeaders, request.Header)) - assert.NoError(t, headersConversionErr) -} - -func createSerializedHeaders(headersField string, headers http.Header) []h2mux.Header { - return []h2mux.Header{{ - Name: headersField, - Value: SerializeHeaders(headers), - }} -} - -func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) { - request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) - assert.NoError(t, err) - - emptyHeaders := make(http.Header) - headersConversionErr := H2RequestHeadersToH1Request( - []h2mux.Header{{ - Name: RequestUserHeaders, - Value: SerializeHeaders(emptyHeaders), - }}, - request, - ) - - assert.True(t, reflect.DeepEqual(emptyHeaders, request.Header)) - assert.NoError(t, headersConversionErr) -} - -func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) { - request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) - assert.NoError(t, err) - - mockRequestHeaders := []h2mux.Header{ - {Name: ":path", Value: "//bad_path/"}, - {Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, - } - - headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) - - assert.Equal(t, http.Header{ - "Mock header": []string{"Mock value"}, - }, request.Header) - - assert.Equal(t, "http://example.com//bad_path/", request.URL.String()) - - assert.NoError(t, headersConversionErr) -} - -func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) { - request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) - assert.NoError(t, err) - - mockRequestHeaders := []h2mux.Header{ - {Name: ":path", Value: "/?query=mock%20value"}, - {Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, - } - - headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) - - assert.Equal(t, http.Header{ - "Mock header": []string{"Mock value"}, - }, request.Header) - - assert.Equal(t, "http://example.com/?query=mock%20value", request.URL.String()) - - assert.NoError(t, headersConversionErr) -} - -func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) { - request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) - assert.NoError(t, err) - - mockRequestHeaders := []h2mux.Header{ - {Name: ":path", Value: "/mock%20path"}, - {Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, - } - - headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) - - assert.Equal(t, http.Header{ - "Mock header": []string{"Mock value"}, - }, request.Header) - - assert.Equal(t, "http://example.com/mock%20path", request.URL.String()) - - assert.NoError(t, headersConversionErr) -} - -func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) { - type testCase struct { - path string - want string - } - testCases := []testCase{ - { - path: "", - want: "", - }, - { - path: "/", - want: "/", - }, - { - path: "//", - want: "//", - }, - { - path: "/test", - want: "/test", - }, - { - path: "//test", - want: "//test", - }, - { - // https://github.com/cloudflare/cloudflared/issues/81 - path: "//test/", - want: "//test/", - }, - { - path: "/%2Ftest", - want: "/%2Ftest", - }, - { - path: "//%20test", - want: "//%20test", - }, - { - // https://github.com/cloudflare/cloudflared/issues/124 - path: "/test?get=somthing%20a", - want: "/test?get=somthing%20a", - }, - { - path: "/%20", - want: "/%20", - }, - { - // stdlib's EscapedPath() will always percent-encode ' ' - path: "/ ", - want: "/%20", - }, - { - path: "/ a ", - want: "/%20a%20", - }, - { - path: "/a%20b", - want: "/a%20b", - }, - { - path: "/foo/bar;param?query#frag", - want: "/foo/bar;param?query#frag", - }, - { - // stdlib's EscapedPath() will always percent-encode non-ASCII chars - path: "/a␠b", - want: "/a%E2%90%A0b", - }, - { - path: "/a-umlaut-ä", - want: "/a-umlaut-%C3%A4", - }, - { - path: "/a-umlaut-%C3%A4", - want: "/a-umlaut-%C3%A4", - }, - { - path: "/a-umlaut-%c3%a4", - want: "/a-umlaut-%c3%a4", - }, - { - // here the second '#' is treated as part of the fragment - path: "/a#b#c", - want: "/a#b%23c", - }, - { - path: "/a#b␠c", - want: "/a#b%E2%90%A0c", - }, - { - path: "/a#b%20c", - want: "/a#b%20c", - }, - { - path: "/a#b c", - want: "/a#b%20c", - }, - { - // stdlib's EscapedPath() will always percent-encode '\' - path: "/\\", - want: "/%5C", - }, - { - path: "/a\\", - want: "/a%5C", - }, - { - path: "/a,b.c.", - want: "/a,b.c.", - }, - { - path: "/.", - want: "/.", - }, - { - // stdlib's EscapedPath() will always percent-encode '`' - path: "/a`", - want: "/a%60", - }, - { - path: "/a[0]", - want: "/a[0]", - }, - { - path: "/?a[0]=5 &b[]=", - want: "/?a[0]=5 &b[]=", - }, - { - path: "/?a=%22b%20%22", - want: "/?a=%22b%20%22", - }, - } - - for index, testCase := range testCases { - requestURL := "https://example.com" - - request, err := http.NewRequest(http.MethodGet, requestURL, nil) - assert.NoError(t, err) - - mockRequestHeaders := []h2mux.Header{ - {Name: ":path", Value: testCase.path}, - {Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})}, - } - - headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request) - assert.NoError(t, headersConversionErr) - - assert.Equal(t, - http.Header{ - "Mock header": []string{"Mock value"}, - }, - request.Header) - - assert.Equal(t, - "https://example.com"+testCase.want, - request.URL.String(), - "Failed URL index: %v %#v", index, testCase) - } -} - -func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) { - config := &quick.Config{ - Values: func(args []reflect.Value, rand *rand.Rand) { - args[0] = reflect.ValueOf(randomHTTP2Path(t, rand)) - }, - } - - type testOrigin struct { - url string - - expectedScheme string - expectedBasePath string - } - testOrigins := []testOrigin{ - { - url: "http://origin.hostname.example.com:8080", - expectedScheme: "http", - expectedBasePath: "http://origin.hostname.example.com:8080", - }, - { - url: "http://origin.hostname.example.com:8080/", - expectedScheme: "http", - expectedBasePath: "http://origin.hostname.example.com:8080", - }, - { - url: "http://origin.hostname.example.com:8080/api", - expectedScheme: "http", - expectedBasePath: "http://origin.hostname.example.com:8080/api", - }, - { - url: "http://origin.hostname.example.com:8080/api/", - expectedScheme: "http", - expectedBasePath: "http://origin.hostname.example.com:8080/api", - }, - { - url: "https://origin.hostname.example.com:8080/api", - expectedScheme: "https", - expectedBasePath: "https://origin.hostname.example.com:8080/api", - }, - } - - // use multiple schemes to demonstrate that the URL is based on the - // origin's scheme, not the :scheme header - for _, testScheme := range []string{"http", "https"} { - for _, testOrigin := range testOrigins { - assertion := func(testPath string) bool { - const expectedMethod = "POST" - const expectedHostname = "request.hostname.example.com" - - h2 := []h2mux.Header{ - {Name: ":method", Value: expectedMethod}, - {Name: ":scheme", Value: testScheme}, - {Name: ":authority", Value: expectedHostname}, - {Name: ":path", Value: testPath}, - {Name: RequestUserHeaders, Value: ""}, - } - h1, err := http.NewRequest("GET", testOrigin.url, nil) - require.NoError(t, err) - - err = H2RequestHeadersToH1Request(h2, h1) - return assert.NoError(t, err) && - assert.Equal(t, expectedMethod, h1.Method) && - assert.Equal(t, expectedHostname, h1.Host) && - assert.Equal(t, testOrigin.expectedScheme, h1.URL.Scheme) && - assert.Equal(t, testOrigin.expectedBasePath+testPath, h1.URL.String()) - } - err := quick.Check(assertion, config) - assert.NoError(t, err) - } - } -} - -func randomASCIIPrintableChar(rand *rand.Rand) int { - // smallest printable ASCII char is 32, largest is 126 - const startPrintable = 32 - const endPrintable = 127 - return startPrintable + rand.Intn(endPrintable-startPrintable) -} - -// randomASCIIText generates an ASCII string, some of whose characters may be -// percent-encoded. Its "logical length" (ignoring percent-encoding) is -// between 1 and `maxLength`. -func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string { - length := minLength + rand.Intn(maxLength) - var result strings.Builder - for i := 0; i < length; i++ { - c := randomASCIIPrintableChar(rand) - - // 1/4 chance of using percent encoding when not necessary - if c == '%' || rand.Intn(4) == 0 { - result.WriteString(fmt.Sprintf("%%%02X", c)) - } else { - result.WriteByte(byte(c)) - } - } - return result.String() -} - -// Calls `randomASCIIText` and ensures the result is a valid URL path, -// i.e. one that can pass unchanged through url.URL.String() -func randomHTTP1Path(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string { - text := randomASCIIText(rand, minLength, maxLength) - re, err := regexp.Compile("[^/;,]*") - require.NoError(t, err) - return "/" + re.ReplaceAllStringFunc(text, url.PathEscape) -} - -// Calls `randomASCIIText` and ensures the result is a valid URL query, -// i.e. one that can pass unchanged through url.URL.String() -func randomHTTP1Query(rand *rand.Rand, minLength int, maxLength int) string { - text := randomASCIIText(rand, minLength, maxLength) - return "?" + strings.ReplaceAll(text, "#", "%23") -} - -// Calls `randomASCIIText` and ensures the result is a valid URL fragment, -// i.e. one that can pass unchanged through url.URL.String() -func randomHTTP1Fragment(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string { - text := randomASCIIText(rand, minLength, maxLength) - u, err := url.Parse("#" + text) - require.NoError(t, err) - return u.String() -} - -// Assemble a random :path pseudoheader that is legal by Go stdlib standards -// (i.e. all characters will satisfy "net/url".shouldEscape for their respective locations) -func randomHTTP2Path(t *testing.T, rand *rand.Rand) string { - result := randomHTTP1Path(t, rand, 1, 64) - if rand.Intn(2) == 1 { - result += randomHTTP1Query(rand, 1, 32) - } - if rand.Intn(2) == 1 { - result += randomHTTP1Fragment(t, rand, 1, 16) - } - return result -} - -func stdlibHeaderToH2muxHeader(headers http.Header) (h2muxHeaders []h2mux.Header) { - for name, values := range headers { - for _, value := range values { - h2muxHeaders = append(h2muxHeaders, h2mux.Header{Name: name, Value: value}) - } - } - - return h2muxHeaders -} - func TestSerializeHeaders(t *testing.T) { request, err := http.NewRequest(http.MethodGet, "http://example.com", nil) assert.NoError(t, err) @@ -511,70 +84,13 @@ func TestDeserializeMalformed(t *testing.T) { } } -func TestParseRequestHeaders(t *testing.T) { - mockUserHeadersToSerialize := http.Header{ - "Mock-Header-One": {"1", "1.5"}, - "Mock-Header-Two": {"2"}, - "Mock-Header-Three": {"3"}, - } - - mockHeaders := []h2mux.Header{ - {Name: "One", Value: "1"}, // will be dropped - {Name: "Cf-Two", Value: "cf-value-1"}, - {Name: "Cf-Two", Value: "cf-value-2"}, - {Name: RequestUserHeaders, Value: SerializeHeaders(mockUserHeadersToSerialize)}, - } - - expectedHeaders := []h2mux.Header{ - {Name: "Cf-Two", Value: "cf-value-1"}, - {Name: "Cf-Two", Value: "cf-value-2"}, - {Name: "Mock-Header-One", Value: "1"}, - {Name: "Mock-Header-One", Value: "1.5"}, - {Name: "Mock-Header-Two", Value: "2"}, - {Name: "Mock-Header-Three", Value: "3"}, - } - h1 := &http.Request{ - Header: make(http.Header), - } - err := H2RequestHeadersToH1Request(mockHeaders, h1) - assert.NoError(t, err) - assert.ElementsMatch(t, expectedHeaders, stdlibHeaderToH2muxHeader(h1.Header)) -} - -func TestIsControlRequestHeader(t *testing.T) { - controlRequestHeaders := []string{ - // Anything that begins with cf- - "cf-sample-header", - - // Any http2 pseudoheader - ":sample-pseudo-header", - - // content-length is a special case, it has to be there - // for some requests to work (per the HTTP2 spec) - "content-length", - - // Websocket request headers - "connection", - "upgrade", - } - - for _, header := range controlRequestHeaders { - assert.True(t, IsControlRequestHeader(header)) - } -} - func TestIsControlResponseHeader(t *testing.T) { controlResponseHeaders := []string{ // Anything that begins with cf-int- or cf-cloudflared- "cf-int-sample-header", "cf-cloudflared-sample-header", - // Any http2 pseudoheader ":sample-pseudo-header", - - // content-length is a special case, it has to be there - // for some requests to work (per the HTTP2 spec) - "content-length", } for _, header := range controlResponseHeaders { @@ -582,17 +98,6 @@ func TestIsControlResponseHeader(t *testing.T) { } } -func TestIsNotControlRequestHeader(t *testing.T) { - notControlRequestHeaders := []string{ - "mock-header", - "another-sample-header", - } - - for _, header := range notControlRequestHeaders { - assert.False(t, IsControlRequestHeader(header)) - } -} - func TestIsNotControlResponseHeader(t *testing.T) { notControlResponseHeaders := []string{ "mock-header", @@ -606,112 +111,3 @@ func TestIsNotControlResponseHeader(t *testing.T) { assert.False(t, IsControlResponseHeader(header)) } } - -func TestH1ResponseToH2ResponseHeaders(t *testing.T) { - mockHeaders := http.Header{ - "User-header-one": {""}, - "User-header-two": {"1", "2"}, - "cf-header": {"cf-value"}, - "cf-int-header": {"cf-int-value"}, - "cf-cloudflared-header": {"cf-cloudflared-value"}, - "Content-Length": {"123"}, - } - mockResponse := http.Response{ - StatusCode: 200, - Header: mockHeaders, - } - - headers := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header) - - serializedHeadersIndex := -1 - for i, header := range headers { - if header.Name == ResponseUserHeaders { - serializedHeadersIndex = i - break - } - } - assert.NotEqual(t, -1, serializedHeadersIndex) - actualControlHeaders := append( - headers[:serializedHeadersIndex], - headers[serializedHeadersIndex+1:]..., - ) - expectedControlHeaders := []h2mux.Header{ - {Name: ":status", Value: "200"}, - {Name: "content-length", Value: "123"}, - } - - assert.ElementsMatch(t, expectedControlHeaders, actualControlHeaders) - - actualUserHeaders, err := DeserializeHeaders(headers[serializedHeadersIndex].Value) - expectedUserHeaders := []h2mux.Header{ - {Name: "User-header-one", Value: ""}, - {Name: "User-header-two", Value: "1"}, - {Name: "User-header-two", Value: "2"}, - {Name: "cf-header", Value: "cf-value"}, - } - assert.NoError(t, err) - assert.ElementsMatch(t, expectedUserHeaders, actualUserHeaders) -} - -// The purpose of this test is to check that our code and the http.Header -// implementation don't throw validation errors about header size -func TestHeaderSize(t *testing.T) { - largeValue := randSeq(5 * 1024 * 1024) // 5Mb - largeHeaders := http.Header{ - "User-header": {largeValue}, - } - mockResponse := http.Response{ - StatusCode: 200, - Header: largeHeaders, - } - - serializedHeaders := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header) - request, err := http.NewRequest(http.MethodGet, "https://example.com/", nil) - assert.NoError(t, err) - for _, header := range serializedHeaders { - request.Header.Set(header.Name, header.Value) - } - - for _, header := range serializedHeaders { - if header.Name != ResponseUserHeaders { - continue - } - - deserializedHeaders, err := DeserializeHeaders(header.Value) - assert.NoError(t, err) - assert.Equal(t, largeValue, deserializedHeaders[0].Value) - } -} - -func randSeq(n int) string { - randomizer := rand.New(rand.NewSource(17)) - var letters = []rune(":;,+/=abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") - b := make([]rune, n) - for i := range b { - b[i] = letters[randomizer.Intn(len(letters))] - } - return string(b) -} - -func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) { - ser := "eC1mb3J3YXJkZWQtcHJvdG8:aHR0cHM;dXBncmFkZS1pbnNlY3VyZS1yZXF1ZXN0cw:MQ;YWNjZXB0LWxhbmd1YWdl:ZW4tVVMsZW47cT0wLjkscnU7cT0wLjg;YWNjZXB0LWVuY29kaW5n:Z3ppcA;eC1mb3J3YXJkZWQtZm9y:MTczLjI0NS42MC42;dXNlci1hZ2VudA:TW96aWxsYS81LjAgKE1hY2ludG9zaDsgSW50ZWwgTWFjIE9TIFggMTBfMTRfNikgQXBwbGVXZWJLaXQvNTM3LjM2IChLSFRNTCwgbGlrZSBHZWNrbykgQ2hyb21lLzg0LjAuNDE0Ny44OSBTYWZhcmkvNTM3LjM2;c2VjLWZldGNoLW1vZGU:bmF2aWdhdGU;Y2RuLWxvb3A:Y2xvdWRmbGFyZQ;c2VjLWZldGNoLWRlc3Q:ZG9jdW1lbnQ;c2VjLWZldGNoLXVzZXI:PzE;c2VjLWZldGNoLXNpdGU:bm9uZQ;Y29va2ll:X19jZmR1aWQ9ZGNkOWZjOGNjNWMxMzE0NTMyYTFkMjhlZDEyOWRhOTYwMTU2OTk1MTYzNDsgX19jZl9ibT1mYzY2MzMzYzAzZmM0MWFiZTZmOWEyYzI2ZDUwOTA0YzIxYzZhMTQ2LTE1OTU2MjIzNDEtMTgwMC1BZTVzS2pIU2NiWGVFM05mMUhrTlNQMG1tMHBLc2pQWkloVnM1Z2g1SkNHQkFhS1UxVDB2b003alBGN3FjMHVSR2NjZGcrWHdhL1EzbTJhQzdDVU4xZ2M9;YWNjZXB0:dGV4dC9odG1sLGFwcGxpY2F0aW9uL3hodG1sK3htbCxhcHBsaWNhdGlvbi94bWw7cT0wLjksaW1hZ2Uvd2VicCxpbWFnZS9hcG5nLCovKjtxPTAuOCxhcHBsaWNhdGlvbi9zaWduZWQtZXhjaGFuZ2U7dj1iMztxPTAuOQ" - h2, _ := DeserializeHeaders(ser) - h1 := make(http.Header) - for _, header := range h2 { - h1.Add(header.Name, header.Value) - } - h1.Add("Content-Length", "200") - h1.Add("Cf-Something", "Else") - h1.Add("Upgrade", "websocket") - - h1resp := &http.Response{ - StatusCode: 200, - Header: h1, - } - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = H1ResponseToH2ResponseHeaders(h1resp.StatusCode, h1resp.Header) - } -} diff --git a/connection/http2.go b/connection/http2.go index 2e869890..c0ab8f23 100644 --- a/connection/http2.go +++ b/connection/http2.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "math" "net" "net/http" "runtime/debug" @@ -60,7 +59,7 @@ func NewHTTP2Connection( return &HTTP2Connection{ conn: conn, server: &http2.Server{ - MaxConcurrentStreams: math.MaxUint32, + MaxConcurrentStreams: MaxConcurrentStreams, }, config: config, connOptions: connOptions, @@ -109,7 +108,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch connType { case TypeControlStream: - if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, true); err != nil { + if err := c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions); err != nil { c.controlStreamErr = err c.log.Error().Err(err) respWriter.WriteErrorResponse() @@ -126,7 +125,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) { case TypeTCP: host, err := getRequestHost(r) if err != nil { - err := fmt.Errorf(`cloudflared recieved a warp-routing request with an empty host value: %w`, err) + err := fmt.Errorf(`cloudflared received a warp-routing request with an empty host value: %w`, err) c.log.Error().Err(err) respWriter.WriteErrorResponse() } @@ -185,12 +184,14 @@ func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) erro for name, values := range header { // Since these are http2 headers, they're required to be lowercase h2name := strings.ToLower(name) + if h2name == "content-length" { // This header has meaning in HTTP/2 and will be used by the edge, - // so it should be sent as an HTTP/2 response header. + // so it should be sent *also* as an HTTP/2 response header. dest[name] = values - // Since these are http2 headers, they're required to be lowercase - } else if !IsControlResponseHeader(h2name) || IsWebsocketClientHeader(h2name) { + } + + if !IsControlResponseHeader(h2name) || IsWebsocketClientHeader(h2name) { // User headers, on the other hand, must all be serialized so that // HTTP/2 header validation won't be applied to HTTP/1 header values userHeaders[name] = values diff --git a/connection/protocol.go b/connection/protocol.go index 1eb37cc1..399e6d9d 100644 --- a/connection/protocol.go +++ b/connection/protocol.go @@ -7,24 +7,24 @@ import ( "time" "github.com/rs/zerolog" + + "github.com/cloudflare/cloudflared/edgediscovery" ) const ( - AvailableProtocolFlagMessage = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux" + AvailableProtocolFlagMessage = "Available protocols: 'auto' - automatically chooses the best protocol over time (the default; and also the recommended one); 'quic' - based on QUIC, relying on UDP egress to Cloudflare edge; 'http2' - using Go's HTTP2 library, relying on TCP egress to Cloudflare edge; 'h2mux' - Cloudflare's implementation of HTTP/2, deprecated" // edgeH2muxTLSServerName is the server name to establish h2mux connection with edge edgeH2muxTLSServerName = "cftunnel.com" // edgeH2TLSServerName is the server name to establish http2 connection with edge edgeH2TLSServerName = "h2.cftunnel.com" // edgeQUICServerName is the server name to establish quic connection with edge. edgeQUICServerName = "quic.cftunnel.com" - // threshold to switch back to h2mux when the user intentionally pick --protocol http2 - explicitHTTP2FallbackThreshold = -1 - autoSelectFlag = "auto" + autoSelectFlag = "auto" ) var ( // ProtocolList represents a list of supported protocols for communication with the edge. - ProtocolList = []Protocol{H2mux, HTTP2, QUIC} + ProtocolList = []Protocol{H2mux, HTTP2, HTTP2Warp, QUIC, QUICWarp} ) type Protocol int64 @@ -36,6 +36,12 @@ const ( HTTP2 // QUIC is used only with named tunnels. QUIC + // HTTP2Warp is used only with named tunnels. It's useful for warp-routing where we don't want to fallback to + // H2mux on HTTP2 failure to connect. + HTTP2Warp + //QUICWarp is used only with named tunnels. It's useful for warp-routing where we want to fallback to HTTP2 but + // don't want HTTP2 to fallback to H2mux + QUICWarp ) // Fallback returns the fallback protocol and whether the protocol has a fallback @@ -45,8 +51,12 @@ func (p Protocol) fallback() (Protocol, bool) { return 0, false case HTTP2: return H2mux, true + case HTTP2Warp: + return 0, false case QUIC: return HTTP2, true + case QUICWarp: + return HTTP2Warp, true default: return 0, false } @@ -56,9 +66,9 @@ func (p Protocol) String() string { switch p { case H2mux: return "h2mux" - case HTTP2: + case HTTP2, HTTP2Warp: return "http2" - case QUIC: + case QUIC, QUICWarp: return "quic" default: return fmt.Sprintf("unknown protocol") @@ -71,11 +81,11 @@ func (p Protocol) TLSSettings() *TLSSettings { return &TLSSettings{ ServerName: edgeH2muxTLSServerName, } - case HTTP2: + case HTTP2, HTTP2Warp: return &TLSSettings{ ServerName: edgeH2TLSServerName, } - case QUIC: + case QUIC, QUICWarp: return &TLSSettings{ ServerName: edgeQUICServerName, NextProtos: []string{"argotunnel"}, @@ -108,29 +118,36 @@ func (s *staticProtocolSelector) Fallback() (Protocol, bool) { } type autoProtocolSelector struct { - lock sync.RWMutex - current Protocol - switchThrehold int32 - fetchFunc PercentageFetcher - refreshAfter time.Time - ttl time.Duration - log *zerolog.Logger + lock sync.RWMutex + + current Protocol + + // protocolPool is desired protocols in the order of priority they should be picked in. + protocolPool []Protocol + + switchThreshold int32 + fetchFunc PercentageFetcher + refreshAfter time.Time + ttl time.Duration + log *zerolog.Logger } func newAutoProtocolSelector( current Protocol, - switchThrehold int32, + protocolPool []Protocol, + switchThreshold int32, fetchFunc PercentageFetcher, ttl time.Duration, log *zerolog.Logger, ) *autoProtocolSelector { return &autoProtocolSelector{ - current: current, - switchThrehold: switchThrehold, - fetchFunc: fetchFunc, - refreshAfter: time.Now().Add(ttl), - ttl: ttl, - log: log, + current: current, + protocolPool: protocolPool, + switchThreshold: switchThreshold, + fetchFunc: fetchFunc, + refreshAfter: time.Now().Add(ttl), + ttl: ttl, + log: log, } } @@ -141,28 +158,39 @@ func (s *autoProtocolSelector) Current() Protocol { return s.current } - percentage, err := s.fetchFunc() + protocol, err := getProtocol(s.protocolPool, s.fetchFunc, s.switchThreshold) if err != nil { s.log.Err(err).Msg("Failed to refresh protocol") return s.current } + s.current = protocol - if s.switchThrehold < percentage { - s.current = HTTP2 - } else { - s.current = H2mux - } s.refreshAfter = time.Now().Add(s.ttl) return s.current } +func getProtocol(protocolPool []Protocol, fetchFunc PercentageFetcher, switchThreshold int32) (Protocol, error) { + protocolPercentages, err := fetchFunc() + if err != nil { + return 0, err + } + for _, protocol := range protocolPool { + protocolPercentage := protocolPercentages.GetPercentage(protocol.String()) + if protocolPercentage > switchThreshold { + return protocol, nil + } + } + + return protocolPool[len(protocolPool)-1], nil +} + func (s *autoProtocolSelector) Fallback() (Protocol, bool) { s.lock.RLock() defer s.lock.RUnlock() return s.current.fallback() } -type PercentageFetcher func() (int32, error) +type PercentageFetcher func() (edgediscovery.ProtocolPercents, error) func NewProtocolSelector( protocolFlag string, @@ -179,54 +207,76 @@ func NewProtocolSelector( }, nil } - // warp routing cannot be served over h2mux connections - if warpRoutingEnabled { - if protocolFlag == H2mux.String() { - log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.") - } - - if protocolFlag == QUIC.String() { - return &staticProtocolSelector{ - current: QUIC, - }, nil - } - return &staticProtocolSelector{ - current: HTTP2, - }, nil - } - - if protocolFlag == H2mux.String() { - return &staticProtocolSelector{ - current: H2mux, - }, nil - } - - if protocolFlag == QUIC.String() { - return newAutoProtocolSelector(QUIC, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil - } - - http2Percentage, err := fetchFunc() + threshold := switchThreshold(namedTunnel.Credentials.AccountTag) + fetchedProtocol, err := getProtocol([]Protocol{QUIC, HTTP2}, fetchFunc, threshold) if err != nil { log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.") return &staticProtocolSelector{ current: HTTP2, }, nil } - if protocolFlag == HTTP2.String() { - if http2Percentage < 0 { - return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil + if warpRoutingEnabled { + if protocolFlag == H2mux.String() || fetchedProtocol == H2mux { + log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.") + protocolFlag = HTTP2.String() + fetchedProtocol = HTTP2Warp } - return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil + return selectWarpRoutingProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol) } - if protocolFlag != autoSelectFlag { - return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) + return selectNamedTunnelProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol) +} + +func selectNamedTunnelProtocols( + protocolFlag string, + fetchFunc PercentageFetcher, + ttl time.Duration, + log *zerolog.Logger, + threshold int32, + protocol Protocol, +) (ProtocolSelector, error) { + // If the user picks a protocol, then we stick to it no matter what. + switch protocolFlag { + case H2mux.String(): + return &staticProtocolSelector{current: H2mux}, nil + case QUIC.String(): + return &staticProtocolSelector{current: QUIC}, nil + case HTTP2.String(): + return &staticProtocolSelector{current: HTTP2}, nil } - threshold := switchThreshold(namedTunnel.Credentials.AccountTag) - if threshold < http2Percentage { - return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, log), nil + + // If the user does not pick (hopefully the majority) then we use the one derived from the TXT DNS record and + // fallback on failures. + if protocolFlag == autoSelectFlag { + return newAutoProtocolSelector(protocol, []Protocol{QUIC, HTTP2, H2mux}, threshold, fetchFunc, ttl, log), nil } - return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, log), nil + + return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) +} + +func selectWarpRoutingProtocols( + protocolFlag string, + fetchFunc PercentageFetcher, + ttl time.Duration, + log *zerolog.Logger, + threshold int32, + protocol Protocol, +) (ProtocolSelector, error) { + // If the user picks a protocol, then we stick to it no matter what. + switch protocolFlag { + case QUIC.String(): + return &staticProtocolSelector{current: QUICWarp}, nil + case HTTP2.String(): + return &staticProtocolSelector{current: HTTP2Warp}, nil + } + + // If the user does not pick (hopefully the majority) then we use the one derived from the TXT DNS record and + // fallback on failures. + if protocolFlag == autoSelectFlag { + return newAutoProtocolSelector(protocol, []Protocol{QUICWarp, HTTP2Warp}, threshold, fetchFunc, ttl, log), nil + } + + return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage) } func switchThreshold(accountTag string) int32 { diff --git a/connection/protocol_test.go b/connection/protocol_test.go index b4a6299a..9bb8c50c 100644 --- a/connection/protocol_test.go +++ b/connection/protocol_test.go @@ -6,6 +6,8 @@ import ( "time" "github.com/stretchr/testify/assert" + + "github.com/cloudflare/cloudflared/edgediscovery" ) const ( @@ -21,29 +23,23 @@ var ( } ) -func mockFetcher(percentage int32) PercentageFetcher { - return func() (int32, error) { - return percentage, nil - } -} - -func mockFetcherWithError() PercentageFetcher { - return func() (int32, error) { - return 0, fmt.Errorf("failed to fetch precentage") +func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) PercentageFetcher { + return func() (edgediscovery.ProtocolPercents, error) { + if getError { + return nil, fmt.Errorf("failed to fetch percentage") + } + return protocolPercent, nil } } type dynamicMockFetcher struct { - percentage int32 - err error + protocolPercents edgediscovery.ProtocolPercents + err error } func (dmf *dynamicMockFetcher) fetch() PercentageFetcher { - return func() (int32, error) { - if dmf.err != nil { - return 0, dmf.err - } - return dmf.percentage, nil + return func() (edgediscovery.ProtocolPercents, error) { + return dmf.protocolPercents, dmf.err } } @@ -69,36 +65,59 @@ func TestNewProtocolSelector(t *testing.T) { name: "named tunnel over h2mux", protocol: "h2mux", expectedProtocol: H2mux, + fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil }, namedTunnelConfig: testNamedTunnelConfig, }, { name: "named tunnel over http2", protocol: "http2", expectedProtocol: HTTP2, - hasFallback: true, - expectedFallback: H2mux, - fetchFunc: mockFetcher(0), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), namedTunnelConfig: testNamedTunnelConfig, }, { - name: "named tunnel http2 disabled", + name: "named tunnel http2 disabled still gets http2 because it is manually picked", protocol: "http2", + expectedProtocol: HTTP2, + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "named tunnel quic disabled still gets quic because it is manually picked", + protocol: "quic", + expectedProtocol: QUIC, + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "named tunnel quic and http2 disabled", + protocol: "auto", expectedProtocol: H2mux, - fetchFunc: mockFetcher(-1), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "named tunnel quic disabled", + protocol: "auto", + expectedProtocol: HTTP2, + // Hasfallback true is because if http2 fails, then we further fallback to h2mux. + hasFallback: true, + expectedFallback: H2mux, + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}), namedTunnelConfig: testNamedTunnelConfig, }, { name: "named tunnel auto all http2 disabled", protocol: "auto", expectedProtocol: H2mux, - fetchFunc: mockFetcher(-1), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), namedTunnelConfig: testNamedTunnelConfig, }, { name: "named tunnel auto to h2mux", protocol: "auto", expectedProtocol: H2mux, - fetchFunc: mockFetcher(0), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}), namedTunnelConfig: testNamedTunnelConfig, }, { @@ -107,36 +126,71 @@ func TestNewProtocolSelector(t *testing.T) { expectedProtocol: HTTP2, hasFallback: true, expectedFallback: H2mux, - fetchFunc: mockFetcher(100), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "named tunnel auto to quic", + protocol: "auto", + expectedProtocol: QUIC, + hasFallback: true, + expectedFallback: HTTP2, + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), namedTunnelConfig: testNamedTunnelConfig, }, { name: "warp routing requesting h2mux", protocol: "h2mux", - expectedProtocol: HTTP2, + expectedProtocol: HTTP2Warp, hasFallback: false, - expectedFallback: H2mux, - fetchFunc: mockFetcher(100), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), + warpRoutingEnabled: true, + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1", + protocol: "h2mux", + expectedProtocol: HTTP2Warp, + hasFallback: false, + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelConfig, }, { name: "warp routing http2", protocol: "http2", - expectedProtocol: HTTP2, + expectedProtocol: HTTP2Warp, hasFallback: false, - expectedFallback: H2mux, - fetchFunc: mockFetcher(100), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), + warpRoutingEnabled: true, + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "warp routing quic", + protocol: "auto", + expectedProtocol: QUICWarp, + hasFallback: true, + expectedFallback: HTTP2Warp, + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelConfig, }, { name: "warp routing auto", protocol: "auto", - expectedProtocol: HTTP2, + expectedProtocol: HTTP2Warp, hasFallback: false, - expectedFallback: H2mux, - fetchFunc: mockFetcher(100), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), + warpRoutingEnabled: true, + namedTunnelConfig: testNamedTunnelConfig, + }, + { + name: "warp routing auto- quic", + protocol: "auto", + expectedProtocol: QUICWarp, + hasFallback: true, + expectedFallback: HTTP2Warp, + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}), warpRoutingEnabled: true, namedTunnelConfig: testNamedTunnelConfig, }, @@ -149,14 +203,14 @@ func TestNewProtocolSelector(t *testing.T) { { name: "named tunnel unknown protocol", protocol: "unknown", - fetchFunc: mockFetcher(100), + fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}), namedTunnelConfig: testNamedTunnelConfig, wantErr: true, }, { name: "named tunnel fetch error", - protocol: "unknown", - fetchFunc: mockFetcherWithError(), + protocol: "auto", + fetchFunc: mockFetcher(true), namedTunnelConfig: testNamedTunnelConfig, expectedProtocol: HTTP2, wantErr: false, @@ -164,18 +218,20 @@ func TestNewProtocolSelector(t *testing.T) { } for _, test := range tests { - selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log) - if test.wantErr { - assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) - } else { - assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name)) - assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name)) - fallback, ok := selector.Fallback() - assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name)) - if test.hasFallback { - assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name)) + t.Run(test.name, func(t *testing.T) { + selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log) + if test.wantErr { + assert.Error(t, err, fmt.Sprintf("test %s failed", test.name)) + } else { + assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name)) + assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name)) + fallback, ok := selector.Fallback() + assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name)) + if test.hasFallback { + assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name)) + } } - } + }) } } @@ -185,64 +241,66 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) { assert.NoError(t, err) assert.Equal(t, H2mux, selector.Current()) - fetcher.percentage = 100 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = 0 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} assert.Equal(t, H2mux, selector.Current()) - fetcher.percentage = 100 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, HTTP2, selector.Current()) fetcher.err = fmt.Errorf("failed to fetch") assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = -1 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} fetcher.err = nil assert.Equal(t, H2mux, selector.Current()) - fetcher.percentage = 0 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} assert.Equal(t, H2mux, selector.Current()) - fetcher.percentage = 100 - assert.Equal(t, HTTP2, selector.Current()) + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} + assert.Equal(t, QUIC, selector.Current()) } func TestHTTP2ProtocolSelectorRefresh(t *testing.T) { fetcher := dynamicMockFetcher{} + // Since the user chooses http2 on purpose, we always stick to it. selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), testNoTTL, &log) assert.NoError(t, err) assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = 100 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = 0 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} assert.Equal(t, HTTP2, selector.Current()) fetcher.err = fmt.Errorf("failed to fetch") assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = -1 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} fetcher.err = nil - assert.Equal(t, H2mux, selector.Current()) - - fetcher.percentage = 0 assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = 100 + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}} assert.Equal(t, HTTP2, selector.Current()) - fetcher.percentage = -1 - assert.Equal(t, H2mux, selector.Current()) + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}} + assert.Equal(t, HTTP2, selector.Current()) + + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}} + assert.Equal(t, HTTP2, selector.Current()) } func TestProtocolSelectorRefreshTTL(t *testing.T) { - fetcher := dynamicMockFetcher{percentage: 100} + fetcher := dynamicMockFetcher{} + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}} selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log) assert.NoError(t, err) - assert.Equal(t, HTTP2, selector.Current()) + assert.Equal(t, QUIC, selector.Current()) - fetcher.percentage = 0 - assert.Equal(t, HTTP2, selector.Current()) + fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 0}} + assert.Equal(t, QUIC, selector.Current()) } diff --git a/connection/quic.go b/connection/quic.go index c4f7e0ae..c1b4ff9d 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -9,11 +9,16 @@ import ( "net/http" "strconv" "strings" + "time" + "github.com/google/uuid" "github.com/lucas-clemente/quic-go" "github.com/pkg/errors" "github.com/rs/zerolog" + "golang.org/x/sync/errgroup" + "github.com/cloudflare/cloudflared/datagramsession" + "github.com/cloudflare/cloudflared/ingress" quicpogs "github.com/cloudflare/cloudflared/quic" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) @@ -29,60 +34,101 @@ const ( // QUICConnection represents the type that facilitates Proxying via QUIC streams. type QUICConnection struct { - session quic.Session - logger *zerolog.Logger - httpProxy OriginProxy - gracefulShutdownC <-chan struct{} - stoppedGracefully bool + session quic.Session + logger *zerolog.Logger + httpProxy OriginProxy + sessionManager datagramsession.Manager + controlStreamHandler ControlStreamHandler + connOptions *tunnelpogs.ConnectionOptions } // NewQUICConnection returns a new instance of QUICConnection. func NewQUICConnection( - ctx context.Context, quicConfig *quic.Config, edgeAddr net.Addr, tlsConfig *tls.Config, httpProxy OriginProxy, connOptions *tunnelpogs.ConnectionOptions, controlStreamHandler ControlStreamHandler, - observer *Observer, + logger *zerolog.Logger, ) (*QUICConnection, error) { session, err := quic.DialAddr(edgeAddr.String(), tlsConfig, quicConfig) if err != nil { - return nil, errors.Wrap(err, "failed to dial to edge") + return nil, fmt.Errorf("failed to dial to edge: %w", err) } - registrationStream, err := session.OpenStream() + datagramMuxer, err := quicpogs.NewDatagramMuxer(session) if err != nil { - return nil, errors.Wrap(err, "failed to open a registration stream") - } - - err = controlStreamHandler.ServeControlStream(ctx, registrationStream, connOptions, false) - if err != nil { - // Not wrapping error here to be consistent with the http2 message. return nil, err } + sessionManager := datagramsession.NewManager(datagramMuxer, logger) + return &QUICConnection{ - session: session, - httpProxy: httpProxy, - logger: observer.log, + session: session, + httpProxy: httpProxy, + logger: logger, + sessionManager: sessionManager, + controlStreamHandler: controlStreamHandler, + connOptions: connOptions, }, nil } // Serve starts a QUIC session that begins accepting streams. func (q *QUICConnection) Serve(ctx context.Context) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + // origintunneld assumes the first stream is used for the control plane + controlStream, err := q.session.OpenStream() + if err != nil { + return fmt.Errorf("failed to open a registration control stream: %w", err) + } + // If either goroutine returns nil error, we rely on this cancellation to make sure the other goroutine exits + // as fast as possible as well. Nil error means we want to exit for good (caller code won't retry serving this + // connection). + // If either goroutine returns a non nil error, then the error group cancels the context, thus also canceling the + // other goroutine as fast as possible. + ctx, cancel := context.WithCancel(ctx) + errGroup, ctx := errgroup.WithContext(ctx) + + // In the future, if cloudflared can autonomously push traffic to the edge, we have to make sure the control + // stream is already fully registered before the other goroutines can proceed. + errGroup.Go(func() error { + defer cancel() + return q.serveControlStream(ctx, controlStream) + }) + errGroup.Go(func() error { + defer cancel() + return q.acceptStream(ctx) + }) + errGroup.Go(func() error { + defer cancel() + return q.sessionManager.Serve(ctx) + }) + + return errGroup.Wait() +} + +func (q *QUICConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error { + // This blocks until the control plane is done. + err := q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions) + if err != nil { + // Not wrapping error here to be consistent with the http2 message. + return err + } + + return nil +} + +func (q *QUICConnection) acceptStream(ctx context.Context) error { + defer q.Close() for { stream, err := q.session.AcceptStream(ctx) if err != nil { // context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional. - if errors.Is(err, context.Canceled) { + if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() { return nil } - return errors.Wrap(err, "failed to accept QUIC stream") + return fmt.Errorf("failed to accept QUIC stream: %w", err) } go func() { defer stream.Close() @@ -99,7 +145,30 @@ func (q *QUICConnection) Close() { } func (q *QUICConnection) handleStream(stream quic.Stream) error { - connectRequest, err := quicpogs.ReadConnectRequestData(stream) + signature, err := quicpogs.DetermineProtocol(stream) + if err != nil { + return err + } + switch signature { + case quicpogs.DataStreamProtocolSignature: + reqServerStream, err := quicpogs.NewRequestServerStream(stream, signature) + if err != nil { + return nil + } + return q.handleDataStream(reqServerStream) + case quicpogs.RPCStreamProtocolSignature: + rpcStream, err := quicpogs.NewRPCServerStream(stream, signature) + if err != nil { + return err + } + return q.handleRPCStream(rpcStream) + default: + return fmt.Errorf("unknown protocol %v", signature) + } +} + +func (q *QUICConnection) handleDataStream(stream *quicpogs.RequestServerStream) error { + connectRequest, err := stream.ReadConnectRequestData() if err != nil { return err } @@ -114,32 +183,99 @@ func (q *QUICConnection) handleStream(stream quic.Stream) error { w := newHTTPResponseAdapter(stream) return q.httpProxy.ProxyHTTP(w, req, connectRequest.Type == quicpogs.ConnectionTypeWebsocket) case quicpogs.ConnectionTypeTCP: - rwa := &streamReadWriteAcker{ - ReadWriter: stream, - } + rwa := &streamReadWriteAcker{stream} return q.httpProxy.ProxyTCP(context.Background(), rwa, &TCPRequest{Dest: connectRequest.Dest}) } return nil } +func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) error { + return rpcStream.Serve(q, q.logger) +} + +// RegisterUdpSession is the RPC method invoked by edge to register and run a session +func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration) error { + // Each session is a series of datagram from an eyeball to a dstIP:dstPort. + // (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket. + originProxy, err := ingress.DialUDP(dstIP, dstPort) + if err != nil { + q.logger.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort) + return err + } + session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy) + if err != nil { + q.logger.Err(err).Str("sessionID", sessionID.String()).Msgf("Failed to register udp session") + return err + } + + go q.serveUDPSession(session, closeAfterIdleHint) + + q.logger.Debug().Msgf("Registered session %v, %v, %v", sessionID, dstIP, dstPort) + return nil +} + +func (q *QUICConnection) serveUDPSession(session *datagramsession.Session, closeAfterIdleHint time.Duration) { + ctx := q.session.Context() + closedByRemote, err := session.Serve(ctx, closeAfterIdleHint) + // If session is terminated by remote, then we know it has been unregistered from session manager and edge + if !closedByRemote { + if err != nil { + q.closeUDPSession(ctx, session.ID, err.Error()) + } else { + q.closeUDPSession(ctx, session.ID, "terminated without error") + } + } + q.logger.Debug().Err(err).Str("sessionID", session.ID.String()).Msg("Session terminated") +} + +// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge +func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) { + q.sessionManager.UnregisterSession(ctx, sessionID, message, false) + stream, err := q.session.OpenStream() + if err != nil { + // Log this at debug because this is not an error if session was closed due to lost connection + // with edge + q.logger.Debug().Err(err).Str("sessionID", sessionID.String()). + Msgf("Failed to open quic stream to unregister udp session with edge") + return + } + rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.logger) + if err != nil { + // Log this at debug because this is not an error if session was closed due to lost connection + // with edge + q.logger.Err(err).Str("sessionID", sessionID.String()). + Msgf("Failed to open rpc stream to unregister udp session with edge") + return + } + if err := rpcClientStream.UnregisterUdpSession(ctx, sessionID, message); err != nil { + q.logger.Err(err).Str("sessionID", sessionID.String()). + Msgf("Failed to unregister udp session with edge") + } +} + +// UnregisterUdpSession is the RPC method invoked by edge to unregister and terminate a sesssion +func (q *QUICConnection) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error { + return q.sessionManager.UnregisterSession(ctx, sessionID, message, true) +} + // streamReadWriteAcker is a light wrapper over QUIC streams with a callback to send response back to // the client. type streamReadWriteAcker struct { - io.ReadWriter + *quicpogs.RequestServerStream } // AckConnection acks response back to the proxy. func (s *streamReadWriteAcker) AckConnection() error { - return quicpogs.WriteConnectResponseData(s, nil) + return s.WriteConnectResponseData(nil) } // httpResponseAdapter translates responses written by the HTTP Proxy into ones that can be used in QUIC. type httpResponseAdapter struct { - io.Writer + *quicpogs.RequestServerStream } -func newHTTPResponseAdapter(w io.Writer) httpResponseAdapter { - return httpResponseAdapter{w} +func newHTTPResponseAdapter(s *quicpogs.RequestServerStream) httpResponseAdapter { + return httpResponseAdapter{s} } func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { @@ -151,18 +287,19 @@ func (hrw httpResponseAdapter) WriteRespHeaders(status int, header http.Header) metadata = append(metadata, quicpogs.Metadata{Key: httpHeaderKey, Val: v}) } } - return quicpogs.WriteConnectResponseData(hrw, nil, metadata...) + return hrw.WriteConnectResponseData(nil, metadata...) } func (hrw httpResponseAdapter) WriteErrorResponse(err error) { - quicpogs.WriteConnectResponseData(hrw, err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) + hrw.WriteConnectResponseData(err, quicpogs.Metadata{Key: "HttpStatus", Val: strconv.Itoa(http.StatusBadGateway)}) } -func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.Reader) (*http.Request, error) { +func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.ReadCloser) (*http.Request, error) { metadata := connectRequest.MetadataMap() dest := connectRequest.Dest method := metadata[HTTPMethodKey] host := metadata[HTTPHostKey] + isWebsocket := connectRequest.Type == quicpogs.ConnectionTypeWebsocket req, err := http.NewRequest(method, dest, body) if err != nil { @@ -175,11 +312,42 @@ func buildHTTPRequest(connectRequest *quicpogs.ConnectRequest, body io.Reader) ( // metadata.Key is off the format httpHeaderKey: httpHeaderKey := strings.Split(metadata.Key, ":") if len(httpHeaderKey) != 2 { - return nil, fmt.Errorf("Header Key: %s malformed", metadata.Key) + return nil, fmt.Errorf("header Key: %s malformed", metadata.Key) } req.Header.Add(httpHeaderKey[1], metadata.Val) } } + // Go's http.Client automatically sends chunked request body if this value is not set on the + // *http.Request struct regardless of header: + // https://go.googlesource.com/go/+/go1.8rc2/src/net/http/transfer.go#154. + if err := setContentLength(req); err != nil { + return nil, fmt.Errorf("Error setting content-length: %w", err) + } + + // Go's client defaults to chunked encoding after a 200ms delay if the following cases are true: + // * the request body blocks + // * the content length is not set (or set to -1) + // * the method doesn't usually have a body (GET, HEAD, DELETE, ...) + // * there is no transfer-encoding=chunked already set. + // So, if transfer cannot be chunked and content length is 0, we dont set a request body. + if !isWebsocket && !isTransferEncodingChunked(req) && req.ContentLength == 0 { + req.Body = nil + } stripWebsocketUpgradeHeader(req) return req, err } + +func setContentLength(req *http.Request) error { + var err error + if contentLengthStr := req.Header.Get("Content-Length"); contentLengthStr != "" { + req.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64) + } + return err +} + +func isTransferEncodingChunked(req *http.Request) bool { + transferEncodingVal := req.Header.Get("Transfer-Encoding") + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding suggests that this can be a comma + // separated value as well. + return strings.Contains(strings.ToLower(transferEncodingVal), "chunked") +} diff --git a/connection/quic_test.go b/connection/quic_test.go index 332eb987..ac945400 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -13,32 +13,36 @@ import ( "math/big" "net" "net/http" + "net/url" "os" "sync" "testing" + "time" "github.com/gobwas/ws/wsutil" + "github.com/google/uuid" "github.com/lucas-clemente/quic-go" "github.com/pkg/errors" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/cloudflare/cloudflared/datagramsession" quicpogs "github.com/cloudflare/cloudflared/quic" - "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" ) +var ( + testTLSServerConfig = generateTLSConfig() + testQUICConfig = &quic.Config{ + KeepAlive: true, + EnableDatagrams: true, + } +) + // TestQUICServer tests if a quic server accepts and responds to a quic client with the acceptance protocol. // It also serves as a demonstration for communication with the QUIC connection started by a cloudflared. func TestQUICServer(t *testing.T) { - quicConfig := &quic.Config{ - KeepAlive: true, - } - - // Setup test. - log := zerolog.New(os.Stdout) - // Start a UDP Listener for QUIC. udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") require.NoError(t, err) @@ -46,18 +50,6 @@ func TestQUICServer(t *testing.T) { require.NoError(t, err) defer udpListener.Close() - // Create a simple tls config. - tlsConfig := generateTLSConfig() - - // Create a client config - tlsClientConfig := &tls.Config{ - InsecureSkipVerify: true, - NextProtos: []string{"argotunnel"}, - } - - // Start a mock httpProxy - originProxy := &mockOriginProxyWithRequest{} - // This is simply a sample websocket frame message. wsBuf := &bytes.Buffer{} wsutil.WriteClientText(wsBuf, []byte("Hello")) @@ -75,15 +67,15 @@ func TestQUICServer(t *testing.T) { dest: "/ok", connectionType: quicpogs.ConnectionTypeHTTP, metadata: []quicpogs.Metadata{ - quicpogs.Metadata{ + { Key: "HttpHeader:Cf-Ray", Val: "123123123", }, - quicpogs.Metadata{ + { Key: "HttpHost", Val: "cf.host", }, - quicpogs.Metadata{ + { Key: "HttpMethod", Val: "GET", }, @@ -95,18 +87,22 @@ func TestQUICServer(t *testing.T) { dest: "/echo_body", connectionType: quicpogs.ConnectionTypeHTTP, metadata: []quicpogs.Metadata{ - quicpogs.Metadata{ + { Key: "HttpHeader:Cf-Ray", Val: "123123123", }, - quicpogs.Metadata{ + { Key: "HttpHost", Val: "cf.host", }, - quicpogs.Metadata{ + { Key: "HttpMethod", Val: "POST", }, + { + Key: "HttpHeader:Content-Length", + Val: "24", + }, }, message: []byte("This is the message body"), expectedResponse: []byte("This is the message body"), @@ -116,19 +112,19 @@ func TestQUICServer(t *testing.T) { dest: "/ok", connectionType: quicpogs.ConnectionTypeWebsocket, metadata: []quicpogs.Metadata{ - quicpogs.Metadata{ + { Key: "HttpHeader:Cf-Cloudflared-Proxy-Connection-Upgrade", Val: "Websocket", }, - quicpogs.Metadata{ + { Key: "HttpHeader:Another-Header", Val: "Misc", }, - quicpogs.Metadata{ + { Key: "HttpHost", Val: "cf.host", }, - quicpogs.Metadata{ + { Key: "HttpMethod", Val: "get", }, @@ -149,29 +145,17 @@ func TestQUICServer(t *testing.T) { t.Run(test.desc, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup + wg.Add(1) go func() { - wg.Add(1) defer wg.Done() quicServer( - t, udpListener, tlsConfig, quicConfig, + t, udpListener, testTLSServerConfig, testQUICConfig, test.dest, test.connectionType, test.metadata, test.message, test.expectedResponse, ) }() - controlStream := fakeControlStream{} - - qC, err := NewQUICConnection( - ctx, - quicConfig, - udpListener.LocalAddr(), - tlsClientConfig, - originProxy, - &pogs.ConnectionOptions{}, - controlStream, - NewObserver(&log, &log, false), - ) - require.NoError(t, err) - go qC.Serve(ctx) + qc := testQUICConnection(udpListener.LocalAddr(), t) + go qc.Serve(ctx) wg.Wait() cancel() @@ -183,7 +167,8 @@ type fakeControlStream struct { ControlStreamHandler } -func (fakeControlStream) ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions, shouldWaitForUnregister bool) error { +func (fakeControlStream) ServeControlStream(ctx context.Context, rw io.ReadWriteCloser, connOptions *tunnelpogs.ConnectionOptions) error { + <-ctx.Done() return nil } func (fakeControlStream) IsStopped() bool { @@ -213,10 +198,11 @@ func quicServer( stream, err := session.OpenStreamSync(context.Background()) require.NoError(t, err) - err = quicpogs.WriteConnectRequestData(stream, dest, connectionType, metadata...) + reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream} + err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...) require.NoError(t, err) - _, err = quicpogs.ReadConnectResponseData(stream) + _, err = reqClientStream.ReadConnectResponseData() require.NoError(t, err) if message != nil { @@ -292,8 +278,390 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque return nil } +func TestBuildHTTPRequest(t *testing.T) { + var tests = []struct { + name string + connectRequest *quicpogs.ConnectRequest + body io.ReadCloser + req *http.Request + }{ + { + name: "check if http.Request is built correctly with content length", + connectRequest: &quicpogs.ConnectRequest{ + Dest: "http://test.com", + Metadata: []quicpogs.Metadata{ + { + Key: "HttpHeader:Cf-Cloudflared-Proxy-Connection-Upgrade", + Val: "Websocket", + }, + { + Key: "HttpHeader:Content-Length", + Val: "514", + }, + { + Key: "HttpHeader:Another-Header", + Val: "Misc", + }, + { + Key: "HttpHost", + Val: "cf.host", + }, + { + Key: "HttpMethod", + Val: "get", + }, + }, + }, + req: &http.Request{ + Method: "get", + URL: &url.URL{ + Scheme: "http", + Host: "test.com", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + "Another-Header": []string{"Misc"}, + "Content-Length": []string{"514"}, + }, + ContentLength: 514, + Host: "cf.host", + Body: io.NopCloser(&bytes.Buffer{}), + }, + body: io.NopCloser(&bytes.Buffer{}), + }, + { + name: "if content length isn't part of request headers, then it's not set", + connectRequest: &quicpogs.ConnectRequest{ + Dest: "http://test.com", + Metadata: []quicpogs.Metadata{ + { + Key: "HttpHeader:Cf-Cloudflared-Proxy-Connection-Upgrade", + Val: "Websocket", + }, + { + Key: "HttpHeader:Another-Header", + Val: "Misc", + }, + { + Key: "HttpHost", + Val: "cf.host", + }, + { + Key: "HttpMethod", + Val: "get", + }, + }, + }, + req: &http.Request{ + Method: "get", + URL: &url.URL{ + Scheme: "http", + Host: "test.com", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + "Another-Header": []string{"Misc"}, + }, + ContentLength: 0, + Host: "cf.host", + Body: nil, + }, + body: io.NopCloser(&bytes.Buffer{}), + }, + { + name: "if content length is 0, but transfer-encoding is chunked, body is not nil", + connectRequest: &quicpogs.ConnectRequest{ + Dest: "http://test.com", + Metadata: []quicpogs.Metadata{ + { + Key: "HttpHeader:Another-Header", + Val: "Misc", + }, + { + Key: "HttpHeader:Transfer-Encoding", + Val: "chunked", + }, + { + Key: "HttpHost", + Val: "cf.host", + }, + { + Key: "HttpMethod", + Val: "get", + }, + }, + }, + req: &http.Request{ + Method: "get", + URL: &url.URL{ + Scheme: "http", + Host: "test.com", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + "Another-Header": []string{"Misc"}, + "Transfer-Encoding": []string{"chunked"}, + }, + ContentLength: 0, + Host: "cf.host", + Body: io.NopCloser(&bytes.Buffer{}), + }, + body: io.NopCloser(&bytes.Buffer{}), + }, + { + name: "if content length is 0, but transfer-encoding is gzip,chunked, body is not nil", + connectRequest: &quicpogs.ConnectRequest{ + Dest: "http://test.com", + Metadata: []quicpogs.Metadata{ + { + Key: "HttpHeader:Another-Header", + Val: "Misc", + }, + { + Key: "HttpHeader:Transfer-Encoding", + Val: "gzip,chunked", + }, + { + Key: "HttpHost", + Val: "cf.host", + }, + { + Key: "HttpMethod", + Val: "get", + }, + }, + }, + req: &http.Request{ + Method: "get", + URL: &url.URL{ + Scheme: "http", + Host: "test.com", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + "Another-Header": []string{"Misc"}, + "Transfer-Encoding": []string{"gzip,chunked"}, + }, + ContentLength: 0, + Host: "cf.host", + Body: io.NopCloser(&bytes.Buffer{}), + }, + body: io.NopCloser(&bytes.Buffer{}), + }, + { + name: "if content length is 0, and connect request is a websocket, body is not nil", + connectRequest: &quicpogs.ConnectRequest{ + Type: quicpogs.ConnectionTypeWebsocket, + Dest: "http://test.com", + Metadata: []quicpogs.Metadata{ + { + Key: "HttpHeader:Another-Header", + Val: "Misc", + }, + { + Key: "HttpHost", + Val: "cf.host", + }, + { + Key: "HttpMethod", + Val: "get", + }, + }, + }, + req: &http.Request{ + Method: "get", + URL: &url.URL{ + Scheme: "http", + Host: "test.com", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + "Another-Header": []string{"Misc"}, + }, + ContentLength: 0, + Host: "cf.host", + Body: io.NopCloser(&bytes.Buffer{}), + }, + body: io.NopCloser(&bytes.Buffer{}), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + req, err := buildHTTPRequest(test.connectRequest, test.body) + assert.NoError(t, err) + test.req = test.req.WithContext(req.Context()) + assert.Equal(t, test.req, req) + }) + } +} + func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error { rwa.AckConnection() io.Copy(rwa, rwa) return nil } + +func TestServeUDPSession(t *testing.T) { + // Start a UDP Listener for QUIC. + udpAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + require.NoError(t, err) + udpListener, err := net.ListenUDP(udpAddr.Network(), udpAddr) + require.NoError(t, err) + defer udpListener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + // Establish QUIC connection with edge + edgeQUICSessionChan := make(chan quic.Session) + go func() { + earlyListener, err := quic.Listen(udpListener, testTLSServerConfig, testQUICConfig) + require.NoError(t, err) + + edgeQUICSession, err := earlyListener.Accept(ctx) + require.NoError(t, err) + edgeQUICSessionChan <- edgeQUICSession + }() + + qc := testQUICConnection(udpListener.LocalAddr(), t) + go qc.Serve(ctx) + + edgeQUICSession := <-edgeQUICSessionChan + serveSession(ctx, qc, edgeQUICSession, closedByOrigin, io.EOF.Error(), t) + serveSession(ctx, qc, edgeQUICSession, closedByTimeout, datagramsession.SessionIdleErr(time.Millisecond*50).Error(), t) + serveSession(ctx, qc, edgeQUICSession, closedByRemote, "eyeball closed connection", t) + cancel() +} + +func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.Session, closeType closeReason, expectedReason string, t *testing.T) { + var ( + payload = []byte(t.Name()) + ) + sessionID := uuid.New() + cfdConn, originConn := net.Pipe() + // Registers and run a new session + session, err := qc.sessionManager.RegisterSession(ctx, sessionID, cfdConn) + require.NoError(t, err) + + sessionDone := make(chan struct{}) + go func() { + qc.serveUDPSession(session, time.Millisecond*50) + close(sessionDone) + }() + + // Send a message to the quic session on edge side, it should be deumx to this datagram session + muxedPayload := append(payload, sessionID[:]...) + err = edgeQUICSession.SendMessage(muxedPayload) + require.NoError(t, err) + + readBuffer := make([]byte, len(payload)+1) + n, err := originConn.Read(readBuffer) + require.NoError(t, err) + require.Equal(t, len(payload), n) + require.True(t, bytes.Equal(payload, readBuffer[:n])) + + // Close connection to terminate session + switch closeType { + case closedByOrigin: + originConn.Close() + case closedByRemote: + err = qc.UnregisterUdpSession(ctx, sessionID, expectedReason) + require.NoError(t, err) + case closedByTimeout: + } + + if closeType != closedByRemote { + // Session was not closed by remote, so closeUDPSession should be invoked to unregister from remote + unregisterFromEdgeChan := make(chan struct{}) + rpcServer := &mockSessionRPCServer{ + sessionID: sessionID, + unregisterReason: expectedReason, + calledUnregisterChan: unregisterFromEdgeChan, + } + go runMockSessionRPCServer(ctx, edgeQUICSession, rpcServer, t) + + <-unregisterFromEdgeChan + } + + <-sessionDone +} + +type closeReason uint8 + +const ( + closedByOrigin closeReason = iota + closedByRemote + closedByTimeout +) + +func runMockSessionRPCServer(ctx context.Context, session quic.Session, rpcServer *mockSessionRPCServer, t *testing.T) { + stream, err := session.AcceptStream(ctx) + require.NoError(t, err) + + if stream.StreamID() == 0 { + // Skip the first stream, it's the control stream of the QUIC connection + stream, err = session.AcceptStream(ctx) + require.NoError(t, err) + } + protocol, err := quicpogs.DetermineProtocol(stream) + assert.NoError(t, err) + rpcServerStream, err := quicpogs.NewRPCServerStream(stream, protocol) + assert.NoError(t, err) + + log := zerolog.New(os.Stdout) + err = rpcServerStream.Serve(rpcServer, &log) + assert.NoError(t, err) +} + +type mockSessionRPCServer struct { + sessionID uuid.UUID + unregisterReason string + calledUnregisterChan chan struct{} +} + +func (s mockSessionRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error { + return fmt.Errorf("mockSessionRPCServer doesn't implement RegisterUdpSession") +} + +func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, reason string) error { + if s.sessionID != sessionID { + return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID) + } + if s.unregisterReason != reason { + return fmt.Errorf("expect unregister reason %s, got %s", s.unregisterReason, reason) + } + close(s.calledUnregisterChan) + fmt.Println("unregister from edge") + return nil +} + +func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection { + tlsClientConfig := &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"argotunnel"}, + } + // Start a mock httpProxy + originProxy := &mockOriginProxyWithRequest{} + log := zerolog.New(os.Stdout) + qc, err := NewQUICConnection( + testQUICConfig, + udpListenerAddr, + tlsClientConfig, + originProxy, + &tunnelpogs.ConnectionOptions{}, + fakeControlStream{}, + &log, + ) + require.NoError(t, err) + return qc +} diff --git a/datagramsession/event.go b/datagramsession/event.go new file mode 100644 index 00000000..d79c6b31 --- /dev/null +++ b/datagramsession/event.go @@ -0,0 +1,50 @@ +package datagramsession + +import ( + "fmt" + "io" + + "github.com/google/uuid" +) + +// registerSessionEvent is an event to start tracking a new session +type registerSessionEvent struct { + sessionID uuid.UUID + originProxy io.ReadWriteCloser + resultChan chan *Session +} + +func newRegisterSessionEvent(sessionID uuid.UUID, originProxy io.ReadWriteCloser) *registerSessionEvent { + return ®isterSessionEvent{ + sessionID: sessionID, + originProxy: originProxy, + resultChan: make(chan *Session, 1), + } +} + +// unregisterSessionEvent is an event to stop tracking and terminate the session. +type unregisterSessionEvent struct { + sessionID uuid.UUID + err *errClosedSession +} + +// ClosedSessionError represent a condition that closes the session other than I/O +// I/O error is not included, because the side that closes the session is ambiguous. +type errClosedSession struct { + message string + byRemote bool +} + +func (sc *errClosedSession) Error() string { + if sc.byRemote { + return fmt.Sprintf("session closed by remote due to %s", sc.message) + } else { + return fmt.Sprintf("session closed by local due to %s", sc.message) + } +} + +// newDatagram is an event when transport receives new datagram +type newDatagram struct { + sessionID uuid.UUID + payload []byte +} diff --git a/datagramsession/manager.go b/datagramsession/manager.go new file mode 100644 index 00000000..caf9d4fd --- /dev/null +++ b/datagramsession/manager.go @@ -0,0 +1,144 @@ +package datagramsession + +import ( + "context" + "io" + + "github.com/google/uuid" + "github.com/lucas-clemente/quic-go" + "github.com/rs/zerolog" + "golang.org/x/sync/errgroup" +) + +const ( + requestChanCapacity = 16 +) + +// Manager defines the APIs to manage sessions from the same transport. +type Manager interface { + // Serve starts the event loop + Serve(ctx context.Context) error + // RegisterSession starts tracking a session. Caller is responsible for starting the session + RegisterSession(ctx context.Context, sessionID uuid.UUID, dstConn io.ReadWriteCloser) (*Session, error) + // UnregisterSession stops tracking the session and terminates it + UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error +} + +type manager struct { + registrationChan chan *registerSessionEvent + unregistrationChan chan *unregisterSessionEvent + datagramChan chan *newDatagram + transport transport + sessions map[uuid.UUID]*Session + log *zerolog.Logger +} + +func NewManager(transport transport, log *zerolog.Logger) Manager { + return &manager{ + registrationChan: make(chan *registerSessionEvent), + unregistrationChan: make(chan *unregisterSessionEvent), + // datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events + datagramChan: make(chan *newDatagram, requestChanCapacity), + transport: transport, + sessions: make(map[uuid.UUID]*Session), + log: log, + } +} + +func (m *manager) Serve(ctx context.Context) error { + errGroup, ctx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + for { + sessionID, payload, err := m.transport.ReceiveFrom() + if err != nil { + if aerr, ok := err.(*quic.ApplicationError); ok && uint64(aerr.ErrorCode) == uint64(quic.NoError) { + return nil + } else { + return err + } + } + datagram := &newDatagram{ + sessionID: sessionID, + payload: payload, + } + select { + case <-ctx.Done(): + return ctx.Err() + // Only the event loop routine can update/lookup the sessions map to avoid concurrent access + // Send the datagram to the event loop. It will find the session to send to + case m.datagramChan <- datagram: + } + } + }) + errGroup.Go(func() error { + for { + select { + case <-ctx.Done(): + return nil + case datagram := <-m.datagramChan: + m.sendToSession(datagram) + case registration := <-m.registrationChan: + m.registerSession(ctx, registration) + // TODO: TUN-5422: Unregister inactive session upon timeout + case unregistration := <-m.unregistrationChan: + m.unregisterSession(unregistration) + } + } + }) + return errGroup.Wait() +} + +func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) { + event := newRegisterSessionEvent(sessionID, originProxy) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case m.registrationChan <- event: + session := <-event.resultChan + return session, nil + } +} + +func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) { + session := newSession(registration.sessionID, m.transport, registration.originProxy, m.log) + m.sessions[registration.sessionID] = session + registration.resultChan <- session +} + +func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error { + event := &unregisterSessionEvent{ + sessionID: sessionID, + err: &errClosedSession{ + message: message, + byRemote: byRemote, + }, + } + select { + case <-ctx.Done(): + return ctx.Err() + case m.unregistrationChan <- event: + return nil + } +} + +func (m *manager) unregisterSession(unregistration *unregisterSessionEvent) { + session, ok := m.sessions[unregistration.sessionID] + if ok { + delete(m.sessions, unregistration.sessionID) + session.close(unregistration.err) + } +} + +func (m *manager) sendToSession(datagram *newDatagram) { + session, ok := m.sessions[datagram.sessionID] + if !ok { + m.log.Error().Str("sessionID", datagram.sessionID.String()).Msg("session not found") + return + } + // session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't + // need to run in another go routine + _, err := session.transportToDst(datagram.payload) + if err != nil { + m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session") + } +} diff --git a/datagramsession/manager_test.go b/datagramsession/manager_test.go new file mode 100644 index 00000000..e81147df --- /dev/null +++ b/datagramsession/manager_test.go @@ -0,0 +1,222 @@ +package datagramsession + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "testing" + "time" + + "github.com/google/uuid" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +func TestManagerServe(t *testing.T) { + const ( + sessions = 20 + msgs = 50 + remoteUnregisterMsg = "eyeball closed connection" + ) + log := zerolog.Nop() + transport := &mockQUICTransport{ + reqChan: newDatagramChannel(1), + respChan: newDatagramChannel(1), + } + mg := NewManager(transport, &log) + + eyeballTracker := make(map[uuid.UUID]*datagramChannel) + for i := 0; i < sessions; i++ { + sessionID := uuid.New() + eyeballTracker[sessionID] = newDatagramChannel(1) + } + + ctx, cancel := context.WithCancel(context.Background()) + serveDone := make(chan struct{}) + go func(ctx context.Context) { + mg.Serve(ctx) + close(serveDone) + }(ctx) + + go func(ctx context.Context) { + for { + sessionID, payload, err := transport.respChan.Receive(ctx) + if err != nil { + require.Equal(t, context.Canceled, err) + return + } + respChan := eyeballTracker[sessionID] + require.NoError(t, respChan.Send(ctx, sessionID, payload)) + } + }(ctx) + + errGroup, ctx := errgroup.WithContext(ctx) + for sID, receiver := range eyeballTracker { + // Assign loop variables to local variables + sessionID := sID + eyeballRespReceiver := receiver + errGroup.Go(func() error { + payload := testPayload(sessionID) + expectResp := testResponse(payload) + + cfdConn, originConn := net.Pipe() + + origin := mockOrigin{ + expectMsgCount: msgs, + expectedMsg: payload, + expectedResp: expectResp, + conn: originConn, + } + eyeball := mockEyeball{ + expectMsgCount: msgs, + expectedMsg: expectResp, + expectSessionID: sessionID, + respReceiver: eyeballRespReceiver, + } + + reqErrGroup, reqCtx := errgroup.WithContext(ctx) + reqErrGroup.Go(func() error { + return origin.serve() + }) + reqErrGroup.Go(func() error { + return eyeball.serve(reqCtx) + }) + + session, err := mg.RegisterSession(ctx, sessionID, cfdConn) + require.NoError(t, err) + + sessionDone := make(chan struct{}) + go func() { + closedByRemote, err := session.Serve(ctx, time.Minute*2) + closeSession := &errClosedSession{ + message: remoteUnregisterMsg, + byRemote: true, + } + require.Equal(t, closeSession, err) + require.True(t, closedByRemote) + close(sessionDone) + }() + + for i := 0; i < msgs; i++ { + require.NoError(t, transport.newRequest(ctx, sessionID, testPayload(sessionID))) + } + + // Make sure eyeball and origin have received all messages before unregistering the session + require.NoError(t, reqErrGroup.Wait()) + + require.NoError(t, mg.UnregisterSession(ctx, sessionID, remoteUnregisterMsg, true)) + <-sessionDone + + return nil + }) + } + + require.NoError(t, errGroup.Wait()) + cancel() + transport.close() + <-serveDone +} + +type mockOrigin struct { + expectMsgCount int + expectedMsg []byte + expectedResp []byte + conn io.ReadWriteCloser +} + +func (mo *mockOrigin) serve() error { + expectedMsgLen := len(mo.expectedMsg) + readBuffer := make([]byte, expectedMsgLen+1) + for i := 0; i < mo.expectMsgCount; i++ { + n, err := mo.conn.Read(readBuffer) + if err != nil { + return err + } + if n != expectedMsgLen { + return fmt.Errorf("Expect to read %d bytes, read %d", expectedMsgLen, n) + } + if !bytes.Equal(readBuffer[:n], mo.expectedMsg) { + return fmt.Errorf("Expect %v, read %v", mo.expectedMsg, readBuffer[:n]) + } + + _, err = mo.conn.Write(mo.expectedResp) + if err != nil { + return err + } + } + return nil +} + +func testPayload(sessionID uuid.UUID) []byte { + return []byte(fmt.Sprintf("Message from %s", sessionID)) +} + +func testResponse(msg []byte) []byte { + return []byte(fmt.Sprintf("Response to %v", msg)) +} + +type mockEyeball struct { + expectMsgCount int + expectedMsg []byte + expectSessionID uuid.UUID + respReceiver *datagramChannel +} + +func (me *mockEyeball) serve(ctx context.Context) error { + for i := 0; i < me.expectMsgCount; i++ { + sessionID, msg, err := me.respReceiver.Receive(ctx) + if err != nil { + return err + } + if sessionID != me.expectSessionID { + return fmt.Errorf("Expect session %s, got %s", me.expectSessionID, sessionID) + } + if !bytes.Equal(msg, me.expectedMsg) { + return fmt.Errorf("Expect %v, read %v", me.expectedMsg, msg) + } + } + return nil +} + +// datagramChannel is a channel for Datagram with wrapper to send/receive with context +type datagramChannel struct { + datagramChan chan *newDatagram + closedChan chan struct{} +} + +func newDatagramChannel(capacity uint) *datagramChannel { + return &datagramChannel{ + datagramChan: make(chan *newDatagram, capacity), + closedChan: make(chan struct{}), + } +} + +func (rc *datagramChannel) Send(ctx context.Context, sessionID uuid.UUID, payload []byte) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-rc.closedChan: + return fmt.Errorf("datagram channel closed") + case rc.datagramChan <- &newDatagram{sessionID: sessionID, payload: payload}: + return nil + } +} + +func (rc *datagramChannel) Receive(ctx context.Context) (uuid.UUID, []byte, error) { + select { + case <-ctx.Done(): + return uuid.Nil, nil, ctx.Err() + case <-rc.closedChan: + return uuid.Nil, nil, fmt.Errorf("datagram channel closed") + case msg := <-rc.datagramChan: + return msg.sessionID, msg.payload, nil + } +} + +func (rc *datagramChannel) Close() { + // No need to close msgChan, it will be garbage collect once there is no reference to it + close(rc.closedChan) +} diff --git a/datagramsession/session.go b/datagramsession/session.go new file mode 100644 index 00000000..fc46aa0f --- /dev/null +++ b/datagramsession/session.go @@ -0,0 +1,140 @@ +package datagramsession + +import ( + "context" + "fmt" + "io" + "time" + + "github.com/google/uuid" + "github.com/rs/zerolog" +) + +const ( + defaultCloseIdleAfter = time.Second * 210 +) + +func SessionIdleErr(timeout time.Duration) error { + return fmt.Errorf("session idle for %v", timeout) +} + +// Session is a bidirectional pipe of datagrams between transport and dstConn +// Currently the only implementation of transport is quic DatagramMuxer +// Destination can be a connection with origin or with eyeball +// When the destination is origin: +// - Datagrams from edge are read by Manager from the transport. Manager finds the corresponding Session and calls the +// write method of the Session to send to origin +// - Datagrams from origin are read from conn and SentTo transport. Transport will return them to eyeball +// When the destination is eyeball: +// - Datagrams from eyeball are read from conn and SentTo transport. Transport will send them to cloudflared +// - Datagrams from cloudflared are read by Manager from the transport. Manager finds the corresponding Session and calls the +// write method of the Session to send to eyeball +type Session struct { + ID uuid.UUID + transport transport + dstConn io.ReadWriteCloser + // activeAtChan is used to communicate the last read/write time + activeAtChan chan time.Time + closeChan chan error + log *zerolog.Logger +} + +func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser, log *zerolog.Logger) *Session { + return &Session{ + ID: id, + transport: transport, + dstConn: dstConn, + // activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will + // drop instead of blocking because last active time only needs to be an approximation + activeAtChan: make(chan time.Time, 2), + // capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel + closeChan: make(chan error, 2), + log: log, + } +} + +func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (closedByRemote bool, err error) { + go func() { + // QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975 + // This makes it safe to share readBuffer between iterations + const maxPacketSize = 1500 + readBuffer := make([]byte, maxPacketSize) + for { + if err := s.dstToTransport(readBuffer); err != nil { + s.closeChan <- err + return + } + } + }() + err = s.waitForCloseCondition(ctx, closeAfterIdle) + if closeSession, ok := err.(*errClosedSession); ok { + closedByRemote = closeSession.byRemote + } + return closedByRemote, err +} + +func (s *Session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) error { + // Closing dstConn cancels read so dstToTransport routine in Serve() can return + defer s.dstConn.Close() + if closeAfterIdle == 0 { + // provide deafult is caller doesn't specify one + closeAfterIdle = defaultCloseIdleAfter + } + + checkIdleFreq := closeAfterIdle / 8 + checkIdleTicker := time.NewTicker(checkIdleFreq) + defer checkIdleTicker.Stop() + + activeAt := time.Now() + for { + select { + case <-ctx.Done(): + return ctx.Err() + case reason := <-s.closeChan: + return reason + // TODO: TUN-5423 evaluate if using atomic is more efficient + case now := <-checkIdleTicker.C: + // The session is considered inactive if current time is after (last active time + allowed idle time) + if now.After(activeAt.Add(closeAfterIdle)) { + return SessionIdleErr(closeAfterIdle) + } + case activeAt = <-s.activeAtChan: // Update last active time + } + } +} + +func (s *Session) dstToTransport(buffer []byte) error { + n, err := s.dstConn.Read(buffer) + s.markActive() + if n > 0 { + if n <= int(s.transport.MTU()) { + err = s.transport.SendTo(s.ID, buffer[:n]) + } else { + // drop packet for now, eventually reply with ICMP for PMTUD + s.log.Debug(). + Str("session", s.ID.String()). + Int("len", n). + Int("mtu", s.transport.MTU()). + Msg("dropped packet exceeding MTU") + } + } + return err +} + +func (s *Session) transportToDst(payload []byte) (int, error) { + s.markActive() + return s.dstConn.Write(payload) +} + +// Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there +// are many concurrent read/write. It is fine to lose some precision +func (s *Session) markActive() { + select { + case s.activeAtChan <- time.Now(): + default: + } +} + +func (s *Session) close(err *errClosedSession) { + s.closeChan <- err +} diff --git a/datagramsession/session_test.go b/datagramsession/session_test.go new file mode 100644 index 00000000..591fa3c6 --- /dev/null +++ b/datagramsession/session_test.go @@ -0,0 +1,197 @@ +package datagramsession + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +// TestCloseSession makes sure a session will stop after context is done +func TestSessionCtxDone(t *testing.T) { + testSessionReturns(t, closeByContext, time.Minute*2) +} + +// TestCloseSession makes sure a session will stop after close method is called +func TestCloseSession(t *testing.T) { + testSessionReturns(t, closeByCallingClose, time.Minute*2) +} + +// TestCloseIdle makess sure a session will stop after there is no read/write for a period defined by closeAfterIdle +func TestCloseIdle(t *testing.T) { + testSessionReturns(t, closeByTimeout, time.Millisecond*100) +} + +func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.Duration) { + var ( + localCloseReason = &errClosedSession{ + message: "connection closed by origin", + byRemote: false, + } + ) + sessionID := uuid.New() + cfdConn, originConn := net.Pipe() + payload := testPayload(sessionID) + transport := &mockQUICTransport{ + reqChan: newDatagramChannel(1), + respChan: newDatagramChannel(1), + } + log := zerolog.Nop() + session := newSession(sessionID, transport, cfdConn, &log) + + ctx, cancel := context.WithCancel(context.Background()) + sessionDone := make(chan struct{}) + go func() { + closedByRemote, err := session.Serve(ctx, closeAfterIdle) + switch closeBy { + case closeByContext: + require.Equal(t, context.Canceled, err) + require.False(t, closedByRemote) + case closeByCallingClose: + require.Equal(t, localCloseReason, err) + require.Equal(t, localCloseReason.byRemote, closedByRemote) + case closeByTimeout: + require.Equal(t, SessionIdleErr(closeAfterIdle), err) + require.False(t, closedByRemote) + } + close(sessionDone) + }() + + go func() { + n, err := session.transportToDst(payload) + require.NoError(t, err) + require.Equal(t, len(payload), n) + }() + + readBuffer := make([]byte, len(payload)+1) + n, err := originConn.Read(readBuffer) + require.NoError(t, err) + require.Equal(t, len(payload), n) + + lastRead := time.Now() + + switch closeBy { + case closeByContext: + cancel() + case closeByCallingClose: + session.close(localCloseReason) + } + + <-sessionDone + if closeBy == closeByTimeout { + require.True(t, time.Now().After(lastRead.Add(closeAfterIdle))) + } + // call cancelled again otherwise the linter will warn about possible context leak + cancel() +} + +type closeMethod int + +const ( + closeByContext closeMethod = iota + closeByCallingClose + closeByTimeout +) + +func TestWriteToDstSessionPreventClosed(t *testing.T) { + testActiveSessionNotClosed(t, false, true) +} + +func TestReadFromDstSessionPreventClosed(t *testing.T) { + testActiveSessionNotClosed(t, true, false) +} + +func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) { + const closeAfterIdle = time.Millisecond * 100 + const activeTime = time.Millisecond * 500 + + sessionID := uuid.New() + cfdConn, originConn := net.Pipe() + payload := testPayload(sessionID) + transport := &mockQUICTransport{ + reqChan: newDatagramChannel(100), + respChan: newDatagramChannel(100), + } + log := zerolog.Nop() + session := newSession(sessionID, transport, cfdConn, &log) + + startTime := time.Now() + activeUntil := startTime.Add(activeTime) + ctx, cancel := context.WithCancel(context.Background()) + errGroup, ctx := errgroup.WithContext(ctx) + errGroup.Go(func() error { + session.Serve(ctx, closeAfterIdle) + if time.Now().Before(startTime.Add(activeTime)) { + return fmt.Errorf("session closed while it's still active") + } + return nil + }) + + if readFromDst { + errGroup.Go(func() error { + for { + if time.Now().After(activeUntil) { + return nil + } + if _, err := originConn.Write(payload); err != nil { + return err + } + time.Sleep(closeAfterIdle / 2) + } + }) + } + if writeToDst { + errGroup.Go(func() error { + readBuffer := make([]byte, len(payload)) + for { + n, err := originConn.Read(readBuffer) + if err != nil { + if err == io.EOF || err == io.ErrClosedPipe { + return nil + } + return err + } + if !bytes.Equal(payload, readBuffer[:n]) { + return fmt.Errorf("payload %v is not equal to %v", readBuffer[:n], payload) + } + } + }) + errGroup.Go(func() error { + for { + if time.Now().After(activeUntil) { + return nil + } + if _, err := session.transportToDst(payload); err != nil { + return err + } + time.Sleep(closeAfterIdle / 2) + } + }) + } + + require.NoError(t, errGroup.Wait()) + cancel() +} + +func TestMarkActiveNotBlocking(t *testing.T) { + const concurrentCalls = 50 + session := newSession(uuid.New(), nil, nil, nil) + var wg sync.WaitGroup + wg.Add(concurrentCalls) + for i := 0; i < concurrentCalls; i++ { + go func() { + session.markActive() + wg.Done() + }() + } + wg.Wait() +} diff --git a/datagramsession/transport.go b/datagramsession/transport.go new file mode 100644 index 00000000..f41e1e83 --- /dev/null +++ b/datagramsession/transport.go @@ -0,0 +1,13 @@ +package datagramsession + +import "github.com/google/uuid" + +// Transport is a connection between cloudflared and edge that can multiplex datagrams from multiple sessions +type transport interface { + // SendTo writes payload for a session to the transport + SendTo(sessionID uuid.UUID, payload []byte) error + // ReceiveFrom reads the next datagram from the transport + ReceiveFrom() (uuid.UUID, []byte, error) + // Max transmission unit to receive from the transport + MTU() int +} diff --git a/datagramsession/transport_test.go b/datagramsession/transport_test.go new file mode 100644 index 00000000..6b187722 --- /dev/null +++ b/datagramsession/transport_test.go @@ -0,0 +1,36 @@ +package datagramsession + +import ( + "context" + + "github.com/google/uuid" +) + +type mockQUICTransport struct { + reqChan *datagramChannel + respChan *datagramChannel +} + +func (mt *mockQUICTransport) SendTo(sessionID uuid.UUID, payload []byte) error { + buf := make([]byte, len(payload)) + // The QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975 + copy(buf, payload) + return mt.respChan.Send(context.Background(), sessionID, buf) +} + +func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) { + return mt.reqChan.Receive(context.Background()) +} + +func (mt *mockQUICTransport) MTU() int { + return 1280 +} + +func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error { + return mt.reqChan.Send(ctx, sessionID, payload) +} + +func (mt *mockQUICTransport) close() { + mt.reqChan.Close() + mt.respChan.Close() +} diff --git a/dev.Dockerfile b/dev.Dockerfile index a9197c9b..add922fb 100644 --- a/dev.Dockerfile +++ b/dev.Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.16.4 as builder +FROM golang:1.17.5 as builder ENV GO111MODULE=on \ CGO_ENABLED=0 WORKDIR /go/src/github.com/cloudflare/cloudflared/ diff --git a/edgediscovery/protocol.go b/edgediscovery/protocol.go index 8d9f5039..5bbb1e91 100644 --- a/edgediscovery/protocol.go +++ b/edgediscovery/protocol.go @@ -1,45 +1,50 @@ package edgediscovery import ( + "encoding/json" "fmt" "net" - "strconv" "strings" ) const ( - protocolRecord = "protocol.argotunnel.com" + protocolRecord = "protocol-v2.argotunnel.com" ) var ( errNoProtocolRecord = fmt.Errorf("No TXT record found for %s to determine connection protocol", protocolRecord) ) -func HTTP2Percentage() (int32, error) { +// ProtocolPercent represents a single Protocol Percentage combination. +type ProtocolPercent struct { + Protocol string `json:"protocol"` + Percentage int32 `json:"percentage"` +} + +// ProtocolPercents represents the preferred distribution ratio of protocols when protocol isn't specified. +type ProtocolPercents []ProtocolPercent + +// GetPercentage returns the threshold percentage of a single protocol. +func (p ProtocolPercents) GetPercentage(protocol string) int32 { + for _, protocolPercent := range p { + if strings.ToLower(protocolPercent.Protocol) == strings.ToLower(protocol) { + return protocolPercent.Percentage + } + } + return 0 +} + +// ProtocolPercentage returns the ratio of protocols and a specification ratio for their selection. +func ProtocolPercentage() (ProtocolPercents, error) { records, err := net.LookupTXT(protocolRecord) if err != nil { - return 0, err + return nil, err } if len(records) == 0 { - return 0, errNoProtocolRecord + return nil, errNoProtocolRecord } - return parseHTTP2Precentage(records[0]) -} - -// The record looks like http2=percentage -func parseHTTP2Precentage(record string) (int32, error) { - const key = "http2" - slices := strings.Split(record, "=") - if len(slices) != 2 { - return 0, fmt.Errorf("Malformed TXT record %s, expect http2=percentage", record) - } - if slices[0] != key { - return 0, fmt.Errorf("Incorrect key %s, expect %s", slices[0], key) - } - percentage, err := strconv.ParseInt(slices[1], 10, 32) - if err != nil { - return 0, err - } - return int32(percentage), nil + var protocolsWithPercent ProtocolPercents + err = json.Unmarshal([]byte(records[0]), &protocolsWithPercent) + return protocolsWithPercent, err } diff --git a/edgediscovery/protocol_test.go b/edgediscovery/protocol_test.go index 874ab6ee..37b9353f 100644 --- a/edgediscovery/protocol_test.go +++ b/edgediscovery/protocol_test.go @@ -6,75 +6,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestHTTP2Percentage(t *testing.T) { - _, err := HTTP2Percentage() +func TestProtocolPercentage(t *testing.T) { + _, err := ProtocolPercentage() assert.NoError(t, err) } - -func TestParseHTTP2Precentage(t *testing.T) { - tests := []struct { - record string - percentage int32 - wantErr bool - }{ - { - record: "http2=-1", - percentage: -1, - wantErr: false, - }, - { - record: "http2=0", - percentage: 0, - wantErr: false, - }, - { - record: "http2=50", - percentage: 50, - wantErr: false, - }, - { - record: "http2=100", - percentage: 100, - wantErr: false, - }, - { - record: "http2=1000", - percentage: 1000, - wantErr: false, - }, - { - record: "http2=10.5", - wantErr: true, - }, - { - record: "http2=10 h2mux=90", - wantErr: true, - }, - { - record: "http2=ten", - wantErr: true, - }, - - { - record: "h2mux=100", - wantErr: true, - }, - { - record: "http2", - wantErr: true, - }, - { - record: "http2=", - wantErr: true, - }, - } - - for _, test := range tests { - p, err := parseHTTP2Precentage(test.record) - if test.wantErr { - assert.Error(t, err) - } else { - assert.Equal(t, test.percentage, p) - } - } -} diff --git a/github_release.py b/github_release.py index 532d0eae..e59a0cbf 100755 --- a/github_release.py +++ b/github_release.py @@ -72,7 +72,7 @@ def get_or_create_release(repo, version, dry_run=False): except UnknownObjectException: logging.info("Release %s not found", version) - # We dont want to create a new release tag if one doesnt already exist + # We don't want to create a new release tag if one doesn't already exist assert_tag_exists(repo, version) if dry_run: @@ -198,7 +198,7 @@ def upload_asset(release, filepath, filename, release_version, kv_account_id, na pass # the macOS release copy fails with being the same file (already in the artifacts directory) def main(): - """ Attempts to upload Asset to Github Release. Creates Release if it doesnt exist """ + """ Attempts to upload Asset to Github Release. Creates Release if it doesn't exist """ try: args = parse_args() client = Github(args.api_key) diff --git a/go.mod b/go.mod index 0344cd32..8655ad1b 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,7 @@ require ( github.com/json-iterator/go v1.1.10 github.com/kr/text v0.2.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect - github.com/lucas-clemente/quic-go v0.23.0 + github.com/lucas-clemente/quic-go v0.24.0 github.com/mattn/go-colorable v0.1.8 github.com/miekg/dns v1.1.31 github.com/mitchellh/go-homedir v1.1.0 @@ -45,11 +45,11 @@ require ( github.com/stretchr/testify v1.6.0 github.com/urfave/cli/v2 v2.2.0 go.uber.org/automaxprocs v1.4.0 - golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a - golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 + golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 + golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c - golang.org/x/sys v0.0.0-20210510120138-977fb7262007 + golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d // indirect google.golang.org/grpc v1.32.0 // indirect @@ -70,6 +70,7 @@ require ( github.com/cheekybits/genny v1.0.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect + github.com/francoispqt/gojay v1.2.13 // indirect github.com/gdamore/encoding v1.0.0 // indirect github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect github.com/golang/protobuf v1.5.2 // indirect @@ -77,6 +78,7 @@ require ( github.com/lucasb-eyer/go-colorful v1.0.3 // indirect github.com/marten-seemann/qtls-go1-16 v0.1.4 // indirect github.com/marten-seemann/qtls-go1-17 v0.1.0 // indirect + github.com/marten-seemann/qtls-go1-18 v0.1.0-beta.1 // indirect github.com/mattn/go-isatty v0.0.12 // indirect github.com/mattn/go-runewidth v0.0.8 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect @@ -97,3 +99,5 @@ require ( ) replace github.com/urfave/cli/v2 => github.com/ipostelnik/cli/v2 v2.3.1-0.20210324024421-b6ea8234fe3d + +replace github.com/lucas-clemente/quic-go => github.com/chungthuang/quic-go v0.24.1-0.20220110095058-981dc498cb62 diff --git a/go.sum b/go.sum index 9d2dddd9..14bfb9d2 100644 --- a/go.sum +++ b/go.sum @@ -125,6 +125,12 @@ github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+ github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= +github.com/chungthuang/quic-go v0.24.1-0.20220106111256-154e7d8a89a9 h1:sHrAhwM2NHkb/5z7+cxDFMCvG3WnSAPbjqSbujLB3nU= +github.com/chungthuang/quic-go v0.24.1-0.20220106111256-154e7d8a89a9/go.mod h1:YtzP8bxRVCBlO77yRanE264+fY/T2U9ZlW1AaHOsMOg= +github.com/chungthuang/quic-go v0.24.1-0.20220106164320-fc99d36b9daa h1:QSi2gWSBtNtCH2/8Y6zFs4H5bnrHQQxFCzl7zJsPp28= +github.com/chungthuang/quic-go v0.24.1-0.20220106164320-fc99d36b9daa/go.mod h1:YtzP8bxRVCBlO77yRanE264+fY/T2U9ZlW1AaHOsMOg= +github.com/chungthuang/quic-go v0.24.1-0.20220110095058-981dc498cb62 h1:PLTB4iA6sOgAItzQY642tYdcGKfG/7i2gu93JQGgUcM= +github.com/chungthuang/quic-go v0.24.1-0.20220110095058-981dc498cb62/go.mod h1:YtzP8bxRVCBlO77yRanE264+fY/T2U9ZlW1AaHOsMOg= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -197,6 +203,7 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= +github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk= github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= @@ -421,6 +428,8 @@ github.com/liquidweb/liquidweb-go v1.6.0/go.mod h1:UDcVnAMDkZxpw4Y7NOHkqoeiGacVL github.com/lucas-clemente/quic-go v0.13.1/go.mod h1:Vn3/Fb0/77b02SGhQk36KzOUmXgVpFfizUfW5WMaqyU= github.com/lucas-clemente/quic-go v0.23.0 h1:5vFnKtZ6nHDFsc/F3uuiF4T3y/AXaQdxjUqiVw26GZE= github.com/lucas-clemente/quic-go v0.23.0/go.mod h1:paZuzjXCE5mj6sikVLMvqXk8lJV2AsqtJ6bDhjEfxx0= +github.com/lucas-clemente/quic-go v0.24.0 h1:ToR7SIIEdrgOhgVTHvPgdVRJfgVy+N0wQAagH7L4d5g= +github.com/lucas-clemente/quic-go v0.24.0/go.mod h1:paZuzjXCE5mj6sikVLMvqXk8lJV2AsqtJ6bDhjEfxx0= github.com/lucasb-eyer/go-colorful v1.0.2/go.mod h1:0MS4r+7BZKSJ5mw4/S5MPN+qHFF1fYclkSPilDOKW0s= github.com/lucasb-eyer/go-colorful v1.0.3 h1:QIbQXiugsb+q10B+MI+7DI1oQLdmnep86tWFlaaUAac= github.com/lucasb-eyer/go-colorful v1.0.3/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= @@ -437,6 +446,8 @@ github.com/marten-seemann/qtls-go1-16 v0.1.4 h1:xbHbOGGhrenVtII6Co8akhLEdrawwB2i github.com/marten-seemann/qtls-go1-16 v0.1.4/go.mod h1:gNpI2Ol+lRS3WwSOtIUUtRwZEQMXjYK+dQSBFbethAk= github.com/marten-seemann/qtls-go1-17 v0.1.0 h1:P9ggrs5xtwiqXv/FHNwntmuLMNq3KaSIG93AtAZ48xk= github.com/marten-seemann/qtls-go1-17 v0.1.0/go.mod h1:fz4HIxByo+LlWcreM4CZOYNuz3taBQ8rN2X6FqvaWo8= +github.com/marten-seemann/qtls-go1-18 v0.1.0-beta.1 h1:EnzzN9fPUkUck/1CuY1FlzBaIYMoiBsdwTNmNGkwUUM= +github.com/marten-seemann/qtls-go1-18 v0.1.0-beta.1/go.mod h1:PUhIQk19LoFt2174H4+an8TYvWOGjb/hHwphBeaDHwI= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= @@ -732,6 +743,8 @@ golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 h1:7I4JAnoQBe7ZtJcBaYHi5UtiO8tQHbUSXxL+pnGRANg= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -815,6 +828,10 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= +golang.org/x/net v0.0.0-20211109214657-ef0fda0de508 h1:v3NKo+t/Kc3EASxaKZ82lwK6mCf4ZeObQBduYFZHo7c= +golang.org/x/net v0.0.0-20211109214657-ef0fda0de508/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d h1:1n1fc535VhN8SYtD4cDUyNlfpAF2ROMM9+11equK3hs= +golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -899,6 +916,8 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007 h1:gG67DSER+11cZvqIMb8S8bt0vZtiN6xWYARwirrOSfE= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 h1:SrN+KX8Art/Sf4HNj6Zcz06G7VEz+7w9tdXTPOZ7+l4= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf h1:MZ2shdL+ZM/XzY3ZGOnh4Nlpnxz5GSOhOmtHo3iPU6M= golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/h2mux/activestreammap.go b/h2mux/activestreammap.go index d8079d07..02386db3 100644 --- a/h2mux/activestreammap.go +++ b/h2mux/activestreammap.go @@ -93,7 +93,7 @@ func (m *activeStreamMap) Set(newStream *MuxedStream) bool { return true } -// Delete stops tracking the stream. It should be called only after it is closed and resetted. +// Delete stops tracking the stream. It should be called only after it is closed and reset. func (m *activeStreamMap) Delete(streamID uint32) { m.Lock() defer m.Unlock() diff --git a/h2mux/h2_compressor_brotli.go b/h2mux/h2_compressor_brotli.go index ed0b85b7..6ff99514 100644 --- a/h2mux/h2_compressor_brotli.go +++ b/h2mux/h2_compressor_brotli.go @@ -1,3 +1,4 @@ +//go:build cgo // +build cgo package h2mux diff --git a/h2mux/h2_compressor_none.go b/h2mux/h2_compressor_none.go index 1ca4d284..bfd6bafe 100644 --- a/h2mux/h2_compressor_none.go +++ b/h2mux/h2_compressor_none.go @@ -1,3 +1,4 @@ +//go:build !cgo // +build !cgo package h2mux diff --git a/h2mux/h2_dictionaries.go b/h2mux/h2_dictionaries.go index bf92d58d..5d11bee7 100644 --- a/h2mux/h2_dictionaries.go +++ b/h2mux/h2_dictionaries.go @@ -12,7 +12,7 @@ import ( /* This is an implementation of https://github.com/vkrasnov/h2-compression-dictionaries but modified for tunnels in a few key ways: Since tunnels is a server-to-server service, some aspects of the spec would cause -unnessasary head-of-line blocking on the CPU and on the network, hence this implementation +unnecessary head-of-line blocking on the CPU and on the network, hence this implementation allows for parallel compression on the "client", and buffering on the "server" to solve this problem. */ @@ -67,7 +67,7 @@ var compressionPresets = map[CompressionSetting]CompressionPreset{ } func compressionSettingVal(version, fmt, sz, nd uint8) uint32 { - // Currently the compression settings are inlcude: + // Currently the compression settings are include: // * version: only 1 is supported // * fmt: only 2 for brotli is supported // * sz: log2 of the maximal allowed dictionary size @@ -438,7 +438,7 @@ func assignDictToStream(s *MuxedStream, p []byte) bool { h2d.dictLock.Lock() if w.comp != nil { - // Check again with lock, in therory the inteface allows for unordered writes + // Check again with lock, in therory the interface allows for unordered writes h2d.dictLock.Unlock() return false } @@ -468,7 +468,7 @@ func assignDictToStream(s *MuxedStream, p []byte) bool { } } else { // Use the overflow dictionary as last resort - // If slots are availabe generate new dictioanries for path and content-type + // If slots are available generate new dictionaries for path and content-type useID, _ = h2d.getGenericDictID() pathID, pathFound = h2d.getNextDictID() if pathFound { diff --git a/h2mux/h2mux.go b/h2mux/h2mux.go index c1bf05bb..40c48e4f 100644 --- a/h2mux/h2mux.go +++ b/h2mux/h2mux.go @@ -174,7 +174,7 @@ func Handshake( pingTimestamp := NewPingTimestamp() connActive := NewSignal() idleDuration := config.HeartbeatInterval - // Sanity check to enusre idelDuration is sane + // Sanity check to ensure idelDuration is sane if idleDuration == 0 || idleDuration < defaultTimeout { idleDuration = defaultTimeout config.Log.Info().Msgf("muxer: Minimum idle time has been adjusted to %d", defaultTimeout) @@ -274,7 +274,7 @@ func (m *Muxer) readPeerSettings(magic uint32) error { m.compressionQuality = compressionPresets[CompressionNone] return nil } - // Values used for compression are the mimimum between the two peers + // Values used for compression are the minimum between the two peers if sz < m.compressionQuality.dictSize { m.compressionQuality.dictSize = sz } diff --git a/h2mux/muxmetrics_test.go b/h2mux/muxmetrics_test.go index d65e35af..a9213a2c 100644 --- a/h2mux/muxmetrics_test.go +++ b/h2mux/muxmetrics_test.go @@ -130,7 +130,7 @@ func TestMuxMetricsUpdater(t *testing.T) { m.updateReceiveWindow(uint32(j)) m.updateSendWindow(uint32(j)) - // should always be disgarded since the send time is before readerSend + // should always be discarded since the send time is before readerSend rm := &roundTripMeasurement{receiveTime: readerStart, sendTime: readerStart.Add(-time.Duration(j*dataPoints) * time.Millisecond)} m.updateRTT(rm) diff --git a/hello/hello.go b/hello/hello.go index 419040fa..785b341a 100644 --- a/hello/hello.go +++ b/hello/hello.go @@ -37,7 +37,7 @@ type OriginUpTime struct { UpTime string `json:"uptime"` } -const defaultServerName = "the Argo Tunnel test server" +const defaultServerName = "the Cloudflare Tunnel test server" const indexTemplate = ` @@ -45,10 +45,10 @@ const indexTemplate = ` - Argo Tunnel Connection + Cloudflare Tunnel Connection - +