Merge branch 'cloudflare:master' into tunnel-health

This commit is contained in:
Mads Jon Nielsen 2024-04-23 08:08:04 +02:00 committed by GitHub
commit 6db3cb2f1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
300 changed files with 3724 additions and 16101 deletions

View File

@ -9,10 +9,10 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v3 uses: actions/setup-go@v5
with: with:
go-version: ${{ matrix.go-version }} go-version: ${{ matrix.go-version }}
- name: Checkout code - name: Checkout code
uses: actions/checkout@v3 uses: actions/checkout@v4
- name: Test - name: Test
run: make test run: make test

8
.teamcity/install-cloudflare-go.sh vendored Executable file
View File

@ -0,0 +1,8 @@
# !/usr/bin/env bash
cd /tmp
git clone -q https://github.com/cloudflare/go
cd go/src
# https://github.com/cloudflare/go/tree/34129e47042e214121b6bbff0ded4712debed18e is version go1.21.5-devel-cf
git checkout -q 34129e47042e214121b6bbff0ded4712debed18e
./make.bash

View File

@ -1,13 +1,8 @@
cd /tmp/ rm -rf /tmp/go
rm -rf go
rm -rf gocache
export GOCACHE=/tmp/gocache export GOCACHE=/tmp/gocache
rm -rf $GOCACHE
git clone -q https://github.com/cloudflare/go ./.teamcity/install-cloudflare-go.sh
cd go/src
# https://github.com/cloudflare/go/tree/34129e47042e214121b6bbff0ded4712debed18e is version go1.21.5-devel-cf
git checkout -q 34129e47042e214121b6bbff0ded4712debed18e
./make.bash
export PATH="/tmp/go/bin:$PATH" export PATH="/tmp/go/bin:$PATH"
go version go version

View File

@ -1,26 +0,0 @@
#!/bin/bash
set -euo pipefail
if ! VERSION="$(git describe --tags --exact-match 2>/dev/null)" ; then
echo "Skipping public release for an untagged commit."
echo "##teamcity[buildStatus status='SUCCESS' text='Skipped due to lack of tag']"
exit 0
fi
if [[ "${HOMEBREW_GITHUB_API_TOKEN:-}" == "" ]] ; then
echo "Missing GITHUB_API_TOKEN"
exit 1
fi
# "install" Homebrew
git clone https://github.com/Homebrew/brew tmp/homebrew
eval "$(tmp/homebrew/bin/brew shellenv)"
brew update --force --quiet
chmod -R go-w "$(brew --prefix)/share/zsh"
git config --global user.name "cloudflare-warp-bot"
git config --global user.email "warp-bot@cloudflare.com"
# bump formula pr
brew bump-formula-pr cloudflared --version="$VERSION" --no-browse --no-audit

View File

@ -1,66 +0,0 @@
#!/bin/bash
set -euo pipefail
FILENAME="${PWD}/artifacts/cloudflared-darwin-amd64.tgz"
if ! VERSION="$(git describe --tags --exact-match 2>/dev/null)" ; then
echo "Skipping public release for an untagged commit."
echo "##teamcity[buildStatus status='SUCCESS' text='Skipped due to lack of tag']"
exit 0
fi
if [[ ! -f "$FILENAME" ]] ; then
echo "Missing $FILENAME"
exit 1
fi
if [[ "${GITHUB_PRIVATE_KEY_B64:-}" == "" ]] ; then
echo "Missing GITHUB_PRIVATE_KEY_B64"
exit 1
fi
# upload to s3 bucket for use by Homebrew formula
s3cmd \
--acl-public --access_key="$AWS_ACCESS_KEY_ID" --secret_key="$AWS_SECRET_ACCESS_KEY" --host-bucket="%(bucket)s.s3.cfdata.org" \
put "$FILENAME" "s3://cftunnel-docs/dl/cloudflared-$VERSION-darwin-amd64.tgz"
s3cmd \
--acl-public --access_key="$AWS_ACCESS_KEY_ID" --secret_key="$AWS_SECRET_ACCESS_KEY" --host-bucket="%(bucket)s.s3.cfdata.org" \
cp "s3://cftunnel-docs/dl/cloudflared-$VERSION-darwin-amd64.tgz" "s3://cftunnel-docs/dl/cloudflared-stable-darwin-amd64.tgz"
SHA256=$(sha256sum "$FILENAME" | cut -b1-64)
# set up git (note that UserKnownHostsFile is an absolute path so we can cd wherever)
mkdir -p tmp
ssh-keyscan -t rsa github.com > tmp/github.txt
echo "$GITHUB_PRIVATE_KEY_B64" | base64 --decode > tmp/private.key
chmod 0400 tmp/private.key
export GIT_SSH_COMMAND="ssh -o UserKnownHostsFile=$PWD/tmp/github.txt -i $PWD/tmp/private.key -o IdentitiesOnly=yes"
# clone Homebrew repo into tmp/homebrew-cloudflare
git clone git@github.com:cloudflare/homebrew-cloudflare.git tmp/homebrew-cloudflare
cd tmp/homebrew-cloudflare
git checkout -f master
git reset --hard origin/master
# modify cloudflared.rb
URL="https://packages.argotunnel.com/dl/cloudflared-$VERSION-darwin-amd64.tgz"
tee cloudflared.rb <<EOF
class Cloudflared < Formula
desc 'Cloudflare Tunnel'
homepage 'https://developers.cloudflare.com/cloudflare-one/connections/connect-apps'
url '$URL'
sha256 '$SHA256'
version '$VERSION'
def install
bin.install 'cloudflared'
end
end
EOF
# push cloudflared.rb
git add cloudflared.rb
git diff
git config user.name "cloudflare-warp-bot"
git config user.email "warp-bot@cloudflare.com"
git commit -m "Release Cloudflare Tunnel $VERSION"
git push -v origin master

View File

@ -1,3 +1,7 @@
## 2024.2.1
### Notices
- Starting from this version, tunnel diagnostics will be enabled by default. This will allow the engineering team to remotely get diagnostics from cloudflared during debug activities. Users still have the capability to opt-out of this feature by defining `--management-diagnostics=false` (or env `TUNNEL_MANAGEMENT_DIAGNOSTICS`).
## 2023.9.0 ## 2023.9.0
### Notices ### Notices
- The `warp-routing` `enabled: boolean` flag is no longer supported in the configuration file. Warp Routing traffic (eg TCP, UDP, ICMP) traffic is proxied to cloudflared if routes to the target tunnel are configured. This change does not affect remotely managed tunnels, but for locally managed tunnels, users that might be relying on this feature flag to block traffic should instead guarantee that tunnel has no Private Routes configured for the tunnel. - The `warp-routing` `enabled: boolean` flag is no longer supported in the configuration file. Warp Routing traffic (eg TCP, UDP, ICMP) traffic is proxied to cloudflared if routes to the target tunnel are configured. This change does not affect remotely managed tunnels, but for locally managed tunnels, users that might be relying on this feature flag to block traffic should instead guarantee that tunnel has no Private Routes configured for the tunnel.

View File

@ -12,8 +12,10 @@ WORKDIR /go/src/github.com/cloudflare/cloudflared/
# copy our sources into the builder image # copy our sources into the builder image
COPY . . COPY . .
RUN .teamcity/install-cloudflare-go.sh
# compile cloudflared # compile cloudflared
RUN make cloudflared RUN PATH="/tmp/go/bin:$PATH" make cloudflared
# use a distroless base image with glibc # use a distroless base image with glibc
FROM gcr.io/distroless/base-debian11:nonroot FROM gcr.io/distroless/base-debian11:nonroot

View File

@ -8,8 +8,10 @@ WORKDIR /go/src/github.com/cloudflare/cloudflared/
# copy our sources into the builder image # copy our sources into the builder image
COPY . . COPY . .
RUN .teamcity/install-cloudflare-go.sh
# compile cloudflared # compile cloudflared
RUN GOOS=linux GOARCH=amd64 make cloudflared RUN GOOS=linux GOARCH=amd64 PATH="/tmp/go/bin:$PATH" make cloudflared
# use a distroless base image with glibc # use a distroless base image with glibc
FROM gcr.io/distroless/base-debian11:nonroot FROM gcr.io/distroless/base-debian11:nonroot

View File

@ -8,8 +8,10 @@ WORKDIR /go/src/github.com/cloudflare/cloudflared/
# copy our sources into the builder image # copy our sources into the builder image
COPY . . COPY . .
RUN .teamcity/install-cloudflare-go.sh
# compile cloudflared # compile cloudflared
RUN GOOS=linux GOARCH=arm64 make cloudflared RUN GOOS=linux GOARCH=arm64 PATH="/tmp/go/bin:$PATH" make cloudflared
# use a distroless base image with glibc # use a distroless base image with glibc
FROM gcr.io/distroless/base-debian11:nonroot-arm64 FROM gcr.io/distroless/base-debian11:nonroot-arm64

View File

@ -1,3 +1,6 @@
# The targets cannot be run in parallel
.NOTPARALLEL:
VERSION := $(shell git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*") VERSION := $(shell git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*")
MSI_VERSION := $(shell git tag -l --sort=v:refname | grep "w" | tail -1 | cut -c2-) MSI_VERSION := $(shell git tag -l --sort=v:refname | grep "w" | tail -1 | cut -c2-)
#MSI_VERSION expects the format of the tag to be: (wX.X.X). Starts with the w character to not break cfsetup. #MSI_VERSION expects the format of the tag to be: (wX.X.X). Starts with the w character to not break cfsetup.
@ -49,6 +52,8 @@ PACKAGE_DIR := $(CURDIR)/packaging
PREFIX := /usr PREFIX := /usr
INSTALL_BINDIR := $(PREFIX)/bin/ INSTALL_BINDIR := $(PREFIX)/bin/
INSTALL_MANDIR := $(PREFIX)/share/man/man1/ INSTALL_MANDIR := $(PREFIX)/share/man/man1/
CF_GO_PATH := /tmp/go
PATH := $(CF_GO_PATH)/bin:$(PATH)
LOCAL_ARCH ?= $(shell uname -m) LOCAL_ARCH ?= $(shell uname -m)
ifneq ($(GOARCH),) ifneq ($(GOARCH),)
@ -164,10 +169,19 @@ cover:
test-ssh-server: test-ssh-server:
docker-compose -f ssh_server_tests/docker-compose.yml up docker-compose -f ssh_server_tests/docker-compose.yml up
.PHONY: install-go
install-go:
rm -rf ${CF_GO_PATH}
./.teamcity/install-cloudflare-go.sh
.PHONY: cleanup-go
cleanup-go:
rm -rf ${CF_GO_PATH}
cloudflared.1: cloudflared_man_template cloudflared.1: cloudflared_man_template
sed -e 's/\$${VERSION}/$(VERSION)/; s/\$${DATE}/$(DATE)/' cloudflared_man_template > cloudflared.1 sed -e 's/\$${VERSION}/$(VERSION)/; s/\$${DATE}/$(DATE)/' cloudflared_man_template > cloudflared.1
install: cloudflared cloudflared.1 install: install-go cloudflared cloudflared.1 cleanup-go
mkdir -p $(DESTDIR)$(INSTALL_BINDIR) $(DESTDIR)$(INSTALL_MANDIR) mkdir -p $(DESTDIR)$(INSTALL_BINDIR) $(DESTDIR)$(INSTALL_MANDIR)
install -m755 cloudflared $(DESTDIR)$(INSTALL_BINDIR)/cloudflared install -m755 cloudflared $(DESTDIR)$(INSTALL_BINDIR)/cloudflared
install -m644 cloudflared.1 $(DESTDIR)$(INSTALL_MANDIR)/cloudflared.1 install -m644 cloudflared.1 $(DESTDIR)$(INSTALL_MANDIR)/cloudflared.1
@ -209,15 +223,6 @@ cloudflared-darwin-amd64.tgz: cloudflared
tar czf cloudflared-darwin-amd64.tgz cloudflared tar czf cloudflared-darwin-amd64.tgz cloudflared
rm cloudflared rm cloudflared
.PHONY: homebrew-upload
homebrew-upload: cloudflared-darwin-amd64.tgz
aws s3 --endpoint-url $(S3_ENDPOINT) cp --acl public-read $$^ $(S3_URI)/cloudflared-$$(VERSION)-$1.tgz
aws s3 --endpoint-url $(S3_ENDPOINT) cp --acl public-read $(S3_URI)/cloudflared-$$(VERSION)-$1.tgz $(S3_URI)/cloudflared-stable-$1.tgz
.PHONY: homebrew-release
homebrew-release: homebrew-upload
./publish-homebrew-formula.sh cloudflared-darwin-amd64.tgz $(VERSION) homebrew-cloudflare
.PHONY: github-release .PHONY: github-release
github-release: cloudflared github-release: cloudflared
python3 github_release.py --path $(EXECUTABLE_PATH) --release-version $(VERSION) python3 github_release.py --path $(EXECUTABLE_PATH) --release-version $(VERSION)

View File

