Merge branch 'cloudflare:master' into tunnel-health

This commit is contained in:
Mads Jon Nielsen 2024-04-23 08:08:04 +02:00 committed by GitHub
commit 6db3cb2f1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
300 changed files with 3724 additions and 16101 deletions

View File

@ -9,10 +9,10 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Install Go
uses: actions/setup-go@v3
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
- name: Test
run: make test

8
.teamcity/install-cloudflare-go.sh vendored Executable file
View File

@ -0,0 +1,8 @@
# !/usr/bin/env bash
cd /tmp
git clone -q https://github.com/cloudflare/go
cd go/src
# https://github.com/cloudflare/go/tree/34129e47042e214121b6bbff0ded4712debed18e is version go1.21.5-devel-cf
git checkout -q 34129e47042e214121b6bbff0ded4712debed18e
./make.bash

View File

@ -1,13 +1,8 @@
cd /tmp/
rm -rf go
rm -rf gocache
rm -rf /tmp/go
export GOCACHE=/tmp/gocache
rm -rf $GOCACHE
git clone -q https://github.com/cloudflare/go
cd go/src
# https://github.com/cloudflare/go/tree/34129e47042e214121b6bbff0ded4712debed18e is version go1.21.5-devel-cf
git checkout -q 34129e47042e214121b6bbff0ded4712debed18e
./make.bash
./.teamcity/install-cloudflare-go.sh
export PATH="/tmp/go/bin:$PATH"
go version

View File

@ -1,26 +0,0 @@
#!/bin/bash
set -euo pipefail
if ! VERSION="$(git describe --tags --exact-match 2>/dev/null)" ; then
echo "Skipping public release for an untagged commit."
echo "##teamcity[buildStatus status='SUCCESS' text='Skipped due to lack of tag']"
exit 0
fi
if [[ "${HOMEBREW_GITHUB_API_TOKEN:-}" == "" ]] ; then
echo "Missing GITHUB_API_TOKEN"
exit 1
fi
# "install" Homebrew
git clone https://github.com/Homebrew/brew tmp/homebrew
eval "$(tmp/homebrew/bin/brew shellenv)"
brew update --force --quiet
chmod -R go-w "$(brew --prefix)/share/zsh"
git config --global user.name "cloudflare-warp-bot"
git config --global user.email "warp-bot@cloudflare.com"
# bump formula pr
brew bump-formula-pr cloudflared --version="$VERSION" --no-browse --no-audit

View File

@ -1,66 +0,0 @@
#!/bin/bash
set -euo pipefail
FILENAME="${PWD}/artifacts/cloudflared-darwin-amd64.tgz"
if ! VERSION="$(git describe --tags --exact-match 2>/dev/null)" ; then
echo "Skipping public release for an untagged commit."
echo "##teamcity[buildStatus status='SUCCESS' text='Skipped due to lack of tag']"
exit 0
fi
if [[ ! -f "$FILENAME" ]] ; then
echo "Missing $FILENAME"
exit 1
fi
if [[ "${GITHUB_PRIVATE_KEY_B64:-}" == "" ]] ; then
echo "Missing GITHUB_PRIVATE_KEY_B64"
exit 1
fi
# upload to s3 bucket for use by Homebrew formula
s3cmd \
--acl-public --access_key="$AWS_ACCESS_KEY_ID" --secret_key="$AWS_SECRET_ACCESS_KEY" --host-bucket="%(bucket)s.s3.cfdata.org" \
put "$FILENAME" "s3://cftunnel-docs/dl/cloudflared-$VERSION-darwin-amd64.tgz"
s3cmd \
--acl-public --access_key="$AWS_ACCESS_KEY_ID" --secret_key="$AWS_SECRET_ACCESS_KEY" --host-bucket="%(bucket)s.s3.cfdata.org" \
cp "s3://cftunnel-docs/dl/cloudflared-$VERSION-darwin-amd64.tgz" "s3://cftunnel-docs/dl/cloudflared-stable-darwin-amd64.tgz"
SHA256=$(sha256sum "$FILENAME" | cut -b1-64)
# set up git (note that UserKnownHostsFile is an absolute path so we can cd wherever)
mkdir -p tmp
ssh-keyscan -t rsa github.com > tmp/github.txt
echo "$GITHUB_PRIVATE_KEY_B64" | base64 --decode > tmp/private.key
chmod 0400 tmp/private.key
export GIT_SSH_COMMAND="ssh -o UserKnownHostsFile=$PWD/tmp/github.txt -i $PWD/tmp/private.key -o IdentitiesOnly=yes"
# clone Homebrew repo into tmp/homebrew-cloudflare
git clone git@github.com:cloudflare/homebrew-cloudflare.git tmp/homebrew-cloudflare
cd tmp/homebrew-cloudflare
git checkout -f master
git reset --hard origin/master
# modify cloudflared.rb
URL="https://packages.argotunnel.com/dl/cloudflared-$VERSION-darwin-amd64.tgz"
tee cloudflared.rb <<EOF
class Cloudflared < Formula
desc 'Cloudflare Tunnel'
homepage 'https://developers.cloudflare.com/cloudflare-one/connections/connect-apps'
url '$URL'
sha256 '$SHA256'
version '$VERSION'
def install
bin.install 'cloudflared'
end
end
EOF
# push cloudflared.rb
git add cloudflared.rb
git diff
git config user.name "cloudflare-warp-bot"
git config user.email "warp-bot@cloudflare.com"
git commit -m "Release Cloudflare Tunnel $VERSION"
git push -v origin master

View File

@ -1,3 +1,7 @@
## 2024.2.1
### Notices
- Starting from this version, tunnel diagnostics will be enabled by default. This will allow the engineering team to remotely get diagnostics from cloudflared during debug activities. Users still have the capability to opt-out of this feature by defining `--management-diagnostics=false` (or env `TUNNEL_MANAGEMENT_DIAGNOSTICS`).
## 2023.9.0
### Notices
- The `warp-routing` `enabled: boolean` flag is no longer supported in the configuration file. Warp Routing traffic (eg TCP, UDP, ICMP) traffic is proxied to cloudflared if routes to the target tunnel are configured. This change does not affect remotely managed tunnels, but for locally managed tunnels, users that might be relying on this feature flag to block traffic should instead guarantee that tunnel has no Private Routes configured for the tunnel.

View File

@ -6,14 +6,16 @@ ENV GO111MODULE=on \
CGO_ENABLED=0 \
TARGET_GOOS=${TARGET_GOOS} \
TARGET_GOARCH=${TARGET_GOARCH}
WORKDIR /go/src/github.com/cloudflare/cloudflared/
# copy our sources into the builder image
COPY . .
RUN .teamcity/install-cloudflare-go.sh
# compile cloudflared
RUN make cloudflared
RUN PATH="/tmp/go/bin:$PATH" make cloudflared
# use a distroless base image with glibc
FROM gcr.io/distroless/base-debian11:nonroot

View File

@ -8,8 +8,10 @@ WORKDIR /go/src/github.com/cloudflare/cloudflared/
# copy our sources into the builder image
COPY . .
RUN .teamcity/install-cloudflare-go.sh
# compile cloudflared
RUN GOOS=linux GOARCH=amd64 make cloudflared
RUN GOOS=linux GOARCH=amd64 PATH="/tmp/go/bin:$PATH" make cloudflared
# use a distroless base image with glibc
FROM gcr.io/distroless/base-debian11:nonroot

View File

@ -8,8 +8,10 @@ WORKDIR /go/src/github.com/cloudflare/cloudflared/
# copy our sources into the builder image
COPY . .
RUN .teamcity/install-cloudflare-go.sh
# compile cloudflared
RUN GOOS=linux GOARCH=arm64 make cloudflared
RUN GOOS=linux GOARCH=arm64 PATH="/tmp/go/bin:$PATH" make cloudflared
# use a distroless base image with glibc
FROM gcr.io/distroless/base-debian11:nonroot-arm64

View File