@ -31,7 +31,7 @@ Downloads are available as standalone binaries, a Docker image, and Debian, RPM,
* Binaries, Debian, and RPM packages for Linux [can be found here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation#linux) * Binaries, Debian, and RPM packages for Linux [can be found here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation#linux)
* A Docker image of `cloudflared` is [available on DockerHub](https://hub.docker.com/r/cloudflare/cloudflared) * A Docker image of `cloudflared` is [available on DockerHub](https://hub.docker.com/r/cloudflare/cloudflared)
* You can install on Windows machines with the [steps here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation#windows) * You can install on Windows machines with the [steps here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation#windows)
* Build from source with the [instructions here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation#build-from-source) * To build from source, first you need to download the go toolchain by running `./.teamcity/install-cloudflare-go.sh` and follow the output. Then you can run `make cloudflared`
User documentation for Cloudflare Tunnel can be found at https://developers.cloudflare.com/cloudflare-one/connections/connect-apps User documentation for Cloudflare Tunnel can be found at https://developers.cloudflare.com/cloudflare-one/connections/connect-apps

View File

@ -1,3 +1,78 @@
2024.4.0
- 2024-04-02 feat: provide short version (#1206)
- 2024-04-02 Format code
- 2024-01-18 feat: auto tls sni
- 2023-12-24 fix checkInPingGroup bugs
- 2023-12-15 Add environment variables for TCP tunnel hostname / destination / URL.
2024.3.0
- 2024-03-14 TUN-8281: Run cloudflared query list tunnels/routes endpoint in a paginated way
- 2024-03-13 TUN-8297: Improve write timeout logging on safe_stream.go
- 2024-03-07 TUN-8290: Remove `|| true` from postrm.sh
- 2024-03-05 TUN-8275: Skip write timeout log on "no network activity"
- 2024-01-23 Update postrm.sh to fix incomplete uninstall
- 2024-01-05 fix typo in errcheck for response parsing logic in CreateTunnel routine
- 2023-12-23 Update linux_service.go
- 2023-12-07 ci: bump actions/checkout to v4
- 2023-12-07 ci/check: bump actions/setup-go to v5
- 2023-04-28 check.yaml: bump actions/setup-go to v4
2024.2.1
- 2024-02-20 TUN-8242: Update Changes.md file with new remote diagnostics behaviour
- 2024-02-19 TUN-8238: Fix type mismatch introduced by fast-forward
- 2024-02-16 TUN-8243: Collect metrics on the number of QUIC frames sent/received
- 2024-02-15 TUN-8238: Refactor proxy logging
- 2024-02-14 TUN-8242: Enable remote diagnostics by default
- 2024-02-12 TUN-8236: Add write timeout to quic and tcp connections
- 2024-02-09 TUN-8224: Fix safety of TCP stream logging, separate connect and ack log messages
2024.2.0
- 2024-02-07 TUN-8224: Count and collect metrics on stream connect successes/errors
2024.1.5
- 2024-01-22 TUN-8176: Support ARM platforms that don't have an FPU or have it enabled in kernel
- 2024-01-15 TUN-8158: Bring back commit e6537418859afcac29e56a39daa08bcabc09e048 and fixes infinite loop on linux when the socket is closed
2024.1.4
- 2024-01-19 Revert "TUN-8158: Add logging to confirm when ICMP reply is returned to the edge"
2024.1.3
- 2024-01-15 TUN-8161: Fix broken ARM build for armv6
- 2024-01-15 TUN-8158: Add logging to confirm when ICMP reply is returned to the edge
2024.1.2
- 2024-01-11 TUN-8147: Disable ECN usage due to bugs in detecting if supported
- 2024-01-11 TUN-8146: Fix export path for install-go command
- 2024-01-11 TUN-8146: Fix Makefile targets should not be run in parallel and install-go script was missing shebang
- 2024-01-10 TUN-8140: Remove homebrew scripts
2024.1.1
- 2024-01-10 TUN-8134: Revert installed prefix to /usr
- 2024-01-09 TUN-8130: Fix path to install go for mac build
- 2024-01-09 TUN-8129: Use the same build command between branch and release builds
- 2024-01-09 TUN-8130: Install go tool chain in /tmp on build agents
- 2024-01-09 TUN-8134: Install cloudflare go as part of make install
- 2024-01-08 TUN-8118: Disable FIPS module to build with go-boring without CGO_ENABLED
2024.1.0
- 2024-01-01 TUN-7934: Update quic-go to a version that queues datagrams for better throughput and drops large datagram
- 2023-12-20 TUN-8072: Need to set GOCACHE in mac go installation script
- 2023-12-17 TUN-8072: Add script to download cloudflare go for Mac build agents
- 2023-12-15 Fix nil pointer dereference segfault when passing "null" config json to cloudflared tunnel ingress validate (#1070)
- 2023-12-15 configuration.go: fix developerPortal link (#960)
- 2023-12-14 tunnelrpc/pogs: fix dropped test errors (#1106)
- 2023-12-14 cmd/cloudflared/updater: fix dropped error (#1055)
- 2023-12-14 use os.Executable to discover the path to cloudflared (#1040)
- 2023-12-14 Remove extraneous `period` from Path Environment Variable (#1009)
- 2023-12-14 Use CLI context when running tunnel (#597)
- 2023-12-14 TUN-8066: Define scripts to build on Windows agents
- 2023-12-11 TUN-8052: Update go to 1.21.5
- 2023-12-07 TUN-7970: Default to enable post quantum encryption for quic transport
- 2023-12-04 TUN-8006: Update quic-go to latest upstream
- 2023-11-15 VULN-44842 Add a flag that allows users to not send the Access JWT to stdout
- 2023-11-13 TUN-7965: Remove legacy incident status page check
- 2023-11-13 AUTH-5682 Org token flow in Access logins should pass CF_AppSession cookie
2023.10.0 2023.10.0
- 2023-10-06 TUN-7864: Document cloudflared versions support - 2023-10-06 TUN-7864: Document cloudflared versions support
- 2023-10-03 CUSTESC-33731: Make rule match test report rule in 0-index base - 2023-10-03 CUSTESC-33731: Make rule match test report rule in 0-index base

View File

@ -1,3 +1,4 @@
#!/bin/bash
VERSION=$(git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*") VERSION=$(git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*")
echo $VERSION echo $VERSION

View File

@ -1,7 +1,9 @@
#!/bin/bash
VERSION=$(git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*") VERSION=$(git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*")
echo $VERSION echo $VERSION
# Avoid depending on C code since we don't need it. # Disable FIPS module in go-boring
export GOEXPERIMENT=noboringcrypto
export CGO_ENABLED=0 export CGO_ENABLED=0
# This controls the directory the built artifacts go into # This controls the directory the built artifacts go into
@ -14,6 +16,12 @@ for arch in ${linuxArchs[@]}; do
unset TARGET_ARM unset TARGET_ARM
export TARGET_ARCH=$arch export TARGET_ARCH=$arch
## Support for arm platforms without hardware FPU enabled
if [[ $arch == arm ]] ; then
export TARGET_ARCH=arm
export TARGET_ARM=5
fi
## Support for armhf builds ## Support for armhf builds
if [[ $arch == armhf ]] ; then if [[ $arch == armhf ]] ; then
export TARGET_ARCH=arm export TARGET_ARCH=arm

View File

@ -109,20 +109,34 @@ func (r *RESTClient) sendRequest(method string, url url.URL, body interface{}) (
return r.client.Do(req) return r.client.Do(req)
} }
func parseResponse(reader io.Reader, data interface{}) error { func parseResponseEnvelope(reader io.Reader) (*response, error) {
// Schema for Tunnelstore responses in the v1 API. // Schema for Tunnelstore responses in the v1 API.
// Roughly, it's a wrapper around a particular result that adds failures/errors/etc // Roughly, it's a wrapper around a particular result that adds failures/errors/etc
var result response var result response
// First, parse the wrapper and check the API call succeeded // First, parse the wrapper and check the API call succeeded
if err := json.NewDecoder(reader).Decode(&result); err != nil { if err := json.NewDecoder(reader).Decode(&result); err != nil {
return errors.Wrap(err, "failed to decode response") return nil, errors.Wrap(err, "failed to decode response")
} }
if err := result.checkErrors(); err != nil { if err := result.checkErrors(); err != nil {
return err return nil, err
} }
if !result.Success { if !result.Success {
return ErrAPINoSuccess return nil, ErrAPINoSuccess
} }
return &result, nil
}
func parseResponse(reader io.Reader, data interface{}) error {
result, err := parseResponseEnvelope(reader)
if err != nil {
return err
}
return parseResponseBody(result, data)
}
func parseResponseBody(result *response, data interface{}) error {
// At this point we know the API call succeeded, so, parse out the inner // At this point we know the API call succeeded, so, parse out the inner
// result into the datatype provided as a parameter. // result into the datatype provided as a parameter.
if err := json.Unmarshal(result.Result, &data); err != nil { if err := json.Unmarshal(result.Result, &data); err != nil {
@ -131,11 +145,58 @@ func parseResponse(reader io.Reader, data interface{}) error {
return nil return nil
} }
func fetchExhaustively[T any](requestFn func(int) (*http.Response, error)) ([]*T, error) {
page := 0
var fullResponse []*T
for {
page += 1
envelope, parsedBody, err := fetchPage[T](requestFn, page)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("Error Parsing page %d", page))
}
fullResponse = append(fullResponse, parsedBody...)
if envelope.Pagination.Count < envelope.Pagination.PerPage || len(fullResponse) >= envelope.Pagination.TotalCount {
break
}
}
return fullResponse, nil
}
func fetchPage[T any](requestFn func(int) (*http.Response, error), page int) (*response, []*T, error) {
pageResp, err := requestFn(page)
if err != nil {
return nil, nil, errors.Wrap(err, "REST request failed")
}
defer pageResp.Body.Close()
if pageResp.StatusCode == http.StatusOK {
envelope, err := parseResponseEnvelope(pageResp.Body)
if err != nil {
return nil, nil, err
}
var parsedRspBody []*T
return envelope, parsedRspBody, parseResponseBody(envelope, &parsedRspBody)
}
return nil, nil, errors.New(fmt.Sprintf("Failed to fetch page. Server returned: %d", pageResp.StatusCode))
}
type response struct { type response struct {
Success bool `json:"success,omitempty"` Success bool `json:"success,omitempty"`
Errors []apiErr `json:"errors,omitempty"` Errors []apiErr `json:"errors,omitempty"`
Messages []string `json:"messages,omitempty"` Messages []string `json:"messages,omitempty"`
Result json.RawMessage `json:"result,omitempty"` Result json.RawMessage `json:"result,omitempty"`
Pagination Pagination `json:"result_info,omitempty"`
}
type Pagination struct {
Count int `json:"count,omitempty"`
Page int `json:"page,omitempty"`
PerPage int `json:"per_page,omitempty"`
TotalCount int `json:"total_count,omitempty"`
} }
func (r *response) checkErrors() error { func (r *response) checkErrors() error {

View File

@ -137,20 +137,24 @@ type GetRouteByIpParams struct {
} }
// ListRoutes calls the Tunnelstore GET endpoint for all routes under an account. // ListRoutes calls the Tunnelstore GET endpoint for all routes under an account.
// Due to pagination on the server side it will call the endpoint multiple times if needed.
func (r *RESTClient) ListRoutes(filter *IpRouteFilter) ([]*DetailedRoute, error) { func (r *RESTClient) ListRoutes(filter *IpRouteFilter) ([]*DetailedRoute, error) {
fetchFn := func(page int) (*http.Response, error) {
endpoint := r.baseEndpoints.accountRoutes endpoint := r.baseEndpoints.accountRoutes
filter.Page(page)
endpoint.RawQuery = filter.Encode() endpoint.RawQuery = filter.Encode()
resp, err := r.sendRequest("GET", endpoint, nil) rsp, err := r.sendRequest("GET", endpoint, nil)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "REST request failed") return nil, errors.Wrap(err, "REST request failed")
} }
defer resp.Body.Close() if rsp.StatusCode != http.StatusOK {
rsp.Body.Close()
if resp.StatusCode == http.StatusOK { return nil, r.statusCodeToError("list routes", rsp)
return parseListDetailedRoutes(resp.Body)
} }
return rsp, nil
return nil, r.statusCodeToError("list routes", resp) }
return fetchExhaustively[DetailedRoute](fetchFn)
} }
// AddRoute calls the Tunnelstore POST endpoint for a given route. // AddRoute calls the Tunnelstore POST endpoint for a given route.
@ -208,12 +212,6 @@ func (r *RESTClient) GetByIP(params GetRouteByIpParams) (DetailedRoute, error) {
return DetailedRoute{}, r.statusCodeToError("get route by IP", resp) return DetailedRoute{}, r.statusCodeToError("get route by IP", resp)
} }
func parseListDetailedRoutes(body io.ReadCloser) ([]*DetailedRoute, error) {
var routes []*DetailedRoute
err := parseResponse(body, &routes)
return routes, err
}
func parseRoute(body io.ReadCloser) (Route, error) { func parseRoute(body io.ReadCloser) (Route, error) {
var route Route var route Route
err := parseResponse(body, &route) err := parseResponse(body, &route)

View File

@ -167,6 +167,10 @@ func (f *IpRouteFilter) MaxFetchSize(max uint) {
f.queryParams.Set("per_page", strconv.Itoa(int(max))) f.queryParams.Set("per_page", strconv.Itoa(int(max)))
} }
func (f *IpRouteFilter) Page(page int) {
f.queryParams.Set("page", strconv.Itoa(page))
}
func (f IpRouteFilter) Encode() string { func (f IpRouteFilter) Encode() string {
return f.queryParams.Encode() return f.queryParams.Encode()
} }

View File

@ -93,7 +93,7 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*TunnelWith
switch resp.StatusCode { switch resp.StatusCode {
case http.StatusOK: case http.StatusOK:
var tunnel TunnelWithToken var tunnel TunnelWithToken
if serdeErr := parseResponse(resp.Body, &tunnel); err != nil { if serdeErr := parseResponse(resp.Body, &tunnel); serdeErr != nil {
return nil, serdeErr return nil, serdeErr
} }
return &tunnel, nil return &tunnel, nil
@ -177,25 +177,22 @@ func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID, cascade bool) error {
} }
func (r *RESTClient) ListTunnels(filter *TunnelFilter) ([]*Tunnel, error) { func (r *RESTClient) ListTunnels(filter *TunnelFilter) ([]*Tunnel, error) {
fetchFn := func(page int) (*http.Response, error) {
endpoint := r.baseEndpoints.accountLevel endpoint := r.baseEndpoints.accountLevel
filter.Page(page)
endpoint.RawQuery = filter.encode() endpoint.RawQuery = filter.encode()
resp, err := r.sendRequest("GET", endpoint, nil) rsp, err := r.sendRequest("GET", endpoint, nil)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "REST request failed") return nil, errors.Wrap(err, "REST request failed")
} }
defer resp.Body.Close() if rsp.StatusCode != http.StatusOK {
rsp.Body.Close()
if resp.StatusCode == http.StatusOK { return nil, r.statusCodeToError("list tunnels", rsp)
return parseListTunnels(resp.Body) }
return rsp, nil
} }
return nil, r.statusCodeToError("list tunnels", resp) return fetchExhaustively[Tunnel](fetchFn)
}
func parseListTunnels(body io.ReadCloser) ([]*Tunnel, error) {
var tunnels []*Tunnel
err := parseResponse(body, &tunnels)
return tunnels, err
} }
func (r *RESTClient) ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error) { func (r *RESTClient) ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error) {

View File

@ -50,6 +50,10 @@ func (f *TunnelFilter) MaxFetchSize(max uint) {
f.queryParams.Set("per_page", strconv.Itoa(int(max))) f.queryParams.Set("per_page", strconv.Itoa(int(max)))
} }
func (f *TunnelFilter) Page(page int) {
f.queryParams.Set("page", strconv.Itoa(page))
}
func (f TunnelFilter) encode() string { func (f TunnelFilter) encode() string {
return f.queryParams.Encode() return f.queryParams.Encode()
} }

View File

@ -3,7 +3,6 @@ package cfapi
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
"net" "net"
"reflect" "reflect"
"strings" "strings"
@ -16,52 +15,6 @@ import (
var loc, _ = time.LoadLocation("UTC") var loc, _ = time.LoadLocation("UTC")
func Test_parseListTunnels(t *testing.T) {
type args struct {
body string
}
tests := []struct {
name string
args args
want []*Tunnel
wantErr bool
}{
{
name: "empty list",
args: args{body: `{"success": true, "result": []}`},
want: []*Tunnel{},
},
{
name: "success is false",
args: args{body: `{"success": false, "result": []}`},
wantErr: true,
},
{
name: "errors are present",
args: args{body: `{"errors": [{"code": 1003, "message":"An A, AAAA or CNAME record already exists with that host"}], "result": []}`},
wantErr: true,
},
{
name: "invalid response",
args: args{body: `abc`},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body := io.NopCloser(bytes.NewReader([]byte(tt.args.body)))
got, err := parseListTunnels(body)
if (err != nil) != tt.wantErr {
t.Errorf("parseListTunnels() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseListTunnels() = %v, want %v", got, tt.want)
}
})
}
}
func Test_unmarshalTunnel(t *testing.T) { func Test_unmarshalTunnel(t *testing.T) {
type args struct { type args struct {
body string body string

View File

@ -9,22 +9,30 @@ buster: &buster
- *pinned_go - *pinned_go
- build-essential - build-essential
- gotest-to-teamcity - gotest-to-teamcity
- fakeroot
- rubygem-fpm
- rpm
- libffi-dev
- reprepro
- createrepo
pre-cache: &build_pre_cache pre-cache: &build_pre_cache
- export GOCACHE=/cfsetup_build/.cache/go-build - export GOCACHE=/cfsetup_build/.cache/go-build
- go install golang.org/x/tools/cmd/goimports@latest - go install golang.org/x/tools/cmd/goimports@latest
post-cache: post-cache:
- export GOOS=linux # TODO: TUN-8126 this is temporary to make sure packages can be built before release
- export GOARCH=amd64 - ./build-packages.sh
- make cloudflared # Build binary for component test
- GOOS=linux GOARCH=amd64 make cloudflared
build-fips: build-fips:
build_dir: *build_dir build_dir: *build_dir
builddeps: *build_deps builddeps: *build_deps
pre-cache: *build_pre_cache pre-cache: *build_pre_cache
post-cache: post-cache:
- export GOOS=linux
- export GOARCH=amd64
- make cloudflared
- export FIPS=true - export FIPS=true
# TODO: TUN-8126 this is temporary to make sure packages can be built before release
- ./build-packages-fips.sh
# Build binary for component test
- GOOS=linux GOARCH=amd64 make cloudflared
cover: cover:
build_dir: *build_dir build_dir: *build_dir
builddeps: *build_deps builddeps: *build_deps
@ -234,16 +242,6 @@ buster: &buster
- component-tests/requirements.txt - component-tests/requirements.txt
pre-cache: *component_test_pre_cache pre-cache: *component_test_pre_cache
post-cache: *component_test_post_cache post-cache: *component_test_post_cache
update-homebrew:
builddeps:
- openssh-client
- s3cmd
- jq
- build-essential
- procps
post-cache:
- .teamcity/update-homebrew.sh
- .teamcity/update-homebrew-core.sh
github-message-release: github-message-release:
build_dir: *build_dir build_dir: *build_dir
builddeps: *build_pygithub builddeps: *build_pygithub

View File

@ -132,15 +132,18 @@ func Commands() []*cli.Command {
Name: sshHostnameFlag, Name: sshHostnameFlag,
Aliases: []string{"tunnel-host", "T"}, Aliases: []string{"tunnel-host", "T"},
Usage: "specify the hostname of your application.", Usage: "specify the hostname of your application.",
EnvVars: []string{"TUNNEL_SERVICE_HOSTNAME"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: sshDestinationFlag, Name: sshDestinationFlag,
Usage: "specify the destination address of your SSH server.", Usage: "specify the destination address of your SSH server.",
EnvVars: []string{"TUNNEL_SERVICE_DESTINATION"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: sshURLFlag, Name: sshURLFlag,
Aliases: []string{"listener", "L"}, Aliases: []string{"listener", "L"},
Usage: "specify the host:port to forward data to Cloudflare edge.", Usage: "specify the host:port to forward data to Cloudflare edge.",
EnvVars: []string{"TUNNEL_SERVICE_URL"},
}, },
&cli.StringSliceFlag{ &cli.StringSliceFlag{
Name: sshHeaderFlag, Name: sshHeaderFlag,

View File

@ -55,7 +55,8 @@ var systemdAllTemplates = map[string]ServiceTemplate{
Path: fmt.Sprintf("/etc/systemd/system/%s", cloudflaredService), Path: fmt.Sprintf("/etc/systemd/system/%s", cloudflaredService),
Content: `[Unit] Content: `[Unit]
Description=cloudflared Description=cloudflared
After=network.target After=network-online.target
Wants=network-online.target
[Service] [Service]
TimeoutStartSec=0 TimeoutStartSec=0
@ -72,7 +73,8 @@ WantedBy=multi-user.target
Path: fmt.Sprintf("/etc/systemd/system/%s", cloudflaredUpdateService), Path: fmt.Sprintf("/etc/systemd/system/%s", cloudflaredUpdateService),
Content: `[Unit] Content: `[Unit]
Description=Update cloudflared Description=Update cloudflared
After=network.target After=network-online.target
Wants=network-online.target
[Service] [Service]
ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 11 ]; then systemctl restart cloudflared; exit 0; fi; exit $code' ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 11 ]; then systemctl restart cloudflared; exit 0; fi; exit $code'

View File

@ -3,6 +3,7 @@ package main
import ( import (
"fmt" "fmt"
"math/rand" "math/rand"
"os"
"strings" "strings"
"time" "time"
@ -49,6 +50,9 @@ var (
) )
func main() { func main() {
// FIXME: TUN-8148: Disable QUIC_GO ECN due to bugs in proper detection if supported
os.Setenv("QUIC_GO_DISABLE_ECN", "1")
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
metrics.RegisterBuildInfo(BuildType, BuildTime, Version) metrics.RegisterBuildInfo(BuildType, BuildTime, Version)
maxprocs.Set() maxprocs.Set()
@ -130,11 +134,22 @@ To determine if an update happened in a script, check for error code 11.`,
{ {
Name: "version", Name: "version",
Action: func(c *cli.Context) (err error) { Action: func(c *cli.Context) (err error) {
if c.Bool("short") {
fmt.Println(strings.Split(c.App.Version, " ")[0])
return nil
}
version(c) version(c)
return nil return nil
}, },
Usage: versionText, Usage: versionText,
Description: versionText, Description: versionText,
Flags: []cli.Flag{
&cli.BoolFlag{
Name: "short",
Aliases: []string{"s"},
Usage: "print just the version number",
},
},
}, },
} }
cmds = append(cmds, tunnel.Commands()...) cmds = append(cmds, tunnel.Commands()...)

View File

@ -81,6 +81,9 @@ const (
// udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge // udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge
udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout" udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout"
// writeStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin).
writeStreamTimeout = "write-stream-timeout"
// quicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size. // quicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size.
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size. // Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
// Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets. // Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets.
@ -697,6 +700,13 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Value: 5 * time.Second, Value: 5 * time.Second,
Hidden: true, Hidden: true,
}), }),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: writeStreamTimeout,
EnvVars: []string{"TUNNEL_STREAM_WRITE_TIMEOUT"},
Usage: "Use this option to add a stream write timeout for connections when writing towards the origin or edge. Default is 0 which disables the write timeout.",
Value: 0 * time.Second,
Hidden: true,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{ altsrc.NewBoolFlag(&cli.BoolFlag{
Name: quicDisablePathMTUDiscovery, Name: quicDisablePathMTUDiscovery,
EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"}, EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"},
@ -781,7 +791,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Name: "management-diagnostics", Name: "management-diagnostics",
Usage: "Enables the in-depth diagnostic routes to be made available over the management service (/debug/pprof, /metrics, etc.)", Usage: "Enables the in-depth diagnostic routes to be made available over the management service (/debug/pprof, /metrics, etc.)",
EnvVars: []string{"TUNNEL_MANAGEMENT_DIAGNOSTICS"}, EnvVars: []string{"TUNNEL_MANAGEMENT_DIAGNOSTICS"},
Value: false, Value: true,
}), }),
selectProtocolFlag, selectProtocolFlag,
overwriteDNSFlag, overwriteDNSFlag,

View File

@ -30,7 +30,10 @@ import (
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
const secretValue = "*****" const (
secretValue = "*****"
icmpFunnelTimeout = time.Second * 10
)
var ( var (
developerPortal = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup" developerPortal = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup"
@ -244,6 +247,7 @@ func prepareTunnelConfig(
FeatureSelector: featureSelector, FeatureSelector: featureSelector,
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")), MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag), UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag),
WriteStreamTimeout: c.Duration(writeStreamTimeout),
DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery), DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery),
} }
packetConfig, err := newPacketConfig(c, log) packetConfig, err := newPacketConfig(c, log)
@ -256,6 +260,7 @@ func prepareTunnelConfig(
Ingress: &ingressRules, Ingress: &ingressRules,
WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting), WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting),
ConfigurationFlags: parseConfigFlags(c), ConfigurationFlags: parseConfigFlags(c),
WriteTimeout: c.Duration(writeStreamTimeout),
} }
return tunnelConfig, orchestratorConfig, nil return tunnelConfig, orchestratorConfig, nil
} }
@ -361,7 +366,7 @@ func newPacketConfig(c *cli.Context, logger *zerolog.Logger) (*ingress.GlobalRou
logger.Info().Msgf("ICMP proxy will use %s as source for IPv6", ipv6Src) logger.Info().Msgf("ICMP proxy will use %s as source for IPv6", ipv6Src)
} }
icmpRouter, err := ingress.NewICMPRouter(ipv4Src, ipv6Src, zone, logger) icmpRouter, err := ingress.NewICMPRouter(ipv4Src, ipv6Src, zone, logger, icmpFunnelTimeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -55,7 +55,7 @@ class TestManagement:
config = component_tests_config(cfd_mode=CfdModes.NAMED, run_proxy_dns=False, provide_ingress=False) config = component_tests_config(cfd_mode=CfdModes.NAMED, run_proxy_dns=False, provide_ingress=False)
LOGGER.debug(config) LOGGER.debug(config)
config_path = write_config(tmp_path, config.full_config) config_path = write_config(tmp_path, config.full_config)
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1", "--management-diagnostics"], new_process=True): with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], new_process=True):
wait_tunnel_ready(require_min_connections=1) wait_tunnel_ready(require_min_connections=1)
cfd_cli = CloudflaredCli(config, config_path, LOGGER) cfd_cli = CloudflaredCli(config, config_path, LOGGER)
url = cfd_cli.get_management_url("metrics", config, config_path) url = cfd_cli.get_management_url("metrics", config, config_path)
@ -76,7 +76,7 @@ class TestManagement:
config = component_tests_config(cfd_mode=CfdModes.NAMED, run_proxy_dns=False, provide_ingress=False) config = component_tests_config(cfd_mode=CfdModes.NAMED, run_proxy_dns=False, provide_ingress=False)
LOGGER.debug(config) LOGGER.debug(config)
config_path = write_config(tmp_path, config.full_config) config_path = write_config(tmp_path, config.full_config)
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1", "--management-diagnostics"], new_process=True): with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], new_process=True):
wait_tunnel_ready(require_min_connections=1) wait_tunnel_ready(require_min_connections=1)
cfd_cli = CloudflaredCli(config, config_path, LOGGER) cfd_cli = CloudflaredCli(config, config_path, LOGGER)
url = cfd_cli.get_management_url("debug/pprof/heap", config, config_path) url = cfd_cli.get_management_url("debug/pprof/heap", config, config_path)
@ -97,7 +97,7 @@ class TestManagement:
config = component_tests_config(cfd_mode=CfdModes.NAMED, run_proxy_dns=False, provide_ingress=False) config = component_tests_config(cfd_mode=CfdModes.NAMED, run_proxy_dns=False, provide_ingress=False)
LOGGER.debug(config) LOGGER.debug(config)
config_path = write_config(tmp_path, config.full_config) config_path = write_config(tmp_path, config.full_config)
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], new_process=True): with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1", "--management-diagnostics=false"], new_process=True):
wait_tunnel_ready(require_min_connections=1) wait_tunnel_ready(require_min_connections=1)
cfd_cli = CloudflaredCli(config, config_path, LOGGER) cfd_cli = CloudflaredCli(config, config_path, LOGGER)
url = cfd_cli.get_management_url("metrics", config, config_path) url = cfd_cli.get_management_url("metrics", config, config_path)

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
from conftest import CfdModes from conftest import CfdModes
from constants import METRICS_PORT from constants import METRICS_PORT
import time
from util import LOGGER, start_cloudflared, wait_tunnel_ready, get_quicktunnel_url, send_requests from util import LOGGER, start_cloudflared, wait_tunnel_ready, get_quicktunnel_url, send_requests
class TestQuickTunnels: class TestQuickTunnels:
@ -9,6 +10,7 @@ class TestQuickTunnels:
LOGGER.debug(config) LOGGER.debug(config)
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], cfd_args=["--hello-world"], new_process=True): with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], cfd_args=["--hello-world"], new_process=True):
wait_tunnel_ready(require_min_connections=1) wait_tunnel_ready(require_min_connections=1)
time.sleep(10)
url = get_quicktunnel_url() url = get_quicktunnel_url()
send_requests(url, 3, True) send_requests(url, 3, True)
@ -17,6 +19,7 @@ class TestQuickTunnels:
LOGGER.debug(config) LOGGER.debug(config)
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], cfd_args=["--url", f"http://localhost:{METRICS_PORT}/"], new_process=True): with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], cfd_args=["--url", f"http://localhost:{METRICS_PORT}/"], new_process=True):
wait_tunnel_ready(require_min_connections=1) wait_tunnel_ready(require_min_connections=1)
time.sleep(10)
url = get_quicktunnel_url() url = get_quicktunnel_url()
send_requests(url+"/ready", 3, True) send_requests(url+"/ready", 3, True)

View File

@ -205,6 +205,8 @@ type OriginRequestConfig struct {
HTTPHostHeader *string `yaml:"httpHostHeader" json:"httpHostHeader,omitempty"` HTTPHostHeader *string `yaml:"httpHostHeader" json:"httpHostHeader,omitempty"`
// Hostname on the origin server certificate. // Hostname on the origin server certificate.
OriginServerName *string `yaml:"originServerName" json:"originServerName,omitempty"` OriginServerName *string `yaml:"originServerName" json:"originServerName,omitempty"`
// Auto configure the Hostname on the origin server certificate.
MatchSNIToHost *bool `yaml:"matchSNItoHost" json:"matchSNItoHost,omitempty"`
// Path to the CA for the certificate of your origin. // Path to the CA for the certificate of your origin.
// This option should be used only if your certificate is not signed by Cloudflare. // This option should be used only if your certificate is not signed by Cloudflare.
CAPool *string `yaml:"caPool" json:"caPool,omitempty"` CAPool *string `yaml:"caPool" json:"caPool,omitempty"`

View File

@ -66,6 +66,7 @@ type QUICConnection struct {
connIndex uint8 connIndex uint8
udpUnregisterTimeout time.Duration udpUnregisterTimeout time.Duration
streamWriteTimeout time.Duration
} }
// NewQUICConnection returns a new instance of QUICConnection. // NewQUICConnection returns a new instance of QUICConnection.
@ -82,6 +83,7 @@ func NewQUICConnection(
logger *zerolog.Logger, logger *zerolog.Logger,
packetRouterConfig *ingress.GlobalRouterConfig, packetRouterConfig *ingress.GlobalRouterConfig,
udpUnregisterTimeout time.Duration, udpUnregisterTimeout time.Duration,
streamWriteTimeout time.Duration,
) (*QUICConnection, error) { ) (*QUICConnection, error) {
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger) udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
if err != nil { if err != nil {
@ -117,6 +119,7 @@ func NewQUICConnection(
connOptions: connOptions, connOptions: connOptions,
connIndex: connIndex, connIndex: connIndex,
udpUnregisterTimeout: udpUnregisterTimeout, udpUnregisterTimeout: udpUnregisterTimeout,
streamWriteTimeout: streamWriteTimeout,
}, nil }, nil
} }
@ -195,7 +198,7 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {
func (q *QUICConnection) runStream(quicStream quic.Stream) { func (q *QUICConnection) runStream(quicStream quic.Stream) {
ctx := quicStream.Context() ctx := quicStream.Context()
stream := quicpogs.NewSafeStreamCloser(quicStream) stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close() defer stream.Close()
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that // we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
@ -321,6 +324,7 @@ func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.
session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy) session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy)
if err != nil { if err != nil {
originProxy.Close()
log.Err(err).Str("sessionID", sessionID.String()).Msgf("Failed to register udp session") log.Err(err).Str("sessionID", sessionID.String()).Msgf("Failed to register udp session")
tracing.EndWithErrorStatus(registerSpan, err) tracing.EndWithErrorStatus(registerSpan, err)
return nil, err return nil, err
@ -373,7 +377,7 @@ func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUI
return return
} }
stream := quicpogs.NewSafeStreamCloser(quicStream) stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close() defer stream.Close()
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger) rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger)
if err != nil { if err != nil {

View File

@ -35,6 +35,7 @@ var (
KeepAlivePeriod: 5 * time.Second, KeepAlivePeriod: 5 * time.Second,
EnableDatagrams: true, EnableDatagrams: true,
} }
defaultQUICTimeout = 30 * time.Second
) )
var _ ReadWriteAcker = (*streamReadWriteAcker)(nil) var _ ReadWriteAcker = (*streamReadWriteAcker)(nil)
@ -197,7 +198,7 @@ func quicServer(
quicStream, err := session.OpenStreamSync(context.Background()) quicStream, err := session.OpenStreamSync(context.Background())
require.NoError(t, err) require.NoError(t, err)
stream := quicpogs.NewSafeStreamCloser(quicStream) stream := quicpogs.NewSafeStreamCloser(quicStream, defaultQUICTimeout, &log)
reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream} reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream}
err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...) err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...)
@ -726,6 +727,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QU
&log, &log,
nil, nil,
5*time.Second, 5*time.Second,
0*time.Second,
) )
require.NoError(t, err) require.NoError(t, err)
return qc return qc

View File

@ -4,6 +4,7 @@ ENV GO111MODULE=on \
WORKDIR /go/src/github.com/cloudflare/cloudflared/ WORKDIR /go/src/github.com/cloudflare/cloudflared/
RUN apt-get update RUN apt-get update
COPY . . COPY . .
RUN .teamcity/install-cloudflare-go.sh
# compile cloudflared # compile cloudflared
RUN make cloudflared RUN PATH="/tmp/go/bin:$PATH" make cloudflared
RUN cp /go/src/github.com/cloudflare/cloudflared/cloudflared /usr/local/bin/ RUN cp /go/src/github.com/cloudflare/cloudflared/cloudflared /usr/local/bin/

18
go.mod
View File

@ -4,14 +4,15 @@ go 1.21
require ( require (
github.com/coredns/coredns v1.10.0 github.com/coredns/coredns v1.10.0
github.com/coreos/go-oidc/v3 v3.6.0 github.com/coreos/go-oidc/v3 v3.10.0
github.com/coreos/go-systemd/v22 v22.5.0 github.com/coreos/go-systemd/v22 v22.5.0
github.com/facebookgo/grace v0.0.0-20180706040059-75cf19382434 github.com/facebookgo/grace v0.0.0-20180706040059-75cf19382434
github.com/fortytw2/leaktest v1.3.0
github.com/fsnotify/fsnotify v1.4.9 github.com/fsnotify/fsnotify v1.4.9
github.com/getsentry/sentry-go v0.16.0 github.com/getsentry/sentry-go v0.16.0
github.com/go-chi/chi/v5 v5.0.8 github.com/go-chi/chi/v5 v5.0.8
github.com/go-chi/cors v1.2.1 github.com/go-chi/cors v1.2.1
github.com/go-jose/go-jose/v3 v3.0.0 github.com/go-jose/go-jose/v4 v4.0.1
github.com/gobwas/ws v1.0.4 github.com/gobwas/ws v1.0.4
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3
github.com/google/gopacket v1.1.19 github.com/google/gopacket v1.1.19
@ -24,7 +25,7 @@ require (
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.13.0 github.com/prometheus/client_golang v1.13.0
github.com/prometheus/client_model v0.2.0 github.com/prometheus/client_model v0.2.0
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55 github.com/quic-go/quic-go v0.42.0
github.com/rs/zerolog v1.20.0 github.com/rs/zerolog v1.20.0
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
github.com/urfave/cli/v2 v2.3.0 github.com/urfave/cli/v2 v2.3.0
@ -35,11 +36,11 @@ require (
go.opentelemetry.io/otel/trace v1.21.0 go.opentelemetry.io/otel/trace v1.21.0
go.opentelemetry.io/proto/otlp v1.0.0 go.opentelemetry.io/proto/otlp v1.0.0
go.uber.org/automaxprocs v1.4.0 go.uber.org/automaxprocs v1.4.0
golang.org/x/crypto v0.16.0 golang.org/x/crypto v0.21.0
golang.org/x/net v0.19.0 golang.org/x/net v0.21.0
golang.org/x/sync v0.4.0 golang.org/x/sync v0.4.0
golang.org/x/sys v0.15.0 golang.org/x/sys v0.18.0
golang.org/x/term v0.15.0 golang.org/x/term v0.18.0
google.golang.org/protobuf v1.31.0 google.golang.org/protobuf v1.31.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
@ -81,10 +82,9 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect
go.opentelemetry.io/otel/metric v1.21.0 // indirect go.opentelemetry.io/otel/metric v1.21.0 // indirect
go.uber.org/mock v0.3.0 // indirect go.uber.org/mock v0.4.0 // indirect
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.11.0 // indirect golang.org/x/mod v0.11.0 // indirect
golang.org/x/oauth2 v0.13.0 // indirect golang.org/x/oauth2 v0.13.0 // indirect

39
go.sum
View File

@ -60,8 +60,8 @@ github.com/coredns/caddy v1.1.1 h1:2eYKZT7i6yxIfGP3qLJoJ7HAsDJqYB+X68g4NYjSrE0=
github.com/coredns/caddy v1.1.1/go.mod h1:A6ntJQlAWuQfFlsd9hvigKbo2WS0VUs2l1e2F+BawD4= github.com/coredns/caddy v1.1.1/go.mod h1:A6ntJQlAWuQfFlsd9hvigKbo2WS0VUs2l1e2F+BawD4=
github.com/coredns/coredns v1.10.0 h1:jCfuWsBjTs0dapkkhISfPCzn5LqvSRtrFtaf/Tjj4DI= github.com/coredns/coredns v1.10.0 h1:jCfuWsBjTs0dapkkhISfPCzn5LqvSRtrFtaf/Tjj4DI=
github.com/coredns/coredns v1.10.0/go.mod h1:CIfRU5TgpuoIiJBJ4XrofQzfFQpPFh32ERpUevrSlaw= github.com/coredns/coredns v1.10.0/go.mod h1:CIfRU5TgpuoIiJBJ4XrofQzfFQpPFh32ERpUevrSlaw=
github.com/coreos/go-oidc/v3 v3.6.0 h1:AKVxfYw1Gmkn/w96z0DbT/B/xFnzTd3MkZvWLjF4n/o= github.com/coreos/go-oidc/v3 v3.10.0 h1:tDnXHnLyiTVyT/2zLDGj09pFPkhND8Gl8lnTRhoEaJU=
github.com/coreos/go-oidc/v3 v3.6.0/go.mod h1:ZpHUsHBucTUj6WOkrP4E20UPynbLZzhTQ1XKCXkxyPc= github.com/coreos/go-oidc/v3 v3.10.0/go.mod h1:5j11xcw0D3+SGxn6Z/WFADsgcWVMyNAlSQupk0KK3ac=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
@ -88,6 +88,8 @@ github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870 h1:E2s37DuLxFhQD
github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870/go.mod h1:5tD+neXqOorC30/tWg0LCSkrqj/AR6gu8yY8/fpw1q0= github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870/go.mod h1:5tD+neXqOorC30/tWg0LCSkrqj/AR6gu8yY8/fpw1q0=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/getsentry/sentry-go v0.16.0 h1:owk+S+5XcgJLlGR/3+3s6N4d+uKwqYvh/eS0AIMjPWo= github.com/getsentry/sentry-go v0.16.0 h1:owk+S+5XcgJLlGR/3+3s6N4d+uKwqYvh/eS0AIMjPWo=
@ -106,8 +108,8 @@ github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3Bop
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-jose/go-jose/v3 v3.0.0 h1:s6rrhirfEP/CGIoc6p+PZAeogN2SxKav6Wp7+dyMWVo= github.com/go-jose/go-jose/v4 v4.0.1 h1:QVEPDE3OluqXBQZDcnNvQrInro2h0e4eqNbnZSWqS6U=
github.com/go-jose/go-jose/v3 v3.0.0/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= github.com/go-jose/go-jose/v4 v4.0.1/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@ -320,10 +322,8 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo= github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4= github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5nfFs= github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM=
github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55 h1:I4N3ZRnkZPbDN935Tg8QDf8fRpHp3bZ0U0/L42jBgNE=
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55/go.mod h1:PeN7kuVJ4xZbxSv/4OX6S1USOX8MJvydwpTx31vx60c=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
@ -381,18 +381,17 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0= go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0=
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q= go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -464,8 +463,8 @@ golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -534,12 +533,12 @@ golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -553,6 +552,8 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=

View File

@ -32,6 +32,7 @@ const (
ProxyKeepAliveTimeoutFlag = "proxy-keepalive-timeout" ProxyKeepAliveTimeoutFlag = "proxy-keepalive-timeout"
HTTPHostHeaderFlag = "http-host-header" HTTPHostHeaderFlag = "http-host-header"
OriginServerNameFlag = "origin-server-name" OriginServerNameFlag = "origin-server-name"
MatchSNIToHostFlag = "match-sni-to-host"
NoTLSVerifyFlag = "no-tls-verify" NoTLSVerifyFlag = "no-tls-verify"
NoChunkedEncodingFlag = "no-chunked-encoding" NoChunkedEncodingFlag = "no-chunked-encoding"
ProxyAddressFlag = "proxy-address" ProxyAddressFlag = "proxy-address"
@ -118,6 +119,7 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
var keepAliveTimeout = defaultKeepAliveTimeout var keepAliveTimeout = defaultKeepAliveTimeout
var httpHostHeader string var httpHostHeader string
var originServerName string var originServerName string
var matchSNItoHost bool
var caPool string var caPool string
var noTLSVerify bool var noTLSVerify bool
var disableChunkedEncoding bool var disableChunkedEncoding bool
@ -150,6 +152,9 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
if flag := OriginServerNameFlag; c.IsSet(flag) { if flag := OriginServerNameFlag; c.IsSet(flag) {
originServerName = c.String(flag) originServerName = c.String(flag)
} }
if flag := MatchSNIToHostFlag; c.IsSet(flag) {
matchSNItoHost = c.Bool(flag)
}
if flag := tlsconfig.OriginCAPoolFlag; c.IsSet(flag) { if flag := tlsconfig.OriginCAPoolFlag; c.IsSet(flag) {
caPool = c.String(flag) caPool = c.String(flag)
} }
@ -185,6 +190,7 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
KeepAliveTimeout: keepAliveTimeout, KeepAliveTimeout: keepAliveTimeout,
HTTPHostHeader: httpHostHeader, HTTPHostHeader: httpHostHeader,
OriginServerName: originServerName, OriginServerName: originServerName,
MatchSNIToHost: matchSNItoHost,
CAPool: caPool, CAPool: caPool,
NoTLSVerify: noTLSVerify, NoTLSVerify: noTLSVerify,
DisableChunkedEncoding: disableChunkedEncoding, DisableChunkedEncoding: disableChunkedEncoding,
@ -229,6 +235,9 @@ func originRequestFromConfig(c config.OriginRequestConfig) OriginRequestConfig {
if c.OriginServerName != nil { if c.OriginServerName != nil {
out.OriginServerName = *c.OriginServerName out.OriginServerName = *c.OriginServerName
} }
if c.MatchSNIToHost != nil {
out.MatchSNIToHost = *c.MatchSNIToHost
}
if c.CAPool != nil { if c.CAPool != nil {
out.CAPool = *c.CAPool out.CAPool = *c.CAPool
} }
@ -287,6 +296,8 @@ type OriginRequestConfig struct {
HTTPHostHeader string `yaml:"httpHostHeader" json:"httpHostHeader"` HTTPHostHeader string `yaml:"httpHostHeader" json:"httpHostHeader"`
// Hostname on the origin server certificate. // Hostname on the origin server certificate.
OriginServerName string `yaml:"originServerName" json:"originServerName"` OriginServerName string `yaml:"originServerName" json:"originServerName"`
// Auto configure the Hostname on the origin server certificate.
MatchSNIToHost bool `yaml:"matchSNItoHost" json:"matchSNItoHost"`
// Path to the CA for the certificate of your origin. // Path to the CA for the certificate of your origin.
// This option should be used only if your certificate is not signed by Cloudflare. // This option should be used only if your certificate is not signed by Cloudflare.
CAPool string `yaml:"caPool" json:"caPool"` CAPool string `yaml:"caPool" json:"caPool"`
@ -362,6 +373,12 @@ func (defaults *OriginRequestConfig) setOriginServerName(overrides config.Origin
} }
} }
func (defaults *OriginRequestConfig) setMatchSNIToHost(overrides config.OriginRequestConfig) {
if val := overrides.MatchSNIToHost; val != nil {
defaults.MatchSNIToHost = *val
}
}
func (defaults *OriginRequestConfig) setCAPool(overrides config.OriginRequestConfig) { func (defaults *OriginRequestConfig) setCAPool(overrides config.OriginRequestConfig) {
if val := overrides.CAPool; val != nil { if val := overrides.CAPool; val != nil {
defaults.CAPool = *val defaults.CAPool = *val
@ -447,6 +464,7 @@ func setConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfi
cfg.setTCPKeepAlive(overrides) cfg.setTCPKeepAlive(overrides)
cfg.setHTTPHostHeader(overrides) cfg.setHTTPHostHeader(overrides)
cfg.setOriginServerName(overrides) cfg.setOriginServerName(overrides)
cfg.setMatchSNIToHost(overrides)
cfg.setCAPool(overrides) cfg.setCAPool(overrides)
cfg.setNoTLSVerify(overrides) cfg.setNoTLSVerify(overrides)
cfg.setDisableChunkedEncoding(overrides) cfg.setDisableChunkedEncoding(overrides)
@ -501,6 +519,7 @@ func ConvertToRawOriginConfig(c OriginRequestConfig) config.OriginRequestConfig
KeepAliveTimeout: keepAliveTimeout, KeepAliveTimeout: keepAliveTimeout,
HTTPHostHeader: emptyStringToNil(c.HTTPHostHeader), HTTPHostHeader: emptyStringToNil(c.HTTPHostHeader),
OriginServerName: emptyStringToNil(c.OriginServerName), OriginServerName: emptyStringToNil(c.OriginServerName),
MatchSNIToHost: defaultBoolToNil(c.MatchSNIToHost),
CAPool: emptyStringToNil(c.CAPool), CAPool: emptyStringToNil(c.CAPool),
NoTLSVerify: defaultBoolToNil(c.NoTLSVerify), NoTLSVerify: defaultBoolToNil(c.NoTLSVerify),
DisableChunkedEncoding: defaultBoolToNil(c.DisableChunkedEncoding), DisableChunkedEncoding: defaultBoolToNil(c.DisableChunkedEncoding),

View File

@ -0,0 +1,7 @@
package ingress
import "github.com/cloudflare/cloudflared/logger"
var (
TestLogger = logger.Create(nil)
)

View File

@ -131,7 +131,7 @@ func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idle
} }
func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *packetResponder) error { func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *packetResponder) error {
ctx, span := responder.requestSpan(ctx, pk) _, span := responder.requestSpan(ctx, pk)
defer responder.exportSpan() defer responder.exportSpan()
originalEcho, err := getICMPEcho(pk.Message) originalEcho, err := getICMPEcho(pk.Message)
@ -139,10 +139,8 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
tracing.EndWithErrorStatus(span, err) tracing.EndWithErrorStatus(span, err)
return err return err
} }
span.SetAttributes( observeICMPRequest(ip.logger, span, pk.Src.String(), pk.Dst.String(), originalEcho.ID, originalEcho.Seq)
attribute.Int("originalEchoID", originalEcho.ID),
attribute.Int("seq", originalEcho.Seq),
)
echoIDTrackerKey := flow3Tuple{ echoIDTrackerKey := flow3Tuple{
srcIP: pk.Src, srcIP: pk.Src,
dstIP: pk.Dst, dstIP: pk.Dst,
@ -189,6 +187,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
tracing.EndWithErrorStatus(span, err) tracing.EndWithErrorStatus(span, err)
return err return err
} }
err = icmpFlow.sendToDst(pk.Dst, pk.Message) err = icmpFlow.sendToDst(pk.Dst, pk.Message)
if err != nil { if err != nil {
tracing.EndWithErrorStatus(span, err) tracing.EndWithErrorStatus(span, err)
@ -269,15 +268,12 @@ func (ip *icmpProxy) sendReply(ctx context.Context, reply *echoReply) error {
_, span := icmpFlow.responder.replySpan(ctx, ip.logger) _, span := icmpFlow.responder.replySpan(ctx, ip.logger)
defer icmpFlow.responder.exportSpan() defer icmpFlow.responder.exportSpan()
span.SetAttributes(
attribute.String("dst", reply.from.String()),
attribute.Int("echoID", reply.echo.ID),
attribute.Int("seq", reply.echo.Seq),
attribute.Int("originalEchoID", icmpFlow.originalEchoID),
)
if err := icmpFlow.returnToSrc(reply); err != nil { if err := icmpFlow.returnToSrc(reply); err != nil {
tracing.EndWithErrorStatus(span, err) tracing.EndWithErrorStatus(span, err)
return err
} }
observeICMPReply(ip.logger, span, reply.from.String(), reply.echo.ID, reply.echo.Seq)
span.SetAttributes(attribute.Int("originalEchoID", icmpFlow.originalEchoID))
tracing.End(span) tracing.End(span)
return nil return nil
} }

View File

@ -78,19 +78,19 @@ func checkInPingGroup() error {
if err != nil { if err != nil {
return err return err
} }
groupID := os.Getgid() groupID := uint64(os.Getegid())
// Example content: 999 59999 // Example content: 999 59999
found := findGroupIDRegex.FindAll(file, 2) found := findGroupIDRegex.FindAll(file, 2)
if len(found) == 2 { if len(found) == 2 {
groupMin, err := strconv.ParseInt(string(found[0]), 10, 32) groupMin, err := strconv.ParseUint(string(found[0]), 10, 32)
if err != nil { if err != nil {
return errors.Wrapf(err, "failed to determine minimum ping group ID") return errors.Wrapf(err, "failed to determine minimum ping group ID")
} }
groupMax, err := strconv.ParseInt(string(found[1]), 10, 32) groupMax, err := strconv.ParseUint(string(found[1]), 10, 32)
if err != nil { if err != nil {
return errors.Wrapf(err, "failed to determine minimum ping group ID") return errors.Wrapf(err, "failed to determine maximum ping group ID")
} }
if groupID < int(groupMin) || groupID > int(groupMax) { if groupID < groupMin || groupID > groupMax {
return fmt.Errorf("Group ID %d is not between ping group %d to %d", groupID, groupMin, groupMax) return fmt.Errorf("Group ID %d is not between ping group %d to %d", groupID, groupMin, groupMax)
} }
return nil return nil
@ -107,10 +107,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
tracing.EndWithErrorStatus(span, err) tracing.EndWithErrorStatus(span, err)
return err return err
} }
span.SetAttributes( observeICMPRequest(ip.logger, span, pk.Src.String(), pk.Dst.String(), originalEcho.ID, originalEcho.Seq)
attribute.Int("originalEchoID", originalEcho.ID),
attribute.Int("seq", originalEcho.Seq),
)
shouldReplaceFunnelFunc := createShouldReplaceFunnelFunc(ip.logger, responder.datagramMuxer, pk, originalEcho.ID) shouldReplaceFunnelFunc := createShouldReplaceFunnelFunc(ip.logger, responder.datagramMuxer, pk, originalEcho.ID)
newFunnelFunc := func() (packet.Funnel, error) { newFunnelFunc := func() (packet.Funnel, error) {
@ -156,14 +153,8 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
Int("originalEchoID", originalEcho.ID). Int("originalEchoID", originalEcho.ID).
Msg("New flow") Msg("New flow")
go func() { go func() {
defer ip.srcFunnelTracker.Unregister(funnelID, icmpFlow) ip.listenResponse(ctx, icmpFlow)
if err := ip.listenResponse(ctx, icmpFlow); err != nil { ip.srcFunnelTracker.Unregister(funnelID, icmpFlow)
ip.logger.Debug().Err(err).
Str("src", pk.Src.String()).
Str("dst", pk.Dst.String()).
Int("originalEchoID", originalEcho.ID).
Msg("Failed to listen for ICMP echo response")
}
}() }()
} }
if err := icmpFlow.sendToDst(pk.Dst, pk.Message); err != nil { if err := icmpFlow.sendToDst(pk.Dst, pk.Message); err != nil {
@ -179,17 +170,17 @@ func (ip *icmpProxy) Serve(ctx context.Context) error {
return ctx.Err() return ctx.Err()
} }
func (ip *icmpProxy) listenResponse(ctx context.Context, flow *icmpEchoFlow) error { func (ip *icmpProxy) listenResponse(ctx context.Context, flow *icmpEchoFlow) {
buf := make([]byte, mtu) buf := make([]byte, mtu)
for { for {
retryable, err := ip.handleResponse(ctx, flow, buf) if done := ip.handleResponse(ctx, flow, buf); done {
if err != nil && !retryable { return
return err
} }
} }
} }
func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf []byte) (retryableErr bool, err error) { // Listens for ICMP response and handles error logging
func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf []byte) (done bool) {
_, span := flow.responder.replySpan(ctx, ip.logger) _, span := flow.responder.replySpan(ctx, ip.logger)
defer flow.responder.exportSpan() defer flow.responder.exportSpan()
@ -199,33 +190,36 @@ func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf
n, from, err := flow.originConn.ReadFrom(buf) n, from, err := flow.originConn.ReadFrom(buf)
if err != nil { if err != nil {
if flow.IsClosed() {
tracing.EndWithErrorStatus(span, fmt.Errorf("flow was closed"))
return true
}
ip.logger.Error().Err(err).Str("socket", flow.originConn.LocalAddr().String()).Msg("Failed to read from ICMP socket")
tracing.EndWithErrorStatus(span, err) tracing.EndWithErrorStatus(span, err)
return false, err return true
} }
reply, err := parseReply(from, buf[:n]) reply, err := parseReply(from, buf[:n])
if err != nil { if err != nil {
ip.logger.Error().Err(err).Str("dst", from.String()).Msg("Failed to parse ICMP reply") ip.logger.Error().Err(err).Str("dst", from.String()).Msg("Failed to parse ICMP reply")
tracing.EndWithErrorStatus(span, err) tracing.EndWithErrorStatus(span, err)
return true, err return false
} }
if !isEchoReply(reply.msg) { if !isEchoReply(reply.msg) {
err := fmt.Errorf("Expect ICMP echo reply, got %s", reply.msg.Type) err := fmt.Errorf("Expect ICMP echo reply, got %s", reply.msg.Type)
ip.logger.Debug().Str("dst", from.String()).Msgf("Drop ICMP %s from reply", reply.msg.Type) ip.logger.Debug().Str("dst", from.String()).Msgf("Drop ICMP %s from reply", reply.msg.Type)
tracing.EndWithErrorStatus(span, err) tracing.EndWithErrorStatus(span, err)
return true, err return false
} }
span.SetAttributes(
attribute.String("dst", reply.from.String()),
attribute.Int("echoID", reply.echo.ID),
attribute.Int("seq", reply.echo.Seq),
)
if err := flow.returnToSrc(reply); err != nil { if err := flow.returnToSrc(reply); err != nil {
ip.logger.Debug().Err(err).Str("dst", from.String()).Msg("Failed to send ICMP reply") ip.logger.Error().Err(err).Str("dst", from.String()).Msg("Failed to send ICMP reply")
tracing.EndWithErrorStatus(span, err) tracing.EndWithErrorStatus(span, err)
return true, err return false
} }
observeICMPReply(ip.logger, span, from.String(), reply.echo.ID, reply.echo.Seq)
tracing.End(span) tracing.End(span)
return true, nil return false
} }
// Only linux uses flow3Tuple as FunnelID // Only linux uses flow3Tuple as FunnelID

View File

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sync/atomic"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -46,6 +47,7 @@ type flow3Tuple struct {
type icmpEchoFlow struct { type icmpEchoFlow struct {
*packet.ActivityTracker *packet.ActivityTracker
closeCallback func() error closeCallback func() error
closed *atomic.Bool
src netip.Addr src netip.Addr
originConn *icmp.PacketConn originConn *icmp.PacketConn
responder *packetResponder responder *packetResponder
@ -59,6 +61,7 @@ func newICMPEchoFlow(src netip.Addr, closeCallback func() error, originConn *icm
return &icmpEchoFlow{ return &icmpEchoFlow{
ActivityTracker: packet.NewActivityTracker(), ActivityTracker: packet.NewActivityTracker(),
closeCallback: closeCallback, closeCallback: closeCallback,
closed: &atomic.Bool{},
src: src, src: src,
originConn: originConn, originConn: originConn,
responder: responder, responder: responder,
@ -86,9 +89,14 @@ func (ief *icmpEchoFlow) Equal(other packet.Funnel) bool {
} }
func (ief *icmpEchoFlow) Close() error { func (ief *icmpEchoFlow) Close() error {
ief.closed.Store(true)
return ief.closeCallback() return ief.closeCallback()
} }
func (ief *icmpEchoFlow) IsClosed() bool {
return ief.closed.Load()
}
// sendToDst rewrites the echo ID to the one assigned to this flow // sendToDst rewrites the echo ID to the one assigned to this flow
func (ief *icmpEchoFlow) sendToDst(dst netip.Addr, msg *icmp.Message) error { func (ief *icmpEchoFlow) sendToDst(dst netip.Addr, msg *icmp.Message) error {
ief.UpdateLastActive() ief.UpdateLastActive()

View File

@ -8,6 +8,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/fortytw2/leaktest"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -18,6 +19,8 @@ import (
) )
func TestFunnelIdleTimeout(t *testing.T) { func TestFunnelIdleTimeout(t *testing.T) {
defer leaktest.Check(t)()
const ( const (
idleTimeout = time.Second idleTimeout = time.Second
echoID = 42573 echoID = 42573
@ -73,13 +76,16 @@ func TestFunnelIdleTimeout(t *testing.T) {
require.NoError(t, proxy.Request(ctx, &pk, &newResponder)) require.NoError(t, proxy.Request(ctx, &pk, &newResponder))
validateEchoFlow(t, <-newMuxer.cfdToEdge, &pk) validateEchoFlow(t, <-newMuxer.cfdToEdge, &pk)
time.Sleep(idleTimeout * 2)
cancel() cancel()
<-proxyDone <-proxyDone
} }
func TestReuseFunnel(t *testing.T) { func TestReuseFunnel(t *testing.T) {
defer leaktest.Check(t)()
const ( const (
idleTimeout = time.Second idleTimeout = time.Millisecond * 100
echoID = 42573 echoID = 42573
startSeq = 8129 startSeq = 8129
) )
@ -135,6 +141,8 @@ func TestReuseFunnel(t *testing.T) {
require.True(t, found) require.True(t, found)
require.Equal(t, funnel1, funnel2) require.Equal(t, funnel1, funnel2)
time.Sleep(idleTimeout * 2)
cancel() cancel()
<-proxyDone <-proxyDone
} }

View File

@ -281,10 +281,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
if err != nil { if err != nil {
return err return err
} }
requestSpan.SetAttributes( observeICMPRequest(ip.logger, requestSpan, pk.Src.String(), pk.Dst.String(), echo.ID, echo.Seq)
attribute.Int("originalEchoID", echo.ID),
attribute.Int("seq", echo.Seq),
)
resp, err := ip.icmpEchoRoundtrip(pk.Dst, echo) resp, err := ip.icmpEchoRoundtrip(pk.Dst, echo)
if err != nil { if err != nil {
@ -296,17 +293,17 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
responder.exportSpan() responder.exportSpan()
_, replySpan := responder.replySpan(ctx, ip.logger) _, replySpan := responder.replySpan(ctx, ip.logger)
replySpan.SetAttributes(
attribute.Int("originalEchoID", echo.ID),
attribute.Int("seq", echo.Seq),
attribute.Int64("rtt", int64(resp.rtt())),
attribute.String("status", resp.status().String()),
)
err = ip.handleEchoReply(pk, echo, resp, responder) err = ip.handleEchoReply(pk, echo, resp, responder)
if err != nil { if err != nil {
ip.logger.Err(err).Msg("Failed to send ICMP reply")
tracing.EndWithErrorStatus(replySpan, err) tracing.EndWithErrorStatus(replySpan, err)
return errors.Wrap(err, "failed to handle ICMP echo reply") return errors.Wrap(err, "failed to handle ICMP echo reply")
} }
observeICMPReply(ip.logger, replySpan, pk.Dst.String(), echo.ID, echo.Seq)
replySpan.SetAttributes(
attribute.Int64("rtt", int64(resp.rtt())),
attribute.String("status", resp.status().String()),
)
tracing.End(replySpan) tracing.End(replySpan)
return nil return nil
} }

View File

@ -0,0 +1,126 @@
package middleware
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/json"
"fmt"
"net/http/httptest"
"testing"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
issuer = fmt.Sprintf(cloudflareAccessCertsURL, "testteam")
)
type accessTokenClaims struct {
Email string `json:"email"`
Type string `json:"type"`
jwt.Claims
}
func TestJWTValidator(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
issued := time.Now()
claims := accessTokenClaims{
Email: "test@example.com",
Type: "app",
Claims: jwt.Claims{
Issuer: issuer,
Subject: "ee239b7a-e3e6-4173-972a-8fbe9d99c04f",
Audience: []string{""},
Expiry: jwt.NewNumericDate(issued.Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(issued),
},
}
token := signToken(t, claims, key)
req.Header.Add(headerKeyAccessJWTAssertion, token)
keySet := oidc.StaticKeySet{PublicKeys: []crypto.PublicKey{key.Public()}}
config := &oidc.Config{
SkipClientIDCheck: true,
SupportedSigningAlgs: []string{string(jose.ES256)},
}
verifier := oidc.NewVerifier(issuer, &keySet, config)
tests := []struct {
name string
audTags []string
aud jwt.Audience
error bool
}{
{
name: "valid",
audTags: []string{
"0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
},
aud: jwt.Audience{"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba"},
error: false,
},
{
name: "invalid no match",
audTags: []string{
"0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
},
aud: jwt.Audience{"09dc377143841843ecca28b196bdb1ec1675af38c8b7b60c7def5876c8877157"},
error: true,
},
{
name: "invalid empty check",
audTags: []string{},
aud: jwt.Audience{"09dc377143841843ecca28b196bdb1ec1675af38c8b7b60c7def5876c8877157"},
error: true,
},
{
name: "invalid absent aud",
audTags: []string{
"0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
},
aud: jwt.Audience{""},
error: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
validator := JWTValidator{
IDTokenVerifier: verifier,
audTags: test.audTags,
}
claims.Audience = test.aud
token := signToken(t, claims, key)
req.Header.Set(headerKeyAccessJWTAssertion, token)
result, err := validator.Handle(context.Background(), req)
assert.NoError(t, err)
assert.Equal(t, test.error, result.ShouldFilterRequest)
})
}
}
func signToken(t *testing.T, token accessTokenClaims, key *ecdsa.PrivateKey) string {
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, &jose.SignerOptions{})
require.NoError(t, err)
payload, err := json.Marshal(token)
require.NoError(t, err)
jws, err := signer.Sign(payload)
require.NoError(t, err)
jwt, err := jws.CompactSerialize()
require.NoError(t, err)
return jwt
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"io" "io"
"net" "net"
"time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -31,15 +32,32 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze
// tcpConnection is an OriginConnection that directly streams to raw TCP. // tcpConnection is an OriginConnection that directly streams to raw TCP.
type tcpConnection struct { type tcpConnection struct {
conn net.Conn net.Conn
writeTimeout time.Duration
logger *zerolog.Logger
} }
func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) { func (tc *tcpConnection) Stream(_ context.Context, tunnelConn io.ReadWriter, _ *zerolog.Logger) {
stream.Pipe(tunnelConn, tc.conn, log) stream.Pipe(tunnelConn, tc, tc.logger)
}
func (tc *tcpConnection) Write(b []byte) (int, error) {
if tc.writeTimeout > 0 {
if err := tc.Conn.SetWriteDeadline(time.Now().Add(tc.writeTimeout)); err != nil {
tc.logger.Err(err).Msg("Error setting write deadline for TCP connection")
}
}
nBytes, err := tc.Conn.Write(b)
if err != nil {
tc.logger.Err(err).Msg("Error writing to the TCP connection")
}
return nBytes, err
} }
func (tc *tcpConnection) Close() { func (tc *tcpConnection) Close() {
tc.conn.Close() tc.Conn.Close()
} }
// tcpOverWSConnection is an OriginConnection that streams to TCP over WS. // tcpOverWSConnection is an OriginConnection that streams to TCP over WS.

View File

@ -19,7 +19,6 @@ import (
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/socks" "github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/stream" "github.com/cloudflare/cloudflared/stream"
"github.com/cloudflare/cloudflared/websocket" "github.com/cloudflare/cloudflared/websocket"
@ -31,7 +30,6 @@ const (
) )
var ( var (
testLogger = logger.Create(nil)
testMessage = []byte("TestStreamOriginConnection") testMessage = []byte("TestStreamOriginConnection")
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage)) testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
) )
@ -39,7 +37,8 @@ var (
func TestStreamTCPConnection(t *testing.T) { func TestStreamTCPConnection(t *testing.T) {
cfdConn, originConn := net.Pipe() cfdConn, originConn := net.Pipe()
tcpConn := tcpConnection{ tcpConn := tcpConnection{
conn: cfdConn, Conn: cfdConn,
writeTimeout: 30 * time.Second,
} }
eyeballConn, edgeConn := net.Pipe() eyeballConn, edgeConn := net.Pipe()
@ -66,7 +65,7 @@ func TestStreamTCPConnection(t *testing.T) {
return nil return nil
}) })
tcpConn.Stream(ctx, edgeConn, testLogger) tcpConn.Stream(ctx, edgeConn, TestLogger)
require.NoError(t, errGroup.Wait()) require.NoError(t, errGroup.Wait())
} }
@ -93,7 +92,7 @@ func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
return nil return nil
}) })
tcpOverWSConn.Stream(ctx, edgeConn, testLogger) tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
require.NoError(t, errGroup.Wait()) require.NoError(t, errGroup.Wait())
} }
@ -147,7 +146,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
errGroup, ctx := errgroup.WithContext(ctx) errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error { errGroup.Go(func() error {
tcpOverWSConn.Stream(ctx, edgeConn, testLogger) tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
return nil return nil
}) })
@ -159,7 +158,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer wsForwarderInConn.Close() defer wsForwarderInConn.Close()
stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger) stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, TestLogger)
return nil return nil
}) })
@ -209,7 +208,7 @@ func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
originConn.Close() originConn.Close()
}() }()
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond) ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
tcpOverWSConn.Stream(ctx, eyeballConn, testLogger) tcpOverWSConn.Stream(ctx, eyeballConn, TestLogger)
}) })
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
defer server.Close() defer server.Close()

View File

@ -7,6 +7,8 @@ import (
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/net/icmp" "golang.org/x/net/icmp"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
@ -15,8 +17,6 @@ import (
) )
const ( const (
// funnelIdleTimeout controls how long to wait to close a funnel without send/return
funnelIdleTimeout = time.Second * 10
mtu = 1500 mtu = 1500
// icmpRequestTimeoutMs controls how long to wait for a reply // icmpRequestTimeoutMs controls how long to wait for a reply
icmpRequestTimeoutMs = 1000 icmpRequestTimeoutMs = 1000
@ -32,8 +32,9 @@ type icmpRouter struct {
} }
// NewICMPRouter doesn't return an error if either ipv4 proxy or ipv6 proxy can be created. The machine might only // NewICMPRouter doesn't return an error if either ipv4 proxy or ipv6 proxy can be created. The machine might only
// support one of them // support one of them.
func NewICMPRouter(ipv4Addr, ipv6Addr netip.Addr, ipv6Zone string, logger *zerolog.Logger) (*icmpRouter, error) { // funnelIdleTimeout controls how long to wait to close a funnel without send/return
func NewICMPRouter(ipv4Addr, ipv6Addr netip.Addr, ipv6Zone string, logger *zerolog.Logger, funnelIdleTimeout time.Duration) (*icmpRouter, error) {
ipv4Proxy, ipv4Err := newICMPProxy(ipv4Addr, "", logger, funnelIdleTimeout) ipv4Proxy, ipv4Err := newICMPProxy(ipv4Addr, "", logger, funnelIdleTimeout)
ipv6Proxy, ipv6Err := newICMPProxy(ipv6Addr, ipv6Zone, logger, funnelIdleTimeout) ipv6Proxy, ipv6Err := newICMPProxy(ipv6Addr, ipv6Zone, logger, funnelIdleTimeout)
if ipv4Err != nil && ipv6Err != nil { if ipv4Err != nil && ipv6Err != nil {
@ -102,3 +103,25 @@ func getICMPEcho(msg *icmp.Message) (*icmp.Echo, error) {
func isEchoReply(msg *icmp.Message) bool { func isEchoReply(msg *icmp.Message) bool {
return msg.Type == ipv4.ICMPTypeEchoReply || msg.Type == ipv6.ICMPTypeEchoReply return msg.Type == ipv4.ICMPTypeEchoReply || msg.Type == ipv6.ICMPTypeEchoReply
} }
func observeICMPRequest(logger *zerolog.Logger, span trace.Span, src string, dst string, echoID int, seq int) {
logger.Debug().
Str("src", src).
Str("dst", dst).
Int("originalEchoID", echoID).
Int("originalEchoSeq", seq).
Msg("Received ICMP request")
span.SetAttributes(
attribute.Int("originalEchoID", echoID),
attribute.Int("seq", seq),
)
}
func observeICMPReply(logger *zerolog.Logger, span trace.Span, dst string, echoID int, seq int) {
logger.Debug().Str("dst", dst).Int("echoID", echoID).Int("seq", seq).Msg("Sent ICMP reply to edge")
span.SetAttributes(
attribute.String("dst", dst),
attribute.Int("echoID", echoID),
attribute.Int("seq", seq),
)
}

View File

@ -9,7 +9,9 @@ import (
"strings" "strings"
"sync" "sync"
"testing" "testing"
"time"
"github.com/fortytw2/leaktest"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -26,6 +28,7 @@ var (
noopLogger = zerolog.Nop() noopLogger = zerolog.Nop()
localhostIP = netip.MustParseAddr("127.0.0.1") localhostIP = netip.MustParseAddr("127.0.0.1")
localhostIPv6 = netip.MustParseAddr("::1") localhostIPv6 = netip.MustParseAddr("::1")
testFunnelIdleTimeout = time.Millisecond * 10
) )
// TestICMPProxyEcho makes sure we can send ICMP echo via the Request method and receives response via the // TestICMPProxyEcho makes sure we can send ICMP echo via the Request method and receives response via the
@ -40,12 +43,14 @@ func TestICMPRouterEcho(t *testing.T) {
} }
func testICMPRouterEcho(t *testing.T, sendIPv4 bool) { func testICMPRouterEcho(t *testing.T, sendIPv4 bool) {
defer leaktest.Check(t)()
const ( const (
echoID = 36571 echoID = 36571
endSeq = 20 endSeq = 20
) )
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger) router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger, testFunnelIdleTimeout)
require.NoError(t, err) require.NoError(t, err)
proxyDone := make(chan struct{}) proxyDone := make(chan struct{})
@ -97,14 +102,19 @@ func testICMPRouterEcho(t *testing.T, sendIPv4 bool) {
validateEchoFlow(t, <-muxer.cfdToEdge, &pk) validateEchoFlow(t, <-muxer.cfdToEdge, &pk)
} }
} }
// Make sure funnel cleanup kicks in
time.Sleep(testFunnelIdleTimeout * 2)
cancel() cancel()
<-proxyDone <-proxyDone
} }
func TestTraceICMPRouterEcho(t *testing.T) { func TestTraceICMPRouterEcho(t *testing.T) {
defer leaktest.Check(t)()
tracingCtx := "ec31ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0:1" tracingCtx := "ec31ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0:1"
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger) router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger, testFunnelIdleTimeout)
require.NoError(t, err) require.NoError(t, err)
proxyDone := make(chan struct{}) proxyDone := make(chan struct{})
@ -196,6 +206,7 @@ func TestTraceICMPRouterEcho(t *testing.T) {
default: default:
} }
time.Sleep(testFunnelIdleTimeout * 2)
cancel() cancel()
<-proxyDone <-proxyDone
} }
@ -203,12 +214,14 @@ func TestTraceICMPRouterEcho(t *testing.T) {
// TestConcurrentRequests makes sure icmpRouter can send concurrent requests to the same destination with different // TestConcurrentRequests makes sure icmpRouter can send concurrent requests to the same destination with different
// echo ID. This simulates concurrent ping to the same destination. // echo ID. This simulates concurrent ping to the same destination.
func TestConcurrentRequestsToSameDst(t *testing.T) { func TestConcurrentRequestsToSameDst(t *testing.T) {
defer leaktest.Check(t)()
const ( const (
concurrentPings = 5 concurrentPings = 5
endSeq = 5 endSeq = 5
) )
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger) router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger, testFunnelIdleTimeout)
require.NoError(t, err) require.NoError(t, err)
proxyDone := make(chan struct{}) proxyDone := make(chan struct{})
@ -282,12 +295,16 @@ func TestConcurrentRequestsToSameDst(t *testing.T) {
}() }()
} }
wg.Wait() wg.Wait()
time.Sleep(testFunnelIdleTimeout * 2)
cancel() cancel()
<-proxyDone <-proxyDone
} }
// TestICMPProxyRejectNotEcho makes sure it rejects messages other than echo // TestICMPProxyRejectNotEcho makes sure it rejects messages other than echo
func TestICMPRouterRejectNotEcho(t *testing.T) { func TestICMPRouterRejectNotEcho(t *testing.T) {
defer leaktest.Check(t)()
msgs := []icmp.Message{ msgs := []icmp.Message{
{ {
Type: ipv4.ICMPTypeDestinationUnreachable, Type: ipv4.ICMPTypeDestinationUnreachable,
@ -341,7 +358,7 @@ func TestICMPRouterRejectNotEcho(t *testing.T) {
} }
func testICMPRouterRejectNotEcho(t *testing.T, srcDstIP netip.Addr, msgs []icmp.Message) { func testICMPRouterRejectNotEcho(t *testing.T, srcDstIP netip.Addr, msgs []icmp.Message) {
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger) router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger, testFunnelIdleTimeout)
require.NoError(t, err) require.NoError(t, err)
muxer := newMockMuxer(1) muxer := newMockMuxer(1)

View File

@ -2,8 +2,12 @@ package ingress
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net"
"net/http" "net/http"
"github.com/rs/zerolog"
) )
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests. // HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
@ -14,7 +18,7 @@ type HTTPOriginProxy interface {
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP. // StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
type StreamBasedOriginProxy interface { type StreamBasedOriginProxy interface {
EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) EstablishConnection(ctx context.Context, dest string, log *zerolog.Logger) (OriginConnection, error)
} }
// HTTPLocalProxy can be implemented by cloudflared services that want to handle incoming http requests. // HTTPLocalProxy can be implemented by cloudflared services that want to handle incoming http requests.
@ -46,9 +50,28 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("X-Forwarded-Host", req.Host) req.Header.Set("X-Forwarded-Host", req.Host)
req.Host = o.hostHeader req.Host = o.hostHeader
} }
if o.matchSNIToHost {
o.SetOriginServerName(req)
}
return o.transport.RoundTrip(req) return o.transport.RoundTrip(req)
} }
func (o *httpService) SetOriginServerName(req *http.Request) {
o.transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := o.transport.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
return tls.Client(conn, &tls.Config{
RootCAs: o.transport.TLSClientConfig.RootCAs,
InsecureSkipVerify: o.transport.TLSClientConfig.InsecureSkipVerify,
ServerName: req.Host,
}), nil
}
}
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) { func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
if o.defaultResp { if o.defaultResp {
o.log.Warn().Msgf(ErrNoIngressRulesCLI.Error()) o.log.Warn().Msgf(ErrNoIngressRulesCLI.Error())
@ -62,19 +85,21 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return resp, nil return resp, nil
} }
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) { func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string, logger *zerolog.Logger) (OriginConnection, error) {
conn, err := o.dialer.DialContext(ctx, "tcp", dest) conn, err := o.dialer.DialContext(ctx, "tcp", dest)
if err != nil { if err != nil {
return nil, err return nil, err
} }
originConn := &tcpConnection{ originConn := &tcpConnection{
conn: conn, Conn: conn,
writeTimeout: o.writeTimeout,
logger: logger,
} }
return originConn, nil return originConn, nil
} }
func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) { func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string, _ *zerolog.Logger) (OriginConnection, error) {
var err error var err error
if !o.isBastion { if !o.isBastion {
dest = o.dest dest = o.dest
@ -92,6 +117,6 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string)
} }
func (o *socksProxyOverWSService) EstablishConnection(_ctx context.Context, _dest string) (OriginConnection, error) { func (o *socksProxyOverWSService) EstablishConnection(_ context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) {
return o.conn, nil return o.conn, nil
} }

View File

@ -36,7 +36,7 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Origin not listening for new connection, should return an error // Origin not listening for new connection, should return an error
_, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String()) _, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String(), TestLogger)
require.Error(t, err) require.Error(t, err)
} }
@ -87,7 +87,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
t.Run(test.testCase, func(t *testing.T) { t.Run(test.testCase, func(t *testing.T) {
if test.expectErr { if test.expectErr {
bastionHost, _ := carrier.ResolveBastionDest(test.req) bastionHost, _ := carrier.ResolveBastionDest(test.req)
_, err := test.service.EstablishConnection(context.Background(), bastionHost) _, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err) assert.Error(t, err)
} }
}) })
@ -99,7 +99,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} { for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
// Origin not listening for new connection, should return an error // Origin not listening for new connection, should return an error
bastionHost, _ := carrier.ResolveBastionDest(bastionReq) bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
_, err := service.EstablishConnection(context.Background(), bastionHost) _, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err) assert.Error(t, err)
} }
} }
@ -132,7 +132,7 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
url: originURL, url: originURL,
} }
shutdownC := make(chan struct{}) shutdownC := make(chan struct{})
require.NoError(t, httpService.start(testLogger, shutdownC, cfg)) require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil) req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
require.NoError(t, err) require.NoError(t, err)
@ -167,7 +167,7 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
url: originURL, url: originURL,
} }
shutdownC := make(chan struct{}) shutdownC := make(chan struct{})
require.NoError(t, httpService.start(testLogger, shutdownC, cfg)) require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
// Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header // Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header
protos := []string{"https", "http", "dne"} protos := []string{"https", "http", "dne"}

View File

@ -71,6 +71,7 @@ type httpService struct {
url *url.URL url *url.URL
hostHeader string hostHeader string
transport *http.Transport transport *http.Transport
matchSNIToHost bool
} }
func (o *httpService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { func (o *httpService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
@ -80,6 +81,7 @@ func (o *httpService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRe
} }
o.hostHeader = cfg.HTTPHostHeader o.hostHeader = cfg.HTTPHostHeader
o.transport = transport o.transport = transport
o.matchSNIToHost = cfg.MatchSNIToHost
return nil return nil
} }
@ -96,13 +98,15 @@ func (o httpService) MarshalJSON() ([]byte, error) {
type rawTCPService struct { type rawTCPService struct {
name string name string
dialer net.Dialer dialer net.Dialer
writeTimeout time.Duration
logger *zerolog.Logger
} }
func (o *rawTCPService) String() string { func (o *rawTCPService) String() string {
return o.name return o.name
} }
func (o *rawTCPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { func (o *rawTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, _ OriginRequestConfig) error {
return nil return nil
} }
@ -285,13 +289,14 @@ type WarpRoutingService struct {
Proxy StreamBasedOriginProxy Proxy StreamBasedOriginProxy
} }
func NewWarpRoutingService(config WarpRoutingConfig) *WarpRoutingService { func NewWarpRoutingService(config WarpRoutingConfig, writeTimeout time.Duration) *WarpRoutingService {
svc := &rawTCPService{ svc := &rawTCPService{
name: ServiceWarpRouting, name: ServiceWarpRouting,
dialer: net.Dialer{ dialer: net.Dialer{
Timeout: config.ConnectTimeout.Duration, Timeout: config.ConnectTimeout.Duration,
KeepAlive: config.TCPKeepAlive.Duration, KeepAlive: config.TCPKeepAlive.Duration,
}, },
writeTimeout: writeTimeout,
} }
return &WarpRoutingService{Proxy: svc} return &WarpRoutingService{Proxy: svc}

View File

@ -204,25 +204,25 @@ func TestMarshalJSON(t *testing.T) {
{ {
name: "Nil", name: "Nil",
path: nil, path: nil,
expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`, expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
want: true, want: true,
}, },
{ {
name: "Nil regex", name: "Nil regex",
path: &Regexp{Regexp: nil}, path: &Regexp{Regexp: nil},
expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`, expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
want: true, want: true,
}, },
{ {
name: "Empty", name: "Empty",
path: &Regexp{Regexp: regexp.MustCompile("")}, path: &Regexp{Regexp: regexp.MustCompile("")},
expected: `{"hostname":"example.com","path":"","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`, expected: `{"hostname":"example.com","path":"","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
want: true, want: true,
}, },
{ {
name: "Basic", name: "Basic",
path: &Regexp{Regexp: regexp.MustCompile("/echo")}, path: &Regexp{Regexp: regexp.MustCompile("/echo")},
expected: `{"hostname":"example.com","path":"/echo","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`, expected: `{"hostname":"example.com","path":"/echo","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
want: true, want: true,
}, },
} }

View File

@ -3,7 +3,8 @@ package management
import ( import (
"fmt" "fmt"
"github.com/go-jose/go-jose/v3/jwt" "github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
) )
type managementTokenClaims struct { type managementTokenClaims struct {
@ -37,7 +38,7 @@ func (t *actor) verify() bool {
} }
func parseToken(token string) (*managementTokenClaims, error) { func parseToken(token string) (*managementTokenClaims, error) {
jwt, err := jwt.ParseSigned(token) jwt, err := jwt.ParseSigned(token, []jose.SignatureAlgorithm{jose.ES256})
if err != nil { if err != nil {
return nil, fmt.Errorf("malformed jwt: %v", err) return nil, fmt.Errorf("malformed jwt: %v", err)
} }

View File

@ -7,7 +7,7 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v4"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )

View File

@ -2,6 +2,7 @@ package orchestration
import ( import (
"encoding/json" "encoding/json"
"time"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
@ -21,6 +22,7 @@ type newLocalConfig struct {
type Config struct { type Config struct {
Ingress *ingress.Ingress Ingress *ingress.Ingress
WarpRouting ingress.WarpRoutingConfig WarpRouting ingress.WarpRoutingConfig
WriteTimeout time.Duration
// Extra settings used to configure this instance but that are not eligible for remotely management // Extra settings used to configure this instance but that are not eligible for remotely management
// ie. (--protocol, --loglevel, ...) // ie. (--protocol, --loglevel, ...)

View File

@ -17,10 +17,10 @@ import (
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
) )
// Orchestrator manages configurations so they can be updatable during runtime // Orchestrator manages configurations, so they can be updatable during runtime
// properties are static, so it can be read without lock // properties are static, so it can be read without lock
// currentVersion and config are read/write infrequently, so their access are synchronized with RWMutex // currentVersion and config are read/write infrequently, so their access are synchronized with RWMutex
// access to proxy is synchronized with atmoic.Value, because it uses copy-on-write to provide scalable frequently // access to proxy is synchronized with atomic.Value, because it uses copy-on-write to provide scalable frequently
// read when update is infrequent // read when update is infrequent
type Orchestrator struct { type Orchestrator struct {
currentVersion int32 currentVersion int32
@ -30,6 +30,7 @@ type Orchestrator struct {
proxy atomic.Value proxy atomic.Value
// Set of internal ingress rules defined at cloudflared startup (separate from user-defined ingress rules) // Set of internal ingress rules defined at cloudflared startup (separate from user-defined ingress rules)
internalRules []ingress.Rule internalRules []ingress.Rule
// cloudflared Configuration
config *Config config *Config
tags []tunnelpogs.Tag tags []tunnelpogs.Tag
log *zerolog.Logger log *zerolog.Logger
@ -40,7 +41,11 @@ type Orchestrator struct {
proxyShutdownC chan<- struct{} proxyShutdownC chan<- struct{}
} }
func NewOrchestrator(ctx context.Context, config *Config, tags []tunnelpogs.Tag, internalRules []ingress.Rule, log *zerolog.Logger) (*Orchestrator, error) { func NewOrchestrator(ctx context.Context,
config *Config,
tags []tunnelpogs.Tag,
internalRules []ingress.Rule,
log *zerolog.Logger) (*Orchestrator, error) {
o := &Orchestrator{ o := &Orchestrator{
// Lowest possible version, any remote configuration will have version higher than this // Lowest possible version, any remote configuration will have version higher than this
// Starting at -1 allows a configuration migration (local to remote) to override the current configuration as it // Starting at -1 allows a configuration migration (local to remote) to override the current configuration as it
@ -131,7 +136,7 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil { if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil {
return errors.Wrap(err, "failed to start origin") return errors.Wrap(err, "failed to start origin")
} }
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.log) proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.config.WriteTimeout, o.log)
o.proxy.Store(proxy) o.proxy.Store(proxy)
o.config.Ingress = &ingressRules o.config.Ingress = &ingressRules
o.config.WarpRouting = warpRouting o.config.WarpRouting = warpRouting

View File

@ -27,7 +27,7 @@ import (
) )
var ( var (
testLogger = zerolog.Logger{} testLogger = zerolog.Nop()
testTags = []tunnelpogs.Tag{ testTags = []tunnelpogs.Tag{
{ {
Name: "package", Name: "package",

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/bin/bash
set -eu set -eu
rm /usr/local/bin/cloudflared rm -f /usr/local/bin/cloudflared
rm /usr/local/etc/cloudflared/.installedFromPackageManager || true rm -f /usr/local/etc/cloudflared/.installedFromPackageManager

78
proxy/logger.go Normal file
View File

@ -0,0 +1,78 @@
package proxy
import (
"net/http"
"strconv"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/management"
)
const (
logFieldCFRay = "cfRay"
logFieldLBProbe = "lbProbe"
logFieldRule = "ingressRule"
logFieldOriginService = "originService"
logFieldFlowID = "flowID"
logFieldConnIndex = "connIndex"
logFieldDestAddr = "destAddr"
)
// newHTTPLogger creates a child zerolog.Logger from the provided with added context from the HTTP request, ingress
// services, and connection index.
func newHTTPLogger(logger *zerolog.Logger, connIndex uint8, req *http.Request, rule int, serviceName string) zerolog.Logger {
ctx := logger.With().
Int(management.EventTypeKey, int(management.HTTP)).
Uint8(logFieldConnIndex, connIndex)
cfRay := connection.FindCfRayHeader(req)
lbProbe := connection.IsLBProbeRequest(req)
if cfRay != "" {
ctx.Str(logFieldCFRay, cfRay)
}
if lbProbe {
ctx.Bool(logFieldLBProbe, lbProbe)
}
return ctx.
Str(logFieldOriginService, serviceName).
Interface(logFieldRule, rule).
Logger()
}
// newTCPLogger creates a child zerolog.Logger from the provided with added context from the TCPRequest.
func newTCPLogger(logger *zerolog.Logger, req *connection.TCPRequest) zerolog.Logger {
return logger.With().
Int(management.EventTypeKey, int(management.TCP)).
Uint8(logFieldConnIndex, req.ConnIndex).
Str(logFieldOriginService, ingress.ServiceWarpRouting).
Str(logFieldFlowID, req.FlowID).
Str(logFieldDestAddr, req.Dest).
Uint8(logFieldConnIndex, req.ConnIndex).
Logger()
}
// logHTTPRequest logs a Debug message with the corresponding HTTP request details from the eyeball.
func logHTTPRequest(logger *zerolog.Logger, r *http.Request) {
logger.Debug().
Str("host", r.Host).
Str("path", r.URL.Path).
Interface("headers", r.Header).
Int64("content-length", r.ContentLength).
Msgf("%s %s %s", r.Method, r.URL, r.Proto)
}
// logOriginHTTPResponse logs a Debug message of the origin response.
func logOriginHTTPResponse(logger *zerolog.Logger, resp *http.Response) {
responseByCode.WithLabelValues(strconv.Itoa(resp.StatusCode)).Inc()
logger.Debug().
Int64("content-length", resp.ContentLength).
Msgf("%s", resp.Status)
}
// logRequestError logs an error for the proxied request.
func logRequestError(logger *zerolog.Logger, err error) {
requestErrors.Inc()
logger.Error().Err(err).Send()
}

View File

@ -59,6 +59,23 @@ var (
Help: "Total count of TCP sessions that have been proxied to any origin", Help: "Total count of TCP sessions that have been proxied to any origin",
}, },
) )
connectLatency = prometheus.NewHistogram(
prometheus.HistogramOpts{
Namespace: connection.MetricsNamespace,
Subsystem: "proxy",
Name: "connect_latency",
Help: "Time it takes to establish and acknowledge connections in milliseconds",
Buckets: []float64{1, 10, 25, 50, 100, 500, 1000, 5000},
},
)
connectStreamErrors = prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: connection.MetricsNamespace,
Subsystem: "proxy",
Name: "connect_streams_errors",
Help: "Total count of failure to establish and acknowledge connections",
},
)
) )
func init() { func init() {
@ -69,6 +86,8 @@ func init() {
requestErrors, requestErrors,
activeTCPSessions, activeTCPSessions,
totalTCPSessions, totalTCPSessions,
connectLatency,
connectStreamErrors,
) )
} }

View File

@ -6,6 +6,7 @@ import (
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -16,7 +17,6 @@ import (
"github.com/cloudflare/cloudflared/cfio" "github.com/cloudflare/cloudflared/cfio"
"github.com/cloudflare/cloudflared/connection" "github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress" "github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/management"
"github.com/cloudflare/cloudflared/stream" "github.com/cloudflare/cloudflared/stream"
"github.com/cloudflare/cloudflared/tracing" "github.com/cloudflare/cloudflared/tracing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs" tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -25,14 +25,6 @@ import (
const ( const (
// TagHeaderNamePrefix indicates a Cloudflared Warp Tag prefix that gets appended for warp traffic stream headers. // TagHeaderNamePrefix indicates a Cloudflared Warp Tag prefix that gets appended for warp traffic stream headers.
TagHeaderNamePrefix = "Cf-Warp-Tag-" TagHeaderNamePrefix = "Cf-Warp-Tag-"
LogFieldCFRay = "cfRay"
LogFieldLBProbe = "lbProbe"
LogFieldRule = "ingressRule"
LogFieldOriginService = "originService"
LogFieldFlowID = "flowID"
LogFieldConnIndex = "connIndex"
LogFieldDestAddr = "destAddr"
trailerHeaderName = "Trailer" trailerHeaderName = "Trailer"
) )
@ -50,6 +42,7 @@ func NewOriginProxy(
ingressRules ingress.Ingress, ingressRules ingress.Ingress,
warpRouting ingress.WarpRoutingConfig, warpRouting ingress.WarpRoutingConfig,
tags []tunnelpogs.Tag, tags []tunnelpogs.Tag,
writeTimeout time.Duration,
log *zerolog.Logger, log *zerolog.Logger,
) *Proxy { ) *Proxy {
proxy := &Proxy{ proxy := &Proxy{
@ -58,7 +51,7 @@ func NewOriginProxy(
log: log, log: log,
} }
proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting) proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting, writeTimeout)
return proxy return proxy
} }
@ -89,26 +82,18 @@ func (p *Proxy) ProxyHTTP(
defer decrementConcurrentRequests() defer decrementConcurrentRequests()
req := tr.Request req := tr.Request
cfRay := connection.FindCfRayHeader(req)
lbProbe := connection.IsLBProbeRequest(req)
p.appendTagHeaders(req) p.appendTagHeaders(req)
_, ruleSpan := tr.Tracer().Start(req.Context(), "ingress_match", _, ruleSpan := tr.Tracer().Start(req.Context(), "ingress_match",
trace.WithAttributes(attribute.String("req-host", req.Host))) trace.WithAttributes(attribute.String("req-host", req.Host)))
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path) rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
logFields := logFields{
cfRay: cfRay,
lbProbe: lbProbe,
rule: ruleNum,
connIndex: tr.ConnIndex,
}
p.logRequest(req, logFields)
ruleSpan.SetAttributes(attribute.Int("rule-num", ruleNum)) ruleSpan.SetAttributes(attribute.Int("rule-num", ruleNum))
ruleSpan.End() ruleSpan.End()
logger := newHTTPLogger(p.log, tr.ConnIndex, req, ruleNum, rule.Service.String())
logHTTPRequest(&logger, req)
if err, applied := p.applyIngressMiddleware(rule, req, w); err != nil { if err, applied := p.applyIngressMiddleware(rule, req, w); err != nil {
if applied { if applied {
rule, srv := ruleField(p.ingressRules, ruleNum) logRequestError(&logger, err)
p.logRequestError(err, cfRay, "", rule, srv)
return nil return nil
} }
return err return err
@ -122,10 +107,9 @@ func (p *Proxy) ProxyHTTP(
originProxy, originProxy,
isWebsocket, isWebsocket,
rule.Config.DisableChunkedEncoding, rule.Config.DisableChunkedEncoding,
logFields, &logger,
); err != nil { ); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum) logRequestError(&logger, err)
p.logRequestError(err, cfRay, "", rule, srv)
return err return err
} }
return nil return nil
@ -139,9 +123,9 @@ func (p *Proxy) ProxyHTTP(
return fmt.Errorf("response writer is not a flusher") return fmt.Errorf("response writer is not a flusher")
} }
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req) rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, originProxy); err != nil { logger := logger.With().Str(logFieldDestAddr, dest).Logger()
rule, srv := ruleField(p.ingressRules, ruleNum) if err := p.proxyStream(tr.ToTracedContext(), rws, dest, originProxy, &logger); err != nil {
p.logRequestError(err, cfRay, "", rule, srv) logRequestError(&logger, err)
return err return err
} }
return nil return nil
@ -171,38 +155,20 @@ func (p *Proxy) ProxyTCP(
serveCtx, cancel := context.WithCancel(ctx) serveCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, p.log) logger := newTCPLogger(p.log, req)
tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, &logger)
logger.Debug().Msg("tcp proxy stream started")
p.log.Debug(). if err := p.proxyStream(tracedCtx, rwa, req.Dest, p.warpRouting.Proxy, &logger); err != nil {
Int(management.EventTypeKey, int(management.TCP)). logRequestError(&logger, err)
Str(LogFieldFlowID, req.FlowID).
Str(LogFieldDestAddr, req.Dest).
Uint8(LogFieldConnIndex, req.ConnIndex).
Msg("tcp proxy stream started")
if err := p.proxyStream(tracedCtx, rwa, req.Dest, p.warpRouting.Proxy); err != nil {
p.logRequestError(err, req.CFRay, req.FlowID, "", ingress.ServiceWarpRouting)
return err return err
} }
p.log.Debug(). logger.Debug().Msg("tcp proxy stream finished successfully")
Int(management.EventTypeKey, int(management.TCP)).
Str(LogFieldFlowID, req.FlowID).
Str(LogFieldDestAddr, req.Dest).
Uint8(LogFieldConnIndex, req.ConnIndex).
Msg("tcp proxy stream finished successfully")
return nil return nil
} }
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
srv = ing.Rules[ruleNum].Service.String()
if ing.IsSingleRule() {
return "", srv
}
return fmt.Sprintf("%d", ruleNum), srv
}
// ProxyHTTPRequest proxies requests of underlying type http and websocket to the origin service. // ProxyHTTPRequest proxies requests of underlying type http and websocket to the origin service.
func (p *Proxy) proxyHTTPRequest( func (p *Proxy) proxyHTTPRequest(
w connection.ResponseWriter, w connection.ResponseWriter,
@ -210,7 +176,7 @@ func (p *Proxy) proxyHTTPRequest(
httpService ingress.HTTPOriginProxy, httpService ingress.HTTPOriginProxy,
isWebsocket bool, isWebsocket bool,
disableChunkedEncoding bool, disableChunkedEncoding bool,
fields logFields, logger *zerolog.Logger,
) error { ) error {
roundTripReq := tr.Request roundTripReq := tr.Request
if isWebsocket { if isWebsocket {
@ -277,7 +243,7 @@ func (p *Proxy) proxyHTTPRequest(
reader: tr.Request.Body, reader: tr.Request.Body,
} }
stream.Pipe(eyeballStream, rwc, p.log) stream.Pipe(eyeballStream, rwc, logger)
return nil return nil
} }
@ -288,35 +254,45 @@ func (p *Proxy) proxyHTTPRequest(
// copy trailers // copy trailers
copyTrailers(w, resp) copyTrailers(w, resp)
p.logOriginResponse(resp, fields) logOriginHTTPResponse(logger, resp)
return nil return nil
} }
// proxyStream proxies type TCP and other underlying types if the connection is defined as a stream oriented // proxyStream proxies type TCP and other underlying types if the connection is defined as a stream oriented
// ingress rule. // ingress rule.
// connectedLogger is used to log when the connection is acknowledged
func (p *Proxy) proxyStream( func (p *Proxy) proxyStream(
tr *tracing.TracedContext, tr *tracing.TracedContext,
rwa connection.ReadWriteAcker, rwa connection.ReadWriteAcker,
dest string, dest string,
connectionProxy ingress.StreamBasedOriginProxy, connectionProxy ingress.StreamBasedOriginProxy,
logger *zerolog.Logger,
) error { ) error {
ctx := tr.Context ctx := tr.Context
_, connectSpan := tr.Tracer().Start(ctx, "stream-connect") _, connectSpan := tr.Tracer().Start(ctx, "stream-connect")
originConn, err := connectionProxy.EstablishConnection(ctx, dest)
start := time.Now()
originConn, err := connectionProxy.EstablishConnection(ctx, dest, logger)
if err != nil { if err != nil {
connectStreamErrors.Inc()
tracing.EndWithErrorStatus(connectSpan, err) tracing.EndWithErrorStatus(connectSpan, err)
return err return err
} }
connectSpan.End() connectSpan.End()
defer originConn.Close() defer originConn.Close()
logger.Debug().Msg("origin connection established")
encodedSpans := tr.GetSpans() encodedSpans := tr.GetSpans()
if err := rwa.AckConnection(encodedSpans); err != nil { if err := rwa.AckConnection(encodedSpans); err != nil {
connectStreamErrors.Inc()
return err return err
} }
originConn.Stream(ctx, rwa, p.log) connectLatency.Observe(float64(time.Since(start).Milliseconds()))
logger.Debug().Msg("proxy stream acknowledged")
originConn.Stream(ctx, rwa, logger)
return nil return nil
} }
@ -350,14 +326,6 @@ func (p *Proxy) appendTagHeaders(r *http.Request) {
} }
} }
type logFields struct {
cfRay string
lbProbe bool
rule int
flowID string
connIndex uint8
}
func copyTrailers(w connection.ResponseWriter, response *http.Response) { func copyTrailers(w connection.ResponseWriter, response *http.Response) {
for trailerHeader, trailerValues := range response.Trailer { for trailerHeader, trailerValues := range response.Trailer {
for _, trailerValue := range trailerValues { for _, trailerValue := range trailerValues {
@ -366,64 +334,6 @@ func copyTrailers(w connection.ResponseWriter, response *http.Response) {
} }
} }
func (p *Proxy) logRequest(r *http.Request, fields logFields) {
log := p.log.With().Int(management.EventTypeKey, int(management.HTTP)).Logger()
event := log.Debug()
if fields.cfRay != "" {
event = event.Str(LogFieldCFRay, fields.cfRay)
}
if fields.lbProbe {
event = event.Bool(LogFieldLBProbe, fields.lbProbe)
}
if fields.cfRay == "" && !fields.lbProbe {
log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto)
}
event.
Uint8(LogFieldConnIndex, fields.connIndex).
Str("host", r.Host).
Str("path", r.URL.Path).
Interface(LogFieldRule, fields.rule).
Interface("headers", r.Header).
Int64("content-length", r.ContentLength).
Msgf("%s %s %s", r.Method, r.URL, r.Proto)
}
func (p *Proxy) logOriginResponse(resp *http.Response, fields logFields) {
responseByCode.WithLabelValues(strconv.Itoa(resp.StatusCode)).Inc()
event := p.log.Debug()
if fields.cfRay != "" {
event = event.Str(LogFieldCFRay, fields.cfRay)
}
if fields.lbProbe {
event = event.Bool(LogFieldLBProbe, fields.lbProbe)
}
event.
Int(management.EventTypeKey, int(management.HTTP)).
Uint8(LogFieldConnIndex, fields.connIndex).
Int64("content-length", resp.ContentLength).
Msgf("%s", resp.Status)
}
func (p *Proxy) logRequestError(err error, cfRay string, flowID string, rule, service string) {
requestErrors.Inc()
log := p.log.Error().Err(err)
if cfRay != "" {
log = log.Str(LogFieldCFRay, cfRay)
}
if flowID != "" {
log = log.Str(LogFieldFlowID, flowID).Int(management.EventTypeKey, int(management.TCP))
} else {
log = log.Int(management.EventTypeKey, int(management.HTTP))
}
if rule != "" {
log = log.Str(LogFieldRule, rule)
}
if service != "" {
log = log.Str(LogFieldOriginService, service)
}
log.Send()
}
func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) { func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
switch rule.Service.String() { switch rule.Service.String() {
case ingress.ServiceBastion: case ingress.ServiceBastion:

View File

@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done())) require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, &log) proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, time.Duration(0), &log)
t.Run("testProxyHTTP", testProxyHTTP(proxy)) t.Run("testProxyHTTP", testProxyHTTP(proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(proxy)) t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
t.Run("testProxySSE", testProxySSE(proxy)) t.Run("testProxySSE", testProxySSE(proxy))
@ -366,7 +366,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
require.NoError(t, ingress.StartOrigins(&log, ctx.Done())) require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, &log) proxy := NewOriginProxy(ingress, noWarpRouting, testTags, time.Duration(0), &log)
for _, test := range tests { for _, test := range tests {
responseWriter := newMockHTTPRespWriter() responseWriter := newMockHTTPRespWriter()
@ -414,7 +414,7 @@ func TestProxyError(t *testing.T) {
log := zerolog.Nop() log := zerolog.Nop()
proxy := NewOriginProxy(ing, noWarpRouting, testTags, &log) proxy := NewOriginProxy(ing, noWarpRouting, testTags, time.Duration(0), &log)
responseWriter := newMockHTTPRespWriter() responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
@ -530,7 +530,7 @@ func TestConnections(t *testing.T) {
originService: runEchoTCPService, originService: runEchoTCPService,
eyeballResponseWriter: newTCPRespWriter(replayer), eyeballResponseWriter: newTCPRespWriter(replayer),
eyeballRequestBody: newTCPRequestBody([]byte("test2")), eyeballRequestBody: newTCPRequestBody([]byte("test2")),
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting), warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
connectionType: connection.TypeTCP, connectionType: connection.TypeTCP,
requestHeaders: map[string][]string{ requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, "Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
@ -548,7 +548,7 @@ func TestConnections(t *testing.T) {
originService: runEchoWSService, originService: runEchoWSService,
// eyeballResponseWriter gets set after roundtrip dial. // eyeballResponseWriter gets set after roundtrip dial.
eyeballRequestBody: newPipedWSRequestBody([]byte("test3")), eyeballRequestBody: newPipedWSRequestBody([]byte("test3")),
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting), warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
requestHeaders: map[string][]string{ requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"}, "Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
}, },
@ -675,7 +675,7 @@ func TestConnections(t *testing.T) {
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String()) ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
ingressRule.StartOrigins(logger, ctx.Done()) ingressRule.StartOrigins(logger, ctx.Done())
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, logger) proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, time.Duration(0), logger)
proxy.warpRouting = test.args.warpRoutingService proxy.warpRouting = test.args.warpRoutingService
dest := ln.Addr().String() dest := ln.Addr().String()

View File

@ -7,17 +7,6 @@ import (
"github.com/quic-go/quic-go/logging" "github.com/quic-go/quic-go/logging"
) )
func perspectiveString(p logging.Perspective) string {
switch p {
case logging.PerspectiveClient:
return "client"
case logging.PerspectiveServer:
return "server"
default:
return ""
}
}
// Helper to convert logging.ByteCount(alias for int64) to float64 used in prometheus // Helper to convert logging.ByteCount(alias for int64) to float64 used in prometheus
func byteCountToPromCount(count logging.ByteCount) float64 { func byteCountToPromCount(count logging.ByteCount) float64 {
return float64(count) return float64(count)

View File

@ -1,6 +1,8 @@
package quic package quic
import ( import (
"reflect"
"strings"
"sync" "sync"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -16,10 +18,10 @@ var (
clientMetrics = struct { clientMetrics = struct {
totalConnections prometheus.Counter totalConnections prometheus.Counter
closedConnections prometheus.Counter closedConnections prometheus.Counter
sentPackets *prometheus.CounterVec sentFrames *prometheus.CounterVec
sentBytes *prometheus.CounterVec sentBytes *prometheus.CounterVec
receivePackets *prometheus.CounterVec receivedFrames *prometheus.CounterVec
receiveBytes *prometheus.CounterVec receivedBytes *prometheus.CounterVec
bufferedPackets *prometheus.CounterVec bufferedPackets *prometheus.CounterVec
droppedPackets *prometheus.CounterVec droppedPackets *prometheus.CounterVec
lostPackets *prometheus.CounterVec lostPackets *prometheus.CounterVec
@ -28,43 +30,88 @@ var (
smoothedRTT *prometheus.GaugeVec smoothedRTT *prometheus.GaugeVec
}{ }{
totalConnections: prometheus.NewCounter( totalConnections: prometheus.NewCounter(
totalConnectionsOpts(logging.PerspectiveClient), prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "total_connections",
Help: "Number of connections initiated",
},
), ),
closedConnections: prometheus.NewCounter( closedConnections: prometheus.NewCounter(
closedConnectionsOpts(logging.PerspectiveClient), prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "closed_connections",
Help: "Number of connections that has been closed",
},
), ),
sentPackets: prometheus.NewCounterVec( sentFrames: prometheus.NewCounterVec(
sentPacketsOpts(logging.PerspectiveClient), prometheus.CounterOpts{
clientConnLabels, Namespace: namespace,
Subsystem: "client",
Name: "sent_frames",
Help: "Number of frames that have been sent through a connection",
},
append(clientConnLabels, "frame_type"),
), ),
sentBytes: prometheus.NewCounterVec( sentBytes: prometheus.NewCounterVec(
sentBytesOpts(logging.PerspectiveClient), prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "sent_bytes",
Help: "Number of bytes that have been sent through a connection",
},
clientConnLabels, clientConnLabels,
), ),
receivePackets: prometheus.NewCounterVec( receivedFrames: prometheus.NewCounterVec(
receivePacketsOpts(logging.PerspectiveClient), prometheus.CounterOpts{
clientConnLabels, Namespace: namespace,
Subsystem: "client",
Name: "received_frames",
Help: "Number of frames that have been received through a connection",
},
append(clientConnLabels, "frame_type"),
), ),
receiveBytes: prometheus.NewCounterVec( receivedBytes: prometheus.NewCounterVec(
receiveBytesOpts(logging.PerspectiveClient), prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "receive_bytes",
Help: "Number of bytes that have been received through a connection",
},
clientConnLabels, clientConnLabels,
), ),
bufferedPackets: prometheus.NewCounterVec( bufferedPackets: prometheus.NewCounterVec(
bufferedPacketsOpts(logging.PerspectiveClient), prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "buffered_packets",
Help: "Number of bytes that have been buffered on a connection",
},
append(clientConnLabels, "packet_type"), append(clientConnLabels, "packet_type"),
), ),
droppedPackets: prometheus.NewCounterVec( droppedPackets: prometheus.NewCounterVec(
droppedPacketsOpts(logging.PerspectiveClient), prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "dropped_packets",
Help: "Number of bytes that have been dropped on a connection",
},
append(clientConnLabels, "packet_type", "reason"), append(clientConnLabels, "packet_type", "reason"),
), ),
lostPackets: prometheus.NewCounterVec( lostPackets: prometheus.NewCounterVec(
lostPacketsOpts(logging.PerspectiveClient), prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "lost_packets",
Help: "Number of packets that have been lost from a connection",
},
append(clientConnLabels, "reason"), append(clientConnLabels, "reason"),
), ),
minRTT: prometheus.NewGaugeVec( minRTT: prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Namespace: namespace, Namespace: namespace,
Subsystem: perspectiveString(logging.PerspectiveClient), Subsystem: "client",
Name: "min_rtt", Name: "min_rtt",
Help: "Lowest RTT measured on a connection in millisec", Help: "Lowest RTT measured on a connection in millisec",
}, },
@ -73,7 +120,7 @@ var (
latestRTT: prometheus.NewGaugeVec( latestRTT: prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Namespace: namespace, Namespace: namespace,
Subsystem: perspectiveString(logging.PerspectiveClient), Subsystem: "client",
Name: "latest_rtt", Name: "latest_rtt",
Help: "Latest RTT measured on a connection", Help: "Latest RTT measured on a connection",
}, },
@ -82,188 +129,37 @@ var (
smoothedRTT: prometheus.NewGaugeVec( smoothedRTT: prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Namespace: namespace, Namespace: namespace,
Subsystem: perspectiveString(logging.PerspectiveClient), Subsystem: "client",
Name: "smoothed_rtt", Name: "smoothed_rtt",
Help: "Calculated smoothed RTT measured on a connection in millisec", Help: "Calculated smoothed RTT measured on a connection in millisec",
}, },
clientConnLabels, clientConnLabels,
), ),
} }
// The server has many QUIC connections. Adding per connection label incurs high memory cost
serverMetrics = struct {
totalConnections prometheus.Counter
closedConnections prometheus.Counter
sentPackets prometheus.Counter
sentBytes prometheus.Counter
receivePackets prometheus.Counter
receiveBytes prometheus.Counter
bufferedPackets *prometheus.CounterVec
droppedPackets *prometheus.CounterVec
lostPackets *prometheus.CounterVec
rtt prometheus.Histogram
}{
totalConnections: prometheus.NewCounter(
totalConnectionsOpts(logging.PerspectiveServer),
),
closedConnections: prometheus.NewCounter(
closedConnectionsOpts(logging.PerspectiveServer),
),
sentPackets: prometheus.NewCounter(
sentPacketsOpts(logging.PerspectiveServer),
),
sentBytes: prometheus.NewCounter(
sentBytesOpts(logging.PerspectiveServer),
),
receivePackets: prometheus.NewCounter(
receivePacketsOpts(logging.PerspectiveServer),
),
receiveBytes: prometheus.NewCounter(
receiveBytesOpts(logging.PerspectiveServer),
),
bufferedPackets: prometheus.NewCounterVec(
bufferedPacketsOpts(logging.PerspectiveServer),
[]string{"packet_type"},
),
droppedPackets: prometheus.NewCounterVec(
droppedPacketsOpts(logging.PerspectiveServer),
[]string{"packet_type", "reason"},
),
lostPackets: prometheus.NewCounterVec(
lostPacketsOpts(logging.PerspectiveServer),
[]string{"reason"},
),
rtt: prometheus.NewHistogram(
prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: perspectiveString(logging.PerspectiveServer),
Name: "rtt",
Buckets: []float64{5, 10, 20, 30, 40, 50, 75, 100},
},
),
}
registerClient = sync.Once{} registerClient = sync.Once{}
registerServer = sync.Once{}
packetTooBigDropped = prometheus.NewCounter(prometheus.CounterOpts{ packetTooBigDropped = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: namespace, Namespace: namespace,
Subsystem: perspectiveString(logging.PerspectiveClient), Subsystem: "client",
Name: "packet_too_big_dropped", Name: "packet_too_big_dropped",
Help: "Count of packets received from origin that are too big to send to the edge and are dropped as a result", Help: "Count of packets received from origin that are too big to send to the edge and are dropped as a result",
}) })
) )
// MetricsCollector abstracts the difference between client and server metrics from connTracer
type MetricsCollector interface {
startedConnection()
closedConnection(err error)
sentPackets(logging.ByteCount)
receivedPackets(logging.ByteCount)
bufferedPackets(logging.PacketType)
droppedPackets(logging.PacketType, logging.ByteCount, logging.PacketDropReason)
lostPackets(logging.PacketLossReason)
updatedRTT(*logging.RTTStats)
}
func totalConnectionsOpts(p logging.Perspective) prometheus.CounterOpts {
var help string
if p == logging.PerspectiveClient {
help = "Number of connections initiated. For all quic metrics, client means the side initiating the connection"
} else {
help = "Number of connections accepted. For all quic metrics, server means the side accepting connections"
}
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "total_connections",
Help: help,
}
}
func closedConnectionsOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "closed_connections",
Help: "Number of connections that has been closed",
}
}
func sentPacketsOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "sent_packets",
Help: "Number of packets that have been sent through a connection",
}
}
func sentBytesOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "sent_bytes",
Help: "Number of bytes that have been sent through a connection",
}
}
func receivePacketsOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "receive_packets",
Help: "Number of packets that have been received through a connection",
}
}
func receiveBytesOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "receive_bytes",
Help: "Number of bytes that have been received through a connection",
}
}
func bufferedPacketsOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "buffered_packets",
Help: "Number of bytes that have been buffered on a connection",
}
}
func droppedPacketsOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "dropped_packets",
Help: "Number of bytes that have been dropped on a connection",
}
}
func lostPacketsOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "lost_packets",
Help: "Number of packets that have been lost from a connection",
}
}
type clientCollector struct { type clientCollector struct {
index string index string
} }
func newClientCollector(index uint8) MetricsCollector { func newClientCollector(index uint8) *clientCollector {
registerClient.Do(func() { registerClient.Do(func() {
prometheus.MustRegister( prometheus.MustRegister(
clientMetrics.totalConnections, clientMetrics.totalConnections,
clientMetrics.closedConnections, clientMetrics.closedConnections,
clientMetrics.sentPackets, clientMetrics.sentFrames,
clientMetrics.sentBytes, clientMetrics.sentBytes,
clientMetrics.receivePackets, clientMetrics.receivedFrames,
clientMetrics.receiveBytes, clientMetrics.receivedBytes,
clientMetrics.bufferedPackets, clientMetrics.bufferedPackets,
clientMetrics.droppedPackets, clientMetrics.droppedPackets,
clientMetrics.lostPackets, clientMetrics.lostPackets,
@ -286,14 +182,12 @@ func (cc *clientCollector) closedConnection(err error) {
clientMetrics.closedConnections.Inc() clientMetrics.closedConnections.Inc()
} }
func (cc *clientCollector) sentPackets(size logging.ByteCount) { func (cc *clientCollector) sentPackets(size logging.ByteCount, frames []logging.Frame) {
clientMetrics.sentPackets.WithLabelValues(cc.index).Inc() cc.collectPackets(size, frames, clientMetrics.sentFrames, clientMetrics.sentBytes)
clientMetrics.sentBytes.WithLabelValues(cc.index).Add(byteCountToPromCount(size))
} }
func (cc *clientCollector) receivedPackets(size logging.ByteCount) { func (cc *clientCollector) receivedPackets(size logging.ByteCount, frames []logging.Frame) {
clientMetrics.receivePackets.WithLabelValues(cc.index).Inc() cc.collectPackets(size, frames, clientMetrics.receivedFrames, clientMetrics.receivedBytes)
clientMetrics.receiveBytes.WithLabelValues(cc.index).Add(byteCountToPromCount(size))
} }
func (cc *clientCollector) bufferedPackets(packetType logging.PacketType) { func (cc *clientCollector) bufferedPackets(packetType logging.PacketType) {
@ -318,63 +212,18 @@ func (cc *clientCollector) updatedRTT(rtt *logging.RTTStats) {
clientMetrics.smoothedRTT.WithLabelValues(cc.index).Set(durationToPromGauge(rtt.SmoothedRTT())) clientMetrics.smoothedRTT.WithLabelValues(cc.index).Set(durationToPromGauge(rtt.SmoothedRTT()))
} }
type serverCollector struct{} func (cc *clientCollector) collectPackets(size logging.ByteCount, frames []logging.Frame, counter, bandwidth *prometheus.CounterVec) {
for _, frame := range frames {
func newServiceCollector() MetricsCollector { counter.WithLabelValues(cc.index, frameName(frame)).Inc()
registerServer.Do(func() { }
prometheus.MustRegister( bandwidth.WithLabelValues(cc.index).Add(byteCountToPromCount(size))
serverMetrics.totalConnections,
serverMetrics.closedConnections,
serverMetrics.sentPackets,
serverMetrics.sentBytes,
serverMetrics.receivePackets,
serverMetrics.receiveBytes,
serverMetrics.bufferedPackets,
serverMetrics.droppedPackets,
serverMetrics.lostPackets,
serverMetrics.rtt,
)
})
return &serverCollector{}
} }
func (sc *serverCollector) startedConnection() { func frameName(frame logging.Frame) string {
serverMetrics.totalConnections.Inc() if frame == nil {
} return "nil"
} else {
func (sc *serverCollector) closedConnection(err error) { name := reflect.TypeOf(frame).Elem().Name()
serverMetrics.closedConnections.Inc() return strings.TrimSuffix(name, "Frame")
}
func (sc *serverCollector) sentPackets(size logging.ByteCount) {
serverMetrics.sentPackets.Inc()
serverMetrics.sentBytes.Add(byteCountToPromCount(size))
}
func (sc *serverCollector) receivedPackets(size logging.ByteCount) {
serverMetrics.receivePackets.Inc()
serverMetrics.receiveBytes.Add(byteCountToPromCount(size))
}
func (sc *serverCollector) bufferedPackets(packetType logging.PacketType) {
serverMetrics.bufferedPackets.WithLabelValues(packetTypeString(packetType)).Inc()
}
func (sc *serverCollector) droppedPackets(packetType logging.PacketType, size logging.ByteCount, reason logging.PacketDropReason) {
serverMetrics.droppedPackets.WithLabelValues(
packetTypeString(packetType),
packetDropReasonString(reason),
).Add(byteCountToPromCount(size))
}
func (sc *serverCollector) lostPackets(reason logging.PacketLossReason) {
serverMetrics.lostPackets.WithLabelValues(packetLossReasonString(reason)).Inc()
}
func (sc *serverCollector) updatedRTT(rtt *logging.RTTStats) {
latestRTT := rtt.LatestRTT()
// May return 0 if no valid updates have occurred
if latestRTT > 0 {
serverMetrics.rtt.Observe(durationToPromGauge(latestRTT))
} }
} }

View File

@ -109,63 +109,6 @@ func TestConnectResponseMeta(t *testing.T) {
} }
} }
func TestUnregisterUdpSession(t *testing.T) {
unregisterMessage := "closed by eyeball"
var tests = []struct {
name string
sessionRPCServer mockSessionRPCServer
timeout time.Duration
}{
{
name: "UnregisterUdpSessionTimesout if the RPC server does not respond",
sessionRPCServer: mockSessionRPCServer{
sessionID: uuid.New(),
dstIP: net.IP{172, 16, 0, 1},
dstPort: 8000,
closeIdleAfter: testCloseIdleAfterHint,
unregisterMessage: unregisterMessage,
traceContext: "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1",
},
// very very low value so we trigger the timeout every time.
timeout: time.Nanosecond * 1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
logger := zerolog.Nop()
clientStream, serverStream := newMockRPCStreams()
sessionRegisteredChan := make(chan struct{})
go func() {
protocol, err := DetermineProtocol(serverStream)
assert.NoError(t, err)
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
assert.NoError(t, err)
err = rpcServerStream.Serve(test.sessionRPCServer, nil, &logger)
assert.NoError(t, err)
serverStream.Close()
close(sessionRegisteredChan)
}()
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, test.timeout, &logger)
assert.NoError(t, err)
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
assert.NoError(t, err)
assert.NoError(t, reg.Err)
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
rpcClientStream.Close()
<-sessionRegisteredChan
})
}
}
func TestRegisterUdpSession(t *testing.T) { func TestRegisterUdpSession(t *testing.T) {
unregisterMessage := "closed by eyeball" unregisterMessage := "closed by eyeball"

View File

@ -1,20 +1,33 @@
package quic package quic
import ( import (
"errors"
"net"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
) )
// The error that is throw by the writer when there is `no network activity`.
var idleTimeoutError = quic.IdleTimeoutError{}
type SafeStreamCloser struct { type SafeStreamCloser struct {
lock sync.Mutex lock sync.Mutex
stream quic.Stream stream quic.Stream
writeTimeout time.Duration
log *zerolog.Logger
closing atomic.Bool
} }
func NewSafeStreamCloser(stream quic.Stream) *SafeStreamCloser { func NewSafeStreamCloser(stream quic.Stream, writeTimeout time.Duration, log *zerolog.Logger) *SafeStreamCloser {
return &SafeStreamCloser{ return &SafeStreamCloser{
stream: stream, stream: stream,
writeTimeout: writeTimeout,
log: log,
} }
} }
@ -25,10 +38,43 @@ func (s *SafeStreamCloser) Read(p []byte) (n int, err error) {
func (s *SafeStreamCloser) Write(p []byte) (n int, err error) { func (s *SafeStreamCloser) Write(p []byte) (n int, err error) {
s.lock.Lock() s.lock.Lock()
defer s.lock.Unlock() defer s.lock.Unlock()
return s.stream.Write(p) if s.writeTimeout > 0 {
err = s.stream.SetWriteDeadline(time.Now().Add(s.writeTimeout))
if err != nil {
log.Err(err).Msg("Error setting write deadline for QUIC stream")
}
}
nBytes, err := s.stream.Write(p)
if err != nil {
s.handleWriteError(err)
}
return nBytes, err
}
// Handles the timeout error in case it happened, by canceling the stream write.
func (s *SafeStreamCloser) handleWriteError(err error) {
// If we are closing the stream we just ignore any write error.
if s.closing.Load() {
return
}
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
// We don't need to log if what cause the timeout was no network activity.
if !errors.Is(netErr, &idleTimeoutError) {
s.log.Error().Err(netErr).Msg("Closing quic stream due to timeout while writing")
}
// We need to explicitly cancel the write so that it frees all buffers.
s.stream.CancelWrite(0)
}
}
} }
func (s *SafeStreamCloser) Close() error { func (s *SafeStreamCloser) Close() error {
// Set this stream to a closing state.
s.closing.Store(true)
// Make sure a possible writer does not block the lock forever. We need it, so we can close the writer // Make sure a possible writer does not block the lock forever. We need it, so we can close the writer
// side of the stream safely. // side of the stream safely.
_ = s.stream.SetWriteDeadline(time.Now()) _ = s.stream.SetWriteDeadline(time.Now())

View File

@ -9,6 +9,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/rs/zerolog"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -70,7 +72,8 @@ func quicClient(t *testing.T, addr net.Addr) {
go func(iter int) { go func(iter int) {
defer wg.Done() defer wg.Done()
stream := NewSafeStreamCloser(quicStream) log := zerolog.Nop()
stream := NewSafeStreamCloser(quicStream, 30*time.Second, &log)
defer stream.Close() defer stream.Close()
// Do a bunch of round trips over this stream that should work. // Do a bunch of round trips over this stream that should work.
@ -107,7 +110,8 @@ func quicServer(t *testing.T, serverReady *sync.WaitGroup, conn net.PacketConn)
go func(iter int) { go func(iter int) {
defer wg.Done() defer wg.Done()
stream := NewSafeStreamCloser(quicStream) log := zerolog.Nop()
stream := NewSafeStreamCloser(quicStream, 30*time.Second, &log)
defer stream.Close() defer stream.Close()
// Do a bunch of round trips over this stream that should work. // Do a bunch of round trips over this stream that should work.

View File

@ -3,7 +3,6 @@ package quic
import ( import (
"context" "context"
"net" "net"
"time"
"github.com/quic-go/quic-go/logging" "github.com/quic-go/quic-go/logging"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -16,8 +15,6 @@ type tracer struct {
} }
type tracerConfig struct { type tracerConfig struct {
isClient bool
// Only client has an index
index uint8 index uint8
} }
@ -25,67 +22,36 @@ func NewClientTracer(logger *zerolog.Logger, index uint8) func(context.Context,
t := &tracer{ t := &tracer{
logger: logger, logger: logger,
config: &tracerConfig{ config: &tracerConfig{
isClient: true,
index: index, index: index,
}, },
} }
return t.TracerForConnection return t.TracerForConnection
} }
func NewServerTracer(logger *zerolog.Logger) *logging.Tracer {
return &logging.Tracer{
SentPacket: func(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) {},
SentVersionNegotiationPacket: func(_ net.Addr, dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {},
DroppedPacket: func(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) {},
}
}
func (t *tracer) TracerForConnection(_ctx context.Context, _p logging.Perspective, _odcid logging.ConnectionID) *logging.ConnectionTracer { func (t *tracer) TracerForConnection(_ctx context.Context, _p logging.Perspective, _odcid logging.ConnectionID) *logging.ConnectionTracer {
if t.config.isClient {
return newConnTracer(newClientCollector(t.config.index)) return newConnTracer(newClientCollector(t.config.index))
} }
return newConnTracer(newServiceCollector())
}
// connTracer collects connection level metrics // connTracer collects connection level metrics
type connTracer struct { type connTracer struct {
metricsCollector MetricsCollector metricsCollector *clientCollector
} }
func newConnTracer(metricsCollector MetricsCollector) *logging.ConnectionTracer { func newConnTracer(metricsCollector *clientCollector) *logging.ConnectionTracer {
tracer := connTracer{ tracer := connTracer{
metricsCollector: metricsCollector, metricsCollector: metricsCollector,
} }
return &logging.ConnectionTracer{ return &logging.ConnectionTracer{
StartedConnection: tracer.StartedConnection, StartedConnection: tracer.StartedConnection,
NegotiatedVersion: tracer.NegotiatedVersion,
ClosedConnection: tracer.ClosedConnection, ClosedConnection: tracer.ClosedConnection,
SentTransportParameters: tracer.SentTransportParameters,
ReceivedTransportParameters: tracer.ReceivedTransportParameters,
RestoredTransportParameters: tracer.RestoredTransportParameters,
SentLongHeaderPacket: tracer.SentLongHeaderPacket, SentLongHeaderPacket: tracer.SentLongHeaderPacket,
SentShortHeaderPacket: tracer.SentShortHeaderPacket, SentShortHeaderPacket: tracer.SentShortHeaderPacket,
ReceivedVersionNegotiationPacket: tracer.ReceivedVersionNegotiationPacket,
ReceivedRetry: tracer.ReceivedRetry,
ReceivedLongHeaderPacket: tracer.ReceivedLongHeaderPacket, ReceivedLongHeaderPacket: tracer.ReceivedLongHeaderPacket,
ReceivedShortHeaderPacket: tracer.ReceivedShortHeaderPacket, ReceivedShortHeaderPacket: tracer.ReceivedShortHeaderPacket,
BufferedPacket: tracer.BufferedPacket, BufferedPacket: tracer.BufferedPacket,
DroppedPacket: tracer.DroppedPacket, DroppedPacket: tracer.DroppedPacket,
UpdatedMetrics: tracer.UpdatedMetrics, UpdatedMetrics: tracer.UpdatedMetrics,
AcknowledgedPacket: tracer.AcknowledgedPacket,
LostPacket: tracer.LostPacket, LostPacket: tracer.LostPacket,
UpdatedCongestionState: tracer.UpdatedCongestionState,
UpdatedPTOCount: tracer.UpdatedPTOCount,
UpdatedKeyFromTLS: tracer.UpdatedKeyFromTLS,
UpdatedKey: tracer.UpdatedKey,
DroppedEncryptionLevel: tracer.DroppedEncryptionLevel,
DroppedKey: tracer.DroppedKey,
SetLossTimer: tracer.SetLossTimer,
LossTimerExpired: tracer.LossTimerExpired,
LossTimerCanceled: tracer.LossTimerCanceled,
ECNStateUpdated: tracer.ECNStateUpdated,
Close: tracer.Close,
Debug: tracer.Debug,
} }
} }
@ -97,14 +63,6 @@ func (ct *connTracer) ClosedConnection(err error) {
ct.metricsCollector.closedConnection(err) ct.metricsCollector.closedConnection(err)
} }
func (ct *connTracer) SentPacket(hdr *logging.ExtendedHeader, packetSize logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
ct.metricsCollector.sentPackets(packetSize)
}
func (ct *connTracer) ReceivedPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, frames []logging.Frame) {
ct.metricsCollector.receivedPackets(size)
}
func (ct *connTracer) BufferedPacket(pt logging.PacketType, size logging.ByteCount) { func (ct *connTracer) BufferedPacket(pt logging.PacketType, size logging.ByteCount) {
ct.metricsCollector.bufferedPackets(pt) ct.metricsCollector.bufferedPackets(pt)
} }
@ -121,74 +79,20 @@ func (ct *connTracer) UpdatedMetrics(rttStats *logging.RTTStats, cwnd, bytesInFl
ct.metricsCollector.updatedRTT(rttStats) ct.metricsCollector.updatedRTT(rttStats)
} }
func (ct *connTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) {
}
func (ct *connTracer) SentTransportParameters(parameters *logging.TransportParameters) {
}
func (ct *connTracer) ReceivedTransportParameters(parameters *logging.TransportParameters) {
}
func (ct *connTracer) RestoredTransportParameters(parameters *logging.TransportParameters) {
}
func (ct *connTracer) SentLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { func (ct *connTracer) SentLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) {
ct.metricsCollector.sentPackets(size, frames)
} }
func (ct *connTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) { func (ct *connTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) {
} ct.metricsCollector.sentPackets(size, frames)
func (ct *connTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
}
func (ct *connTracer) ReceivedRetry(header *logging.Header) {
} }
func (ct *connTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { func (ct *connTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) {
ct.metricsCollector.receivedPackets(size, frames)
} }
func (ct *connTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) { func (ct *connTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) {
} ct.metricsCollector.receivedPackets(size, frames)
func (ct *connTracer) AcknowledgedPacket(level logging.EncryptionLevel, number logging.PacketNumber) {
}
func (ct *connTracer) UpdatedCongestionState(state logging.CongestionState) {
}
func (ct *connTracer) UpdatedPTOCount(value uint32) {
}
func (ct *connTracer) UpdatedKeyFromTLS(level logging.EncryptionLevel, perspective logging.Perspective) {
}
func (ct *connTracer) UpdatedKey(generation logging.KeyPhase, remote bool) {
}
func (ct *connTracer) DroppedEncryptionLevel(level logging.EncryptionLevel) {
}
func (ct *connTracer) DroppedKey(generation logging.KeyPhase) {
}
func (ct *connTracer) SetLossTimer(timerType logging.TimerType, level logging.EncryptionLevel, time time.Time) {
}
func (ct *connTracer) LossTimerExpired(timerType logging.TimerType, level logging.EncryptionLevel) {
}
func (ct *connTracer) LossTimerCanceled() {
}
func (ct *connTracer) ECNStateUpdated(state logging.ECNState, trigger logging.ECNStateTrigger) {
}
func (ct *connTracer) Close() {
}
func (ct *connTracer) Debug(name, msg string) {
} }
type quicLogger struct { type quicLogger struct {

View File

@ -15,7 +15,8 @@ import (
"os" "os"
"time" "time"
"github.com/go-jose/go-jose/v3/jwt" "github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
homedir "github.com/mitchellh/go-homedir" homedir "github.com/mitchellh/go-homedir"
"github.com/pkg/errors" "github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh" gossh "golang.org/x/crypto/ssh"
@ -51,6 +52,8 @@ type errorResponse struct {
var mockRequest func(url, contentType string, body io.Reader) (*http.Response, error) = nil var mockRequest func(url, contentType string, body io.Reader) (*http.Response, error) = nil
var signatureAlgs = []jose.SignatureAlgorithm{jose.RS256}
// GenerateShortLivedCertificate generates and stores a keypair for short lived certs // GenerateShortLivedCertificate generates and stores a keypair for short lived certs
func GenerateShortLivedCertificate(appURL *url.URL, token string) error { func GenerateShortLivedCertificate(appURL *url.URL, token string) error {
fullName, err := cfpath.GenerateSSHCertFilePathFromURL(appURL, keyName) fullName, err := cfpath.GenerateSSHCertFilePathFromURL(appURL, keyName)
@ -87,7 +90,7 @@ func SignCert(token, pubKey string) (string, error) {
return "", errors.New("invalid token") return "", errors.New("invalid token")
} }
parsedToken, err := jwt.ParseSigned(token) parsedToken, err := jwt.ParseSigned(token, signatureAlgs)
if err != nil { if err != nil {
return "", errors.Wrap(err, "failed to parse JWT") return "", errors.Wrap(err, "failed to parse JWT")
} }

View File

@ -3,6 +3,8 @@
package sshgen package sshgen
import ( import (
"crypto/rand"
"crypto/rsa"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -14,8 +16,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v3/jwt" "github.com/go-jose/go-jose/v4/jwt"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/config" "github.com/cloudflare/cloudflared/config"
@ -103,13 +105,16 @@ func tokenGenerator() string {
Expiry: jwt.NewNumericDate(exp), Expiry: jwt.NewNumericDate(exp),
} }
key := []byte("secret") key, err := rsa.GenerateKey(rand.Reader, 4096)
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, (&jose.SignerOptions{}).WithType("JWT")) if err != nil {
panic(err)
}
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: key}, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil { if err != nil {
panic(err) panic(err)
} }
signedToken, err := jwt.Signed(signer).Claims(claims).CompactSerialize() signedToken, err := jwt.Signed(signer).Claims(claims).Serialize()
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -66,6 +66,7 @@ type TunnelConfig struct {
PacketConfig *ingress.GlobalRouterConfig PacketConfig *ingress.GlobalRouterConfig
UDPUnregisterSessionTimeout time.Duration UDPUnregisterSessionTimeout time.Duration
WriteStreamTimeout time.Duration
DisableQUICPathMTUDiscovery bool DisableQUICPathMTUDiscovery bool
@ -614,6 +615,7 @@ func (e *EdgeTunnelServer) serveQUIC(
connLogger.Logger(), connLogger.Logger(),
e.config.PacketConfig, e.config.PacketConfig,
e.config.UDPUnregisterSessionTimeout, e.config.UDPUnregisterSessionTimeout,
e.config.WriteStreamTimeout,
) )
if err != nil { if err != nil {
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection") connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection")

View File

@ -12,7 +12,7 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v4"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@ -32,6 +32,7 @@ const (
var ( var (
userAgent = "DEV" userAgent = "DEV"
signatureAlgs = []jose.SignatureAlgorithm{jose.RS256}
) )
type AppInfo struct { type AppInfo struct {
@ -415,7 +416,7 @@ func getTokenIfExists(path string) (*jose.JSONWebSignature, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
token, err := jose.ParseSigned(string(content)) token, err := jose.ParseSigned(string(content), signatureAlgs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,5 +1,7 @@
package oidc package oidc
import jose "github.com/go-jose/go-jose/v4"
// JOSE asymmetric signing algorithm values as defined by RFC 7518 // JOSE asymmetric signing algorithm values as defined by RFC 7518
// //
// see: https://tools.ietf.org/html/rfc7518#section-3.1 // see: https://tools.ietf.org/html/rfc7518#section-3.1
@ -15,3 +17,16 @@ const (
PS512 = "PS512" // RSASSA-PSS using SHA512 and MGF1-SHA512 PS512 = "PS512" // RSASSA-PSS using SHA512 and MGF1-SHA512
EdDSA = "EdDSA" // Ed25519 using SHA-512 EdDSA = "EdDSA" // Ed25519 using SHA-512
) )
var allAlgs = []jose.SignatureAlgorithm{
jose.RS256,
jose.RS384,
jose.RS512,
jose.ES256,
jose.ES384,
jose.ES512,
jose.PS256,
jose.PS384,
jose.PS512,
jose.EdDSA,
}

View File

@ -8,12 +8,12 @@ import (
"crypto/rsa" "crypto/rsa"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io"
"net/http" "net/http"
"sync" "sync"
"time" "time"
jose "github.com/go-jose/go-jose/v3" jose "github.com/go-jose/go-jose/v4"
) )
// StaticKeySet is a verifier that validates JWT against a static set of public keys. // StaticKeySet is a verifier that validates JWT against a static set of public keys.
@ -25,7 +25,9 @@ type StaticKeySet struct {
// VerifySignature compares the signature against a static set of public keys. // VerifySignature compares the signature against a static set of public keys.
func (s *StaticKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) { func (s *StaticKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
jws, err := jose.ParseSigned(jwt) // Algorithms are already checked by Verifier, so this parse method accepts
// any algorithm.
jws, err := jose.ParseSigned(jwt, allAlgs)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing jwt: %v", err) return nil, fmt.Errorf("parsing jwt: %v", err)
} }
@ -127,8 +129,13 @@ var parsedJWTKey contextKey
func (r *RemoteKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) { func (r *RemoteKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
jws, ok := ctx.Value(parsedJWTKey).(*jose.JSONWebSignature) jws, ok := ctx.Value(parsedJWTKey).(*jose.JSONWebSignature)
if !ok { if !ok {
// The algorithm values are already enforced by the Validator, which also sets
// the context value above to pre-parsed signature.
//
// Practically, this codepath isn't called in normal use of this package, but
// if it is, the algorithms have already been checked.
var err error var err error
jws, err = jose.ParseSigned(jwt) jws, err = jose.ParseSigned(jwt, allAlgs)
if err != nil { if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err) return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
} }
@ -159,7 +166,7 @@ func (r *RemoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) (
// https://openid.net/specs/openid-connect-core-1_0.html#RotateSigKeys // https://openid.net/specs/openid-connect-core-1_0.html#RotateSigKeys
keys, err := r.keysFromRemote(ctx) keys, err := r.keysFromRemote(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("fetching keys %v", err) return nil, fmt.Errorf("fetching keys %w", err)
} }
for _, key := range keys { for _, key := range keys {
@ -228,11 +235,11 @@ func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) {
resp, err := doRequest(r.ctx, req) resp, err := doRequest(r.ctx, req)
if err != nil { if err != nil {
return nil, fmt.Errorf("oidc: get keys failed %v", err) return nil, fmt.Errorf("oidc: get keys failed %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to read response body: %v", err) return nil, fmt.Errorf("unable to read response body: %v", err)
} }

View File

@ -10,7 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"hash" "hash"
"io/ioutil" "io"
"mime" "mime"
"net/http" "net/http"
"strings" "strings"
@ -79,7 +79,7 @@ func getClient(ctx context.Context) *http.Client {
// provider, err := oidc.NewProvider(ctx, discoveryBaseURL) // provider, err := oidc.NewProvider(ctx, discoveryBaseURL)
// //
// This is insecure because validating the correct issuer is critical for multi-tenant // This is insecure because validating the correct issuer is critical for multi-tenant
// proivders. Any overrides here MUST be carefully reviewed. // providers. Any overrides here MUST be carefully reviewed.
func InsecureIssuerURLContext(ctx context.Context, issuerURL string) context.Context { func InsecureIssuerURLContext(ctx context.Context, issuerURL string) context.Context {
return context.WithValue(ctx, issuerURLKey, issuerURL) return context.WithValue(ctx, issuerURLKey, issuerURL)
} }
@ -97,6 +97,7 @@ type Provider struct {
issuer string issuer string
authURL string authURL string
tokenURL string tokenURL string
deviceAuthURL string
userInfoURL string userInfoURL string
jwksURL string jwksURL string
algorithms []string algorithms []string
@ -131,6 +132,7 @@ type providerJSON struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"` AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"` TokenURL string `json:"token_endpoint"`
DeviceAuthURL string `json:"device_authorization_endpoint"`
JWKSURL string `json:"jwks_uri"` JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"` UserInfoURL string `json:"userinfo_endpoint"`
Algorithms []string `json:"id_token_signing_alg_values_supported"` Algorithms []string `json:"id_token_signing_alg_values_supported"`
@ -165,6 +167,9 @@ type ProviderConfig struct {
// TokenURL is the endpoint used by the provider to support the OAuth 2.0 // TokenURL is the endpoint used by the provider to support the OAuth 2.0
// token endpoint. // token endpoint.
TokenURL string TokenURL string
// DeviceAuthURL is the endpoint used by the provider to support the OAuth 2.0
// device authorization endpoint.
DeviceAuthURL string
// UserInfoURL is the endpoint used by the provider to support the OpenID // UserInfoURL is the endpoint used by the provider to support the OpenID
// Connect UserInfo flow. // Connect UserInfo flow.
// //
@ -188,6 +193,7 @@ func (p *ProviderConfig) NewProvider(ctx context.Context) *Provider {
issuer: p.IssuerURL, issuer: p.IssuerURL,
authURL: p.AuthURL, authURL: p.AuthURL,
tokenURL: p.TokenURL, tokenURL: p.TokenURL,
deviceAuthURL: p.DeviceAuthURL,
userInfoURL: p.UserInfoURL, userInfoURL: p.UserInfoURL,
jwksURL: p.JWKSURL, jwksURL: p.JWKSURL,
algorithms: p.Algorithms, algorithms: p.Algorithms,
@ -211,7 +217,7 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to read response body: %v", err) return nil, fmt.Errorf("unable to read response body: %v", err)
} }
@ -243,6 +249,7 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
issuer: issuerURL, issuer: issuerURL,
authURL: p.AuthURL, authURL: p.AuthURL,
tokenURL: p.TokenURL, tokenURL: p.TokenURL,
deviceAuthURL: p.DeviceAuthURL,
userInfoURL: p.UserInfoURL, userInfoURL: p.UserInfoURL,
jwksURL: p.JWKSURL, jwksURL: p.JWKSURL,
algorithms: algs, algorithms: algs,
@ -273,7 +280,7 @@ func (p *Provider) Claims(v interface{}) error {
// Endpoint returns the OAuth2 auth and token endpoints for the given provider. // Endpoint returns the OAuth2 auth and token endpoints for the given provider.
func (p *Provider) Endpoint() oauth2.Endpoint { func (p *Provider) Endpoint() oauth2.Endpoint {
return oauth2.Endpoint{AuthURL: p.authURL, TokenURL: p.tokenURL} return oauth2.Endpoint{AuthURL: p.authURL, DeviceAuthURL: p.deviceAuthURL, TokenURL: p.tokenURL}
} }
// UserInfoEndpoint returns the OpenID Connect userinfo endpoint for the given // UserInfoEndpoint returns the OpenID Connect userinfo endpoint for the given
@ -332,7 +339,7 @@ func (p *Provider) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource)
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -7,12 +7,12 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io"
"net/http" "net/http"
"strings" "strings"
"time" "time"
jose "github.com/go-jose/go-jose/v3" jose "github.com/go-jose/go-jose/v4"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@ -182,7 +182,7 @@ func resolveDistributedClaim(ctx context.Context, verifier *IDTokenVerifier, src
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to read response body: %v", err) return nil, fmt.Errorf("unable to read response body: %v", err)
} }
@ -310,7 +310,16 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
return t, nil return t, nil
} }
jws, err := jose.ParseSigned(rawIDToken) var supportedSigAlgs []jose.SignatureAlgorithm
for _, alg := range v.config.SupportedSigningAlgs {
supportedSigAlgs = append(supportedSigAlgs, jose.SignatureAlgorithm(alg))
}
if len(supportedSigAlgs) == 0 {
// If no algorithms were specified by both the config and discovery, default
// to the one mandatory algorithm "RS256".
supportedSigAlgs = []jose.SignatureAlgorithm{jose.RS256}
}
jws, err := jose.ParseSigned(rawIDToken, supportedSigAlgs)
if err != nil { if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err) return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
} }
@ -322,17 +331,7 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
default: default:
return nil, fmt.Errorf("oidc: multiple signatures on id token not supported") return nil, fmt.Errorf("oidc: multiple signatures on id token not supported")
} }
sig := jws.Signatures[0] sig := jws.Signatures[0]
supportedSigAlgs := v.config.SupportedSigningAlgs
if len(supportedSigAlgs) == 0 {
supportedSigAlgs = []string{RS256}
}
if !contains(supportedSigAlgs, sig.Header.Algorithm) {
return nil, fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", supportedSigAlgs, sig.Header.Algorithm)
}
t.sigAlgorithm = sig.Header.Algorithm t.sigAlgorithm = sig.Header.Algorithm
ctx = context.WithValue(ctx, parsedJWTKey, jws) ctx = context.WithValue(ctx, parsedJWTKey, jws)

16
vendor/github.com/fortytw2/leaktest/.travis.yml generated vendored Normal file
View File

@ -0,0 +1,16 @@
language: go
go:
- 1.8
- 1.9
- "1.10"
- "1.11"
- tip
script:
- go test -v -race -parallel 5 -coverprofile=coverage.txt -covermode=atomic ./
- go test github.com/fortytw2/leaktest -run ^TestEmptyLeak$
before_install:
- pip install --user codecov
after_success:
- codecov

64
vendor/github.com/fortytw2/leaktest/README.md generated vendored Normal file
View File

@ -0,0 +1,64 @@
## Leaktest [![Build Status](https://travis-ci.org/fortytw2/leaktest.svg?branch=master)](https://travis-ci.org/fortytw2/leaktest) [![codecov](https://codecov.io/gh/fortytw2/leaktest/branch/master/graph/badge.svg)](https://codecov.io/gh/fortytw2/leaktest) [![Sourcegraph](https://sourcegraph.com/github.com/fortytw2/leaktest/-/badge.svg)](https://sourcegraph.com/github.com/fortytw2/leaktest?badge) [![Documentation](https://godoc.org/github.com/fortytw2/gpt?status.svg)](http://godoc.org/github.com/fortytw2/leaktest)
Refactored, tested variant of the goroutine leak detector found in both
`net/http` tests and the `cockroachdb` source tree.
Takes a snapshot of running goroutines at the start of a test, and at the end -
compares the two and _voila_. Ignores runtime/sys goroutines. Doesn't play nice
with `t.Parallel()` right now, but there are plans to do so.
### Installation
Go 1.7+
```
go get -u github.com/fortytw2/leaktest
```
Go 1.5/1.6 need to use the tag `v1.0.0`, as newer versions depend on
`context.Context`.
### Example
These tests fail, because they leak a goroutine
```go
// Default "Check" will poll for 5 seconds to check that all
// goroutines are cleaned up
func TestPool(t *testing.T) {
defer leaktest.Check(t)()
go func() {
for {
time.Sleep(time.Second)
}
}()
}
// Helper function to timeout after X duration
func TestPoolTimeout(t *testing.T) {
defer leaktest.CheckTimeout(t, time.Second)()
go func() {
for {
time.Sleep(time.Second)
}
}()
}
// Use Go 1.7+ context.Context for cancellation
func TestPoolContext(t *testing.T) {
ctx, _ := context.WithTimeout(context.Background(), time.Second)
defer leaktest.CheckContext(ctx, t)()
go func() {
for {
time.Sleep(time.Second)
}
}()
}
```
## LICENSE
Same BSD-style as Go, see LICENSE

153
vendor/github.com/fortytw2/leaktest/leaktest.go generated vendored Normal file
View File

@ -0,0 +1,153 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package leaktest provides tools to detect leaked goroutines in tests.
// To use it, call "defer leaktest.Check(t)()" at the beginning of each
// test that may use goroutines.
// copied out of the cockroachdb source tree with slight modifications to be
// more re-useable
package leaktest
import (
"context"
"fmt"
"runtime"
"sort"
"strconv"
"strings"
"time"
)
type goroutine struct {
id uint64
stack string
}
type goroutineByID []*goroutine
func (g goroutineByID) Len() int { return len(g) }
func (g goroutineByID) Less(i, j int) bool { return g[i].id < g[j].id }
func (g goroutineByID) Swap(i, j int) { g[i], g[j] = g[j], g[i] }
func interestingGoroutine(g string) (*goroutine, error) {
sl := strings.SplitN(g, "\n", 2)
if len(sl) != 2 {
return nil, fmt.Errorf("error parsing stack: %q", g)
}
stack := strings.TrimSpace(sl[1])
if strings.HasPrefix(stack, "testing.RunTests") {
return nil, nil
}
if stack == "" ||
// Ignore HTTP keep alives
strings.Contains(stack, ").readLoop(") ||
strings.Contains(stack, ").writeLoop(") ||
// Below are the stacks ignored by the upstream leaktest code.
strings.Contains(stack, "testing.Main(") ||
strings.Contains(stack, "testing.(*T).Run(") ||
strings.Contains(stack, "runtime.goexit") ||
strings.Contains(stack, "created by runtime.gc") ||
strings.Contains(stack, "interestingGoroutines") ||
strings.Contains(stack, "runtime.MHeap_Scavenger") ||
strings.Contains(stack, "signal.signal_recv") ||
strings.Contains(stack, "sigterm.handler") ||
strings.Contains(stack, "runtime_mcall") ||
strings.Contains(stack, "goroutine in C code") {
return nil, nil
}
// Parse the goroutine's ID from the header line.
h := strings.SplitN(sl[0], " ", 3)
if len(h) < 3 {
return nil, fmt.Errorf("error parsing stack header: %q", sl[0])
}
id, err := strconv.ParseUint(h[1], 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing goroutine id: %s", err)
}
return &goroutine{id: id, stack: strings.TrimSpace(g)}, nil
}
// interestingGoroutines returns all goroutines we care about for the purpose
// of leak checking. It excludes testing or runtime ones.
func interestingGoroutines(t ErrorReporter) []*goroutine {
buf := make([]byte, 2<<20)
buf = buf[:runtime.Stack(buf, true)]
var gs []*goroutine
for _, g := range strings.Split(string(buf), "\n\n") {
gr, err := interestingGoroutine(g)
if err != nil {
t.Errorf("leaktest: %s", err)
continue
} else if gr == nil {
continue
}
gs = append(gs, gr)
}
sort.Sort(goroutineByID(gs))
return gs
}
// ErrorReporter is a tiny subset of a testing.TB to make testing not such a
// massive pain
type ErrorReporter interface {
Errorf(format string, args ...interface{})
}
// Check snapshots the currently-running goroutines and returns a
// function to be run at the end of tests to see whether any
// goroutines leaked, waiting up to 5 seconds in error conditions
func Check(t ErrorReporter) func() {
return CheckTimeout(t, 5*time.Second)
}
// CheckTimeout is the same as Check, but with a configurable timeout
func CheckTimeout(t ErrorReporter, dur time.Duration) func() {
ctx, cancel := context.WithCancel(context.Background())
fn := CheckContext(ctx, t)
return func() {
timer := time.AfterFunc(dur, cancel)
fn()
// Remember to clean up the timer and context
timer.Stop()
cancel()
}
}
// CheckContext is the same as Check, but uses a context.Context for
// cancellation and timeout control
func CheckContext(ctx context.Context, t ErrorReporter) func() {
orig := map[uint64]bool{}
for _, g := range interestingGoroutines(t) {
orig[g.id] = true
}
return func() {
var leaked []string
for {
select {
case <-ctx.Done():
t.Errorf("leaktest: timed out checking goroutines")
default:
leaked = make([]string, 0)
for _, g := range interestingGoroutines(t) {
if !orig[g.id] {
leaked = append(leaked, g.stack)
}
}
if len(leaked) == 0 {
return
}
// don't spin needlessly
time.Sleep(time.Millisecond * 50)
continue
}
break
}
for _, g := range leaked {
t.Errorf("leaktest: leaked goroutine: %v", g)
}
}
}

View File

@ -1,10 +0,0 @@
Serious about security
======================
Square recognizes the important contributions the security research community
can make. We therefore encourage reporting security issues with the code
contained in this repository.
If you believe you have discovered a security vulnerability, please follow the
guidelines at <https://bugcrowd.com/squareopensource>.

View File

@ -1,133 +0,0 @@
/*-
* Copyright 2016 Zbigniew Mandziejewicz
* Copyright 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jwt
import (
"fmt"
"strings"
jose "github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/json"
)
// JSONWebToken represents a JSON Web Token (as specified in RFC7519).
type JSONWebToken struct {
payload func(k interface{}) ([]byte, error)
unverifiedPayload func() []byte
Headers []jose.Header
}
type NestedJSONWebToken struct {
enc *jose.JSONWebEncryption
Headers []jose.Header
}
// Claims deserializes a JSONWebToken into dest using the provided key.
func (t *JSONWebToken) Claims(key interface{}, dest ...interface{}) error {
b, err := t.payload(key)
if err != nil {
return err
}
for _, d := range dest {
if err := json.Unmarshal(b, d); err != nil {
return err
}
}
return nil
}
// UnsafeClaimsWithoutVerification deserializes the claims of a
// JSONWebToken into the dests. For signed JWTs, the claims are not
// verified. This function won't work for encrypted JWTs.
func (t *JSONWebToken) UnsafeClaimsWithoutVerification(dest ...interface{}) error {
if t.unverifiedPayload == nil {
return fmt.Errorf("go-jose/go-jose: Cannot get unverified claims")
}
claims := t.unverifiedPayload()
for _, d := range dest {
if err := json.Unmarshal(claims, d); err != nil {
return err
}
}
return nil
}
func (t *NestedJSONWebToken) Decrypt(decryptionKey interface{}) (*JSONWebToken, error) {
b, err := t.enc.Decrypt(decryptionKey)
if err != nil {
return nil, err
}
sig, err := ParseSigned(string(b))
if err != nil {
return nil, err
}
return sig, nil
}
// ParseSigned parses token from JWS form.
func ParseSigned(s string) (*JSONWebToken, error) {
sig, err := jose.ParseSigned(s)
if err != nil {
return nil, err
}
headers := make([]jose.Header, len(sig.Signatures))
for i, signature := range sig.Signatures {
headers[i] = signature.Header
}
return &JSONWebToken{
payload: sig.Verify,
unverifiedPayload: sig.UnsafePayloadWithoutVerification,
Headers: headers,
}, nil
}
// ParseEncrypted parses token from JWE form.
func ParseEncrypted(s string) (*JSONWebToken, error) {
enc, err := jose.ParseEncrypted(s)
if err != nil {
return nil, err
}
return &JSONWebToken{
payload: enc.Decrypt,
Headers: []jose.Header{enc.Header},
}, nil
}
// ParseSignedAndEncrypted parses signed-then-encrypted token from JWE form.
func ParseSignedAndEncrypted(s string) (*NestedJSONWebToken, error) {
enc, err := jose.ParseEncrypted(s)
if err != nil {
return nil, err
}
contentType, _ := enc.Header.ExtraHeaders[jose.HeaderContentType].(string)
if strings.ToUpper(contentType) != "JWT" {
return nil, ErrInvalidContentType
}
return &NestedJSONWebToken{
enc: enc,
Headers: []jose.Header{enc.Header},
}, nil
}

72
vendor/github.com/go-jose/go-jose/v4/CHANGELOG.md generated vendored Normal file
View File

@ -0,0 +1,72 @@
# v4.0.1
## Fixed
- An attacker could send a JWE containing compressed data that used large
amounts of memory and CPU when decompressed by `Decrypt` or `DecryptMulti`.
Those functions now return an error if the decompressed data would exceed
250kB or 10x the compressed size (whichever is larger). Thanks to
Enze Wang@Alioth and Jianjun Chen@Zhongguancun Lab (@zer0yu and @chenjj)
for reporting.
# v4.0.0
This release makes some breaking changes in order to more thoroughly
address the vulnerabilities discussed in [Three New Attacks Against JSON Web
Tokens][1], "Sign/encrypt confusion", "Billion hash attack", and "Polyglot
token".
## Changed
- Limit JWT encryption types (exclude password or public key types) (#78)
- Enforce minimum length for HMAC keys (#85)
- jwt: match any audience in a list, rather than requiring all audiences (#81)
- jwt: accept only Compact Serialization (#75)
- jws: Add expected algorithms for signatures (#74)
- Require specifying expected algorithms for ParseEncrypted,
ParseSigned, ParseDetached, jwt.ParseEncrypted, jwt.ParseSigned,
jwt.ParseSignedAndEncrypted (#69, #74)
- Usually there is a small, known set of appropriate algorithms for a program
to use and it's a mistake to allow unexpected algorithms. For instance the
"billion hash attack" relies in part on programs accepting the PBES2
encryption algorithm and doing the necessary work even if they weren't
specifically configured to allow PBES2.
- Revert "Strip padding off base64 strings" (#82)
- The specs require base64url encoding without padding.
- Minimum supported Go version is now 1.21
## Added
- ParseSignedCompact, ParseSignedJSON, ParseEncryptedCompact, ParseEncryptedJSON.
- These allow parsing a specific serialization, as opposed to ParseSigned and
ParseEncrypted, which try to automatically detect which serialization was
provided. It's common to require a specific serialization for a specific
protocol - for instance JWT requires Compact serialization.
[1]: https://i.blackhat.com/BH-US-23/Presentations/US-23-Tervoort-Three-New-Attacks-Against-JSON-Web-Tokens.pdf
# v3.0.2
## Fixed
- DecryptMulti: handle decompression error (#19)
## Changed
- jwe/CompactSerialize: improve performance (#67)
- Increase the default number of PBKDF2 iterations to 600k (#48)
- Return the proper algorithm for ECDSA keys (#45)
## Added
- Add Thumbprint support for opaque signers (#38)
# v3.0.1
## Fixed
- Security issue: an attacker specifying a large "p2c" value can cause
JSONWebEncryption.Decrypt and JSONWebEncryption.DecryptMulti to consume large
amounts of CPU, causing a DoS. Thanks to Matt Schwager (@mschwager) for the
disclosure and to Tom Tervoort for originally publishing the category of attack.
https://i.blackhat.com/BH-US-23/Presentations/US-23-Tervoort-Three-New-Attacks-Against-JSON-Web-Tokens.pdf

View File

@ -1,10 +1,9 @@
# Go JOSE # Go JOSE
[![godoc](http://img.shields.io/badge/godoc-jose_package-blue.svg?style=flat)](https://godoc.org/gopkg.in/go-jose/go-jose.v2) [![godoc](https://pkg.go.dev/badge/github.com/go-jose/go-jose/v4.svg)](https://pkg.go.dev/github.com/go-jose/go-jose/v4)
[![godoc](http://img.shields.io/badge/godoc-jwt_package-blue.svg?style=flat)](https://godoc.org/gopkg.in/go-jose/go-jose.v2/jwt) [![godoc](https://pkg.go.dev/badge/github.com/go-jose/go-jose/v4/jwt.svg)](https://pkg.go.dev/github.com/go-jose/go-jose/v4/jwt)
[![license](http://img.shields.io/badge/license-apache_2.0-blue.svg?style=flat)](https://raw.githubusercontent.com/go-jose/go-jose/master/LICENSE) [![license](https://img.shields.io/badge/license-apache_2.0-blue.svg?style=flat)](https://raw.githubusercontent.com/go-jose/go-jose/master/LICENSE)
[![build](https://travis-ci.org/go-jose/go-jose.svg?branch=master)](https://travis-ci.org/go-jose/go-jose) [![test](https://img.shields.io/github/checks-status/go-jose/go-jose/v4)](https://github.com/go-jose/go-jose/actions)
[![coverage](https://coveralls.io/repos/github/go-jose/go-jose/badge.svg?branch=master)](https://coveralls.io/r/go-jose/go-jose)
Package jose aims to provide an implementation of the Javascript Object Signing Package jose aims to provide an implementation of the Javascript Object Signing
and Encryption set of standards. This includes support for JSON Web Encryption, and Encryption set of standards. This includes support for JSON Web Encryption,
@ -21,13 +20,13 @@ US maintained blocked list.
## Overview ## Overview
The implementation follows the The implementation follows the
[JSON Web Encryption](http://dx.doi.org/10.17487/RFC7516) (RFC 7516), [JSON Web Encryption](https://dx.doi.org/10.17487/RFC7516) (RFC 7516),
[JSON Web Signature](http://dx.doi.org/10.17487/RFC7515) (RFC 7515), and [JSON Web Signature](https://dx.doi.org/10.17487/RFC7515) (RFC 7515), and
[JSON Web Token](http://dx.doi.org/10.17487/RFC7519) (RFC 7519) specifications. [JSON Web Token](https://dx.doi.org/10.17487/RFC7519) (RFC 7519) specifications.
Tables of supported algorithms are shown below. The library supports both Tables of supported algorithms are shown below. The library supports both
the compact and JWS/JWE JSON Serialization formats, and has optional support for the compact and JWS/JWE JSON Serialization formats, and has optional support for
multiple recipients. It also comes with a small command-line utility multiple recipients. It also comes with a small command-line utility
([`jose-util`](https://github.com/go-jose/go-jose/tree/master/jose-util)) ([`jose-util`](https://pkg.go.dev/github.com/go-jose/go-jose/jose-util))
for dealing with JOSE messages in a shell. for dealing with JOSE messages in a shell.
**Note**: We use a forked version of the `encoding/json` package from the Go **Note**: We use a forked version of the `encoding/json` package from the Go
@ -38,29 +37,22 @@ libraries in other languages.
### Versions ### Versions
[Version 2](https://gopkg.in/go-jose/go-jose.v2) [Version 4](https://github.com/go-jose/go-jose)
([branch](https://github.com/go-jose/go-jose/tree/v2), ([branch](https://github.com/go-jose/go-jose/tree/main),
[doc](https://godoc.org/gopkg.in/go-jose/go-jose.v2)) is the current stable version: [doc](https://pkg.go.dev/github.com/go-jose/go-jose/v4), [releases](https://github.com/go-jose/go-jose/releases)) is the current stable version:
import "gopkg.in/go-jose/go-jose.v2" import "github.com/go-jose/go-jose/v4"
[Version 3](https://github.com/go-jose/go-jose) The old [square/go-jose](https://github.com/square/go-jose) repo contains the prior v1 and v2 versions, which
([branch](https://github.com/go-jose/go-jose/tree/master), are still useable but not actively developed anymore.
[doc](https://godoc.org/github.com/go-jose/go-jose)) is the under development/unstable version (not released yet):
import "github.com/go-jose/go-jose/v3" Version 3, in this repo, is still receiving security fixes but not functionality
updates.
All new feature development takes place on the `master` branch, which we are
preparing to release as version 3 soon. Version 2 will continue to receive
critical bug and security fixes. Note that starting with version 3 we are
using Go modules for versioning instead of `gopkg.in` as before. Version 3 also will require Go version 1.13 or higher.
Version 1 (on the `v1` branch) is frozen and not supported anymore.
### Supported algorithms ### Supported algorithms
See below for a table of supported algorithms. Algorithm identifiers match See below for a table of supported algorithms. Algorithm identifiers match
the names in the [JSON Web Algorithms](http://dx.doi.org/10.17487/RFC7518) the names in the [JSON Web Algorithms](https://dx.doi.org/10.17487/RFC7518)
standard where possible. The Godoc reference has a list of constants. standard where possible. The Godoc reference has a list of constants.
Key encryption | Algorithm identifier(s) Key encryption | Algorithm identifier(s)
@ -103,20 +95,20 @@ allows attaching a key id.
Algorithm(s) | Corresponding types Algorithm(s) | Corresponding types
:------------------------- | ------------------------------- :------------------------- | -------------------------------
RSA | *[rsa.PublicKey](http://golang.org/pkg/crypto/rsa/#PublicKey), *[rsa.PrivateKey](http://golang.org/pkg/crypto/rsa/#PrivateKey) RSA | *[rsa.PublicKey](https://pkg.go.dev/crypto/rsa/#PublicKey), *[rsa.PrivateKey](https://pkg.go.dev/crypto/rsa/#PrivateKey)
ECDH, ECDSA | *[ecdsa.PublicKey](http://golang.org/pkg/crypto/ecdsa/#PublicKey), *[ecdsa.PrivateKey](http://golang.org/pkg/crypto/ecdsa/#PrivateKey) ECDH, ECDSA | *[ecdsa.PublicKey](https://pkg.go.dev/crypto/ecdsa/#PublicKey), *[ecdsa.PrivateKey](https://pkg.go.dev/crypto/ecdsa/#PrivateKey)
EdDSA<sup>1</sup> | [ed25519.PublicKey](https://godoc.org/pkg/crypto/ed25519#PublicKey), [ed25519.PrivateKey](https://godoc.org/pkg/crypto/ed25519#PrivateKey) EdDSA<sup>1</sup> | [ed25519.PublicKey](https://pkg.go.dev/crypto/ed25519#PublicKey), [ed25519.PrivateKey](https://pkg.go.dev/crypto/ed25519#PrivateKey)
AES, HMAC | []byte AES, HMAC | []byte
<sup>1. Only available in version 2 or later of the package</sup> <sup>1. Only available in version 2 or later of the package</sup>
## Examples ## Examples
[![godoc](http://img.shields.io/badge/godoc-jose_package-blue.svg?style=flat)](https://godoc.org/gopkg.in/go-jose/go-jose.v2) [![godoc](https://pkg.go.dev/badge/github.com/go-jose/go-jose/v4.svg)](https://pkg.go.dev/github.com/go-jose/go-jose/v4)
[![godoc](http://img.shields.io/badge/godoc-jwt_package-blue.svg?style=flat)](https://godoc.org/gopkg.in/go-jose/go-jose.v2/jwt) [![godoc](https://pkg.go.dev/badge/github.com/go-jose/go-jose/v4/jwt.svg)](https://pkg.go.dev/github.com/go-jose/go-jose/v4/jwt)
Examples can be found in the Godoc Examples can be found in the Godoc
reference for this package. The reference for this package. The
[`jose-util`](https://github.com/go-jose/go-jose/tree/master/jose-util) [`jose-util`](https://github.com/go-jose/go-jose/tree/v4/jose-util)
subdirectory also contains a small command-line utility which might be useful subdirectory also contains a small command-line utility which might be useful
as an example as well. as an example as well.

13
vendor/github.com/go-jose/go-jose/v4/SECURITY.md generated vendored Normal file
View File

@ -0,0 +1,13 @@
# Security Policy
This document explains how to contact the Let's Encrypt security team to report security vulnerabilities.
## Supported Versions
| Version | Supported |
| ------- | ----------|
| >= v3 | &check; |
| v2 | &cross; |
| v1 | &cross; |
## Reporting a vulnerability
Please see [https://letsencrypt.org/contact/#security](https://letsencrypt.org/contact/#security) for the email address to report a vulnerability. Ensure that the subject line for your report contains the word `vulnerability` and is descriptive. Your email should be acknowledged within 24 hours. If you do not receive a response within 24 hours, please follow-up again with another email.

View File

@ -29,8 +29,8 @@ import (
"fmt" "fmt"
"math/big" "math/big"
josecipher "github.com/go-jose/go-jose/v3/cipher" josecipher "github.com/go-jose/go-jose/v4/cipher"
"github.com/go-jose/go-jose/v3/json" "github.com/go-jose/go-jose/v4/json"
) )
// A generic RSA-based encrypter/verifier // A generic RSA-based encrypter/verifier
@ -285,6 +285,9 @@ func (ctx rsaDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm
switch alg { switch alg {
case RS256, RS384, RS512: case RS256, RS384, RS512:
// TODO(https://github.com/go-jose/go-jose/issues/40): As of go1.20, the
// random parameter is legacy and ignored, and it can be nil.
// https://cs.opensource.google/go/go/+/refs/tags/go1.20:src/crypto/rsa/pkcs1v15.go;l=263;bpv=0;bpt=1
out, err = rsa.SignPKCS1v15(RandReader, ctx.privateKey, hash, hashed) out, err = rsa.SignPKCS1v15(RandReader, ctx.privateKey, hash, hashed)
case PS256, PS384, PS512: case PS256, PS384, PS512:
out, err = rsa.SignPSS(RandReader, ctx.privateKey, hash, hashed, &rsa.PSSOptions{ out, err = rsa.SignPSS(RandReader, ctx.privateKey, hash, hashed, &rsa.PSSOptions{

View File

@ -21,9 +21,8 @@ import (
"crypto/rsa" "crypto/rsa"
"errors" "errors"
"fmt" "fmt"
"reflect"
"github.com/go-jose/go-jose/v3/json" "github.com/go-jose/go-jose/v4/json"
) )
// Encrypter represents an encrypter which produces an encrypted JWE object. // Encrypter represents an encrypter which produces an encrypted JWE object.
@ -76,14 +75,24 @@ type recipientKeyInfo struct {
type EncrypterOptions struct { type EncrypterOptions struct {
Compression CompressionAlgorithm Compression CompressionAlgorithm
// Optional map of additional keys to be inserted into the protected header // Optional map of name/value pairs to be inserted into the protected
// of a JWS object. Some specifications which make use of JWS like to insert // header of a JWS object. Some specifications which make use of
// additional values here. All values must be JSON-serializable. // JWS require additional values here.
//
// Values will be serialized by [json.Marshal] and must be valid inputs to
// that function.
//
// [json.Marshal]: https://pkg.go.dev/encoding/json#Marshal
ExtraHeaders map[HeaderKey]interface{} ExtraHeaders map[HeaderKey]interface{}
} }
// WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it // WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it
// if necessary. It returns itself and so can be used in a fluent style. // if necessary, and returns the updated EncrypterOptions.
//
// The v parameter will be serialized by [json.Marshal] and must be a valid
// input to that function.
//
// [json.Marshal]: https://pkg.go.dev/encoding/json#Marshal
func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions { func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions {
if eo.ExtraHeaders == nil { if eo.ExtraHeaders == nil {
eo.ExtraHeaders = map[HeaderKey]interface{}{} eo.ExtraHeaders = map[HeaderKey]interface{}{}
@ -112,6 +121,16 @@ func (eo *EncrypterOptions) WithType(typ ContentType) *EncrypterOptions {
// be generated. // be generated.
type Recipient struct { type Recipient struct {
Algorithm KeyAlgorithm Algorithm KeyAlgorithm
// Key must have one of these types:
// - ed25519.PublicKey
// - *ecdsa.PublicKey
// - *rsa.PublicKey
// - *JSONWebKey
// - JSONWebKey
// - []byte (a symmetric key)
// - Any type that satisfies the OpaqueKeyEncrypter interface
//
// The type of Key must match the value of Algorithm.
Key interface{} Key interface{}
KeyID string KeyID string
PBES2Count int PBES2Count int
@ -150,16 +169,17 @@ func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions)
switch rcpt.Algorithm { switch rcpt.Algorithm {
case DIRECT: case DIRECT:
// Direct encryption mode must be treated differently // Direct encryption mode must be treated differently
if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) { keyBytes, ok := rawKey.([]byte)
if !ok {
return nil, ErrUnsupportedKeyType return nil, ErrUnsupportedKeyType
} }
if encrypter.cipher.keySize() != len(rawKey.([]byte)) { if encrypter.cipher.keySize() != len(keyBytes) {
return nil, ErrInvalidKeySize return nil, ErrInvalidKeySize
} }
encrypter.keyGenerator = staticKeyGenerator{ encrypter.keyGenerator = staticKeyGenerator{
key: rawKey.([]byte), key: keyBytes,
} }
recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, rawKey.([]byte)) recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, keyBytes)
recipientInfo.keyID = keyID recipientInfo.keyID = keyID
if rcpt.KeyID != "" { if rcpt.KeyID != "" {
recipientInfo.keyID = rcpt.KeyID recipientInfo.keyID = rcpt.KeyID
@ -168,16 +188,16 @@ func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions)
return encrypter, nil return encrypter, nil
case ECDH_ES: case ECDH_ES:
// ECDH-ES (w/o key wrapping) is similar to DIRECT mode // ECDH-ES (w/o key wrapping) is similar to DIRECT mode
typeOf := reflect.TypeOf(rawKey) keyDSA, ok := rawKey.(*ecdsa.PublicKey)
if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) { if !ok {
return nil, ErrUnsupportedKeyType return nil, ErrUnsupportedKeyType
} }
encrypter.keyGenerator = ecKeyGenerator{ encrypter.keyGenerator = ecKeyGenerator{
size: encrypter.cipher.keySize(), size: encrypter.cipher.keySize(),
algID: string(enc), algID: string(enc),
publicKey: rawKey.(*ecdsa.PublicKey), publicKey: keyDSA,
} }
recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, rawKey.(*ecdsa.PublicKey)) recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, keyDSA)
recipientInfo.keyID = keyID recipientInfo.keyID = keyID
if rcpt.KeyID != "" { if rcpt.KeyID != "" {
recipientInfo.keyID = rcpt.KeyID recipientInfo.keyID = rcpt.KeyID
@ -270,9 +290,8 @@ func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKey
recipient, err := makeJWERecipient(alg, encryptionKey.Key) recipient, err := makeJWERecipient(alg, encryptionKey.Key)
recipient.keyID = encryptionKey.KeyID recipient.keyID = encryptionKey.KeyID
return recipient, err return recipient, err
} case OpaqueKeyEncrypter:
if encrypter, ok := encryptionKey.(OpaqueKeyEncrypter); ok { return newOpaqueKeyEncrypter(alg, encryptionKey)
return newOpaqueKeyEncrypter(alg, encrypter)
} }
return recipientKeyInfo{}, ErrUnsupportedKeyType return recipientKeyInfo{}, ErrUnsupportedKeyType
} }
@ -300,12 +319,12 @@ func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
return newDecrypter(decryptionKey.Key) return newDecrypter(decryptionKey.Key)
case *JSONWebKey: case *JSONWebKey:
return newDecrypter(decryptionKey.Key) return newDecrypter(decryptionKey.Key)
} case OpaqueKeyDecrypter:
if okd, ok := decryptionKey.(OpaqueKeyDecrypter); ok { return &opaqueKeyDecrypter{decrypter: decryptionKey}, nil
return &opaqueKeyDecrypter{decrypter: okd}, nil default:
}
return nil, ErrUnsupportedKeyType return nil, ErrUnsupportedKeyType
} }
}
// Implementation of encrypt method producing a JWE object. // Implementation of encrypt method producing a JWE object.
func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JSONWebEncryption, error) { func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JSONWebEncryption, error) {
@ -403,9 +422,27 @@ func (ctx *genericEncrypter) Options() EncrypterOptions {
} }
} }
// Decrypt and validate the object and return the plaintext. Note that this // Decrypt and validate the object and return the plaintext. This
// function does not support multi-recipient, if you desire multi-recipient // function does not support multi-recipient. If you desire multi-recipient
// decryption use DecryptMulti instead. // decryption use DecryptMulti instead.
//
// The decryptionKey argument must contain a private or symmetric key
// and must have one of these types:
// - *ecdsa.PrivateKey
// - *rsa.PrivateKey
// - *JSONWebKey
// - JSONWebKey
// - *JSONWebKeySet
// - JSONWebKeySet
// - []byte (a symmetric key)
// - string (a symmetric key)
// - Any type that satisfies the OpaqueKeyDecrypter interface.
//
// Note that ed25519 is only available for signatures, not encryption, so is
// not an option here.
//
// Automatically decompresses plaintext, but returns an error if the decompressed
// data would be >250kB or >10x the size of the compressed data, whichever is larger.
func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) { func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) {
headers := obj.mergedHeaders(nil) headers := obj.mergedHeaders(nil)
@ -462,15 +499,24 @@ func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error)
// The "zip" header parameter may only be present in the protected header. // The "zip" header parameter may only be present in the protected header.
if comp := obj.protected.getCompression(); comp != "" { if comp := obj.protected.getCompression(); comp != "" {
plaintext, err = decompress(comp, plaintext) plaintext, err = decompress(comp, plaintext)
if err != nil {
return nil, fmt.Errorf("go-jose/go-jose: failed to decompress plaintext: %v", err)
}
} }
return plaintext, err return plaintext, nil
} }
// DecryptMulti decrypts and validates the object and returns the plaintexts, // DecryptMulti decrypts and validates the object and returns the plaintexts,
// with support for multiple recipients. It returns the index of the recipient // with support for multiple recipients. It returns the index of the recipient
// for which the decryption was successful, the merged headers for that recipient, // for which the decryption was successful, the merged headers for that recipient,
// and the plaintext. // and the plaintext.
//
// The decryptionKey argument must have one of the types allowed for the
// decryptionKey argument of Decrypt().
//
// Automatically decompresses plaintext, but returns an error if the decompressed
// data would be >250kB or >3x the size of the compressed data, whichever is larger.
func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) { func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) {
globalHeaders := obj.mergedHeaders(nil) globalHeaders := obj.mergedHeaders(nil)
@ -532,7 +578,10 @@ func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Heade
// The "zip" header parameter may only be present in the protected header. // The "zip" header parameter may only be present in the protected header.
if comp := obj.protected.getCompression(); comp != "" { if comp := obj.protected.getCompression(); comp != "" {
plaintext, _ = decompress(comp, plaintext) plaintext, err = decompress(comp, plaintext)
if err != nil {
return -1, Header{}, nil, fmt.Errorf("go-jose/go-jose: failed to decompress plaintext: %v", err)
}
} }
sanitized, err := headers.sanitized() sanitized, err := headers.sanitized()

View File

@ -15,13 +15,11 @@
*/ */
/* /*
Package jose aims to provide an implementation of the Javascript Object Signing Package jose aims to provide an implementation of the Javascript Object Signing
and Encryption set of standards. It implements encryption and signing based on and Encryption set of standards. It implements encryption and signing based on
the JSON Web Encryption and JSON Web Signature standards, with optional JSON Web the JSON Web Encryption and JSON Web Signature standards, with optional JSON Web
Token support available in a sub-package. The library supports both the compact Token support available in a sub-package. The library supports both the compact
and JWS/JWE JSON Serialization formats, and has optional support for multiple and JWS/JWE JSON Serialization formats, and has optional support for multiple
recipients. recipients.
*/ */
package jose package jose

View File

@ -21,12 +21,13 @@ import (
"compress/flate" "compress/flate"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"math/big" "math/big"
"strings" "strings"
"unicode" "unicode"
"github.com/go-jose/go-jose/v3/json" "github.com/go-jose/go-jose/v4/json"
) )
// Helper function to serialize known-good objects. // Helper function to serialize known-good objects.
@ -85,7 +86,7 @@ func decompress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) {
} }
} }
// Compress with DEFLATE // deflate compresses the input.
func deflate(input []byte) ([]byte, error) { func deflate(input []byte) ([]byte, error) {
output := new(bytes.Buffer) output := new(bytes.Buffer)
@ -97,15 +98,24 @@ func deflate(input []byte) ([]byte, error) {
return output.Bytes(), err return output.Bytes(), err
} }
// Decompress with DEFLATE // inflate decompresses the input.
//
// Errors if the decompressed data would be >250kB or >10x the size of the
// compressed data, whichever is larger.
func inflate(input []byte) ([]byte, error) { func inflate(input []byte) ([]byte, error) {
output := new(bytes.Buffer) output := new(bytes.Buffer)
reader := flate.NewReader(bytes.NewBuffer(input)) reader := flate.NewReader(bytes.NewBuffer(input))
_, err := io.Copy(output, reader) maxCompressedSize := max(250_000, 10*int64(len(input)))
if err != nil {
limit := maxCompressedSize + 1
n, err := io.CopyN(output, reader, limit)
if err != nil && err != io.EOF {
return nil, err return nil, err
} }
if n == limit {
return nil, fmt.Errorf("uncompressed data would be too large (>%d bytes)", maxCompressedSize)
}
err = reader.Close() err = reader.Close()
return output.Bytes(), err return output.Bytes(), err
@ -154,7 +164,7 @@ func (b *byteBuffer) UnmarshalJSON(data []byte) error {
return nil return nil
} }
decoded, err := base64URLDecode(encoded) decoded, err := base64.RawURLEncoding.DecodeString(encoded)
if err != nil { if err != nil {
return err return err
} }
@ -184,8 +194,35 @@ func (b byteBuffer) toInt() int {
return int(b.bigInt().Int64()) return int(b.bigInt().Int64())
} }
// base64URLDecode is implemented as defined in https://www.rfc-editor.org/rfc/rfc7515.html#appendix-C func base64EncodeLen(sl []byte) int {
func base64URLDecode(value string) ([]byte, error) { return base64.RawURLEncoding.EncodedLen(len(sl))
value = strings.TrimRight(value, "=") }
return base64.RawURLEncoding.DecodeString(value)
func base64JoinWithDots(inputs ...[]byte) string {
if len(inputs) == 0 {
return ""
}
// Count of dots.
totalCount := len(inputs) - 1
for _, input := range inputs {
totalCount += base64EncodeLen(input)
}
out := make([]byte, totalCount)
startEncode := 0
for i, input := range inputs {
base64.RawURLEncoding.Encode(out[startEncode:], input)
if i == len(inputs)-1 {
continue
}
startEncode += base64EncodeLen(input)
out[startEncode] = '.'
startEncode++
}
return string(out)
} }

View File

@ -1,4 +1,4 @@
Copyright (c) 2009 The Go Authors. All rights reserved. Copyright (c) 2012 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are modification, are permitted provided that the following conditions are

View File

@ -75,14 +75,13 @@ import (
// //
// The JSON null value unmarshals into an interface, map, pointer, or slice // The JSON null value unmarshals into an interface, map, pointer, or slice
// by setting that Go value to nil. Because null is often used in JSON to mean // by setting that Go value to nil. Because null is often used in JSON to mean
// ``not present,'' unmarshaling a JSON null into any other Go type has no effect // “not present,” unmarshaling a JSON null into any other Go type has no effect
// on the value and produces no error. // on the value and produces no error.
// //
// When unmarshaling quoted strings, invalid UTF-8 or // When unmarshaling quoted strings, invalid UTF-8 or
// invalid UTF-16 surrogate pairs are not treated as an error. // invalid UTF-16 surrogate pairs are not treated as an error.
// Instead, they are replaced by the Unicode replacement // Instead, they are replaced by the Unicode replacement
// character U+FFFD. // character U+FFFD.
//
func Unmarshal(data []byte, v interface{}) error { func Unmarshal(data []byte, v interface{}) error {
// Check for well-formedness. // Check for well-formedness.
// Avoids filling out half a data structure // Avoids filling out half a data structure

View File

@ -58,6 +58,7 @@ import (
// becomes a member of the object unless // becomes a member of the object unless
// - the field's tag is "-", or // - the field's tag is "-", or
// - the field is empty and its tag specifies the "omitempty" option. // - the field is empty and its tag specifies the "omitempty" option.
//
// The empty values are false, 0, any // The empty values are false, 0, any
// nil pointer or interface value, and any array, slice, map, or string of // nil pointer or interface value, and any array, slice, map, or string of
// length zero. The object's default key string is the struct field name // length zero. The object's default key string is the struct field name
@ -133,7 +134,6 @@ import (
// JSON cannot represent cyclic data structures and Marshal does not // JSON cannot represent cyclic data structures and Marshal does not
// handle them. Passing cyclic structures to Marshal will result in // handle them. Passing cyclic structures to Marshal will result in
// an infinite recursion. // an infinite recursion.
//
func Marshal(v interface{}) ([]byte, error) { func Marshal(v interface{}) ([]byte, error) {
e := &encodeState{} e := &encodeState{}
err := e.marshal(v) err := e.marshal(v)

Some files were not shown because too many files have changed in this diff Show More