@ -1,3 +1,6 @@
# The targets cannot be run in parallel
.NOTPARALLEL:
VERSION := $(shell git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*")
MSI_VERSION := $(shell git tag -l --sort=v:refname | grep "w" | tail -1 | cut -c2-)
#MSI_VERSION expects the format of the tag to be: (wX.X.X). Starts with the w character to not break cfsetup.
@ -49,6 +52,8 @@ PACKAGE_DIR := $(CURDIR)/packaging
PREFIX := /usr
INSTALL_BINDIR := $(PREFIX)/bin/
INSTALL_MANDIR := $(PREFIX)/share/man/man1/
CF_GO_PATH := /tmp/go
PATH := $(CF_GO_PATH)/bin:$(PATH)
LOCAL_ARCH ?= $(shell uname -m)
ifneq ($(GOARCH),)
@ -164,10 +169,19 @@ cover:
test-ssh-server:
docker-compose -f ssh_server_tests/docker-compose.yml up
.PHONY: install-go
install-go:
rm -rf ${CF_GO_PATH}
./.teamcity/install-cloudflare-go.sh
.PHONY: cleanup-go
cleanup-go:
rm -rf ${CF_GO_PATH}
cloudflared.1: cloudflared_man_template
sed -e 's/\$${VERSION}/$(VERSION)/; s/\$${DATE}/$(DATE)/' cloudflared_man_template > cloudflared.1
install: cloudflared cloudflared.1
install: install-go cloudflared cloudflared.1 cleanup-go
mkdir -p $(DESTDIR)$(INSTALL_BINDIR) $(DESTDIR)$(INSTALL_MANDIR)
install -m755 cloudflared $(DESTDIR)$(INSTALL_BINDIR)/cloudflared
install -m644 cloudflared.1 $(DESTDIR)$(INSTALL_MANDIR)/cloudflared.1
@ -209,15 +223,6 @@ cloudflared-darwin-amd64.tgz: cloudflared
tar czf cloudflared-darwin-amd64.tgz cloudflared
rm cloudflared
.PHONY: homebrew-upload
homebrew-upload: cloudflared-darwin-amd64.tgz
aws s3 --endpoint-url $(S3_ENDPOINT) cp --acl public-read $$^ $(S3_URI)/cloudflared-$$(VERSION)-$1.tgz
aws s3 --endpoint-url $(S3_ENDPOINT) cp --acl public-read $(S3_URI)/cloudflared-$$(VERSION)-$1.tgz $(S3_URI)/cloudflared-stable-$1.tgz
.PHONY: homebrew-release
homebrew-release: homebrew-upload
./publish-homebrew-formula.sh cloudflared-darwin-amd64.tgz $(VERSION) homebrew-cloudflare
.PHONY: github-release
github-release: cloudflared
python3 github_release.py --path $(EXECUTABLE_PATH) --release-version $(VERSION)

View File

@ -31,7 +31,7 @@ Downloads are available as standalone binaries, a Docker image, and Debian, RPM,
* Binaries, Debian, and RPM packages for Linux [can be found here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation#linux)
* A Docker image of `cloudflared` is [available on DockerHub](https://hub.docker.com/r/cloudflare/cloudflared)
* You can install on Windows machines with the [steps here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation#windows)
* Build from source with the [instructions here](https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup/installation#build-from-source)
* To build from source, first you need to download the go toolchain by running `./.teamcity/install-cloudflare-go.sh` and follow the output. Then you can run `make cloudflared`
User documentation for Cloudflare Tunnel can be found at https://developers.cloudflare.com/cloudflare-one/connections/connect-apps

View File

@ -1,3 +1,78 @@
2024.4.0
- 2024-04-02 feat: provide short version (#1206)
- 2024-04-02 Format code
- 2024-01-18 feat: auto tls sni
- 2023-12-24 fix checkInPingGroup bugs
- 2023-12-15 Add environment variables for TCP tunnel hostname / destination / URL.
2024.3.0
- 2024-03-14 TUN-8281: Run cloudflared query list tunnels/routes endpoint in a paginated way
- 2024-03-13 TUN-8297: Improve write timeout logging on safe_stream.go
- 2024-03-07 TUN-8290: Remove `|| true` from postrm.sh
- 2024-03-05 TUN-8275: Skip write timeout log on "no network activity"
- 2024-01-23 Update postrm.sh to fix incomplete uninstall
- 2024-01-05 fix typo in errcheck for response parsing logic in CreateTunnel routine
- 2023-12-23 Update linux_service.go
- 2023-12-07 ci: bump actions/checkout to v4
- 2023-12-07 ci/check: bump actions/setup-go to v5
- 2023-04-28 check.yaml: bump actions/setup-go to v4
2024.2.1
- 2024-02-20 TUN-8242: Update Changes.md file with new remote diagnostics behaviour
- 2024-02-19 TUN-8238: Fix type mismatch introduced by fast-forward
- 2024-02-16 TUN-8243: Collect metrics on the number of QUIC frames sent/received
- 2024-02-15 TUN-8238: Refactor proxy logging
- 2024-02-14 TUN-8242: Enable remote diagnostics by default
- 2024-02-12 TUN-8236: Add write timeout to quic and tcp connections
- 2024-02-09 TUN-8224: Fix safety of TCP stream logging, separate connect and ack log messages
2024.2.0
- 2024-02-07 TUN-8224: Count and collect metrics on stream connect successes/errors
2024.1.5
- 2024-01-22 TUN-8176: Support ARM platforms that don't have an FPU or have it enabled in kernel
- 2024-01-15 TUN-8158: Bring back commit e6537418859afcac29e56a39daa08bcabc09e048 and fixes infinite loop on linux when the socket is closed
2024.1.4
- 2024-01-19 Revert "TUN-8158: Add logging to confirm when ICMP reply is returned to the edge"
2024.1.3
- 2024-01-15 TUN-8161: Fix broken ARM build for armv6
- 2024-01-15 TUN-8158: Add logging to confirm when ICMP reply is returned to the edge
2024.1.2
- 2024-01-11 TUN-8147: Disable ECN usage due to bugs in detecting if supported
- 2024-01-11 TUN-8146: Fix export path for install-go command
- 2024-01-11 TUN-8146: Fix Makefile targets should not be run in parallel and install-go script was missing shebang
- 2024-01-10 TUN-8140: Remove homebrew scripts
2024.1.1
- 2024-01-10 TUN-8134: Revert installed prefix to /usr
- 2024-01-09 TUN-8130: Fix path to install go for mac build
- 2024-01-09 TUN-8129: Use the same build command between branch and release builds
- 2024-01-09 TUN-8130: Install go tool chain in /tmp on build agents
- 2024-01-09 TUN-8134: Install cloudflare go as part of make install
- 2024-01-08 TUN-8118: Disable FIPS module to build with go-boring without CGO_ENABLED
2024.1.0
- 2024-01-01 TUN-7934: Update quic-go to a version that queues datagrams for better throughput and drops large datagram
- 2023-12-20 TUN-8072: Need to set GOCACHE in mac go installation script
- 2023-12-17 TUN-8072: Add script to download cloudflare go for Mac build agents
- 2023-12-15 Fix nil pointer dereference segfault when passing "null" config json to cloudflared tunnel ingress validate (#1070)
- 2023-12-15 configuration.go: fix developerPortal link (#960)
- 2023-12-14 tunnelrpc/pogs: fix dropped test errors (#1106)
- 2023-12-14 cmd/cloudflared/updater: fix dropped error (#1055)
- 2023-12-14 use os.Executable to discover the path to cloudflared (#1040)
- 2023-12-14 Remove extraneous `period` from Path Environment Variable (#1009)
- 2023-12-14 Use CLI context when running tunnel (#597)
- 2023-12-14 TUN-8066: Define scripts to build on Windows agents
- 2023-12-11 TUN-8052: Update go to 1.21.5
- 2023-12-07 TUN-7970: Default to enable post quantum encryption for quic transport
- 2023-12-04 TUN-8006: Update quic-go to latest upstream
- 2023-11-15 VULN-44842 Add a flag that allows users to not send the Access JWT to stdout
- 2023-11-13 TUN-7965: Remove legacy incident status page check
- 2023-11-13 AUTH-5682 Org token flow in Access logins should pass CF_AppSession cookie
2023.10.0
- 2023-10-06 TUN-7864: Document cloudflared versions support
- 2023-10-03 CUSTESC-33731: Make rule match test report rule in 0-index base

View File

@ -1,3 +1,4 @@
#!/bin/bash
VERSION=$(git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*")
echo $VERSION

View File

@ -1,7 +1,9 @@
#!/bin/bash
VERSION=$(git describe --tags --always --match "[0-9][0-9][0-9][0-9].*.*")
echo $VERSION
# Avoid depending on C code since we don't need it.
# Disable FIPS module in go-boring
export GOEXPERIMENT=noboringcrypto
export CGO_ENABLED=0
# This controls the directory the built artifacts go into
@ -13,6 +15,12 @@ export TARGET_OS=linux
for arch in ${linuxArchs[@]}; do
unset TARGET_ARM
export TARGET_ARCH=$arch
## Support for arm platforms without hardware FPU enabled
if [[ $arch == arm ]] ; then
export TARGET_ARCH=arm
export TARGET_ARM=5
fi
## Support for armhf builds
if [[ $arch == armhf ]] ; then

View File

@ -109,20 +109,34 @@ func (r *RESTClient) sendRequest(method string, url url.URL, body interface{}) (
return r.client.Do(req)
}
func parseResponse(reader io.Reader, data interface{}) error {
func parseResponseEnvelope(reader io.Reader) (*response, error) {
// Schema for Tunnelstore responses in the v1 API.
// Roughly, it's a wrapper around a particular result that adds failures/errors/etc
var result response
// First, parse the wrapper and check the API call succeeded
if err := json.NewDecoder(reader).Decode(&result); err != nil {
return errors.Wrap(err, "failed to decode response")
return nil, errors.Wrap(err, "failed to decode response")
}
if err := result.checkErrors(); err != nil {
return err
return nil, err
}
if !result.Success {
return ErrAPINoSuccess
return nil, ErrAPINoSuccess
}
return &result, nil
}
func parseResponse(reader io.Reader, data interface{}) error {
result, err := parseResponseEnvelope(reader)
if err != nil {
return err
}
return parseResponseBody(result, data)
}
func parseResponseBody(result *response, data interface{}) error {
// At this point we know the API call succeeded, so, parse out the inner
// result into the datatype provided as a parameter.
if err := json.Unmarshal(result.Result, &data); err != nil {
@ -131,11 +145,58 @@ func parseResponse(reader io.Reader, data interface{}) error {
return nil
}
func fetchExhaustively[T any](requestFn func(int) (*http.Response, error)) ([]*T, error) {
page := 0
var fullResponse []*T
for {
page += 1
envelope, parsedBody, err := fetchPage[T](requestFn, page)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("Error Parsing page %d", page))
}
fullResponse = append(fullResponse, parsedBody...)
if envelope.Pagination.Count < envelope.Pagination.PerPage || len(fullResponse) >= envelope.Pagination.TotalCount {
break
}
}
return fullResponse, nil
}
func fetchPage[T any](requestFn func(int) (*http.Response, error), page int) (*response, []*T, error) {
pageResp, err := requestFn(page)
if err != nil {
return nil, nil, errors.Wrap(err, "REST request failed")
}
defer pageResp.Body.Close()
if pageResp.StatusCode == http.StatusOK {
envelope, err := parseResponseEnvelope(pageResp.Body)
if err != nil {
return nil, nil, err
}
var parsedRspBody []*T
return envelope, parsedRspBody, parseResponseBody(envelope, &parsedRspBody)
}
return nil, nil, errors.New(fmt.Sprintf("Failed to fetch page. Server returned: %d", pageResp.StatusCode))
}
type response struct {
Success bool `json:"success,omitempty"`
Errors []apiErr `json:"errors,omitempty"`
Messages []string `json:"messages,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Success bool `json:"success,omitempty"`
Errors []apiErr `json:"errors,omitempty"`
Messages []string `json:"messages,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Pagination Pagination `json:"result_info,omitempty"`
}
type Pagination struct {
Count int `json:"count,omitempty"`
Page int `json:"page,omitempty"`
PerPage int `json:"per_page,omitempty"`
TotalCount int `json:"total_count,omitempty"`
}
func (r *response) checkErrors() error {

View File

@ -137,20 +137,24 @@ type GetRouteByIpParams struct {
}
// ListRoutes calls the Tunnelstore GET endpoint for all routes under an account.
// Due to pagination on the server side it will call the endpoint multiple times if needed.
func (r *RESTClient) ListRoutes(filter *IpRouteFilter) ([]*DetailedRoute, error) {
endpoint := r.baseEndpoints.accountRoutes
endpoint.RawQuery = filter.Encode()
resp, err := r.sendRequest("GET", endpoint, nil)
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
fetchFn := func(page int) (*http.Response, error) {
endpoint := r.baseEndpoints.accountRoutes
filter.Page(page)
endpoint.RawQuery = filter.Encode()
rsp, err := r.sendRequest("GET", endpoint, nil)
if resp.StatusCode == http.StatusOK {
return parseListDetailedRoutes(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
if rsp.StatusCode != http.StatusOK {
rsp.Body.Close()
return nil, r.statusCodeToError("list routes", rsp)
}
return rsp, nil
}
return nil, r.statusCodeToError("list routes", resp)
return fetchExhaustively[DetailedRoute](fetchFn)
}
// AddRoute calls the Tunnelstore POST endpoint for a given route.
@ -208,12 +212,6 @@ func (r *RESTClient) GetByIP(params GetRouteByIpParams) (DetailedRoute, error) {
return DetailedRoute{}, r.statusCodeToError("get route by IP", resp)
}
func parseListDetailedRoutes(body io.ReadCloser) ([]*DetailedRoute, error) {
var routes []*DetailedRoute
err := parseResponse(body, &routes)
return routes, err
}
func parseRoute(body io.ReadCloser) (Route, error) {
var route Route
err := parseResponse(body, &route)

View File

@ -167,6 +167,10 @@ func (f *IpRouteFilter) MaxFetchSize(max uint) {
f.queryParams.Set("per_page", strconv.Itoa(int(max)))
}
func (f *IpRouteFilter) Page(page int) {
f.queryParams.Set("page", strconv.Itoa(page))
}
func (f IpRouteFilter) Encode() string {
return f.queryParams.Encode()
}

View File

@ -93,7 +93,7 @@ func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*TunnelWith
switch resp.StatusCode {
case http.StatusOK:
var tunnel TunnelWithToken
if serdeErr := parseResponse(resp.Body, &tunnel); err != nil {
if serdeErr := parseResponse(resp.Body, &tunnel); serdeErr != nil {
return nil, serdeErr
}
return &tunnel, nil
@ -177,25 +177,22 @@ func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID, cascade bool) error {
}
func (r *RESTClient) ListTunnels(filter *TunnelFilter) ([]*Tunnel, error) {
endpoint := r.baseEndpoints.accountLevel
endpoint.RawQuery = filter.encode()
resp, err := r.sendRequest("GET", endpoint, nil)
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return parseListTunnels(resp.Body)
fetchFn := func(page int) (*http.Response, error) {
endpoint := r.baseEndpoints.accountLevel
filter.Page(page)
endpoint.RawQuery = filter.encode()
rsp, err := r.sendRequest("GET", endpoint, nil)
if err != nil {
return nil, errors.Wrap(err, "REST request failed")
}
if rsp.StatusCode != http.StatusOK {
rsp.Body.Close()
return nil, r.statusCodeToError("list tunnels", rsp)
}
return rsp, nil
}
return nil, r.statusCodeToError("list tunnels", resp)
}
func parseListTunnels(body io.ReadCloser) ([]*Tunnel, error) {
var tunnels []*Tunnel
err := parseResponse(body, &tunnels)
return tunnels, err
return fetchExhaustively[Tunnel](fetchFn)
}
func (r *RESTClient) ListActiveClients(tunnelID uuid.UUID) ([]*ActiveClient, error) {

View File

@ -50,6 +50,10 @@ func (f *TunnelFilter) MaxFetchSize(max uint) {
f.queryParams.Set("per_page", strconv.Itoa(int(max)))
}
func (f *TunnelFilter) Page(page int) {
f.queryParams.Set("page", strconv.Itoa(page))
}
func (f TunnelFilter) encode() string {
return f.queryParams.Encode()
}

View File

@ -3,7 +3,6 @@ package cfapi
import (
"bytes"
"fmt"
"io"
"net"
"reflect"
"strings"
@ -16,52 +15,6 @@ import (
var loc, _ = time.LoadLocation("UTC")
func Test_parseListTunnels(t *testing.T) {
type args struct {
body string
}
tests := []struct {
name string
args args
want []*Tunnel
wantErr bool
}{
{
name: "empty list",
args: args{body: `{"success": true, "result": []}`},
want: []*Tunnel{},
},
{
name: "success is false",
args: args{body: `{"success": false, "result": []}`},
wantErr: true,
},
{
name: "errors are present",
args: args{body: `{"errors": [{"code": 1003, "message":"An A, AAAA or CNAME record already exists with that host"}], "result": []}`},
wantErr: true,
},
{
name: "invalid response",
args: args{body: `abc`},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body := io.NopCloser(bytes.NewReader([]byte(tt.args.body)))
got, err := parseListTunnels(body)
if (err != nil) != tt.wantErr {
t.Errorf("parseListTunnels() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseListTunnels() = %v, want %v", got, tt.want)
}
})
}
}
func Test_unmarshalTunnel(t *testing.T) {
type args struct {
body string

View File

@ -9,22 +9,30 @@ buster: &buster
- *pinned_go
- build-essential
- gotest-to-teamcity
- fakeroot
- rubygem-fpm
- rpm
- libffi-dev
- reprepro
- createrepo
pre-cache: &build_pre_cache
- export GOCACHE=/cfsetup_build/.cache/go-build
- go install golang.org/x/tools/cmd/goimports@latest
post-cache:
- export GOOS=linux
- export GOARCH=amd64
- make cloudflared
# TODO: TUN-8126 this is temporary to make sure packages can be built before release
- ./build-packages.sh
# Build binary for component test
- GOOS=linux GOARCH=amd64 make cloudflared
build-fips:
build_dir: *build_dir
builddeps: *build_deps
pre-cache: *build_pre_cache
post-cache:
- export GOOS=linux
- export GOARCH=amd64
- make cloudflared
- export FIPS=true
# TODO: TUN-8126 this is temporary to make sure packages can be built before release
- ./build-packages-fips.sh
# Build binary for component test
- GOOS=linux GOARCH=amd64 make cloudflared
cover:
build_dir: *build_dir
builddeps: *build_deps
@ -234,16 +242,6 @@ buster: &buster
- component-tests/requirements.txt
pre-cache: *component_test_pre_cache
post-cache: *component_test_post_cache
update-homebrew:
builddeps:
- openssh-client
- s3cmd
- jq
- build-essential
- procps
post-cache:
- .teamcity/update-homebrew.sh
- .teamcity/update-homebrew-core.sh
github-message-release:
build_dir: *build_dir
builddeps: *build_pygithub

View File

@ -132,15 +132,18 @@ func Commands() []*cli.Command {
Name: sshHostnameFlag,
Aliases: []string{"tunnel-host", "T"},
Usage: "specify the hostname of your application.",
EnvVars: []string{"TUNNEL_SERVICE_HOSTNAME"},
},
&cli.StringFlag{
Name: sshDestinationFlag,
Usage: "specify the destination address of your SSH server.",
Name: sshDestinationFlag,
Usage: "specify the destination address of your SSH server.",
EnvVars: []string{"TUNNEL_SERVICE_DESTINATION"},
},
&cli.StringFlag{
Name: sshURLFlag,
Aliases: []string{"listener", "L"},
Usage: "specify the host:port to forward data to Cloudflare edge.",
EnvVars: []string{"TUNNEL_SERVICE_URL"},
},
&cli.StringSliceFlag{
Name: sshHeaderFlag,

View File

@ -55,7 +55,8 @@ var systemdAllTemplates = map[string]ServiceTemplate{
Path: fmt.Sprintf("/etc/systemd/system/%s", cloudflaredService),
Content: `[Unit]
Description=cloudflared
After=network.target
After=network-online.target
Wants=network-online.target
[Service]
TimeoutStartSec=0
@ -72,7 +73,8 @@ WantedBy=multi-user.target
Path: fmt.Sprintf("/etc/systemd/system/%s", cloudflaredUpdateService),
Content: `[Unit]
Description=Update cloudflared
After=network.target
After=network-online.target
Wants=network-online.target
[Service]
ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 11 ]; then systemctl restart cloudflared; exit 0; fi; exit $code'

View File

@ -3,6 +3,7 @@ package main
import (
"fmt"
"math/rand"
"os"
"strings"
"time"
@ -49,6 +50,9 @@ var (
)
func main() {
// FIXME: TUN-8148: Disable QUIC_GO ECN due to bugs in proper detection if supported
os.Setenv("QUIC_GO_DISABLE_ECN", "1")
rand.Seed(time.Now().UnixNano())
metrics.RegisterBuildInfo(BuildType, BuildTime, Version)
maxprocs.Set()
@ -130,11 +134,22 @@ To determine if an update happened in a script, check for error code 11.`,
{
Name: "version",
Action: func(c *cli.Context) (err error) {
if c.Bool("short") {
fmt.Println(strings.Split(c.App.Version, " ")[0])
return nil
}
version(c)
return nil
},
Usage: versionText,
Description: versionText,
Flags: []cli.Flag{
&cli.BoolFlag{
Name: "short",
Aliases: []string{"s"},
Usage: "print just the version number",
},
},
},
}
cmds = append(cmds, tunnel.Commands()...)

View File

@ -81,6 +81,9 @@ const (
// udpUnregisterSessionTimeout is how long we wait before we stop trying to unregister a UDP session from the edge
udpUnregisterSessionTimeoutFlag = "udp-unregister-session-timeout"
// writeStreamTimeout sets if we should have a timeout when writing data to a stream towards the destination (edge/origin).
writeStreamTimeout = "write-stream-timeout"
// quicDisablePathMTUDiscovery sets if QUIC should not perform PTMU discovery and use a smaller (safe) packet size.
// Packets will then be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
// Note that this may result in packet drops for UDP proxying, since we expect being able to send at least 1280 bytes of inner packets.
@ -697,6 +700,13 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Value: 5 * time.Second,
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: writeStreamTimeout,
EnvVars: []string{"TUNNEL_STREAM_WRITE_TIMEOUT"},
Usage: "Use this option to add a stream write timeout for connections when writing towards the origin or edge. Default is 0 which disables the write timeout.",
Value: 0 * time.Second,
Hidden: true,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: quicDisablePathMTUDiscovery,
EnvVars: []string{"TUNNEL_DISABLE_QUIC_PMTU"},
@ -781,7 +791,7 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
Name: "management-diagnostics",
Usage: "Enables the in-depth diagnostic routes to be made available over the management service (/debug/pprof, /metrics, etc.)",
EnvVars: []string{"TUNNEL_MANAGEMENT_DIAGNOSTICS"},
Value: false,
Value: true,
}),
selectProtocolFlag,
overwriteDNSFlag,

View File

@ -30,7 +30,10 @@ import (
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const secretValue = "*****"
const (
secretValue = "*****"
icmpFunnelTimeout = time.Second * 10
)
var (
developerPortal = "https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/install-and-setup"
@ -244,6 +247,7 @@ func prepareTunnelConfig(
FeatureSelector: featureSelector,
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag),
WriteStreamTimeout: c.Duration(writeStreamTimeout),
DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery),
}
packetConfig, err := newPacketConfig(c, log)
@ -256,6 +260,7 @@ func prepareTunnelConfig(
Ingress: &ingressRules,
WarpRouting: ingress.NewWarpRoutingConfig(&cfg.WarpRouting),
ConfigurationFlags: parseConfigFlags(c),
WriteTimeout: c.Duration(writeStreamTimeout),
}
return tunnelConfig, orchestratorConfig, nil
}
@ -361,7 +366,7 @@ func newPacketConfig(c *cli.Context, logger *zerolog.Logger) (*ingress.GlobalRou
logger.Info().Msgf("ICMP proxy will use %s as source for IPv6", ipv6Src)
}
icmpRouter, err := ingress.NewICMPRouter(ipv4Src, ipv6Src, zone, logger)
icmpRouter, err := ingress.NewICMPRouter(ipv4Src, ipv6Src, zone, logger, icmpFunnelTimeout)
if err != nil {
return nil, err
}

View File

@ -55,7 +55,7 @@ class TestManagement:
config = component_tests_config(cfd_mode=CfdModes.NAMED, run_proxy_dns=False, provide_ingress=False)
LOGGER.debug(config)
config_path = write_config(tmp_path, config.full_config)
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1", "--management-diagnostics"], new_process=True):
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], new_process=True):
wait_tunnel_ready(require_min_connections=1)
cfd_cli = CloudflaredCli(config, config_path, LOGGER)
url = cfd_cli.get_management_url("metrics", config, config_path)
@ -76,7 +76,7 @@ class TestManagement:
config = component_tests_config(cfd_mode=CfdModes.NAMED, run_proxy_dns=False, provide_ingress=False)
LOGGER.debug(config)
config_path = write_config(tmp_path, config.full_config)
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1", "--management-diagnostics"], new_process=True):
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], new_process=True):
wait_tunnel_ready(require_min_connections=1)
cfd_cli = CloudflaredCli(config, config_path, LOGGER)
url = cfd_cli.get_management_url("debug/pprof/heap", config, config_path)
@ -97,7 +97,7 @@ class TestManagement:
config = component_tests_config(cfd_mode=CfdModes.NAMED, run_proxy_dns=False, provide_ingress=False)
LOGGER.debug(config)
config_path = write_config(tmp_path, config.full_config)
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], new_process=True):
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1", "--management-diagnostics=false"], new_process=True):
wait_tunnel_ready(require_min_connections=1)
cfd_cli = CloudflaredCli(config, config_path, LOGGER)
url = cfd_cli.get_management_url("metrics", config, config_path)

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python
from conftest import CfdModes
from constants import METRICS_PORT
import time
from util import LOGGER, start_cloudflared, wait_tunnel_ready, get_quicktunnel_url, send_requests
class TestQuickTunnels:
@ -9,6 +10,7 @@ class TestQuickTunnels:
LOGGER.debug(config)
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], cfd_args=["--hello-world"], new_process=True):
wait_tunnel_ready(require_min_connections=1)
time.sleep(10)
url = get_quicktunnel_url()
send_requests(url, 3, True)
@ -17,6 +19,7 @@ class TestQuickTunnels:
LOGGER.debug(config)
with start_cloudflared(tmp_path, config, cfd_pre_args=["tunnel", "--ha-connections", "1"], cfd_args=["--url", f"http://localhost:{METRICS_PORT}/"], new_process=True):
wait_tunnel_ready(require_min_connections=1)
time.sleep(10)
url = get_quicktunnel_url()
send_requests(url+"/ready", 3, True)

View File

@ -205,6 +205,8 @@ type OriginRequestConfig struct {
HTTPHostHeader *string `yaml:"httpHostHeader" json:"httpHostHeader,omitempty"`
// Hostname on the origin server certificate.
OriginServerName *string `yaml:"originServerName" json:"originServerName,omitempty"`
// Auto configure the Hostname on the origin server certificate.
MatchSNIToHost *bool `yaml:"matchSNItoHost" json:"matchSNItoHost,omitempty"`
// Path to the CA for the certificate of your origin.
// This option should be used only if your certificate is not signed by Cloudflare.
CAPool *string `yaml:"caPool" json:"caPool,omitempty"`

View File

@ -66,6 +66,7 @@ type QUICConnection struct {
connIndex uint8
udpUnregisterTimeout time.Duration
streamWriteTimeout time.Duration
}
// NewQUICConnection returns a new instance of QUICConnection.
@ -82,6 +83,7 @@ func NewQUICConnection(
logger *zerolog.Logger,
packetRouterConfig *ingress.GlobalRouterConfig,
udpUnregisterTimeout time.Duration,
streamWriteTimeout time.Duration,
) (*QUICConnection, error) {
udpConn, err := createUDPConnForConnIndex(connIndex, localAddr, logger)
if err != nil {
@ -117,6 +119,7 @@ func NewQUICConnection(
connOptions: connOptions,
connIndex: connIndex,
udpUnregisterTimeout: udpUnregisterTimeout,
streamWriteTimeout: streamWriteTimeout,
}, nil
}
@ -195,7 +198,7 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {
func (q *QUICConnection) runStream(quicStream quic.Stream) {
ctx := quicStream.Context()
stream := quicpogs.NewSafeStreamCloser(quicStream)
stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close()
// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
@ -321,6 +324,7 @@ func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.
session, err := q.sessionManager.RegisterSession(ctx, sessionID, originProxy)
if err != nil {
originProxy.Close()
log.Err(err).Str("sessionID", sessionID.String()).Msgf("Failed to register udp session")
tracing.EndWithErrorStatus(registerSpan, err)
return nil, err
@ -373,7 +377,7 @@ func (q *QUICConnection) closeUDPSession(ctx context.Context, sessionID uuid.UUI
return
}
stream := quicpogs.NewSafeStreamCloser(quicStream)
stream := quicpogs.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close()
rpcClientStream, err := quicpogs.NewRPCClientStream(ctx, stream, q.udpUnregisterTimeout, q.logger)
if err != nil {

View File

@ -35,6 +35,7 @@ var (
KeepAlivePeriod: 5 * time.Second,
EnableDatagrams: true,
}
defaultQUICTimeout = 30 * time.Second
)
var _ ReadWriteAcker = (*streamReadWriteAcker)(nil)
@ -197,7 +198,7 @@ func quicServer(
quicStream, err := session.OpenStreamSync(context.Background())
require.NoError(t, err)
stream := quicpogs.NewSafeStreamCloser(quicStream)
stream := quicpogs.NewSafeStreamCloser(quicStream, defaultQUICTimeout, &log)
reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream}
err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...)
@ -726,6 +727,7 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T, index uint8) *QU
&log,
nil,
5*time.Second,
0*time.Second,
)
require.NoError(t, err)
return qc

View File

@ -4,6 +4,7 @@ ENV GO111MODULE=on \
WORKDIR /go/src/github.com/cloudflare/cloudflared/
RUN apt-get update
COPY . .
RUN .teamcity/install-cloudflare-go.sh
# compile cloudflared
RUN make cloudflared
RUN PATH="/tmp/go/bin:$PATH" make cloudflared
RUN cp /go/src/github.com/cloudflare/cloudflared/cloudflared /usr/local/bin/

18
go.mod
View File

@ -4,14 +4,15 @@ go 1.21
require (
github.com/coredns/coredns v1.10.0
github.com/coreos/go-oidc/v3 v3.6.0
github.com/coreos/go-oidc/v3 v3.10.0
github.com/coreos/go-systemd/v22 v22.5.0
github.com/facebookgo/grace v0.0.0-20180706040059-75cf19382434
github.com/fortytw2/leaktest v1.3.0
github.com/fsnotify/fsnotify v1.4.9
github.com/getsentry/sentry-go v0.16.0
github.com/go-chi/chi/v5 v5.0.8
github.com/go-chi/cors v1.2.1
github.com/go-jose/go-jose/v3 v3.0.0
github.com/go-jose/go-jose/v4 v4.0.1
github.com/gobwas/ws v1.0.4
github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3
github.com/google/gopacket v1.1.19
@ -24,7 +25,7 @@ require (
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.13.0
github.com/prometheus/client_model v0.2.0
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55
github.com/quic-go/quic-go v0.42.0
github.com/rs/zerolog v1.20.0
github.com/stretchr/testify v1.8.4
github.com/urfave/cli/v2 v2.3.0
@ -35,11 +36,11 @@ require (
go.opentelemetry.io/otel/trace v1.21.0
go.opentelemetry.io/proto/otlp v1.0.0
go.uber.org/automaxprocs v1.4.0
golang.org/x/crypto v0.16.0
golang.org/x/net v0.19.0
golang.org/x/crypto v0.21.0
golang.org/x/net v0.21.0
golang.org/x/sync v0.4.0
golang.org/x/sys v0.15.0
golang.org/x/term v0.15.0
golang.org/x/sys v0.18.0
golang.org/x/term v0.18.0
google.golang.org/protobuf v1.31.0
gopkg.in/natefinch/lumberjack.v2 v2.0.0
gopkg.in/yaml.v3 v3.0.1
@ -81,10 +82,9 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/common v0.37.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect
github.com/quic-go/qtls-go1-20 v0.4.1 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
go.opentelemetry.io/otel/metric v1.21.0 // indirect
go.uber.org/mock v0.3.0 // indirect
go.uber.org/mock v0.4.0 // indirect
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.11.0 // indirect
golang.org/x/oauth2 v0.13.0 // indirect

39
go.sum
View File

@ -60,8 +60,8 @@ github.com/coredns/caddy v1.1.1 h1:2eYKZT7i6yxIfGP3qLJoJ7HAsDJqYB+X68g4NYjSrE0=
github.com/coredns/caddy v1.1.1/go.mod h1:A6ntJQlAWuQfFlsd9hvigKbo2WS0VUs2l1e2F+BawD4=
github.com/coredns/coredns v1.10.0 h1:jCfuWsBjTs0dapkkhISfPCzn5LqvSRtrFtaf/Tjj4DI=
github.com/coredns/coredns v1.10.0/go.mod h1:CIfRU5TgpuoIiJBJ4XrofQzfFQpPFh32ERpUevrSlaw=
github.com/coreos/go-oidc/v3 v3.6.0 h1:AKVxfYw1Gmkn/w96z0DbT/B/xFnzTd3MkZvWLjF4n/o=
github.com/coreos/go-oidc/v3 v3.6.0/go.mod h1:ZpHUsHBucTUj6WOkrP4E20UPynbLZzhTQ1XKCXkxyPc=
github.com/coreos/go-oidc/v3 v3.10.0 h1:tDnXHnLyiTVyT/2zLDGj09pFPkhND8Gl8lnTRhoEaJU=
github.com/coreos/go-oidc/v3 v3.10.0/go.mod h1:5j11xcw0D3+SGxn6Z/WFADsgcWVMyNAlSQupk0KK3ac=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
@ -88,6 +88,8 @@ github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870 h1:E2s37DuLxFhQD
github.com/facebookgo/subset v0.0.0-20150612182917-8dac2c3c4870/go.mod h1:5tD+neXqOorC30/tWg0LCSkrqj/AR6gu8yY8/fpw1q0=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/getsentry/sentry-go v0.16.0 h1:owk+S+5XcgJLlGR/3+3s6N4d+uKwqYvh/eS0AIMjPWo=
@ -106,8 +108,8 @@ github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3Bop
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-jose/go-jose/v3 v3.0.0 h1:s6rrhirfEP/CGIoc6p+PZAeogN2SxKav6Wp7+dyMWVo=
github.com/go-jose/go-jose/v3 v3.0.0/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8=
github.com/go-jose/go-jose/v4 v4.0.1 h1:QVEPDE3OluqXBQZDcnNvQrInro2h0e4eqNbnZSWqS6U=
github.com/go-jose/go-jose/v4 v4.0.1/go.mod h1:WVf9LFMHh/QVrmqrOfqun0C45tMe3RoiKJMPvgWwLfY=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
@ -320,10 +322,8 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1
github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo=
github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4=
github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5nfFs=
github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k=
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55 h1:I4N3ZRnkZPbDN935Tg8QDf8fRpHp3bZ0U0/L42jBgNE=
github.com/quic-go/quic-go v0.40.1-0.20231203135336-87ef8ec48d55/go.mod h1:PeN7kuVJ4xZbxSv/4OX6S1USOX8MJvydwpTx31vx60c=
github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM=
github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
@ -381,18 +381,17 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/automaxprocs v1.4.0 h1:CpDZl6aOlLhReez+8S3eEotD7Jx0Os++lemPlMULQP0=
go.uber.org/automaxprocs v1.4.0/go.mod h1:/mTEdr7LvHhs0v7mjdxDreTz1OG5zdZGqgOnhWiR/+Q=
go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo=
go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -464,8 +463,8 @@ golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -534,12 +533,12 @@ golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4=
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -553,6 +552,8 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=

View File

@ -32,6 +32,7 @@ const (
ProxyKeepAliveTimeoutFlag = "proxy-keepalive-timeout"
HTTPHostHeaderFlag = "http-host-header"
OriginServerNameFlag = "origin-server-name"
MatchSNIToHostFlag = "match-sni-to-host"
NoTLSVerifyFlag = "no-tls-verify"
NoChunkedEncodingFlag = "no-chunked-encoding"
ProxyAddressFlag = "proxy-address"
@ -118,6 +119,7 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
var keepAliveTimeout = defaultKeepAliveTimeout
var httpHostHeader string
var originServerName string
var matchSNItoHost bool
var caPool string
var noTLSVerify bool
var disableChunkedEncoding bool
@ -150,6 +152,9 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
if flag := OriginServerNameFlag; c.IsSet(flag) {
originServerName = c.String(flag)
}
if flag := MatchSNIToHostFlag; c.IsSet(flag) {
matchSNItoHost = c.Bool(flag)
}
if flag := tlsconfig.OriginCAPoolFlag; c.IsSet(flag) {
caPool = c.String(flag)
}
@ -185,6 +190,7 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
KeepAliveTimeout: keepAliveTimeout,
HTTPHostHeader: httpHostHeader,
OriginServerName: originServerName,
MatchSNIToHost: matchSNItoHost,
CAPool: caPool,
NoTLSVerify: noTLSVerify,
DisableChunkedEncoding: disableChunkedEncoding,
@ -229,6 +235,9 @@ func originRequestFromConfig(c config.OriginRequestConfig) OriginRequestConfig {
if c.OriginServerName != nil {
out.OriginServerName = *c.OriginServerName
}
if c.MatchSNIToHost != nil {
out.MatchSNIToHost = *c.MatchSNIToHost
}
if c.CAPool != nil {
out.CAPool = *c.CAPool
}
@ -287,6 +296,8 @@ type OriginRequestConfig struct {
HTTPHostHeader string `yaml:"httpHostHeader" json:"httpHostHeader"`
// Hostname on the origin server certificate.
OriginServerName string `yaml:"originServerName" json:"originServerName"`
// Auto configure the Hostname on the origin server certificate.
MatchSNIToHost bool `yaml:"matchSNItoHost" json:"matchSNItoHost"`
// Path to the CA for the certificate of your origin.
// This option should be used only if your certificate is not signed by Cloudflare.
CAPool string `yaml:"caPool" json:"caPool"`
@ -362,6 +373,12 @@ func (defaults *OriginRequestConfig) setOriginServerName(overrides config.Origin
}
}
func (defaults *OriginRequestConfig) setMatchSNIToHost(overrides config.OriginRequestConfig) {
if val := overrides.MatchSNIToHost; val != nil {
defaults.MatchSNIToHost = *val
}
}
func (defaults *OriginRequestConfig) setCAPool(overrides config.OriginRequestConfig) {
if val := overrides.CAPool; val != nil {
defaults.CAPool = *val
@ -447,6 +464,7 @@ func setConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfi
cfg.setTCPKeepAlive(overrides)
cfg.setHTTPHostHeader(overrides)
cfg.setOriginServerName(overrides)
cfg.setMatchSNIToHost(overrides)
cfg.setCAPool(overrides)
cfg.setNoTLSVerify(overrides)
cfg.setDisableChunkedEncoding(overrides)
@ -501,6 +519,7 @@ func ConvertToRawOriginConfig(c OriginRequestConfig) config.OriginRequestConfig
KeepAliveTimeout: keepAliveTimeout,
HTTPHostHeader: emptyStringToNil(c.HTTPHostHeader),
OriginServerName: emptyStringToNil(c.OriginServerName),
MatchSNIToHost: defaultBoolToNil(c.MatchSNIToHost),
CAPool: emptyStringToNil(c.CAPool),
NoTLSVerify: defaultBoolToNil(c.NoTLSVerify),
DisableChunkedEncoding: defaultBoolToNil(c.DisableChunkedEncoding),

View File

@ -0,0 +1,7 @@
package ingress
import "github.com/cloudflare/cloudflared/logger"
var (
TestLogger = logger.Create(nil)
)

View File

@ -131,7 +131,7 @@ func newICMPProxy(listenIP netip.Addr, zone string, logger *zerolog.Logger, idle
}
func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *packetResponder) error {
ctx, span := responder.requestSpan(ctx, pk)
_, span := responder.requestSpan(ctx, pk)
defer responder.exportSpan()
originalEcho, err := getICMPEcho(pk.Message)
@ -139,10 +139,8 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
tracing.EndWithErrorStatus(span, err)
return err
}
span.SetAttributes(
attribute.Int("originalEchoID", originalEcho.ID),
attribute.Int("seq", originalEcho.Seq),
)
observeICMPRequest(ip.logger, span, pk.Src.String(), pk.Dst.String(), originalEcho.ID, originalEcho.Seq)
echoIDTrackerKey := flow3Tuple{
srcIP: pk.Src,
dstIP: pk.Dst,
@ -189,6 +187,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
tracing.EndWithErrorStatus(span, err)
return err
}
err = icmpFlow.sendToDst(pk.Dst, pk.Message)
if err != nil {
tracing.EndWithErrorStatus(span, err)
@ -269,15 +268,12 @@ func (ip *icmpProxy) sendReply(ctx context.Context, reply *echoReply) error {
_, span := icmpFlow.responder.replySpan(ctx, ip.logger)
defer icmpFlow.responder.exportSpan()
span.SetAttributes(
attribute.String("dst", reply.from.String()),
attribute.Int("echoID", reply.echo.ID),
attribute.Int("seq", reply.echo.Seq),
attribute.Int("originalEchoID", icmpFlow.originalEchoID),
)
if err := icmpFlow.returnToSrc(reply); err != nil {
tracing.EndWithErrorStatus(span, err)
return err
}
observeICMPReply(ip.logger, span, reply.from.String(), reply.echo.ID, reply.echo.Seq)
span.SetAttributes(attribute.Int("originalEchoID", icmpFlow.originalEchoID))
tracing.End(span)
return nil
}

View File

@ -78,19 +78,19 @@ func checkInPingGroup() error {
if err != nil {
return err
}
groupID := os.Getgid()
groupID := uint64(os.Getegid())
// Example content: 999 59999
found := findGroupIDRegex.FindAll(file, 2)
if len(found) == 2 {
groupMin, err := strconv.ParseInt(string(found[0]), 10, 32)
groupMin, err := strconv.ParseUint(string(found[0]), 10, 32)
if err != nil {
return errors.Wrapf(err, "failed to determine minimum ping group ID")
}
groupMax, err := strconv.ParseInt(string(found[1]), 10, 32)
groupMax, err := strconv.ParseUint(string(found[1]), 10, 32)
if err != nil {
return errors.Wrapf(err, "failed to determine minimum ping group ID")
return errors.Wrapf(err, "failed to determine maximum ping group ID")
}
if groupID < int(groupMin) || groupID > int(groupMax) {
if groupID < groupMin || groupID > groupMax {
return fmt.Errorf("Group ID %d is not between ping group %d to %d", groupID, groupMin, groupMax)
}
return nil
@ -107,10 +107,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
tracing.EndWithErrorStatus(span, err)
return err
}
span.SetAttributes(
attribute.Int("originalEchoID", originalEcho.ID),
attribute.Int("seq", originalEcho.Seq),
)
observeICMPRequest(ip.logger, span, pk.Src.String(), pk.Dst.String(), originalEcho.ID, originalEcho.Seq)
shouldReplaceFunnelFunc := createShouldReplaceFunnelFunc(ip.logger, responder.datagramMuxer, pk, originalEcho.ID)
newFunnelFunc := func() (packet.Funnel, error) {
@ -156,14 +153,8 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
Int("originalEchoID", originalEcho.ID).
Msg("New flow")
go func() {
defer ip.srcFunnelTracker.Unregister(funnelID, icmpFlow)
if err := ip.listenResponse(ctx, icmpFlow); err != nil {
ip.logger.Debug().Err(err).
Str("src", pk.Src.String()).
Str("dst", pk.Dst.String()).
Int("originalEchoID", originalEcho.ID).
Msg("Failed to listen for ICMP echo response")
}
ip.listenResponse(ctx, icmpFlow)
ip.srcFunnelTracker.Unregister(funnelID, icmpFlow)
}()
}
if err := icmpFlow.sendToDst(pk.Dst, pk.Message); err != nil {
@ -179,17 +170,17 @@ func (ip *icmpProxy) Serve(ctx context.Context) error {
return ctx.Err()
}
func (ip *icmpProxy) listenResponse(ctx context.Context, flow *icmpEchoFlow) error {
func (ip *icmpProxy) listenResponse(ctx context.Context, flow *icmpEchoFlow) {
buf := make([]byte, mtu)
for {
retryable, err := ip.handleResponse(ctx, flow, buf)
if err != nil && !retryable {
return err
if done := ip.handleResponse(ctx, flow, buf); done {
return
}
}
}
func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf []byte) (retryableErr bool, err error) {
// Listens for ICMP response and handles error logging
func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf []byte) (done bool) {
_, span := flow.responder.replySpan(ctx, ip.logger)
defer flow.responder.exportSpan()
@ -199,33 +190,36 @@ func (ip *icmpProxy) handleResponse(ctx context.Context, flow *icmpEchoFlow, buf
n, from, err := flow.originConn.ReadFrom(buf)
if err != nil {
if flow.IsClosed() {
tracing.EndWithErrorStatus(span, fmt.Errorf("flow was closed"))
return true
}
ip.logger.Error().Err(err).Str("socket", flow.originConn.LocalAddr().String()).Msg("Failed to read from ICMP socket")
tracing.EndWithErrorStatus(span, err)
return false, err
return true
}
reply, err := parseReply(from, buf[:n])
if err != nil {
ip.logger.Error().Err(err).Str("dst", from.String()).Msg("Failed to parse ICMP reply")
tracing.EndWithErrorStatus(span, err)
return true, err
return false
}
if !isEchoReply(reply.msg) {
err := fmt.Errorf("Expect ICMP echo reply, got %s", reply.msg.Type)
ip.logger.Debug().Str("dst", from.String()).Msgf("Drop ICMP %s from reply", reply.msg.Type)
tracing.EndWithErrorStatus(span, err)
return true, err
return false
}
span.SetAttributes(
attribute.String("dst", reply.from.String()),
attribute.Int("echoID", reply.echo.ID),
attribute.Int("seq", reply.echo.Seq),
)
if err := flow.returnToSrc(reply); err != nil {
ip.logger.Debug().Err(err).Str("dst", from.String()).Msg("Failed to send ICMP reply")
ip.logger.Error().Err(err).Str("dst", from.String()).Msg("Failed to send ICMP reply")
tracing.EndWithErrorStatus(span, err)
return true, err
return false
}
observeICMPReply(ip.logger, span, from.String(), reply.echo.ID, reply.echo.Seq)
tracing.End(span)
return true, nil
return false
}
// Only linux uses flow3Tuple as FunnelID

View File

@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/netip"
"sync/atomic"
"github.com/google/gopacket/layers"
"github.com/rs/zerolog"
@ -46,6 +47,7 @@ type flow3Tuple struct {
type icmpEchoFlow struct {
*packet.ActivityTracker
closeCallback func() error
closed *atomic.Bool
src netip.Addr
originConn *icmp.PacketConn
responder *packetResponder
@ -59,6 +61,7 @@ func newICMPEchoFlow(src netip.Addr, closeCallback func() error, originConn *icm
return &icmpEchoFlow{
ActivityTracker: packet.NewActivityTracker(),
closeCallback: closeCallback,
closed: &atomic.Bool{},
src: src,
originConn: originConn,
responder: responder,
@ -86,9 +89,14 @@ func (ief *icmpEchoFlow) Equal(other packet.Funnel) bool {
}
func (ief *icmpEchoFlow) Close() error {
ief.closed.Store(true)
return ief.closeCallback()
}
func (ief *icmpEchoFlow) IsClosed() bool {
return ief.closed.Load()
}
// sendToDst rewrites the echo ID to the one assigned to this flow
func (ief *icmpEchoFlow) sendToDst(dst netip.Addr, msg *icmp.Message) error {
ief.UpdateLastActive()

View File

@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/fortytw2/leaktest"
"github.com/google/gopacket/layers"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
@ -18,6 +19,8 @@ import (
)
func TestFunnelIdleTimeout(t *testing.T) {
defer leaktest.Check(t)()
const (
idleTimeout = time.Second
echoID = 42573
@ -73,13 +76,16 @@ func TestFunnelIdleTimeout(t *testing.T) {
require.NoError(t, proxy.Request(ctx, &pk, &newResponder))
validateEchoFlow(t, <-newMuxer.cfdToEdge, &pk)
time.Sleep(idleTimeout * 2)
cancel()
<-proxyDone
}
func TestReuseFunnel(t *testing.T) {
defer leaktest.Check(t)()
const (
idleTimeout = time.Second
idleTimeout = time.Millisecond * 100
echoID = 42573
startSeq = 8129
)
@ -135,6 +141,8 @@ func TestReuseFunnel(t *testing.T) {
require.True(t, found)
require.Equal(t, funnel1, funnel2)
time.Sleep(idleTimeout * 2)
cancel()
<-proxyDone
}

View File

@ -281,10 +281,7 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
if err != nil {
return err
}
requestSpan.SetAttributes(
attribute.Int("originalEchoID", echo.ID),
attribute.Int("seq", echo.Seq),
)
observeICMPRequest(ip.logger, requestSpan, pk.Src.String(), pk.Dst.String(), echo.ID, echo.Seq)
resp, err := ip.icmpEchoRoundtrip(pk.Dst, echo)
if err != nil {
@ -296,17 +293,17 @@ func (ip *icmpProxy) Request(ctx context.Context, pk *packet.ICMP, responder *pa
responder.exportSpan()
_, replySpan := responder.replySpan(ctx, ip.logger)
replySpan.SetAttributes(
attribute.Int("originalEchoID", echo.ID),
attribute.Int("seq", echo.Seq),
attribute.Int64("rtt", int64(resp.rtt())),
attribute.String("status", resp.status().String()),
)
err = ip.handleEchoReply(pk, echo, resp, responder)
if err != nil {
ip.logger.Err(err).Msg("Failed to send ICMP reply")
tracing.EndWithErrorStatus(replySpan, err)
return errors.Wrap(err, "failed to handle ICMP echo reply")
}
observeICMPReply(ip.logger, replySpan, pk.Dst.String(), echo.ID, echo.Seq)
replySpan.SetAttributes(
attribute.Int64("rtt", int64(resp.rtt())),
attribute.String("status", resp.status().String()),
)
tracing.End(replySpan)
return nil
}

View File

@ -0,0 +1,126 @@
package middleware
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/json"
"fmt"
"net/http/httptest"
"testing"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
issuer = fmt.Sprintf(cloudflareAccessCertsURL, "testteam")
)
type accessTokenClaims struct {
Email string `json:"email"`
Type string `json:"type"`
jwt.Claims
}
func TestJWTValidator(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com", nil)
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
issued := time.Now()
claims := accessTokenClaims{
Email: "test@example.com",
Type: "app",
Claims: jwt.Claims{
Issuer: issuer,
Subject: "ee239b7a-e3e6-4173-972a-8fbe9d99c04f",
Audience: []string{""},
Expiry: jwt.NewNumericDate(issued.Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(issued),
},
}
token := signToken(t, claims, key)
req.Header.Add(headerKeyAccessJWTAssertion, token)
keySet := oidc.StaticKeySet{PublicKeys: []crypto.PublicKey{key.Public()}}
config := &oidc.Config{
SkipClientIDCheck: true,
SupportedSigningAlgs: []string{string(jose.ES256)},
}
verifier := oidc.NewVerifier(issuer, &keySet, config)
tests := []struct {
name string
audTags []string
aud jwt.Audience
error bool
}{
{
name: "valid",
audTags: []string{
"0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
},
aud: jwt.Audience{"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba"},
error: false,
},
{
name: "invalid no match",
audTags: []string{
"0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
},
aud: jwt.Audience{"09dc377143841843ecca28b196bdb1ec1675af38c8b7b60c7def5876c8877157"},
error: true,
},
{
name: "invalid empty check",
audTags: []string{},
aud: jwt.Audience{"09dc377143841843ecca28b196bdb1ec1675af38c8b7b60c7def5876c8877157"},
error: true,
},
{
name: "invalid absent aud",
audTags: []string{
"0bc545634b1732494b3f9472794a549c883fabd48de9dfe0e0413e59c3f96c38",
"d7ec5b7fda23ffa8f8c8559fb37c66a2278208a78dbe376a3394b5ffec6911ba",
},
aud: jwt.Audience{""},
error: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
validator := JWTValidator{
IDTokenVerifier: verifier,
audTags: test.audTags,
}
claims.Audience = test.aud
token := signToken(t, claims, key)
req.Header.Set(headerKeyAccessJWTAssertion, token)
result, err := validator.Handle(context.Background(), req)
assert.NoError(t, err)
assert.Equal(t, test.error, result.ShouldFilterRequest)
})
}
}
func signToken(t *testing.T, token accessTokenClaims, key *ecdsa.PrivateKey) string {
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, &jose.SignerOptions{})
require.NoError(t, err)
payload, err := json.Marshal(token)
require.NoError(t, err)
jws, err := signer.Sign(payload)
require.NoError(t, err)
jwt, err := jws.CompactSerialize()
require.NoError(t, err)
return jwt
}

View File

@ -4,6 +4,7 @@ import (
"context"
"io"
"net"
"time"
"github.com/rs/zerolog"
@ -31,15 +32,32 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze
// tcpConnection is an OriginConnection that directly streams to raw TCP.
type tcpConnection struct {
conn net.Conn
net.Conn
writeTimeout time.Duration
logger *zerolog.Logger
}
func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
stream.Pipe(tunnelConn, tc.conn, log)
func (tc *tcpConnection) Stream(_ context.Context, tunnelConn io.ReadWriter, _ *zerolog.Logger) {
stream.Pipe(tunnelConn, tc, tc.logger)
}
func (tc *tcpConnection) Write(b []byte) (int, error) {
if tc.writeTimeout > 0 {
if err := tc.Conn.SetWriteDeadline(time.Now().Add(tc.writeTimeout)); err != nil {
tc.logger.Err(err).Msg("Error setting write deadline for TCP connection")
}
}
nBytes, err := tc.Conn.Write(b)
if err != nil {
tc.logger.Err(err).Msg("Error writing to the TCP connection")
}
return nBytes, err
}
func (tc *tcpConnection) Close() {
tc.conn.Close()
tc.Conn.Close()
}
// tcpOverWSConnection is an OriginConnection that streams to TCP over WS.

View File

@ -19,7 +19,6 @@ import (
"golang.org/x/net/proxy"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/stream"
"github.com/cloudflare/cloudflared/websocket"
@ -31,7 +30,6 @@ const (
)
var (
testLogger = logger.Create(nil)
testMessage = []byte("TestStreamOriginConnection")
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
)
@ -39,7 +37,8 @@ var (
func TestStreamTCPConnection(t *testing.T) {
cfdConn, originConn := net.Pipe()
tcpConn := tcpConnection{
conn: cfdConn,
Conn: cfdConn,
writeTimeout: 30 * time.Second,
}
eyeballConn, edgeConn := net.Pipe()
@ -66,7 +65,7 @@ func TestStreamTCPConnection(t *testing.T) {
return nil
})
tcpConn.Stream(ctx, edgeConn, testLogger)
tcpConn.Stream(ctx, edgeConn, TestLogger)
require.NoError(t, errGroup.Wait())
}
@ -93,7 +92,7 @@ func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
return nil
})
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
require.NoError(t, errGroup.Wait())
}
@ -147,7 +146,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
return nil
})
@ -159,7 +158,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
require.NoError(t, err)
defer wsForwarderInConn.Close()
stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, TestLogger)
return nil
})
@ -209,7 +208,7 @@ func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
originConn.Close()
}()
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
tcpOverWSConn.Stream(ctx, eyeballConn, testLogger)
tcpOverWSConn.Stream(ctx, eyeballConn, TestLogger)
})
server := httptest.NewServer(handler)
defer server.Close()

View File

@ -7,6 +7,8 @@ import (
"time"
"github.com/rs/zerolog"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
@ -15,9 +17,7 @@ import (
)
const (
// funnelIdleTimeout controls how long to wait to close a funnel without send/return
funnelIdleTimeout = time.Second * 10
mtu = 1500
mtu = 1500
// icmpRequestTimeoutMs controls how long to wait for a reply
icmpRequestTimeoutMs = 1000
)
@ -32,8 +32,9 @@ type icmpRouter struct {
}
// NewICMPRouter doesn't return an error if either ipv4 proxy or ipv6 proxy can be created. The machine might only
// support one of them
func NewICMPRouter(ipv4Addr, ipv6Addr netip.Addr, ipv6Zone string, logger *zerolog.Logger) (*icmpRouter, error) {
// support one of them.
// funnelIdleTimeout controls how long to wait to close a funnel without send/return
func NewICMPRouter(ipv4Addr, ipv6Addr netip.Addr, ipv6Zone string, logger *zerolog.Logger, funnelIdleTimeout time.Duration) (*icmpRouter, error) {
ipv4Proxy, ipv4Err := newICMPProxy(ipv4Addr, "", logger, funnelIdleTimeout)
ipv6Proxy, ipv6Err := newICMPProxy(ipv6Addr, ipv6Zone, logger, funnelIdleTimeout)
if ipv4Err != nil && ipv6Err != nil {
@ -102,3 +103,25 @@ func getICMPEcho(msg *icmp.Message) (*icmp.Echo, error) {
func isEchoReply(msg *icmp.Message) bool {
return msg.Type == ipv4.ICMPTypeEchoReply || msg.Type == ipv6.ICMPTypeEchoReply
}
func observeICMPRequest(logger *zerolog.Logger, span trace.Span, src string, dst string, echoID int, seq int) {
logger.Debug().
Str("src", src).
Str("dst", dst).
Int("originalEchoID", echoID).
Int("originalEchoSeq", seq).
Msg("Received ICMP request")
span.SetAttributes(
attribute.Int("originalEchoID", echoID),
attribute.Int("seq", seq),
)
}
func observeICMPReply(logger *zerolog.Logger, span trace.Span, dst string, echoID int, seq int) {
logger.Debug().Str("dst", dst).Int("echoID", echoID).Int("seq", seq).Msg("Sent ICMP reply to edge")
span.SetAttributes(
attribute.String("dst", dst),
attribute.Int("echoID", echoID),
attribute.Int("seq", seq),
)
}

View File

@ -9,7 +9,9 @@ import (
"strings"
"sync"
"testing"
"time"
"github.com/fortytw2/leaktest"
"github.com/google/gopacket/layers"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
@ -23,9 +25,10 @@ import (
)
var (
noopLogger = zerolog.Nop()
localhostIP = netip.MustParseAddr("127.0.0.1")
localhostIPv6 = netip.MustParseAddr("::1")
noopLogger = zerolog.Nop()
localhostIP = netip.MustParseAddr("127.0.0.1")
localhostIPv6 = netip.MustParseAddr("::1")
testFunnelIdleTimeout = time.Millisecond * 10
)
// TestICMPProxyEcho makes sure we can send ICMP echo via the Request method and receives response via the
@ -40,12 +43,14 @@ func TestICMPRouterEcho(t *testing.T) {
}
func testICMPRouterEcho(t *testing.T, sendIPv4 bool) {
defer leaktest.Check(t)()
const (
echoID = 36571
endSeq = 20
)
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger)
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger, testFunnelIdleTimeout)
require.NoError(t, err)
proxyDone := make(chan struct{})
@ -97,14 +102,19 @@ func testICMPRouterEcho(t *testing.T, sendIPv4 bool) {
validateEchoFlow(t, <-muxer.cfdToEdge, &pk)
}
}
// Make sure funnel cleanup kicks in
time.Sleep(testFunnelIdleTimeout * 2)
cancel()
<-proxyDone
}
func TestTraceICMPRouterEcho(t *testing.T) {
defer leaktest.Check(t)()
tracingCtx := "ec31ad8a01fde11fdcabe2efdce36873:52726f6cabc144f5:0:1"
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger)
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger, testFunnelIdleTimeout)
require.NoError(t, err)
proxyDone := make(chan struct{})
@ -196,6 +206,7 @@ func TestTraceICMPRouterEcho(t *testing.T) {
default:
}
time.Sleep(testFunnelIdleTimeout * 2)
cancel()
<-proxyDone
}
@ -203,12 +214,14 @@ func TestTraceICMPRouterEcho(t *testing.T) {
// TestConcurrentRequests makes sure icmpRouter can send concurrent requests to the same destination with different
// echo ID. This simulates concurrent ping to the same destination.
func TestConcurrentRequestsToSameDst(t *testing.T) {
defer leaktest.Check(t)()
const (
concurrentPings = 5
endSeq = 5
)
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger)
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger, testFunnelIdleTimeout)
require.NoError(t, err)
proxyDone := make(chan struct{})
@ -282,12 +295,16 @@ func TestConcurrentRequestsToSameDst(t *testing.T) {
}()
}
wg.Wait()
time.Sleep(testFunnelIdleTimeout * 2)
cancel()
<-proxyDone
}
// TestICMPProxyRejectNotEcho makes sure it rejects messages other than echo
func TestICMPRouterRejectNotEcho(t *testing.T) {
defer leaktest.Check(t)()
msgs := []icmp.Message{
{
Type: ipv4.ICMPTypeDestinationUnreachable,
@ -341,7 +358,7 @@ func TestICMPRouterRejectNotEcho(t *testing.T) {
}
func testICMPRouterRejectNotEcho(t *testing.T, srcDstIP netip.Addr, msgs []icmp.Message) {
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger)
router, err := NewICMPRouter(localhostIP, localhostIPv6, "", &noopLogger, testFunnelIdleTimeout)
require.NoError(t, err)
muxer := newMockMuxer(1)

View File

@ -2,8 +2,12 @@ package ingress
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"github.com/rs/zerolog"
)
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
@ -14,7 +18,7 @@ type HTTPOriginProxy interface {
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
type StreamBasedOriginProxy interface {
EstablishConnection(ctx context.Context, dest string) (OriginConnection, error)
EstablishConnection(ctx context.Context, dest string, log *zerolog.Logger) (OriginConnection, error)
}
// HTTPLocalProxy can be implemented by cloudflared services that want to handle incoming http requests.
@ -46,9 +50,28 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("X-Forwarded-Host", req.Host)
req.Host = o.hostHeader
}
if o.matchSNIToHost {
o.SetOriginServerName(req)
}
return o.transport.RoundTrip(req)
}
func (o *httpService) SetOriginServerName(req *http.Request) {
o.transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := o.transport.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
return tls.Client(conn, &tls.Config{
RootCAs: o.transport.TLSClientConfig.RootCAs,
InsecureSkipVerify: o.transport.TLSClientConfig.InsecureSkipVerify,
ServerName: req.Host,
}), nil
}
}
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
if o.defaultResp {
o.log.Warn().Msgf(ErrNoIngressRulesCLI.Error())
@ -62,19 +85,21 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return resp, nil
}
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) {
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string, logger *zerolog.Logger) (OriginConnection, error) {
conn, err := o.dialer.DialContext(ctx, "tcp", dest)
if err != nil {
return nil, err
}
originConn := &tcpConnection{
conn: conn,
Conn: conn,
writeTimeout: o.writeTimeout,
logger: logger,
}
return originConn, nil
}
func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) {
func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string, _ *zerolog.Logger) (OriginConnection, error) {
var err error
if !o.isBastion {
dest = o.dest
@ -92,6 +117,6 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string)
}
func (o *socksProxyOverWSService) EstablishConnection(_ctx context.Context, _dest string) (OriginConnection, error) {
func (o *socksProxyOverWSService) EstablishConnection(_ context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) {
return o.conn, nil
}

View File

@ -36,7 +36,7 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
require.NoError(t, err)
// Origin not listening for new connection, should return an error
_, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String())
_, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String(), TestLogger)
require.Error(t, err)
}
@ -87,7 +87,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
t.Run(test.testCase, func(t *testing.T) {
if test.expectErr {
bastionHost, _ := carrier.ResolveBastionDest(test.req)
_, err := test.service.EstablishConnection(context.Background(), bastionHost)
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err)
}
})
@ -99,7 +99,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
// Origin not listening for new connection, should return an error
bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
_, err := service.EstablishConnection(context.Background(), bastionHost)
_, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger)
assert.Error(t, err)
}
}
@ -132,7 +132,7 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
url: originURL,
}
shutdownC := make(chan struct{})
require.NoError(t, httpService.start(testLogger, shutdownC, cfg))
require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
require.NoError(t, err)
@ -167,7 +167,7 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
url: originURL,
}
shutdownC := make(chan struct{})
require.NoError(t, httpService.start(testLogger, shutdownC, cfg))
require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
// Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header
protos := []string{"https", "http", "dne"}

View File

@ -68,9 +68,10 @@ func (o unixSocketPath) MarshalJSON() ([]byte, error) {
}
type httpService struct {
url *url.URL
hostHeader string
transport *http.Transport
url *url.URL
hostHeader string
transport *http.Transport
matchSNIToHost bool
}
func (o *httpService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
@ -80,6 +81,7 @@ func (o *httpService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRe
}
o.hostHeader = cfg.HTTPHostHeader
o.transport = transport
o.matchSNIToHost = cfg.MatchSNIToHost
return nil
}
@ -94,15 +96,17 @@ func (o httpService) MarshalJSON() ([]byte, error) {
// rawTCPService dials TCP to the destination specified by the client
// It's used by warp routing
type rawTCPService struct {
name string
dialer net.Dialer
name string
dialer net.Dialer
writeTimeout time.Duration
logger *zerolog.Logger
}
func (o *rawTCPService) String() string {
return o.name
}
func (o *rawTCPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
func (o *rawTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, _ OriginRequestConfig) error {
return nil
}
@ -285,13 +289,14 @@ type WarpRoutingService struct {
Proxy StreamBasedOriginProxy
}
func NewWarpRoutingService(config WarpRoutingConfig) *WarpRoutingService {
func NewWarpRoutingService(config WarpRoutingConfig, writeTimeout time.Duration) *WarpRoutingService {
svc := &rawTCPService{
name: ServiceWarpRouting,
dialer: net.Dialer{
Timeout: config.ConnectTimeout.Duration,
KeepAlive: config.TCPKeepAlive.Duration,
},
writeTimeout: writeTimeout,
}
return &WarpRoutingService{Proxy: svc}

View File

@ -204,25 +204,25 @@ func TestMarshalJSON(t *testing.T) {
{
name: "Nil",
path: nil,
expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
want: true,
},
{
name: "Nil regex",
path: &Regexp{Regexp: nil},
expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
want: true,
},
{
name: "Empty",
path: &Regexp{Regexp: regexp.MustCompile("")},
expected: `{"hostname":"example.com","path":"","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
expected: `{"hostname":"example.com","path":"","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
want: true,
},
{
name: "Basic",
path: &Regexp{Regexp: regexp.MustCompile("/echo")},
expected: `{"hostname":"example.com","path":"/echo","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
expected: `{"hostname":"example.com","path":"/echo","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`,
want: true,
},
}

View File

@ -3,7 +3,8 @@ package management
import (
"fmt"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
)
type managementTokenClaims struct {
@ -37,7 +38,7 @@ func (t *actor) verify() bool {
}
func parseToken(token string) (*managementTokenClaims, error) {
jwt, err := jwt.ParseSigned(token)
jwt, err := jwt.ParseSigned(token, []jose.SignatureAlgorithm{jose.ES256})
if err != nil {
return nil, fmt.Errorf("malformed jwt: %v", err)
}

View File

@ -7,7 +7,7 @@ import (
"errors"
"testing"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v4"
"github.com/stretchr/testify/require"
)

View File

@ -2,6 +2,7 @@ package orchestration
import (
"encoding/json"
"time"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/ingress"
@ -19,8 +20,9 @@ type newLocalConfig struct {
// Config is the original config as read and parsed by cloudflared.
type Config struct {
Ingress *ingress.Ingress
WarpRouting ingress.WarpRoutingConfig
Ingress *ingress.Ingress
WarpRouting ingress.WarpRoutingConfig
WriteTimeout time.Duration
// Extra settings used to configure this instance but that are not eligible for remotely management
// ie. (--protocol, --loglevel, ...)

View File

@ -17,10 +17,10 @@ import (
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// Orchestrator manages configurations so they can be updatable during runtime
// Orchestrator manages configurations, so they can be updatable during runtime
// properties are static, so it can be read without lock
// currentVersion and config are read/write infrequently, so their access are synchronized with RWMutex
// access to proxy is synchronized with atmoic.Value, because it uses copy-on-write to provide scalable frequently
// access to proxy is synchronized with atomic.Value, because it uses copy-on-write to provide scalable frequently
// read when update is infrequent
type Orchestrator struct {
currentVersion int32
@ -30,9 +30,10 @@ type Orchestrator struct {
proxy atomic.Value
// Set of internal ingress rules defined at cloudflared startup (separate from user-defined ingress rules)
internalRules []ingress.Rule
config *Config
tags []tunnelpogs.Tag
log *zerolog.Logger
// cloudflared Configuration
config *Config
tags []tunnelpogs.Tag
log *zerolog.Logger
// orchestrator must not handle any more updates after shutdownC is closed
shutdownC <-chan struct{}
@ -40,7 +41,11 @@ type Orchestrator struct {
proxyShutdownC chan<- struct{}
}
func NewOrchestrator(ctx context.Context, config *Config, tags []tunnelpogs.Tag, internalRules []ingress.Rule, log *zerolog.Logger) (*Orchestrator, error) {
func NewOrchestrator(ctx context.Context,
config *Config,
tags []tunnelpogs.Tag,
internalRules []ingress.Rule,
log *zerolog.Logger) (*Orchestrator, error) {
o := &Orchestrator{
// Lowest possible version, any remote configuration will have version higher than this
// Starting at -1 allows a configuration migration (local to remote) to override the current configuration as it
@ -131,7 +136,7 @@ func (o *Orchestrator) updateIngress(ingressRules ingress.Ingress, warpRouting i
if err := ingressRules.StartOrigins(o.log, proxyShutdownC); err != nil {
return errors.Wrap(err, "failed to start origin")
}
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.log)
proxy := proxy.NewOriginProxy(ingressRules, warpRouting, o.tags, o.config.WriteTimeout, o.log)
o.proxy.Store(proxy)
o.config.Ingress = &ingressRules
o.config.WarpRouting = warpRouting

View File

@ -27,7 +27,7 @@ import (
)
var (
testLogger = zerolog.Logger{}
testLogger = zerolog.Nop()
testTags = []tunnelpogs.Tag{
{
Name: "package",

View File

@ -1,4 +1,4 @@
#!/bin/bash
set -eu
rm /usr/local/bin/cloudflared
rm /usr/local/etc/cloudflared/.installedFromPackageManager || true
rm -f /usr/local/bin/cloudflared
rm -f /usr/local/etc/cloudflared/.installedFromPackageManager

78
proxy/logger.go Normal file
View File

@ -0,0 +1,78 @@
package proxy
import (
"net/http"
"strconv"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/management"
)
const (
logFieldCFRay = "cfRay"
logFieldLBProbe = "lbProbe"
logFieldRule = "ingressRule"
logFieldOriginService = "originService"
logFieldFlowID = "flowID"
logFieldConnIndex = "connIndex"
logFieldDestAddr = "destAddr"
)
// newHTTPLogger creates a child zerolog.Logger from the provided with added context from the HTTP request, ingress
// services, and connection index.
func newHTTPLogger(logger *zerolog.Logger, connIndex uint8, req *http.Request, rule int, serviceName string) zerolog.Logger {
ctx := logger.With().
Int(management.EventTypeKey, int(management.HTTP)).
Uint8(logFieldConnIndex, connIndex)
cfRay := connection.FindCfRayHeader(req)
lbProbe := connection.IsLBProbeRequest(req)
if cfRay != "" {
ctx.Str(logFieldCFRay, cfRay)
}
if lbProbe {
ctx.Bool(logFieldLBProbe, lbProbe)
}
return ctx.
Str(logFieldOriginService, serviceName).
Interface(logFieldRule, rule).
Logger()
}
// newTCPLogger creates a child zerolog.Logger from the provided with added context from the TCPRequest.
func newTCPLogger(logger *zerolog.Logger, req *connection.TCPRequest) zerolog.Logger {
return logger.With().
Int(management.EventTypeKey, int(management.TCP)).
Uint8(logFieldConnIndex, req.ConnIndex).
Str(logFieldOriginService, ingress.ServiceWarpRouting).
Str(logFieldFlowID, req.FlowID).
Str(logFieldDestAddr, req.Dest).
Uint8(logFieldConnIndex, req.ConnIndex).
Logger()
}
// logHTTPRequest logs a Debug message with the corresponding HTTP request details from the eyeball.
func logHTTPRequest(logger *zerolog.Logger, r *http.Request) {
logger.Debug().
Str("host", r.Host).
Str("path", r.URL.Path).
Interface("headers", r.Header).
Int64("content-length", r.ContentLength).
Msgf("%s %s %s", r.Method, r.URL, r.Proto)
}
// logOriginHTTPResponse logs a Debug message of the origin response.
func logOriginHTTPResponse(logger *zerolog.Logger, resp *http.Response) {
responseByCode.WithLabelValues(strconv.Itoa(resp.StatusCode)).Inc()
logger.Debug().
Int64("content-length", resp.ContentLength).
Msgf("%s", resp.Status)
}
// logRequestError logs an error for the proxied request.
func logRequestError(logger *zerolog.Logger, err error) {
requestErrors.Inc()
logger.Error().Err(err).Send()
}

View File

@ -59,6 +59,23 @@ var (
Help: "Total count of TCP sessions that have been proxied to any origin",
},
)
connectLatency = prometheus.NewHistogram(
prometheus.HistogramOpts{
Namespace: connection.MetricsNamespace,
Subsystem: "proxy",
Name: "connect_latency",
Help: "Time it takes to establish and acknowledge connections in milliseconds",
Buckets: []float64{1, 10, 25, 50, 100, 500, 1000, 5000},
},
)
connectStreamErrors = prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: connection.MetricsNamespace,
Subsystem: "proxy",
Name: "connect_streams_errors",
Help: "Total count of failure to establish and acknowledge connections",
},
)
)
func init() {
@ -69,6 +86,8 @@ func init() {
requestErrors,
activeTCPSessions,
totalTCPSessions,
connectLatency,
connectStreamErrors,
)
}

View File

@ -6,6 +6,7 @@ import (
"io"
"net/http"
"strconv"
"time"
"github.com/pkg/errors"
"github.com/rs/zerolog"
@ -16,7 +17,6 @@ import (
"github.com/cloudflare/cloudflared/cfio"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/management"
"github.com/cloudflare/cloudflared/stream"
"github.com/cloudflare/cloudflared/tracing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@ -24,16 +24,8 @@ import (
const (
// TagHeaderNamePrefix indicates a Cloudflared Warp Tag prefix that gets appended for warp traffic stream headers.
TagHeaderNamePrefix = "Cf-Warp-Tag-"
LogFieldCFRay = "cfRay"
LogFieldLBProbe = "lbProbe"
LogFieldRule = "ingressRule"
LogFieldOriginService = "originService"
LogFieldFlowID = "flowID"
LogFieldConnIndex = "connIndex"
LogFieldDestAddr = "destAddr"
trailerHeaderName = "Trailer"
TagHeaderNamePrefix = "Cf-Warp-Tag-"
trailerHeaderName = "Trailer"
)
// Proxy represents a means to Proxy between cloudflared and the origin services.
@ -50,6 +42,7 @@ func NewOriginProxy(
ingressRules ingress.Ingress,
warpRouting ingress.WarpRoutingConfig,
tags []tunnelpogs.Tag,
writeTimeout time.Duration,
log *zerolog.Logger,
) *Proxy {
proxy := &Proxy{
@ -58,7 +51,7 @@ func NewOriginProxy(
log: log,
}
proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting)
proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting, writeTimeout)
return proxy
}
@ -89,26 +82,18 @@ func (p *Proxy) ProxyHTTP(
defer decrementConcurrentRequests()
req := tr.Request
cfRay := connection.FindCfRayHeader(req)
lbProbe := connection.IsLBProbeRequest(req)
p.appendTagHeaders(req)
_, ruleSpan := tr.Tracer().Start(req.Context(), "ingress_match",
trace.WithAttributes(attribute.String("req-host", req.Host)))
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
logFields := logFields{
cfRay: cfRay,
lbProbe: lbProbe,
rule: ruleNum,
connIndex: tr.ConnIndex,
}
p.logRequest(req, logFields)
ruleSpan.SetAttributes(attribute.Int("rule-num", ruleNum))
ruleSpan.End()
logger := newHTTPLogger(p.log, tr.ConnIndex, req, ruleNum, rule.Service.String())
logHTTPRequest(&logger, req)
if err, applied := p.applyIngressMiddleware(rule, req, w); err != nil {
if applied {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, "", rule, srv)
logRequestError(&logger, err)
return nil
}
return err
@ -122,10 +107,9 @@ func (p *Proxy) ProxyHTTP(
originProxy,
isWebsocket,
rule.Config.DisableChunkedEncoding,
logFields,
&logger,
); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, "", rule, srv)
logRequestError(&logger, err)
return err
}
return nil
@ -139,9 +123,9 @@ func (p *Proxy) ProxyHTTP(
return fmt.Errorf("response writer is not a flusher")
}
rws := connection.NewHTTPResponseReadWriterAcker(w, flusher, req)
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, originProxy); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, "", rule, srv)
logger := logger.With().Str(logFieldDestAddr, dest).Logger()
if err := p.proxyStream(tr.ToTracedContext(), rws, dest, originProxy, &logger); err != nil {
logRequestError(&logger, err)
return err
}
return nil
@ -171,38 +155,20 @@ func (p *Proxy) ProxyTCP(
serveCtx, cancel := context.WithCancel(ctx)
defer cancel()
tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, p.log)
logger := newTCPLogger(p.log, req)
tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, &logger)
logger.Debug().Msg("tcp proxy stream started")
p.log.Debug().
Int(management.EventTypeKey, int(management.TCP)).
Str(LogFieldFlowID, req.FlowID).
Str(LogFieldDestAddr, req.Dest).
Uint8(LogFieldConnIndex, req.ConnIndex).
Msg("tcp proxy stream started")
if err := p.proxyStream(tracedCtx, rwa, req.Dest, p.warpRouting.Proxy); err != nil {
p.logRequestError(err, req.CFRay, req.FlowID, "", ingress.ServiceWarpRouting)
if err := p.proxyStream(tracedCtx, rwa, req.Dest, p.warpRouting.Proxy, &logger); err != nil {
logRequestError(&logger, err)
return err
}
p.log.Debug().
Int(management.EventTypeKey, int(management.TCP)).
Str(LogFieldFlowID, req.FlowID).
Str(LogFieldDestAddr, req.Dest).
Uint8(LogFieldConnIndex, req.ConnIndex).
Msg("tcp proxy stream finished successfully")
logger.Debug().Msg("tcp proxy stream finished successfully")
return nil
}
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
srv = ing.Rules[ruleNum].Service.String()
if ing.IsSingleRule() {
return "", srv
}
return fmt.Sprintf("%d", ruleNum), srv
}
// ProxyHTTPRequest proxies requests of underlying type http and websocket to the origin service.
func (p *Proxy) proxyHTTPRequest(
w connection.ResponseWriter,
@ -210,7 +176,7 @@ func (p *Proxy) proxyHTTPRequest(
httpService ingress.HTTPOriginProxy,
isWebsocket bool,
disableChunkedEncoding bool,
fields logFields,
logger *zerolog.Logger,
) error {
roundTripReq := tr.Request
if isWebsocket {
@ -277,7 +243,7 @@ func (p *Proxy) proxyHTTPRequest(
reader: tr.Request.Body,
}
stream.Pipe(eyeballStream, rwc, p.log)
stream.Pipe(eyeballStream, rwc, logger)
return nil
}
@ -288,35 +254,45 @@ func (p *Proxy) proxyHTTPRequest(
// copy trailers
copyTrailers(w, resp)
p.logOriginResponse(resp, fields)
logOriginHTTPResponse(logger, resp)
return nil
}
// proxyStream proxies type TCP and other underlying types if the connection is defined as a stream oriented
// ingress rule.
// connectedLogger is used to log when the connection is acknowledged
func (p *Proxy) proxyStream(
tr *tracing.TracedContext,
rwa connection.ReadWriteAcker,
dest string,
connectionProxy ingress.StreamBasedOriginProxy,
logger *zerolog.Logger,
) error {
ctx := tr.Context
_, connectSpan := tr.Tracer().Start(ctx, "stream-connect")
originConn, err := connectionProxy.EstablishConnection(ctx, dest)
start := time.Now()
originConn, err := connectionProxy.EstablishConnection(ctx, dest, logger)
if err != nil {
connectStreamErrors.Inc()
tracing.EndWithErrorStatus(connectSpan, err)
return err
}
connectSpan.End()
defer originConn.Close()
logger.Debug().Msg("origin connection established")
encodedSpans := tr.GetSpans()
if err := rwa.AckConnection(encodedSpans); err != nil {
connectStreamErrors.Inc()
return err
}
originConn.Stream(ctx, rwa, p.log)
connectLatency.Observe(float64(time.Since(start).Milliseconds()))
logger.Debug().Msg("proxy stream acknowledged")
originConn.Stream(ctx, rwa, logger)
return nil
}
@ -350,14 +326,6 @@ func (p *Proxy) appendTagHeaders(r *http.Request) {
}
}
type logFields struct {
cfRay string
lbProbe bool
rule int
flowID string
connIndex uint8
}
func copyTrailers(w connection.ResponseWriter, response *http.Response) {
for trailerHeader, trailerValues := range response.Trailer {
for _, trailerValue := range trailerValues {
@ -366,64 +334,6 @@ func copyTrailers(w connection.ResponseWriter, response *http.Response) {
}
}
func (p *Proxy) logRequest(r *http.Request, fields logFields) {
log := p.log.With().Int(management.EventTypeKey, int(management.HTTP)).Logger()
event := log.Debug()
if fields.cfRay != "" {
event = event.Str(LogFieldCFRay, fields.cfRay)
}
if fields.lbProbe {
event = event.Bool(LogFieldLBProbe, fields.lbProbe)
}
if fields.cfRay == "" && !fields.lbProbe {
log.Debug().Msgf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", r.Method, r.URL, r.Proto)
}
event.
Uint8(LogFieldConnIndex, fields.connIndex).
Str("host", r.Host).
Str("path", r.URL.Path).
Interface(LogFieldRule, fields.rule).
Interface("headers", r.Header).
Int64("content-length", r.ContentLength).
Msgf("%s %s %s", r.Method, r.URL, r.Proto)
}
func (p *Proxy) logOriginResponse(resp *http.Response, fields logFields) {
responseByCode.WithLabelValues(strconv.Itoa(resp.StatusCode)).Inc()
event := p.log.Debug()
if fields.cfRay != "" {
event = event.Str(LogFieldCFRay, fields.cfRay)
}
if fields.lbProbe {
event = event.Bool(LogFieldLBProbe, fields.lbProbe)
}
event.
Int(management.EventTypeKey, int(management.HTTP)).
Uint8(LogFieldConnIndex, fields.connIndex).
Int64("content-length", resp.ContentLength).
Msgf("%s", resp.Status)
}
func (p *Proxy) logRequestError(err error, cfRay string, flowID string, rule, service string) {
requestErrors.Inc()
log := p.log.Error().Err(err)
if cfRay != "" {
log = log.Str(LogFieldCFRay, cfRay)
}
if flowID != "" {
log = log.Str(LogFieldFlowID, flowID).Int(management.EventTypeKey, int(management.TCP))
} else {
log = log.Int(management.EventTypeKey, int(management.HTTP))
}
if rule != "" {
log = log.Str(LogFieldRule, rule)
}
if service != "" {
log = log.Str(LogFieldOriginService, service)
}
log.Send()
}
func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
switch rule.Service.String() {
case ingress.ServiceBastion:

View File

@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, &log)
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, time.Duration(0), &log)
t.Run("testProxyHTTP", testProxyHTTP(proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
t.Run("testProxySSE", testProxySSE(proxy))
@ -366,7 +366,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
ctx, cancel := context.WithCancel(context.Background())
require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, &log)
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, time.Duration(0), &log)
for _, test := range tests {
responseWriter := newMockHTTPRespWriter()
@ -414,7 +414,7 @@ func TestProxyError(t *testing.T) {
log := zerolog.Nop()
proxy := NewOriginProxy(ing, noWarpRouting, testTags, &log)
proxy := NewOriginProxy(ing, noWarpRouting, testTags, time.Duration(0), &log)
responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
@ -530,7 +530,7 @@ func TestConnections(t *testing.T) {
originService: runEchoTCPService,
eyeballResponseWriter: newTCPRespWriter(replayer),
eyeballRequestBody: newTCPRequestBody([]byte("test2")),
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting),
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
connectionType: connection.TypeTCP,
requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
@ -548,7 +548,7 @@ func TestConnections(t *testing.T) {
originService: runEchoWSService,
// eyeballResponseWriter gets set after roundtrip dial.
eyeballRequestBody: newPipedWSRequestBody([]byte("test3")),
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting),
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
},
@ -675,7 +675,7 @@ func TestConnections(t *testing.T) {
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
ingressRule.StartOrigins(logger, ctx.Done())
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, logger)
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, time.Duration(0), logger)
proxy.warpRouting = test.args.warpRoutingService
dest := ln.Addr().String()

View File

@ -7,17 +7,6 @@ import (
"github.com/quic-go/quic-go/logging"
)
func perspectiveString(p logging.Perspective) string {
switch p {
case logging.PerspectiveClient:
return "client"
case logging.PerspectiveServer:
return "server"
default:
return ""
}
}
// Helper to convert logging.ByteCount(alias for int64) to float64 used in prometheus
func byteCountToPromCount(count logging.ByteCount) float64 {
return float64(count)

View File

@ -1,6 +1,8 @@
package quic
import (
"reflect"
"strings"
"sync"
"github.com/prometheus/client_golang/prometheus"
@ -16,10 +18,10 @@ var (
clientMetrics = struct {
totalConnections prometheus.Counter
closedConnections prometheus.Counter
sentPackets *prometheus.CounterVec
sentFrames *prometheus.CounterVec
sentBytes *prometheus.CounterVec
receivePackets *prometheus.CounterVec
receiveBytes *prometheus.CounterVec
receivedFrames *prometheus.CounterVec
receivedBytes *prometheus.CounterVec
bufferedPackets *prometheus.CounterVec
droppedPackets *prometheus.CounterVec
lostPackets *prometheus.CounterVec
@ -28,43 +30,88 @@ var (
smoothedRTT *prometheus.GaugeVec
}{
totalConnections: prometheus.NewCounter(
totalConnectionsOpts(logging.PerspectiveClient),
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "total_connections",
Help: "Number of connections initiated",
},
),
closedConnections: prometheus.NewCounter(
closedConnectionsOpts(logging.PerspectiveClient),
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "closed_connections",
Help: "Number of connections that has been closed",
},
),
sentPackets: prometheus.NewCounterVec(
sentPacketsOpts(logging.PerspectiveClient),
clientConnLabels,
sentFrames: prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "sent_frames",
Help: "Number of frames that have been sent through a connection",
},
append(clientConnLabels, "frame_type"),
),
sentBytes: prometheus.NewCounterVec(
sentBytesOpts(logging.PerspectiveClient),
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "sent_bytes",
Help: "Number of bytes that have been sent through a connection",
},
clientConnLabels,
),
receivePackets: prometheus.NewCounterVec(
receivePacketsOpts(logging.PerspectiveClient),
clientConnLabels,
receivedFrames: prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "received_frames",
Help: "Number of frames that have been received through a connection",
},
append(clientConnLabels, "frame_type"),
),
receiveBytes: prometheus.NewCounterVec(
receiveBytesOpts(logging.PerspectiveClient),
receivedBytes: prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "receive_bytes",
Help: "Number of bytes that have been received through a connection",
},
clientConnLabels,
),
bufferedPackets: prometheus.NewCounterVec(
bufferedPacketsOpts(logging.PerspectiveClient),
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "buffered_packets",
Help: "Number of bytes that have been buffered on a connection",
},
append(clientConnLabels, "packet_type"),
),
droppedPackets: prometheus.NewCounterVec(
droppedPacketsOpts(logging.PerspectiveClient),
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "dropped_packets",
Help: "Number of bytes that have been dropped on a connection",
},
append(clientConnLabels, "packet_type", "reason"),
),
lostPackets: prometheus.NewCounterVec(
lostPacketsOpts(logging.PerspectiveClient),
prometheus.CounterOpts{
Namespace: namespace,
Subsystem: "client",
Name: "lost_packets",
Help: "Number of packets that have been lost from a connection",
},
append(clientConnLabels, "reason"),
),
minRTT: prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: namespace,
Subsystem: perspectiveString(logging.PerspectiveClient),
Subsystem: "client",
Name: "min_rtt",
Help: "Lowest RTT measured on a connection in millisec",
},
@ -73,7 +120,7 @@ var (
latestRTT: prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: namespace,
Subsystem: perspectiveString(logging.PerspectiveClient),
Subsystem: "client",
Name: "latest_rtt",
Help: "Latest RTT measured on a connection",
},
@ -82,188 +129,37 @@ var (
smoothedRTT: prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: namespace,
Subsystem: perspectiveString(logging.PerspectiveClient),
Subsystem: "client",
Name: "smoothed_rtt",
Help: "Calculated smoothed RTT measured on a connection in millisec",
},
clientConnLabels,
),
}
// The server has many QUIC connections. Adding per connection label incurs high memory cost
serverMetrics = struct {
totalConnections prometheus.Counter
closedConnections prometheus.Counter
sentPackets prometheus.Counter
sentBytes prometheus.Counter
receivePackets prometheus.Counter
receiveBytes prometheus.Counter
bufferedPackets *prometheus.CounterVec
droppedPackets *prometheus.CounterVec
lostPackets *prometheus.CounterVec
rtt prometheus.Histogram
}{
totalConnections: prometheus.NewCounter(
totalConnectionsOpts(logging.PerspectiveServer),
),
closedConnections: prometheus.NewCounter(
closedConnectionsOpts(logging.PerspectiveServer),
),
sentPackets: prometheus.NewCounter(
sentPacketsOpts(logging.PerspectiveServer),
),
sentBytes: prometheus.NewCounter(
sentBytesOpts(logging.PerspectiveServer),
),
receivePackets: prometheus.NewCounter(
receivePacketsOpts(logging.PerspectiveServer),
),
receiveBytes: prometheus.NewCounter(
receiveBytesOpts(logging.PerspectiveServer),
),
bufferedPackets: prometheus.NewCounterVec(
bufferedPacketsOpts(logging.PerspectiveServer),
[]string{"packet_type"},
),
droppedPackets: prometheus.NewCounterVec(
droppedPacketsOpts(logging.PerspectiveServer),
[]string{"packet_type", "reason"},
),
lostPackets: prometheus.NewCounterVec(
lostPacketsOpts(logging.PerspectiveServer),
[]string{"reason"},
),
rtt: prometheus.NewHistogram(
prometheus.HistogramOpts{
Namespace: namespace,
Subsystem: perspectiveString(logging.PerspectiveServer),
Name: "rtt",
Buckets: []float64{5, 10, 20, 30, 40, 50, 75, 100},
},
),
}
registerClient = sync.Once{}
registerServer = sync.Once{}
packetTooBigDropped = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(logging.PerspectiveClient),
Subsystem: "client",
Name: "packet_too_big_dropped",
Help: "Count of packets received from origin that are too big to send to the edge and are dropped as a result",
})
)
// MetricsCollector abstracts the difference between client and server metrics from connTracer
type MetricsCollector interface {
startedConnection()
closedConnection(err error)
sentPackets(logging.ByteCount)
receivedPackets(logging.ByteCount)
bufferedPackets(logging.PacketType)
droppedPackets(logging.PacketType, logging.ByteCount, logging.PacketDropReason)
lostPackets(logging.PacketLossReason)
updatedRTT(*logging.RTTStats)
}
func totalConnectionsOpts(p logging.Perspective) prometheus.CounterOpts {
var help string
if p == logging.PerspectiveClient {
help = "Number of connections initiated. For all quic metrics, client means the side initiating the connection"
} else {
help = "Number of connections accepted. For all quic metrics, server means the side accepting connections"
}
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "total_connections",
Help: help,
}
}
func closedConnectionsOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "closed_connections",
Help: "Number of connections that has been closed",
}
}
func sentPacketsOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "sent_packets",
Help: "Number of packets that have been sent through a connection",
}
}
func sentBytesOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "sent_bytes",
Help: "Number of bytes that have been sent through a connection",
}
}
func receivePacketsOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "receive_packets",
Help: "Number of packets that have been received through a connection",
}
}
func receiveBytesOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "receive_bytes",
Help: "Number of bytes that have been received through a connection",
}
}
func bufferedPacketsOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "buffered_packets",
Help: "Number of bytes that have been buffered on a connection",
}
}
func droppedPacketsOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "dropped_packets",
Help: "Number of bytes that have been dropped on a connection",
}
}
func lostPacketsOpts(p logging.Perspective) prometheus.CounterOpts {
return prometheus.CounterOpts{
Namespace: namespace,
Subsystem: perspectiveString(p),
Name: "lost_packets",
Help: "Number of packets that have been lost from a connection",
}
}
type clientCollector struct {
index string
}
func newClientCollector(index uint8) MetricsCollector {
func newClientCollector(index uint8) *clientCollector {
registerClient.Do(func() {
prometheus.MustRegister(
clientMetrics.totalConnections,
clientMetrics.closedConnections,
clientMetrics.sentPackets,
clientMetrics.sentFrames,
clientMetrics.sentBytes,
clientMetrics.receivePackets,
clientMetrics.receiveBytes,
clientMetrics.receivedFrames,
clientMetrics.receivedBytes,
clientMetrics.bufferedPackets,
clientMetrics.droppedPackets,
clientMetrics.lostPackets,
@ -286,14 +182,12 @@ func (cc *clientCollector) closedConnection(err error) {
clientMetrics.closedConnections.Inc()
}
func (cc *clientCollector) sentPackets(size logging.ByteCount) {
clientMetrics.sentPackets.WithLabelValues(cc.index).Inc()
clientMetrics.sentBytes.WithLabelValues(cc.index).Add(byteCountToPromCount(size))
func (cc *clientCollector) sentPackets(size logging.ByteCount, frames []logging.Frame) {
cc.collectPackets(size, frames, clientMetrics.sentFrames, clientMetrics.sentBytes)
}
func (cc *clientCollector) receivedPackets(size logging.ByteCount) {
clientMetrics.receivePackets.WithLabelValues(cc.index).Inc()
clientMetrics.receiveBytes.WithLabelValues(cc.index).Add(byteCountToPromCount(size))
func (cc *clientCollector) receivedPackets(size logging.ByteCount, frames []logging.Frame) {
cc.collectPackets(size, frames, clientMetrics.receivedFrames, clientMetrics.receivedBytes)
}
func (cc *clientCollector) bufferedPackets(packetType logging.PacketType) {
@ -318,63 +212,18 @@ func (cc *clientCollector) updatedRTT(rtt *logging.RTTStats) {
clientMetrics.smoothedRTT.WithLabelValues(cc.index).Set(durationToPromGauge(rtt.SmoothedRTT()))
}
type serverCollector struct{}
func newServiceCollector() MetricsCollector {
registerServer.Do(func() {
prometheus.MustRegister(
serverMetrics.totalConnections,
serverMetrics.closedConnections,
serverMetrics.sentPackets,
serverMetrics.sentBytes,
serverMetrics.receivePackets,
serverMetrics.receiveBytes,
serverMetrics.bufferedPackets,
serverMetrics.droppedPackets,
serverMetrics.lostPackets,
serverMetrics.rtt,
)
})
return &serverCollector{}
func (cc *clientCollector) collectPackets(size logging.ByteCount, frames []logging.Frame, counter, bandwidth *prometheus.CounterVec) {
for _, frame := range frames {
counter.WithLabelValues(cc.index, frameName(frame)).Inc()
}
bandwidth.WithLabelValues(cc.index).Add(byteCountToPromCount(size))
}
func (sc *serverCollector) startedConnection() {
serverMetrics.totalConnections.Inc()
}
func (sc *serverCollector) closedConnection(err error) {
serverMetrics.closedConnections.Inc()
}
func (sc *serverCollector) sentPackets(size logging.ByteCount) {
serverMetrics.sentPackets.Inc()
serverMetrics.sentBytes.Add(byteCountToPromCount(size))
}
func (sc *serverCollector) receivedPackets(size logging.ByteCount) {
serverMetrics.receivePackets.Inc()
serverMetrics.receiveBytes.Add(byteCountToPromCount(size))
}
func (sc *serverCollector) bufferedPackets(packetType logging.PacketType) {
serverMetrics.bufferedPackets.WithLabelValues(packetTypeString(packetType)).Inc()
}
func (sc *serverCollector) droppedPackets(packetType logging.PacketType, size logging.ByteCount, reason logging.PacketDropReason) {
serverMetrics.droppedPackets.WithLabelValues(
packetTypeString(packetType),
packetDropReasonString(reason),
).Add(byteCountToPromCount(size))
}
func (sc *serverCollector) lostPackets(reason logging.PacketLossReason) {
serverMetrics.lostPackets.WithLabelValues(packetLossReasonString(reason)).Inc()
}
func (sc *serverCollector) updatedRTT(rtt *logging.RTTStats) {
latestRTT := rtt.LatestRTT()
// May return 0 if no valid updates have occurred
if latestRTT > 0 {
serverMetrics.rtt.Observe(durationToPromGauge(latestRTT))
func frameName(frame logging.Frame) string {
if frame == nil {
return "nil"
} else {
name := reflect.TypeOf(frame).Elem().Name()
return strings.TrimSuffix(name, "Frame")
}
}

View File

@ -109,63 +109,6 @@ func TestConnectResponseMeta(t *testing.T) {
}
}
func TestUnregisterUdpSession(t *testing.T) {
unregisterMessage := "closed by eyeball"
var tests = []struct {
name string
sessionRPCServer mockSessionRPCServer
timeout time.Duration
}{
{
name: "UnregisterUdpSessionTimesout if the RPC server does not respond",
sessionRPCServer: mockSessionRPCServer{
sessionID: uuid.New(),
dstIP: net.IP{172, 16, 0, 1},
dstPort: 8000,
closeIdleAfter: testCloseIdleAfterHint,
unregisterMessage: unregisterMessage,
traceContext: "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1",
},
// very very low value so we trigger the timeout every time.
timeout: time.Nanosecond * 1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
logger := zerolog.Nop()
clientStream, serverStream := newMockRPCStreams()
sessionRegisteredChan := make(chan struct{})
go func() {
protocol, err := DetermineProtocol(serverStream)
assert.NoError(t, err)
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
assert.NoError(t, err)
err = rpcServerStream.Serve(test.sessionRPCServer, nil, &logger)
assert.NoError(t, err)
serverStream.Close()
close(sessionRegisteredChan)
}()
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, test.timeout, &logger)
assert.NoError(t, err)
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
assert.NoError(t, err)
assert.NoError(t, reg.Err)
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
rpcClientStream.Close()
<-sessionRegisteredChan
})
}
}
func TestRegisterUdpSession(t *testing.T) {
unregisterMessage := "closed by eyeball"

View File

@ -1,20 +1,33 @@
package quic
import (
"errors"
"net"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
// The error that is throw by the writer when there is `no network activity`.
var idleTimeoutError = quic.IdleTimeoutError{}
type SafeStreamCloser struct {
lock sync.Mutex
stream quic.Stream
lock sync.Mutex
stream quic.Stream
writeTimeout time.Duration
log *zerolog.Logger
closing atomic.Bool
}
func NewSafeStreamCloser(stream quic.Stream) *SafeStreamCloser {
func NewSafeStreamCloser(stream quic.Stream, writeTimeout time.Duration, log *zerolog.Logger) *SafeStreamCloser {
return &SafeStreamCloser{
stream: stream,
stream: stream,
writeTimeout: writeTimeout,
log: log,
}
}
@ -25,10 +38,43 @@ func (s *SafeStreamCloser) Read(p []byte) (n int, err error) {
func (s *SafeStreamCloser) Write(p []byte) (n int, err error) {
s.lock.Lock()
defer s.lock.Unlock()
return s.stream.Write(p)
if s.writeTimeout > 0 {
err = s.stream.SetWriteDeadline(time.Now().Add(s.writeTimeout))
if err != nil {
log.Err(err).Msg("Error setting write deadline for QUIC stream")
}
}
nBytes, err := s.stream.Write(p)
if err != nil {
s.handleWriteError(err)
}
return nBytes, err
}
// Handles the timeout error in case it happened, by canceling the stream write.
func (s *SafeStreamCloser) handleWriteError(err error) {
// If we are closing the stream we just ignore any write error.
if s.closing.Load() {
return
}
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
// We don't need to log if what cause the timeout was no network activity.
if !errors.Is(netErr, &idleTimeoutError) {
s.log.Error().Err(netErr).Msg("Closing quic stream due to timeout while writing")
}
// We need to explicitly cancel the write so that it frees all buffers.
s.stream.CancelWrite(0)
}
}
}
func (s *SafeStreamCloser) Close() error {
// Set this stream to a closing state.
s.closing.Store(true)
// Make sure a possible writer does not block the lock forever. We need it, so we can close the writer
// side of the stream safely.
_ = s.stream.SetWriteDeadline(time.Now())

View File

@ -9,6 +9,8 @@ import (
"testing"
"time"
"github.com/rs/zerolog"
"github.com/quic-go/quic-go"
"github.com/stretchr/testify/require"
)
@ -70,7 +72,8 @@ func quicClient(t *testing.T, addr net.Addr) {
go func(iter int) {
defer wg.Done()
stream := NewSafeStreamCloser(quicStream)
log := zerolog.Nop()
stream := NewSafeStreamCloser(quicStream, 30*time.Second, &log)
defer stream.Close()
// Do a bunch of round trips over this stream that should work.
@ -107,7 +110,8 @@ func quicServer(t *testing.T, serverReady *sync.WaitGroup, conn net.PacketConn)
go func(iter int) {
defer wg.Done()
stream := NewSafeStreamCloser(quicStream)
log := zerolog.Nop()
stream := NewSafeStreamCloser(quicStream, 30*time.Second, &log)
defer stream.Close()
// Do a bunch of round trips over this stream that should work.

View File

@ -3,7 +3,6 @@ package quic
import (
"context"
"net"
"time"
"github.com/quic-go/quic-go/logging"
"github.com/rs/zerolog"
@ -16,8 +15,6 @@ type tracer struct {
}
type tracerConfig struct {
isClient bool
// Only client has an index
index uint8
}
@ -25,67 +22,36 @@ func NewClientTracer(logger *zerolog.Logger, index uint8) func(context.Context,
t := &tracer{
logger: logger,
config: &tracerConfig{
isClient: true,
index: index,
index: index,
},
}
return t.TracerForConnection
}
func NewServerTracer(logger *zerolog.Logger) *logging.Tracer {
return &logging.Tracer{
SentPacket: func(net.Addr, *logging.Header, logging.ByteCount, []logging.Frame) {},
SentVersionNegotiationPacket: func(_ net.Addr, dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {},
DroppedPacket: func(net.Addr, logging.PacketType, logging.ByteCount, logging.PacketDropReason) {},
}
}
func (t *tracer) TracerForConnection(_ctx context.Context, _p logging.Perspective, _odcid logging.ConnectionID) *logging.ConnectionTracer {
if t.config.isClient {
return newConnTracer(newClientCollector(t.config.index))
}
return newConnTracer(newServiceCollector())
return newConnTracer(newClientCollector(t.config.index))
}
// connTracer collects connection level metrics
type connTracer struct {
metricsCollector MetricsCollector
metricsCollector *clientCollector
}
func newConnTracer(metricsCollector MetricsCollector) *logging.ConnectionTracer {
func newConnTracer(metricsCollector *clientCollector) *logging.ConnectionTracer {
tracer := connTracer{
metricsCollector: metricsCollector,
}
return &logging.ConnectionTracer{
StartedConnection: tracer.StartedConnection,
NegotiatedVersion: tracer.NegotiatedVersion,
ClosedConnection: tracer.ClosedConnection,
SentTransportParameters: tracer.SentTransportParameters,
ReceivedTransportParameters: tracer.ReceivedTransportParameters,
RestoredTransportParameters: tracer.RestoredTransportParameters,
SentLongHeaderPacket: tracer.SentLongHeaderPacket,
SentShortHeaderPacket: tracer.SentShortHeaderPacket,
ReceivedVersionNegotiationPacket: tracer.ReceivedVersionNegotiationPacket,
ReceivedRetry: tracer.ReceivedRetry,
ReceivedLongHeaderPacket: tracer.ReceivedLongHeaderPacket,
ReceivedShortHeaderPacket: tracer.ReceivedShortHeaderPacket,
BufferedPacket: tracer.BufferedPacket,
DroppedPacket: tracer.DroppedPacket,
UpdatedMetrics: tracer.UpdatedMetrics,
AcknowledgedPacket: tracer.AcknowledgedPacket,
LostPacket: tracer.LostPacket,
UpdatedCongestionState: tracer.UpdatedCongestionState,
UpdatedPTOCount: tracer.UpdatedPTOCount,
UpdatedKeyFromTLS: tracer.UpdatedKeyFromTLS,
UpdatedKey: tracer.UpdatedKey,
DroppedEncryptionLevel: tracer.DroppedEncryptionLevel,
DroppedKey: tracer.DroppedKey,
SetLossTimer: tracer.SetLossTimer,
LossTimerExpired: tracer.LossTimerExpired,
LossTimerCanceled: tracer.LossTimerCanceled,
ECNStateUpdated: tracer.ECNStateUpdated,
Close: tracer.Close,
Debug: tracer.Debug,
StartedConnection: tracer.StartedConnection,
ClosedConnection: tracer.ClosedConnection,
SentLongHeaderPacket: tracer.SentLongHeaderPacket,
SentShortHeaderPacket: tracer.SentShortHeaderPacket,
ReceivedLongHeaderPacket: tracer.ReceivedLongHeaderPacket,
ReceivedShortHeaderPacket: tracer.ReceivedShortHeaderPacket,
BufferedPacket: tracer.BufferedPacket,
DroppedPacket: tracer.DroppedPacket,
UpdatedMetrics: tracer.UpdatedMetrics,
LostPacket: tracer.LostPacket,
}
}
@ -97,14 +63,6 @@ func (ct *connTracer) ClosedConnection(err error) {
ct.metricsCollector.closedConnection(err)
}
func (ct *connTracer) SentPacket(hdr *logging.ExtendedHeader, packetSize logging.ByteCount, ack *logging.AckFrame, frames []logging.Frame) {
ct.metricsCollector.sentPackets(packetSize)
}
func (ct *connTracer) ReceivedPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, frames []logging.Frame) {
ct.metricsCollector.receivedPackets(size)
}
func (ct *connTracer) BufferedPacket(pt logging.PacketType, size logging.ByteCount) {
ct.metricsCollector.bufferedPackets(pt)
}
@ -121,74 +79,20 @@ func (ct *connTracer) UpdatedMetrics(rttStats *logging.RTTStats, cwnd, bytesInFl
ct.metricsCollector.updatedRTT(rttStats)
}
func (ct *connTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) {
}
func (ct *connTracer) SentTransportParameters(parameters *logging.TransportParameters) {
}
func (ct *connTracer) ReceivedTransportParameters(parameters *logging.TransportParameters) {
}
func (ct *connTracer) RestoredTransportParameters(parameters *logging.TransportParameters) {
}
func (ct *connTracer) SentLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) {
ct.metricsCollector.sentPackets(size, frames)
}
func (ct *connTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, ack *logging.AckFrame, frames []logging.Frame) {
}
func (ct *connTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) {
}
func (ct *connTracer) ReceivedRetry(header *logging.Header) {
ct.metricsCollector.sentPackets(size, frames)
}
func (ct *connTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) {
ct.metricsCollector.receivedPackets(size, frames)
}
func (ct *connTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, size logging.ByteCount, ecn logging.ECN, frames []logging.Frame) {
}
func (ct *connTracer) AcknowledgedPacket(level logging.EncryptionLevel, number logging.PacketNumber) {
}
func (ct *connTracer) UpdatedCongestionState(state logging.CongestionState) {
}
func (ct *connTracer) UpdatedPTOCount(value uint32) {
}
func (ct *connTracer) UpdatedKeyFromTLS(level logging.EncryptionLevel, perspective logging.Perspective) {
}
func (ct *connTracer) UpdatedKey(generation logging.KeyPhase, remote bool) {
}
func (ct *connTracer) DroppedEncryptionLevel(level logging.EncryptionLevel) {
}
func (ct *connTracer) DroppedKey(generation logging.KeyPhase) {
}
func (ct *connTracer) SetLossTimer(timerType logging.TimerType, level logging.EncryptionLevel, time time.Time) {
}
func (ct *connTracer) LossTimerExpired(timerType logging.TimerType, level logging.EncryptionLevel) {
}
func (ct *connTracer) LossTimerCanceled() {
}
func (ct *connTracer) ECNStateUpdated(state logging.ECNState, trigger logging.ECNStateTrigger) {
}
func (ct *connTracer) Close() {
}
func (ct *connTracer) Debug(name, msg string) {
ct.metricsCollector.receivedPackets(size, frames)
}
type quicLogger struct {

View File

@ -15,7 +15,8 @@ import (
"os"
"time"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
homedir "github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh"
@ -51,6 +52,8 @@ type errorResponse struct {
var mockRequest func(url, contentType string, body io.Reader) (*http.Response, error) = nil
var signatureAlgs = []jose.SignatureAlgorithm{jose.RS256}
// GenerateShortLivedCertificate generates and stores a keypair for short lived certs
func GenerateShortLivedCertificate(appURL *url.URL, token string) error {
fullName, err := cfpath.GenerateSSHCertFilePathFromURL(appURL, keyName)
@ -87,7 +90,7 @@ func SignCert(token, pubKey string) (string, error) {
return "", errors.New("invalid token")
}
parsedToken, err := jwt.ParseSigned(token)
parsedToken, err := jwt.ParseSigned(token, signatureAlgs)
if err != nil {
return "", errors.Wrap(err, "failed to parse JWT")
}

View File

@ -3,6 +3,8 @@
package sshgen
import (
"crypto/rand"
"crypto/rsa"
"encoding/json"
"fmt"
"io"
@ -14,8 +16,8 @@ import (
"testing"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/config"
@ -103,13 +105,16 @@ func tokenGenerator() string {
Expiry: jwt.NewNumericDate(exp),
}
key := []byte("secret")
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, (&jose.SignerOptions{}).WithType("JWT"))
key, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
panic(err)
}
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: key}, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
panic(err)
}
signedToken, err := jwt.Signed(signer).Claims(claims).CompactSerialize()
signedToken, err := jwt.Signed(signer).Claims(claims).Serialize()
if err != nil {
panic(err)
}

View File

@ -66,6 +66,7 @@ type TunnelConfig struct {
PacketConfig *ingress.GlobalRouterConfig
UDPUnregisterSessionTimeout time.Duration
WriteStreamTimeout time.Duration
DisableQUICPathMTUDiscovery bool
@ -614,6 +615,7 @@ func (e *EdgeTunnelServer) serveQUIC(
connLogger.Logger(),
e.config.PacketConfig,
e.config.UDPUnregisterSessionTimeout,
e.config.WriteStreamTimeout,
)
if err != nil {
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to create new quic connection")

View File

@ -12,7 +12,7 @@ import (
"syscall"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v4"
"github.com/pkg/errors"
"github.com/rs/zerolog"
@ -31,7 +31,8 @@ const (
)
var (
userAgent = "DEV"
userAgent = "DEV"
signatureAlgs = []jose.SignatureAlgorithm{jose.RS256}
)
type AppInfo struct {
@ -415,7 +416,7 @@ func getTokenIfExists(path string) (*jose.JSONWebSignature, error) {
if err != nil {
return nil, err
}
token, err := jose.ParseSigned(string(content))
token, err := jose.ParseSigned(string(content), signatureAlgs)
if err != nil {
return nil, err
}

View File

@ -1,5 +1,7 @@
package oidc
import jose "github.com/go-jose/go-jose/v4"
// JOSE asymmetric signing algorithm values as defined by RFC 7518
//
// see: https://tools.ietf.org/html/rfc7518#section-3.1
@ -15,3 +17,16 @@ const (
PS512 = "PS512" // RSASSA-PSS using SHA512 and MGF1-SHA512
EdDSA = "EdDSA" // Ed25519 using SHA-512
)
var allAlgs = []jose.SignatureAlgorithm{
jose.RS256,
jose.RS384,
jose.RS512,
jose.ES256,
jose.ES384,
jose.ES512,
jose.PS256,
jose.PS384,
jose.PS512,
jose.EdDSA,
}

View File

@ -8,12 +8,12 @@ import (
"crypto/rsa"
"errors"
"fmt"
"io/ioutil"
"io"
"net/http"
"sync"
"time"
jose "github.com/go-jose/go-jose/v3"
jose "github.com/go-jose/go-jose/v4"
)
// StaticKeySet is a verifier that validates JWT against a static set of public keys.
@ -25,7 +25,9 @@ type StaticKeySet struct {
// VerifySignature compares the signature against a static set of public keys.
func (s *StaticKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
jws, err := jose.ParseSigned(jwt)
// Algorithms are already checked by Verifier, so this parse method accepts
// any algorithm.
jws, err := jose.ParseSigned(jwt, allAlgs)
if err != nil {
return nil, fmt.Errorf("parsing jwt: %v", err)
}
@ -127,8 +129,13 @@ var parsedJWTKey contextKey
func (r *RemoteKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
jws, ok := ctx.Value(parsedJWTKey).(*jose.JSONWebSignature)
if !ok {
// The algorithm values are already enforced by the Validator, which also sets
// the context value above to pre-parsed signature.
//
// Practically, this codepath isn't called in normal use of this package, but
// if it is, the algorithms have already been checked.
var err error
jws, err = jose.ParseSigned(jwt)
jws, err = jose.ParseSigned(jwt, allAlgs)
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
}
@ -159,7 +166,7 @@ func (r *RemoteKeySet) verify(ctx context.Context, jws *jose.JSONWebSignature) (
// https://openid.net/specs/openid-connect-core-1_0.html#RotateSigKeys
keys, err := r.keysFromRemote(ctx)
if err != nil {
return nil, fmt.Errorf("fetching keys %v", err)
return nil, fmt.Errorf("fetching keys %w", err)
}
for _, key := range keys {
@ -228,11 +235,11 @@ func (r *RemoteKeySet) updateKeys() ([]jose.JSONWebKey, error) {
resp, err := doRequest(r.ctx, req)
if err != nil {
return nil, fmt.Errorf("oidc: get keys failed %v", err)
return nil, fmt.Errorf("oidc: get keys failed %w", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("unable to read response body: %v", err)
}

View File

@ -10,7 +10,7 @@ import (
"errors"
"fmt"
"hash"
"io/ioutil"
"io"
"mime"
"net/http"
"strings"
@ -79,7 +79,7 @@ func getClient(ctx context.Context) *http.Client {
// provider, err := oidc.NewProvider(ctx, discoveryBaseURL)
//
// This is insecure because validating the correct issuer is critical for multi-tenant
// proivders. Any overrides here MUST be carefully reviewed.
// providers. Any overrides here MUST be carefully reviewed.
func InsecureIssuerURLContext(ctx context.Context, issuerURL string) context.Context {
return context.WithValue(ctx, issuerURLKey, issuerURL)
}
@ -94,12 +94,13 @@ func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
// Provider represents an OpenID Connect server's configuration.
type Provider struct {
issuer string
authURL string
tokenURL string
userInfoURL string
jwksURL string
algorithms []string
issuer string
authURL string
tokenURL string
deviceAuthURL string
userInfoURL string
jwksURL string
algorithms []string
// Raw claims returned by the server.
rawClaims []byte
@ -128,12 +129,13 @@ func (p *Provider) remoteKeySet() KeySet {
}
type providerJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
Algorithms []string `json:"id_token_signing_alg_values_supported"`
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
DeviceAuthURL string `json:"device_authorization_endpoint"`
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
Algorithms []string `json:"id_token_signing_alg_values_supported"`
}
// supportedAlgorithms is a list of algorithms explicitly supported by this
@ -165,6 +167,9 @@ type ProviderConfig struct {
// TokenURL is the endpoint used by the provider to support the OAuth 2.0
// token endpoint.
TokenURL string
// DeviceAuthURL is the endpoint used by the provider to support the OAuth 2.0
// device authorization endpoint.
DeviceAuthURL string
// UserInfoURL is the endpoint used by the provider to support the OpenID
// Connect UserInfo flow.
//
@ -185,13 +190,14 @@ type ProviderConfig struct {
// through discovery.
func (p *ProviderConfig) NewProvider(ctx context.Context) *Provider {
return &Provider{
issuer: p.IssuerURL,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
userInfoURL: p.UserInfoURL,
jwksURL: p.JWKSURL,
algorithms: p.Algorithms,
client: getClient(ctx),
issuer: p.IssuerURL,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
deviceAuthURL: p.DeviceAuthURL,
userInfoURL: p.UserInfoURL,
jwksURL: p.JWKSURL,
algorithms: p.Algorithms,
client: getClient(ctx),
}
}
@ -211,7 +217,7 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("unable to read response body: %v", err)
}
@ -240,14 +246,15 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
}
}
return &Provider{
issuer: issuerURL,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
userInfoURL: p.UserInfoURL,
jwksURL: p.JWKSURL,
algorithms: algs,
rawClaims: body,
client: getClient(ctx),
issuer: issuerURL,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
deviceAuthURL: p.DeviceAuthURL,
userInfoURL: p.UserInfoURL,
jwksURL: p.JWKSURL,
algorithms: algs,
rawClaims: body,
client: getClient(ctx),
}, nil
}
@ -273,7 +280,7 @@ func (p *Provider) Claims(v interface{}) error {
// Endpoint returns the OAuth2 auth and token endpoints for the given provider.
func (p *Provider) Endpoint() oauth2.Endpoint {
return oauth2.Endpoint{AuthURL: p.authURL, TokenURL: p.tokenURL}
return oauth2.Endpoint{AuthURL: p.authURL, DeviceAuthURL: p.deviceAuthURL, TokenURL: p.tokenURL}
}
// UserInfoEndpoint returns the OpenID Connect userinfo endpoint for the given
@ -332,7 +339,7 @@ func (p *Provider) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource)
return nil, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

View File

@ -7,12 +7,12 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"io"
"net/http"
"strings"
"time"
jose "github.com/go-jose/go-jose/v3"
jose "github.com/go-jose/go-jose/v4"
"golang.org/x/oauth2"
)
@ -182,7 +182,7 @@ func resolveDistributedClaim(ctx context.Context, verifier *IDTokenVerifier, src
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("unable to read response body: %v", err)
}
@ -310,7 +310,16 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
return t, nil
}
jws, err := jose.ParseSigned(rawIDToken)
var supportedSigAlgs []jose.SignatureAlgorithm
for _, alg := range v.config.SupportedSigningAlgs {
supportedSigAlgs = append(supportedSigAlgs, jose.SignatureAlgorithm(alg))
}
if len(supportedSigAlgs) == 0 {
// If no algorithms were specified by both the config and discovery, default
// to the one mandatory algorithm "RS256".
supportedSigAlgs = []jose.SignatureAlgorithm{jose.RS256}
}
jws, err := jose.ParseSigned(rawIDToken, supportedSigAlgs)
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
}
@ -322,17 +331,7 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
default:
return nil, fmt.Errorf("oidc: multiple signatures on id token not supported")
}
sig := jws.Signatures[0]
supportedSigAlgs := v.config.SupportedSigningAlgs
if len(supportedSigAlgs) == 0 {
supportedSigAlgs = []string{RS256}
}
if !contains(supportedSigAlgs, sig.Header.Algorithm) {
return nil, fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", supportedSigAlgs, sig.Header.Algorithm)
}
t.sigAlgorithm = sig.Header.Algorithm
ctx = context.WithValue(ctx, parsedJWTKey, jws)

16
vendor/github.com/fortytw2/leaktest/.travis.yml generated vendored Normal file
View File

@ -0,0 +1,16 @@
language: go
go:
- 1.8
- 1.9
- "1.10"
- "1.11"
- tip
script:
- go test -v -race -parallel 5 -coverprofile=coverage.txt -covermode=atomic ./
- go test github.com/fortytw2/leaktest -run ^TestEmptyLeak$
before_install:
- pip install --user codecov
after_success:
- codecov

64
vendor/github.com/fortytw2/leaktest/README.md generated vendored Normal file
View File

@ -0,0 +1,64 @@
## Leaktest [![Build Status](https://travis-ci.org/fortytw2/leaktest.svg?branch=master)](https://travis-ci.org/fortytw2/leaktest) [![codecov](https://codecov.io/gh/fortytw2/leaktest/branch/master/graph/badge.svg)](https://codecov.io/gh/fortytw2/leaktest) [![Sourcegraph](https://sourcegraph.com/github.com/fortytw2/leaktest/-/badge.svg)](https://sourcegraph.com/github.com/fortytw2/leaktest?badge) [![Documentation](https://godoc.org/github.com/fortytw2/gpt?status.svg)](http://godoc.org/github.com/fortytw2/leaktest)
Refactored, tested variant of the goroutine leak detector found in both
`net/http` tests and the `cockroachdb` source tree.
Takes a snapshot of running goroutines at the start of a test, and at the end -
compares the two and _voila_. Ignores runtime/sys goroutines. Doesn't play nice
with `t.Parallel()` right now, but there are plans to do so.
### Installation
Go 1.7+
```
go get -u github.com/fortytw2/leaktest
```
Go 1.5/1.6 need to use the tag `v1.0.0`, as newer versions depend on
`context.Context`.
### Example
These tests fail, because they leak a goroutine
```go
// Default "Check" will poll for 5 seconds to check that all
// goroutines are cleaned up
func TestPool(t *testing.T) {
defer leaktest.Check(t)()
go func() {
for {
time.Sleep(time.Second)
}
}()
}
// Helper function to timeout after X duration
func TestPoolTimeout(t *testing.T) {
defer leaktest.CheckTimeout(t, time.Second)()
go func() {
for {
time.Sleep(time.Second)
}
}()
}
// Use Go 1.7+ context.Context for cancellation
func TestPoolContext(t *testing.T) {
ctx, _ := context.WithTimeout(context.Background(), time.Second)
defer leaktest.CheckContext(ctx, t)()
go func() {
for {
time.Sleep(time.Second)
}
}()
}
```
## LICENSE
Same BSD-style as Go, see LICENSE

153
vendor/github.com/fortytw2/leaktest/leaktest.go generated vendored Normal file
View File

@ -0,0 +1,153 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package leaktest provides tools to detect leaked goroutines in tests.
// To use it, call "defer leaktest.Check(t)()" at the beginning of each
// test that may use goroutines.
// copied out of the cockroachdb source tree with slight modifications to be
// more re-useable
package leaktest
import (
"context"
"fmt"
"runtime"
"sort"
"strconv"
"strings"
"time"
)
type goroutine struct {
id uint64
stack string
}
type goroutineByID []*goroutine
func (g goroutineByID) Len() int { return len(g) }
func (g goroutineByID) Less(i, j int) bool { return g[i].id < g[j].id }
func (g goroutineByID) Swap(i, j int) { g[i], g[j] = g[j], g[i] }
func interestingGoroutine(g string) (*goroutine, error) {
sl := strings.SplitN(g, "\n", 2)
if len(sl) != 2 {
return nil, fmt.Errorf("error parsing stack: %q", g)
}
stack := strings.TrimSpace(sl[1])
if strings.HasPrefix(stack, "testing.RunTests") {
return nil, nil
}
if stack == "" ||
// Ignore HTTP keep alives
strings.Contains(stack, ").readLoop(") ||
strings.Contains(stack, ").writeLoop(") ||
// Below are the stacks ignored by the upstream leaktest code.
strings.Contains(stack, "testing.Main(") ||
strings.Contains(stack, "testing.(*T).Run(") ||
strings.Contains(stack, "runtime.goexit") ||
strings.Contains(stack, "created by runtime.gc") ||
strings.Contains(stack, "interestingGoroutines") ||
strings.Contains(stack, "runtime.MHeap_Scavenger") ||
strings.Contains(stack, "signal.signal_recv") ||
strings.Contains(stack, "sigterm.handler") ||
strings.Contains(stack, "runtime_mcall") ||
strings.Contains(stack, "goroutine in C code") {
return nil, nil
}
// Parse the goroutine's ID from the header line.
h := strings.SplitN(sl[0], " ", 3)
if len(h) < 3 {
return nil, fmt.Errorf("error parsing stack header: %q", sl[0])
}
id, err := strconv.ParseUint(h[1], 10, 64)
if err != nil {
return nil, fmt.Errorf("error parsing goroutine id: %s", err)
}
return &goroutine{id: id, stack: strings.TrimSpace(g)}, nil
}
// interestingGoroutines returns all goroutines we care about for the purpose
// of leak checking. It excludes testing or runtime ones.
func interestingGoroutines(t ErrorReporter) []*goroutine {
buf := make([]byte, 2<<20)
buf = buf[:runtime.Stack(buf, true)]
var gs []*goroutine
for _, g := range strings.Split(string(buf), "\n\n") {
gr, err := interestingGoroutine(g)
if err != nil {
t.Errorf("leaktest: %s", err)
continue
} else if gr == nil {
continue
}
gs = append(gs, gr)
}
sort.Sort(goroutineByID(gs))
return gs
}
// ErrorReporter is a tiny subset of a testing.TB to make testing not such a
// massive pain
type ErrorReporter interface {
Errorf(format string, args ...interface{})
}
// Check snapshots the currently-running goroutines and returns a
// function to be run at the end of tests to see whether any
// goroutines leaked, waiting up to 5 seconds in error conditions
func Check(t ErrorReporter) func() {
return CheckTimeout(t, 5*time.Second)
}
// CheckTimeout is the same as Check, but with a configurable timeout
func CheckTimeout(t ErrorReporter, dur time.Duration) func() {
ctx, cancel := context.WithCancel(context.Background())
fn := CheckContext(ctx, t)
return func() {
timer := time.AfterFunc(dur, cancel)
fn()
// Remember to clean up the timer and context
timer.Stop()
cancel()
}
}
// CheckContext is the same as Check, but uses a context.Context for
// cancellation and timeout control
func CheckContext(ctx context.Context, t ErrorReporter) func() {
orig := map[uint64]bool{}
for _, g := range interestingGoroutines(t) {
orig[g.id] = true
}
return func() {
var leaked []string
for {
select {
case <-ctx.Done():
t.Errorf("leaktest: timed out checking goroutines")
default:
leaked = make([]string, 0)
for _, g := range interestingGoroutines(t) {
if !orig[g.id] {
leaked = append(leaked, g.stack)
}
}
if len(leaked) == 0 {
return
}
// don't spin needlessly
time.Sleep(time.Millisecond * 50)
continue
}
break
}
for _, g := range leaked {
t.Errorf("leaktest: leaked goroutine: %v", g)
}
}
}

View File

@ -1,10 +0,0 @@
Serious about security
======================
Square recognizes the important contributions the security research community
can make. We therefore encourage reporting security issues with the code
contained in this repository.
If you believe you have discovered a security vulnerability, please follow the
guidelines at <https://bugcrowd.com/squareopensource>.

View File

@ -1,133 +0,0 @@
/*-
* Copyright 2016 Zbigniew Mandziejewicz
* Copyright 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jwt
import (
"fmt"
"strings"
jose "github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/json"
)
// JSONWebToken represents a JSON Web Token (as specified in RFC7519).
type JSONWebToken struct {
payload func(k interface{}) ([]byte, error)
unverifiedPayload func() []byte
Headers []jose.Header
}
type NestedJSONWebToken struct {
enc *jose.JSONWebEncryption
Headers []jose.Header
}
// Claims deserializes a JSONWebToken into dest using the provided key.
func (t *JSONWebToken) Claims(key interface{}, dest ...interface{}) error {
b, err := t.payload(key)
if err != nil {
return err
}
for _, d := range dest {
if err := json.Unmarshal(b, d); err != nil {
return err
}
}
return nil
}
// UnsafeClaimsWithoutVerification deserializes the claims of a
// JSONWebToken into the dests. For signed JWTs, the claims are not
// verified. This function won't work for encrypted JWTs.
func (t *JSONWebToken) UnsafeClaimsWithoutVerification(dest ...interface{}) error {
if t.unverifiedPayload == nil {
return fmt.Errorf("go-jose/go-jose: Cannot get unverified claims")
}
claims := t.unverifiedPayload()
for _, d := range dest {
if err := json.Unmarshal(claims, d); err != nil {
return err
}
}
return nil
}
func (t *NestedJSONWebToken) Decrypt(decryptionKey interface{}) (*JSONWebToken, error) {
b, err := t.enc.Decrypt(decryptionKey)
if err != nil {
return nil, err
}
sig, err := ParseSigned(string(b))
if err != nil {
return nil, err
}
return sig, nil
}
// ParseSigned parses token from JWS form.
func ParseSigned(s string) (*JSONWebToken, error) {
sig, err := jose.ParseSigned(s)
if err != nil {
return nil, err
}
headers := make([]jose.Header, len(sig.Signatures))
for i, signature := range sig.Signatures {
headers[i] = signature.Header
}
return &JSONWebToken{
payload: sig.Verify,
unverifiedPayload: sig.UnsafePayloadWithoutVerification,
Headers: headers,
}, nil
}
// ParseEncrypted parses token from JWE form.
func ParseEncrypted(s string) (*JSONWebToken, error) {
enc, err := jose.ParseEncrypted(s)
if err != nil {
return nil, err
}
return &JSONWebToken{
payload: enc.Decrypt,
Headers: []jose.Header{enc.Header},
}, nil
}
// ParseSignedAndEncrypted parses signed-then-encrypted token from JWE form.
func ParseSignedAndEncrypted(s string) (*NestedJSONWebToken, error) {
enc, err := jose.ParseEncrypted(s)
if err != nil {
return nil, err
}
contentType, _ := enc.Header.ExtraHeaders[jose.HeaderContentType].(string)
if strings.ToUpper(contentType) != "JWT" {
return nil, ErrInvalidContentType
}
return &NestedJSONWebToken{
enc: enc,
Headers: []jose.Header{enc.Header},
}, nil
}

72
vendor/github.com/go-jose/go-jose/v4/CHANGELOG.md generated vendored Normal file
View File

@ -0,0 +1,72 @@
# v4.0.1
## Fixed
- An attacker could send a JWE containing compressed data that used large
amounts of memory and CPU when decompressed by `Decrypt` or `DecryptMulti`.
Those functions now return an error if the decompressed data would exceed
250kB or 10x the compressed size (whichever is larger). Thanks to
Enze Wang@Alioth and Jianjun Chen@Zhongguancun Lab (@zer0yu and @chenjj)
for reporting.
# v4.0.0
This release makes some breaking changes in order to more thoroughly
address the vulnerabilities discussed in [Three New Attacks Against JSON Web
Tokens][1], "Sign/encrypt confusion", "Billion hash attack", and "Polyglot
token".
## Changed
- Limit JWT encryption types (exclude password or public key types) (#78)
- Enforce minimum length for HMAC keys (#85)
- jwt: match any audience in a list, rather than requiring all audiences (#81)
- jwt: accept only Compact Serialization (#75)
- jws: Add expected algorithms for signatures (#74)
- Require specifying expected algorithms for ParseEncrypted,
ParseSigned, ParseDetached, jwt.ParseEncrypted, jwt.ParseSigned,
jwt.ParseSignedAndEncrypted (#69, #74)
- Usually there is a small, known set of appropriate algorithms for a program
to use and it's a mistake to allow unexpected algorithms. For instance the
"billion hash attack" relies in part on programs accepting the PBES2
encryption algorithm and doing the necessary work even if they weren't
specifically configured to allow PBES2.
- Revert "Strip padding off base64 strings" (#82)
- The specs require base64url encoding without padding.
- Minimum supported Go version is now 1.21
## Added
- ParseSignedCompact, ParseSignedJSON, ParseEncryptedCompact, ParseEncryptedJSON.
- These allow parsing a specific serialization, as opposed to ParseSigned and
ParseEncrypted, which try to automatically detect which serialization was
provided. It's common to require a specific serialization for a specific
protocol - for instance JWT requires Compact serialization.
[1]: https://i.blackhat.com/BH-US-23/Presentations/US-23-Tervoort-Three-New-Attacks-Against-JSON-Web-Tokens.pdf
# v3.0.2
## Fixed
- DecryptMulti: handle decompression error (#19)
## Changed
- jwe/CompactSerialize: improve performance (#67)
- Increase the default number of PBKDF2 iterations to 600k (#48)
- Return the proper algorithm for ECDSA keys (#45)
## Added
- Add Thumbprint support for opaque signers (#38)
# v3.0.1
## Fixed
- Security issue: an attacker specifying a large "p2c" value can cause
JSONWebEncryption.Decrypt and JSONWebEncryption.DecryptMulti to consume large
amounts of CPU, causing a DoS. Thanks to Matt Schwager (@mschwager) for the
disclosure and to Tom Tervoort for originally publishing the category of attack.
https://i.blackhat.com/BH-US-23/Presentations/US-23-Tervoort-Three-New-Attacks-Against-JSON-Web-Tokens.pdf

View File

@ -1,10 +1,9 @@
# Go JOSE
[![godoc](http://img.shields.io/badge/godoc-jose_package-blue.svg?style=flat)](https://godoc.org/gopkg.in/go-jose/go-jose.v2)
[![godoc](http://img.shields.io/badge/godoc-jwt_package-blue.svg?style=flat)](https://godoc.org/gopkg.in/go-jose/go-jose.v2/jwt)
[![license](http://img.shields.io/badge/license-apache_2.0-blue.svg?style=flat)](https://raw.githubusercontent.com/go-jose/go-jose/master/LICENSE)
[![build](https://travis-ci.org/go-jose/go-jose.svg?branch=master)](https://travis-ci.org/go-jose/go-jose)
[![coverage](https://coveralls.io/repos/github/go-jose/go-jose/badge.svg?branch=master)](https://coveralls.io/r/go-jose/go-jose)
[![godoc](https://pkg.go.dev/badge/github.com/go-jose/go-jose/v4.svg)](https://pkg.go.dev/github.com/go-jose/go-jose/v4)
[![godoc](https://pkg.go.dev/badge/github.com/go-jose/go-jose/v4/jwt.svg)](https://pkg.go.dev/github.com/go-jose/go-jose/v4/jwt)
[![license](https://img.shields.io/badge/license-apache_2.0-blue.svg?style=flat)](https://raw.githubusercontent.com/go-jose/go-jose/master/LICENSE)
[![test](https://img.shields.io/github/checks-status/go-jose/go-jose/v4)](https://github.com/go-jose/go-jose/actions)
Package jose aims to provide an implementation of the Javascript Object Signing
and Encryption set of standards. This includes support for JSON Web Encryption,
@ -21,13 +20,13 @@ US maintained blocked list.
## Overview
The implementation follows the
[JSON Web Encryption](http://dx.doi.org/10.17487/RFC7516) (RFC 7516),
[JSON Web Signature](http://dx.doi.org/10.17487/RFC7515) (RFC 7515), and
[JSON Web Token](http://dx.doi.org/10.17487/RFC7519) (RFC 7519) specifications.
[JSON Web Encryption](https://dx.doi.org/10.17487/RFC7516) (RFC 7516),
[JSON Web Signature](https://dx.doi.org/10.17487/RFC7515) (RFC 7515), and
[JSON Web Token](https://dx.doi.org/10.17487/RFC7519) (RFC 7519) specifications.
Tables of supported algorithms are shown below. The library supports both
the compact and JWS/JWE JSON Serialization formats, and has optional support for
multiple recipients. It also comes with a small command-line utility
([`jose-util`](https://github.com/go-jose/go-jose/tree/master/jose-util))
([`jose-util`](https://pkg.go.dev/github.com/go-jose/go-jose/jose-util))
for dealing with JOSE messages in a shell.
**Note**: We use a forked version of the `encoding/json` package from the Go
@ -38,29 +37,22 @@ libraries in other languages.
### Versions
[Version 2](https://gopkg.in/go-jose/go-jose.v2)
([branch](https://github.com/go-jose/go-jose/tree/v2),
[doc](https://godoc.org/gopkg.in/go-jose/go-jose.v2)) is the current stable version:
[Version 4](https://github.com/go-jose/go-jose)
([branch](https://github.com/go-jose/go-jose/tree/main),
[doc](https://pkg.go.dev/github.com/go-jose/go-jose/v4), [releases](https://github.com/go-jose/go-jose/releases)) is the current stable version:
import "gopkg.in/go-jose/go-jose.v2"
import "github.com/go-jose/go-jose/v4"
[Version 3](https://github.com/go-jose/go-jose)
([branch](https://github.com/go-jose/go-jose/tree/master),
[doc](https://godoc.org/github.com/go-jose/go-jose)) is the under development/unstable version (not released yet):
The old [square/go-jose](https://github.com/square/go-jose) repo contains the prior v1 and v2 versions, which
are still useable but not actively developed anymore.
import "github.com/go-jose/go-jose/v3"
All new feature development takes place on the `master` branch, which we are
preparing to release as version 3 soon. Version 2 will continue to receive
critical bug and security fixes. Note that starting with version 3 we are
using Go modules for versioning instead of `gopkg.in` as before. Version 3 also will require Go version 1.13 or higher.
Version 1 (on the `v1` branch) is frozen and not supported anymore.
Version 3, in this repo, is still receiving security fixes but not functionality
updates.
### Supported algorithms
See below for a table of supported algorithms. Algorithm identifiers match
the names in the [JSON Web Algorithms](http://dx.doi.org/10.17487/RFC7518)
the names in the [JSON Web Algorithms](https://dx.doi.org/10.17487/RFC7518)
standard where possible. The Godoc reference has a list of constants.
Key encryption | Algorithm identifier(s)
@ -103,20 +95,20 @@ allows attaching a key id.
Algorithm(s) | Corresponding types
:------------------------- | -------------------------------
RSA | *[rsa.PublicKey](http://golang.org/pkg/crypto/rsa/#PublicKey), *[rsa.PrivateKey](http://golang.org/pkg/crypto/rsa/#PrivateKey)
ECDH, ECDSA | *[ecdsa.PublicKey](http://golang.org/pkg/crypto/ecdsa/#PublicKey), *[ecdsa.PrivateKey](http://golang.org/pkg/crypto/ecdsa/#PrivateKey)
EdDSA<sup>1</sup> | [ed25519.PublicKey](https://godoc.org/pkg/crypto/ed25519#PublicKey), [ed25519.PrivateKey](https://godoc.org/pkg/crypto/ed25519#PrivateKey)
RSA | *[rsa.PublicKey](https://pkg.go.dev/crypto/rsa/#PublicKey), *[rsa.PrivateKey](https://pkg.go.dev/crypto/rsa/#PrivateKey)
ECDH, ECDSA | *[ecdsa.PublicKey](https://pkg.go.dev/crypto/ecdsa/#PublicKey), *[ecdsa.PrivateKey](https://pkg.go.dev/crypto/ecdsa/#PrivateKey)
EdDSA<sup>1</sup> | [ed25519.PublicKey](https://pkg.go.dev/crypto/ed25519#PublicKey), [ed25519.PrivateKey](https://pkg.go.dev/crypto/ed25519#PrivateKey)
AES, HMAC | []byte
<sup>1. Only available in version 2 or later of the package</sup>
## Examples
[![godoc](http://img.shields.io/badge/godoc-jose_package-blue.svg?style=flat)](https://godoc.org/gopkg.in/go-jose/go-jose.v2)
[![godoc](http://img.shields.io/badge/godoc-jwt_package-blue.svg?style=flat)](https://godoc.org/gopkg.in/go-jose/go-jose.v2/jwt)
[![godoc](https://pkg.go.dev/badge/github.com/go-jose/go-jose/v4.svg)](https://pkg.go.dev/github.com/go-jose/go-jose/v4)
[![godoc](https://pkg.go.dev/badge/github.com/go-jose/go-jose/v4/jwt.svg)](https://pkg.go.dev/github.com/go-jose/go-jose/v4/jwt)
Examples can be found in the Godoc
reference for this package. The
[`jose-util`](https://github.com/go-jose/go-jose/tree/master/jose-util)
[`jose-util`](https://github.com/go-jose/go-jose/tree/v4/jose-util)
subdirectory also contains a small command-line utility which might be useful
as an example as well.

13
vendor/github.com/go-jose/go-jose/v4/SECURITY.md generated vendored Normal file
View File

@ -0,0 +1,13 @@
# Security Policy
This document explains how to contact the Let's Encrypt security team to report security vulnerabilities.
## Supported Versions
| Version | Supported |
| ------- | ----------|
| >= v3 | &check; |
| v2 | &cross; |
| v1 | &cross; |
## Reporting a vulnerability
Please see [https://letsencrypt.org/contact/#security](https://letsencrypt.org/contact/#security) for the email address to report a vulnerability. Ensure that the subject line for your report contains the word `vulnerability` and is descriptive. Your email should be acknowledged within 24 hours. If you do not receive a response within 24 hours, please follow-up again with another email.

View File

@ -29,8 +29,8 @@ import (
"fmt"
"math/big"
josecipher "github.com/go-jose/go-jose/v3/cipher"
"github.com/go-jose/go-jose/v3/json"
josecipher "github.com/go-jose/go-jose/v4/cipher"
"github.com/go-jose/go-jose/v4/json"
)
// A generic RSA-based encrypter/verifier
@ -285,6 +285,9 @@ func (ctx rsaDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm
switch alg {
case RS256, RS384, RS512:
// TODO(https://github.com/go-jose/go-jose/issues/40): As of go1.20, the
// random parameter is legacy and ignored, and it can be nil.
// https://cs.opensource.google/go/go/+/refs/tags/go1.20:src/crypto/rsa/pkcs1v15.go;l=263;bpv=0;bpt=1
out, err = rsa.SignPKCS1v15(RandReader, ctx.privateKey, hash, hashed)
case PS256, PS384, PS512:
out, err = rsa.SignPSS(RandReader, ctx.privateKey, hash, hashed, &rsa.PSSOptions{

View File

@ -21,9 +21,8 @@ import (
"crypto/rsa"
"errors"
"fmt"
"reflect"
"github.com/go-jose/go-jose/v3/json"
"github.com/go-jose/go-jose/v4/json"
)
// Encrypter represents an encrypter which produces an encrypted JWE object.
@ -76,14 +75,24 @@ type recipientKeyInfo struct {
type EncrypterOptions struct {
Compression CompressionAlgorithm
// Optional map of additional keys to be inserted into the protected header
// of a JWS object. Some specifications which make use of JWS like to insert
// additional values here. All values must be JSON-serializable.
// Optional map of name/value pairs to be inserted into the protected
// header of a JWS object. Some specifications which make use of
// JWS require additional values here.
//
// Values will be serialized by [json.Marshal] and must be valid inputs to
// that function.
//
// [json.Marshal]: https://pkg.go.dev/encoding/json#Marshal
ExtraHeaders map[HeaderKey]interface{}
}
// WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it
// if necessary. It returns itself and so can be used in a fluent style.
// if necessary, and returns the updated EncrypterOptions.
//
// The v parameter will be serialized by [json.Marshal] and must be a valid
// input to that function.
//
// [json.Marshal]: https://pkg.go.dev/encoding/json#Marshal
func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions {
if eo.ExtraHeaders == nil {
eo.ExtraHeaders = map[HeaderKey]interface{}{}
@ -111,7 +120,17 @@ func (eo *EncrypterOptions) WithType(typ ContentType) *EncrypterOptions {
// default of 100000 will be used for the count and a 128-bit random salt will
// be generated.
type Recipient struct {
Algorithm KeyAlgorithm
Algorithm KeyAlgorithm
// Key must have one of these types:
// - ed25519.PublicKey
// - *ecdsa.PublicKey
// - *rsa.PublicKey
// - *JSONWebKey
// - JSONWebKey
// - []byte (a symmetric key)
// - Any type that satisfies the OpaqueKeyEncrypter interface
//
// The type of Key must match the value of Algorithm.
Key interface{}
KeyID string
PBES2Count int
@ -150,16 +169,17 @@ func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions)
switch rcpt.Algorithm {
case DIRECT:
// Direct encryption mode must be treated differently
if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) {
keyBytes, ok := rawKey.([]byte)
if !ok {
return nil, ErrUnsupportedKeyType
}
if encrypter.cipher.keySize() != len(rawKey.([]byte)) {
if encrypter.cipher.keySize() != len(keyBytes) {
return nil, ErrInvalidKeySize
}
encrypter.keyGenerator = staticKeyGenerator{
key: rawKey.([]byte),
key: keyBytes,
}
recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, rawKey.([]byte))
recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, keyBytes)
recipientInfo.keyID = keyID
if rcpt.KeyID != "" {
recipientInfo.keyID = rcpt.KeyID
@ -168,16 +188,16 @@ func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions)
return encrypter, nil
case ECDH_ES:
// ECDH-ES (w/o key wrapping) is similar to DIRECT mode
typeOf := reflect.TypeOf(rawKey)
if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) {
keyDSA, ok := rawKey.(*ecdsa.PublicKey)
if !ok {
return nil, ErrUnsupportedKeyType
}
encrypter.keyGenerator = ecKeyGenerator{
size: encrypter.cipher.keySize(),
algID: string(enc),
publicKey: rawKey.(*ecdsa.PublicKey),
publicKey: keyDSA,
}
recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, rawKey.(*ecdsa.PublicKey))
recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, keyDSA)
recipientInfo.keyID = keyID
if rcpt.KeyID != "" {
recipientInfo.keyID = rcpt.KeyID
@ -270,9 +290,8 @@ func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKey
recipient, err := makeJWERecipient(alg, encryptionKey.Key)
recipient.keyID = encryptionKey.KeyID
return recipient, err
}
if encrypter, ok := encryptionKey.(OpaqueKeyEncrypter); ok {
return newOpaqueKeyEncrypter(alg, encrypter)
case OpaqueKeyEncrypter:
return newOpaqueKeyEncrypter(alg, encryptionKey)
}
return recipientKeyInfo{}, ErrUnsupportedKeyType
}
@ -300,11 +319,11 @@ func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
return newDecrypter(decryptionKey.Key)
case *JSONWebKey:
return newDecrypter(decryptionKey.Key)
case OpaqueKeyDecrypter:
return &opaqueKeyDecrypter{decrypter: decryptionKey}, nil
default:
return nil, ErrUnsupportedKeyType
}
if okd, ok := decryptionKey.(OpaqueKeyDecrypter); ok {
return &opaqueKeyDecrypter{decrypter: okd}, nil
}
return nil, ErrUnsupportedKeyType
}
// Implementation of encrypt method producing a JWE object.
@ -403,9 +422,27 @@ func (ctx *genericEncrypter) Options() EncrypterOptions {
}
}
// Decrypt and validate the object and return the plaintext. Note that this
// function does not support multi-recipient, if you desire multi-recipient
// Decrypt and validate the object and return the plaintext. This
// function does not support multi-recipient. If you desire multi-recipient
// decryption use DecryptMulti instead.
//
// The decryptionKey argument must contain a private or symmetric key
// and must have one of these types:
// - *ecdsa.PrivateKey
// - *rsa.PrivateKey
// - *JSONWebKey
// - JSONWebKey
// - *JSONWebKeySet
// - JSONWebKeySet
// - []byte (a symmetric key)
// - string (a symmetric key)
// - Any type that satisfies the OpaqueKeyDecrypter interface.
//
// Note that ed25519 is only available for signatures, not encryption, so is
// not an option here.
//
// Automatically decompresses plaintext, but returns an error if the decompressed
// data would be >250kB or >10x the size of the compressed data, whichever is larger.
func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) {
headers := obj.mergedHeaders(nil)
@ -462,15 +499,24 @@ func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error)
// The "zip" header parameter may only be present in the protected header.
if comp := obj.protected.getCompression(); comp != "" {
plaintext, err = decompress(comp, plaintext)
if err != nil {
return nil, fmt.Errorf("go-jose/go-jose: failed to decompress plaintext: %v", err)
}
}
return plaintext, err
return plaintext, nil
}
// DecryptMulti decrypts and validates the object and returns the plaintexts,
// with support for multiple recipients. It returns the index of the recipient
// for which the decryption was successful, the merged headers for that recipient,
// and the plaintext.
//
// The decryptionKey argument must have one of the types allowed for the
// decryptionKey argument of Decrypt().
//
// Automatically decompresses plaintext, but returns an error if the decompressed
// data would be >250kB or >3x the size of the compressed data, whichever is larger.
func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) {
globalHeaders := obj.mergedHeaders(nil)
@ -532,7 +578,10 @@ func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Heade
// The "zip" header parameter may only be present in the protected header.
if comp := obj.protected.getCompression(); comp != "" {
plaintext, _ = decompress(comp, plaintext)
plaintext, err = decompress(comp, plaintext)
if err != nil {
return -1, Header{}, nil, fmt.Errorf("go-jose/go-jose: failed to decompress plaintext: %v", err)
}
}
sanitized, err := headers.sanitized()

View File

@ -15,13 +15,11 @@
*/
/*
Package jose aims to provide an implementation of the Javascript Object Signing
and Encryption set of standards. It implements encryption and signing based on
the JSON Web Encryption and JSON Web Signature standards, with optional JSON Web
Token support available in a sub-package. The library supports both the compact
and JWS/JWE JSON Serialization formats, and has optional support for multiple
recipients.
*/
package jose

View File

@ -21,12 +21,13 @@ import (
"compress/flate"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"math/big"
"strings"
"unicode"
"github.com/go-jose/go-jose/v3/json"
"github.com/go-jose/go-jose/v4/json"
)
// Helper function to serialize known-good objects.
@ -85,7 +86,7 @@ func decompress(algorithm CompressionAlgorithm, input []byte) ([]byte, error) {
}
}
// Compress with DEFLATE
// deflate compresses the input.
func deflate(input []byte) ([]byte, error) {
output := new(bytes.Buffer)
@ -97,15 +98,24 @@ func deflate(input []byte) ([]byte, error) {
return output.Bytes(), err
}
// Decompress with DEFLATE
// inflate decompresses the input.
//
// Errors if the decompressed data would be >250kB or >10x the size of the
// compressed data, whichever is larger.
func inflate(input []byte) ([]byte, error) {
output := new(bytes.Buffer)
reader := flate.NewReader(bytes.NewBuffer(input))
_, err := io.Copy(output, reader)
if err != nil {
maxCompressedSize := max(250_000, 10*int64(len(input)))
limit := maxCompressedSize + 1
n, err := io.CopyN(output, reader, limit)
if err != nil && err != io.EOF {
return nil, err
}
if n == limit {
return nil, fmt.Errorf("uncompressed data would be too large (>%d bytes)", maxCompressedSize)
}
err = reader.Close()
return output.Bytes(), err
@ -154,7 +164,7 @@ func (b *byteBuffer) UnmarshalJSON(data []byte) error {
return nil
}
decoded, err := base64URLDecode(encoded)
decoded, err := base64.RawURLEncoding.DecodeString(encoded)
if err != nil {
return err
}
@ -184,8 +194,35 @@ func (b byteBuffer) toInt() int {
return int(b.bigInt().Int64())
}
// base64URLDecode is implemented as defined in https://www.rfc-editor.org/rfc/rfc7515.html#appendix-C
func base64URLDecode(value string) ([]byte, error) {
value = strings.TrimRight(value, "=")
return base64.RawURLEncoding.DecodeString(value)
func base64EncodeLen(sl []byte) int {
return base64.RawURLEncoding.EncodedLen(len(sl))
}
func base64JoinWithDots(inputs ...[]byte) string {
if len(inputs) == 0 {
return ""
}
// Count of dots.
totalCount := len(inputs) - 1
for _, input := range inputs {
totalCount += base64EncodeLen(input)
}
out := make([]byte, totalCount)
startEncode := 0
for i, input := range inputs {
base64.RawURLEncoding.Encode(out[startEncode:], input)
if i == len(inputs)-1 {
continue
}
startEncode += base64EncodeLen(input)
out[startEncode] = '.'
startEncode++
}
return string(out)
}

View File

@ -1,4 +1,4 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Copyright (c) 2012 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are

View File

@ -75,14 +75,13 @@ import (
//
// The JSON null value unmarshals into an interface, map, pointer, or slice
// by setting that Go value to nil. Because null is often used in JSON to mean
// ``not present,'' unmarshaling a JSON null into any other Go type has no effect
// “not present,” unmarshaling a JSON null into any other Go type has no effect
// on the value and produces no error.
//
// When unmarshaling quoted strings, invalid UTF-8 or
// invalid UTF-16 surrogate pairs are not treated as an error.
// Instead, they are replaced by the Unicode replacement
// character U+FFFD.
//
func Unmarshal(data []byte, v interface{}) error {
// Check for well-formedness.
// Avoids filling out half a data structure

View File

@ -58,6 +58,7 @@ import (
// becomes a member of the object unless
// - the field's tag is "-", or
// - the field is empty and its tag specifies the "omitempty" option.
//
// The empty values are false, 0, any
// nil pointer or interface value, and any array, slice, map, or string of
// length zero. The object's default key string is the struct field name
@ -65,28 +66,28 @@ import (
// the struct field's tag value is the key name, followed by an optional comma
// and options. Examples:
//
// // Field is ignored by this package.
// Field int `json:"-"`
// // Field is ignored by this package.
// Field int `json:"-"`
//
// // Field appears in JSON as key "myName".
// Field int `json:"myName"`
// // Field appears in JSON as key "myName".
// Field int `json:"myName"`
//
// // Field appears in JSON as key "myName" and
// // the field is omitted from the object if its value is empty,
// // as defined above.
// Field int `json:"myName,omitempty"`
// // Field appears in JSON as key "myName" and
// // the field is omitted from the object if its value is empty,
// // as defined above.
// Field int `json:"myName,omitempty"`
//
// // Field appears in JSON as key "Field" (the default), but
// // the field is skipped if empty.
// // Note the leading comma.
// Field int `json:",omitempty"`
// // Field appears in JSON as key "Field" (the default), but
// // the field is skipped if empty.
// // Note the leading comma.
// Field int `json:",omitempty"`
//
// The "string" option signals that a field is stored as JSON inside a
// JSON-encoded string. It applies only to fields of string, floating point,
// integer, or boolean types. This extra level of encoding is sometimes used
// when communicating with JavaScript programs:
//
// Int64String int64 `json:",string"`
// Int64String int64 `json:",string"`
//
// The key name will be used if it's a non-empty string consisting of
// only Unicode letters, digits, dollar signs, percent signs, hyphens,
@ -133,7 +134,6 @@ import (
// JSON cannot represent cyclic data structures and Marshal does not
// handle them. Passing cyclic structures to Marshal will result in
// an infinite recursion.
//
func Marshal(v interface{}) ([]byte, error) {
e := &encodeState{}
err := e.marshal(v)

Some files were not shown because too many files have changed in this diff Show More